be/src/exprs/aggregate/aggregate_function_distinct.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionDistinct.cpp |
19 | | // and modified by Doris |
20 | | |
21 | | #include "exprs/aggregate/aggregate_function_distinct.h" |
22 | | |
23 | | #include <algorithm> |
24 | | |
25 | | #include "core/data_type/data_type_nullable.h" |
26 | | #include "exprs/aggregate/aggregate_function_combinator.h" |
27 | | #include "exprs/aggregate/aggregate_function_simple_factory.h" |
28 | | #include "exprs/aggregate/helpers.h" |
29 | | |
30 | | namespace doris { |
31 | | #include "common/compile_check_begin.h" |
32 | | |
33 | | template <PrimitiveType T> |
34 | | struct Reducer { |
35 | | template <bool stable> |
36 | | using Output = AggregateFunctionDistinctSingleNumericData<T, stable>; |
37 | | using AggregateFunctionDistinctNormal = AggregateFunctionDistinct<Output, false>; |
38 | | }; |
39 | | |
40 | | template <PrimitiveType T> |
41 | | using AggregateFunctionDistinctNumeric = typename Reducer<T>::AggregateFunctionDistinctNormal; |
42 | | |
43 | | class AggregateFunctionCombinatorDistinct final : public IAggregateFunctionCombinator { |
44 | | public: |
45 | 0 | String get_name() const override { return "Distinct"; } |
46 | | |
47 | 16 | DataTypes transform_arguments(const DataTypes& arguments) const override { |
48 | 16 | if (arguments.empty()) { |
49 | 0 | throw doris::Exception( |
50 | 0 | ErrorCode::INTERNAL_ERROR, |
51 | 0 | "Incorrect number of arguments for aggregate function with Distinct suffix"); |
52 | 0 | } |
53 | 16 | return arguments; |
54 | 16 | } |
55 | | |
56 | | AggregateFunctionPtr transform_aggregate_function( |
57 | | const AggregateFunctionPtr& nested_function, const DataTypes& arguments, |
58 | 16 | const bool result_is_nullable, const AggregateFunctionAttr& attr) const override { |
59 | 16 | DCHECK(nested_function != nullptr); |
60 | 16 | if (nested_function == nullptr) { |
61 | 0 | return nullptr; |
62 | 0 | } |
63 | | |
64 | 16 | if (arguments.size() == 1) { |
65 | 16 | AggregateFunctionPtr res( |
66 | 16 | creator_with_type_list<TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, |
67 | 16 | TYPE_LARGEINT>:: |
68 | 16 | create<AggregateFunctionDistinctNumeric>(arguments, result_is_nullable, |
69 | 16 | attr, nested_function)); |
70 | 16 | if (res) { |
71 | 16 | return res; |
72 | 16 | } |
73 | | |
74 | 0 | res = creator_without_type::create< |
75 | 0 | AggregateFunctionDistinct<AggregateFunctionDistinctSingleGenericData>>( |
76 | 0 | arguments, result_is_nullable, attr, nested_function); |
77 | 0 | return res; |
78 | 16 | } |
79 | 0 | return creator_without_type::create< |
80 | 0 | AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>( |
81 | 0 | arguments, result_is_nullable, attr, nested_function); |
82 | 16 | } |
83 | | }; |
84 | | |
85 | 7 | void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) { |
86 | 7 | AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, |
87 | 7 | const DataTypePtr& result_type, |
88 | 7 | const bool result_is_nullable, |
89 | 16 | const AggregateFunctionAttr& attr) { |
90 | | // 1. we should get not nullable types; |
91 | 16 | DataTypes nested_types(types.size()); |
92 | 16 | std::ranges::transform(types, nested_types.begin(), |
93 | 16 | [](const auto& e) { return remove_nullable(e); }); |
94 | 16 | auto function_combinator = std::make_shared<AggregateFunctionCombinatorDistinct>(); |
95 | 16 | auto transform_arguments = function_combinator->transform_arguments(nested_types); |
96 | 16 | auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size()); |
97 | 16 | auto nested_function = factory.get(nested_function_name, transform_arguments, result_type, |
98 | 16 | false, BeExecVersionManager::get_newest_version(), attr); |
99 | 16 | return function_combinator->transform_aggregate_function(nested_function, types, |
100 | 16 | result_is_nullable, attr); |
101 | 16 | }; |
102 | 7 | factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX); |
103 | 7 | factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX, true); |
104 | 7 | } |
105 | | } // namespace doris |