be/src/udf/python/python_udaf_client.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "udf/python/python_udaf_client.h" |
19 | | |
20 | | #include <arrow/array/builder_base.h> |
21 | | #include <arrow/array/builder_binary.h> |
22 | | #include <arrow/array/builder_primitive.h> |
23 | | #include <arrow/flight/client.h> |
24 | | #include <arrow/flight/server.h> |
25 | | #include <arrow/io/memory.h> |
26 | | #include <arrow/ipc/writer.h> |
27 | | #include <arrow/record_batch.h> |
28 | | #include <arrow/type.h> |
29 | | |
30 | | #include "common/compiler_util.h" |
31 | | #include "common/status.h" |
32 | | #include "format/arrow/arrow_utils.h" |
33 | | #include "udf/python/python_udf_meta.h" |
34 | | #include "udf/python/python_udf_runtime.h" |
35 | | #include "util/unaligned.h" |
36 | | |
37 | | namespace doris { |
38 | | |
39 | | // Unified response structure for UDAF operations |
40 | | // Arrow Schema: [success: bool, rows_processed: int64, data: binary] |
41 | | // Different operations use different fields: |
42 | | // - CREATE/MERGE/RESET/DESTROY: use success only |
43 | | // - ACCUMULATE: use success + rows_processed (number of rows processed) |
44 | | // - SERIALIZE: use success + data (serialized_state) |
45 | | // - FINALIZE: use success + data (serialized result, may be null) |
46 | | // - Any failed operation: use success=false + data (UTF-8 error message) |
47 | | // |
48 | | // This unified schema allows all operations to return consistent format, |
49 | | // solving Arrow Flight's limitation that all responses must have the same schema. |
50 | | static const std::shared_ptr<arrow::Schema> kUnifiedUDAFResponseSchema = arrow::schema({ |
51 | | arrow::field("success", arrow::boolean()), |
52 | | arrow::field("rows_processed", arrow::int64()), |
53 | | arrow::field("serialized_data", arrow::binary()), |
54 | | }); |
55 | | |
56 | | Status PythonUDAFClient::make_udaf_failure_status( |
57 | | const std::shared_ptr<arrow::RecordBatch>& response, const char* operation, |
58 | 11 | int64_t place_id) { |
59 | 11 | if (response == nullptr || response->num_rows() != 1 || |
60 | 11 | response->num_columns() != kUnifiedUDAFResponseSchema->num_fields()) [[unlikely]] { |
61 | 3 | return Status::InternalError("Invalid {} failure response for place_id={}", operation, |
62 | 3 | place_id); |
63 | 3 | } |
64 | | |
65 | 8 | auto data_array = std::static_pointer_cast<arrow::BinaryArray>(response->column(2)); |
66 | 8 | if (data_array->IsNull(0)) { |
67 | 1 | return Status::InternalError("{} operation failed for place_id={}", operation, place_id); |
68 | 1 | } |
69 | | |
70 | 7 | const auto* offsets = data_array->raw_value_offsets(); |
71 | 7 | if (offsets == nullptr) [[unlikely]] { |
72 | 0 | return Status::InternalError("Invalid {} failure response for place_id={}: null offsets", |
73 | 0 | operation, place_id); |
74 | 0 | } |
75 | | // Arrow Flight buffers may be unaligned after IPC deserialization |
76 | 7 | int32_t offset_start = unaligned_load<int32_t>(offsets); |
77 | 7 | int32_t offset_end = unaligned_load<int32_t>(offsets + 1); |
78 | | |
79 | 7 | int32_t length = offset_end - offset_start; |
80 | 7 | if (length <= 0) { |
81 | 1 | return Status::InternalError("{} operation failed for place_id={}", operation, place_id); |
82 | 1 | } |
83 | 6 | const uint8_t* data = data_array->value_data()->data() + offset_start; |
84 | 6 | std::string error_message(reinterpret_cast<const char*>(data), length); |
85 | 6 | return Status::InternalError("{} operation failed for place_id={}: {}", operation, place_id, |
86 | 6 | error_message); |
87 | 7 | } |
88 | | |
89 | | #ifdef BE_TEST |
90 | | Status PythonUDAFClient::make_udaf_failure_status_for_test( |
91 | | const std::shared_ptr<arrow::RecordBatch>& response, const char* operation, |
92 | | int64_t place_id) { |
93 | | return make_udaf_failure_status(response, operation, place_id); |
94 | | } |
95 | | #endif |
96 | | |
97 | | Status PythonUDAFClient::create(const PythonUDFMeta& func_meta, ProcessPtr process, |
98 | | const std::shared_ptr<arrow::Schema>& data_schema, |
99 | 3.58k | PythonUDAFClientPtr* client) { |
100 | 3.58k | PythonUDAFClientPtr python_udaf_client = std::make_shared<PythonUDAFClient>(); |
101 | 3.58k | RETURN_IF_ERROR(python_udaf_client->init(func_meta, std::move(process), data_schema)); |
102 | 3.58k | *client = std::move(python_udaf_client); |
103 | 3.58k | return Status::OK(); |
104 | 3.58k | } |
105 | | |
106 | | Status PythonUDAFClient::init(const PythonUDFMeta& func_meta, ProcessPtr process, |
107 | 3.58k | const std::shared_ptr<arrow::Schema>& data_schema) { |
108 | 3.58k | _schema = data_schema; |
109 | 3.58k | return PythonClient::init(func_meta, std::move(process)); |
110 | 3.58k | } |
111 | | |
112 | 3.58k | Status PythonUDAFClient::create(int64_t place_id) { |
113 | 3.58k | std::shared_ptr<arrow::RecordBatch> request_batch; |
114 | 3.58k | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
115 | | |
116 | 3.58k | UDAFMetadata metadata { |
117 | 3.58k | .meta_version = UDAF_METADATA_VERSION, |
118 | 3.58k | .operation = static_cast<uint8_t>(UDAFOperation::CREATE), |
119 | 3.58k | .is_single_place = 0, |
120 | 3.58k | .place_id = place_id, |
121 | 3.58k | .row_start = 0, |
122 | 3.58k | .row_end = 0, |
123 | 3.58k | }; |
124 | | |
125 | 3.58k | std::shared_ptr<arrow::RecordBatch> response_batch; |
126 | 3.58k | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response_batch)); |
127 | | |
128 | | // Parse unified response_batch: [success: bool, rows_processed: int64, serialized_data: binary] |
129 | 3.56k | if (response_batch->num_rows() != 1) { |
130 | 0 | return Status::InternalError("Invalid CREATE response_batch: expected 1 row"); |
131 | 0 | } |
132 | | |
133 | 3.56k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response_batch->column(0)); |
134 | 3.56k | if (!success_array->Value(0)) { |
135 | 0 | return make_udaf_failure_status(response_batch, "CREATE", place_id); |
136 | 0 | } |
137 | | |
138 | 3.56k | _created_place_id = place_id; |
139 | 3.56k | return Status::OK(); |
140 | 3.56k | } |
141 | | |
142 | | Status PythonUDAFClient::accumulate(int64_t place_id, bool is_single_place, |
143 | | const arrow::RecordBatch& input, int64_t row_start, |
144 | 2.58k | int64_t row_end) { |
145 | | // Validate input parameters |
146 | 2.58k | if (UNLIKELY(row_start < 0 || row_end < row_start || row_end > input.num_rows())) { |
147 | 0 | return Status::InvalidArgument( |
148 | 0 | "Invalid row range: row_start={}, row_end={}, input.num_rows={}", row_start, |
149 | 0 | row_end, input.num_rows()); |
150 | 0 | } |
151 | | |
152 | | // In multi-place mode, input RecordBatch must contain "places" column as last column |
153 | 2.58k | if (UNLIKELY(!is_single_place && |
154 | 2.58k | (input.num_columns() == 0 || |
155 | 2.58k | input.schema()->field(input.num_columns() - 1)->name() != "places"))) { |
156 | 0 | return Status::InternalError( |
157 | 0 | "In multi-place mode, input RecordBatch must contain 'places' column as the " |
158 | 0 | "last column"); |
159 | 0 | } |
160 | | |
161 | | // Create request batch: input data + NULL binary_data column |
162 | 2.58k | std::shared_ptr<arrow::RecordBatch> request_batch; |
163 | 2.58k | RETURN_IF_ERROR(_create_data_request_batch(input, &request_batch)); |
164 | | |
165 | | // Create metadata structure |
166 | 2.58k | UDAFMetadata metadata { |
167 | 2.58k | .meta_version = UDAF_METADATA_VERSION, |
168 | 2.58k | .operation = static_cast<uint8_t>(UDAFOperation::ACCUMULATE), |
169 | 2.58k | .is_single_place = static_cast<uint8_t>(is_single_place ? 1 : 0), |
170 | 2.58k | .place_id = place_id, |
171 | 2.58k | .row_start = row_start, |
172 | 2.58k | .row_end = row_end, |
173 | 2.58k | }; |
174 | | |
175 | | // Send to server with metadata in app_metadata |
176 | 2.58k | std::shared_ptr<arrow::RecordBatch> response; |
177 | 2.58k | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
178 | | |
179 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
180 | 2.58k | if (response->num_rows() != 1) { |
181 | 0 | return Status::InternalError("Invalid ACCUMULATE response: expected 1 row"); |
182 | 0 | } |
183 | | |
184 | 2.58k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
185 | 2.58k | auto rows_processed_array = std::static_pointer_cast<arrow::Int64Array>(response->column(1)); |
186 | | |
187 | 2.58k | if (!success_array->Value(0)) { |
188 | 1 | return make_udaf_failure_status(response, "ACCUMULATE", place_id); |
189 | 1 | } |
190 | | |
191 | | // Arrow Flight buffers may be unaligned after IPC deserialization. |
192 | 2.58k | const auto* raw_ptr = rows_processed_array->raw_values(); |
193 | 2.58k | if (raw_ptr == nullptr) { |
194 | 0 | return Status::InternalError("ACCUMULATE response has null rows_processed array"); |
195 | 0 | } |
196 | 2.58k | int64_t rows_processed = unaligned_load<int64_t>(raw_ptr); |
197 | | |
198 | 2.58k | int64_t expected_rows = row_end - row_start; |
199 | | |
200 | 2.58k | if (rows_processed < expected_rows) { |
201 | 0 | return Status::InternalError( |
202 | 0 | "ACCUMULATE operation only processed {} out of {} rows for place_id={}", |
203 | 0 | rows_processed, expected_rows, place_id); |
204 | 0 | } |
205 | 2.58k | return Status::OK(); |
206 | 2.58k | } |
207 | | |
208 | | Status PythonUDAFClient::serialize(int64_t place_id, |
209 | 770 | std::shared_ptr<arrow::Buffer>* serialized_state) { |
210 | 770 | std::shared_ptr<arrow::RecordBatch> request_batch; |
211 | 770 | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
212 | | |
213 | 770 | UDAFMetadata metadata { |
214 | 770 | .meta_version = UDAF_METADATA_VERSION, |
215 | 770 | .operation = static_cast<uint8_t>(UDAFOperation::SERIALIZE), |
216 | 770 | .is_single_place = 0, |
217 | 770 | .place_id = place_id, |
218 | 770 | .row_start = 0, |
219 | 770 | .row_end = 0, |
220 | 770 | }; |
221 | | |
222 | 770 | std::shared_ptr<arrow::RecordBatch> response; |
223 | 770 | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
224 | | |
225 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
226 | 770 | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
227 | 770 | auto data_array = std::static_pointer_cast<arrow::BinaryArray>(response->column(2)); |
228 | | |
229 | 770 | if (!success_array->Value(0)) { |
230 | 0 | return make_udaf_failure_status(response, "SERIALIZE", place_id); |
231 | 0 | } |
232 | | |
233 | | // Arrow Flight buffers may be unaligned after IPC deserialization. |
234 | 770 | const auto* offsets = data_array->raw_value_offsets(); |
235 | 770 | if (offsets == nullptr) { |
236 | 0 | return Status::InternalError("SERIALIZE response has null offsets"); |
237 | 0 | } |
238 | 770 | int32_t offset_start = unaligned_load<int32_t>(offsets); |
239 | 770 | int32_t offset_end = unaligned_load<int32_t>(offsets + 1); |
240 | | |
241 | 770 | int32_t length = offset_end - offset_start; |
242 | | |
243 | 770 | if (length == 0) { |
244 | 0 | return Status::InternalError("SERIALIZE operation returned empty data for place_id={}", |
245 | 0 | place_id); |
246 | 0 | } |
247 | | |
248 | 770 | const uint8_t* data = data_array->value_data()->data() + offset_start; |
249 | 770 | *serialized_state = arrow::Buffer::Wrap(data, length); |
250 | 770 | return Status::OK(); |
251 | 770 | } |
252 | | |
253 | | Status PythonUDAFClient::merge(int64_t place_id, |
254 | 770 | const std::shared_ptr<arrow::Buffer>& serialized_state) { |
255 | 770 | std::shared_ptr<arrow::RecordBatch> request_batch; |
256 | 770 | RETURN_IF_ERROR(_create_binary_request_batch(serialized_state, &request_batch)); |
257 | | |
258 | 770 | UDAFMetadata metadata { |
259 | 770 | .meta_version = UDAF_METADATA_VERSION, |
260 | 770 | .operation = static_cast<uint8_t>(UDAFOperation::MERGE), |
261 | 770 | .is_single_place = 0, |
262 | 770 | .place_id = place_id, |
263 | 770 | .row_start = 0, |
264 | 770 | .row_end = 0, |
265 | 770 | }; |
266 | | |
267 | 770 | std::shared_ptr<arrow::RecordBatch> response; |
268 | 770 | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
269 | | |
270 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
271 | 770 | if (response->num_rows() != 1) { |
272 | 0 | return Status::InternalError("Invalid MERGE response: expected 1 row"); |
273 | 0 | } |
274 | | |
275 | 770 | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
276 | 770 | if (!success_array->Value(0)) { |
277 | 1 | return make_udaf_failure_status(response, "MERGE", place_id); |
278 | 1 | } |
279 | | |
280 | 769 | return Status::OK(); |
281 | 770 | } |
282 | | |
283 | 2.28k | Status PythonUDAFClient::finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output) { |
284 | 2.28k | std::shared_ptr<arrow::RecordBatch> request_batch; |
285 | 2.28k | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
286 | | |
287 | 2.28k | UDAFMetadata metadata { |
288 | 2.28k | .meta_version = UDAF_METADATA_VERSION, |
289 | 2.28k | .operation = static_cast<uint8_t>(UDAFOperation::FINALIZE), |
290 | 2.28k | .is_single_place = 0, |
291 | 2.28k | .place_id = place_id, |
292 | 2.28k | .row_start = 0, |
293 | 2.28k | .row_end = 0, |
294 | 2.28k | }; |
295 | | |
296 | 2.28k | std::shared_ptr<arrow::RecordBatch> response_batch; |
297 | 2.28k | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response_batch)); |
298 | | |
299 | | // Parse unified response_batch: [success: bool, rows_processed: int64, serialized_data: binary] |
300 | 2.28k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response_batch->column(0)); |
301 | 2.28k | auto data_array = std::static_pointer_cast<arrow::BinaryArray>(response_batch->column(2)); |
302 | | |
303 | 2.28k | if (!success_array->Value(0)) { |
304 | 2 | return make_udaf_failure_status(response_batch, "FINALIZE", place_id); |
305 | 2 | } |
306 | | |
307 | | // Arrow Flight buffers may be unaligned after IPC deserialization. |
308 | 2.28k | const auto* offsets = data_array->raw_value_offsets(); |
309 | 2.28k | if (offsets == nullptr) { |
310 | 0 | return Status::InternalError("FINALIZE response has null offsets"); |
311 | 0 | } |
312 | 2.28k | int32_t offset_start = unaligned_load<int32_t>(offsets); |
313 | 2.28k | int32_t offset_end = unaligned_load<int32_t>(offsets + 1); |
314 | | |
315 | 2.28k | int32_t length = offset_end - offset_start; |
316 | | |
317 | 2.28k | if (length == 0) { |
318 | 0 | return Status::InternalError("FINALIZE operation returned empty data for place_id={}", |
319 | 0 | place_id); |
320 | 0 | } |
321 | | |
322 | 2.28k | const uint8_t* data = data_array->value_data()->data() + offset_start; |
323 | 2.28k | auto buffer = arrow::Buffer::Wrap(data, length); |
324 | 2.28k | auto input_stream = std::make_shared<arrow::io::BufferReader>(buffer); |
325 | | |
326 | 2.28k | auto reader_result = arrow::ipc::RecordBatchStreamReader::Open(input_stream); |
327 | 2.28k | if (UNLIKELY(!reader_result.ok())) { |
328 | 0 | return Status::InternalError("Failed to deserialize FINALIZE result: {}", |
329 | 0 | reader_result.status().message()); |
330 | 0 | } |
331 | 2.28k | auto reader = std::move(reader_result).ValueOrDie(); |
332 | | |
333 | 2.28k | auto batch_result = reader->Next(); |
334 | 2.28k | if (UNLIKELY(!batch_result.ok())) { |
335 | 0 | return Status::InternalError("Failed to read FINALIZE result: {}", |
336 | 0 | batch_result.status().message()); |
337 | 0 | } |
338 | | |
339 | 2.28k | *output = std::move(batch_result).ValueOrDie(); |
340 | | |
341 | 2.28k | return Status::OK(); |
342 | 2.28k | } |
343 | | |
344 | 657 | Status PythonUDAFClient::reset(int64_t place_id) { |
345 | 657 | std::shared_ptr<arrow::RecordBatch> request_batch; |
346 | 657 | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
347 | | |
348 | 657 | UDAFMetadata metadata { |
349 | 657 | .meta_version = UDAF_METADATA_VERSION, |
350 | 657 | .operation = static_cast<uint8_t>(UDAFOperation::RESET), |
351 | 657 | .is_single_place = 0, |
352 | 657 | .place_id = place_id, |
353 | 657 | .row_start = 0, |
354 | 657 | .row_end = 0, |
355 | 657 | }; |
356 | | |
357 | 657 | std::shared_ptr<arrow::RecordBatch> response; |
358 | 657 | RETURN_IF_ERROR(_send_request(metadata, request_batch, &response)); |
359 | | |
360 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
361 | 657 | if (response->num_rows() != 1) { |
362 | 0 | return Status::InternalError("Invalid RESET response: expected 1 row"); |
363 | 0 | } |
364 | | |
365 | 657 | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
366 | 657 | if (!success_array->Value(0)) { |
367 | 0 | return make_udaf_failure_status(response, "RESET", place_id); |
368 | 0 | } |
369 | | |
370 | 657 | return Status::OK(); |
371 | 657 | } |
372 | | |
373 | 3.56k | Status PythonUDAFClient::destroy(int64_t place_id) { |
374 | 3.56k | std::shared_ptr<arrow::RecordBatch> request_batch; |
375 | 3.56k | RETURN_IF_ERROR(_get_empty_request_batch(&request_batch)); |
376 | | |
377 | 3.56k | UDAFMetadata metadata { |
378 | 3.56k | .meta_version = UDAF_METADATA_VERSION, |
379 | 3.56k | .operation = static_cast<uint8_t>(UDAFOperation::DESTROY), |
380 | 3.56k | .is_single_place = 0, |
381 | 3.56k | .place_id = place_id, |
382 | 3.56k | .row_start = 0, |
383 | 3.56k | .row_end = 0, |
384 | 3.56k | }; |
385 | | |
386 | 3.56k | std::shared_ptr<arrow::RecordBatch> response; |
387 | 3.56k | Status st = _send_request(metadata, request_batch, &response); |
388 | | |
389 | | // Always clear tracking, even if RPC failed |
390 | 3.56k | _created_place_id.reset(); |
391 | | |
392 | 3.56k | if (!st.ok()) { |
393 | 0 | LOG(WARNING) << "Failed to destroy place_id=" << place_id << ": " << st.to_string(); |
394 | 0 | return st; |
395 | 0 | } |
396 | | |
397 | | // Parse unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
398 | 3.56k | if (response->num_rows() != 1) { |
399 | 0 | return Status::InternalError("Invalid DESTROY response: expected 1 row"); |
400 | 0 | } |
401 | | |
402 | 3.56k | auto success_array = std::static_pointer_cast<arrow::BooleanArray>(response->column(0)); |
403 | | |
404 | 3.56k | if (!success_array->Value(0)) { |
405 | 0 | LOG(WARNING) << "DESTROY operation failed for place_id=" << place_id; |
406 | 0 | return make_udaf_failure_status(response, "DESTROY", place_id); |
407 | 0 | } |
408 | | |
409 | 3.56k | return Status::OK(); |
410 | 3.56k | } |
411 | | |
412 | 3.57k | Status PythonUDAFClient::close() { |
413 | 3.58k | if (!_inited || !_writer) return Status::OK(); |
414 | | |
415 | | // Destroy the place if it exists (cleanup on client destruction) |
416 | 3.55k | if (_created_place_id.has_value()) { |
417 | 0 | int64_t place_id = _created_place_id.value(); |
418 | 0 | Status st = destroy(place_id); |
419 | 0 | if (!st.ok()) { |
420 | 0 | LOG(WARNING) << "Failed to destroy place_id=" << place_id |
421 | 0 | << " during close: " << st.to_string(); |
422 | | // Clear tracking even on failure to prevent issues in base class close |
423 | 0 | _created_place_id.reset(); |
424 | 0 | } |
425 | 0 | } |
426 | | |
427 | 3.55k | return PythonClient::close(); |
428 | 3.57k | } |
429 | | |
430 | | Status PythonUDAFClient::_send_request(const UDAFMetadata& metadata, |
431 | | const std::shared_ptr<arrow::RecordBatch>& request_batch, |
432 | 14.2k | std::shared_ptr<arrow::RecordBatch>* response_batch) { |
433 | 14.2k | DCHECK(response_batch != nullptr); |
434 | | |
435 | | // Create app_metadata buffer from metadata struct |
436 | 14.2k | auto app_metadata = |
437 | 14.2k | arrow::Buffer::Wrap(reinterpret_cast<const uint8_t*>(&metadata), sizeof(metadata)); |
438 | | |
439 | 14.2k | std::lock_guard<std::mutex> lock(_operation_mutex); |
440 | | |
441 | | // Check if writer/reader are still valid (they could be reset by handle_error) |
442 | 14.2k | if (UNLIKELY(!_writer || !_reader)) { |
443 | 0 | return Status::InternalError("{} writer/reader have been closed due to previous error", |
444 | 0 | _operation_name); |
445 | 0 | } |
446 | | |
447 | | // Begin stream on first call (using data schema: argument_types + places + binary_data) |
448 | 14.2k | if (UNLIKELY(!_begin)) { |
449 | 3.58k | auto begin_res = _writer->Begin(_schema); |
450 | 3.58k | if (!begin_res.ok()) { |
451 | 0 | return handle_error(begin_res); |
452 | 0 | } |
453 | 3.58k | _begin = true; |
454 | 3.58k | } |
455 | | |
456 | | // Write batch with metadata in app_metadata |
457 | 14.2k | auto write_res = _writer->WriteWithMetadata(*request_batch, app_metadata); |
458 | 14.2k | if (!write_res.ok()) { |
459 | 1 | return handle_error(write_res); |
460 | 1 | } |
461 | | |
462 | | // Read unified response: [success: bool, rows_processed: int64, serialized_data: binary] |
463 | 14.2k | auto read_res = _reader->Next(); |
464 | 14.2k | if (!read_res.ok()) { |
465 | 19 | return handle_error(read_res.status()); |
466 | 19 | } |
467 | | |
468 | 14.1k | arrow::flight::FlightStreamChunk chunk = std::move(*read_res); |
469 | 14.1k | if (!chunk.data) { |
470 | 0 | return Status::InternalError("Received empty RecordBatch from {} server", _operation_name); |
471 | 0 | } |
472 | | |
473 | | // Validate unified response schema |
474 | 14.1k | if (!chunk.data->schema()->Equals(kUnifiedUDAFResponseSchema)) { |
475 | 0 | return Status::InternalError( |
476 | 0 | "Invalid response schema: expected [success: bool, rows_processed: int64, " |
477 | 0 | "serialized_data: binary], got {}", |
478 | 0 | chunk.data->schema()->ToString()); |
479 | 0 | } |
480 | | |
481 | 14.1k | *response_batch = std::move(chunk.data); |
482 | 14.1k | return Status::OK(); |
483 | 14.1k | } |
484 | | |
485 | | Status PythonUDAFClient::_create_data_request_batch(const arrow::RecordBatch& input_data, |
486 | 2.58k | std::shared_ptr<arrow::RecordBatch>* out) { |
487 | | // Determine if input has places column |
488 | 2.58k | int num_input_columns = input_data.num_columns(); |
489 | 2.58k | bool has_places = false; |
490 | 2.58k | if (num_input_columns > 0 && |
491 | 2.58k | input_data.schema()->field(num_input_columns - 1)->name() == "places") { |
492 | 720 | has_places = true; |
493 | 720 | } |
494 | | |
495 | | // Expected schema structure: [argument_types..., places, binary_data] |
496 | | // - Input in single-place mode: [argument_types...] |
497 | | // - Input in multi-place mode: [argument_types..., places] |
498 | 2.58k | std::vector<std::shared_ptr<arrow::Array>> columns; |
499 | | // Copy argument_types columns |
500 | 2.58k | int num_arg_columns = has_places ? (num_input_columns - 1) : num_input_columns; |
501 | | |
502 | 5.81k | for (int i = 0; i < num_arg_columns; ++i) { |
503 | 3.22k | columns.push_back(input_data.column(i)); |
504 | 3.22k | } |
505 | | |
506 | | // Add places column |
507 | 2.58k | if (has_places) { |
508 | | // Use existing places column from input |
509 | 721 | columns.push_back(input_data.column(num_input_columns - 1)); |
510 | 1.86k | } else { |
511 | | // Create NULL places column for single-place mode |
512 | 1.86k | arrow::Int64Builder places_builder; |
513 | 1.86k | std::shared_ptr<arrow::Array> places_array; |
514 | 1.86k | RETURN_DORIS_STATUS_IF_ERROR(places_builder.AppendNulls(input_data.num_rows())); |
515 | 1.86k | RETURN_DORIS_STATUS_IF_ERROR(places_builder.Finish(&places_array)); |
516 | 1.86k | columns.push_back(places_array); |
517 | 1.86k | } |
518 | | |
519 | | // Add NULL binary_data column |
520 | 2.58k | arrow::BinaryBuilder binary_builder; |
521 | 2.58k | std::shared_ptr<arrow::Array> binary_array; |
522 | 2.58k | RETURN_DORIS_STATUS_IF_ERROR(binary_builder.AppendNulls(input_data.num_rows())); |
523 | 2.58k | RETURN_DORIS_STATUS_IF_ERROR(binary_builder.Finish(&binary_array)); |
524 | 2.58k | columns.push_back(binary_array); |
525 | | |
526 | 2.58k | *out = arrow::RecordBatch::Make(_schema, input_data.num_rows(), columns); |
527 | 2.58k | return Status::OK(); |
528 | 2.58k | } |
529 | | |
530 | | Status PythonUDAFClient::_create_binary_request_batch( |
531 | | const std::shared_ptr<arrow::Buffer>& binary_data, |
532 | 770 | std::shared_ptr<arrow::RecordBatch>* out) { |
533 | 770 | std::vector<std::shared_ptr<arrow::Array>> columns; |
534 | | |
535 | | // Create NULL arrays for data columns (all columns except the last binary_data column) |
536 | | // Schema: [argument_types..., places, binary_data] |
537 | 770 | int num_data_columns = _schema->num_fields() - 1; |
538 | 2.73k | for (int i = 0; i < num_data_columns; ++i) { |
539 | 1.96k | std::unique_ptr<arrow::ArrayBuilder> builder; |
540 | 1.96k | std::shared_ptr<arrow::Array> null_array; |
541 | 1.96k | RETURN_DORIS_STATUS_IF_ERROR(arrow::MakeBuilder(arrow::default_memory_pool(), |
542 | 1.96k | _schema->field(i)->type(), &builder)); |
543 | 1.96k | RETURN_DORIS_STATUS_IF_ERROR(builder->AppendNull()); |
544 | 1.96k | RETURN_DORIS_STATUS_IF_ERROR(builder->Finish(&null_array)); |
545 | 1.96k | columns.push_back(null_array); |
546 | 1.96k | } |
547 | | |
548 | | // Create binary_data column |
549 | 770 | arrow::BinaryBuilder binary_builder; |
550 | 770 | std::shared_ptr<arrow::Array> binary_array; |
551 | 770 | RETURN_DORIS_STATUS_IF_ERROR( |
552 | 770 | binary_builder.Append(binary_data->data(), static_cast<int32_t>(binary_data->size()))); |
553 | 770 | RETURN_DORIS_STATUS_IF_ERROR(binary_builder.Finish(&binary_array)); |
554 | 770 | columns.push_back(binary_array); |
555 | | |
556 | 770 | *out = arrow::RecordBatch::Make(_schema, 1, columns); |
557 | 770 | return Status::OK(); |
558 | 770 | } |
559 | | |
560 | 10.8k | Status PythonUDAFClient::_get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out) { |
561 | | // Return cached batch if already created |
562 | 10.8k | if (_empty_request_batch) { |
563 | 7.28k | *out = _empty_request_batch; |
564 | 7.28k | return Status::OK(); |
565 | 7.28k | } |
566 | | |
567 | | // Create empty batch on first use (all columns NULL, 1 row) |
568 | 3.57k | std::vector<std::shared_ptr<arrow::Array>> columns; |
569 | | |
570 | 16.0k | for (int i = 0; i < _schema->num_fields(); ++i) { |
571 | 12.5k | auto field = _schema->field(i); |
572 | 12.5k | std::unique_ptr<arrow::ArrayBuilder> builder; |
573 | 12.5k | std::shared_ptr<arrow::Array> null_array; |
574 | 12.5k | RETURN_DORIS_STATUS_IF_ERROR( |
575 | 12.5k | arrow::MakeBuilder(arrow::default_memory_pool(), field->type(), &builder)); |
576 | 12.5k | RETURN_DORIS_STATUS_IF_ERROR(builder->AppendNull()); |
577 | 12.5k | RETURN_DORIS_STATUS_IF_ERROR(builder->Finish(&null_array)); |
578 | 12.5k | columns.push_back(null_array); |
579 | 12.5k | } |
580 | | |
581 | 3.57k | _empty_request_batch = arrow::RecordBatch::Make(_schema, 1, columns); |
582 | 3.57k | *out = _empty_request_batch; |
583 | 3.57k | return Status::OK(); |
584 | 3.57k | } |
585 | | |
586 | | } // namespace doris |