be/src/exprs/function/function.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/IFunction.cpp |
19 | | // and modified by Doris |
20 | | |
21 | | #include "exprs/function/function.h" |
22 | | |
23 | | #include <algorithm> |
24 | | #include <memory> |
25 | | #include <numeric> |
26 | | |
27 | | #include "common/status.h" |
28 | | #include "core/assert_cast.h" |
29 | | #include "core/column/column.h" |
30 | | #include "core/column/column_const.h" |
31 | | #include "core/column/column_nullable.h" |
32 | | #include "core/column/column_vector.h" |
33 | | #include "core/data_type/data_type_array.h" |
34 | | #include "core/data_type/data_type_nothing.h" |
35 | | #include "core/data_type/data_type_nullable.h" |
36 | | #include "core/data_type/define_primitive_type.h" |
37 | | #include "core/data_type/primitive_type.h" |
38 | | #include "core/field.h" |
39 | | #include "exec/common/util.hpp" |
40 | | #include "exprs/aggregate/aggregate_function.h" |
41 | | #include "exprs/function/function_helpers.h" |
42 | | |
43 | | namespace doris { |
44 | | ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const ColumnNumbers& args, |
45 | 658k | size_t input_rows_count) { |
46 | 658k | ColumnPtr result_null_map_column; |
47 | | /// If result is already nullable. |
48 | 658k | ColumnPtr src_not_nullable = src; |
49 | 658k | MutableColumnPtr mutable_result_null_map_column; |
50 | | |
51 | 658k | if (auto nullable = check_and_get_column_ptr<ColumnNullable>(src)) { |
52 | 186k | src_not_nullable = nullable->get_nested_column_ptr(); |
53 | 186k | result_null_map_column = nullable->get_null_map_column_ptr(); |
54 | 186k | } |
55 | | |
56 | 944k | for (const auto& arg : args) { |
57 | 944k | const ColumnWithTypeAndName& elem = block.get_by_position(arg); |
58 | 944k | if (!elem.type->is_nullable() || is_column_const(*elem.column)) { |
59 | 277k | continue; |
60 | 277k | } |
61 | | |
62 | 666k | if (auto nullable = cast_to_column<ColumnNullable>(elem.column); nullable->has_null()) { |
63 | 59.2k | const ColumnPtr& null_map_column = nullable->get_null_map_column_ptr(); |
64 | 59.2k | if (!result_null_map_column) { // NOLINT(bugprone-use-after-move) |
65 | 32.3k | result_null_map_column = null_map_column->clone_resized(input_rows_count); |
66 | 32.3k | continue; |
67 | 32.3k | } |
68 | | |
69 | 26.8k | if (!mutable_result_null_map_column) { |
70 | 25.5k | mutable_result_null_map_column = |
71 | 25.5k | std::move(result_null_map_column)->assume_mutable(); |
72 | 25.5k | } |
73 | | |
74 | 26.8k | NullMap& result_null_map = |
75 | 26.8k | assert_cast<ColumnUInt8&>(*mutable_result_null_map_column).get_data(); |
76 | 26.8k | const NullMap& src_null_map = |
77 | 26.8k | assert_cast<const ColumnUInt8&>(*null_map_column).get_data(); |
78 | | |
79 | 26.8k | VectorizedUtils::update_null_map(result_null_map, src_null_map); |
80 | 26.8k | } |
81 | 666k | } |
82 | | |
83 | 658k | if (!result_null_map_column) { |
84 | 439k | if (is_column_const(*src)) { |
85 | 73 | return ColumnConst::create( |
86 | 73 | make_nullable(assert_cast<const ColumnConst&>(*src).get_data_column_ptr(), |
87 | 73 | false), |
88 | 73 | input_rows_count); |
89 | 73 | } |
90 | 439k | return ColumnNullable::create(src, ColumnUInt8::create(input_rows_count, 0)); |
91 | 439k | } |
92 | | |
93 | 218k | return ColumnNullable::create(src_not_nullable, result_null_map_column); |
94 | 658k | } |
95 | | |
96 | 626k | bool have_null_column(const Block& block, const ColumnNumbers& args) { |
97 | 1.00M | return std::ranges::any_of(args, [&block](const auto& elem) { |
98 | 1.00M | return block.get_by_position(elem).type->is_nullable(); |
99 | 1.00M | }); |
100 | 626k | } |
101 | | |
102 | 624k | bool have_null_column(const ColumnsWithTypeAndName& args) { |
103 | 1.12M | return std::ranges::any_of(args, [](const auto& elem) { return elem.type->is_nullable(); }); |
104 | 624k | } |
105 | | |
106 | | inline Status PreparedFunctionImpl::_execute_skipped_constant_deal(FunctionContext* context, |
107 | | Block& block, |
108 | | const ColumnNumbers& args, |
109 | | uint32_t result, |
110 | 1.17M | size_t input_rows_count) const { |
111 | 1.17M | bool executed = false; |
112 | 1.17M | RETURN_IF_ERROR(default_implementation_for_nulls(context, block, args, result, input_rows_count, |
113 | 1.17M | &executed)); |
114 | 1.17M | if (executed) { |
115 | 275k | return Status::OK(); |
116 | 275k | } |
117 | 901k | return execute_impl(context, block, args, result, input_rows_count); |
118 | 1.17M | } |
119 | | |
120 | | Status PreparedFunctionImpl::default_implementation_for_constant_arguments( |
121 | | FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result, |
122 | 1.17M | size_t input_rows_count, bool* executed) const { |
123 | 1.17M | *executed = false; |
124 | 1.17M | ColumnNumbers args_expect_const = get_arguments_that_are_always_constant(); |
125 | | |
126 | | // Check that these arguments are really constant. |
127 | 1.17M | for (auto arg_num : args_expect_const) { |
128 | 421k | if (arg_num < args.size() && |
129 | 421k | !is_column_const(*block.get_by_position(args[arg_num]).column)) { |
130 | 3 | return Status::InvalidArgument("Argument at index {} for function {} must be constant", |
131 | 3 | arg_num, get_name()); |
132 | 3 | } |
133 | 421k | } |
134 | | |
135 | 1.17M | if (args.empty() || !use_default_implementation_for_constants() || |
136 | 1.17M | !VectorizedUtils::all_arguments_are_constant(block, args)) { |
137 | 1.08M | return Status::OK(); |
138 | 1.08M | } |
139 | | |
140 | | // now all columns are const. |
141 | 89.3k | Block temporary_block; |
142 | | |
143 | 89.3k | int arguments_size = (int)args.size(); |
144 | 215k | for (size_t arg_num = 0; arg_num < arguments_size; ++arg_num) { |
145 | 126k | const ColumnWithTypeAndName& column = block.get_by_position(args[arg_num]); |
146 | | // Columns in const_list --> column_const, others --> nested_column |
147 | | // that's because some functions supposes some specific columns always constant. |
148 | | // If we unpack it, there will be unnecessary cost of virtual judge. |
149 | 126k | if (args_expect_const.end() != |
150 | 126k | std::find(args_expect_const.begin(), args_expect_const.end(), arg_num)) { |
151 | 499 | temporary_block.insert({column.column, column.type, column.name}); |
152 | 125k | } else { |
153 | 125k | temporary_block.insert( |
154 | 125k | {assert_cast<const ColumnConst*>(column.column.get())->get_data_column_ptr(), |
155 | 125k | column.type, column.name}); |
156 | 125k | } |
157 | 126k | } |
158 | | |
159 | 89.3k | temporary_block.insert(block.get_by_position(result)); |
160 | | |
161 | 89.3k | ColumnNumbers temporary_argument_numbers(arguments_size); |
162 | 215k | for (int i = 0; i < arguments_size; ++i) { |
163 | 126k | temporary_argument_numbers[i] = i; |
164 | 126k | } |
165 | | |
166 | 89.3k | RETURN_IF_ERROR(_execute_skipped_constant_deal(context, temporary_block, |
167 | 89.3k | temporary_argument_numbers, arguments_size, |
168 | 89.3k | temporary_block.rows())); |
169 | | |
170 | 88.2k | ColumnPtr result_column; |
171 | | /// extremely rare case, when we have function with completely const arguments |
172 | | /// but some of them produced by non is_deterministic function |
173 | 88.2k | if (temporary_block.get_by_position(arguments_size).column->size() > 1) { |
174 | 0 | result_column = temporary_block.get_by_position(arguments_size).column->clone_resized(1); |
175 | 88.2k | } else { |
176 | 88.2k | result_column = temporary_block.get_by_position(arguments_size).column; |
177 | 88.2k | } |
178 | | // We shuold handle the case where the result column is also a ColumnConst. |
179 | 88.2k | block.get_by_position(result).column = ColumnConst::create(result_column, input_rows_count); |
180 | 88.2k | *executed = true; |
181 | 88.2k | return Status::OK(); |
182 | 89.3k | } |
183 | | |
184 | | Status PreparedFunctionImpl::default_implementation_for_nulls( |
185 | | FunctionContext* context, Block& block, const ColumnNumbers& args, uint32_t result, |
186 | 1.17M | size_t input_rows_count, bool* executed) const { |
187 | 1.17M | *executed = false; |
188 | 1.17M | if (args.empty() || !use_default_implementation_for_nulls()) { |
189 | 538k | return Status::OK(); |
190 | 538k | } |
191 | | |
192 | 1.28M | if (std::ranges::any_of(args, [&block](const auto& elem) { |
193 | 1.28M | return block.get_by_position(elem).column->only_null(); |
194 | 1.28M | })) { |
195 | 13.1k | block.get_by_position(result).column = |
196 | 13.1k | block.get_by_position(result).type->create_column_const(input_rows_count, Field()); |
197 | 13.1k | *executed = true; |
198 | 13.1k | return Status::OK(); |
199 | 13.1k | } |
200 | | |
201 | 626k | if (have_null_column(block, args)) { |
202 | 262k | bool need_to_default = need_replace_null_data_to_default(); |
203 | | // extract nested column from nulls |
204 | 262k | ColumnNumbers new_args; |
205 | 262k | Block new_block; |
206 | | |
207 | 800k | for (int i = 0; i < args.size(); ++i) { |
208 | 537k | uint32_t arg = args[i]; |
209 | 537k | new_args.push_back(i); |
210 | 537k | new_block.insert(block.get_by_position(arg).unnest_nullable(need_to_default)); |
211 | 537k | } |
212 | 262k | new_block.insert(block.get_by_position(result)); |
213 | 262k | int new_result = new_block.columns() - 1; |
214 | | |
215 | 262k | RETURN_IF_ERROR(default_execute(context, new_block, new_args, new_result, block.rows())); |
216 | | // After run with nested, wrap them in null. Before this, block.get_by_position(result).type |
217 | | // is not compatible with get_by_position(result).column |
218 | | |
219 | 262k | block.get_by_position(result).column = wrap_in_nullable( |
220 | 262k | new_block.get_by_position(new_result).column, block, args, input_rows_count); |
221 | | |
222 | 262k | *executed = true; |
223 | 262k | return Status::OK(); |
224 | 262k | } |
225 | 363k | return Status::OK(); |
226 | 626k | } |
227 | | |
228 | | Status PreparedFunctionImpl::default_execute(FunctionContext* context, Block& block, |
229 | | const ColumnNumbers& args, uint32_t result, |
230 | 1.17M | size_t input_rows_count) const { |
231 | 1.17M | bool executed = false; |
232 | | |
233 | 1.17M | RETURN_IF_ERROR(default_implementation_for_constant_arguments(context, block, args, result, |
234 | 1.17M | input_rows_count, &executed)); |
235 | 1.17M | if (executed) { |
236 | 88.8k | return Status::OK(); |
237 | 88.8k | } |
238 | | |
239 | 1.08M | return _execute_skipped_constant_deal(context, block, args, result, input_rows_count); |
240 | 1.17M | } |
241 | | |
242 | | Status PreparedFunctionImpl::execute(FunctionContext* context, Block& block, |
243 | | const ColumnNumbers& args, uint32_t result, |
244 | 915k | size_t input_rows_count) const { |
245 | 915k | return default_execute(context, block, args, result, input_rows_count); |
246 | 915k | } |
247 | | |
248 | 709k | void FunctionBuilderImpl::check_number_of_arguments(size_t number_of_arguments) const { |
249 | 709k | if (is_variadic()) { |
250 | 77.1k | return; |
251 | 77.1k | } |
252 | | |
253 | 632k | size_t expected_number_of_arguments = get_number_of_arguments(); |
254 | | |
255 | 632k | DCHECK_EQ(number_of_arguments, expected_number_of_arguments) << fmt::format( |
256 | 0 | "Number of arguments for function {} doesn't match: passed {} , should be {}", |
257 | 0 | get_name(), number_of_arguments, expected_number_of_arguments); |
258 | 632k | if (number_of_arguments != expected_number_of_arguments) { |
259 | 0 | throw Exception( |
260 | 0 | ErrorCode::INVALID_ARGUMENT, |
261 | 0 | "Number of arguments for function {} doesn't match: passed {} , should be {}", |
262 | 0 | get_name(), number_of_arguments, expected_number_of_arguments); |
263 | 0 | } |
264 | 632k | } |
265 | | |
266 | 713k | DataTypePtr FunctionBuilderImpl::get_return_type(const ColumnsWithTypeAndName& arguments) const { |
267 | 713k | check_number_of_arguments(arguments.size()); |
268 | | |
269 | 713k | if (!arguments.empty() && use_default_implementation_for_nulls()) { |
270 | 624k | if (have_null_column(arguments)) { |
271 | 121k | ColumnNumbers numbers(arguments.size()); |
272 | 121k | std::iota(numbers.begin(), numbers.end(), 0); |
273 | 121k | auto [nested_block, _] = |
274 | 121k | create_block_with_nested_columns(Block(arguments), numbers, false); |
275 | 121k | auto return_type = get_return_type_impl( |
276 | 121k | ColumnsWithTypeAndName(nested_block.begin(), nested_block.end())); |
277 | 121k | if (!return_type) { |
278 | 0 | return nullptr; |
279 | 0 | } |
280 | 121k | return make_nullable(return_type); |
281 | 121k | } |
282 | 624k | } |
283 | | |
284 | 591k | return get_return_type_impl(arguments); |
285 | 713k | } |
286 | | |
287 | | bool FunctionBuilderImpl::is_date_or_datetime_or_decimal( |
288 | 2.67k | const DataTypePtr& return_type, const DataTypePtr& func_return_type) const { |
289 | 2.67k | return (is_date_or_datetime(return_type->get_primitive_type()) && |
290 | 2.67k | is_date_or_datetime(func_return_type->get_primitive_type())) || |
291 | 2.67k | (is_date_v2_or_datetime_v2(return_type->get_primitive_type()) && |
292 | 2.67k | is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) || |
293 | | // For some date functions such as str_to_date(string, string), return_type will |
294 | | // be datetimev2 if users enable datev2 but get_return_type(arguments) will still |
295 | | // return datetime. We need keep backward compatibility here. |
296 | 2.67k | (is_date_v2_or_datetime_v2(return_type->get_primitive_type()) && |
297 | 1.78k | is_date_or_datetime(func_return_type->get_primitive_type())) || |
298 | 2.67k | (is_date_or_datetime(return_type->get_primitive_type()) && |
299 | 1.76k | is_date_v2_or_datetime_v2(func_return_type->get_primitive_type())) || |
300 | 2.67k | (is_decimal(return_type->get_primitive_type()) && |
301 | 1.76k | is_decimal(func_return_type->get_primitive_type())) || |
302 | 2.67k | (is_time_type(return_type->get_primitive_type()) && |
303 | 497 | is_time_type(func_return_type->get_primitive_type())); |
304 | 2.67k | } |
305 | | |
306 | 572 | bool contains_date_or_datetime_or_decimal(const DataTypePtr& type) { |
307 | 572 | auto type_ptr = type->is_nullable() ? ((DataTypeNullable*)type.get())->get_nested_type() : type; |
308 | | |
309 | 572 | switch (type_ptr->get_primitive_type()) { |
310 | 15 | case TYPE_ARRAY: { |
311 | 15 | const auto* array_type = assert_cast<const DataTypeArray*>(type_ptr.get()); |
312 | 15 | return contains_date_or_datetime_or_decimal(array_type->get_nested_type()); |
313 | 0 | } |
314 | 0 | case TYPE_MAP: { |
315 | 0 | const auto* map_type = assert_cast<const DataTypeMap*>(type_ptr.get()); |
316 | 0 | return contains_date_or_datetime_or_decimal(map_type->get_key_type()) || |
317 | 0 | contains_date_or_datetime_or_decimal(map_type->get_value_type()); |
318 | 0 | } |
319 | 33 | case TYPE_STRUCT: { |
320 | 33 | const auto* struct_type = assert_cast<const DataTypeStruct*>(type_ptr.get()); |
321 | 33 | const auto& elements = struct_type->get_elements(); |
322 | 71 | return std::ranges::any_of(elements, [](const DataTypePtr& element) { |
323 | 71 | return contains_date_or_datetime_or_decimal(element); |
324 | 71 | }); |
325 | 0 | } |
326 | 524 | default: |
327 | | // For scalar types, check if it's date/datetime/decimal |
328 | 524 | return is_date_or_datetime(type_ptr->get_primitive_type()) || |
329 | 524 | is_date_v2_or_datetime_v2(type_ptr->get_primitive_type()) || |
330 | 524 | is_decimal(type_ptr->get_primitive_type()) || |
331 | 524 | is_time_type(type_ptr->get_primitive_type()); |
332 | 572 | } |
333 | 572 | } |
334 | | |
335 | | // make sure array/map/struct and nested array/map/struct can be check |
336 | | bool FunctionBuilderImpl::is_nested_type_date_or_datetime_or_decimal( |
337 | 488 | const DataTypePtr& return_type, const DataTypePtr& func_return_type) const { |
338 | 488 | auto return_type_ptr = return_type->is_nullable() |
339 | 488 | ? ((DataTypeNullable*)return_type.get())->get_nested_type() |
340 | 488 | : return_type; |
341 | 488 | auto func_return_type_ptr = |
342 | 488 | func_return_type->is_nullable() |
343 | 488 | ? ((DataTypeNullable*)func_return_type.get())->get_nested_type() |
344 | 488 | : func_return_type; |
345 | | // make sure that map/struct/array also need to check |
346 | 488 | if (return_type_ptr->get_primitive_type() != func_return_type_ptr->get_primitive_type()) { |
347 | 2 | return false; |
348 | 2 | } |
349 | | |
350 | | // Check if this type contains date/datetime/decimal types |
351 | 486 | if (!contains_date_or_datetime_or_decimal(return_type_ptr)) { |
352 | | // If no date/datetime/decimal types, just pass through |
353 | 454 | return true; |
354 | 454 | } |
355 | | |
356 | | // If contains date/datetime/decimal types, recursively check each element |
357 | 32 | switch (return_type_ptr->get_primitive_type()) { |
358 | 15 | case TYPE_ARRAY: { |
359 | 15 | auto nested_return_type = remove_nullable( |
360 | 15 | (assert_cast<const DataTypeArray*>(return_type_ptr.get()))->get_nested_type()); |
361 | 15 | auto nested_func_type = remove_nullable( |
362 | 15 | (assert_cast<const DataTypeArray*>(func_return_type_ptr.get()))->get_nested_type()); |
363 | 15 | return is_nested_type_date_or_datetime_or_decimal(nested_return_type, nested_func_type); |
364 | 0 | } |
365 | 0 | case TYPE_MAP: { |
366 | 0 | const auto* return_map = assert_cast<const DataTypeMap*>(return_type_ptr.get()); |
367 | 0 | const auto* func_map = assert_cast<const DataTypeMap*>(func_return_type_ptr.get()); |
368 | |
|
369 | 0 | auto key_return = remove_nullable(return_map->get_key_type()); |
370 | 0 | auto key_func = remove_nullable(func_map->get_key_type()); |
371 | 0 | auto value_return = remove_nullable(return_map->get_value_type()); |
372 | 0 | auto value_func = remove_nullable(func_map->get_value_type()); |
373 | |
|
374 | 0 | return is_nested_type_date_or_datetime_or_decimal(key_return, key_func) && |
375 | 0 | is_nested_type_date_or_datetime_or_decimal(value_return, value_func); |
376 | 0 | } |
377 | 1 | case TYPE_STRUCT: { |
378 | 1 | const auto* return_struct = assert_cast<const DataTypeStruct*>(return_type_ptr.get()); |
379 | 1 | const auto* func_struct = assert_cast<const DataTypeStruct*>(func_return_type_ptr.get()); |
380 | | |
381 | 1 | auto return_elements = return_struct->get_elements(); |
382 | 1 | auto func_elements = func_struct->get_elements(); |
383 | | |
384 | 1 | if (return_elements.size() != func_elements.size()) { |
385 | 0 | return false; |
386 | 0 | } |
387 | | |
388 | 5 | for (size_t i = 0; i < return_elements.size(); i++) { |
389 | 4 | auto elem_return = remove_nullable(return_elements[i]); |
390 | 4 | auto elem_func = remove_nullable(func_elements[i]); |
391 | | |
392 | 4 | if (!is_nested_type_date_or_datetime_or_decimal(elem_return, elem_func)) { |
393 | 0 | return false; |
394 | 0 | } |
395 | 4 | } |
396 | 1 | return true; |
397 | 1 | } |
398 | 16 | default: |
399 | 16 | return is_date_or_datetime_or_decimal(return_type_ptr, func_return_type_ptr); |
400 | 32 | } |
401 | 32 | } |
402 | | |
403 | | } // namespace doris |