be/src/udf/python/python_udaf_client.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <arrow/status.h> |
21 | | |
22 | | #include "udf/python/python_client.h" |
23 | | |
24 | | namespace doris { |
25 | | |
26 | | class PythonUDAFClient; |
27 | | |
28 | | using PythonUDAFClientPtr = std::shared_ptr<PythonUDAFClient>; |
29 | | |
30 | | // Fixed-size (30 bytes) binary metadata structure for UDAF operations (Request) |
31 | | struct __attribute__((packed)) UDAFMetadata { |
32 | | uint32_t meta_version; // 4 bytes: metadata version (current version = 1) |
33 | | uint8_t operation; // 1 byte: UDAFOperation enum |
34 | | uint8_t is_single_place; // 1 byte: boolean (0 or 1, ACCUMULATE only) |
35 | | int64_t place_id; // 8 bytes: aggregate state identifier (globally unique) |
36 | | int64_t row_start; // 8 bytes: start row index (ACCUMULATE only) |
37 | | int64_t row_end; // 8 bytes: end row index (exclusive, ACCUMULATE only) |
38 | | }; |
39 | | |
40 | | static_assert(sizeof(UDAFMetadata) == 30, "UDAFMetadata size must be 30 bytes"); |
41 | | |
42 | | // Current metadata version constant |
43 | | constexpr uint32_t UDAF_METADATA_VERSION = 1; |
44 | | |
45 | | /** |
46 | | * Python UDAF Client |
47 | | * |
48 | | * Implements Snowflake-style UDAF pattern with the following methods: |
49 | | * - __init__(): Initialize aggregate state |
50 | | * - aggregate_state: Property that returns internal state |
51 | | * - accumulate(input): Add new input to aggregate state |
52 | | * - merge(other_state): Combine two intermediate states |
53 | | * - finish(): Generate final result from aggregate state |
54 | | * |
55 | | * Communication protocol with Python server: |
56 | | * 1. CREATE: Initialize UDAF class instance and get initial state |
57 | | * 2. ACCUMULATE: Send input data batch and get updated states |
58 | | * 3. SERIALIZE: Get serialized state for shuffle/merge |
59 | | * 4. MERGE: Combine serialized states |
60 | | * 5. FINALIZE: Get final result from state |
61 | | * 6. RESET: Reset state to initial value |
62 | | * 7. DESTROY: Clean up resources |
63 | | */ |
64 | | class PythonUDAFClient : public PythonClient { |
65 | | public: |
66 | | // UDAF operation types |
67 | | enum class UDAFOperation : uint8_t { |
68 | | CREATE = 0, // Create new aggregate state |
69 | | ACCUMULATE = 1, // Add input rows to state |
70 | | SERIALIZE = 2, // Serialize state for shuffle |
71 | | MERGE = 3, // Merge two states |
72 | | FINALIZE = 4, // Get final result |
73 | | RESET = 5, // Reset state |
74 | | DESTROY = 6 // Destroy state |
75 | | }; |
76 | | |
77 | 3.07k | PythonUDAFClient() = default; |
78 | | |
79 | 3.07k | ~PythonUDAFClient() override { |
80 | | // Clean up all remaining states on destruction |
81 | 3.07k | auto st = close(); |
82 | 3.07k | if (!st.ok()) { |
83 | | LOG(WARNING) << "Failed to close PythonUDAFClient in destructor: " << st.to_string(); |
84 | 0 | } |
85 | 3.07k | } |
86 | | |
87 | | static Status create(const PythonUDFMeta& func_meta, ProcessPtr process, |
88 | | const std::shared_ptr<arrow::Schema>& data_schema, |
89 | | PythonUDAFClientPtr* client); |
90 | | |
91 | | /** |
92 | | * Initialize UDAF client with data schema |
93 | | * Overrides base class to set _schema before initialization |
94 | | * @param func_meta Function metadata |
95 | | * @param process Python process handle |
96 | | * @param data_schema Arrow schema for UDAF data |
97 | | * @return Status |
98 | | */ |
99 | | Status init(const PythonUDFMeta& func_meta, ProcessPtr process, |
100 | | const std::shared_ptr<arrow::Schema>& data_schema); |
101 | | |
102 | | /** |
103 | | * Create aggregate state for a place |
104 | | * @param place_id Unique identifier for the aggregate state |
105 | | * @return Status |
106 | | */ |
107 | | Status create(int64_t place_id); |
108 | | |
109 | | /** |
110 | | * Accumulate input data into aggregate state |
111 | | * |
112 | | * For single-place mode (is_single_place=true): |
113 | | * - input RecordBatch contains only data columns |
114 | | * - All rows are accumulated to the same place_id |
115 | | * |
116 | | * For multi-place mode (is_single_place=false): |
117 | | * - input RecordBatch MUST contain a "places" column (int64) as the last column |
118 | | * - The "places" column indicates which place each row belongs to |
119 | | * - place_id parameter is ignored (set to 0 by convention) |
120 | | * |
121 | | * @param place_id Aggregate state identifier (used only in single-place mode) |
122 | | * @param is_single_place Whether all rows go to single place |
123 | | * @param input Input data batch (must contain "places" column if is_single_place=false) |
124 | | * @param row_start Start row index |
125 | | * @param row_end End row index (exclusive) |
126 | | * @return Status |
127 | | */ |
128 | | Status accumulate(int64_t place_id, bool is_single_place, const arrow::RecordBatch& input, |
129 | | int64_t row_start, int64_t row_end); |
130 | | |
131 | | /** |
132 | | * Serialize aggregate state for shuffle/merge |
133 | | * @param place_id Aggregate state identifier |
134 | | * @param serialized_state Output serialized state |
135 | | * @return Status |
136 | | */ |
137 | | Status serialize(int64_t place_id, std::shared_ptr<arrow::Buffer>* serialized_state); |
138 | | |
139 | | /** |
140 | | * Merge another serialized state into current state |
141 | | * @param place_id Target aggregate state identifier |
142 | | * @param serialized_state Serialized state to merge |
143 | | * @return Status |
144 | | */ |
145 | | Status merge(int64_t place_id, const std::shared_ptr<arrow::Buffer>& serialized_state); |
146 | | |
147 | | /** |
148 | | * Get final result from aggregate state |
149 | | * @param place_id Aggregate state identifier |
150 | | * @param output Output result |
151 | | * @return Status |
152 | | */ |
153 | | Status finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output); |
154 | | |
155 | | /** |
156 | | * Reset aggregate state to initial value |
157 | | * @param place_id Aggregate state identifier |
158 | | * @return Status |
159 | | */ |
160 | | Status reset(int64_t place_id); |
161 | | |
162 | | /** |
163 | | * Destroy aggregate state and free resources |
164 | | * @param place_id Aggregate state identifier |
165 | | * @return Status |
166 | | */ |
167 | | Status destroy(int64_t place_id); |
168 | | |
169 | | /** |
170 | | * Close client connection and cleanup |
171 | | * Overrides base class to destroy the tracked place first |
172 | | * @return Status |
173 | | */ |
174 | | Status close(); |
175 | | |
176 | | private: |
177 | | DISALLOW_COPY_AND_ASSIGN(PythonUDAFClient); |
178 | | |
179 | | /** |
180 | | * Send RecordBatch request to Python server with app_metadata |
181 | | * @param metadata UDAFMetadata structure (will be sent as app_metadata) |
182 | | * @param request_batch Request RecordBatch (contains data columns + binary_data column) |
183 | | * @param response_batch Output RecordBatch |
184 | | * @return Status |
185 | | */ |
186 | | Status _send_request(const UDAFMetadata& metadata, |
187 | | const std::shared_ptr<arrow::RecordBatch>& request_batch, |
188 | | std::shared_ptr<arrow::RecordBatch>* response_batch); |
189 | | |
190 | | /** |
191 | | * Create request batch with data columns (for ACCUMULATE) |
192 | | * Appends NULL binary_data column to input data batch |
193 | | */ |
194 | | Status _create_data_request_batch(const arrow::RecordBatch& input_data, |
195 | | std::shared_ptr<arrow::RecordBatch>* out); |
196 | | |
197 | | /** |
198 | | * Create request batch with binary data (for MERGE) |
199 | | * Creates NULL data columns + binary_data column |
200 | | */ |
201 | | Status _create_binary_request_batch(const std::shared_ptr<arrow::Buffer>& binary_data, |
202 | | std::shared_ptr<arrow::RecordBatch>* out); |
203 | | |
204 | | /** |
205 | | * Get or create empty request batch (for CREATE/SERIALIZE/FINALIZE/RESET/DESTROY) |
206 | | * All columns are NULL. Cached after first creation for reuse. |
207 | | */ |
208 | | Status _get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out); |
209 | | |
210 | | // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary] |
211 | | std::shared_ptr<arrow::Schema> _schema; |
212 | | std::shared_ptr<arrow::RecordBatch> _empty_request_batch; |
213 | | // Track created state for cleanup |
214 | | std::optional<int64_t> _created_place_id; |
215 | | // Thread safety: protect gRPC stream operations |
216 | | // CRITICAL: gRPC ClientReaderWriter does NOT support concurrent Write() calls |
217 | | // Even within same thread, multiple pipeline tasks may trigger concurrent operations |
218 | | // (e.g., normal accumulate() + cleanup destroy() during task finalization) |
219 | | mutable std::mutex _operation_mutex; |
220 | | }; |
221 | | |
222 | | } // namespace doris |