Coverage Report

Created: 2026-03-12 16:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_count.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCount.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <stddef.h>
24
25
#include <algorithm>
26
#include <boost/iterator/iterator_facade.hpp>
27
#include <memory>
28
#include <vector>
29
30
#include "core/assert_cast.h"
31
#include "core/column/column.h"
32
#include "core/column/column_fixed_length_object.h"
33
#include "core/column/column_nullable.h"
34
#include "core/column/column_vector.h"
35
#include "core/data_type/data_type.h"
36
#include "core/data_type/data_type_fixed_length_object.h"
37
#include "core/data_type/data_type_number.h"
38
#include "core/types.h"
39
#include "exprs/aggregate/aggregate_function.h"
40
41
namespace doris {
42
#include "common/compile_check_begin.h"
43
class Arena;
44
class BufferReadable;
45
class BufferWritable;
46
47
struct AggregateFunctionCountData {
48
    UInt64 count = 0;
49
};
50
51
/// Simply count number of calls.
52
class AggregateFunctionCount final
53
        : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>,
54
          VarargsExpression,
55
          NotNullableAggregateFunction {
56
public:
57
    AggregateFunctionCount(const DataTypes& argument_types_)
58
47.2k
            : IAggregateFunctionDataHelper(argument_types_) {}
59
60
4.40k
    String get_name() const override { return "count"; }
61
62
93.9k
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
63
64
0
    bool is_trivial() const override { return true; }
65
66
122M
    void add(AggregateDataPtr __restrict place, const IColumn**, ssize_t, Arena&) const override {
67
122M
        ++data(place).count;
68
122M
    }
69
70
130
    void reset(AggregateDataPtr place) const override {
71
130
        AggregateFunctionCount::data(place).count = 0;
72
130
    }
73
74
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
75
251k
               Arena&) const override {
76
251k
        data(place).count += data(rhs).count;
77
251k
    }
78
79
96
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
80
96
        buf.write_var_uint(data(place).count);
81
96
    }
82
83
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
84
96
                     Arena&) const override {
85
96
        buf.read_var_uint(data(place).count);
86
96
    }
87
88
354k
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
89
354k
        assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
90
354k
    }
91
92
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
93
10.6k
                             MutableColumnPtr& dst, const size_t num_rows) const override {
94
10.6k
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
95
18.4E
        DCHECK(col.item_size() == sizeof(Data))
96
18.4E
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
97
10.6k
        col.resize(num_rows);
98
10.6k
        auto* data = col.get_data().data();
99
261k
        for (size_t i = 0; i != num_rows; ++i) {
100
251k
            *reinterpret_cast<Data*>(&data[sizeof(Data) * i]) =
101
251k
                    *reinterpret_cast<Data*>(places[i] + offset);
102
251k
        }
103
10.6k
    }
104
105
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
106
80
                                           const size_t num_rows, Arena&) const override {
107
80
        auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst);
108
80
        DCHECK(dst_col.item_size() == sizeof(Data))
109
0
                << "size is not equal: " << dst_col.item_size() << " " << sizeof(Data);
110
80
        dst_col.resize(num_rows);
111
80
        auto* data = dst_col.get_data().data();
112
391
        for (size_t i = 0; i != num_rows; ++i) {
113
311
            auto& state = *reinterpret_cast<Data*>(&data[sizeof(Data) * i]);
114
311
            state.count = 1;
115
311
        }
116
80
    }
117
118
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
119
                                                 const IColumn& column, size_t begin, size_t end,
120
25.0k
                                                 Arena&) const override {
121
25.0k
        DCHECK(end <= column.size() && begin <= end)
122
4
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
123
25.0k
        auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
124
25.0k
        auto* data = reinterpret_cast<const Data*>(col.get_data().data());
125
50.0k
        for (size_t i = begin; i <= end; ++i) {
126
25.0k
            doris::AggregateFunctionCount::data(place).count += data[i].count;
127
25.0k
        }
128
25.0k
    }
129
130
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
131
                                   AggregateDataPtr rhs, const IColumn* column, Arena& arena,
132
8.08k
                                   const size_t num_rows) const override {
133
8.08k
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
134
8.08k
        const auto* data = col.get_data().data();
135
8.08k
        this->merge_vec(places, offset, AggregateDataPtr(data), arena, num_rows);
136
8.08k
    }
137
138
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
139
                                            AggregateDataPtr rhs, const IColumn* column,
140
1
                                            Arena& arena, const size_t num_rows) const override {
141
1
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
142
1
        const auto* data = col.get_data().data();
143
1
        this->merge_vec_selected(places, offset, AggregateDataPtr(data), arena, num_rows);
144
1
    }
145
146
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
147
24.7k
                                         IColumn& to) const override {
148
24.7k
        auto& col = assert_cast<ColumnFixedLengthObject&>(to);
149
24.7k
        DCHECK(col.item_size() == sizeof(Data))
150
8
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
151
24.7k
        size_t old_size = col.size();
152
24.7k
        col.resize(old_size + 1);
153
24.7k
        (reinterpret_cast<Data*>(col.get_data().data()) + old_size)->count =
154
24.7k
                AggregateFunctionCount::data(place).count;
155
24.7k
    }
156
157
31.3k
    MutableColumnPtr create_serialize_column() const override {
158
31.3k
        return ColumnFixedLengthObject::create(sizeof(Data));
159
31.3k
    }
160
161
31.7k
    DataTypePtr get_serialized_type() const override {
162
31.7k
        return std::make_shared<DataTypeFixedLengthObject>();
163
31.7k
    }
164
165
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
166
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
167
                                Arena& arena, UInt8* use_null_result,
168
253
                                UInt8* could_use_previous_result) const override {
169
253
        frame_start = std::max<int64_t>(frame_start, partition_start);
170
253
        frame_end = std::min<int64_t>(frame_end, partition_end);
171
253
        if (frame_start >= frame_end) {
172
12
            if (!*could_use_previous_result) {
173
0
                *use_null_result = true;
174
0
            }
175
241
        } else {
176
241
            AggregateFunctionCount::data(place).count += frame_end - frame_start;
177
241
            *use_null_result = false;
178
241
            *could_use_previous_result = true;
179
241
        }
180
253
    }
181
};
182
183
// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount
184
// Simply count number of not-NULL values.
185
class AggregateFunctionCountNotNullUnary final
186
        : public IAggregateFunctionDataHelper<AggregateFunctionCountData,
187
                                              AggregateFunctionCountNotNullUnary> {
188
public:
189
    AggregateFunctionCountNotNullUnary(const DataTypes& argument_types_)
190
18.8k
            : IAggregateFunctionDataHelper(argument_types_) {}
191
192
251
    String get_name() const override { return "count"; }
193
194
38.5k
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
195
196
0
    bool is_trivial() const override { return true; }
197
198
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
199
1.43M
             Arena&) const override {
200
1.43M
        data(place).count +=
201
1.43M
                !assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0])
202
1.43M
                         .is_null_at(row_num);
203
1.43M
    }
204
205
271
    void reset(AggregateDataPtr place) const override { data(place).count = 0; }
206
207
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
208
22.7k
               Arena&) const override {
209
22.7k
        data(place).count += data(rhs).count;
210
22.7k
    }
211
212
6
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
213
6
        buf.write_var_uint(data(place).count);
214
6
    }
215
216
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
217
6
                     Arena&) const override {
218
6
        buf.read_var_uint(data(place).count);
219
6
    }
220
221
91.0k
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
222
91.0k
        if (to.is_nullable()) {
223
0
            auto& null_column = assert_cast<ColumnNullable&>(to);
224
0
            null_column.get_null_map_data().push_back(0);
225
0
            assert_cast<ColumnInt64&>(null_column.get_nested_column())
226
0
                    .get_data()
227
0
                    .push_back(data(place).count);
228
91.0k
        } else {
229
91.0k
            assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
230
91.0k
        }
231
91.0k
    }
232
233
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
234
2.40k
                             MutableColumnPtr& dst, const size_t num_rows) const override {
235
2.40k
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
236
2.40k
        DCHECK(col.item_size() == sizeof(Data))
237
7
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
238
2.40k
        col.resize(num_rows);
239
2.40k
        auto* data = col.get_data().data();
240
54.4k
        for (size_t i = 0; i != num_rows; ++i) {
241
52.0k
            *reinterpret_cast<Data*>(&data[sizeof(Data) * i]) =
242
52.0k
                    *reinterpret_cast<Data*>(places[i] + offset);
243
52.0k
        }
244
2.40k
    }
245
246
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
247
0
                                           const size_t num_rows, Arena&) const override {
248
0
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
249
0
        DCHECK(col.item_size() == sizeof(Data))
250
0
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
251
0
        col.resize(num_rows);
252
0
        auto& data = col.get_data();
253
0
        const ColumnNullable& input_col = assert_cast<const ColumnNullable&>(*columns[0]);
254
0
        for (size_t i = 0; i < num_rows; i++) {
255
0
            auto& state = *reinterpret_cast<Data*>(&data[sizeof(Data) * i]);
256
0
            state.count = !input_col.is_null_at(i);
257
0
        }
258
0
    }
259
260
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
261
                                                 const IColumn& column, size_t begin, size_t end,
262
11.3k
                                                 Arena&) const override {
263
18.4E
        DCHECK(end <= column.size() && begin <= end)
264
18.4E
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
265
11.3k
        auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
266
11.3k
        auto* data = reinterpret_cast<const Data*>(col.get_data().data());
267
22.7k
        for (size_t i = begin; i <= end; ++i) {
268
11.3k
            doris::AggregateFunctionCountNotNullUnary::data(place).count += data[i].count;
269
11.3k
        }
270
11.3k
    }
271
272
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
273
                                   AggregateDataPtr rhs, const IColumn* column, Arena& arena,
274
1.73k
                                   const size_t num_rows) const override {
275
1.73k
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
276
1.73k
        const auto* data = col.get_data().data();
277
1.73k
        this->merge_vec(places, offset, AggregateDataPtr(data), arena, num_rows);
278
1.73k
    }
279
280
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
281
                                            AggregateDataPtr rhs, const IColumn* column,
282
0
                                            Arena& arena, const size_t num_rows) const override {
283
0
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
284
0
        const auto* data = col.get_data().data();
285
0
        this->merge_vec_selected(places, offset, AggregateDataPtr(data), arena, num_rows);
286
0
    }
287
288
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
289
11.3k
                                         IColumn& to) const override {
290
11.3k
        auto& col = assert_cast<ColumnFixedLengthObject&>(to);
291
11.3k
        DCHECK(col.item_size() == sizeof(Data))
292
1
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
293
11.3k
        col.resize(1);
294
11.3k
        reinterpret_cast<Data*>(col.get_data().data())->count =
295
11.3k
                AggregateFunctionCountNotNullUnary::data(place).count;
296
11.3k
    }
297
298
13.7k
    MutableColumnPtr create_serialize_column() const override {
299
13.7k
        return ColumnFixedLengthObject::create(sizeof(Data));
300
13.7k
    }
301
302
13.7k
    DataTypePtr get_serialized_type() const override {
303
13.7k
        return std::make_shared<DataTypeFixedLengthObject>();
304
13.7k
    }
305
306
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
307
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
308
                                Arena& arena, UInt8* use_null_result,
309
391
                                UInt8* could_use_previous_result) const override {
310
391
        frame_start = std::max<int64_t>(frame_start, partition_start);
311
391
        frame_end = std::min<int64_t>(frame_end, partition_end);
312
391
        if (frame_start >= frame_end) {
313
14
            if (!*could_use_previous_result) {
314
0
                *use_null_result = true;
315
0
            }
316
377
        } else {
317
377
            const auto& nullable_column =
318
377
                    assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
319
377
            size_t count = 0;
320
377
            if (nullable_column.has_null()) {
321
380
                for (int64_t i = frame_start; i < frame_end; ++i) {
322
224
                    if (!nullable_column.is_null_at(i)) {
323
192
                        ++count;
324
192
                    }
325
224
                }
326
221
            } else {
327
221
                count = frame_end - frame_start;
328
221
            }
329
377
            *use_null_result = false;
330
377
            *could_use_previous_result = true;
331
377
            AggregateFunctionCountNotNullUnary::data(place).count += count;
332
377
        }
333
391
    }
334
};
335
336
} // namespace doris
337
338
#include "common/compile_check_end.h"