be/src/udf/python/python_udaf_client.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <arrow/record_batch.h> |
21 | | #include <arrow/status.h> |
22 | | |
23 | | #include "udf/python/python_client.h" |
24 | | |
25 | | namespace doris { |
26 | | |
27 | | class PythonUDAFClient; |
28 | | |
29 | | using PythonUDAFClientPtr = std::shared_ptr<PythonUDAFClient>; |
30 | | |
31 | | // Fixed-size (30 bytes) binary metadata structure for UDAF operations (Request) |
32 | | struct __attribute__((packed)) UDAFMetadata { |
33 | | uint32_t meta_version; // 4 bytes: metadata version (current version = 1) |
34 | | uint8_t operation; // 1 byte: UDAFOperation enum |
35 | | uint8_t is_single_place; // 1 byte: boolean (0 or 1, ACCUMULATE only) |
36 | | int64_t place_id; // 8 bytes: aggregate state identifier (globally unique) |
37 | | int64_t row_start; // 8 bytes: start row index (ACCUMULATE only) |
38 | | int64_t row_end; // 8 bytes: end row index (exclusive, ACCUMULATE only) |
39 | | }; |
40 | | |
41 | | static_assert(sizeof(UDAFMetadata) == 30, "UDAFMetadata size must be 30 bytes"); |
42 | | |
43 | | // Current metadata version constant |
44 | | constexpr uint32_t UDAF_METADATA_VERSION = 1; |
45 | | |
46 | | /** |
47 | | * Python UDAF Client |
48 | | * |
49 | | * Implements Snowflake-style UDAF pattern with the following methods: |
50 | | * - __init__(): Initialize aggregate state |
51 | | * - aggregate_state: Property that returns internal state |
52 | | * - accumulate(input): Add new input to aggregate state |
53 | | * - merge(other_state): Combine two intermediate states |
54 | | * - finish(): Generate final result from aggregate state |
55 | | * |
56 | | * Communication protocol with Python server: |
57 | | * 1. CREATE: Initialize UDAF class instance and get initial state |
58 | | * 2. ACCUMULATE: Send input data batch and get updated states |
59 | | * 3. SERIALIZE: Get serialized state for shuffle/merge |
60 | | * 4. MERGE: Combine serialized states |
61 | | * 5. FINALIZE: Get final result from state |
62 | | * 6. RESET: Reset state to initial value |
63 | | * 7. DESTROY: Clean up resources |
64 | | */ |
65 | | class PythonUDAFClient : public PythonClient { |
66 | | public: |
67 | | // UDAF operation types |
68 | | enum class UDAFOperation : uint8_t { |
69 | | CREATE = 0, // Create new aggregate state |
70 | | ACCUMULATE = 1, // Add input rows to state |
71 | | SERIALIZE = 2, // Serialize state for shuffle |
72 | | MERGE = 3, // Merge two states |
73 | | FINALIZE = 4, // Get final result |
74 | | RESET = 5, // Reset state |
75 | | DESTROY = 6 // Destroy state |
76 | | }; |
77 | | |
78 | 3.46k | PythonUDAFClient() = default; |
79 | | |
80 | 3.46k | ~PythonUDAFClient() override { |
81 | | // Clean up all remaining states on destruction |
82 | 3.46k | auto st = close(); |
83 | 3.46k | if (!st.ok()) { |
84 | | LOG(WARNING) << "Failed to close PythonUDAFClient in destructor: " << st.to_string(); |
85 | 0 | } |
86 | 3.46k | } |
87 | | |
88 | | static Status create(const PythonUDFMeta& func_meta, ProcessPtr process, |
89 | | const std::shared_ptr<arrow::Schema>& data_schema, |
90 | | PythonUDAFClientPtr* client); |
91 | | |
92 | | /** |
93 | | * Initialize UDAF client with data schema |
94 | | * Overrides base class to set _schema before initialization |
95 | | * @param func_meta Function metadata |
96 | | * @param process Python process handle |
97 | | * @param data_schema Arrow schema for UDAF data |
98 | | * @return Status |
99 | | */ |
100 | | Status init(const PythonUDFMeta& func_meta, ProcessPtr process, |
101 | | const std::shared_ptr<arrow::Schema>& data_schema); |
102 | | |
103 | | /** |
104 | | * Create aggregate state for a place |
105 | | * @param place_id Unique identifier for the aggregate state |
106 | | * @return Status |
107 | | */ |
108 | | Status create(int64_t place_id); |
109 | | |
110 | | /** |
111 | | * Accumulate input data into aggregate state |
112 | | * |
113 | | * For single-place mode (is_single_place=true): |
114 | | * - input RecordBatch contains only data columns |
115 | | * - All rows are accumulated to the same place_id |
116 | | * |
117 | | * For multi-place mode (is_single_place=false): |
118 | | * - input RecordBatch MUST contain a "places" column (int64) as the last column |
119 | | * - The "places" column indicates which place each row belongs to |
120 | | * - place_id parameter is ignored (set to 0 by convention) |
121 | | * |
122 | | * @param place_id Aggregate state identifier (used only in single-place mode) |
123 | | * @param is_single_place Whether all rows go to single place |
124 | | * @param input Input data batch (must contain "places" column if is_single_place=false) |
125 | | * @param row_start Start row index |
126 | | * @param row_end End row index (exclusive) |
127 | | * @return Status |
128 | | */ |
129 | | Status accumulate(int64_t place_id, bool is_single_place, const arrow::RecordBatch& input, |
130 | | int64_t row_start, int64_t row_end); |
131 | | |
132 | | /** |
133 | | * Serialize aggregate state for shuffle/merge |
134 | | * @param place_id Aggregate state identifier |
135 | | * @param serialized_state Output serialized state |
136 | | * @return Status |
137 | | */ |
138 | | Status serialize(int64_t place_id, std::shared_ptr<arrow::Buffer>* serialized_state); |
139 | | |
140 | | /** |
141 | | * Merge another serialized state into current state |
142 | | * @param place_id Target aggregate state identifier |
143 | | * @param serialized_state Serialized state to merge |
144 | | * @return Status |
145 | | */ |
146 | | Status merge(int64_t place_id, const std::shared_ptr<arrow::Buffer>& serialized_state); |
147 | | |
148 | | /** |
149 | | * Get final result from aggregate state |
150 | | * @param place_id Aggregate state identifier |
151 | | * @param output Output result |
152 | | * @return Status |
153 | | */ |
154 | | Status finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output); |
155 | | |
156 | | /** |
157 | | * Reset aggregate state to initial value |
158 | | * @param place_id Aggregate state identifier |
159 | | * @return Status |
160 | | */ |
161 | | Status reset(int64_t place_id); |
162 | | |
163 | | /** |
164 | | * Destroy aggregate state and free resources |
165 | | * @param place_id Aggregate state identifier |
166 | | * @return Status |
167 | | */ |
168 | | Status destroy(int64_t place_id); |
169 | | |
170 | | /** |
171 | | * Close client connection and cleanup |
172 | | * Overrides base class to destroy the tracked place first |
173 | | * @return Status |
174 | | */ |
175 | | Status close(); |
176 | | |
177 | | #ifdef BE_TEST |
178 | | static Status make_udaf_failure_status_for_test( |
179 | | const std::shared_ptr<arrow::RecordBatch>& response, const char* operation, |
180 | | int64_t place_id); |
181 | | #endif |
182 | | |
183 | | private: |
184 | | DISALLOW_COPY_AND_ASSIGN(PythonUDAFClient); |
185 | | |
186 | | static Status make_udaf_failure_status(const std::shared_ptr<arrow::RecordBatch>& response, |
187 | | const char* operation, int64_t place_id); |
188 | | |
189 | | /** |
190 | | * Send RecordBatch request to Python server with app_metadata |
191 | | * @param metadata UDAFMetadata structure (will be sent as app_metadata) |
192 | | * @param request_batch Request RecordBatch (contains data columns + binary_data column) |
193 | | * @param response_batch Output RecordBatch |
194 | | * @return Status |
195 | | */ |
196 | | Status _send_request(const UDAFMetadata& metadata, |
197 | | const std::shared_ptr<arrow::RecordBatch>& request_batch, |
198 | | std::shared_ptr<arrow::RecordBatch>* response_batch); |
199 | | |
200 | | /** |
201 | | * Create request batch with data columns (for ACCUMULATE) |
202 | | * Appends NULL binary_data column to input data batch |
203 | | */ |
204 | | Status _create_data_request_batch(const arrow::RecordBatch& input_data, |
205 | | std::shared_ptr<arrow::RecordBatch>* out); |
206 | | |
207 | | /** |
208 | | * Create request batch with binary data (for MERGE) |
209 | | * Creates NULL data columns + binary_data column |
210 | | */ |
211 | | Status _create_binary_request_batch(const std::shared_ptr<arrow::Buffer>& binary_data, |
212 | | std::shared_ptr<arrow::RecordBatch>* out); |
213 | | |
214 | | /** |
215 | | * Get or create empty request batch (for CREATE/SERIALIZE/FINALIZE/RESET/DESTROY) |
216 | | * All columns are NULL. Cached after first creation for reuse. |
217 | | */ |
218 | | Status _get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out); |
219 | | |
220 | | // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary] |
221 | | std::shared_ptr<arrow::Schema> _schema; |
222 | | std::shared_ptr<arrow::RecordBatch> _empty_request_batch; |
223 | | // Track created state for cleanup |
224 | | std::optional<int64_t> _created_place_id; |
225 | | // Thread safety: protect gRPC stream operations |
226 | | // CRITICAL: gRPC ClientReaderWriter does NOT support concurrent Write() calls |
227 | | // Even within same thread, multiple pipeline tasks may trigger concurrent operations |
228 | | // (e.g., normal accumulate() + cleanup destroy() during task finalization) |
229 | | mutable std::mutex _operation_mutex; |
230 | | }; |
231 | | |
232 | | } // namespace doris |