be/src/exprs/aggregate/aggregate_function_python_udaf.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "exprs/aggregate/aggregate_function_python_udaf.h" |
19 | | |
20 | | #include <arrow/array.h> |
21 | | #include <arrow/memory_pool.h> |
22 | | #include <arrow/record_batch.h> |
23 | | #include <arrow/type.h> |
24 | | #include <fmt/format.h> |
25 | | |
26 | | #include "common/exception.h" |
27 | | #include "common/logging.h" |
28 | | #include "core/block/block.h" |
29 | | #include "core/column/column_nullable.h" |
30 | | #include "core/column/column_vector.h" |
31 | | #include "core/data_type/data_type_factory.hpp" |
32 | | #include "core/data_type/data_type_nullable.h" |
33 | | #include "core/data_type/define_primitive_type.h" |
34 | | #include "format/arrow/arrow_block_convertor.h" |
35 | | #include "format/arrow/arrow_row_batch.h" |
36 | | #include "runtime/user_function_cache.h" |
37 | | #include "udf/python/python_env.h" |
38 | | #include "udf/python/python_server.h" |
39 | | #include "util/timezone_utils.h" |
40 | | |
41 | | namespace doris { |
42 | | |
43 | 0 | Status AggregatePythonUDAFData::create(int64_t place) { |
44 | 0 | DCHECK(client) << "Client must be set before calling create"; |
45 | 0 | RETURN_IF_ERROR(client->create(place)); |
46 | 0 | return Status::OK(); |
47 | 0 | } |
48 | | |
49 | | Status AggregatePythonUDAFData::add(int64_t place_id, const IColumn** columns, |
50 | | int64_t row_num_start, int64_t row_num_end, |
51 | 0 | const DataTypes& argument_types) { |
52 | 0 | DCHECK(client) << "Client must be set before calling add"; |
53 | | |
54 | | // Zero-copy: Use full columns with range specification |
55 | 0 | Block input_block; |
56 | 0 | for (size_t i = 0; i < argument_types.size(); ++i) { |
57 | 0 | input_block.insert( |
58 | 0 | ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i))); |
59 | 0 | } |
60 | |
|
61 | 0 | std::shared_ptr<arrow::Schema> schema; |
62 | 0 | RETURN_IF_ERROR( |
63 | 0 | get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone)); |
64 | 0 | cctz::time_zone timezone_obj; |
65 | 0 | TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj); |
66 | |
|
67 | 0 | std::shared_ptr<arrow::RecordBatch> batch; |
68 | | // Zero-copy: convert only the specified range |
69 | 0 | RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(), |
70 | 0 | &batch, timezone_obj, row_num_start, row_num_end)); |
71 | | // Send the batch (already sliced in convert_to_arrow_batch) |
72 | | // Single place mode: no places column needed |
73 | 0 | RETURN_IF_ERROR(client->accumulate(place_id, true, *batch, 0, batch->num_rows())); |
74 | 0 | return Status::OK(); |
75 | 0 | } |
76 | | |
77 | | Status AggregatePythonUDAFData::add_batch(AggregateDataPtr* places, size_t place_offset, |
78 | | size_t num_rows, const IColumn** columns, |
79 | | const DataTypes& argument_types, size_t start, |
80 | 0 | size_t end) { |
81 | 0 | DCHECK(client) << "Client must be set before calling add_batch"; |
82 | 0 | DCHECK(end > start) << "end must be greater than start"; |
83 | 0 | DCHECK(end <= num_rows) << "end must not exceed num_rows"; |
84 | |
|
85 | 0 | size_t slice_rows = end - start; |
86 | 0 | Block input_block; |
87 | 0 | for (size_t i = 0; i < argument_types.size(); ++i) { |
88 | 0 | DCHECK(columns[i]->size() == num_rows) << "Column size must match num_rows"; |
89 | 0 | input_block.insert( |
90 | 0 | ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i))); |
91 | 0 | } |
92 | |
|
93 | 0 | auto places_col = ColumnInt64::create(num_rows); |
94 | 0 | auto& places_data = places_col->get_data(); |
95 | | |
96 | | // Fill places column with place IDs for the slice [start, end) |
97 | 0 | for (size_t i = start; i < end; ++i) { |
98 | 0 | places_data[i] = reinterpret_cast<int64_t>(places[i] + place_offset); |
99 | 0 | } |
100 | |
|
101 | 0 | static DataTypePtr places_type = |
102 | 0 | DataTypeFactory::instance().create_data_type(PrimitiveType::TYPE_BIGINT, false); |
103 | 0 | input_block.insert(ColumnWithTypeAndName(std::move(places_col), places_type, "places")); |
104 | |
|
105 | 0 | std::shared_ptr<arrow::Schema> schema; |
106 | 0 | RETURN_IF_ERROR( |
107 | 0 | get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone)); |
108 | 0 | cctz::time_zone timezone_obj; |
109 | 0 | TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj); |
110 | |
|
111 | 0 | std::shared_ptr<arrow::RecordBatch> batch; |
112 | | // Zero-copy: convert only the [start, end) range |
113 | | // This slice includes the places column automatically |
114 | 0 | RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(), |
115 | 0 | &batch, timezone_obj, start, end)); |
116 | | // Send entire batch (already contains places column) to Python |
117 | | // place_id=0 is ignored when is_single_place=false |
118 | 0 | RETURN_IF_ERROR(client->accumulate(0, false, *batch, 0, slice_rows)); |
119 | 0 | return Status::OK(); |
120 | 0 | } |
121 | | |
122 | 0 | Status AggregatePythonUDAFData::merge(const AggregatePythonUDAFData& rhs, int64_t place) { |
123 | 0 | DCHECK(client) << "Client must be set before calling merge"; |
124 | | |
125 | | // Get serialized state from rhs (already stored in serialize_data by read()) |
126 | 0 | auto serialized_state = arrow::Buffer::Wrap( |
127 | 0 | reinterpret_cast<const uint8_t*>(rhs.serialize_data.data()), rhs.serialize_data.size()); |
128 | 0 | RETURN_IF_ERROR(client->merge(place, serialized_state)); |
129 | 0 | return Status::OK(); |
130 | 0 | } |
131 | | |
132 | 0 | Status AggregatePythonUDAFData::write(BufferWritable& buf, int64_t place) const { |
133 | 0 | DCHECK(client) << "Client must be set before calling write"; |
134 | | |
135 | | // Serialize state from Python server |
136 | 0 | std::shared_ptr<arrow::Buffer> serialized_state; |
137 | 0 | RETURN_IF_ERROR(client->serialize(place, &serialized_state)); |
138 | 0 | const char* data = reinterpret_cast<const char*>(serialized_state->data()); |
139 | 0 | size_t size = serialized_state->size(); |
140 | 0 | buf.write_binary(StringRef {data, size}); |
141 | 0 | return Status::OK(); |
142 | 0 | } |
143 | | |
144 | 0 | void AggregatePythonUDAFData::read(BufferReadable& buf) { |
145 | | // Read serialized state from buffer into serialize_data |
146 | | // This will be used later by merge() in deserialize_and_merge() |
147 | 0 | buf.read_binary(serialize_data); |
148 | 0 | } |
149 | | |
150 | 0 | Status AggregatePythonUDAFData::reset(int64_t place) { |
151 | 0 | DCHECK(client) << "Client must be set before calling reset"; |
152 | 0 | RETURN_IF_ERROR(client->reset(place)); |
153 | | // After reset, state still exists but is back to initial state |
154 | 0 | return Status::OK(); |
155 | 0 | } |
156 | | |
157 | 0 | Status AggregatePythonUDAFData::destroy(int64_t place) { |
158 | 0 | DCHECK(client) << "Client must be set before calling destroy"; |
159 | 0 | RETURN_IF_ERROR(client->destroy(place)); |
160 | 0 | return Status::OK(); |
161 | 0 | } |
162 | | |
163 | | Status AggregatePythonUDAFData::get(IColumn& to, const DataTypePtr& result_type, |
164 | 0 | int64_t place) const { |
165 | 0 | DCHECK(client) << "Client must be set before calling get"; |
166 | | |
167 | | // Get final result from Python server |
168 | 0 | std::shared_ptr<arrow::RecordBatch> result; |
169 | 0 | RETURN_IF_ERROR(client->finalize(place, &result)); |
170 | | |
171 | | // Convert Arrow RecordBatch to Block |
172 | 0 | Block result_block; |
173 | 0 | DataTypes types = {result_type}; |
174 | 0 | cctz::time_zone timezone_obj; |
175 | 0 | TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj); |
176 | 0 | RETURN_IF_ERROR(convert_from_arrow_batch(result, types, &result_block, timezone_obj)); |
177 | | |
178 | | // Insert the result value into output column |
179 | 0 | if (result_block.rows() != 1) { |
180 | 0 | return Status::InternalError("Expected 1 row in result block, got {}", result_block.rows()); |
181 | 0 | } |
182 | | |
183 | 0 | auto& result_column = result_block.get_by_position(0).column; |
184 | 0 | to.insert_from(*result_column, 0); |
185 | 0 | return Status::OK(); |
186 | 0 | } |
187 | | |
188 | 0 | Status AggregatePythonUDAF::open() { |
189 | | // Build function metadata from TFunction |
190 | 0 | _func_meta.id = _fn.id; |
191 | 0 | _func_meta.name = _fn.name.function_name; |
192 | | |
193 | | // For UDAF, symbol is in aggregate_fn |
194 | 0 | if (_fn.__isset.aggregate_fn && _fn.aggregate_fn.__isset.symbol) { |
195 | 0 | _func_meta.symbol = _fn.aggregate_fn.symbol; |
196 | 0 | } else { |
197 | 0 | return Status::InvalidArgument("Python UDAF symbol is not set"); |
198 | 0 | } |
199 | | |
200 | | // Determine load type (inline code or module) |
201 | 0 | if (!_fn.function_code.empty()) { |
202 | 0 | _func_meta.type = PythonUDFLoadType::INLINE; |
203 | 0 | _func_meta.location = "inline"; |
204 | 0 | _func_meta.inline_code = _fn.function_code; |
205 | 0 | } else if (!_fn.hdfs_location.empty()) { |
206 | 0 | _func_meta.type = PythonUDFLoadType::MODULE; |
207 | 0 | _func_meta.location = _fn.hdfs_location; |
208 | 0 | _func_meta.checksum = _fn.checksum; |
209 | 0 | } else { |
210 | 0 | _func_meta.type = PythonUDFLoadType::UNKNOWN; |
211 | 0 | _func_meta.location = "unknown"; |
212 | 0 | } |
213 | |
|
214 | 0 | _func_meta.input_types = argument_types; |
215 | 0 | _func_meta.return_type = _return_type; |
216 | 0 | _func_meta.client_type = PythonClientType::UDAF; |
217 | | |
218 | | // Get Python version |
219 | 0 | if (_fn.__isset.runtime_version && !_fn.runtime_version.empty()) { |
220 | 0 | RETURN_IF_ERROR(PythonVersionManager::instance().get_version(_fn.runtime_version, |
221 | 0 | &_python_version)); |
222 | 0 | } else { |
223 | 0 | return Status::InvalidArgument("Python UDAF runtime version is not set"); |
224 | 0 | } |
225 | | |
226 | 0 | _func_meta.runtime_version = _python_version.full_version; |
227 | 0 | RETURN_IF_ERROR(_func_meta.check()); |
228 | 0 | _func_meta.always_nullable = _return_type->is_nullable(); |
229 | |
|
230 | 0 | LOG(INFO) << fmt::format("Creating Python UDAF: {}, runtime_version: {}, func_meta: {}", |
231 | 0 | _fn.name.function_name, _python_version.to_string(), |
232 | 0 | _func_meta.to_string()); |
233 | |
|
234 | 0 | if (_func_meta.type == PythonUDFLoadType::MODULE) { |
235 | 0 | RETURN_IF_ERROR(UserFunctionCache::instance()->get_pypath( |
236 | 0 | _func_meta.id, _func_meta.location, _func_meta.checksum, &_func_meta.location)); |
237 | 0 | } |
238 | | |
239 | 0 | return Status::OK(); |
240 | 0 | } |
241 | | |
242 | 0 | void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const { |
243 | 0 | std::call_once(_schema_init_flag, [this]() { |
244 | 0 | std::vector<std::shared_ptr<arrow::Field>> fields; |
245 | |
|
246 | 0 | std::string timezone = TimezoneUtils::default_time_zone; |
247 | 0 | for (size_t i = 0; i < argument_types.size(); ++i) { |
248 | 0 | std::shared_ptr<arrow::DataType> arrow_type; |
249 | 0 | Status st = convert_to_arrow_type(argument_types[i], &arrow_type, timezone); |
250 | 0 | if (!st.ok()) { |
251 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, |
252 | 0 | "Failed to convert argument type {} to Arrow type: {}", i, |
253 | 0 | st.to_string()); |
254 | 0 | } |
255 | 0 | fields.push_back(arrow::field(std::to_string(i), arrow_type)); |
256 | 0 | } |
257 | | |
258 | | // Add places column for GROUP BY aggregation (always included, NULL in single-place mode) |
259 | 0 | fields.push_back(arrow::field("places", arrow::int64())); |
260 | | // Add binary_data column for merge operations |
261 | 0 | fields.push_back(arrow::field("binary_data", arrow::binary())); |
262 | 0 | _schema = arrow::schema(fields); |
263 | 0 | }); |
264 | | |
265 | | // Initialize the data structure |
266 | 0 | new (place) Data(); |
267 | 0 | DCHECK(reinterpret_cast<Data*>(place)) << "Place must not be null"; |
268 | |
|
269 | 0 | if (Status st = PythonServerManager::instance().get_client( |
270 | 0 | _func_meta, _python_version, &(this->data(place).client), _schema); |
271 | 0 | UNLIKELY(!st.ok())) { |
272 | 0 | this->data(place).~Data(); |
273 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Failed to get Python UDAF client: {}", |
274 | 0 | st.to_string()); |
275 | 0 | } |
276 | | |
277 | | // Initialize UDAF state in Python server |
278 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
279 | 0 | if (Status st = this->data(place).create(place_id); UNLIKELY(!st.ok())) { |
280 | 0 | this->data(place).~Data(); |
281 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
282 | 0 | } |
283 | 0 | } |
284 | | |
285 | 0 | void AggregatePythonUDAF::destroy(AggregateDataPtr __restrict place) const noexcept { |
286 | 0 | try { |
287 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
288 | | |
289 | | // Destroy state in Python server |
290 | 0 | if (this->data(place).client) { |
291 | 0 | Status st = this->data(place).destroy(place_id); |
292 | 0 | if (UNLIKELY(!st.ok())) { |
293 | 0 | LOG(WARNING) << "Failed to destroy Python UDAF state for place_id=" << place_id |
294 | 0 | << ", function=" << _func_meta.name << ": " << st.to_string(); |
295 | 0 | } |
296 | |
|
297 | 0 | this->data(place).client.reset(); |
298 | 0 | } |
299 | |
|
300 | 0 | this->data(place).~Data(); |
301 | 0 | } catch (const std::exception& e) { |
302 | 0 | LOG(ERROR) << "Exception in AggregatePythonUDAF::destroy: " << e.what(); |
303 | 0 | } catch (...) { |
304 | 0 | LOG(ERROR) << "Unknown exception in AggregatePythonUDAF::destroy"; |
305 | 0 | } |
306 | 0 | } |
307 | | |
308 | | void AggregatePythonUDAF::add(AggregateDataPtr __restrict place, const IColumn** columns, |
309 | 0 | ssize_t row_num, Arena&) const { |
310 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
311 | 0 | Status st = this->data(place).add(place_id, columns, row_num, row_num + 1, argument_types); |
312 | 0 | if (UNLIKELY(!st.ok())) { |
313 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
314 | 0 | } |
315 | 0 | } |
316 | | |
317 | | void AggregatePythonUDAF::add_batch(size_t batch_size, AggregateDataPtr* places, |
318 | | size_t place_offset, const IColumn** columns, Arena&, |
319 | 0 | bool /*agg_many*/) const { |
320 | 0 | if (batch_size == 0) return; |
321 | | |
322 | 0 | size_t start = 0; |
323 | 0 | while (start < batch_size) { |
324 | | // Get the starting place for this segment |
325 | 0 | AggregateDataPtr start_place = places[start] + place_offset; |
326 | 0 | auto& start_place_data = this->data(start_place); |
327 | | // Get the process for this segment |
328 | 0 | const auto* current_process = start_place_data.client->get_process().get(); |
329 | | |
330 | | // Scan forward to find the end of this consecutive segment (same process) |
331 | 0 | size_t end = start + 1; |
332 | 0 | while (end < batch_size) { |
333 | 0 | AggregateDataPtr end_place = places[end] + place_offset; |
334 | 0 | auto& end_place_data = this->data(end_place); |
335 | 0 | const auto* next_process = end_place_data.client->get_process().get(); |
336 | | // If different process, end the current segment |
337 | 0 | if (*next_process != *current_process) break; |
338 | 0 | ++end; |
339 | 0 | } |
340 | | |
341 | | // Send this segment to Python with zero-copy |
342 | | // Pass places array and let add_batch construct place_ids on-demand |
343 | 0 | Status st = start_place_data.add_batch(places, place_offset, batch_size, columns, |
344 | 0 | argument_types, start, end); |
345 | |
|
346 | 0 | if (UNLIKELY(!st.ok())) { |
347 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, |
348 | 0 | "Failed to send segment to Python: " + st.to_string()); |
349 | 0 | } |
350 | | |
351 | 0 | start = end; |
352 | 0 | } |
353 | 0 | } |
354 | | |
355 | | void AggregatePythonUDAF::add_batch_single_place(size_t batch_size, AggregateDataPtr place, |
356 | 0 | const IColumn** columns, Arena&) const { |
357 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
358 | 0 | Status st = this->data(place).add(place_id, columns, 0, batch_size, argument_types); |
359 | 0 | if (UNLIKELY(!st.ok())) { |
360 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
361 | 0 | } |
362 | 0 | } |
363 | | |
364 | | void AggregatePythonUDAF::add_range_single_place(int64_t partition_start, int64_t partition_end, |
365 | | int64_t frame_start, int64_t frame_end, |
366 | | AggregateDataPtr place, const IColumn** columns, |
367 | | Arena& arena, UInt8* current_window_empty, |
368 | 0 | UInt8* current_window_has_inited) const { |
369 | | // Calculate actual frame range |
370 | 0 | frame_start = std::max<int64_t>(frame_start, partition_start); |
371 | 0 | frame_end = std::min<int64_t>(frame_end, partition_end); |
372 | |
|
373 | 0 | if (frame_start >= frame_end) { |
374 | 0 | if (!*current_window_has_inited) { |
375 | 0 | *current_window_empty = true; |
376 | 0 | } |
377 | 0 | return; |
378 | 0 | } |
379 | | |
380 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
381 | 0 | Status st = this->data(place).add(place_id, columns, frame_start, frame_end, argument_types); |
382 | 0 | if (UNLIKELY(!st.ok())) { |
383 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
384 | 0 | } |
385 | | |
386 | 0 | *current_window_empty = false; |
387 | 0 | *current_window_has_inited = true; |
388 | 0 | } |
389 | | |
390 | 0 | void AggregatePythonUDAF::reset(AggregateDataPtr place) const { |
391 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
392 | 0 | Status st = this->data(place).reset(place_id); |
393 | 0 | if (UNLIKELY(!st.ok())) { |
394 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
395 | 0 | } |
396 | 0 | } |
397 | | |
398 | | void AggregatePythonUDAF::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
399 | 0 | Arena&) const { |
400 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
401 | 0 | Status st = this->data(place).merge(this->data(rhs), place_id); |
402 | 0 | if (UNLIKELY(!st.ok())) { |
403 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
404 | 0 | } |
405 | 0 | } |
406 | | |
407 | | void AggregatePythonUDAF::serialize(ConstAggregateDataPtr __restrict place, |
408 | 0 | BufferWritable& buf) const { |
409 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
410 | 0 | Status st = this->data(place).write(buf, place_id); |
411 | 0 | if (UNLIKELY(!st.ok())) { |
412 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
413 | 0 | } |
414 | 0 | } |
415 | | |
416 | | void AggregatePythonUDAF::deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
417 | 0 | Arena&) const { |
418 | 0 | this->data(place).read(buf); |
419 | 0 | } |
420 | | |
421 | | void AggregatePythonUDAF::insert_result_into(ConstAggregateDataPtr __restrict place, |
422 | 0 | IColumn& to) const { |
423 | 0 | int64_t place_id = reinterpret_cast<int64_t>(place); |
424 | 0 | Status st = this->data(place).get(to, _return_type, place_id); |
425 | 0 | if (UNLIKELY(!st.ok())) { |
426 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
427 | 0 | } |
428 | 0 | } |
429 | | |
430 | | } // namespace doris |