be/src/udf/python/python_udaf_client.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "udf/python/python_udaf_client.h" |
19 | | |
20 | | #include <arrow/array/builder_base.h> |
21 | | #include <arrow/array/builder_binary.h> |
22 | | #include <arrow/array/builder_primitive.h> |
23 | | #include <arrow/flight/client.h> |
24 | | #include <arrow/flight/server.h> |
25 | | #include <arrow/io/memory.h> |
26 | | #include <arrow/ipc/writer.h> |
27 | | #include <arrow/record_batch.h> |
28 | | #include <arrow/type.h> |
29 | | |
30 | | #include "common/compiler_util.h" |
31 | | #include "common/status.h" |
32 | | #include "format/arrow/arrow_utils.h" |
33 | | #include "udf/python/python_udf_meta.h" |
34 | | #include "udf/python/python_udf_runtime.h" |
35 | | |
36 | | namespace doris { |
37 | | |
38 | | // Unified response structure for UDAF operations |
39 | | // Arrow Schema: [success: bool, rows_processed: int64, data: binary] |
40 | | // Different operations use different fields: |
41 | | // - CREATE/MERGE/RESET/DESTROY: use success only |
42 | | // - ACCUMULATE: use success + rows_processed (number of rows processed) |
43 | | // - SERIALIZE: use success + data (serialized_state) |
44 | | // - FINALIZE: use success + data (serialized result, may be null) |
45 | | // |
46 | | // This unified schema allows all operations to return consistent format, |
47 | | // solving Arrow Flight's limitation that all responses must have the same schema. |
48 | | static const std::shared_ptr<arrow::Schema> kUnifiedUDAFResponseSchema = arrow::schema({ |
49 | | arrow::field("success", arrow::boolean()), |
50 | | arrow::field("rows_processed", arrow::int64()), |
51 | | arrow::field("serialized_data", arrow::binary()), |
52 | | }); |
53 | | |
54 | | Status PythonUDAFClient::create(const PythonUDFMeta& func_meta, ProcessPtr process, |
55 | | const std::shared_ptr<arrow::Schema>& data_schema, |
56 | 3.76k | PythonUDAFClientPtr* client) { |
57 | 3.76k | PythonUDAFClientPtr python_udaf_client = std::make_shared<PythonUDAFClient>(); |
58 | 3.76k | RETURN_IF_ERROR(python_udaf_client->init(func_meta, std::move(process), data_schema)); |
59 | 3.76k | *client = std::move(python_udaf_client); |
60 | 3.76k | return Status::OK(); |
61 | 3.76k | } |
62 | | |
63 | | Status PythonUDAFClient::init(const PythonUDFMeta& func_meta, ProcessPtr process, |
64 | 3.76k | const std::shared_ptr<arrow::Schema>& data_schema) { |
65 | 3.76k | _schema = data_schema; |
66 | 3.76k | return PythonClient::init(func_meta, std::move(process)); |
67 | 3.76k | } |
68 | | |
69 | 3.75k | Status PythonUDAFClient::create(int64_t place_id) { |
70 | 3.75k | std::shared_ptr<arrow::RecordBatch> request_batch; |
71 | 3.75k | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
72 | | |
73 | 3.75k | UDAFMetadata metadata { |
74 | 3.75k | .meta_version = UDAF_METADATA_VERSION, |
75 | 3.75k | .operation = static_cast<uint8_t>(UDAFOperation::CREATE), |
76 | 3.75k | .is_single_place = 0, |
77 | 3.75k | .place_id = place_id, |
78 | 3.75k | .row_start = 0, |
79 | 3.75k | .row_end = 0, |
80 | 3.75k | }; |
81 | | |
82 | 3.75k | std::shared_ptr<arrow::RecordBatch> response_batch; |
83 | 3.75k | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response_batch)); |
84 | | |
85 | | // Parse unified response_batch: [success: bool, rows_processed: int64, serialized_data: binary] |
86 | 3.75k | if (response_batch->num_rows() != 1) { |
87 | 0 | return Status::InternalError("Invalid CREATE response_batch: expected 1 row"); |
88 | 0 | } |
89 | | |
90 | 3.75k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response_batch->column(0)); |
91 | 3.75k | if (!success_array->Value(0)) { |
92 | 0 | return Status::InternalError("CREATE operation failed for place_id={}", place_id); |
93 | 0 | } |
94 | | |
95 | 3.75k | _created_place_id = place_id; |
96 | 3.75k | return Status::OK(); |
97 | 3.75k | } |
98 | | |
99 | | Status PythonUDAFClient::accumulate(int64_t place_id, bool is_single_place, |
100 | | const arrow::RecordBatch& input, int64_t row_start, |
101 | 2.54k | int64_t row_end) { |
102 | | // Validate input parameters |
103 | 2.54k | if (UNLIKELY(row_start < 0 || row_end < row_start || row_end > input.num_rows())) { |
104 | 0 | return Status::InvalidArgument( |
105 | 0 | "Invalid row range: row_start={}, row_end={}, input.num_rows={}", row_start, |
106 | 0 | row_end, input.num_rows()); |
107 | 0 | } |
108 | | |
109 | | // In multi-place mode, input RecordBatch must contain "places" column as last column |
110 | 2.54k | if (UNLIKELY(!is_single_place && |
111 | 2.54k | (input.num_columns() == 0 || |
112 | 2.54k | input.schema()->field(input.num_columns() - 1)->name() != "places"))) { |
113 | 0 | return Status::InternalError( |
114 | 0 | "In multi-place mode, input RecordBatch must contain 'places' column as the " |
115 | 0 | "last column"); |
116 | 0 | } |
117 | | |
118 | | // Create request batch: input data + NULL binary_data column |
119 | 2.54k | std::shared_ptr<arrow::RecordBatch> request_batch; |
120 | 2.54k | RETURN_IF_ERROR(_create_data_request_batch(input, &request_batch)); |
121 | | |
122 | | // Create metadata structure |
123 | 2.54k | UDAFMetadata metadata { |
124 | 2.54k | .meta_version = UDAF_METADATA_VERSION, |
125 | 2.54k | .operation = static_cast<uint8_t>(UDAFOperation::ACCUMULATE), |
126 | 2.54k | .is_single_place = static_cast<uint8_t>(is_single_place ? 1 : 0), |
127 | 2.54k | .place_id = place_id, |
128 | 2.54k | .row_start = row_start, |
129 | 2.54k | .row_end = row_end, |
130 | 2.54k | }; |
131 | | |
132 | | // Send to server with metadata in app_metadata |
133 | 2.54k | std::shared_ptr<arrow::RecordBatch> response; |
134 | 2.54k | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
135 | | |
136 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
137 | 2.54k | if (response->num_rows() != 1) { |
138 | 0 | return Status::InternalError("Invalid ACCUMULATE response: expected 1 row"); |
139 | 0 | } |
140 | | |
141 | 2.54k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
142 | 2.54k | auto rows_processed_array = std::static_pointer_cast<arrow::Int64Array>(response->column(1)); |
143 | | |
144 | 2.54k | if (!success_array->Value(0)) { |
145 | 0 | return Status::InternalError("ACCUMULATE operation failed for place_id={}", place_id); |
146 | 0 | } |
147 | | |
148 | | // Cast to uint8_t* first to avoid UBSAN misaligned pointer errors |
149 | 2.54k | const uint8_t* raw_ptr = reinterpret_cast<const uint8_t*>(rows_processed_array->raw_values()); |
150 | 2.54k | if (raw_ptr == nullptr) { |
151 | 0 | return Status::InternalError("ACCUMULATE response has null rows_processed array"); |
152 | 0 | } |
153 | 2.54k | int64_t rows_processed; |
154 | 2.54k | memcpy(&rows_processed, raw_ptr, sizeof(int64_t)); |
155 | | |
156 | 2.54k | int64_t expected_rows = row_end - row_start; |
157 | | |
158 | 2.54k | if (rows_processed < expected_rows) { |
159 | 0 | return Status::InternalError( |
160 | 0 | "ACCUMULATE operation only processed {} out of {} rows for place_id={}", |
161 | 0 | rows_processed, expected_rows, place_id); |
162 | 0 | } |
163 | 2.54k | return Status::OK(); |
164 | 2.54k | } |
165 | | |
166 | | Status PythonUDAFClient::serialize(int64_t place_id, |
167 | 766 | std::shared_ptr<arrow::Buffer>* serialized_state) { |
168 | 766 | std::shared_ptr<arrow::RecordBatch> request_batch; |
169 | 766 | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
170 | | |
171 | 766 | UDAFMetadata metadata { |
172 | 766 | .meta_version = UDAF_METADATA_VERSION, |
173 | 766 | .operation = static_cast<uint8_t>(UDAFOperation::SERIALIZE), |
174 | 766 | .is_single_place = 0, |
175 | 766 | .place_id = place_id, |
176 | 766 | .row_start = 0, |
177 | 766 | .row_end = 0, |
178 | 766 | }; |
179 | | |
180 | 766 | std::shared_ptr<arrow::RecordBatch> response; |
181 | 766 | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
182 | | |
183 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
184 | 766 | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
185 | 766 | auto data_array = std::static_pointer_cast<arrow::BinaryArray>(response->column(2)); |
186 | | |
187 | 766 | if (!success_array->Value(0)) { |
188 | 0 | return Status::InternalError("SERIALIZE operation failed for place_id={}", place_id); |
189 | 0 | } |
190 | | |
191 | | // Cast to uint8_t* first to avoid UBSAN misaligned pointer errors |
192 | 766 | const uint8_t* offsets = reinterpret_cast<const uint8_t*>(data_array->raw_value_offsets()); |
193 | 766 | if (offsets == nullptr) { |
194 | 0 | return Status::InternalError("SERIALIZE response has null offsets"); |
195 | 0 | } |
196 | 766 | int32_t offset_start, offset_end; |
197 | 766 | memcpy(&offset_start, offsets, sizeof(int32_t)); |
198 | 766 | memcpy(&offset_end, offsets + sizeof(int32_t), sizeof(int32_t)); |
199 | | |
200 | 766 | int32_t length = offset_end - offset_start; |
201 | | |
202 | 766 | if (length == 0) { |
203 | 0 | return Status::InternalError("SERIALIZE operation returned empty data for place_id={}", |
204 | 0 | place_id); |
205 | 0 | } |
206 | | |
207 | 766 | const uint8_t* data = data_array->value_data()->data() + offset_start; |
208 | 766 | *serialized_state = arrow::Buffer::Wrap(data, length); |
209 | 766 | return Status::OK(); |
210 | 766 | } |
211 | | |
212 | | Status PythonUDAFClient::merge(int64_t place_id, |
213 | 766 | const std::shared_ptr<arrow::Buffer>& serialized_state) { |
214 | 766 | std::shared_ptr<arrow::RecordBatch> request_batch; |
215 | 766 | RETURN_IF_ERROR(_create_binary_request_batch(serialized_state, &request_batch)); |
216 | | |
217 | 766 | UDAFMetadata metadata { |
218 | 766 | .meta_version = UDAF_METADATA_VERSION, |
219 | 766 | .operation = static_cast<uint8_t>(UDAFOperation::MERGE), |
220 | 766 | .is_single_place = 0, |
221 | 766 | .place_id = place_id, |
222 | 766 | .row_start = 0, |
223 | 766 | .row_end = 0, |
224 | 766 | }; |
225 | | |
226 | 766 | std::shared_ptr<arrow::RecordBatch> response; |
227 | 766 | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
228 | | |
229 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
230 | 766 | if (response->num_rows() != 1) { |
231 | 0 | return Status::InternalError("Invalid MERGE response: expected 1 row"); |
232 | 0 | } |
233 | | |
234 | 766 | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
235 | 766 | if (!success_array->Value(0)) { |
236 | 0 | return Status::InternalError("MERGE operation failed for place_id={}", place_id); |
237 | 0 | } |
238 | | |
239 | 766 | return Status::OK(); |
240 | 766 | } |
241 | | |
242 | 2.26k | Status PythonUDAFClient::finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output) { |
243 | 2.26k | std::shared_ptr<arrow::RecordBatch> request_batch; |
244 | 2.26k | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
245 | | |
246 | 2.26k | UDAFMetadata metadata { |
247 | 2.26k | .meta_version = UDAF_METADATA_VERSION, |
248 | 2.26k | .operation = static_cast<uint8_t>(UDAFOperation::FINALIZE), |
249 | 2.26k | .is_single_place = 0, |
250 | 2.26k | .place_id = place_id, |
251 | 2.26k | .row_start = 0, |
252 | 2.26k | .row_end = 0, |
253 | 2.26k | }; |
254 | | |
255 | 2.26k | std::shared_ptr<arrow::RecordBatch> response_batch; |
256 | 2.26k | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response_batch)); |
257 | | |
258 | | // Parse unified response_batch: [success: bool, rows_processed: int64, serialized_data: binary] |
259 | 2.26k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response_batch->column(0)); |
260 | 2.26k | auto data_array = std::static_pointer_cast<arrow::BinaryArray>(response_batch->column(2)); |
261 | | |
262 | 2.26k | if (!success_array->Value(0)) { |
263 | 0 | return Status::InternalError("FINALIZE operation failed for place_id={}", place_id); |
264 | 0 | } |
265 | | |
266 | | // Cast to uint8_t* first to avoid UBSAN misaligned pointer errors |
267 | 2.26k | const uint8_t* offsets = reinterpret_cast<const uint8_t*>(data_array->raw_value_offsets()); |
268 | 2.26k | if (offsets == nullptr) { |
269 | 0 | return Status::InternalError("FINALIZE response has null offsets"); |
270 | 0 | } |
271 | 2.26k | int32_t offset_start, offset_end; |
272 | 2.26k | memcpy(&offset_start, offsets, sizeof(int32_t)); |
273 | 2.26k | memcpy(&offset_end, offsets + sizeof(int32_t), sizeof(int32_t)); |
274 | | |
275 | 2.26k | int32_t length = offset_end - offset_start; |
276 | | |
277 | 2.26k | if (length == 0) { |
278 | 0 | return Status::InternalError("FINALIZE operation returned empty data for place_id={}", |
279 | 0 | place_id); |
280 | 0 | } |
281 | | |
282 | 2.26k | const uint8_t* data = data_array->value_data()->data() + offset_start; |
283 | 2.26k | auto buffer = arrow::Buffer::Wrap(data, length); |
284 | 2.26k | auto input_stream = std::make_shared<arrow::io::BufferReader>(buffer); |
285 | | |
286 | 2.26k | auto reader_result = arrow::ipc::RecordBatchStreamReader::Open(input_stream); |
287 | 2.26k | if (UNLIKELY(!reader_result.ok())) { |
288 | 0 | return Status::InternalError("Failed to deserialize FINALIZE result: {}", |
289 | 0 | reader_result.status().message()); |
290 | 0 | } |
291 | 2.26k | auto reader = std::move(reader_result).ValueOrDie(); |
292 | | |
293 | 2.26k | auto batch_result = reader->Next(); |
294 | 2.26k | if (UNLIKELY(!batch_result.ok())) { |
295 | 0 | return Status::InternalError("Failed to read FINALIZE result: {}", |
296 | 0 | batch_result.status().message()); |
297 | 0 | } |
298 | | |
299 | 2.26k | *output = std::move(batch_result).ValueOrDie(); |
300 | | |
301 | 2.26k | return Status::OK(); |
302 | 2.26k | } |
303 | | |
304 | 670 | Status PythonUDAFClient::reset(int64_t place_id) { |
305 | 670 | std::shared_ptr<arrow::RecordBatch> request_batch; |
306 | 670 | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
307 | | |
308 | 670 | UDAFMetadata metadata { |
309 | 670 | .meta_version = UDAF_METADATA_VERSION, |
310 | 670 | .operation = static_cast<uint8_t>(UDAFOperation::RESET), |
311 | 670 | .is_single_place = 0, |
312 | 670 | .place_id = place_id, |
313 | 670 | .row_start = 0, |
314 | 670 | .row_end = 0, |
315 | 670 | }; |
316 | | |
317 | 670 | std::shared_ptr<arrow::RecordBatch> response; |
318 | 670 | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
319 | | |
320 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
321 | 670 | if (response->num_rows() != 1) { |
322 | 0 | return Status::InternalError("Invalid RESET response: expected 1 row"); |
323 | 0 | } |
324 | | |
325 | 670 | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
326 | 670 | if (!success_array->Value(0)) { |
327 | 0 | return Status::InternalError("RESET operation failed for place_id={}", place_id); |
328 | 0 | } |
329 | | |
330 | 670 | return Status::OK(); |
331 | 670 | } |
332 | | |
333 | 3.75k | Status PythonUDAFClient::destroy(int64_t place_id) { |
334 | 3.75k | std::shared_ptr<arrow::RecordBatch> request_batch; |
335 | 3.75k | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
336 | | |
337 | 3.75k | UDAFMetadata metadata { |
338 | 3.75k | .meta_version = UDAF_METADATA_VERSION, |
339 | 3.75k | .operation = static_cast<uint8_t>(UDAFOperation::DESTROY), |
340 | 3.75k | .is_single_place = 0, |
341 | 3.75k | .place_id = place_id, |
342 | 3.75k | .row_start = 0, |
343 | 3.75k | .row_end = 0, |
344 | 3.75k | }; |
345 | | |
346 | 3.75k | std::shared_ptr<arrow::RecordBatch> response; |
347 | 3.75k | Status st = _send_request(metadata, request_batch, &response); |
348 | | |
349 | | // Always clear tracking, even if RPC failed |
350 | 3.75k | _created_place_id.reset(); |
351 | | |
352 | 3.75k | if (!st.ok()) { |
353 | 0 | LOG(WARNING) << "Failed to destroy place_id=" << place_id << ": " << st.to_string(); |
354 | 0 | return st; |
355 | 0 | } |
356 | | |
357 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
358 | 3.75k | if (response->num_rows() != 1) { |
359 | 0 | return Status::InternalError("Invalid DESTROY response: expected 1 row"); |
360 | 0 | } |
361 | | |
362 | 3.75k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
363 | | |
364 | 3.75k | if (!success_array->Value(0)) { |
365 | 0 | LOG(WARNING) << "DESTROY operation failed for place_id=" << place_id; |
366 | 0 | return Status::InternalError("DESTROY operation failed for place_id={}", place_id); |
367 | 0 | } |
368 | | |
369 | 3.75k | return Status::OK(); |
370 | 3.75k | } |
371 | | |
372 | 3.75k | Status PythonUDAFClient::close() { |
373 | 3.75k | if (!_inited || !_writer) return Status::OK(); |
374 | | |
375 | | // Destroy the place if it exists (cleanup on client destruction) |
376 | 3.75k | if (_created_place_id.has_value()) { |
377 | 0 | int64_t place_id = _created_place_id.value(); |
378 | 0 | Status st = destroy(place_id); |
379 | 0 | if (!st.ok()) { |
380 | 0 | LOG(WARNING) << "Failed to destroy place_id=" << place_id |
381 | 0 | << " during close: " << st.to_string(); |
382 | | // Clear tracking even on failure to prevent issues in base class close |
383 | 0 | _created_place_id.reset(); |
384 | 0 | } |
385 | 0 | } |
386 | | |
387 | 3.75k | return PythonClient::close(); |
388 | 3.75k | } |
389 | | |
390 | | Status PythonUDAFClient::_send_request(const UDAFMetadata& metadata, |
391 | | const std::shared_ptr<arrow::RecordBatch>& request_batch, |
392 | 14.5k | std::shared_ptr<arrow::RecordBatch>* response_batch) { |
393 | 14.5k | DCHECK(response_batch != nullptr); |
394 | | |
395 | | // Create app_metadata buffer from metadata struct |
396 | 14.5k | auto app_metadata = |
397 | 14.5k | arrow::Buffer::Wrap(reinterpret_cast<const uint8_t*>(&metadata), sizeof(metadata)); |
398 | | |
399 | 14.5k | std::lock_guard<std::mutex> lock(_operation_mutex); |
400 | | |
401 | | // Check if writer/reader are still valid (they could be reset by handle_error) |
402 | 14.5k | if (UNLIKELY(!_writer || !_reader)) { |
403 | 0 | return Status::InternalError("{} writer/reader have been closed due to previous error", |
404 | 0 | _operation_name); |
405 | 0 | } |
406 | | |
407 | | // Begin stream on first call (using data schema: argument_types + places + binary_data) |
408 | 14.5k | if (UNLIKELY(!_begin)) { |
409 | 3.75k | auto begin_res = _writer->Begin(_schema); |
410 | 3.75k | if (!begin_res.ok()) { |
411 | 0 | return handle_error(begin_res); |
412 | 0 | } |
413 | 3.75k | _begin = true; |
414 | 3.75k | } |
415 | | |
416 | | // Write batch with metadata in app_metadata |
417 | 14.5k | auto write_res = _writer->WriteWithMetadata(*request_batch, app_metadata); |
418 | 14.5k | if (!write_res.ok()) { |
419 | 0 | return handle_error(write_res); |
420 | 0 | } |
421 | | |
422 | | // Read unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
423 | 14.5k | auto read_res = _reader->Next(); |
424 | 14.5k | if (!read_res.ok()) { |
425 | 0 | return handle_error(read_res.status()); |
426 | 0 | } |
427 | | |
428 | 14.5k | arrow::flight::FlightStreamChunk chunk = std::move(*read_res); |
429 | 14.5k | if (!chunk.data) { |
430 | 0 | return Status::InternalError("Received empty RecordBatch from {} server", _operation_name); |
431 | 0 | } |
432 | | |
433 | | // Validate unified response schema |
434 | 14.5k | if (!chunk.data->schema()->Equals(kUnifiedUDAFResponseSchema)) { |
435 | 0 | return Status::InternalError( |
436 | 0 | "Invalid response schema: expected [success: bool, rows_processed: int64, " |
437 | 0 | "serialized_data: binary], got {}", |
438 | 0 | chunk.data->schema()->ToString()); |
439 | 0 | } |
440 | | |
441 | 14.5k | *response_batch = std::move(chunk.data); |
442 | 14.5k | return Status::OK(); |
443 | 14.5k | } |
444 | | |
445 | | Status PythonUDAFClient::_create_data_request_batch(const arrow::RecordBatch& input_data, |
446 | 2.54k | std::shared_ptr<arrow::RecordBatch>* out) { |
447 | | // Determine if input has places column |
448 | 2.54k | int num_input_columns = input_data.num_columns(); |
449 | 2.54k | bool has_places = false; |
450 | 2.54k | if (num_input_columns > 0 && |
451 | 2.54k | input_data.schema()->field(num_input_columns - 1)->name() == "places") { |
452 | 712 | has_places = true; |
453 | 712 | } |
454 | | |
455 | | // Expected schema structure: [argument_types..., places, binary_data] |
456 | | // - Input in single-place mode: [argument_types...] |
457 | | // - Input in multi-place mode: [argument_types..., places] |
458 | 2.54k | std::vector<std::shared_ptr<arrow::Array>> columns; |
459 | | // Copy argument_types columns |
460 | 2.54k | int num_arg_columns = has_places ? (num_input_columns - 1) : num_input_columns; |
461 | | |
462 | 5.77k | for (int i = 0; i < num_arg_columns; ++i) { |
463 | 3.23k | columns.push_back(input_data.column(i)); |
464 | 3.23k | } |
465 | | |
466 | | // Add places column |
467 | 2.54k | if (has_places) { |
468 | | // Use existing places column from input |
469 | 712 | columns.push_back(input_data.column(num_input_columns - 1)); |
470 | 1.83k | } else { |
471 | | // Create NULL places column for single-place mode |
472 | 1.83k | arrow::Int64Builder places_builder; |
473 | 1.83k | std::shared_ptr<arrow::Array> places_array; |
474 | 1.83k | RETURN_DORIS_STATUS_IF_ERROR(places_builder.AppendNulls(input_data.num_rows())); |
475 | 1.83k | RETURN_DORIS_STATUS_IF_ERROR(places_builder.Finish(&places_array)); |
476 | 1.83k | columns.push_back(places_array); |
477 | 1.83k | } |
478 | | |
479 | | // Add NULL binary_data column |
480 | 2.54k | arrow::BinaryBuilder binary_builder; |
481 | 2.54k | std::shared_ptr<arrow::Array> binary_array; |
482 | 2.54k | RETURN_DORIS_STATUS_IF_ERROR(binary_builder.AppendNulls(input_data.num_rows())); |
483 | 2.54k | RETURN_DORIS_STATUS_IF_ERROR(binary_builder.Finish(&binary_array)); |
484 | 2.54k | columns.push_back(binary_array); |
485 | | |
486 | 2.54k | *out = arrow::RecordBatch::Make(_schema, input_data.num_rows(), columns); |
487 | 2.54k | return Status::OK(); |
488 | 2.54k | } |
489 | | |
490 | | Status PythonUDAFClient::_create_binary_request_batch( |
491 | | const std::shared_ptr<arrow::Buffer>& binary_data, |
492 | 766 | std::shared_ptr<arrow::RecordBatch>* out) { |
493 | 766 | std::vector<std::shared_ptr<arrow::Array>> columns; |
494 | | |
495 | | // Create NULL arrays for data columns (all columns except the last binary_data column) |
496 | | // Schema: [argument_types..., places, binary_data] |
497 | 766 | int num_data_columns = _schema->num_fields() - 1; |
498 | 2.78k | for (int i = 0; i < num_data_columns; ++i) { |
499 | 2.01k | std::unique_ptr<arrow::ArrayBuilder> builder; |
500 | 2.01k | std::shared_ptr<arrow::Array> null_array; |
501 | 2.01k | RETURN_DORIS_STATUS_IF_ERROR(arrow::MakeBuilder(arrow::default_memory_pool(), |
502 | 2.01k | _schema->field(i)->type(), &builder)); |
503 | 2.01k | RETURN_DORIS_STATUS_IF_ERROR(builder->AppendNull()); |
504 | 2.01k | RETURN_DORIS_STATUS_IF_ERROR(builder->Finish(&null_array)); |
505 | 2.01k | columns.push_back(null_array); |
506 | 2.01k | } |
507 | | |
508 | | // Create binary_data column |
509 | 766 | arrow::BinaryBuilder binary_builder; |
510 | 766 | std::shared_ptr<arrow::Array> binary_array; |
511 | 766 | RETURN_DORIS_STATUS_IF_ERROR( |
512 | 766 | binary_builder.Append(binary_data->data(), static_cast<int32_t>(binary_data->size()))); |
513 | 766 | RETURN_DORIS_STATUS_IF_ERROR(binary_builder.Finish(&binary_array)); |
514 | 766 | columns.push_back(binary_array); |
515 | | |
516 | 766 | *out = arrow::RecordBatch::Make(_schema, 1, columns); |
517 | 766 | return Status::OK(); |
518 | 766 | } |
519 | | |
520 | 11.2k | Status PythonUDAFClient::_get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out) { |
521 | | // Return cached batch if already created |
522 | 11.2k | if (_empty_request_batch) { |
523 | 7.46k | *out = _empty_request_batch; |
524 | 7.46k | return Status::OK(); |
525 | 7.46k | } |
526 | | |
527 | | // Create empty batch on first use (all columns NULL, 1 row) |
528 | 3.74k | std::vector<std::shared_ptr<arrow::Array>> columns; |
529 | | |
530 | 16.4k | for (int i = 0; i < _schema->num_fields(); ++i) { |
531 | 12.7k | auto field = _schema->field(i); |
532 | 12.7k | std::unique_ptr<arrow::ArrayBuilder> builder; |
533 | 12.7k | std::shared_ptr<arrow::Array> null_array; |
534 | 12.7k | RETURN_DORIS_STATUS_IF_ERROR( |
535 | 12.7k | arrow::MakeBuilder(arrow::default_memory_pool(), field->type(), &builder)); |
536 | 12.7k | RETURN_DORIS_STATUS_IF_ERROR(builder->AppendNull()); |
537 | 12.7k | RETURN_DORIS_STATUS_IF_ERROR(builder->Finish(&null_array)); |
538 | 12.7k | columns.push_back(null_array); |
539 | 12.7k | } |
540 | | |
541 | 3.74k | _empty_request_batch = arrow::RecordBatch::Make(_schema, 1, columns); |
542 | 3.74k | *out = _empty_request_batch; |
543 | 3.74k | return Status::OK(); |
544 | 3.74k | } |
545 | | |
546 | | } // namespace doris |