Coverage Report

Created: 2026-04-10 04:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_python_udaf.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include "common/status.h"
21
#include "core/column/column.h"
22
#include "core/data_type/data_type.h"
23
#include "core/types.h"
24
#include "exprs/aggregate/aggregate_function.h"
25
#include "udf/python/python_env.h"
26
#include "udf/python/python_udaf_client.h"
27
#include "udf/python/python_udf_meta.h"
28
29
namespace doris {
30
31
/**
32
 * Aggregate state data for Python UDAF
33
 * 
34
 * Python UDAF state is managed remotely (in Python server).
35
 * We cache serialized state for shuffle/merge operations (similar to Java UDAF).
36
 */
37
struct AggregatePythonUDAFData {
38
    std::string serialize_data;
39
    PythonUDAFClientPtr client;
40
41
0
    AggregatePythonUDAFData() = default;
42
43
    AggregatePythonUDAFData(const AggregatePythonUDAFData& other)
44
0
            : serialize_data(other.serialize_data), client(other.client) {}
45
46
0
    ~AggregatePythonUDAFData() = default;
47
48
    Status create(int64_t place);
49
50
    Status add(int64_t place_id, const IColumn** columns, int64_t row_num_start,
51
               int64_t row_num_end, const DataTypes& argument_types);
52
53
    Status add_batch(AggregateDataPtr* places, size_t place_offset, size_t num_rows,
54
                     const IColumn** columns, const DataTypes& argument_types, size_t start,
55
                     size_t end);
56
57
    Status merge(const AggregatePythonUDAFData& rhs, int64_t place);
58
59
    Status write(BufferWritable& buf, int64_t place) const;
60
61
    void read(BufferReadable& buf);
62
63
    Status reset(int64_t place);
64
65
    Status destroy(int64_t place);
66
67
    Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const;
68
};
69
70
/**
71
 * Python UDAF Aggregate Function
72
 * 
73
 * Implements Snowflake-style UDAF pattern:
74
 * - __init__(): Initialize aggregate state
75
 * - aggregate_state: Property returning serializable state
76
 * - accumulate(*args): Add input to state
77
 * - merge(other_state): Combine two states
78
 * - finish(): Get final result
79
 * 
80
 * Communication with Python server via PythonUDAFClient using Arrow Flight.
81
 */
82
class AggregatePythonUDAF final
83
        : public IAggregateFunctionDataHelper<AggregatePythonUDAFData, AggregatePythonUDAF>,
84
          VarargsExpression,
85
          NullableAggregateFunction {
86
public:
87
    ENABLE_FACTORY_CREATOR(AggregatePythonUDAF);
88
89
    AggregatePythonUDAF(const TFunction& fn, const DataTypes& argument_types_,
90
                        const DataTypePtr& return_type)
91
0
            : IAggregateFunctionDataHelper(argument_types_), _fn(fn), _return_type(return_type) {}
92
93
0
    ~AggregatePythonUDAF() override = default;
94
95
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
96
0
                                       const DataTypePtr& return_type) {
97
0
        return std::make_shared<AggregatePythonUDAF>(fn, argument_types_, return_type);
98
0
    }
99
100
0
    String get_name() const override { return _fn.name.function_name; }
101
102
0
    DataTypePtr get_return_type() const override { return _return_type; }
103
104
    /**
105
     * Initialize function metadata
106
     */
107
    Status open();
108
109
    /**
110
     * Create aggregate state in Python server
111
     */
112
    void create(AggregateDataPtr __restrict place) const override;
113
114
    /**
115
     * Destroy aggregate state in Python server
116
     */
117
    void destroy(AggregateDataPtr __restrict place) const noexcept override;
118
119
    /**
120
     * Add single row to aggregate state
121
     */
122
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
123
             Arena&) const override;
124
125
    /**
126
     * Add batch of rows to multiple aggregate states (GROUP BY)
127
     */
128
    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
129
                   const IColumn** columns, Arena&, bool /*agg_many*/) const override;
130
131
    /**
132
     * Add batch of rows to single aggregate state (no GROUP BY)
133
     */
134
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
135
                                Arena&) const override;
136
137
    /**
138
     * Add range of rows to single place (for window functions)
139
     */
140
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
141
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
142
                                Arena& arena, UInt8* current_window_empty,
143
                                UInt8* current_window_has_inited) const override;
144
145
    /**
146
     * Reset aggregate state to initial value
147
     */
148
    void reset(AggregateDataPtr place) const override;
149
150
    /**
151
     * Merge two aggregate states
152
     */
153
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena&) const override;
154
155
    /**
156
     * Serialize aggregate state for shuffle
157
     */
158
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override;
159
160
    /**
161
     * Deserialize aggregate state from shuffle
162
     */
163
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena&) const override;
164
165
    /**
166
     * Get final result and insert into output column
167
     */
168
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override;
169
170
private:
171
    TFunction _fn;
172
    DataTypePtr _return_type;
173
    PythonUDFMeta _func_meta;
174
    PythonVersion _python_version;
175
    // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary]
176
    // Used for all UDAF RPC operations
177
    // - places column is always present (NULL in single-place mode, actual place_id values in GROUP BY mode)
178
    // - binary_data column contains serialized data for MERGE operations (NULL for ACCUMULATE)
179
    mutable std::shared_ptr<arrow::Schema> _schema;
180
    mutable std::once_flag _schema_init_flag;
181
};
182
183
} // namespace doris