Coverage Report

Created: 2026-03-16 21:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_python_udaf.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include "common/status.h"
21
#include "core/column/column.h"
22
#include "core/data_type/data_type.h"
23
#include "core/types.h"
24
#include "exprs/aggregate/aggregate_function.h"
25
#include "udf/python/python_env.h"
26
#include "udf/python/python_udaf_client.h"
27
#include "udf/python/python_udf_meta.h"
28
29
namespace doris {
30
#include "common/compile_check_begin.h"
31
32
/**
33
 * Aggregate state data for Python UDAF
34
 * 
35
 * Python UDAF state is managed remotely (in Python server).
36
 * We cache serialized state for shuffle/merge operations (similar to Java UDAF).
37
 */
38
struct AggregatePythonUDAFData {
39
    std::string serialize_data;
40
    PythonUDAFClientPtr client;
41
42
0
    AggregatePythonUDAFData() = default;
43
44
    AggregatePythonUDAFData(const AggregatePythonUDAFData& other)
45
0
            : serialize_data(other.serialize_data), client(other.client) {}
46
47
0
    ~AggregatePythonUDAFData() = default;
48
49
    Status create(int64_t place);
50
51
    Status add(int64_t place_id, const IColumn** columns, int64_t row_num_start,
52
               int64_t row_num_end, const DataTypes& argument_types);
53
54
    Status add_batch(AggregateDataPtr* places, size_t place_offset, size_t num_rows,
55
                     const IColumn** columns, const DataTypes& argument_types, size_t start,
56
                     size_t end);
57
58
    Status merge(const AggregatePythonUDAFData& rhs, int64_t place);
59
60
    Status write(BufferWritable& buf, int64_t place) const;
61
62
    void read(BufferReadable& buf);
63
64
    Status reset(int64_t place);
65
66
    Status destroy(int64_t place);
67
68
    Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const;
69
};
70
71
/**
72
 * Python UDAF Aggregate Function
73
 * 
74
 * Implements Snowflake-style UDAF pattern:
75
 * - __init__(): Initialize aggregate state
76
 * - aggregate_state: Property returning serializable state
77
 * - accumulate(*args): Add input to state
78
 * - merge(other_state): Combine two states
79
 * - finish(): Get final result
80
 * 
81
 * Communication with Python server via PythonUDAFClient using Arrow Flight.
82
 */
83
class AggregatePythonUDAF final
84
        : public IAggregateFunctionDataHelper<AggregatePythonUDAFData, AggregatePythonUDAF>,
85
          VarargsExpression,
86
          NullableAggregateFunction {
87
public:
88
    ENABLE_FACTORY_CREATOR(AggregatePythonUDAF);
89
90
    AggregatePythonUDAF(const TFunction& fn, const DataTypes& argument_types_,
91
                        const DataTypePtr& return_type)
92
0
            : IAggregateFunctionDataHelper(argument_types_), _fn(fn), _return_type(return_type) {}
93
94
0
    ~AggregatePythonUDAF() override = default;
95
96
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
97
0
                                       const DataTypePtr& return_type) {
98
0
        return std::make_shared<AggregatePythonUDAF>(fn, argument_types_, return_type);
99
0
    }
100
101
0
    String get_name() const override { return _fn.name.function_name; }
102
103
0
    DataTypePtr get_return_type() const override { return _return_type; }
104
105
    /**
106
     * Initialize function metadata
107
     */
108
    Status open();
109
110
    /**
111
     * Create aggregate state in Python server
112
     */
113
    void create(AggregateDataPtr __restrict place) const override;
114
115
    /**
116
     * Destroy aggregate state in Python server
117
     */
118
    void destroy(AggregateDataPtr __restrict place) const noexcept override;
119
120
    /**
121
     * Add single row to aggregate state
122
     */
123
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
124
             Arena&) const override;
125
126
    /**
127
     * Add batch of rows to multiple aggregate states (GROUP BY)
128
     */
129
    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
130
                   const IColumn** columns, Arena&, bool /*agg_many*/) const override;
131
132
    /**
133
     * Add batch of rows to single aggregate state (no GROUP BY)
134
     */
135
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
136
                                Arena&) const override;
137
138
    /**
139
     * Add range of rows to single place (for window functions)
140
     */
141
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
142
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
143
                                Arena& arena, UInt8* current_window_empty,
144
                                UInt8* current_window_has_inited) const override;
145
146
    /**
147
     * Reset aggregate state to initial value
148
     */
149
    void reset(AggregateDataPtr place) const override;
150
151
    /**
152
     * Merge two aggregate states
153
     */
154
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena&) const override;
155
156
    /**
157
     * Serialize aggregate state for shuffle
158
     */
159
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override;
160
161
    /**
162
     * Deserialize aggregate state from shuffle
163
     */
164
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena&) const override;
165
166
    /**
167
     * Get final result and insert into output column
168
     */
169
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override;
170
171
private:
172
    TFunction _fn;
173
    DataTypePtr _return_type;
174
    PythonUDFMeta _func_meta;
175
    PythonVersion _python_version;
176
    // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary]
177
    // Used for all UDAF RPC operations
178
    // - places column is always present (NULL in single-place mode, actual place_id values in GROUP BY mode)
179
    // - binary_data column contains serialized data for MERGE operations (NULL for ACCUMULATE)
180
    mutable std::shared_ptr<arrow::Schema> _schema;
181
    mutable std::once_flag _schema_init_flag;
182
};
183
184
} // namespace doris
185
186
#include "common/compile_check_end.h"