Coverage Report

Created: 2025-10-16 20:27

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/be/src/vec/functions/function.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/IFunction.cpp
19
// and modified by Doris
20
21
#include "vec/functions/function.h"
22
23
#include <algorithm>
24
#include <memory>
25
#include <numeric>
26
27
#include "common/status.h"
28
#include "runtime/define_primitive_type.h"
29
#include "runtime/primitive_type.h"
30
#include "vec/aggregate_functions/aggregate_function.h"
31
#include "vec/columns/column.h"
32
#include "vec/columns/column_const.h"
33
#include "vec/columns/column_nullable.h"
34
#include "vec/columns/column_vector.h"
35
#include "vec/common/assert_cast.h"
36
#include "vec/core/field.h"
37
#include "vec/data_types/data_type_array.h"
38
#include "vec/data_types/data_type_nothing.h"
39
#include "vec/data_types/data_type_nullable.h"
40
#include "vec/functions/function_helpers.h"
41
#include "vec/utils/util.hpp"
42
43
namespace doris::vectorized {
44
#include "common/compile_check_begin.h"
45
ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const ColumnNumbers& args,
46
69.0k
                           uint32_t result, size_t input_rows_count) {
47
69.0k
    ColumnPtr result_null_map_column;
48
    /// If result is already nullable.
49
69.0k
    ColumnPtr src_not_nullable = src;
50
69.0k
    MutableColumnPtr mutable_result_null_map_column;
51
52
69.0k
    if (const auto* nullable = check_and_get_column<ColumnNullable>(*src)) {
53
8.43k
        src_not_nullable = nullable->get_nested_column_ptr();
54
8.43k
        result_null_map_column = nullable->get_null_map_column_ptr();
55
8.43k
    }
56
57
93.4k
    for (const auto& arg : args) {
58
93.4k
        const ColumnWithTypeAndName& elem = block.get_by_position(arg);
59
93.4k
        if (!elem.type->is_nullable() || is_column_const(*elem.column)) {
60
19.7k
            continue;
61
19.7k
        }
62
63
73.6k
        if (const auto* nullable = assert_cast<const ColumnNullable*>(elem.column.get());
64
73.6k
            nullable->has_null()) {
65
1.03k
            const ColumnPtr& null_map_column = nullable->get_null_map_column_ptr();
66
1.03k
            if (!result_null_map_column) { // NOLINT(bugprone-use-after-move)
67
945
                result_null_map_column = null_map_column->clone_resized(input_rows_count);
68
945
                continue;
69
945
            }
70
71
94
            if (!mutable_result_null_map_column) {
72
70
                mutable_result_null_map_column =
73
70
                        std::move(result_null_map_column)->assume_mutable();
74
70
            }
75
76
94
            NullMap& result_null_map =
77
94
                    assert_cast<ColumnUInt8&>(*mutable_result_null_map_column).get_data();
78
94
            const NullMap& src_null_map =
79
94
                    assert_cast<const ColumnUInt8&>(*null_map_column).get_data();
80
81
94
            VectorizedUtils::update_null_map(result_null_map, src_null_map);
82
94
        }
83
73.6k
    }
84
85
69.0k
    if (!result_null_map_column) {
86
59.6k
        if (is_column_const(*src)) {
87
0
            return ColumnConst::create(
88
0
                    make_nullable(assert_cast<const ColumnConst&>(*src).get_data_column_ptr(),
89
0
                                  false),
90
0
                    input_rows_count);
91
0
        }
92
59.6k
        return ColumnNullable::create(src, ColumnUInt8::create(input_rows_count, 0));
93
59.6k
    }
94
95
9.38k
    return ColumnNullable::create(src_not_nullable, result_null_map_column);
96
69.0k
}
97
98
16.0k
bool have_null_column(const Block& block, const ColumnNumbers& args) {
99
29.0k
    return std::ranges::any_of(args, [&block](const auto& elem) {
100
29.0k
        return block.get_by_position(elem).type->is_nullable();
101
29.0k
    });
102
16.0k
}
103
104
11.5k
bool have_null_column(const ColumnsWithTypeAndName& args) {
105
11.5k
    return std::ranges::any_of(args, [](const auto& elem) { return elem.type->is_nullable(); });
106
11.5k
}
107
108
inline Status PreparedFunctionImpl::_execute_skipped_constant_deal(
109
        FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
110
105k
        size_t input_rows_count, bool dry_run) const {
111
105k
    bool executed = false;
112
105k
    RETURN_IF_ERROR(default_implementation_for_nulls(context, block, args, result, input_rows_count,
113
105k
                                                     dry_run, &executed));
114
105k
    if (executed) {
115
11.3k
        return Status::OK();
116
11.3k
    }
117
118
94.6k
    if (dry_run) {
119
0
        return execute_impl_dry_run(context, block, args, result, input_rows_count);
120
94.6k
    } else {
121
94.6k
        return execute_impl(context, block, args, result, input_rows_count);
122
94.6k
    }
123
94.6k
}
124
125
Status PreparedFunctionImpl::default_implementation_for_constant_arguments(
126
        FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
127
105k
        size_t input_rows_count, bool dry_run, bool* executed) const {
128
105k
    *executed = false;
129
105k
    ColumnNumbers args_expect_const = get_arguments_that_are_always_constant();
130
131
    // Check that these arguments are really constant.
132
105k
    for (auto arg_num : args_expect_const) {
133
85.5k
        if (arg_num < args.size() &&
134
85.5k
            !is_column_const(*block.get_by_position(args[arg_num]).column)) {
135
0
            return Status::InvalidArgument("Argument at index {} for function {} must be constant",
136
0
                                           arg_num, get_name());
137
0
        }
138
85.5k
    }
139
140
105k
    if (args.empty() || !use_default_implementation_for_constants() ||
141
105k
        !VectorizedUtils::all_arguments_are_constant(block, args)) {
142
103k
        return Status::OK();
143
103k
    }
144
145
    // now all columns are const.
146
2.44k
    Block temporary_block;
147
148
2.44k
    int arguments_size = (int)args.size();
149
8.16k
    for (size_t arg_num = 0; arg_num < arguments_size; ++arg_num) {
150
5.71k
        const ColumnWithTypeAndName& column = block.get_by_position(args[arg_num]);
151
        // Columns in const_list --> column_const,    others --> nested_column
152
        // that's because some functions supposes some specific columns always constant.
153
        // If we unpack it, there will be unnecessary cost of virtual judge.
154
5.71k
        if (args_expect_const.end() !=
155
5.71k
            std::find(args_expect_const.begin(), args_expect_const.end(), arg_num)) {
156
0
            temporary_block.insert({column.column, column.type, column.name});
157
5.71k
        } else {
158
5.71k
            temporary_block.insert(
159
5.71k
                    {assert_cast<const ColumnConst*>(column.column.get())->get_data_column_ptr(),
160
5.71k
                     column.type, column.name});
161
5.71k
        }
162
5.71k
    }
163
164
2.44k
    temporary_block.insert(block.get_by_position(result));
165
166
2.44k
    ColumnNumbers temporary_argument_numbers(arguments_size);
167
8.16k
    for (int i = 0; i < arguments_size; ++i) {
168
5.71k
        temporary_argument_numbers[i] = i;
169
5.71k
    }
170
171
2.44k
    RETURN_IF_ERROR(_execute_skipped_constant_deal(context, temporary_block,
172
2.44k
                                                   temporary_argument_numbers, arguments_size,
173
2.44k
                                                   temporary_block.rows(), dry_run));
174
175
2.44k
    ColumnPtr result_column;
176
    /// extremely rare case, when we have function with completely const arguments
177
    /// but some of them produced by non is_deterministic function
178
2.44k
    if (temporary_block.get_by_position(arguments_size).column->size() > 1) {
179
0
        result_column = temporary_block.get_by_position(arguments_size).column->clone_resized(1);
180
2.44k
    } else {
181
2.44k
        result_column = temporary_block.get_by_position(arguments_size).column;
182
2.44k
    }
183
    // We shuold handle the case where the result column is also a ColumnConst.
184
2.44k
    block.get_by_position(result).column = ColumnConst::create(result_column, input_rows_count);
185
2.44k
    *executed = true;
186
2.44k
    return Status::OK();
187
2.44k
}
188
189
Status PreparedFunctionImpl::default_implementation_for_nulls(
190
        FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
191
105k
        size_t input_rows_count, bool dry_run, bool* executed) const {
192
105k
    *executed = false;
193
105k
    if (args.empty() || !use_default_implementation_for_nulls()) {
194
86.5k
        return Status::OK();
195
86.5k
    }
196
197
48.1k
    if (std::ranges::any_of(args, [&block](const auto& elem) {
198
48.1k
            return block.get_by_position(elem).column->only_null();
199
48.1k
        })) {
200
3.40k
        block.get_by_position(result).column =
201
3.40k
                block.get_by_position(result).type->create_column_const(input_rows_count, Field());
202
3.40k
        *executed = true;
203
3.40k
        return Status::OK();
204
3.40k
    }
205
206
16.0k
    if (have_null_column(block, args)) {
207
7.92k
        bool need_to_default = need_replace_null_data_to_default();
208
        // extract nested column from nulls
209
7.92k
        ColumnNumbers new_args;
210
20.9k
        for (auto arg : args) {
211
20.9k
            new_args.push_back(block.columns());
212
20.9k
            block.insert(block.get_by_position(arg).unnest_nullable(need_to_default));
213
20.9k
            DCHECK(!block.get_by_position(new_args.back()).column->is_nullable());
214
20.9k
        }
215
7.92k
        RETURN_IF_ERROR(execute_without_low_cardinality_columns(context, block, new_args, result,
216
7.92k
                                                                block.rows(), dry_run));
217
        // After run with nested, wrap them in null. Before this, block.get_by_position(result).type
218
        // is not compatible with get_by_position(result).column
219
7.92k
        block.get_by_position(result).column = wrap_in_nullable(
220
7.92k
                block.get_by_position(result).column, block, args, result, input_rows_count);
221
222
28.8k
        while (!new_args.empty()) {
223
20.9k
            block.erase(new_args.back());
224
20.9k
            new_args.pop_back();
225
20.9k
        }
226
7.92k
        *executed = true;
227
7.92k
        return Status::OK();
228
7.92k
    }
229
8.08k
    return Status::OK();
230
16.0k
}
231
232
Status PreparedFunctionImpl::execute_without_low_cardinality_columns(
233
        FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result,
234
105k
        size_t input_rows_count, bool dry_run) const {
235
105k
    bool executed = false;
236
237
105k
    RETURN_IF_ERROR(default_implementation_for_constant_arguments(
238
105k
            context, block, args, result, input_rows_count, dry_run, &executed));
239
105k
    if (executed) {
240
2.44k
        return Status::OK();
241
2.44k
    }
242
243
103k
    return _execute_skipped_constant_deal(context, block, args, result, input_rows_count, dry_run);
244
105k
}
245
246
Status PreparedFunctionImpl::execute(FunctionContext* context, Block& block,
247
                                     const ColumnNumbers& args, uint32_t result,
248
98.0k
                                     size_t input_rows_count, bool dry_run) const {
249
98.0k
    return execute_without_low_cardinality_columns(context, block, args, result, input_rows_count,
250
98.0k
                                                   dry_run);
251
98.0k
}
252
253
12.6k
void FunctionBuilderImpl::check_number_of_arguments(size_t number_of_arguments) const {
254
12.6k
    if (is_variadic()) {
255
3.95k
        return;
256
3.95k
    }
257
258
8.70k
    size_t expected_number_of_arguments = get_number_of_arguments();
259
260
8.70k
    DCHECK_EQ(number_of_arguments, expected_number_of_arguments) << fmt::format(
261
0
            "Number of arguments for function {} doesn't match: passed {} , should be {}",
262
0
            get_name(), number_of_arguments, expected_number_of_arguments);
263
8.70k
    if (number_of_arguments != expected_number_of_arguments) {
264
0
        throw Exception(
265
0
                ErrorCode::INVALID_ARGUMENT,
266
0
                "Number of arguments for function {} doesn't match: passed {} , should be {}",
267
0
                get_name(), number_of_arguments, expected_number_of_arguments);
268
0
    }
269
8.70k
}
270
271
12.6k
DataTypePtr FunctionBuilderImpl::get_return_type(const ColumnsWithTypeAndName& arguments) const {
272
12.6k
    check_number_of_arguments(arguments.size());
273
274
12.6k
    if (!arguments.empty() && use_default_implementation_for_nulls()) {
275
11.5k
        if (have_null_column(arguments)) {
276
11.3k
            ColumnNumbers numbers(arguments.size());
277
11.3k
            std::iota(numbers.begin(), numbers.end(), 0);
278
11.3k
            auto [nested_block, _] =
279
11.3k
                    create_block_with_nested_columns(Block(arguments), numbers, false);
280
11.3k
            auto return_type = get_return_type_impl(
281
11.3k
                    ColumnsWithTypeAndName(nested_block.begin(), nested_block.end()));
282
11.3k
            return make_nullable(return_type);
283
11.3k
        }
284
11.5k
    }
285
286
1.29k
    return get_return_type_impl(arguments);
287
12.6k
}
288
289
bool FunctionBuilderImpl::is_date_or_datetime_or_decimal(
290
9
        const DataTypePtr& return_type, const DataTypePtr& func_return_type) const {
291
9
    return (is_date_or_datetime(return_type->get_primitive_type()) &&
292
9
            is_date_or_datetime(func_return_type->get_primitive_type())) ||
293
9
           (is_date_v2_or_datetime_v2(return_type->get_primitive_type()) &&
294
9
            is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) ||
295
           // For some date functions such as str_to_date(string, string), return_type will
296
           // be datetimev2 if users enable datev2 but get_return_type(arguments) will still
297
           // return datetime. We need keep backward compatibility here.
298
9
           (is_date_v2_or_datetime_v2(return_type->get_primitive_type()) &&
299
5
            is_date_or_datetime(func_return_type->get_primitive_type())) ||
300
9
           (is_date_or_datetime(return_type->get_primitive_type()) &&
301
5
            is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) ||
302
9
           (is_decimal(return_type->get_primitive_type()) &&
303
5
            is_decimal(func_return_type->get_primitive_type())) ||
304
9
           (is_time_type(return_type->get_primitive_type()) &&
305
3
            is_time_type(func_return_type->get_primitive_type()));
306
9
}
307
308
0
bool contains_date_or_datetime_or_decimal(const DataTypePtr& type) {
309
0
    auto type_ptr = type->is_nullable() ? ((DataTypeNullable*)type.get())->get_nested_type() : type;
310
311
0
    switch (type_ptr->get_primitive_type()) {
312
0
    case TYPE_ARRAY: {
313
0
        const auto* array_type = assert_cast<const DataTypeArray*>(type_ptr.get());
314
0
        return contains_date_or_datetime_or_decimal(array_type->get_nested_type());
315
0
    }
316
0
    case TYPE_MAP: {
317
0
        const auto* map_type = assert_cast<const DataTypeMap*>(type_ptr.get());
318
0
        return contains_date_or_datetime_or_decimal(map_type->get_key_type()) ||
319
0
               contains_date_or_datetime_or_decimal(map_type->get_value_type());
320
0
    }
321
0
    case TYPE_STRUCT: {
322
0
        const auto* struct_type = assert_cast<const DataTypeStruct*>(type_ptr.get());
323
0
        const auto& elements = struct_type->get_elements();
324
0
        return std::ranges::any_of(elements, [](const DataTypePtr& element) {
325
0
            return contains_date_or_datetime_or_decimal(element);
326
0
        });
327
0
    }
328
0
    default:
329
        // For scalar types, check if it's date/datetime/decimal
330
0
        return is_date_or_datetime(type_ptr->get_primitive_type()) ||
331
0
               is_date_v2_or_datetime_v2(type_ptr->get_primitive_type()) ||
332
0
               is_decimal(type_ptr->get_primitive_type()) ||
333
0
               is_time_type(type_ptr->get_primitive_type());
334
0
    }
335
0
}
336
337
// make sure array/map/struct and nested  array/map/struct can be check
338
bool FunctionBuilderImpl::is_nested_type_date_or_datetime_or_decimal(
339
2
        const DataTypePtr& return_type, const DataTypePtr& func_return_type) const {
340
2
    auto return_type_ptr = return_type->is_nullable()
341
2
                                   ? ((DataTypeNullable*)return_type.get())->get_nested_type()
342
2
                                   : return_type;
343
2
    auto func_return_type_ptr =
344
2
            func_return_type->is_nullable()
345
2
                    ? ((DataTypeNullable*)func_return_type.get())->get_nested_type()
346
2
                    : func_return_type;
347
    // make sure that map/struct/array also need to check
348
2
    if (return_type_ptr->get_primitive_type() != func_return_type_ptr->get_primitive_type()) {
349
2
        return false;
350
2
    }
351
352
    // Check if this type contains date/datetime/decimal types
353
0
    if (!contains_date_or_datetime_or_decimal(return_type_ptr)) {
354
        // If no date/datetime/decimal types, just pass through
355
0
        return true;
356
0
    }
357
358
    // If contains date/datetime/decimal types, recursively check each element
359
0
    switch (return_type_ptr->get_primitive_type()) {
360
0
    case TYPE_ARRAY: {
361
0
        auto nested_return_type = remove_nullable(
362
0
                (assert_cast<const DataTypeArray*>(return_type_ptr.get()))->get_nested_type());
363
0
        auto nested_func_type = remove_nullable(
364
0
                (assert_cast<const DataTypeArray*>(func_return_type_ptr.get()))->get_nested_type());
365
0
        return is_nested_type_date_or_datetime_or_decimal(nested_return_type, nested_func_type);
366
0
    }
367
0
    case TYPE_MAP: {
368
0
        const auto* return_map = assert_cast<const DataTypeMap*>(return_type_ptr.get());
369
0
        const auto* func_map = assert_cast<const DataTypeMap*>(func_return_type_ptr.get());
370
371
0
        auto key_return = remove_nullable(return_map->get_key_type());
372
0
        auto key_func = remove_nullable(func_map->get_key_type());
373
0
        auto value_return = remove_nullable(return_map->get_value_type());
374
0
        auto value_func = remove_nullable(func_map->get_value_type());
375
376
0
        return is_nested_type_date_or_datetime_or_decimal(key_return, key_func) &&
377
0
               is_nested_type_date_or_datetime_or_decimal(value_return, value_func);
378
0
    }
379
0
    case TYPE_STRUCT: {
380
0
        const auto* return_struct = assert_cast<const DataTypeStruct*>(return_type_ptr.get());
381
0
        const auto* func_struct = assert_cast<const DataTypeStruct*>(func_return_type_ptr.get());
382
383
0
        auto return_elements = return_struct->get_elements();
384
0
        auto func_elements = func_struct->get_elements();
385
386
0
        if (return_elements.size() != func_elements.size()) {
387
0
            return false;
388
0
        }
389
390
0
        for (size_t i = 0; i < return_elements.size(); i++) {
391
0
            auto elem_return = remove_nullable(return_elements[i]);
392
0
            auto elem_func = remove_nullable(func_elements[i]);
393
394
0
            if (!is_nested_type_date_or_datetime_or_decimal(elem_return, elem_func)) {
395
0
                return false;
396
0
            }
397
0
        }
398
0
        return true;
399
0
    }
400
0
    default:
401
0
        return is_date_or_datetime_or_decimal(return_type_ptr, func_return_type_ptr);
402
0
    }
403
0
}
404
405
#include "common/compile_check_end.h"
406
} // namespace doris::vectorized