Coverage Report

Created: 2026-06-23 16:02

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/function.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/IFunction.cpp
19
// and modified by Doris
20
21
#include "exprs/function/function.h"
22
23
#include <algorithm>
24
#include <memory>
25
#include <numeric>
26
27
#include "common/status.h"
28
#include "core/assert_cast.h"
29
#include "core/column/column.h"
30
#include "core/column/column_const.h"
31
#include "core/column/column_nullable.h"
32
#include "core/column/column_vector.h"
33
#include "core/data_type/data_type_array.h"
34
#include "core/data_type/data_type_nothing.h"
35
#include "core/data_type/data_type_nullable.h"
36
#include "core/data_type/define_primitive_type.h"
37
#include "core/data_type/primitive_type.h"
38
#include "core/field.h"
39
#include "exec/common/util.hpp"
40
#include "exprs/aggregate/aggregate_function.h"
41
#include "exprs/function/function_helpers.h"
42
#include "storage/index/zone_map/zonemap_eval_context.h"
43
44
namespace doris {
45
ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const ColumnNumbers& args,
46
2.02M
                           size_t input_rows_count) {
47
2.02M
    ColumnPtr result_null_map_column;
48
    /// If result is already nullable.
49
2.02M
    ColumnPtr src_not_nullable = src;
50
2.02M
    MutableColumnPtr mutable_result_null_map_column;
51
52
2.02M
    if (auto nullable = check_and_get_column_ptr<ColumnNullable>(src)) {
53
1.03M
        src_not_nullable = nullable->get_nested_column_ptr();
54
1.03M
        result_null_map_column = nullable->get_null_map_column_ptr();
55
1.03M
    }
56
57
2.61M
    for (const auto& arg : args) {
58
2.61M
        const ColumnWithTypeAndName& elem = block.get_by_position(arg);
59
2.61M
        if (!elem.type->is_nullable() || is_column_const(*elem.column)) {
60
592k
            continue;
61
592k
        }
62
63
2.02M
        if (auto nullable = cast_to_column<ColumnNullable>(elem.column); nullable->has_null()) {
64
179k
            const ColumnPtr& null_map_column = nullable->get_null_map_column_ptr();
65
179k
            if (!result_null_map_column) { // NOLINT(bugprone-use-after-move)
66
152k
                result_null_map_column = null_map_column;
67
152k
                continue;
68
152k
            }
69
70
27.3k
            if (!mutable_result_null_map_column) {
71
26.0k
                mutable_result_null_map_column = (*std::move(result_null_map_column)).mutate();
72
26.0k
            }
73
74
27.3k
            NullMap& result_null_map =
75
27.3k
                    assert_cast<ColumnUInt8&>(*mutable_result_null_map_column).get_data();
76
27.3k
            const NullMap& src_null_map =
77
27.3k
                    assert_cast<const ColumnUInt8&>(*null_map_column).get_data();
78
79
27.3k
            VectorizedUtils::update_null_map(result_null_map, src_null_map);
80
27.3k
        }
81
2.02M
    }
82
83
    // Commit merged null map back: result_null_map_column was moved into
84
    // mutable_result_null_map_column when merging 2+ nullable args with nulls.
85
2.02M
    if (mutable_result_null_map_column) {
86
26.0k
        result_null_map_column = std::move(mutable_result_null_map_column);
87
26.0k
    }
88
89
2.02M
    if (!result_null_map_column) {
90
838k
        if (is_column_const(*src)) {
91
75
            return ColumnConst::create(
92
75
                    make_nullable(assert_cast<const ColumnConst&>(*src).get_data_column_ptr(),
93
75
                                  false),
94
75
                    input_rows_count);
95
75
        }
96
838k
        return ColumnNullable::create(src, ColumnUInt8::create(input_rows_count, 0));
97
838k
    }
98
99
1.18M
    return ColumnNullable::create(src_not_nullable, result_null_map_column);
100
2.02M
}
101
102
1.22M
bool have_null_column(const Block& block, const ColumnNumbers& args) {
103
1.92M
    return std::ranges::any_of(args, [&block](const auto& elem) {
104
1.92M
        return block.get_by_position(elem).type->is_nullable();
105
1.92M
    });
106
1.22M
}
107
108
659k
bool have_null_column(const ColumnsWithTypeAndName& args) {
109
1.20M
    return std::ranges::any_of(args, [](const auto& elem) { return elem.type->is_nullable(); });
110
659k
}
111
112
inline Status PreparedFunctionImpl::_execute_skipped_constant_deal(FunctionContext* context,
113
                                                                   Block& block,
114
                                                                   const ColumnNumbers& args,
115
                                                                   uint32_t result,
116
2.91M
                                                                   size_t input_rows_count) const {
117
2.91M
    bool executed = false;
118
2.91M
    RETURN_IF_ERROR(default_implementation_for_nulls(context, block, args, result, input_rows_count,
119
2.91M
                                                     &executed));
120
2.91M
    if (executed) {
121
561k
        return Status::OK();
122
561k
    }
123
2.35M
    return execute_impl(context, block, args, result, input_rows_count);
124
2.91M
}
125
126
Status PreparedFunctionImpl::default_implementation_for_constant_arguments(
127
        FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
128
2.91M
        size_t input_rows_count, bool* executed) const {
129
2.91M
    *executed = false;
130
2.91M
    ColumnNumbers args_expect_const = get_arguments_that_are_always_constant();
131
132
    // Check that these arguments are really constant.
133
2.91M
    for (auto arg_num : args_expect_const) {
134
1.37M
        if (arg_num < args.size() &&
135
1.37M
            !is_column_const(*block.get_by_position(args[arg_num]).column)) {
136
7
            return Status::InvalidArgument("Argument at index {} for function {} must be constant",
137
7
                                           arg_num, get_name());
138
7
        }
139
1.37M
    }
140
141
2.91M
    if (args.empty() || !use_default_implementation_for_constants() ||
142
2.91M
        !VectorizedUtils::all_arguments_are_constant(block, args)) {
143
2.82M
        return Status::OK();
144
2.82M
    }
145
146
    // now all columns are const.
147
94.2k
    Block temporary_block;
148
149
94.2k
    int arguments_size = (int)args.size();
150
225k
    for (size_t arg_num = 0; arg_num < arguments_size; ++arg_num) {
151
131k
        const ColumnWithTypeAndName& column = block.get_by_position(args[arg_num]);
152
        // Columns in const_list --> column_const,    others --> nested_column
153
        // that's because some functions supposes some specific columns always constant.
154
        // If we unpack it, there will be unnecessary cost of virtual judge.
155
131k
        if (args_expect_const.end() !=
156
131k
            std::find(args_expect_const.begin(), args_expect_const.end(), arg_num)) {
157
327
            temporary_block.insert({column.column, column.type, column.name});
158
131k
        } else {
159
131k
            temporary_block.insert(
160
131k
                    {assert_cast<const ColumnConst*>(column.column.get())->get_data_column_ptr(),
161
131k
                     column.type, column.name});
162
131k
        }
163
131k
    }
164
165
94.2k
    temporary_block.insert(block.get_by_position(result));
166
167
94.2k
    ColumnNumbers temporary_argument_numbers(arguments_size);
168
225k
    for (int i = 0; i < arguments_size; ++i) {
169
131k
        temporary_argument_numbers[i] = i;
170
131k
    }
171
172
94.2k
    RETURN_IF_ERROR(_execute_skipped_constant_deal(context, temporary_block,
173
94.2k
                                                   temporary_argument_numbers, arguments_size,
174
94.2k
                                                   temporary_block.rows()));
175
176
93.1k
    ColumnPtr result_column;
177
    /// extremely rare case, when we have function with completely const arguments
178
    /// but some of them produced by non is_deterministic function
179
93.1k
    if (temporary_block.get_by_position(arguments_size).column->size() > 1) {
180
0
        result_column = temporary_block.get_by_position(arguments_size).column->clone_resized(1);
181
93.1k
    } else {
182
93.1k
        result_column = temporary_block.get_by_position(arguments_size).column;
183
93.1k
    }
184
    // We shuold handle the case where the result column is also a ColumnConst.
185
93.1k
    block.get_by_position(result).column = ColumnConst::create(result_column, input_rows_count);
186
93.1k
    *executed = true;
187
93.1k
    return Status::OK();
188
94.2k
}
189
190
Status PreparedFunctionImpl::default_implementation_for_nulls(
191
        FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
192
2.91M
        size_t input_rows_count, bool* executed) const {
193
2.91M
    *executed = false;
194
2.91M
    if (args.empty() || !use_default_implementation_for_nulls()) {
195
1.68M
        return Status::OK();
196
1.68M
    }
197
198
2.49M
    if (std::ranges::any_of(args, [&block](const auto& elem) {
199
2.49M
            return block.get_by_position(elem).column->only_null();
200
2.49M
        })) {
201
16.4k
        block.get_by_position(result).column =
202
16.4k
                block.get_by_position(result).type->create_column_const(input_rows_count, Field());
203
16.4k
        *executed = true;
204
16.4k
        return Status::OK();
205
16.4k
    }
206
207
1.22M
    if (have_null_column(block, args)) {
208
545k
        bool need_to_default = need_replace_null_data_to_default();
209
        // extract nested column from nulls
210
545k
        ColumnNumbers new_args;
211
545k
        Block new_block;
212
213
1.64M
        for (int i = 0; i < args.size(); ++i) {
214
1.10M
            uint32_t arg = args[i];
215
1.10M
            new_args.push_back(i);
216
1.10M
            new_block.insert(block.get_by_position(arg).unnest_nullable(need_to_default));
217
1.10M
        }
218
545k
        new_block.insert(block.get_by_position(result));
219
545k
        int new_result = new_block.columns() - 1;
220
221
545k
        RETURN_IF_ERROR(default_execute(context, new_block, new_args, new_result, block.rows()));
222
        // After run with nested, wrap them in null. Before this, block.get_by_position(result).type
223
        // is not compatible with get_by_position(result).column
224
225
545k
        block.get_by_position(result).column = wrap_in_nullable(
226
545k
                new_block.get_by_position(new_result).column, block, args, input_rows_count);
227
228
545k
        *executed = true;
229
545k
        return Status::OK();
230
545k
    }
231
676k
    return Status::OK();
232
1.22M
}
233
234
Status PreparedFunctionImpl::default_execute(FunctionContext* context, Block& block,
235
                                             const ColumnNumbers& args, uint32_t result,
236
2.91M
                                             size_t input_rows_count) const {
237
2.91M
    bool executed = false;
238
239
2.91M
    RETURN_IF_ERROR(default_implementation_for_constant_arguments(context, block, args, result,
240
2.91M
                                                                  input_rows_count, &executed));
241
2.91M
    if (executed) {
242
93.2k
        return Status::OK();
243
93.2k
    }
244
245
2.82M
    return _execute_skipped_constant_deal(context, block, args, result, input_rows_count);
246
2.91M
}
247
248
Status PreparedFunctionImpl::execute(FunctionContext* context, Block& block,
249
                                     const ColumnNumbers& args, uint32_t result,
250
2.37M
                                     size_t input_rows_count) const {
251
2.37M
    return default_execute(context, block, args, result, input_rows_count);
252
2.37M
}
253
254
756k
void FunctionBuilderImpl::check_number_of_arguments(size_t number_of_arguments) const {
255
756k
    if (is_variadic()) {
256
81.4k
        return;
257
81.4k
    }
258
259
674k
    size_t expected_number_of_arguments = get_number_of_arguments();
260
261
674k
    DCHECK_EQ(number_of_arguments, expected_number_of_arguments) << fmt::format(
262
0
            "Number of arguments for function {} doesn't match: passed {} , should be {}",
263
0
            get_name(), number_of_arguments, expected_number_of_arguments);
264
674k
    if (number_of_arguments != expected_number_of_arguments) {
265
0
        throw Exception(
266
0
                ErrorCode::INVALID_ARGUMENT,
267
0
                "Number of arguments for function {} doesn't match: passed {} , should be {}",
268
0
                get_name(), number_of_arguments, expected_number_of_arguments);
269
0
    }
270
674k
}
271
272
760k
DataTypePtr FunctionBuilderImpl::get_return_type(const ColumnsWithTypeAndName& arguments) const {
273
760k
    check_number_of_arguments(arguments.size());
274
275
760k
    if (!arguments.empty() && use_default_implementation_for_nulls()) {
276
658k
        if (have_null_column(arguments)) {
277
122k
            ColumnNumbers numbers(arguments.size());
278
122k
            std::iota(numbers.begin(), numbers.end(), 0);
279
122k
            auto [nested_block, _] =
280
122k
                    create_block_with_nested_columns(Block(arguments), numbers, false);
281
122k
            auto return_type = get_return_type_impl(
282
122k
                    ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()));
283
122k
            if (!return_type) {
284
0
                return nullptr;
285
0
            }
286
122k
            return make_nullable(return_type);
287
122k
        }
288
658k
    }
289
290
637k
    return get_return_type_impl(arguments);
291
760k
}
292
293
bool FunctionBuilderImpl::is_date_or_datetime_or_decimal(
294
2.79k
        const DataTypePtr& return_type, const DataTypePtr& func_return_type) const {
295
2.79k
    return (is_date_or_datetime(return_type->get_primitive_type()) &&
296
2.79k
            is_date_or_datetime(func_return_type->get_primitive_type())) ||
297
2.79k
           (is_date_v2_or_datetime_v2(return_type->get_primitive_type()) &&
298
2.79k
            is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) ||
299
           // For some date functions such as str_to_date(string, string), return_type will
300
           // be datetimev2 if users enable datev2 but get_return_type(arguments) will still
301
           // return datetime. We need keep backward compatibility here.
302
2.79k
           (is_date_v2_or_datetime_v2(return_type->get_primitive_type()) &&
303
1.91k
            is_date_or_datetime(func_return_type->get_primitive_type())) ||
304
2.79k
           (is_date_or_datetime(return_type->get_primitive_type()) &&
305
1.88k
            is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) ||
306
2.79k
           (is_decimal(return_type->get_primitive_type()) &&
307
1.88k
            is_decimal(func_return_type->get_primitive_type())) ||
308
2.79k
           (is_time_type(return_type->get_primitive_type()) &&
309
522
            is_time_type(func_return_type->get_primitive_type()));
310
2.79k
}
311
312
593
bool contains_date_or_datetime_or_decimal(const DataTypePtr& type) {
313
593
    auto type_ptr = type->is_nullable() ? ((DataTypeNullable*)type.get())->get_nested_type() : type;
314
315
593
    switch (type_ptr->get_primitive_type()) {
316
11
    case TYPE_ARRAY: {
317
11
        const auto* array_type = assert_cast<const DataTypeArray*>(type_ptr.get());
318
11
        return contains_date_or_datetime_or_decimal(array_type->get_nested_type());
319
0
    }
320
0
    case TYPE_MAP: {
321
0
        const auto* map_type = assert_cast<const DataTypeMap*>(type_ptr.get());
322
0
        return contains_date_or_datetime_or_decimal(map_type->get_key_type()) ||
323
0
               contains_date_or_datetime_or_decimal(map_type->get_value_type());
324
0
    }
325
33
    case TYPE_STRUCT: {
326
33
        const auto* struct_type = assert_cast<const DataTypeStruct*>(type_ptr.get());
327
33
        const auto& elements = struct_type->get_elements();
328
71
        return std::ranges::any_of(elements, [](const DataTypePtr& element) {
329
71
            return contains_date_or_datetime_or_decimal(element);
330
71
        });
331
0
    }
332
549
    default:
333
        // For scalar types, check if it's date/datetime/decimal
334
549
        return is_date_or_datetime(type_ptr->get_primitive_type()) ||
335
549
               is_date_v2_or_datetime_v2(type_ptr->get_primitive_type()) ||
336
549
               is_decimal(type_ptr->get_primitive_type()) ||
337
549
               is_time_type(type_ptr->get_primitive_type());
338
593
    }
339
593
}
340
341
// make sure array/map/struct and nested  array/map/struct can be check
342
bool FunctionBuilderImpl::is_nested_type_date_or_datetime_or_decimal(
343
513
        const DataTypePtr& return_type, const DataTypePtr& func_return_type) const {
344
513
    auto return_type_ptr = return_type->is_nullable()
345
513
                                   ? ((DataTypeNullable*)return_type.get())->get_nested_type()
346
513
                                   : return_type;
347
513
    auto func_return_type_ptr =
348
513
            func_return_type->is_nullable()
349
513
                    ? ((DataTypeNullable*)func_return_type.get())->get_nested_type()
350
513
                    : func_return_type;
351
    // make sure that map/struct/array also need to check
352
513
    if (return_type_ptr->get_primitive_type() != func_return_type_ptr->get_primitive_type()) {
353
2
        return false;
354
2
    }
355
356
    // Check if this type contains date/datetime/decimal types
357
511
    if (!contains_date_or_datetime_or_decimal(return_type_ptr)) {
358
        // If no date/datetime/decimal types, just pass through
359
487
        return true;
360
487
    }
361
362
    // If contains date/datetime/decimal types, recursively check each element
363
24
    switch (return_type_ptr->get_primitive_type()) {
364
11
    case TYPE_ARRAY: {
365
11
        auto nested_return_type = remove_nullable(
366
11
                (assert_cast<const DataTypeArray*>(return_type_ptr.get()))->get_nested_type());
367
11
        auto nested_func_type = remove_nullable(
368
11
                (assert_cast<const DataTypeArray*>(func_return_type_ptr.get()))->get_nested_type());
369
11
        return is_nested_type_date_or_datetime_or_decimal(nested_return_type, nested_func_type);
370
0
    }
371
0
    case TYPE_MAP: {
372
0
        const auto* return_map = assert_cast<const DataTypeMap*>(return_type_ptr.get());
373
0
        const auto* func_map = assert_cast<const DataTypeMap*>(func_return_type_ptr.get());
374
375
0
        auto key_return = remove_nullable(return_map->get_key_type());
376
0
        auto key_func = remove_nullable(func_map->get_key_type());
377
0
        auto value_return = remove_nullable(return_map->get_value_type());
378
0
        auto value_func = remove_nullable(func_map->get_value_type());
379
380
0
        return is_nested_type_date_or_datetime_or_decimal(key_return, key_func) &&
381
0
               is_nested_type_date_or_datetime_or_decimal(value_return, value_func);
382
0
    }
383
1
    case TYPE_STRUCT: {
384
1
        const auto* return_struct = assert_cast<const DataTypeStruct*>(return_type_ptr.get());
385
1
        const auto* func_struct = assert_cast<const DataTypeStruct*>(func_return_type_ptr.get());
386
387
1
        auto return_elements = return_struct->get_elements();
388
1
        auto func_elements = func_struct->get_elements();
389
390
1
        if (return_elements.size() != func_elements.size()) {
391
0
            return false;
392
0
        }
393
394
5
        for (size_t i = 0; i < return_elements.size(); i++) {
395
4
            auto elem_return = remove_nullable(return_elements[i]);
396
4
            auto elem_func = remove_nullable(func_elements[i]);
397
398
4
            if (!is_nested_type_date_or_datetime_or_decimal(elem_return, elem_func)) {
399
0
                return false;
400
0
            }
401
4
        }
402
1
        return true;
403
1
    }
404
12
    default:
405
12
        return is_date_or_datetime_or_decimal(return_type_ptr, func_return_type_ptr);
406
24
    }
407
24
}
408
409
ZoneMapFilterResult IFunctionBase::evaluate_zonemap_filter(
410
0
        const ZoneMapEvalContext& ctx, const VExprSPtrs& function_arguments) const {
411
0
    return unsupported_zonemap_filter(ctx);
412
0
}
413
414
} // namespace doris