Coverage Report

Created: 2026-04-01 10:28

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_python_udaf.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "exprs/aggregate/aggregate_function_python_udaf.h"
19
20
#include <arrow/array.h>
21
#include <arrow/memory_pool.h>
22
#include <arrow/record_batch.h>
23
#include <arrow/type.h>
24
#include <fmt/format.h>
25
26
#include "common/exception.h"
27
#include "common/logging.h"
28
#include "core/block/block.h"
29
#include "core/column/column_nullable.h"
30
#include "core/column/column_vector.h"
31
#include "core/data_type/data_type_factory.hpp"
32
#include "core/data_type/data_type_nullable.h"
33
#include "core/data_type/define_primitive_type.h"
34
#include "format/arrow/arrow_block_convertor.h"
35
#include "format/arrow/arrow_row_batch.h"
36
#include "runtime/user_function_cache.h"
37
#include "udf/python/python_env.h"
38
#include "udf/python/python_server.h"
39
#include "util/timezone_utils.h"
40
41
namespace doris {
42
43
0
Status AggregatePythonUDAFData::create(int64_t place) {
44
0
    DCHECK(client) << "Client must be set before calling create";
45
0
    RETURN_IF_ERROR(client->create(place));
46
0
    return Status::OK();
47
0
}
48
49
Status AggregatePythonUDAFData::add(int64_t place_id, const IColumn** columns,
50
                                    int64_t row_num_start, int64_t row_num_end,
51
0
                                    const DataTypes& argument_types) {
52
0
    DCHECK(client) << "Client must be set before calling add";
53
54
    // Zero-copy: Use full columns with range specification
55
0
    Block input_block;
56
0
    for (size_t i = 0; i < argument_types.size(); ++i) {
57
0
        input_block.insert(
58
0
                ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i)));
59
0
    }
60
61
0
    std::shared_ptr<arrow::Schema> schema;
62
0
    RETURN_IF_ERROR(
63
0
            get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
64
0
    cctz::time_zone timezone_obj;
65
0
    TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
66
67
0
    std::shared_ptr<arrow::RecordBatch> batch;
68
    // Zero-copy: convert only the specified range
69
0
    RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
70
0
                                           &batch, timezone_obj, row_num_start, row_num_end));
71
    // Send the batch (already sliced in convert_to_arrow_batch)
72
    // Single place mode: no places column needed
73
0
    RETURN_IF_ERROR(client->accumulate(place_id, true, *batch, 0, batch->num_rows()));
74
0
    return Status::OK();
75
0
}
76
77
Status AggregatePythonUDAFData::add_batch(AggregateDataPtr* places, size_t place_offset,
78
                                          size_t num_rows, const IColumn** columns,
79
                                          const DataTypes& argument_types, size_t start,
80
0
                                          size_t end) {
81
0
    DCHECK(client) << "Client must be set before calling add_batch";
82
0
    DCHECK(end > start) << "end must be greater than start";
83
0
    DCHECK(end <= num_rows) << "end must not exceed num_rows";
84
85
0
    size_t slice_rows = end - start;
86
0
    Block input_block;
87
0
    for (size_t i = 0; i < argument_types.size(); ++i) {
88
0
        DCHECK(columns[i]->size() == num_rows) << "Column size must match num_rows";
89
0
        input_block.insert(
90
0
                ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], std::to_string(i)));
91
0
    }
92
93
0
    auto places_col = ColumnInt64::create(num_rows);
94
0
    auto& places_data = places_col->get_data();
95
96
    // Fill places column with place IDs for the slice [start, end)
97
0
    for (size_t i = start; i < end; ++i) {
98
0
        places_data[i] = reinterpret_cast<int64_t>(places[i] + place_offset);
99
0
    }
100
101
0
    static DataTypePtr places_type =
102
0
            DataTypeFactory::instance().create_data_type(PrimitiveType::TYPE_BIGINT, false);
103
0
    input_block.insert(ColumnWithTypeAndName(std::move(places_col), places_type, "places"));
104
105
0
    std::shared_ptr<arrow::Schema> schema;
106
0
    RETURN_IF_ERROR(
107
0
            get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
108
0
    cctz::time_zone timezone_obj;
109
0
    TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
110
111
0
    std::shared_ptr<arrow::RecordBatch> batch;
112
    // Zero-copy: convert only the [start, end) range
113
    // This slice includes the places column automatically
114
0
    RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
115
0
                                           &batch, timezone_obj, start, end));
116
    // Send entire batch (already contains places column) to Python
117
    // place_id=0 is ignored when is_single_place=false
118
0
    RETURN_IF_ERROR(client->accumulate(0, false, *batch, 0, slice_rows));
119
0
    return Status::OK();
120
0
}
121
122
0
Status AggregatePythonUDAFData::merge(const AggregatePythonUDAFData& rhs, int64_t place) {
123
0
    DCHECK(client) << "Client must be set before calling merge";
124
125
    // Get serialized state from rhs (already stored in serialize_data by read())
126
0
    auto serialized_state = arrow::Buffer::Wrap(
127
0
            reinterpret_cast<const uint8_t*>(rhs.serialize_data.data()), rhs.serialize_data.size());
128
0
    RETURN_IF_ERROR(client->merge(place, serialized_state));
129
0
    return Status::OK();
130
0
}
131
132
0
Status AggregatePythonUDAFData::write(BufferWritable& buf, int64_t place) const {
133
0
    DCHECK(client) << "Client must be set before calling write";
134
135
    // Serialize state from Python server
136
0
    std::shared_ptr<arrow::Buffer> serialized_state;
137
0
    RETURN_IF_ERROR(client->serialize(place, &serialized_state));
138
0
    const char* data = reinterpret_cast<const char*>(serialized_state->data());
139
0
    size_t size = serialized_state->size();
140
0
    buf.write_binary(StringRef {data, size});
141
0
    return Status::OK();
142
0
}
143
144
0
void AggregatePythonUDAFData::read(BufferReadable& buf) {
145
    // Read serialized state from buffer into serialize_data
146
    // This will be used later by merge() in deserialize_and_merge()
147
0
    buf.read_binary(serialize_data);
148
0
}
149
150
0
Status AggregatePythonUDAFData::reset(int64_t place) {
151
0
    DCHECK(client) << "Client must be set before calling reset";
152
0
    RETURN_IF_ERROR(client->reset(place));
153
    // After reset, state still exists but is back to initial state
154
0
    return Status::OK();
155
0
}
156
157
0
Status AggregatePythonUDAFData::destroy(int64_t place) {
158
0
    DCHECK(client) << "Client must be set before calling destroy";
159
0
    RETURN_IF_ERROR(client->destroy(place));
160
0
    return Status::OK();
161
0
}
162
163
Status AggregatePythonUDAFData::get(IColumn& to, const DataTypePtr& result_type,
164
0
                                    int64_t place) const {
165
0
    DCHECK(client) << "Client must be set before calling get";
166
167
    // Get final result from Python server
168
0
    std::shared_ptr<arrow::RecordBatch> result;
169
0
    RETURN_IF_ERROR(client->finalize(place, &result));
170
171
    // Convert Arrow RecordBatch to Block
172
0
    Block result_block;
173
0
    DataTypes types = {result_type};
174
0
    cctz::time_zone timezone_obj;
175
0
    TimezoneUtils::find_cctz_time_zone(TimezoneUtils::default_time_zone, timezone_obj);
176
0
    RETURN_IF_ERROR(convert_from_arrow_batch(result, types, &result_block, timezone_obj));
177
178
    // Insert the result value into output column
179
0
    if (result_block.rows() != 1) {
180
0
        return Status::InternalError("Expected 1 row in result block, got {}", result_block.rows());
181
0
    }
182
183
0
    auto& result_column = result_block.get_by_position(0).column;
184
0
    to.insert_from(*result_column, 0);
185
0
    return Status::OK();
186
0
}
187
188
0
Status AggregatePythonUDAF::open() {
189
    // Build function metadata from TFunction
190
0
    _func_meta.id = _fn.id;
191
0
    _func_meta.name = _fn.name.function_name;
192
193
    // For UDAF, symbol is in aggregate_fn
194
0
    if (_fn.__isset.aggregate_fn && _fn.aggregate_fn.__isset.symbol) {
195
0
        _func_meta.symbol = _fn.aggregate_fn.symbol;
196
0
    } else {
197
0
        return Status::InvalidArgument("Python UDAF symbol is not set");
198
0
    }
199
200
    // Determine load type (inline code or module)
201
0
    if (!_fn.function_code.empty()) {
202
0
        _func_meta.type = PythonUDFLoadType::INLINE;
203
0
        _func_meta.location = "inline";
204
0
        _func_meta.inline_code = _fn.function_code;
205
0
    } else if (!_fn.hdfs_location.empty()) {
206
0
        _func_meta.type = PythonUDFLoadType::MODULE;
207
0
        _func_meta.location = _fn.hdfs_location;
208
0
        _func_meta.checksum = _fn.checksum;
209
0
    } else {
210
0
        _func_meta.type = PythonUDFLoadType::UNKNOWN;
211
0
        _func_meta.location = "unknown";
212
0
    }
213
214
0
    _func_meta.input_types = argument_types;
215
0
    _func_meta.return_type = _return_type;
216
0
    _func_meta.client_type = PythonClientType::UDAF;
217
218
    // Get Python version
219
0
    if (_fn.__isset.runtime_version && !_fn.runtime_version.empty()) {
220
0
        RETURN_IF_ERROR(PythonVersionManager::instance().get_version(_fn.runtime_version,
221
0
                                                                     &_python_version));
222
0
    } else {
223
0
        return Status::InvalidArgument("Python UDAF runtime version is not set");
224
0
    }
225
226
0
    _func_meta.runtime_version = _python_version.full_version;
227
0
    RETURN_IF_ERROR(_func_meta.check());
228
0
    _func_meta.always_nullable = _return_type->is_nullable();
229
230
0
    LOG(INFO) << fmt::format("Creating Python UDAF: {}, runtime_version: {}, func_meta: {}",
231
0
                             _fn.name.function_name, _python_version.to_string(),
232
0
                             _func_meta.to_string());
233
234
0
    if (_func_meta.type == PythonUDFLoadType::MODULE) {
235
0
        RETURN_IF_ERROR(UserFunctionCache::instance()->get_pypath(
236
0
                _func_meta.id, _func_meta.location, _func_meta.checksum, &_func_meta.location));
237
0
    }
238
239
0
    return Status::OK();
240
0
}
241
242
0
void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const {
243
0
    std::call_once(_schema_init_flag, [this]() {
244
0
        std::vector<std::shared_ptr<arrow::Field>> fields;
245
246
0
        std::string timezone = TimezoneUtils::default_time_zone;
247
0
        for (size_t i = 0; i < argument_types.size(); ++i) {
248
0
            std::shared_ptr<arrow::DataType> arrow_type;
249
0
            Status st = convert_to_arrow_type(argument_types[i], &arrow_type, timezone);
250
0
            if (!st.ok()) {
251
0
                throw doris::Exception(ErrorCode::INTERNAL_ERROR,
252
0
                                       "Failed to convert argument type {} to Arrow type: {}", i,
253
0
                                       st.to_string());
254
0
            }
255
0
            fields.push_back(create_arrow_field_with_metadata(
256
0
                    std::to_string(i), arrow_type, argument_types[i]->is_nullable(),
257
0
                    argument_types[i]->get_primitive_type()));
258
0
        }
259
260
        // Add places column for GROUP BY aggregation (always included, NULL in single-place mode)
261
0
        fields.push_back(arrow::field("places", arrow::int64()));
262
        // Add binary_data column for merge operations
263
0
        fields.push_back(arrow::field("binary_data", arrow::binary()));
264
0
        _schema = arrow::schema(fields);
265
0
    });
266
267
    // Initialize the data structure
268
0
    new (place) Data();
269
0
    DCHECK(reinterpret_cast<Data*>(place)) << "Place must not be null";
270
271
0
    if (Status st = PythonServerManager::instance().get_client(
272
0
                _func_meta, _python_version, &(this->data(place).client), _schema);
273
0
        UNLIKELY(!st.ok())) {
274
0
        this->data(place).~Data();
275
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Failed to get Python UDAF client: {}",
276
0
                               st.to_string());
277
0
    }
278
279
    // Initialize UDAF state in Python server
280
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
281
0
    if (Status st = this->data(place).create(place_id); UNLIKELY(!st.ok())) {
282
0
        this->data(place).~Data();
283
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
284
0
    }
285
0
}
286
287
0
void AggregatePythonUDAF::destroy(AggregateDataPtr __restrict place) const noexcept {
288
0
    try {
289
0
        int64_t place_id = reinterpret_cast<int64_t>(place);
290
291
        // Destroy state in Python server
292
0
        if (this->data(place).client) {
293
0
            Status st = this->data(place).destroy(place_id);
294
0
            if (UNLIKELY(!st.ok())) {
295
0
                LOG(WARNING) << "Failed to destroy Python UDAF state for place_id=" << place_id
296
0
                             << ", function=" << _func_meta.name << ": " << st.to_string();
297
0
            }
298
299
0
            this->data(place).client.reset();
300
0
        }
301
302
0
        this->data(place).~Data();
303
0
    } catch (const std::exception& e) {
304
0
        LOG(ERROR) << "Exception in AggregatePythonUDAF::destroy: " << e.what();
305
0
    } catch (...) {
306
0
        LOG(ERROR) << "Unknown exception in AggregatePythonUDAF::destroy";
307
0
    }
308
0
}
309
310
void AggregatePythonUDAF::add(AggregateDataPtr __restrict place, const IColumn** columns,
311
0
                              ssize_t row_num, Arena&) const {
312
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
313
0
    Status st = this->data(place).add(place_id, columns, row_num, row_num + 1, argument_types);
314
0
    if (UNLIKELY(!st.ok())) {
315
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
316
0
    }
317
0
}
318
319
void AggregatePythonUDAF::add_batch(size_t batch_size, AggregateDataPtr* places,
320
                                    size_t place_offset, const IColumn** columns, Arena&,
321
0
                                    bool /*agg_many*/) const {
322
0
    if (batch_size == 0) return;
323
324
0
    size_t start = 0;
325
0
    while (start < batch_size) {
326
        // Get the starting place for this segment
327
0
        AggregateDataPtr start_place = places[start] + place_offset;
328
0
        auto& start_place_data = this->data(start_place);
329
        // Get the process for this segment
330
0
        const auto* current_process = start_place_data.client->get_process().get();
331
332
        // Scan forward to find the end of this consecutive segment (same process)
333
0
        size_t end = start + 1;
334
0
        while (end < batch_size) {
335
0
            AggregateDataPtr end_place = places[end] + place_offset;
336
0
            auto& end_place_data = this->data(end_place);
337
0
            const auto* next_process = end_place_data.client->get_process().get();
338
            // If different process, end the current segment
339
0
            if (*next_process != *current_process) break;
340
0
            ++end;
341
0
        }
342
343
        // Send this segment to Python with zero-copy
344
        // Pass places array and let add_batch construct place_ids on-demand
345
0
        Status st = start_place_data.add_batch(places, place_offset, batch_size, columns,
346
0
                                               argument_types, start, end);
347
348
0
        if (UNLIKELY(!st.ok())) {
349
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR,
350
0
                                   "Failed to send segment to Python: " + st.to_string());
351
0
        }
352
353
0
        start = end;
354
0
    }
355
0
}
356
357
void AggregatePythonUDAF::add_batch_single_place(size_t batch_size, AggregateDataPtr place,
358
0
                                                 const IColumn** columns, Arena&) const {
359
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
360
0
    Status st = this->data(place).add(place_id, columns, 0, batch_size, argument_types);
361
0
    if (UNLIKELY(!st.ok())) {
362
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
363
0
    }
364
0
}
365
366
void AggregatePythonUDAF::add_range_single_place(int64_t partition_start, int64_t partition_end,
367
                                                 int64_t frame_start, int64_t frame_end,
368
                                                 AggregateDataPtr place, const IColumn** columns,
369
                                                 Arena& arena, UInt8* current_window_empty,
370
0
                                                 UInt8* current_window_has_inited) const {
371
    // Calculate actual frame range
372
0
    frame_start = std::max<int64_t>(frame_start, partition_start);
373
0
    frame_end = std::min<int64_t>(frame_end, partition_end);
374
375
0
    if (frame_start >= frame_end) {
376
0
        if (!*current_window_has_inited) {
377
0
            *current_window_empty = true;
378
0
        }
379
0
        return;
380
0
    }
381
382
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
383
0
    Status st = this->data(place).add(place_id, columns, frame_start, frame_end, argument_types);
384
0
    if (UNLIKELY(!st.ok())) {
385
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
386
0
    }
387
388
0
    *current_window_empty = false;
389
0
    *current_window_has_inited = true;
390
0
}
391
392
0
void AggregatePythonUDAF::reset(AggregateDataPtr place) const {
393
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
394
0
    Status st = this->data(place).reset(place_id);
395
0
    if (UNLIKELY(!st.ok())) {
396
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
397
0
    }
398
0
}
399
400
void AggregatePythonUDAF::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
401
0
                                Arena&) const {
402
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
403
0
    Status st = this->data(place).merge(this->data(rhs), place_id);
404
0
    if (UNLIKELY(!st.ok())) {
405
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
406
0
    }
407
0
}
408
409
void AggregatePythonUDAF::serialize(ConstAggregateDataPtr __restrict place,
410
0
                                    BufferWritable& buf) const {
411
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
412
0
    Status st = this->data(place).write(buf, place_id);
413
0
    if (UNLIKELY(!st.ok())) {
414
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
415
0
    }
416
0
}
417
418
void AggregatePythonUDAF::deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
419
0
                                      Arena&) const {
420
0
    this->data(place).read(buf);
421
0
}
422
423
void AggregatePythonUDAF::insert_result_into(ConstAggregateDataPtr __restrict place,
424
0
                                             IColumn& to) const {
425
0
    int64_t place_id = reinterpret_cast<int64_t>(place);
426
0
    Status st = this->data(place).get(to, _return_type, place_id);
427
0
    if (UNLIKELY(!st.ok())) {
428
0
        throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
429
0
    }
430
0
}
431
432
} // namespace doris