be/src/exprs/aggregate/aggregate_function_foreach.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h |
19 | | // and modified by Doris |
20 | | |
21 | | #pragma once |
22 | | |
23 | | #include "common/status.h" |
24 | | #include "core/assert_cast.h" |
25 | | #include "core/column/column_nullable.h" |
26 | | #include "core/data_type/data_type_array.h" |
27 | | #include "core/data_type/data_type_nullable.h" |
28 | | #include "exec/common/arithmetic_overflow.h" |
29 | | #include "exprs/aggregate/aggregate_function.h" |
30 | | #include "exprs/function/array/function_array_utils.h" |
31 | | |
32 | | namespace doris { |
33 | | |
34 | | struct AggregateFunctionForEachData { |
35 | | size_t dynamic_array_size = 0; |
36 | | char* array_of_aggregate_datas = nullptr; |
37 | | }; |
38 | | |
39 | | /** Adaptor for aggregate functions. |
40 | | * Adding -ForEach suffix to aggregate function |
41 | | * will convert that aggregate function to a function, accepting arrays, |
42 | | * and applies aggregation for each corresponding elements of arrays independently, |
43 | | * returning arrays of aggregated values on corresponding positions. |
44 | | * |
45 | | * Example: sumForEach of: |
46 | | * [1, 2], |
47 | | * [3, 4, 5], |
48 | | * [6, 7] |
49 | | * will return: |
50 | | * [10, 13, 5] |
51 | | * |
52 | | * TODO Allow variable number of arguments. |
53 | | */ |
54 | | class AggregateFunctionForEach : public IAggregateFunctionDataHelper<AggregateFunctionForEachData, |
55 | | AggregateFunctionForEach>, |
56 | | VarargsExpression, |
57 | | NullableAggregateFunction { |
58 | | protected: |
59 | | using Base = |
60 | | IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>; |
61 | | |
62 | | AggregateFunctionPtr nested_function; |
63 | | const size_t nested_size_of_data; |
64 | | const size_t num_arguments; |
65 | | |
66 | | AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr __restrict place, |
67 | 30 | size_t new_size, Arena& arena) const { |
68 | 30 | AggregateFunctionForEachData& state = data(place); |
69 | | |
70 | | /// Ensure we have aggregate states for new_size elements, allocate |
71 | | /// from arena if needed. When reallocating, we can't copy the |
72 | | /// states to new buffer with memcpy, because they may contain pointers |
73 | | /// to themselves. In particular, this happens when a state contains |
74 | | /// a PODArrayWithStackMemory, which stores small number of elements |
75 | | /// inline. This is why we create new empty states in the new buffer, |
76 | | /// and merge the old states to them. |
77 | 30 | size_t old_size = state.dynamic_array_size; |
78 | 30 | if (old_size < new_size) { |
79 | 30 | static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL; |
80 | 30 | if (new_size > MAX_ARRAY_SIZE) { |
81 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
82 | 0 | "Suspiciously large array size ({}) in -ForEach aggregate function", |
83 | 0 | new_size); |
84 | 0 | } |
85 | | |
86 | 30 | size_t allocation_size = 0; |
87 | 30 | if (common::mul_overflow(new_size, nested_size_of_data, allocation_size)) { |
88 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
89 | 0 | "Allocation size ({} * {}) overflows in -ForEach aggregate " |
90 | 0 | "function, but it should've been prevented by previous checks", |
91 | 0 | new_size, nested_size_of_data); |
92 | 0 | } |
93 | | |
94 | 30 | char* old_state = state.array_of_aggregate_datas; |
95 | | |
96 | 30 | char* new_state = |
97 | 30 | arena.aligned_alloc(allocation_size, nested_function->align_of_data()); |
98 | | |
99 | 30 | size_t i; |
100 | 30 | try { |
101 | 150 | for (i = 0; i < new_size; ++i) { |
102 | 120 | nested_function->create(&new_state[i * nested_size_of_data]); |
103 | 120 | } |
104 | 30 | } catch (...) { |
105 | 0 | size_t cleanup_size = i; |
106 | |
|
107 | 0 | for (i = 0; i < cleanup_size; ++i) { |
108 | 0 | nested_function->destroy(&new_state[i * nested_size_of_data]); |
109 | 0 | } |
110 | |
|
111 | 0 | throw; |
112 | 0 | } |
113 | | |
114 | 30 | for (i = 0; i < old_size; ++i) { |
115 | 0 | nested_function->merge(&new_state[i * nested_size_of_data], |
116 | 0 | &old_state[i * nested_size_of_data], arena); |
117 | 0 | nested_function->destroy(&old_state[i * nested_size_of_data]); |
118 | 0 | } |
119 | | |
120 | 30 | state.array_of_aggregate_datas = new_state; |
121 | 30 | state.dynamic_array_size = new_size; |
122 | 30 | } |
123 | | |
124 | 30 | return state; |
125 | 30 | } |
126 | | |
127 | | public: |
128 | | constexpr static auto AGG_FOREACH_SUFFIX = "_foreach"; |
129 | | AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const DataTypes& arguments) |
130 | 2 | : Base(arguments), |
131 | 2 | nested_function {std::move(nested_function_)}, |
132 | 2 | nested_size_of_data(nested_function->size_of_data()), |
133 | 2 | num_arguments(arguments.size()) { |
134 | 2 | if (arguments.empty()) { |
135 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
136 | 0 | "Aggregate function {} require at least one argument", get_name()); |
137 | 0 | } |
138 | 2 | } |
139 | 0 | void set_version(const int version_) override { |
140 | 0 | Base::set_version(version_); |
141 | 0 | nested_function->set_version(version_); |
142 | 0 | } |
143 | | |
144 | 2 | String get_name() const override { return nested_function->get_name() + AGG_FOREACH_SUFFIX; } |
145 | | |
146 | 2 | DataTypePtr get_return_type() const override { |
147 | 2 | return std::make_shared<DataTypeArray>(nested_function->get_return_type()); |
148 | 2 | } |
149 | | |
150 | 30 | void destroy(AggregateDataPtr __restrict place) const noexcept override { |
151 | 30 | AggregateFunctionForEachData& state = data(place); |
152 | | |
153 | 30 | char* nested_state = state.array_of_aggregate_datas; |
154 | 150 | for (size_t i = 0; i < state.dynamic_array_size; ++i) { |
155 | 120 | nested_function->destroy(nested_state); |
156 | 120 | nested_state += nested_size_of_data; |
157 | 120 | } |
158 | 30 | } |
159 | | |
160 | 4 | bool is_trivial() const override { |
161 | 4 | return std::is_trivial_v<Data> && nested_function->is_trivial(); |
162 | 4 | } |
163 | | |
164 | | void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
165 | 8 | Arena& arena) const override { |
166 | 8 | const AggregateFunctionForEachData& rhs_state = data(rhs); |
167 | 8 | AggregateFunctionForEachData& state = |
168 | 8 | ensure_aggregate_data(place, rhs_state.dynamic_array_size, arena); |
169 | | |
170 | 8 | const char* rhs_nested_state = rhs_state.array_of_aggregate_datas; |
171 | 8 | char* nested_state = state.array_of_aggregate_datas; |
172 | | |
173 | 40 | for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) { |
174 | 32 | nested_function->merge(nested_state, rhs_nested_state, arena); |
175 | | |
176 | 32 | rhs_nested_state += nested_size_of_data; |
177 | 32 | nested_state += nested_size_of_data; |
178 | 32 | } |
179 | 8 | } |
180 | | |
181 | 10 | void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
182 | 10 | const AggregateFunctionForEachData& state = data(place); |
183 | 10 | buf.write_binary(state.dynamic_array_size); |
184 | 10 | const char* nested_state = state.array_of_aggregate_datas; |
185 | 50 | for (size_t i = 0; i < state.dynamic_array_size; ++i) { |
186 | 40 | nested_function->serialize(nested_state, buf); |
187 | 40 | nested_state += nested_size_of_data; |
188 | 40 | } |
189 | 10 | } |
190 | | |
191 | | void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
192 | 8 | Arena& arena) const override { |
193 | 8 | AggregateFunctionForEachData& state = data(place); |
194 | | |
195 | 8 | size_t new_size = 0; |
196 | 8 | buf.read_binary(new_size); |
197 | | |
198 | 8 | ensure_aggregate_data(place, new_size, arena); |
199 | | |
200 | 8 | char* nested_state = state.array_of_aggregate_datas; |
201 | 40 | for (size_t i = 0; i < new_size; ++i) { |
202 | 32 | nested_function->deserialize(nested_state, buf, arena); |
203 | 32 | nested_state += nested_size_of_data; |
204 | 32 | } |
205 | 8 | } |
206 | | |
207 | 12 | void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
208 | 12 | const AggregateFunctionForEachData& state = data(place); |
209 | | |
210 | 12 | auto& arr_to = assert_cast<ColumnArray&>(to); |
211 | 12 | auto& offsets_to = arr_to.get_offsets(); |
212 | 12 | IColumn& elems_to = arr_to.get_data(); |
213 | | |
214 | 12 | char* nested_state = state.array_of_aggregate_datas; |
215 | 60 | for (size_t i = 0; i < state.dynamic_array_size; ++i) { |
216 | 48 | nested_function->insert_result_into(nested_state, elems_to); |
217 | 48 | nested_state += nested_size_of_data; |
218 | 48 | } |
219 | | |
220 | 12 | offsets_to.push_back(offsets_to.back() + state.dynamic_array_size); |
221 | 12 | } |
222 | | |
223 | | void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
224 | 14 | Arena& arena) const override { |
225 | 14 | std::vector<const IColumn*> nested(num_arguments); |
226 | | |
227 | 28 | for (size_t i = 0; i < num_arguments; ++i) { |
228 | 14 | nested[i] = &assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]) |
229 | 14 | .get_data(); |
230 | 14 | } |
231 | | |
232 | 14 | const auto& first_array_column = |
233 | 14 | assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[0]); |
234 | 14 | const auto& offsets = first_array_column.get_offsets(); |
235 | | |
236 | 14 | size_t begin = offsets[row_num - 1]; |
237 | 14 | size_t end = offsets[row_num]; |
238 | | |
239 | | /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. |
240 | 14 | for (size_t i = 1; i < num_arguments; ++i) { |
241 | 0 | const auto& ith_column = |
242 | 0 | assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]); |
243 | 0 | const auto& ith_offsets = ith_column.get_offsets(); |
244 | |
|
245 | 0 | if (ith_offsets[row_num] != end || |
246 | 0 | (row_num != 0 && ith_offsets[row_num - 1] != begin)) { |
247 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
248 | 0 | "Arrays passed to {} aggregate function have different sizes", |
249 | 0 | get_name()); |
250 | 0 | } |
251 | 0 | } |
252 | | |
253 | 14 | AggregateFunctionForEachData& state = ensure_aggregate_data(place, end - begin, arena); |
254 | | |
255 | 14 | char* nested_state = state.array_of_aggregate_datas; |
256 | 70 | for (size_t i = begin; i < end; ++i) { |
257 | 56 | nested_function->add(nested_state, nested.data(), i, arena); |
258 | 56 | nested_state += nested_size_of_data; |
259 | 56 | } |
260 | 14 | } |
261 | | }; |
262 | | } // namespace doris |