be/src/exprs/function/function_paimon.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/function/simple_function_factory.h" |
19 | | |
20 | | #ifdef WITH_PAIMON_CPP |
21 | | |
22 | | #include <arrow/array.h> |
23 | | #include <arrow/c/bridge.h> |
24 | | #include <arrow/record_batch.h> |
25 | | |
26 | | #include <memory> |
27 | | #include <string> |
28 | | #include <unordered_map> |
29 | | #include <utility> |
30 | | #include <vector> |
31 | | |
32 | | #include "core/column/column_const.h" |
33 | | #include "core/column/column_vector.h" |
34 | | #include "core/data_type/data_type_number.h" |
35 | | #include "exprs/function/function.h" |
36 | | #include "format/arrow/arrow_block_convertor.h" |
37 | | #include "format/arrow/arrow_row_batch.h" |
38 | | #include "format/parquet/arrow_memory_pool.h" |
39 | | #include "paimon/utils/bucket_id_calculator.h" |
40 | | #include "runtime/query_context.h" |
41 | | #include "runtime/runtime_state.h" |
42 | | #include "vec/sink/writer/paimon/paimon_doris_memory_pool.h" |
43 | | |
44 | | namespace doris::vectorized { |
45 | | |
46 | | struct PaimonBucketIdState { |
47 | | int32_t bucket_num = 0; |
48 | | std::shared_ptr<::paimon::MemoryPool> pool; |
49 | | std::unique_ptr<::paimon::BucketIdCalculator> calculator; |
50 | | }; |
51 | | |
52 | | class FunctionPaimonBucketId final : public IFunction { |
53 | | public: |
54 | | static constexpr auto name = "paimon_bucket_id"; |
55 | 2 | static FunctionPtr create() { return std::make_shared<FunctionPaimonBucketId>(); } |
56 | 0 | String get_name() const override { return name; } |
57 | 1 | bool is_variadic() const override { return true; } |
58 | 0 | size_t get_number_of_arguments() const override { return 0; } |
59 | 0 | bool use_default_implementation_for_constants() const override { return false; } |
60 | | |
61 | 0 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
62 | 0 | return std::make_shared<DataTypeInt32>(); |
63 | 0 | } |
64 | | |
65 | 0 | Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { |
66 | 0 | if (scope != FunctionContext::THREAD_LOCAL) { |
67 | 0 | return Status::OK(); |
68 | 0 | } |
69 | 0 | const int num_args = context->get_num_args(); |
70 | 0 | if (num_args < 2) { |
71 | 0 | return Status::InvalidArgument("paimon_bucket_id requires at least 2 arguments"); |
72 | 0 | } |
73 | 0 | const int bucket_num_arg_idx = num_args - 1; |
74 | 0 | if (!context->is_col_constant(bucket_num_arg_idx)) { |
75 | 0 | return Status::InvalidArgument("paimon_bucket_id requires constant bucket_num"); |
76 | 0 | } |
77 | 0 | int64_t bucket_num64 = 0; |
78 | 0 | if (!context->get_constant_col(bucket_num_arg_idx)->column_ptr->is_null_at(0)) { |
79 | 0 | bucket_num64 = context->get_constant_col(bucket_num_arg_idx)->column_ptr->get_int(0); |
80 | 0 | } |
81 | 0 | if (bucket_num64 <= 0 || bucket_num64 > std::numeric_limits<int32_t>::max()) { |
82 | 0 | return Status::InvalidArgument("invalid paimon bucket_num {}", bucket_num64); |
83 | 0 | } |
84 | 0 | auto st = std::make_shared<PaimonBucketIdState>(); |
85 | 0 | st->bucket_num = static_cast<int32_t>(bucket_num64); |
86 | 0 | st->pool = std::make_shared<PaimonDorisMemoryPool>(context->state()->query_mem_tracker()); |
87 | 0 | auto calc_res = ::paimon::BucketIdCalculator::Create(false, st->bucket_num, st->pool); |
88 | 0 | if (!calc_res.ok()) { |
89 | 0 | return Status::InternalError("failed to create paimon bucket calculator: {}", |
90 | 0 | calc_res.status().ToString()); |
91 | 0 | } |
92 | 0 | st->calculator = std::move(calc_res).value(); |
93 | 0 | context->set_function_state(scope, st); |
94 | 0 | return Status::OK(); |
95 | 0 | } |
96 | | |
97 | 0 | Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { |
98 | 0 | if (scope == FunctionContext::THREAD_LOCAL) { |
99 | 0 | context->set_function_state(scope, nullptr); |
100 | 0 | } |
101 | 0 | return Status::OK(); |
102 | 0 | } |
103 | | |
104 | | Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, |
105 | 0 | uint32_t result, size_t input_rows_count) const override { |
106 | 0 | auto* st = reinterpret_cast<PaimonBucketIdState*>( |
107 | 0 | context->get_function_state(FunctionContext::THREAD_LOCAL)); |
108 | 0 | if (st == nullptr || st->calculator == nullptr) { |
109 | 0 | return Status::InternalError("paimon_bucket_id state is not initialized"); |
110 | 0 | } |
111 | 0 | if (arguments.size() < 2) { |
112 | 0 | return Status::InvalidArgument("paimon_bucket_id requires at least 2 arguments"); |
113 | 0 | } |
114 | | |
115 | 0 | Block key_block; |
116 | 0 | key_block.reserve(arguments.size() - 1); |
117 | 0 | for (size_t i = 0; i + 1 < arguments.size(); ++i) { |
118 | 0 | const auto& col = block.get_by_position(arguments[i]); |
119 | 0 | key_block.insert(col); |
120 | 0 | } |
121 | |
|
122 | 0 | ArrowMemoryPool<> arrow_pool; |
123 | 0 | std::shared_ptr<arrow::Schema> arrow_schema; |
124 | 0 | RETURN_IF_ERROR(get_arrow_schema_from_block(key_block, &arrow_schema, |
125 | 0 | context->state()->timezone())); |
126 | 0 | std::shared_ptr<arrow::RecordBatch> record_batch; |
127 | 0 | RETURN_IF_ERROR(convert_to_arrow_batch(key_block, arrow_schema, &arrow_pool, &record_batch, |
128 | 0 | context->state()->timezone_obj())); |
129 | | |
130 | 0 | std::vector<std::shared_ptr<arrow::Field>> bucket_fields; |
131 | 0 | std::vector<std::shared_ptr<arrow::Array>> bucket_columns; |
132 | 0 | bucket_fields.reserve(record_batch->num_columns()); |
133 | 0 | bucket_columns.reserve(record_batch->num_columns()); |
134 | 0 | for (int i = 0; i < record_batch->num_columns(); ++i) { |
135 | 0 | bucket_fields.push_back(arrow_schema->field(i)); |
136 | 0 | bucket_columns.push_back(record_batch->column(i)); |
137 | 0 | } |
138 | |
|
139 | 0 | auto bucket_struct_res = arrow::StructArray::Make(bucket_columns, bucket_fields); |
140 | 0 | if (!bucket_struct_res.ok()) { |
141 | 0 | return Status::InternalError("failed to build bucket struct array: {}", |
142 | 0 | bucket_struct_res.status().ToString()); |
143 | 0 | } |
144 | 0 | std::shared_ptr<arrow::Array> bucket_struct = bucket_struct_res.ValueOrDie(); |
145 | 0 | std::shared_ptr<arrow::Schema> bucket_schema = arrow::schema(bucket_fields); |
146 | |
|
147 | 0 | ArrowArray c_bucket_array; |
148 | 0 | auto arrow_status = arrow::ExportArray(*bucket_struct, &c_bucket_array); |
149 | 0 | if (!arrow_status.ok()) { |
150 | 0 | return Status::InternalError("failed to export bucket arrow array: {}", |
151 | 0 | arrow_status.ToString()); |
152 | 0 | } |
153 | 0 | ArrowSchema c_bucket_schema; |
154 | 0 | arrow_status = arrow::ExportSchema(*bucket_schema, &c_bucket_schema); |
155 | 0 | if (!arrow_status.ok()) { |
156 | 0 | if (c_bucket_array.release) { |
157 | 0 | c_bucket_array.release(&c_bucket_array); |
158 | 0 | } |
159 | 0 | return Status::InternalError("failed to export bucket arrow schema: {}", |
160 | 0 | arrow_status.ToString()); |
161 | 0 | } |
162 | | |
163 | 0 | std::vector<int32_t> bucket_ids(input_rows_count, -1); |
164 | 0 | auto paimon_st = st->calculator->CalculateBucketIds(&c_bucket_array, &c_bucket_schema, |
165 | 0 | bucket_ids.data()); |
166 | 0 | if (c_bucket_array.release) { |
167 | 0 | c_bucket_array.release(&c_bucket_array); |
168 | 0 | } |
169 | 0 | if (c_bucket_schema.release) { |
170 | 0 | c_bucket_schema.release(&c_bucket_schema); |
171 | 0 | } |
172 | 0 | if (!paimon_st.ok()) { |
173 | 0 | return Status::InternalError("failed to calculate paimon bucket ids: {}", |
174 | 0 | paimon_st.ToString()); |
175 | 0 | } |
176 | | |
177 | 0 | auto col_res = ColumnInt32::create(input_rows_count); |
178 | 0 | auto& res_data = col_res->get_data(); |
179 | 0 | for (size_t i = 0; i < input_rows_count; ++i) { |
180 | 0 | res_data[i] = bucket_ids[i]; |
181 | 0 | } |
182 | 0 | block.replace_by_position(result, std::move(col_res)); |
183 | 0 | return Status::OK(); |
184 | 0 | } |
185 | | }; |
186 | | |
187 | 1 | void register_function_paimon(SimpleFunctionFactory& factory) { |
188 | 1 | factory.register_function<FunctionPaimonBucketId>(); |
189 | 1 | } |
190 | | |
191 | | } // namespace doris::vectorized |
192 | | |
193 | | #else |
194 | | |
195 | | namespace doris::vectorized { |
196 | | void register_function_paimon(SimpleFunctionFactory&) {} |
197 | | } // namespace doris::vectorized |
198 | | |
199 | | #endif |