be/src/exprs/aggregate/aggregate_function_foreach.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h |
19 | | // and modified by Doris |
20 | | |
21 | | #pragma once |
22 | | |
23 | | #include "common/status.h" |
24 | | #include "core/assert_cast.h" |
25 | | #include "core/column/column_nullable.h" |
26 | | #include "core/data_type/data_type_array.h" |
27 | | #include "core/data_type/data_type_nullable.h" |
28 | | #include "exec/common/arithmetic_overflow.h" |
29 | | #include "exprs/aggregate/aggregate_function.h" |
30 | | #include "exprs/function/array/function_array_utils.h" |
31 | | |
32 | | namespace doris { |
33 | | #include "common/compile_check_begin.h" |
34 | | |
35 | | struct AggregateFunctionForEachData { |
36 | | size_t dynamic_array_size = 0; |
37 | | char* array_of_aggregate_datas = nullptr; |
38 | | }; |
39 | | |
40 | | /** Adaptor for aggregate functions. |
41 | | * Adding -ForEach suffix to aggregate function |
42 | | * will convert that aggregate function to a function, accepting arrays, |
43 | | * and applies aggregation for each corresponding elements of arrays independently, |
44 | | * returning arrays of aggregated values on corresponding positions. |
45 | | * |
46 | | * Example: sumForEach of: |
47 | | * [1, 2], |
48 | | * [3, 4, 5], |
49 | | * [6, 7] |
50 | | * will return: |
51 | | * [10, 13, 5] |
52 | | * |
53 | | * TODO Allow variable number of arguments. |
54 | | */ |
55 | | class AggregateFunctionForEach : public IAggregateFunctionDataHelper<AggregateFunctionForEachData, |
56 | | AggregateFunctionForEach>, |
57 | | VarargsExpression, |
58 | | NullableAggregateFunction { |
59 | | protected: |
60 | | using Base = |
61 | | IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>; |
62 | | |
63 | | AggregateFunctionPtr nested_function; |
64 | | const size_t nested_size_of_data; |
65 | | const size_t num_arguments; |
66 | | |
67 | | AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr __restrict place, |
68 | 30 | size_t new_size, Arena& arena) const { |
69 | 30 | AggregateFunctionForEachData& state = data(place); |
70 | | |
71 | | /// Ensure we have aggregate states for new_size elements, allocate |
72 | | /// from arena if needed. When reallocating, we can't copy the |
73 | | /// states to new buffer with memcpy, because they may contain pointers |
74 | | /// to themselves. In particular, this happens when a state contains |
75 | | /// a PODArrayWithStackMemory, which stores small number of elements |
76 | | /// inline. This is why we create new empty states in the new buffer, |
77 | | /// and merge the old states to them. |
78 | 30 | size_t old_size = state.dynamic_array_size; |
79 | 30 | if (old_size < new_size) { |
80 | 30 | static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL; |
81 | 30 | if (new_size > MAX_ARRAY_SIZE) { |
82 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
83 | 0 | "Suspiciously large array size ({}) in -ForEach aggregate function", |
84 | 0 | new_size); |
85 | 0 | } |
86 | | |
87 | 30 | size_t allocation_size = 0; |
88 | 30 | if (common::mul_overflow(new_size, nested_size_of_data, allocation_size)) { |
89 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
90 | 0 | "Allocation size ({} * {}) overflows in -ForEach aggregate " |
91 | 0 | "function, but it should've been prevented by previous checks", |
92 | 0 | new_size, nested_size_of_data); |
93 | 0 | } |
94 | | |
95 | 30 | char* old_state = state.array_of_aggregate_datas; |
96 | | |
97 | 30 | char* new_state = |
98 | 30 | arena.aligned_alloc(allocation_size, nested_function->align_of_data()); |
99 | | |
100 | 30 | size_t i; |
101 | 30 | try { |
102 | 150 | for (i = 0; i < new_size; ++i) { |
103 | 120 | nested_function->create(&new_state[i * nested_size_of_data]); |
104 | 120 | } |
105 | 30 | } catch (...) { |
106 | 0 | size_t cleanup_size = i; |
107 | |
|
108 | 0 | for (i = 0; i < cleanup_size; ++i) { |
109 | 0 | nested_function->destroy(&new_state[i * nested_size_of_data]); |
110 | 0 | } |
111 | |
|
112 | 0 | throw; |
113 | 0 | } |
114 | | |
115 | 30 | for (i = 0; i < old_size; ++i) { |
116 | 0 | nested_function->merge(&new_state[i * nested_size_of_data], |
117 | 0 | &old_state[i * nested_size_of_data], arena); |
118 | 0 | nested_function->destroy(&old_state[i * nested_size_of_data]); |
119 | 0 | } |
120 | | |
121 | 30 | state.array_of_aggregate_datas = new_state; |
122 | 30 | state.dynamic_array_size = new_size; |
123 | 30 | } |
124 | | |
125 | 30 | return state; |
126 | 30 | } |
127 | | |
128 | | public: |
129 | | constexpr static auto AGG_FOREACH_SUFFIX = "_foreach"; |
130 | | AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const DataTypes& arguments) |
131 | 2 | : Base(arguments), |
132 | 2 | nested_function {std::move(nested_function_)}, |
133 | 2 | nested_size_of_data(nested_function->size_of_data()), |
134 | 2 | num_arguments(arguments.size()) { |
135 | 2 | if (arguments.empty()) { |
136 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
137 | 0 | "Aggregate function {} require at least one argument", get_name()); |
138 | 0 | } |
139 | 2 | } |
140 | 0 | void set_version(const int version_) override { |
141 | 0 | Base::set_version(version_); |
142 | 0 | nested_function->set_version(version_); |
143 | 0 | } |
144 | | |
145 | 2 | String get_name() const override { return nested_function->get_name() + AGG_FOREACH_SUFFIX; } |
146 | | |
147 | 2 | DataTypePtr get_return_type() const override { |
148 | 2 | return std::make_shared<DataTypeArray>(nested_function->get_return_type()); |
149 | 2 | } |
150 | | |
151 | 30 | void destroy(AggregateDataPtr __restrict place) const noexcept override { |
152 | 30 | AggregateFunctionForEachData& state = data(place); |
153 | | |
154 | 30 | char* nested_state = state.array_of_aggregate_datas; |
155 | 150 | for (size_t i = 0; i < state.dynamic_array_size; ++i) { |
156 | 120 | nested_function->destroy(nested_state); |
157 | 120 | nested_state += nested_size_of_data; |
158 | 120 | } |
159 | 30 | } |
160 | | |
161 | 4 | bool is_trivial() const override { |
162 | 4 | return std::is_trivial_v<Data> && nested_function->is_trivial(); |
163 | 4 | } |
164 | | |
165 | | void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
166 | 8 | Arena& arena) const override { |
167 | 8 | const AggregateFunctionForEachData& rhs_state = data(rhs); |
168 | 8 | AggregateFunctionForEachData& state = |
169 | 8 | ensure_aggregate_data(place, rhs_state.dynamic_array_size, arena); |
170 | | |
171 | 8 | const char* rhs_nested_state = rhs_state.array_of_aggregate_datas; |
172 | 8 | char* nested_state = state.array_of_aggregate_datas; |
173 | | |
174 | 40 | for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) { |
175 | 32 | nested_function->merge(nested_state, rhs_nested_state, arena); |
176 | | |
177 | 32 | rhs_nested_state += nested_size_of_data; |
178 | 32 | nested_state += nested_size_of_data; |
179 | 32 | } |
180 | 8 | } |
181 | | |
182 | 10 | void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
183 | 10 | const AggregateFunctionForEachData& state = data(place); |
184 | 10 | buf.write_binary(state.dynamic_array_size); |
185 | 10 | const char* nested_state = state.array_of_aggregate_datas; |
186 | 50 | for (size_t i = 0; i < state.dynamic_array_size; ++i) { |
187 | 40 | nested_function->serialize(nested_state, buf); |
188 | 40 | nested_state += nested_size_of_data; |
189 | 40 | } |
190 | 10 | } |
191 | | |
192 | | void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
193 | 8 | Arena& arena) const override { |
194 | 8 | AggregateFunctionForEachData& state = data(place); |
195 | | |
196 | 8 | size_t new_size = 0; |
197 | 8 | buf.read_binary(new_size); |
198 | | |
199 | 8 | ensure_aggregate_data(place, new_size, arena); |
200 | | |
201 | 8 | char* nested_state = state.array_of_aggregate_datas; |
202 | 40 | for (size_t i = 0; i < new_size; ++i) { |
203 | 32 | nested_function->deserialize(nested_state, buf, arena); |
204 | 32 | nested_state += nested_size_of_data; |
205 | 32 | } |
206 | 8 | } |
207 | | |
208 | 12 | void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
209 | 12 | const AggregateFunctionForEachData& state = data(place); |
210 | | |
211 | 12 | auto& arr_to = assert_cast<ColumnArray&>(to); |
212 | 12 | auto& offsets_to = arr_to.get_offsets(); |
213 | 12 | IColumn& elems_to = arr_to.get_data(); |
214 | | |
215 | 12 | char* nested_state = state.array_of_aggregate_datas; |
216 | 60 | for (size_t i = 0; i < state.dynamic_array_size; ++i) { |
217 | 48 | nested_function->insert_result_into(nested_state, elems_to); |
218 | 48 | nested_state += nested_size_of_data; |
219 | 48 | } |
220 | | |
221 | 12 | offsets_to.push_back(offsets_to.back() + state.dynamic_array_size); |
222 | 12 | } |
223 | | |
224 | | void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
225 | 14 | Arena& arena) const override { |
226 | 14 | std::vector<const IColumn*> nested(num_arguments); |
227 | | |
228 | 28 | for (size_t i = 0; i < num_arguments; ++i) { |
229 | 14 | nested[i] = &assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]) |
230 | 14 | .get_data(); |
231 | 14 | } |
232 | | |
233 | 14 | const auto& first_array_column = |
234 | 14 | assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[0]); |
235 | 14 | const auto& offsets = first_array_column.get_offsets(); |
236 | | |
237 | 14 | size_t begin = offsets[row_num - 1]; |
238 | 14 | size_t end = offsets[row_num]; |
239 | | |
240 | | /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. |
241 | 14 | for (size_t i = 1; i < num_arguments; ++i) { |
242 | 0 | const auto& ith_column = |
243 | 0 | assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]); |
244 | 0 | const auto& ith_offsets = ith_column.get_offsets(); |
245 | |
|
246 | 0 | if (ith_offsets[row_num] != end || |
247 | 0 | (row_num != 0 && ith_offsets[row_num - 1] != begin)) { |
248 | 0 | throw Exception(ErrorCode::INTERNAL_ERROR, |
249 | 0 | "Arrays passed to {} aggregate function have different sizes", |
250 | 0 | get_name()); |
251 | 0 | } |
252 | 0 | } |
253 | | |
254 | 14 | AggregateFunctionForEachData& state = ensure_aggregate_data(place, end - begin, arena); |
255 | | |
256 | 14 | char* nested_state = state.array_of_aggregate_datas; |
257 | 70 | for (size_t i = begin; i < end; ++i) { |
258 | 56 | nested_function->add(nested_state, nested.data(), i, arena); |
259 | 56 | nested_state += nested_size_of_data; |
260 | 56 | } |
261 | 14 | } |
262 | | }; |
263 | | } // namespace doris |
264 | | |
265 | | #include "common/compile_check_end.h" |