be/src/exprs/aggregate/aggregate_function_python_udaf.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include "common/status.h" |
21 | | #include "core/column/column.h" |
22 | | #include "core/data_type/data_type.h" |
23 | | #include "core/types.h" |
24 | | #include "exprs/aggregate/aggregate_function.h" |
25 | | #include "udf/python/python_env.h" |
26 | | #include "udf/python/python_udaf_client.h" |
27 | | #include "udf/python/python_udf_meta.h" |
28 | | |
29 | | namespace doris { |
30 | | #include "common/compile_check_begin.h" |
31 | | |
32 | | /** |
33 | | * Aggregate state data for Python UDAF |
34 | | * |
35 | | * Python UDAF state is managed remotely (in Python server). |
36 | | * We cache serialized state for shuffle/merge operations (similar to Java UDAF). |
37 | | */ |
38 | | struct AggregatePythonUDAFData { |
39 | | std::string serialize_data; |
40 | | PythonUDAFClientPtr client; |
41 | | |
42 | 0 | AggregatePythonUDAFData() = default; |
43 | | |
44 | | AggregatePythonUDAFData(const AggregatePythonUDAFData& other) |
45 | 0 | : serialize_data(other.serialize_data), client(other.client) {} |
46 | | |
47 | 0 | ~AggregatePythonUDAFData() = default; |
48 | | |
49 | | Status create(int64_t place); |
50 | | |
51 | | Status add(int64_t place_id, const IColumn** columns, int64_t row_num_start, |
52 | | int64_t row_num_end, const DataTypes& argument_types); |
53 | | |
54 | | Status add_batch(AggregateDataPtr* places, size_t place_offset, size_t num_rows, |
55 | | const IColumn** columns, const DataTypes& argument_types, size_t start, |
56 | | size_t end); |
57 | | |
58 | | Status merge(const AggregatePythonUDAFData& rhs, int64_t place); |
59 | | |
60 | | Status write(BufferWritable& buf, int64_t place) const; |
61 | | |
62 | | void read(BufferReadable& buf); |
63 | | |
64 | | Status reset(int64_t place); |
65 | | |
66 | | Status destroy(int64_t place); |
67 | | |
68 | | Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const; |
69 | | }; |
70 | | |
71 | | /** |
72 | | * Python UDAF Aggregate Function |
73 | | * |
74 | | * Implements Snowflake-style UDAF pattern: |
75 | | * - __init__(): Initialize aggregate state |
76 | | * - aggregate_state: Property returning serializable state |
77 | | * - accumulate(*args): Add input to state |
78 | | * - merge(other_state): Combine two states |
79 | | * - finish(): Get final result |
80 | | * |
81 | | * Communication with Python server via PythonUDAFClient using Arrow Flight. |
82 | | */ |
83 | | class AggregatePythonUDAF final |
84 | | : public IAggregateFunctionDataHelper<AggregatePythonUDAFData, AggregatePythonUDAF>, |
85 | | VarargsExpression, |
86 | | NullableAggregateFunction { |
87 | | public: |
88 | | ENABLE_FACTORY_CREATOR(AggregatePythonUDAF); |
89 | | |
90 | | AggregatePythonUDAF(const TFunction& fn, const DataTypes& argument_types_, |
91 | | const DataTypePtr& return_type) |
92 | 0 | : IAggregateFunctionDataHelper(argument_types_), _fn(fn), _return_type(return_type) {} |
93 | | |
94 | 0 | ~AggregatePythonUDAF() override = default; |
95 | | |
96 | | static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_, |
97 | 0 | const DataTypePtr& return_type) { |
98 | 0 | return std::make_shared<AggregatePythonUDAF>(fn, argument_types_, return_type); |
99 | 0 | } |
100 | | |
101 | 0 | String get_name() const override { return _fn.name.function_name; } |
102 | | |
103 | 0 | DataTypePtr get_return_type() const override { return _return_type; } |
104 | | |
105 | | /** |
106 | | * Initialize function metadata |
107 | | */ |
108 | | Status open(); |
109 | | |
110 | | /** |
111 | | * Create aggregate state in Python server |
112 | | */ |
113 | | void create(AggregateDataPtr __restrict place) const override; |
114 | | |
115 | | /** |
116 | | * Destroy aggregate state in Python server |
117 | | */ |
118 | | void destroy(AggregateDataPtr __restrict place) const noexcept override; |
119 | | |
120 | | /** |
121 | | * Add single row to aggregate state |
122 | | */ |
123 | | void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
124 | | Arena&) const override; |
125 | | |
126 | | /** |
127 | | * Add batch of rows to multiple aggregate states (GROUP BY) |
128 | | */ |
129 | | void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, |
130 | | const IColumn** columns, Arena&, bool /*agg_many*/) const override; |
131 | | |
132 | | /** |
133 | | * Add batch of rows to single aggregate state (no GROUP BY) |
134 | | */ |
135 | | void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, |
136 | | Arena&) const override; |
137 | | |
138 | | /** |
139 | | * Add range of rows to single place (for window functions) |
140 | | */ |
141 | | void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, |
142 | | int64_t frame_end, AggregateDataPtr place, const IColumn** columns, |
143 | | Arena& arena, UInt8* current_window_empty, |
144 | | UInt8* current_window_has_inited) const override; |
145 | | |
146 | | /** |
147 | | * Reset aggregate state to initial value |
148 | | */ |
149 | | void reset(AggregateDataPtr place) const override; |
150 | | |
151 | | /** |
152 | | * Merge two aggregate states |
153 | | */ |
154 | | void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena&) const override; |
155 | | |
156 | | /** |
157 | | * Serialize aggregate state for shuffle |
158 | | */ |
159 | | void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override; |
160 | | |
161 | | /** |
162 | | * Deserialize aggregate state from shuffle |
163 | | */ |
164 | | void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena&) const override; |
165 | | |
166 | | /** |
167 | | * Get final result and insert into output column |
168 | | */ |
169 | | void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override; |
170 | | |
171 | | private: |
172 | | TFunction _fn; |
173 | | DataTypePtr _return_type; |
174 | | PythonUDFMeta _func_meta; |
175 | | PythonVersion _python_version; |
176 | | // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary] |
177 | | // Used for all UDAF RPC operations |
178 | | // - places column is always present (NULL in single-place mode, actual place_id values in GROUP BY mode) |
179 | | // - binary_data column contains serialized data for MERGE operations (NULL for ACCUMULATE) |
180 | | mutable std::shared_ptr<arrow::Schema> _schema; |
181 | | mutable std::once_flag _schema_init_flag; |
182 | | }; |
183 | | |
184 | | } // namespace doris |
185 | | |
186 | | #include "common/compile_check_end.h" |