Coverage Report

Created: 2026-03-25 16:15

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_count.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCount.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <stddef.h>
24
25
#include <algorithm>
26
#include <boost/iterator/iterator_facade.hpp>
27
#include <memory>
28
#include <vector>
29
30
#include "core/assert_cast.h"
31
#include "core/column/column.h"
32
#include "core/column/column_fixed_length_object.h"
33
#include "core/column/column_nullable.h"
34
#include "core/column/column_vector.h"
35
#include "core/data_type/data_type.h"
36
#include "core/data_type/data_type_fixed_length_object.h"
37
#include "core/data_type/data_type_number.h"
38
#include "core/types.h"
39
#include "exprs/aggregate/aggregate_function.h"
40
41
namespace doris {
42
#include "common/compile_check_begin.h"
43
class Arena;
44
class BufferReadable;
45
class BufferWritable;
46
47
struct AggregateFunctionCountData {
48
    UInt64 count = 0;
49
};
50
51
/// Simply count number of calls.
52
class AggregateFunctionCount final
53
        : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>,
54
          VarargsExpression,
55
          NotNullableAggregateFunction {
56
public:
57
    AggregateFunctionCount(const DataTypes& argument_types_)
58
55.7k
            : IAggregateFunctionDataHelper(argument_types_) {}
59
60
5.65k
    bool is_simple_count() const override { return true; }
61
287
    String get_name() const override { return "count"; }
62
63
110k
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
64
65
0
    bool is_trivial() const override { return true; }
66
67
150M
    void add(AggregateDataPtr __restrict place, const IColumn**, ssize_t, Arena&) const override {
68
150M
        ++data(place).count;
69
150M
    }
70
71
137
    void reset(AggregateDataPtr place) const override {
72
137
        AggregateFunctionCount::data(place).count = 0;
73
137
    }
74
75
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
76
782k
               Arena&) const override {
77
782k
        data(place).count += data(rhs).count;
78
782k
    }
79
80
90
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
81
90
        buf.write_var_uint(data(place).count);
82
90
    }
83
84
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
85
90
                     Arena&) const override {
86
90
        buf.read_var_uint(data(place).count);
87
90
    }
88
89
1.51M
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
90
1.51M
        assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
91
1.51M
    }
92
93
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
94
5.68k
                             MutableColumnPtr& dst, const size_t num_rows) const override {
95
5.68k
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
96
5.68k
        DCHECK(col.item_size() == sizeof(Data))
97
7
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
98
5.68k
        col.resize(num_rows);
99
5.68k
        auto* data = col.get_data().data();
100
788k
        for (size_t i = 0; i != num_rows; ++i) {
101
782k
            *reinterpret_cast<Data*>(&data[sizeof(Data) * i]) =
102
782k
                    *reinterpret_cast<Data*>(places[i] + offset);
103
782k
        }
104
5.68k
    }
105
106
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
107
77
                                           const size_t num_rows, Arena&) const override {
108
77
        auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst);
109
77
        DCHECK(dst_col.item_size() == sizeof(Data))
110
0
                << "size is not equal: " << dst_col.item_size() << " " << sizeof(Data);
111
77
        dst_col.resize(num_rows);
112
77
        auto* data = dst_col.get_data().data();
113
388
        for (size_t i = 0; i != num_rows; ++i) {
114
311
            auto& state = *reinterpret_cast<Data*>(&data[sizeof(Data) * i]);
115
311
            state.count = 1;
116
311
        }
117
77
    }
118
119
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
120
                                                 const IColumn& column, size_t begin, size_t end,
121
27.5k
                                                 Arena&) const override {
122
27.5k
        DCHECK(end <= column.size() && begin <= end)
123
3
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
124
27.5k
        auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
125
27.5k
        auto* data = reinterpret_cast<const Data*>(col.get_data().data());
126
55.1k
        for (size_t i = begin; i <= end; ++i) {
127
27.6k
            doris::AggregateFunctionCount::data(place).count += data[i].count;
128
27.6k
        }
129
27.5k
    }
130
131
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
132
                                   AggregateDataPtr rhs, const IColumn* column, Arena& arena,
133
3.78k
                                   const size_t num_rows) const override {
134
3.78k
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
135
3.78k
        const auto* data = col.get_data().data();
136
3.78k
        this->merge_vec(places, offset, AggregateDataPtr(data), arena, num_rows);
137
3.78k
    }
138
139
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
140
                                            AggregateDataPtr rhs, const IColumn* column,
141
1
                                            Arena& arena, const size_t num_rows) const override {
142
1
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
143
1
        const auto* data = col.get_data().data();
144
1
        this->merge_vec_selected(places, offset, AggregateDataPtr(data), arena, num_rows);
145
1
    }
146
147
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
148
27.2k
                                         IColumn& to) const override {
149
27.2k
        auto& col = assert_cast<ColumnFixedLengthObject&>(to);
150
27.2k
        DCHECK(col.item_size() == sizeof(Data))
151
9
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
152
27.2k
        size_t old_size = col.size();
153
27.2k
        col.resize(old_size + 1);
154
27.2k
        (reinterpret_cast<Data*>(col.get_data().data()) + old_size)->count =
155
27.2k
                AggregateFunctionCount::data(place).count;
156
27.2k
    }
157
158
35.5k
    MutableColumnPtr create_serialize_column() const override {
159
35.5k
        return ColumnFixedLengthObject::create(sizeof(Data));
160
35.5k
    }
161
162
35.6k
    DataTypePtr get_serialized_type() const override {
163
35.6k
        return std::make_shared<DataTypeFixedLengthObject>();
164
35.6k
    }
165
166
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
167
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
168
                                Arena& arena, UInt8* use_null_result,
169
253
                                UInt8* could_use_previous_result) const override {
170
253
        frame_start = std::max<int64_t>(frame_start, partition_start);
171
253
        frame_end = std::min<int64_t>(frame_end, partition_end);
172
253
        if (frame_start >= frame_end) {
173
12
            if (!*could_use_previous_result) {
174
0
                *use_null_result = true;
175
0
            }
176
241
        } else {
177
241
            AggregateFunctionCount::data(place).count += frame_end - frame_start;
178
241
            *use_null_result = false;
179
241
            *could_use_previous_result = true;
180
241
        }
181
253
    }
182
};
183
184
// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount
185
// Simply count number of not-NULL values.
186
class AggregateFunctionCountNotNullUnary final
187
        : public IAggregateFunctionDataHelper<AggregateFunctionCountData,
188
                                              AggregateFunctionCountNotNullUnary> {
189
public:
190
    AggregateFunctionCountNotNullUnary(const DataTypes& argument_types_)
191
19.0k
            : IAggregateFunctionDataHelper(argument_types_) {}
192
193
308
    String get_name() const override { return "count"; }
194
195
38.9k
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
196
197
0
    bool is_trivial() const override { return true; }
198
199
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
200
1.56M
             Arena&) const override {
201
1.56M
        data(place).count +=
202
1.56M
                !assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0])
203
1.56M
                         .is_null_at(row_num);
204
1.56M
    }
205
206
271
    void reset(AggregateDataPtr place) const override { data(place).count = 0; }
207
208
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
209
32.5k
               Arena&) const override {
210
32.5k
        data(place).count += data(rhs).count;
211
32.5k
    }
212
213
6
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
214
6
        buf.write_var_uint(data(place).count);
215
6
    }
216
217
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
218
6
                     Arena&) const override {
219
6
        buf.read_var_uint(data(place).count);
220
6
    }
221
222
103k
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
223
103k
        if (to.is_nullable()) {
224
0
            auto& null_column = assert_cast<ColumnNullable&>(to);
225
0
            null_column.get_null_map_data().push_back(0);
226
0
            assert_cast<ColumnInt64&>(null_column.get_nested_column())
227
0
                    .get_data()
228
0
                    .push_back(data(place).count);
229
103k
        } else {
230
103k
            assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
231
103k
        }
232
103k
    }
233
234
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
235
2.24k
                             MutableColumnPtr& dst, const size_t num_rows) const override {
236
2.24k
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
237
2.24k
        DCHECK(col.item_size() == sizeof(Data))
238
2
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
239
2.24k
        col.resize(num_rows);
240
2.24k
        auto* data = col.get_data().data();
241
64.8k
        for (size_t i = 0; i != num_rows; ++i) {
242
62.5k
            *reinterpret_cast<Data*>(&data[sizeof(Data) * i]) =
243
62.5k
                    *reinterpret_cast<Data*>(places[i] + offset);
244
62.5k
        }
245
2.24k
    }
246
247
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
248
2
                                           const size_t num_rows, Arena&) const override {
249
2
        auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
250
2
        DCHECK(col.item_size() == sizeof(Data))
251
0
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
252
2
        col.resize(num_rows);
253
2
        auto& data = col.get_data();
254
2
        const ColumnNullable& input_col = assert_cast<const ColumnNullable&>(*columns[0]);
255
6
        for (size_t i = 0; i < num_rows; i++) {
256
4
            auto& state = *reinterpret_cast<Data*>(&data[sizeof(Data) * i]);
257
4
            state.count = !input_col.is_null_at(i);
258
4
        }
259
2
    }
260
261
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
262
                                                 const IColumn& column, size_t begin, size_t end,
263
11.4k
                                                 Arena&) const override {
264
11.4k
        DCHECK(end <= column.size() && begin <= end)
265
0
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
266
11.4k
        auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
267
11.4k
        auto* data = reinterpret_cast<const Data*>(col.get_data().data());
268
22.8k
        for (size_t i = begin; i <= end; ++i) {
269
11.4k
            doris::AggregateFunctionCountNotNullUnary::data(place).count += data[i].count;
270
11.4k
        }
271
11.4k
    }
272
273
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
274
                                   AggregateDataPtr rhs, const IColumn* column, Arena& arena,
275
1.38k
                                   const size_t num_rows) const override {
276
1.38k
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
277
1.38k
        const auto* data = col.get_data().data();
278
1.38k
        this->merge_vec(places, offset, AggregateDataPtr(data), arena, num_rows);
279
1.38k
    }
280
281
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
282
                                            AggregateDataPtr rhs, const IColumn* column,
283
0
                                            Arena& arena, const size_t num_rows) const override {
284
0
        const auto& col = assert_cast<const ColumnFixedLengthObject&>(*column);
285
0
        const auto* data = col.get_data().data();
286
0
        this->merge_vec_selected(places, offset, AggregateDataPtr(data), arena, num_rows);
287
0
    }
288
289
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
290
11.4k
                                         IColumn& to) const override {
291
11.4k
        auto& col = assert_cast<ColumnFixedLengthObject&>(to);
292
11.4k
        DCHECK(col.item_size() == sizeof(Data))
293
0
                << "size is not equal: " << col.item_size() << " " << sizeof(Data);
294
11.4k
        col.resize(1);
295
11.4k
        reinterpret_cast<Data*>(col.get_data().data())->count =
296
11.4k
                AggregateFunctionCountNotNullUnary::data(place).count;
297
11.4k
    }
298
299
13.6k
    MutableColumnPtr create_serialize_column() const override {
300
13.6k
        return ColumnFixedLengthObject::create(sizeof(Data));
301
13.6k
    }
302
303
13.6k
    DataTypePtr get_serialized_type() const override {
304
13.6k
        return std::make_shared<DataTypeFixedLengthObject>();
305
13.6k
    }
306
307
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
308
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
309
                                Arena& arena, UInt8* use_null_result,
310
391
                                UInt8* could_use_previous_result) const override {
311
391
        frame_start = std::max<int64_t>(frame_start, partition_start);
312
391
        frame_end = std::min<int64_t>(frame_end, partition_end);
313
391
        if (frame_start >= frame_end) {
314
14
            if (!*could_use_previous_result) {
315
0
                *use_null_result = true;
316
0
            }
317
377
        } else {
318
377
            const auto& nullable_column =
319
377
                    assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
320
377
            size_t count = 0;
321
377
            if (nullable_column.has_null()) {
322
222
                for (int64_t i = frame_start; i < frame_end; ++i) {
323
127
                    if (!nullable_column.is_null_at(i)) {
324
95
                        ++count;
325
95
                    }
326
127
                }
327
282
            } else {
328
282
                count = frame_end - frame_start;
329
282
            }
330
377
            *use_null_result = false;
331
377
            *could_use_previous_result = true;
332
377
            AggregateFunctionCountNotNullUnary::data(place).count += count;
333
377
        }
334
391
    }
335
};
336
337
} // namespace doris
338
339
#include "common/compile_check_end.h"