be/src/udf/python/python_udtf_client.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "udf/python/python_udtf_client.h" |
19 | | |
20 | | #include "arrow/array/array_nested.h" |
21 | | #include "arrow/array/array_primitive.h" |
22 | | #include "arrow/record_batch.h" |
23 | | #include "arrow/type.h" |
24 | | #include "common/status.h" |
25 | | |
26 | | namespace doris { |
27 | | |
28 | | Status PythonUDTFClient::create(const PythonUDFMeta& func_meta, ProcessPtr process, |
29 | 1.64k | PythonUDTFClientPtr* client) { |
30 | 1.64k | PythonUDTFClientPtr python_udtf_client = std::make_shared<PythonUDTFClient>(); |
31 | 1.64k | RETURN_IF_ERROR(python_udtf_client->init(func_meta, std::move(process))); |
32 | 1.64k | *client = std::move(python_udtf_client); |
33 | 1.64k | return Status::OK(); |
34 | 1.64k | } |
35 | | |
36 | | Status PythonUDTFClient::evaluate(const arrow::RecordBatch& input, |
37 | 265 | std::shared_ptr<arrow::ListArray>* list_array) { |
38 | 265 | RETURN_IF_ERROR(begin_stream(input.schema())); |
39 | 265 | RETURN_IF_ERROR(write_batch(input)); |
40 | | |
41 | | // Read the response (ListArray-based) |
42 | 265 | std::shared_ptr<arrow::RecordBatch> response_batch; |
43 | 265 | RETURN_IF_ERROR(read_batch(&response_batch)); |
44 | | |
45 | | // Validate response structure: should have a single ListArray column |
46 | 265 | if (response_batch->num_columns() != 1) { |
47 | 0 | return Status::InternalError( |
48 | 0 | fmt::format("Invalid UDTF response: expected 1 column (ListArray), got {}", |
49 | 0 | response_batch->num_columns())); |
50 | 0 | } |
51 | | |
52 | 265 | auto list_array_ptr = response_batch->column(0); |
53 | 265 | if (list_array_ptr->type_id() != arrow::Type::LIST) { |
54 | 0 | return Status::InternalError( |
55 | 0 | fmt::format("Invalid UDTF response: expected ListArray, got type {}", |
56 | 0 | list_array_ptr->type()->ToString())); |
57 | 0 | } |
58 | | |
59 | 265 | *list_array = std::static_pointer_cast<arrow::ListArray>(list_array_ptr); |
60 | 265 | return Status::OK(); |
61 | 265 | } |
62 | | |
63 | | } // namespace doris |