Coverage Report

Created: 2026-06-08 10:57

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_count.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCount.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <stddef.h>
24
25
#include <algorithm>
26
#include <boost/iterator/iterator_facade.hpp>
27
#include <memory>
28
#include <vector>
29
30
#include "core/assert_cast.h"
31
#include "core/column/column.h"
32
#include "core/column/column_fixed_length_object.h"
33
#include "core/column/column_nullable.h"
34
#include "core/column/column_vector.h"
35
#include "core/data_type/data_type.h"
36
#include "core/data_type/data_type_fixed_length_object.h"
37
#include "core/data_type/data_type_number.h"
38
#include "core/types.h"
39
#include "exprs/aggregate/aggregate_function.h"
40
#include "util/simd/bits.h"
41
42
namespace doris {
43
class Arena;
44
class BufferReadable;
45
class BufferWritable;
46
47
struct AggregateFunctionCountData {
48
    UInt64 count = 0;
49
};
50
51
/// Simply count number of calls.
52
class AggregateFunctionCount final
53
        : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>,
54
          VarargsExpression,
55
          NotNullableAggregateFunction {
56
public:
57
    AggregateFunctionCount(const DataTypes& argument_types_)
58
43.4k
            : IAggregateFunctionDataHelper(argument_types_) {}
59
60
11.5k
    bool is_simple_count() const override { return true; }
61
54
    String get_name() const override { return "count"; }
62
63
107k
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
64
65
0
    bool is_trivial() const override { return true; }
66
67
43.0M
    void add(AggregateDataPtr __restrict place, const IColumn**, ssize_t, Arena&) const override {
68
43.0M
        ++data(place).count;
69
43.0M
    }
70
71
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn**,
72
93.4k
                                Arena&) const override {
73
93.4k
        data(place).count += batch_size;
74
93.4k
    }
75
76
160
    void reset(AggregateDataPtr place) const override {
77
160
        AggregateFunctionCount::data(place).count = 0;
78
160
    }
79
80
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
81
7.15k
               Arena&) const override {
82
7.15k
        data(place).count += data(rhs).count;
83
7.15k
    }
84
85
78
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
86
78
        buf.write_var_uint(data(place).count);
87
78
    }
88
89
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
90
78
                     Arena&) const override {
91
78
        buf.read_var_uint(data(place).count);
92
78
    }
93
94
334k
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
95
334k
        assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
96
334k
    }
97
98
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
99
4.97k
                             MutableColumnPtr& dst, const size_t num_rows) const override {
100
4.97k
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
101
4.97k
        DCHECK(col.item_size() == sizeof(Data))
102
3
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
103
4.97k
        col.resize(num_rows);
104
4.97k
        auto* data = col.get_data().data();
105
12.1k
        for (size_t i = 0; i != num_rows; ++i) {
106
7.21k
            *reinterpret_cast<Data*>(&data[sizeof(Data) * i]) =
107
7.21k
                    *reinterpret_cast<Data*>(places[i] + offset);
108
7.21k
        }
109
4.97k
    }
110
111
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
112
78
                                           const size_t num_rows, Arena&) const override {
113
78
        auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst);
114
78
        DCHECK(dst_col.item_size() == sizeof(Data))
115
0
                << "size is not equal: " << dst_col.item_size() << " " << sizeof(Data);
116
78
        dst_col.resize(num_rows);
117
78
        auto* data = dst_col.get_data().data();
118
371
        for (size_t i = 0; i != num_rows; ++i) {
119
293
            auto& state = *reinterpret_cast<Data*>(&data[sizeof(Data) * i]);
120
293
            state.count = 1;
121
293
        }
122
78
    }
123
124
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
125
                                                 const IColumn& column, size_t begin, size_t end,
126
20.6k
                                                 Arena&) const override {
127
20.6k
        DCHECK(end <= column.size() && begin <= end)
128
0
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
129
20.6k
        auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
130
20.6k
        auto* data = reinterpret_cast<const Data*>(col.get_data().data());
131
41.3k
        for (size_t i = begin; i <= end; ++i) {
132
20.6k
            doris::AggregateFunctionCount::data(place).count += data[i].count;
133
20.6k
        }
134
20.6k
    }
135
136
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
137
                                   AggregateDataPtr rhs, const IColumn* column, Arena& arena,
138
2.80k
                                   const size_t num_rows) const override {
139
2.80k
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
140
2.80k
        const auto* data = col.get_data().data();
141
2.80k
        this->merge_vec(places, offset, AggregateDataPtr(data), arena, num_rows);
142
2.80k
    }
143
144
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
145
                                            AggregateDataPtr rhs, const IColumn* column,
146
1
                                            Arena& arena, const size_t num_rows) const override {
147
1
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
148
1
        const auto* data = col.get_data().data();
149
1
        this->merge_vec_selected(places, offset, AggregateDataPtr(data), arena, num_rows);
150
1
    }
151
152
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
153
20.3k
                                         IColumn& to) const override {
154
20.3k
        auto& col = assert_cast<ColumnFixedLengthObject&>(to);
155
20.3k
        DCHECK(col.item_size() == sizeof(Data))
156
22
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
157
20.3k
        size_t old_size = col.size();
158
20.3k
        col.resize(old_size + 1);
159
20.3k
        (reinterpret_cast<Data*>(col.get_data().data()) + old_size)->count =
160
20.3k
                AggregateFunctionCount::data(place).count;
161
20.3k
    }
162
163
28.2k
    MutableColumnPtr create_serialize_column() const override {
164
28.2k
        return ColumnFixedLengthObject::create(sizeof(Data));
165
28.2k
    }
166
167
28.3k
    DataTypePtr get_serialized_type() const override {
168
28.3k
        return std::make_shared<DataTypeFixedLengthObject>();
169
28.3k
    }
170
171
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
172
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
173
                                Arena& arena, UInt8* use_null_result,
174
260
                                UInt8* could_use_previous_result) const override {
175
260
        frame_start = std::max<int64_t>(frame_start, partition_start);
176
260
        frame_end = std::min<int64_t>(frame_end, partition_end);
177
260
        if (frame_start >= frame_end) {
178
12
            if (!*could_use_previous_result) {
179
0
                *use_null_result = true;
180
0
            }
181
248
        } else {
182
248
            AggregateFunctionCount::data(place).count += frame_end - frame_start;
183
248
            *use_null_result = false;
184
248
            *could_use_previous_result = true;
185
248
        }
186
260
    }
187
};
188
189
// Used for unary count(nullable_expr). SQL count(expr) counts non-NULL values.
190
class AggregateFunctionCountNotNullUnary final
191
        : public IAggregateFunctionDataHelper<AggregateFunctionCountData,
192
                                              AggregateFunctionCountNotNullUnary> {
193
public:
194
    AggregateFunctionCountNotNullUnary(const DataTypes& argument_types_)
195
12.1k
            : IAggregateFunctionDataHelper(argument_types_) {}
196
197
161
    String get_name() const override { return "count"; }
198
199
30.8k
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
200
201
0
    bool is_trivial() const override { return true; }
202
203
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
204
1.57M
             Arena&) const override {
205
1.57M
        data(place).count +=
206
1.57M
                !assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0])
207
1.57M
                         .is_null_at(row_num);
208
1.57M
    }
209
210
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
211
24.4k
                                Arena&) const override {
212
24.4k
        const auto& nullable_column =
213
24.4k
                assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
214
24.4k
        const auto& null_map = nullable_column.get_null_map_data();
215
24.4k
        DCHECK_LE(batch_size, null_map.size());
216
24.4k
        if (!nullable_column.has_null(0, batch_size)) {
217
21.3k
            data(place).count += batch_size;
218
21.3k
            return;
219
21.3k
        }
220
3.08k
        data(place).count +=
221
3.08k
                simd::count_zero_num(reinterpret_cast<const int8_t*>(null_map.data()), batch_size);
222
3.08k
    }
223
224
238
    void reset(AggregateDataPtr place) const override { data(place).count = 0; }
225
226
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
227
28.9k
               Arena&) const override {
228
28.9k
        data(place).count += data(rhs).count;
229
28.9k
    }
230
231
6
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
232
6
        buf.write_var_uint(data(place).count);
233
6
    }
234
235
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
236
6
                     Arena&) const override {
237
6
        buf.read_var_uint(data(place).count);
238
6
    }
239
240
111k
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
241
111k
        if (to.is_nullable()) {
242
0
            auto& null_column = assert_cast<ColumnNullable&>(to);
243
0
            null_column.get_null_map_data().push_back(0);
244
0
            assert_cast<ColumnInt64&>(null_column.get_nested_column())
245
0
                    .get_data()
246
0
                    .push_back(data(place).count);
247
111k
        } else {
248
111k
            assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
249
111k
        }
250
111k
    }
251
252
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
253
2.47k
                             MutableColumnPtr& dst, const size_t num_rows) const override {
254
2.47k
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
255
2.47k
        DCHECK(col.item_size() == sizeof(Data))
256
1
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
257
2.47k
        col.resize(num_rows);
258
2.47k
        auto* data = col.get_data().data();
259
31.4k
        for (size_t i = 0; i != num_rows; ++i) {
260
28.9k
            *reinterpret_cast<Data*>(&data[sizeof(Data) * i]) =
261
28.9k
                    *reinterpret_cast<Data*>(places[i] + offset);
262
28.9k
        }
263
2.47k
    }
264
265
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
266
2
                                           const size_t num_rows, Arena&) const override {
267
2
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
268
2
        DCHECK(col.item_size() == sizeof(Data))
269
0
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
270
2
        col.resize(num_rows);
271
2
        auto& data = col.get_data();
272
2
        const ColumnNullable& input_col = assert_cast<const ColumnNullable&>(*columns[0]);
273
11
        for (size_t i = 0; i < num_rows; i++) {
274
9
            auto& state = *reinterpret_cast<Data*>(&data[sizeof(Data) * i]);
275
9
            state.count = !input_col.is_null_at(i);
276
9
        }
277
2
    }
278
279
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
280
                                                 const IColumn& column, size_t begin, size_t end,
281
5.83k
                                                 Arena&) const override {
282
5.83k
        DCHECK(end <= column.size() && begin <= end)
283
0
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
284
5.83k
        auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
285
5.83k
        auto* data = reinterpret_cast<const Data*>(col.get_data().data());
286
11.6k
        for (size_t i = begin; i <= end; ++i) {
287
5.84k
            doris::AggregateFunctionCountNotNullUnary::data(place).count += data[i].count;
288
5.84k
        }
289
5.83k
    }
290
291
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
292
                                   AggregateDataPtr rhs, const IColumn* column, Arena& arena,
293
2.63k
                                   const size_t num_rows) const override {
294
2.63k
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
295
2.63k
        const auto* data = col.get_data().data();
296
2.63k
        this->merge_vec(places, offset, AggregateDataPtr(data), arena, num_rows);
297
2.63k
    }
298
299
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
300
                                            AggregateDataPtr rhs, const IColumn* column,
301
2
                                            Arena& arena, const size_t num_rows) const override {
302
2
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
303
2
        const auto* data = col.get_data().data();
304
2
        this->merge_vec_selected(places, offset, AggregateDataPtr(data), arena, num_rows);
305
2
    }
306
307
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
308
5.83k
                                         IColumn& to) const override {
309
5.83k
        auto& col = assert_cast<ColumnFixedLengthObject&>(to);
310
18.4E
        DCHECK(col.item_size() == sizeof(Data))
311
18.4E
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
312
5.83k
        col.resize(1);
313
5.83k
        reinterpret_cast<Data*>(col.get_data().data())->count =
314
5.83k
                AggregateFunctionCountNotNullUnary::data(place).count;
315
5.83k
    }
316
317
8.29k
    MutableColumnPtr create_serialize_column() const override {
318
8.29k
        return ColumnFixedLengthObject::create(sizeof(Data));
319
8.29k
    }
320
321
8.30k
    DataTypePtr get_serialized_type() const override {
322
8.30k
        return std::make_shared<DataTypeFixedLengthObject>();
323
8.30k
    }
324
325
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
326
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
327
                                Arena& arena, UInt8* use_null_result,
328
330
                                UInt8* could_use_previous_result) const override {
329
330
        frame_start = std::max<int64_t>(frame_start, partition_start);
330
330
        frame_end = std::min<int64_t>(frame_end, partition_end);
331
330
        if (frame_start >= frame_end) {
332
14
            if (!*could_use_previous_result) {
333
0
                *use_null_result = true;
334
0
            }
335
316
        } else {
336
316
            const auto& nullable_column =
337
316
                    assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
338
316
            size_t count = 0;
339
316
            if (nullable_column.has_null()) {
340
583
                for (int64_t i = frame_start; i < frame_end; ++i) {
341
352
                    if (!nullable_column.is_null_at(i)) {
342
318
                        ++count;
343
318
                    }
344
352
                }
345
231
            } else {
346
85
                count = frame_end - frame_start;
347
85
            }
348
316
            *use_null_result = false;
349
316
            *could_use_previous_result = true;
350
316
            AggregateFunctionCountNotNullUnary::data(place).count += count;
351
316
        }
352
330
    }
353
};
354
355
} // namespace doris