Coverage Report

Created: 2026-03-12 02:33

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_foreach.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include "common/status.h"
24
#include "core/assert_cast.h"
25
#include "core/column/column_nullable.h"
26
#include "core/data_type/data_type_array.h"
27
#include "core/data_type/data_type_nullable.h"
28
#include "exec/common/arithmetic_overflow.h"
29
#include "exprs/aggregate/aggregate_function.h"
30
#include "exprs/function/array/function_array_utils.h"
31
32
namespace doris {
33
#include "common/compile_check_begin.h"
34
35
struct AggregateFunctionForEachData {
36
    size_t dynamic_array_size = 0;
37
    char* array_of_aggregate_datas = nullptr;
38
};
39
40
/** Adaptor for aggregate functions.
41
  * Adding -ForEach suffix to aggregate function
42
  *  will convert that aggregate function to a function, accepting arrays,
43
  *  and applies aggregation for each corresponding elements of arrays independently,
44
  *  returning arrays of aggregated values on corresponding positions.
45
  *
46
  * Example: sumForEach of:
47
  *  [1, 2],
48
  *  [3, 4, 5],
49
  *  [6, 7]
50
  * will return:
51
  *  [10, 13, 5]
52
  *
53
  * TODO Allow variable number of arguments.
54
  */
55
class AggregateFunctionForEach : public IAggregateFunctionDataHelper<AggregateFunctionForEachData,
56
                                                                     AggregateFunctionForEach>,
57
                                 VarargsExpression,
58
                                 NullableAggregateFunction {
59
protected:
60
    using Base =
61
            IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>;
62
63
    AggregateFunctionPtr nested_function;
64
    const size_t nested_size_of_data;
65
    const size_t num_arguments;
66
67
    AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr __restrict place,
68
30
                                                        size_t new_size, Arena& arena) const {
69
30
        AggregateFunctionForEachData& state = data(place);
70
71
        /// Ensure we have aggregate states for new_size elements, allocate
72
        /// from arena if needed. When reallocating, we can't copy the
73
        /// states to new buffer with memcpy, because they may contain pointers
74
        /// to themselves. In particular, this happens when a state contains
75
        /// a PODArrayWithStackMemory, which stores small number of elements
76
        /// inline. This is why we create new empty states in the new buffer,
77
        /// and merge the old states to them.
78
30
        size_t old_size = state.dynamic_array_size;
79
30
        if (old_size < new_size) {
80
30
            static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL;
81
30
            if (new_size > MAX_ARRAY_SIZE) {
82
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
83
0
                                "Suspiciously large array size ({}) in -ForEach aggregate function",
84
0
                                new_size);
85
0
            }
86
87
30
            size_t allocation_size = 0;
88
30
            if (common::mul_overflow(new_size, nested_size_of_data, allocation_size)) {
89
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
90
0
                                "Allocation size ({} * {}) overflows in -ForEach aggregate "
91
0
                                "function, but it should've been prevented by previous checks",
92
0
                                new_size, nested_size_of_data);
93
0
            }
94
95
30
            char* old_state = state.array_of_aggregate_datas;
96
97
30
            char* new_state =
98
30
                    arena.aligned_alloc(allocation_size, nested_function->align_of_data());
99
100
30
            size_t i;
101
30
            try {
102
150
                for (i = 0; i < new_size; ++i) {
103
120
                    nested_function->create(&new_state[i * nested_size_of_data]);
104
120
                }
105
30
            } catch (...) {
106
0
                size_t cleanup_size = i;
107
108
0
                for (i = 0; i < cleanup_size; ++i) {
109
0
                    nested_function->destroy(&new_state[i * nested_size_of_data]);
110
0
                }
111
112
0
                throw;
113
0
            }
114
115
30
            for (i = 0; i < old_size; ++i) {
116
0
                nested_function->merge(&new_state[i * nested_size_of_data],
117
0
                                       &old_state[i * nested_size_of_data], arena);
118
0
                nested_function->destroy(&old_state[i * nested_size_of_data]);
119
0
            }
120
121
30
            state.array_of_aggregate_datas = new_state;
122
30
            state.dynamic_array_size = new_size;
123
30
        }
124
125
30
        return state;
126
30
    }
127
128
public:
129
    constexpr static auto AGG_FOREACH_SUFFIX = "_foreach";
130
    AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const DataTypes& arguments)
131
2
            : Base(arguments),
132
2
              nested_function {std::move(nested_function_)},
133
2
              nested_size_of_data(nested_function->size_of_data()),
134
2
              num_arguments(arguments.size()) {
135
2
        if (arguments.empty()) {
136
0
            throw Exception(ErrorCode::INTERNAL_ERROR,
137
0
                            "Aggregate function {} require at least one argument", get_name());
138
0
        }
139
2
    }
140
0
    void set_version(const int version_) override {
141
0
        Base::set_version(version_);
142
0
        nested_function->set_version(version_);
143
0
    }
144
145
2
    String get_name() const override { return nested_function->get_name() + AGG_FOREACH_SUFFIX; }
146
147
2
    DataTypePtr get_return_type() const override {
148
2
        return std::make_shared<DataTypeArray>(nested_function->get_return_type());
149
2
    }
150
151
30
    void destroy(AggregateDataPtr __restrict place) const noexcept override {
152
30
        AggregateFunctionForEachData& state = data(place);
153
154
30
        char* nested_state = state.array_of_aggregate_datas;
155
150
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
156
120
            nested_function->destroy(nested_state);
157
120
            nested_state += nested_size_of_data;
158
120
        }
159
30
    }
160
161
4
    bool is_trivial() const override {
162
4
        return std::is_trivial_v<Data> && nested_function->is_trivial();
163
4
    }
164
165
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
166
8
               Arena& arena) const override {
167
8
        const AggregateFunctionForEachData& rhs_state = data(rhs);
168
8
        AggregateFunctionForEachData& state =
169
8
                ensure_aggregate_data(place, rhs_state.dynamic_array_size, arena);
170
171
8
        const char* rhs_nested_state = rhs_state.array_of_aggregate_datas;
172
8
        char* nested_state = state.array_of_aggregate_datas;
173
174
40
        for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) {
175
32
            nested_function->merge(nested_state, rhs_nested_state, arena);
176
177
32
            rhs_nested_state += nested_size_of_data;
178
32
            nested_state += nested_size_of_data;
179
32
        }
180
8
    }
181
182
10
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
183
10
        const AggregateFunctionForEachData& state = data(place);
184
10
        buf.write_binary(state.dynamic_array_size);
185
10
        const char* nested_state = state.array_of_aggregate_datas;
186
50
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
187
40
            nested_function->serialize(nested_state, buf);
188
40
            nested_state += nested_size_of_data;
189
40
        }
190
10
    }
191
192
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
193
8
                     Arena& arena) const override {
194
8
        AggregateFunctionForEachData& state = data(place);
195
196
8
        size_t new_size = 0;
197
8
        buf.read_binary(new_size);
198
199
8
        ensure_aggregate_data(place, new_size, arena);
200
201
8
        char* nested_state = state.array_of_aggregate_datas;
202
40
        for (size_t i = 0; i < new_size; ++i) {
203
32
            nested_function->deserialize(nested_state, buf, arena);
204
32
            nested_state += nested_size_of_data;
205
32
        }
206
8
    }
207
208
12
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
209
12
        const AggregateFunctionForEachData& state = data(place);
210
211
12
        auto& arr_to = assert_cast<ColumnArray&>(to);
212
12
        auto& offsets_to = arr_to.get_offsets();
213
12
        IColumn& elems_to = arr_to.get_data();
214
215
12
        char* nested_state = state.array_of_aggregate_datas;
216
60
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
217
48
            nested_function->insert_result_into(nested_state, elems_to);
218
48
            nested_state += nested_size_of_data;
219
48
        }
220
221
12
        offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
222
12
    }
223
224
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
225
14
             Arena& arena) const override {
226
14
        std::vector<const IColumn*> nested(num_arguments);
227
228
28
        for (size_t i = 0; i < num_arguments; ++i) {
229
14
            nested[i] = &assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i])
230
14
                                 .get_data();
231
14
        }
232
233
14
        const auto& first_array_column =
234
14
                assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[0]);
235
14
        const auto& offsets = first_array_column.get_offsets();
236
237
14
        size_t begin = offsets[row_num - 1];
238
14
        size_t end = offsets[row_num];
239
240
        /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
241
14
        for (size_t i = 1; i < num_arguments; ++i) {
242
0
            const auto& ith_column =
243
0
                    assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]);
244
0
            const auto& ith_offsets = ith_column.get_offsets();
245
246
0
            if (ith_offsets[row_num] != end ||
247
0
                (row_num != 0 && ith_offsets[row_num - 1] != begin)) {
248
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
249
0
                                "Arrays passed to {} aggregate function have different sizes",
250
0
                                get_name());
251
0
            }
252
0
        }
253
254
14
        AggregateFunctionForEachData& state = ensure_aggregate_data(place, end - begin, arena);
255
256
14
        char* nested_state = state.array_of_aggregate_datas;
257
70
        for (size_t i = begin; i < end; ++i) {
258
56
            nested_function->add(nested_state, nested.data(), i, arena);
259
56
            nested_state += nested_size_of_data;
260
56
        }
261
14
    }
262
};
263
} // namespace doris
264
265
#include "common/compile_check_end.h"