Coverage Report

Created: 2026-05-19 15:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_foreach.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include "common/status.h"
24
#include "core/assert_cast.h"
25
#include "core/column/column_nullable.h"
26
#include "core/data_type/data_type_array.h"
27
#include "core/data_type/data_type_nullable.h"
28
#include "exec/common/arithmetic_overflow.h"
29
#include "exprs/aggregate/aggregate_function.h"
30
#include "exprs/function/array/function_array_utils.h"
31
32
namespace doris {
33
34
struct AggregateFunctionForEachData {
35
    size_t dynamic_array_size = 0;
36
    char* array_of_aggregate_datas = nullptr;
37
};
38
39
/** Adaptor for aggregate functions.
40
  * Adding -ForEach suffix to aggregate function
41
  *  will convert that aggregate function to a function, accepting arrays,
42
  *  and applies aggregation for each corresponding elements of arrays independently,
43
  *  returning arrays of aggregated values on corresponding positions.
44
  *
45
  * Example: sumForEach of:
46
  *  [1, 2],
47
  *  [3, 4, 5],
48
  *  [6, 7]
49
  * will return:
50
  *  [10, 13, 5]
51
  *
52
  * TODO Allow variable number of arguments.
53
  */
54
class AggregateFunctionForEach : public AggregateFunctionNonFinalBase,
55
                                 public IAggregateFunctionDataHelper<AggregateFunctionForEachData,
56
                                                                     AggregateFunctionForEach>,
57
                                 VarargsExpression,
58
                                 NullableAggregateFunction {
59
protected:
60
    using Base =
61
            IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>;
62
63
    AggregateFunctionPtr nested_function;
64
    const size_t nested_size_of_data;
65
    const size_t num_arguments;
66
67
    AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr __restrict place,
68
0
                                                        size_t new_size, Arena& arena) const {
69
0
        AggregateFunctionForEachData& state = data(place);
70
71
        /// Ensure we have aggregate states for new_size elements, allocate
72
        /// from arena if needed. When reallocating, we can't copy the
73
        /// states to new buffer with memcpy, because they may contain pointers
74
        /// to themselves. In particular, this happens when a state contains
75
        /// a PODArrayWithStackMemory, which stores small number of elements
76
        /// inline. This is why we create new empty states in the new buffer,
77
        /// and merge the old states to them.
78
0
        size_t old_size = state.dynamic_array_size;
79
0
        if (old_size < new_size) {
80
0
            static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL;
81
0
            if (new_size > MAX_ARRAY_SIZE) {
82
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
83
0
                                "Suspiciously large array size ({}) in -ForEach aggregate function",
84
0
                                new_size);
85
0
            }
86
87
0
            size_t allocation_size = 0;
88
0
            if (common::mul_overflow(new_size, nested_size_of_data, allocation_size)) {
89
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
90
0
                                "Allocation size ({} * {}) overflows in -ForEach aggregate "
91
0
                                "function, but it should've been prevented by previous checks",
92
0
                                new_size, nested_size_of_data);
93
0
            }
94
95
0
            char* old_state = state.array_of_aggregate_datas;
96
97
0
            char* new_state =
98
0
                    arena.aligned_alloc(allocation_size, nested_function->align_of_data());
99
100
0
            size_t i;
101
0
            try {
102
0
                for (i = 0; i < new_size; ++i) {
103
0
                    nested_function->create(&new_state[i * nested_size_of_data]);
104
0
                }
105
0
            } catch (...) {
106
0
                size_t cleanup_size = i;
107
108
0
                for (i = 0; i < cleanup_size; ++i) {
109
0
                    nested_function->destroy(&new_state[i * nested_size_of_data]);
110
0
                }
111
112
0
                throw;
113
0
            }
114
115
0
            for (i = 0; i < old_size; ++i) {
116
0
                nested_function->merge(&new_state[i * nested_size_of_data],
117
0
                                       &old_state[i * nested_size_of_data], arena);
118
0
                nested_function->destroy(&old_state[i * nested_size_of_data]);
119
0
            }
120
121
0
            state.array_of_aggregate_datas = new_state;
122
0
            state.dynamic_array_size = new_size;
123
0
        }
124
125
0
        return state;
126
0
    }
127
128
public:
129
    constexpr static auto AGG_FOREACH_SUFFIX = "_foreach";
130
    AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const DataTypes& arguments)
131
0
            : Base(arguments),
132
0
              nested_function {std::move(nested_function_)},
133
0
              nested_size_of_data(nested_function->size_of_data()),
134
0
              num_arguments(arguments.size()) {
135
0
        if (arguments.empty()) {
136
0
            throw Exception(ErrorCode::INTERNAL_ERROR,
137
0
                            "Aggregate function {} require at least one argument", get_name());
138
0
        }
139
0
    }
140
0
    void set_version(const int version_) override {
141
0
        Base::set_version(version_);
142
0
        nested_function->set_version(version_);
143
0
    }
144
145
0
    String get_name() const override { return nested_function->get_name() + AGG_FOREACH_SUFFIX; }
146
147
0
    DataTypePtr get_return_type() const override {
148
0
        return std::make_shared<DataTypeArray>(nested_function->get_return_type());
149
0
    }
150
151
0
    void destroy(AggregateDataPtr __restrict place) const noexcept override {
152
0
        AggregateFunctionForEachData& state = data(place);
153
154
0
        char* nested_state = state.array_of_aggregate_datas;
155
0
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
156
0
            nested_function->destroy(nested_state);
157
0
            nested_state += nested_size_of_data;
158
0
        }
159
0
    }
160
161
0
    bool is_trivial() const override {
162
0
        return std::is_trivial_v<Data> && nested_function->is_trivial();
163
0
    }
164
165
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
166
0
               Arena& arena) const override {
167
0
        const AggregateFunctionForEachData& rhs_state = data(rhs);
168
0
        AggregateFunctionForEachData& state =
169
0
                ensure_aggregate_data(place, rhs_state.dynamic_array_size, arena);
170
171
0
        const char* rhs_nested_state = rhs_state.array_of_aggregate_datas;
172
0
        char* nested_state = state.array_of_aggregate_datas;
173
174
0
        for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) {
175
0
            nested_function->merge(nested_state, rhs_nested_state, arena);
176
177
0
            rhs_nested_state += nested_size_of_data;
178
0
            nested_state += nested_size_of_data;
179
0
        }
180
0
    }
181
182
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
183
0
        const AggregateFunctionForEachData& state = data(place);
184
0
        buf.write_binary(state.dynamic_array_size);
185
0
        const char* nested_state = state.array_of_aggregate_datas;
186
0
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
187
0
            nested_function->serialize(nested_state, buf);
188
0
            nested_state += nested_size_of_data;
189
0
        }
190
0
    }
191
192
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
193
0
                     Arena& arena) const override {
194
0
        AggregateFunctionForEachData& state = data(place);
195
196
0
        size_t new_size = 0;
197
0
        buf.read_binary(new_size);
198
199
0
        ensure_aggregate_data(place, new_size, arena);
200
201
0
        char* nested_state = state.array_of_aggregate_datas;
202
0
        for (size_t i = 0; i < new_size; ++i) {
203
0
            nested_function->deserialize(nested_state, buf, arena);
204
0
            nested_state += nested_size_of_data;
205
0
        }
206
0
    }
207
208
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
209
0
        const AggregateFunctionForEachData& state = data(place);
210
211
0
        auto& arr_to = assert_cast<ColumnArray&>(to);
212
0
        auto& offsets_to = arr_to.get_offsets();
213
0
        IColumn& elems_to = arr_to.get_data();
214
215
0
        char* nested_state = state.array_of_aggregate_datas;
216
0
        for (size_t i = 0; i < state.dynamic_array_size; ++i) {
217
0
            nested_function->insert_result_into(nested_state, elems_to);
218
0
            nested_state += nested_size_of_data;
219
0
        }
220
221
0
        offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
222
0
    }
223
224
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
225
0
             Arena& arena) const override {
226
0
        std::vector<const IColumn*> nested(num_arguments);
227
228
0
        for (size_t i = 0; i < num_arguments; ++i) {
229
0
            nested[i] = &assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i])
230
0
                                 .get_data();
231
0
        }
232
233
0
        const auto& first_array_column =
234
0
                assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[0]);
235
0
        const auto& offsets = first_array_column.get_offsets();
236
237
0
        size_t begin = offsets[row_num - 1];
238
0
        size_t end = offsets[row_num];
239
240
        /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
241
0
        for (size_t i = 1; i < num_arguments; ++i) {
242
0
            const auto& ith_column =
243
0
                    assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[i]);
244
0
            const auto& ith_offsets = ith_column.get_offsets();
245
246
0
            if (ith_offsets[row_num] != end ||
247
0
                (row_num != 0 && ith_offsets[row_num - 1] != begin)) {
248
0
                throw Exception(ErrorCode::INTERNAL_ERROR,
249
0
                                "Arrays passed to {} aggregate function have different sizes",
250
0
                                get_name());
251
0
            }
252
0
        }
253
254
0
        AggregateFunctionForEachData& state = ensure_aggregate_data(place, end - begin, arena);
255
256
0
        char* nested_state = state.array_of_aggregate_datas;
257
0
        for (size_t i = begin; i < end; ++i) {
258
0
            nested_function->add(nested_state, nested.data(), i, arena);
259
0
            nested_state += nested_size_of_data;
260
0
        }
261
0
    }
262
};
263
} // namespace doris