Coverage Report

Created: 2026-03-13 21:50

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_python_udaf.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "exprs/aggregate/aggregate_function_python_udaf.h"
19
20
#include <arrow/array.h>
21
#include <arrow/memory_pool.h>
22
#include <arrow/record_batch.h>
23
#include <arrow/type.h>
24
#include <fmt/format.h>
25
26
#include "common/exception.h"
27
#include "common/logging.h"
28
#include "core/block/block.h"
29
#include "core/column/column_nullable.h"
30
#include "core/column/column_vector.h"
31
#include "core/data_type/data_type_factory.hpp"
32
#include "core/data_type/data_type_nullable.h"
33
#include "core/data_type/define_primitive_type.h"
34
#include "format/arrow/arrow_block_convertor.h"
35
#include "format/arrow/arrow_row_batch.h"
36
#include "runtime/user_function_cache.h"
37
#include "udf/python/python_env.h"
38
#include "udf/python/python_server.h"
39
#include "util/timezone_utils.h"
40
41
namespace doris {
42
43
0
Status AggregatePythonUDAFData::create(int64_t place) {
44
0
    DCHECK(client) << "Client must be set before calling create";
45
0
    RETURN_IF_ERROR(client->create(place));
46
0
    return Status::OK();
47
0
}
48
49
Status AggregatePythonUDAFData::add(int64_t place_id, const IColumn** columns,
50
                                    int64_t row_num_start, int64_t row_num_end,
51
0
                                    const DataTypes& argument_types) {
52
0
    DCHECK(client) << "Client must be set before calling add";
53
54
    // Zero-copy: Use full columns with range specification
55
0
    Block input_block;
56
0
    for (size_t i = 0; i < argument_types.size(); ++i) {
57
0
        input_block.insert(
58
0
                ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i)));
59
0
    }
60
61
0
    std::shared_ptr<arrow::Schema> schema;
62
0
    RETURN_IF_ERROR(
63
0
            get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
64
0
    cctz::time_zone timezone_obj;
65
0
    TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
66
67
0
    std::shared_ptr<arrow::RecordBatch> batch;
68
    // Zero-copy: convert only the specified range
69
0
    RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
70
0
                                           &batch, timezone_obj, row_num_start, row_num_end));
71
    // Send the batch (already sliced in convert_to_arrow_batch)
72
    // Single place mode: no places column needed
73
0
    RETURN_IF_ERROR(client->accumulate(place_id, true, *batch, 0, batch->num_rows()));
74
0
    return Status::OK();
75
0
}
76
77
Status AggregatePythonUDAFData::add_batch(AggregateDataPtr* places, size_t place_offset,
78
                                          size_t num_rows, const IColumn** columns,
79
                                          const DataTypes& argument_types, size_t start,
80
0
                                          size_t end) {
81
0
    DCHECK(client) << "Client must be set before calling add_batch";
82
0
    DCHECK(end > start) << "end must be greater than start";
83
0
    DCHECK(end <= num_rows) << "end must not exceed num_rows";
84
85
0
    size_t slice_rows = end - start;
86
0
    Block input_block;
87
0
    for (size_t i = 0; i < argument_types.size(); ++i) {
88
0
        DCHECK(columns[i]->size() == num_rows) << "Column size must match num_rows";
89
0
        input_block.insert(
90
0
                ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i)));
91
0
    }
92
93
0
    auto places_col = ColumnInt64::create(num_rows);
94
0
    auto& places_data = places_col->get_data();
95
96
    // Fill places column with place IDs for the slice [start, end)
97
0
    for (size_t i = start; i < end; ++i) {
98
0
        places_data[i] = reinterpret_cast<int64_t>(places[i] + place_offset);
99
0
    }
100
101
0
    static DataTypePtr places_type =
102
0
            DataTypeFactory::instance().create_data_type(PrimitiveType::TYPE_BIGINT, false);
103
0
    input_block.insert(ColumnWithTypeAndName(std::move(places_col), places_type, "places"));
104
105
0
    std::shared_ptr<arrow::Schema> schema;
106
0
    RETURN_IF_ERROR(
107
0
            get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
108
0
    cctz::time_zone timezone_obj;
109
0
    TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
110
111
0
    std::shared_ptr<arrow::RecordBatch> batch;
112
    // Zero-copy: convert only the [start, end) range
113
    // This slice includes the places column automatically
114
0
    RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
115
0
                                           &batch, timezone_obj, start, end));
116
    // Send entire batch (already contains places column) to Python
117
    // place_id=0 is ignored when is_single_place=false
118
0
    RETURN_IF_ERROR(client->accumulate(0, false, *batch, 0, slice_rows));
119
0
    return Status::OK();
120
0
}
121
122
0
Status AggregatePythonUDAFData::merge(const AggregatePythonUDAFData& rhs, int64_t place) {
123
0
    DCHECK(client) << "Client must be set before calling merge";
124
125
    // Get serialized state from rhs (already stored in serialize_data by read())
126
0
    auto serialized_state = arrow::Buffer::Wrap(
127
0
            reinterpret_cast<const uint8_t*>(rhs.serialize_data.data()), rhs.serialize_data.size());
128
0
    RETURN_IF_ERROR(client->merge(place, serialized_state));
129
0
    return Status::OK();
130
0
}
131
132
0
Status AggregatePythonUDAFData::write(BufferWritable& buf, int64_t place) const {
133
0
    DCHECK(client) << "Client must be set before calling write";
134
135
    // Serialize state from Python server
136
0
    std::shared_ptr<arrow::Buffer> serialized_state;
137
0
    RETURN_IF_ERROR(client->serialize(place, &serialized_state));
138
0
    const char* data = reinterpret_cast<const char*>(serialized_state->data());
139
0
    size_t size = serialized_state->size();
140
0
    buf.write_binary(StringRef {data, size});
141
0
    return Status::OK();
142
0
}
143
144
0
void AggregatePythonUDAFData::read(BufferReadable& buf) {
145
    // Read serialized state from buffer into serialize_data
146
    // This will be used later by merge() in deserialize_and_merge()
147
0
    buf.read_binary(serialize_data);
148
0
}
149
150
0
Status AggregatePythonUDAFData::reset(int64_t place) {
151
0
    DCHECK(client) << "Client must be set before calling reset";
152
0
    RETURN_IF_ERROR(client->reset(place));
153
    // After reset, state still exists but is back to initial state
154
0
    return Status::OK();
155
0
}
156
157
0
Status AggregatePythonUDAFData::destroy(int64_t place) {
158
0
    DCHECK(client) << "Client must be set before calling destroy";
159
0
    RETURN_IF_ERROR(client->destroy(place));
160
0
    return Status::OK();
161
0
}
162
163
Status AggregatePythonUDAFData::get(IColumn& to, const DataTypePtr& result_type,
164
0
                                    int64_t place) const {
165
0
    DCHECK(client) << "Client must be set before calling get";
166
167
    // Get final result from Python server
168
0
    std::shared_ptr<arrow::RecordBatch> result;
169
0
    RETURN_IF_ERROR(client->finalize(place, &result));
170
171
    // Convert Arrow RecordBatch to Block
172
0
    Block result_block;
173
0
    DataTypes types = {result_type};
174
0
    cctz::time_zone timezone_obj;
175
0
    TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
176
0
    RETURN_IF_ERROR(convert_from_arrow_batch(result, types, &result_block, timezone_obj));
177
178
    // Insert the result value into output column
179
0
    if (result_block.rows() != 1) {
180
0
        return Status::InternalError("Expected 1 row in result block, got {}", result_block.rows());
181
0
    }
182
183
0
    auto& result_column = result_block.get_by_position(0).column;
184
0
    to.insert_from(*result_column, 0);
185
0
    return Status::OK();
186
0
}
187
188
0
Status AggregatePythonUDAF::open() {
189
    // Build function metadata from TFunction
190
0
    _func_meta.id = _fn.id;
191
0
    _func_meta.name = _fn.name.function_name;
192
193
    // For UDAF, symbol is in aggregate_fn
194
0
    if (_fn.__isset.aggregate_fn && _fn.aggregate_fn.__isset.symbol) {
195
0
        _func_meta.symbol = _fn.aggregate_fn.symbol;
196
0
    } else {
197
0
        return Status::InvalidArgument("Python UDAF symbol is not set");
198
0
    }
199
200
    // Determine load type (inline code or module)
201
0
    if (!_fn.function_code.empty()) {
202
0
        _func_meta.type = PythonUDFLoadType::INLINE;
203
0
        _func_meta.location = "inline";
204
0
        _func_meta.inline_code = _fn.function_code;
205
0
    } else if (!_fn.hdfs_location.empty()) {
206
0
        _func_meta.type = PythonUDFLoadType::MODULE;
207
0
        _func_meta.location = _fn.hdfs_location;
208
0
        _func_meta.checksum = _fn.checksum;
209
0
    } else {
210
0
        _func_meta.type = PythonUDFLoadType::UNKNOWN;
211
0
        _func_meta.location = "unknown";
212
0
    }
213
214
0
    _func_meta.input_types = argument_types;
215
0
    _func_meta.return_type = _return_type;
216
0
    _func_meta.client_type = PythonClientType::UDAF;
217
218
    // Get Python version
219
0
    if (_fn.__isset.runtime_version && !_fn.runtime_version.empty()) {
220
0
        RETURN_IF_ERROR(PythonVersionManager::instance().get_version(_fn.runtime_version,
221
0
                                                                     &_python_version));
222
0
    } else {
223
0
        return Status::InvalidArgument("Python UDAF runtime version is not set");
224
0
    }
225
226
0
    _func_meta.runtime_version = _python_version.full_version;
227
0
    RETURN_IF_ERROR(_func_meta.check());
228
0
    _func_meta.always_nullable = _return_type->is_nullable();
229
230
0
    LOG(INFO) << fmt::format("Creating Python UDAF: {}, runtime_version: {}, func_meta: {}",
231
0
                             _fn.name.function_name, _python_version.to_string(),
232
0
                             _func_meta.to_string());
233
234
0
    if (_func_meta.type == PythonUDFLoadType::MODULE) {
235
0
        RETURN_IF_ERROR(UserFunctionCache::instance()->get_pypath(
236
0
                _func_meta.id, _func_meta.location, _func_meta.checksum, &_func_meta.location));
237
0
    }
238
239
0
    return Status::OK();
240
0
}
241
242
0
void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const {
243
0
    std::call_once(_schema_init_flag, [this]() {
244
0
        std::vector<std::shared_ptr<arrow::Field>> fields;
245
246
0
        std::string timezone = TimezoneUtils::default_time_zone;
247
0
        for (size_t i = 0; i < argument_types.size(); ++i) {
248
0
            std::shared_ptr<arrow::DataType> arrow_type;
249
0
            Status st = convert_to_arrow_type(argument_types[i], &arrow_type, timezone);
250
0
            if (!st.ok()) {
251
0
                throw doris::Exception(ErrorCode::INTERNAL_ERROR,
252
0
                                       "Failed to convert argument type {} to Arrow type: {}", i,
253
0
                                       st.to_string());
254
0
            }
255
0
            fields.push_back(arrow::field(std::to_string(i), arrow_type));
256
0
        }
257
258
        // Add places column for GROUP BY aggregation (always included, NULL in single-place mode)
259
0
        fields.push_back(arrow::field("places", arrow::int64()));
260
        // Add binary_data column for merge operations
261
0
        fields.push_back(arrow::field("binary_data", arrow::binary()));
262
0
        _schema = arrow::schema(fields);
263
0
    });
264
265
    // Initialize the data structure
266
0
    new (place) Data();
267
0
    DCHECK(reinterpret_cast<Data*>(place)) << "Place must not be null";
268
269
0
    if (Status st = PythonServerManager::instance().get_client(
270
0
                _func_meta, _python_version, &(this->data(place).client), _schema);
271
0
        UNLIKELY(!st.ok())) {
272
0
        this->data(place).~Data();
273
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Failed to get Python UDAF client: {}",
274
0
                               st.to_string());
275
0
    }
276
277
    // Initialize UDAF state in Python server
278
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
279
0
    if (Status st = this->data(place).create(place_id); UNLIKELY(!st.ok())) {
280
0
        this->data(place).~Data();
281
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
282
0
    }
283
0
}
284
285
0
void AggregatePythonUDAF::destroy(AggregateDataPtr __restrict place) const noexcept {
286
0
    try {
287
0
        int64_t place_id = reinterpret_cast<int64_t>(place);
288
289
        // Destroy state in Python server
290
0
        if (this->data(place).client) {
291
0
            Status st = this->data(place).destroy(place_id);
292
0
            if (UNLIKELY(!st.ok())) {
293
0
                LOG(WARNING) << "Failed to destroy Python UDAF state for place_id=" << place_id
294
0
                             << ", function=" << _func_meta.name << ": " << st.to_string();
295
0
            }
296
297
0
            this->data(place).client.reset();
298
0
        }
299
300
0
        this->data(place).~Data();
301
0
    } catch (const std::exception& e) {
302
0
        LOG(ERROR) << "Exception in AggregatePythonUDAF::destroy: " << e.what();
303
0
    } catch (...) {
304
0
        LOG(ERROR) << "Unknown exception in AggregatePythonUDAF::destroy";
305
0
    }
306
0
}
307
308
void AggregatePythonUDAF::add(AggregateDataPtr __restrict place, const IColumn** columns,
309
0
                              ssize_t row_num, Arena&) const {
310
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
311
0
    Status st = this->data(place).add(place_id, columns, row_num, row_num + 1, argument_types);
312
0
    if (UNLIKELY(!st.ok())) {
313
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
314
0
    }
315
0
}
316
317
void AggregatePythonUDAF::add_batch(size_t batch_size, AggregateDataPtr* places,
318
                                    size_t place_offset, const IColumn** columns, Arena&,
319
0
                                    bool /*agg_many*/) const {
320
0
    if (batch_size == 0) return;
321
322
0
    size_t start = 0;
323
0
    while (start < batch_size) {
324
        // Get the starting place for this segment
325
0
        AggregateDataPtr start_place = places[start] + place_offset;
326
0
        auto& start_place_data = this->data(start_place);
327
        // Get the process for this segment
328
0
        const auto* current_process = start_place_data.client->get_process().get();
329
330
        // Scan forward to find the end of this consecutive segment (same process)
331
0
        size_t end = start + 1;
332
0
        while (end < batch_size) {
333
0
            AggregateDataPtr end_place = places[end] + place_offset;
334
0
            auto& end_place_data = this->data(end_place);
335
0
            const auto* next_process = end_place_data.client->get_process().get();
336
            // If different process, end the current segment
337
0
            if (*next_process != *current_process) break;
338
0
            ++end;
339
0
        }
340
341
        // Send this segment to Python with zero-copy
342
        // Pass places array and let add_batch construct place_ids on-demand
343
0
        Status st = start_place_data.add_batch(places, place_offset, batch_size, columns,
344
0
                                               argument_types, start, end);
345
346
0
        if (UNLIKELY(!st.ok())) {
347
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR,
348
0
                                   "Failed to send segment to Python: " + st.to_string());
349
0
        }
350
351
0
        start = end;
352
0
    }
353
0
}
354
355
void AggregatePythonUDAF::add_batch_single_place(size_t batch_size, AggregateDataPtr place,
356
0
                                                 const IColumn** columns, Arena&) const {
357
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
358
0
    Status st = this->data(place).add(place_id, columns, 0, batch_size, argument_types);
359
0
    if (UNLIKELY(!st.ok())) {
360
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
361
0
    }
362
0
}
363
364
void AggregatePythonUDAF::add_range_single_place(int64_t partition_start, int64_t partition_end,
365
                                                 int64_t frame_start, int64_t frame_end,
366
                                                 AggregateDataPtr place, const IColumn** columns,
367
                                                 Arena& arena, UInt8* current_window_empty,
368
0
                                                 UInt8* current_window_has_inited) const {
369
    // Calculate actual frame range
370
0
    frame_start = std::max<int64_t>(frame_start, partition_start);
371
0
    frame_end = std::min<int64_t>(frame_end, partition_end);
372
373
0
    if (frame_start >= frame_end) {
374
0
        if (!*current_window_has_inited) {
375
0
            *current_window_empty = true;
376
0
        }
377
0
        return;
378
0
    }
379
380
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
381
0
    Status st = this->data(place).add(place_id, columns, frame_start, frame_end, argument_types);
382
0
    if (UNLIKELY(!st.ok())) {
383
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
384
0
    }
385
386
0
    *current_window_empty = false;
387
0
    *current_window_has_inited = true;
388
0
}
389
390
0
void AggregatePythonUDAF::reset(AggregateDataPtr place) const {
391
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
392
0
    Status st = this->data(place).reset(place_id);
393
0
    if (UNLIKELY(!st.ok())) {
394
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
395
0
    }
396
0
}
397
398
void AggregatePythonUDAF::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
399
0
                                Arena&) const {
400
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
401
0
    Status st = this->data(place).merge(this->data(rhs), place_id);
402
0
    if (UNLIKELY(!st.ok())) {
403
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
404
0
    }
405
0
}
406
407
void AggregatePythonUDAF::serialize(ConstAggregateDataPtr __restrict place,
408
0
                                    BufferWritable& buf) const {
409
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
410
0
    Status st = this->data(place).write(buf, place_id);
411
0
    if (UNLIKELY(!st.ok())) {
412
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
413
0
    }
414
0
}
415
416
void AggregatePythonUDAF::deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
417
0
                                      Arena&) const {
418
0
    this->data(place).read(buf);
419
0
}
420
421
void AggregatePythonUDAF::insert_result_into(ConstAggregateDataPtr __restrict place,
422
0
                                             IColumn& to) const {
423
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
424
0
    Status st = this->data(place).get(to, _return_type, place_id);
425
0
    if (UNLIKELY(!st.ok())) {
426
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
427
0
    }
428
0
}
429
430
} // namespace doris