Coverage Report

Created: 2026-05-12 20:22

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/vectorized_agg_fn.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "exprs/vectorized_agg_fn.h"
19
20
#include <fmt/format.h>
21
#include <fmt/ranges.h> // IWYU pragma: keep
22
#include <gen_cpp/Exprs_types.h>
23
#include <gen_cpp/PlanNodes_types.h>
24
#include <glog/logging.h>
25
26
#include <memory>
27
#include <ostream>
28
#include <string_view>
29
30
#include "common/config.h"
31
#include "common/object_pool.h"
32
#include "core/block/block.h"
33
#include "core/block/column_with_type_and_name.h"
34
#include "core/block/materialize_block.h"
35
#include "core/data_type/data_type_agg_state.h"
36
#include "core/data_type/data_type_factory.hpp"
37
#include "exec/common/util.hpp"
38
#include "exprs/aggregate/aggregate_function_ai_agg.h"
39
#include "exprs/aggregate/aggregate_function_java_udaf.h"
40
#include "exprs/aggregate/aggregate_function_python_udaf.h"
41
#include "exprs/aggregate/aggregate_function_rpc.h"
42
#include "exprs/aggregate/aggregate_function_simple_factory.h"
43
#include "exprs/aggregate/aggregate_function_sort.h"
44
#include "exprs/aggregate/aggregate_function_state_merge.h"
45
#include "exprs/aggregate/aggregate_function_state_union.h"
46
#include "exprs/vexpr.h"
47
#include "exprs/vexpr_context.h"
48
49
static constexpr int64_t BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK = 8;
50
51
namespace doris {
52
class RowDescriptor;
53
class Arena;
54
class BufferWritable;
55
class IColumn;
56
} // namespace doris
57
58
namespace doris {
59
60
template <class FunctionType>
61
AggregateFunctionPtr get_agg_state_function(const DataTypes& argument_types,
62
40
                                            DataTypePtr return_type) {
63
40
    return FunctionType::create(
64
40
            assert_cast<const DataTypeAggState*>(argument_types[0].get())->get_nested_function(),
65
40
            argument_types, return_type);
66
40
}
_ZN5doris22get_agg_state_functionINS_19AggregateStateUnionEEESt10shared_ptrINS_18IAggregateFunctionEERKSt6vectorIS2_IKNS_9IDataTypeEESaIS8_EES8_
Line
Count
Source
62
12
                                            DataTypePtr return_type) {
63
12
    return FunctionType::create(
64
12
            assert_cast<const DataTypeAggState*>(argument_types[0].get())->get_nested_function(),
65
12
            argument_types, return_type);
66
12
}
_ZN5doris22get_agg_state_functionINS_19AggregateStateMergeEEESt10shared_ptrINS_18IAggregateFunctionEERKSt6vectorIS2_IKNS_9IDataTypeEESaIS8_EES8_
Line
Count
Source
62
28
                                            DataTypePtr return_type) {
63
28
    return FunctionType::create(
64
28
            assert_cast<const DataTypeAggState*>(argument_types[0].get())->get_nested_function(),
65
28
            argument_types, return_type);
66
28
}
67
68
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, const bool without_key,
69
                               const bool is_window_function)
70
42.1k
        : _fn(desc.fn),
71
42.1k
          _is_merge(desc.agg_expr.is_merge_agg),
72
42.1k
          _without_key(without_key),
73
42.1k
          _is_window_function(is_window_function),
74
42.1k
          _data_type(DataTypeFactory::instance().create_data_type(
75
18.4E
                  desc.fn.ret_type, desc.__isset.is_nullable ? desc.is_nullable : true)) {
76
42.1k
    if (desc.agg_expr.__isset.param_types) {
77
42.1k
        const auto& param_types = desc.agg_expr.param_types;
78
42.1k
        for (const auto& param_type : param_types) {
79
28.8k
            _argument_types_with_sort.push_back(
80
28.8k
                    DataTypeFactory::instance().create_data_type(param_type));
81
28.8k
        }
82
42.1k
    }
83
42.1k
}
84
85
Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
86
                              const bool without_key, const bool is_window_function,
87
42.1k
                              AggFnEvaluator** result) {
88
42.1k
    *result =
89
42.1k
            pool->add(AggFnEvaluator::create_unique(desc.nodes[0], without_key, is_window_function)
90
42.1k
                              .release());
91
42.1k
    auto& agg_fn_evaluator = *result;
92
42.1k
    int node_idx = 0;
93
78.1k
    for (int i = 0; i < desc.nodes[0].num_children; ++i) {
94
36.0k
        ++node_idx;
95
36.0k
        VExprSPtr expr;
96
36.0k
        VExprContextSPtr ctx;
97
36.0k
        RETURN_IF_ERROR(VExpr::create_tree_from_thrift(desc.nodes, &node_idx, expr, ctx));
98
36.0k
        agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx);
99
36.0k
    }
100
101
42.1k
    auto sort_size = sort_info.ordering_exprs.size();
102
42.1k
    auto real_arguments_size = agg_fn_evaluator->_argument_types_with_sort.size() - sort_size;
103
    // Child arguments contains [real arguments, order by arguments], we pass the arguments
104
    // to the order by functions
105
42.1k
    for (int i = 0; i < sort_size; ++i) {
106
0
        agg_fn_evaluator->_sort_description.emplace_back(real_arguments_size + i,
107
0
                                                         sort_info.is_asc_order[i] ? 1 : -1,
108
0
                                                         sort_info.nulls_first[i] ? -1 : 1);
109
0
    }
110
111
    // Pass the real arguments to get functions
112
71.0k
    for (int i = 0; i < real_arguments_size; ++i) {
113
28.8k
        agg_fn_evaluator->_real_argument_types.emplace_back(
114
28.8k
                agg_fn_evaluator->_argument_types_with_sort[i]);
115
28.8k
    }
116
42.1k
    return Status::OK();
117
42.1k
}
118
119
Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
120
                               const SlotDescriptor* intermediate_slot_desc,
121
42.1k
                               const SlotDescriptor* output_slot_desc) {
122
42.1k
    DCHECK(intermediate_slot_desc != nullptr);
123
42.1k
    DCHECK(_intermediate_slot_desc == nullptr);
124
42.1k
    _output_slot_desc = output_slot_desc;
125
42.1k
    _intermediate_slot_desc = intermediate_slot_desc;
126
127
42.1k
    Status status = VExpr::prepare(_input_exprs_ctxs, state, desc);
128
42.1k
    RETURN_IF_ERROR(status);
129
130
42.1k
    DataTypes tmp_argument_types;
131
42.1k
    tmp_argument_types.reserve(_input_exprs_ctxs.size());
132
133
42.1k
    std::vector<std::string_view> child_expr_name;
134
135
    // prepare for argument
136
42.1k
    for (auto& _input_exprs_ctx : _input_exprs_ctxs) {
137
36.0k
        auto data_type = _input_exprs_ctx->root()->data_type();
138
36.0k
        tmp_argument_types.emplace_back(data_type);
139
36.0k
        child_expr_name.emplace_back(_input_exprs_ctx->root()->expr_name());
140
36.0k
    }
141
142
42.1k
    std::vector<std::string> column_names;
143
42.1k
    for (const auto& expr_ctx : _input_exprs_ctxs) {
144
36.0k
        const auto& root = expr_ctx->root();
145
36.0k
        if (!root->expr_name().empty() && !root->is_constant()) {
146
12.8k
            column_names.emplace_back(root->expr_name());
147
12.8k
        }
148
36.0k
    }
149
150
42.1k
    const DataTypes& argument_types =
151
42.1k
            _real_argument_types.empty() ? tmp_argument_types : _real_argument_types;
152
153
42.1k
    if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
154
0
        if (config::enable_java_support) {
155
0
            _function = AggregateJavaUdaf::create(_fn, argument_types, _data_type);
156
0
            RETURN_IF_ERROR(static_cast<AggregateJavaUdaf*>(_function.get())->check_udaf(_fn));
157
0
        } else {
158
0
            return Status::InternalError(
159
0
                    "Java UDAF is not enabled, you can change be config enable_java_support to "
160
0
                    "true and restart be.");
161
0
        }
162
42.1k
    } else if (_fn.binary_type == TFunctionBinaryType::PYTHON_UDF) {
163
0
        if (config::enable_python_udf_support) {
164
0
            _function = AggregatePythonUDAF::create(_fn, argument_types, _data_type);
165
0
            RETURN_IF_ERROR(static_cast<AggregatePythonUDAF*>(_function.get())->open());
166
0
            LOG(INFO) << fmt::format(
167
0
                    "Created Python UDAF: {}, runtime_version: {}, function_code: {}",
168
0
                    _fn.name.function_name, _fn.runtime_version, _fn.function_code);
169
0
        } else {
170
0
            return Status::InternalError(
171
0
                    "Python UDAF is not enabled, you can change be config "
172
0
                    "enable_python_udf_support to true and restart be.");
173
0
        }
174
42.1k
    } else if (_fn.binary_type == TFunctionBinaryType::RPC) {
175
0
        _function = AggregateRpcUdaf::create(_fn, argument_types, _data_type);
176
42.1k
    } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
177
40
        if (argument_types.size() != 1) {
178
0
            return Status::InternalError("Agg state Function must input 1 argument but get {}",
179
0
                                         argument_types.size());
180
0
        }
181
40
        if (argument_types[0]->is_nullable()) {
182
0
            return Status::InternalError("Agg state function input type must be not nullable");
183
0
        }
184
40
        if (argument_types[0]->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
185
0
            return Status::InternalError(
186
0
                    "Agg state function input type must be agg_state but get {}",
187
0
                    argument_types[0]->get_family_name());
188
0
        }
189
190
40
        std::string type_function_name =
191
40
                assert_cast<const DataTypeAggState*>(argument_types[0].get())->get_function_name();
192
40
        if (type_function_name + AGG_UNION_SUFFIX == _fn.name.function_name) {
193
12
            if (_data_type->is_nullable()) {
194
0
                return Status::InternalError(
195
0
                        "Union function return type must be not nullable, real={}",
196
0
                        _data_type->get_name());
197
0
            }
198
12
            if (_data_type->get_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
199
0
                return Status::InternalError(
200
0
                        "Union function return type must be AGG_STATE, real={}",
201
0
                        _data_type->get_name());
202
0
            }
203
12
            _function = get_agg_state_function<AggregateStateUnion>(argument_types, _data_type);
204
28
        } else if (type_function_name + AGG_MERGE_SUFFIX == _fn.name.function_name) {
205
28
            auto type = assert_cast<const DataTypeAggState*>(argument_types[0].get())
206
28
                                ->get_nested_function()
207
28
                                ->get_return_type();
208
28
            if (!type->equals(*_data_type)) {
209
0
                return Status::InternalError("{}'s expect return type is {}, but input {}",
210
0
                                             argument_types[0]->get_name(), type->get_name(),
211
0
                                             _data_type->get_name());
212
0
            }
213
28
            _function = get_agg_state_function<AggregateStateMerge>(argument_types, _data_type);
214
28
        } else {
215
0
            return Status::InternalError("{} not match function {}", argument_types[0]->get_name(),
216
0
                                         _fn.name.function_name);
217
0
        }
218
42.1k
    } else {
219
42.1k
        const bool is_foreach =
220
42.1k
                AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name) ||
221
42.1k
                AggregateFunctionSimpleFactory::is_foreachv2(_fn.name.function_name);
222
        // Here, only foreachv1 needs special treatment, and v2 can follow the normal code logic.
223
42.1k
        if (AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name)) {
224
0
            _function = AggregateFunctionSimpleFactory::instance().get(
225
0
                    _fn.name.function_name, argument_types, _data_type,
226
0
                    AggregateFunctionSimpleFactory::result_nullable_by_foreach(_data_type),
227
0
                    state->be_exec_version(),
228
0
                    {.is_window_function = _is_window_function,
229
0
                     .is_foreach = is_foreach,
230
0
                     .enable_aggregate_function_null_v2 =
231
0
                             state->enable_aggregate_function_null_v2(),
232
0
                     .new_version_percentile =
233
0
                             state->query_options().__isset.new_version_percentile &&
234
0
                             state->query_options().new_version_percentile,
235
0
                     .column_names = std::move(column_names)});
236
42.1k
        } else {
237
42.1k
            _function = AggregateFunctionSimpleFactory::instance().get(
238
42.1k
                    _fn.name.function_name, argument_types, _data_type, _data_type->is_nullable(),
239
42.1k
                    state->be_exec_version(),
240
42.1k
                    {.is_window_function = _is_window_function,
241
42.1k
                     .is_foreach = is_foreach,
242
42.1k
                     .enable_aggregate_function_null_v2 =
243
42.1k
                             state->enable_aggregate_function_null_v2(),
244
42.1k
                     .new_version_percentile =
245
42.1k
                             state->query_options().__isset.new_version_percentile &&
246
42.1k
                             state->query_options().new_version_percentile,
247
42.1k
                     .column_names = std::move(column_names)});
248
42.1k
        }
249
42.1k
    }
250
42.1k
    if (_function == nullptr) {
251
0
        return Status::InternalError("Agg Function {} is not implemented", _fn.signature);
252
0
    }
253
254
42.1k
    if (!_sort_description.empty()) {
255
0
        _function = transform_to_sort_agg_function(_function, _argument_types_with_sort,
256
0
                                                   _sort_description, state);
257
0
    }
258
259
42.1k
    if (_fn.name.function_name == "ai_agg") {
260
0
        _function->set_query_context(state->get_query_ctx());
261
0
    }
262
263
    // Foreachv2, like foreachv1, does not check the return type,
264
    // because its return type is related to the internal agg.
265
42.1k
    if (!AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name) &&
266
42.1k
        !AggregateFunctionSimpleFactory::is_foreachv2(_fn.name.function_name)) {
267
42.1k
        if (state->be_exec_version() >= BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK) {
268
42.1k
            RETURN_IF_ERROR(
269
42.1k
                    _function->verify_result_type(_without_key, argument_types, _data_type));
270
42.1k
        }
271
42.1k
    }
272
42.1k
    _expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name);
273
42.1k
    return Status::OK();
274
42.1k
}
275
276
42.1k
Status AggFnEvaluator::open(RuntimeState* state) {
277
42.1k
    return VExpr::open(_input_exprs_ctxs, state);
278
42.1k
}
279
280
1.60M
void AggFnEvaluator::create(AggregateDataPtr place) {
281
1.60M
    _function->create(place);
282
1.60M
}
283
284
257
void AggFnEvaluator::destroy(AggregateDataPtr place) {
285
257
    _function->destroy(place);
286
257
}
287
288
74.7k
Status AggFnEvaluator::execute_single_add(Block* block, AggregateDataPtr place, Arena& arena) {
289
74.7k
    RETURN_IF_ERROR(_calc_argument_columns(block));
290
74.7k
    _function->add_batch_single_place(block->rows(), place, _agg_columns.data(), arena);
291
74.7k
    return Status::OK();
292
74.7k
}
293
294
Status AggFnEvaluator::execute_batch_add(Block* block, size_t offset, AggregateDataPtr* places,
295
23.0k
                                         Arena& arena, bool agg_many) {
296
23.0k
    RETURN_IF_ERROR(_calc_argument_columns(block));
297
23.0k
    _function->add_batch(block->rows(), places, offset, _agg_columns.data(), arena, agg_many);
298
23.0k
    return Status::OK();
299
23.0k
}
300
301
Status AggFnEvaluator::execute_batch_add_selected(Block* block, size_t offset,
302
2
                                                  AggregateDataPtr* places, Arena& arena) {
303
2
    RETURN_IF_ERROR(_calc_argument_columns(block));
304
2
    _function->add_batch_selected(block->rows(), places, offset, _agg_columns.data(), arena);
305
2
    return Status::OK();
306
2
}
307
308
Status AggFnEvaluator::streaming_agg_serialize_to_column(Block* block, MutableColumnPtr& dst,
309
24
                                                         const size_t num_rows, Arena& arena) {
310
24
    RETURN_IF_ERROR(_calc_argument_columns(block));
311
24
    _function->streaming_agg_serialize_to_column(_agg_columns.data(), dst, num_rows, arena);
312
24
    return Status::OK();
313
24
}
314
315
18.8k
void AggFnEvaluator::insert_result_info(AggregateDataPtr place, IColumn* column) {
316
18.8k
    _function->insert_result_into(place, *column);
317
18.8k
}
318
319
void AggFnEvaluator::insert_result_info_vec(const std::vector<AggregateDataPtr>& places,
320
6.77k
                                            size_t offset, IColumn* column, const size_t num_rows) {
321
6.77k
    _function->insert_result_into_vec(places, offset, *column, num_rows);
322
6.77k
}
323
324
151
void AggFnEvaluator::reset(AggregateDataPtr place) {
325
151
    _function->reset(place);
326
151
}
327
328
0
std::string AggFnEvaluator::debug_string(const std::vector<AggFnEvaluator*>& exprs) {
329
0
    std::stringstream out;
330
0
    out << "[";
331
332
0
    for (int i = 0; i < exprs.size(); ++i) {
333
0
        out << (i == 0 ? "" : " ") << exprs[i]->debug_string();
334
0
    }
335
336
0
    out << "]";
337
0
    return out.str();
338
0
}
339
340
0
std::string AggFnEvaluator::debug_string() const {
341
0
    std::stringstream out;
342
0
    out << "AggFnEvaluator(";
343
0
    out << _fn.signature;
344
0
    out << ")";
345
0
    return out.str();
346
0
}
347
348
97.8k
Status AggFnEvaluator::_calc_argument_columns(Block* block) {
349
97.8k
    SCOPED_TIMER(_expr_timer);
350
97.8k
    _agg_columns.resize(_input_exprs_ctxs.size());
351
97.8k
    std::vector<int> column_ids(_input_exprs_ctxs.size());
352
171k
    for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
353
74.0k
        int column_id = -1;
354
74.0k
        RETURN_IF_ERROR(_input_exprs_ctxs[i]->execute(block, &column_id));
355
74.0k
        column_ids[i] = column_id;
356
74.0k
    }
357
97.8k
    materialize_block_inplace(*block, column_ids.data(),
358
97.8k
                              column_ids.data() + _input_exprs_ctxs.size());
359
171k
    for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
360
74.0k
        _agg_columns[i] = block->get_by_position(column_ids[i]).column.get();
361
74.0k
    }
362
97.8k
    return Status::OK();
363
97.8k
}
364
365
73.4k
AggFnEvaluator* AggFnEvaluator::clone(RuntimeState* state, ObjectPool* pool) {
366
73.4k
    return pool->add(AggFnEvaluator::create_unique(*this, state).release());
367
73.4k
}
368
369
AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state)
370
73.4k
        : _fn(evaluator._fn),
371
73.4k
          _is_merge(evaluator._is_merge),
372
73.4k
          _without_key(evaluator._without_key),
373
73.4k
          _is_window_function(evaluator._is_window_function),
374
73.4k
          _argument_types_with_sort(evaluator._argument_types_with_sort),
375
73.4k
          _real_argument_types(evaluator._real_argument_types),
376
73.4k
          _intermediate_slot_desc(evaluator._intermediate_slot_desc),
377
73.4k
          _output_slot_desc(evaluator._output_slot_desc),
378
73.4k
          _sort_description(evaluator._sort_description),
379
73.4k
          _data_type(evaluator._data_type),
380
73.4k
          _function(evaluator._function),
381
73.4k
          _expr_name(evaluator._expr_name),
382
73.4k
          _agg_columns(evaluator._agg_columns) {
383
73.4k
    if (evaluator._fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
384
0
        DataTypes tmp_argument_types;
385
0
        tmp_argument_types.reserve(evaluator._input_exprs_ctxs.size());
386
        // prepare for argument
387
0
        for (auto& _input_exprs_ctx : evaluator._input_exprs_ctxs) {
388
0
            auto data_type = _input_exprs_ctx->root()->data_type();
389
0
            tmp_argument_types.emplace_back(data_type);
390
0
        }
391
0
        const DataTypes& argument_types =
392
0
                _real_argument_types.empty() ? tmp_argument_types : _real_argument_types;
393
0
        _function = AggregateJavaUdaf::create(evaluator._fn, argument_types, evaluator._data_type);
394
0
        THROW_IF_ERROR(static_cast<AggregateJavaUdaf*>(_function.get())->check_udaf(evaluator._fn));
395
0
    }
396
73.4k
    DCHECK(_function != nullptr);
397
398
73.4k
    _input_exprs_ctxs.resize(evaluator._input_exprs_ctxs.size());
399
123k
    for (size_t i = 0; i < _input_exprs_ctxs.size(); i++) {
400
49.7k
        WARN_IF_ERROR(evaluator._input_exprs_ctxs[i]->clone(state, _input_exprs_ctxs[i]), "");
401
49.7k
    }
402
73.4k
}
403
404
Status AggFnEvaluator::check_agg_fn_output(uint32_t key_size,
405
                                           const std::vector<AggFnEvaluator*>& agg_fn,
406
11.0k
                                           const RowDescriptor& output_row_desc) {
407
11.0k
    auto name_and_types = VectorizedUtils::create_name_and_data_types(output_row_desc);
408
32.9k
    for (uint32_t i = key_size, j = 0; i < name_and_types.size(); i++, j++) {
409
21.8k
        auto&& [name, column_type] = name_and_types[i];
410
21.8k
        auto agg_return_type = agg_fn[j]->function()->get_return_type();
411
21.8k
        if (!column_type->equals(*agg_return_type)) {
412
2.10k
            if (!column_type->is_nullable() || agg_return_type->is_nullable() ||
413
2.10k
                !remove_nullable(column_type)->equals(*agg_return_type)) {
414
0
                return Status::InternalError(
415
0
                        "column_type not match data_types in agg node, column_type={}, "
416
0
                        "data_types={},column name={}",
417
0
                        column_type->get_name(), agg_return_type->get_name(), name);
418
0
            }
419
2.10k
        }
420
21.8k
    }
421
11.0k
    return Status::OK();
422
11.0k
}
423
424
297k
bool AggFnEvaluator::is_blockable() const {
425
297k
    return _function->is_blockable() ||
426
297k
           std::any_of(_input_exprs_ctxs.begin(), _input_exprs_ctxs.end(),
427
297k
                       [](VExprContextSPtr ctx) { return ctx->root()->is_blockable(); });
428
297k
}
429
430
} // namespace doris