Coverage Report

Created: 2026-04-14 20:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_foreach.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include "common/status.h"
24
#include "core/assert_cast.h"
25
#include "core/column/column_nullable.h"
26
#include "core/data_type/data_type_array.h"
27
#include "core/data_type/data_type_nullable.h"
28
#include "exec/common/arithmetic_overflow.h"
29
#include "exprs/aggregate/aggregate_function.h"
30
#include "exprs/function/array/function_array_utils.h"
31
32
namespace doris {
33
34
struct AggregateFunctionForEachData {
35
    size_t dynamic_array_size = 0;
36
    char* array_of_aggregate_datas = nullptr;
37
};
38
39
/** Adaptor for aggregate functions.
40
  * Adding -ForEach suffix to aggregate function
41
  *  will convert that aggregate function to a function, accepting arrays,
42
  *  and applies aggregation for each corresponding elements of arrays independently,
43
  *  returning arrays of aggregated values on corresponding positions.
44
  *
45
  * Example: sumForEach of:
46
  *  [1, 2],
47
  *  [3, 4, 5],
48
  *  [6, 7]
49
  * will return:
50
  *  [10, 13, 5]
51
  *
52
  * TODO Allow variable number of arguments.
53
  */
54
class AggregateFunctionForEach : public IAggregateFunctionDataHelper<AggregateFunctionForEachData,
55
                                                                     AggregateFunctionForEach>,
56
                                 VarargsExpression,
57
                                 NullableAggregateFunction {
58
protected:
59
    using Base =
60
            IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>;
61
62
    AggregateFunctionPtr nested_function;
63
    const size_t nested_size_of_data;
64
    const size_t num_arguments;
65
66
    AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr __restrict place,
67
30
                                                        size_t new_size, Arena& arena) const {
68
30
        AggregateFunctionForEachData& state = data(place);
69
70
        /// Ensure we have aggregate states for new_size elements, allocate
71
        /// from arena if needed. When reallocating, we can't copy the
72
        /// states to new buffer with memcpy, because they may contain pointers
73
        /// to themselves. In particular, this happens when a state contains
74
        /// a PODArrayWithStackMemory, which stores small number of elements
75
        /// inline. This is why we create new empty states in the new buffer,
76
        /// and merge the old states to them.
77
30
        size_t old_size = state.dynamic_array_size;
78
30
        if (old_size < new_size) {
79
30
            static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL;
80
30
            if (new_size > MAX_ARRAY_SIZE) {
81
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
82
0
                                "Suspiciously large array size ({}) in -ForEach aggregate function",
83
0
                                new_size);
84
0
            }
85
86
30
            size_t allocation_size = 0;
87
30
            if (common::mul_overflow(new_size, nested_size_of_data, allocation_size)) {
88
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
89
0
                                "Allocation size ({} * {}) overflows in -ForEach aggregate "
90
0
                                "function, but it should've been prevented by previous checks",
91
0
                                new_size, nested_size_of_data);
92
0
            }
93
94
30
            char* old_state = state.array_of_aggregate_datas;
95
96
30
            char* new_state =
97
30
                    arena.aligned_alloc(allocation_size, nested_function->align_of_data());
98
99
30
            size_t i;
100
30
            try {
101
150
                for (i = 0; i < new_size; ++i) {
102
120
                    nested_function->create(&new_state[i * nested_size_of_data]);
103
120
                }
104
30
            } catch (...) {
105
0
                size_t cleanup_size = i;
106
107
0
                for (i = 0; i < cleanup_size; ++i) {
108
0
                    nested_function->destroy(&new_state[i * nested_size_of_data]);
109
0
                }
110
111
0
                throw;
112
0
            }
113
114
30
            for (i = 0; i < old_size; ++i) {
115
0
                nested_function->merge(&new_state[i * nested_size_of_data],
116
0
                                       &old_state[i * nested_size_of_data], arena);
117
0
                nested_function->destroy(&old_state[i * nested_size_of_data]);
118
0
            }
119
120
30
            state.array_of_aggregate_datas = new_state;
121
30
            state.dynamic_array_size = new_size;
122
30
        }
123
124
30
        return state;
125
30
    }
126
127
public:
128
    constexpr static auto AGG_FOREACH_SUFFIX = "_foreach";
129
    AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const DataTypes& arguments)
130
2
            : Base(arguments),
131
2
              nested_function {std::move(nested_function_)},
132
2
              nested_size_of_data(nested_function->size_of_data()),
133
2
              num_arguments(arguments.size()) {
134
2
        if (arguments.empty()) {
135
0
            throw Exception(ErrorCode::INTERNAL_ERROR,
136
0
                            "Aggregate function {} require at least one argument", get_name());
137
0
        }
138
2
    }
139
0
    void set_version(const int version_) override {
140
0
        Base::set_version(version_);
141
0
        nested_function->set_version(version_);
142
0
    }
143
144
2
    String get_name() const override { return nested_function->get_name() + AGG_FOREACH_SUFFIX; }
145
146
2
    DataTypePtr get_return_type() const override {
147
2
        return std::make_shared<DataTypeArray>(nested_function->get_return_type());
148
2
    }
149
150
30
    void destroy(AggregateDataPtr __restrict place) const noexcept override {
151
30
        AggregateFunctionForEachData& state = data(place);
152
153
30
        char* nested_state = state.array_of_aggregate_datas;
154
150
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
155
120
            nested_function->destroy(nested_state);
156
120
            nested_state += nested_size_of_data;
157
120
        }
158
30
    }
159
160
4
    bool is_trivial() const override {
161
4
        return std::is_trivial_v<Data> && nested_function->is_trivial();
162
4
    }
163
164
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
165
8
               Arena& arena) const override {
166
8
        const AggregateFunctionForEachData& rhs_state = data(rhs);
167
8
        AggregateFunctionForEachData& state =
168
8
                ensure_aggregate_data(place, rhs_state.dynamic_array_size, arena);
169
170
8
        const char* rhs_nested_state = rhs_state.array_of_aggregate_datas;
171
8
        char* nested_state = state.array_of_aggregate_datas;
172
173
40
        for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) {
174
32
            nested_function->merge(nested_state, rhs_nested_state, arena);
175
176
32
            rhs_nested_state += nested_size_of_data;
177
32
            nested_state += nested_size_of_data;
178
32
        }
179
8
    }
180
181
10
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
182
10
        const AggregateFunctionForEachData& state = data(place);
183
10
        buf.write_binary(state.dynamic_array_size);
184
10
        const char* nested_state = state.array_of_aggregate_datas;
185
50
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
186
40
            nested_function->serialize(nested_state, buf);
187
40
            nested_state += nested_size_of_data;
188
40
        }
189
10
    }
190
191
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
192
8
                     Arena& arena) const override {
193
8
        AggregateFunctionForEachData& state = data(place);
194
195
8
        size_t new_size = 0;
196
8
        buf.read_binary(new_size);
197
198
8
        ensure_aggregate_data(place, new_size, arena);
199
200
8
        char* nested_state = state.array_of_aggregate_datas;
201
40
        for (size_t i = 0; i < new_size; ++i) {
202
32
            nested_function->deserialize(nested_state, buf, arena);
203
32
            nested_state += nested_size_of_data;
204
32
        }
205
8
    }
206
207
12
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
208
12
        const AggregateFunctionForEachData& state = data(place);
209
210
12
        auto& arr_to = assert_cast<ColumnArray&>(to);
211
12
        auto& offsets_to = arr_to.get_offsets();
212
12
        IColumn& elems_to = arr_to.get_data();
213
214
12
        char* nested_state = state.array_of_aggregate_datas;
215
60
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
216
48
            nested_function->insert_result_into(nested_state, elems_to);
217
48
            nested_state += nested_size_of_data;
218
48
        }
219
220
12
        offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
221
12
    }
222
223
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
224
14
             Arena& arena) const override {
225
14
        std::vector<const IColumn*> nested(num_arguments);
226
227
28
        for (size_t i = 0; i < num_arguments; ++i) {
228
14
            nested[i] = &assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i])
229
14
                                 .get_data();
230
14
        }
231
232
14
        const auto& first_array_column =
233
14
                assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[0]);
234
14
        const auto& offsets = first_array_column.get_offsets();
235
236
14
        size_t begin = offsets[row_num - 1];
237
14
        size_t end = offsets[row_num];
238
239
        /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
240
14
        for (size_t i = 1; i < num_arguments; ++i) {
241
0
            const auto& ith_column =
242
0
                    assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]);
243
0
            const auto& ith_offsets = ith_column.get_offsets();
244
245
0
            if (ith_offsets[row_num] != end ||
246
0
                (row_num != 0 && ith_offsets[row_num - 1] != begin)) {
247
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
248
0
                                "Arrays passed to {} aggregate function have different sizes",
249
0
                                get_name());
250
0
            }
251
0
        }
252
253
14
        AggregateFunctionForEachData& state = ensure_aggregate_data(place, end - begin, arena);
254
255
14
        char* nested_state = state.array_of_aggregate_datas;
256
70
        for (size_t i = begin; i < end; ++i) {
257
56
            nested_function->add(nested_state, nested.data(), i, arena);
258
56
            nested_state += nested_size_of_data;
259
56
        }
260
14
    }
261
};
262
} // namespace doris