Coverage Report

Created: 2026-03-13 21:50

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_map_v2.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <glog/logging.h>
21
#include <parallel_hashmap/phmap.h>
22
23
#include "core/assert_cast.h"
24
#include "core/column/column_decimal.h"
25
#include "core/column/column_map.h"
26
#include "core/column/column_string.h"
27
#include "core/data_type/data_type_map.h"
28
#include "core/field.h"
29
#include "exprs/aggregate/aggregate_function.h"
30
#include "exprs/aggregate/aggregate_function_simple_factory.h"
31
32
namespace doris {
33
#include "common/compile_check_begin.h"
34
35
struct AggregateFunctionMapAggDataV2 {
36
    using Map = phmap::flat_hash_map<Field, int64_t>;
37
38
0
    AggregateFunctionMapAggDataV2() {
39
0
        throw Exception(Status::FatalError("__builtin_unreachable"));
40
0
    }
41
42
    AggregateFunctionMapAggDataV2(const DataTypes& argument_types, const int be_exec_version)
43
0
            : _be_version(be_exec_version) {
44
0
        _key_type = make_nullable(argument_types[0]);
45
0
        _value_type = make_nullable(argument_types[1]);
46
0
        _key_column = _key_type->create_column();
47
0
        _value_column = _value_type->create_column();
48
0
    }
49
50
0
    void reset() {
51
0
        _map.clear();
52
0
        _key_column->clear();
53
0
        _value_column->clear();
54
0
    }
55
56
0
    void add_single(const Field& key, const Field& value) {
57
0
        if (UNLIKELY(_map.find(key) != _map.end())) {
58
0
            return;
59
0
        }
60
61
0
        _map.emplace(key, _key_column->size());
62
0
        _key_column->insert(key);
63
0
        _value_column->insert(value);
64
0
    }
65
66
0
    void add(const Field& key_, const Field& value) {
67
0
        DCHECK(!key_.is_null());
68
0
        auto key_array = key_.get<TYPE_ARRAY>();
69
0
        auto value_array = value.get<TYPE_ARRAY>();
70
71
0
        const auto count = key_array.size();
72
0
        DCHECK_EQ(count, value_array.size());
73
74
0
        for (size_t i = 0; i != count; ++i) {
75
0
            const auto& key = key_array[i];
76
77
0
            if (UNLIKELY(_map.find(key) != _map.end())) {
78
0
                continue;
79
0
            }
80
81
0
            _map.emplace(key, _key_column->size());
82
0
            _key_column->insert(key);
83
0
            _value_column->insert(value_array[i]);
84
0
        }
85
0
    }
86
87
0
    void merge(const AggregateFunctionMapAggDataV2& other) {
88
0
        const size_t num_rows = other._key_column->size();
89
0
        if (num_rows == 0) {
90
0
            return;
91
0
        }
92
93
0
        auto& other_key_column_nullable = assert_cast<ColumnNullable&>(*other._key_column);
94
0
        for (size_t i = 0; i != num_rows; ++i) {
95
0
            const auto& key = other_key_column_nullable[i];
96
0
            if (_map.find(key) != _map.cend()) {
97
0
                continue;
98
0
            }
99
100
0
            _map.emplace(key, _key_column->size());
101
0
            _key_column->insert(key);
102
103
0
            _value_column->insert((*other._value_column)[i]);
104
0
        }
105
0
    }
106
107
0
    void insert_result_into(IColumn& to) const {
108
0
        auto& dst = assert_cast<ColumnMap&>(to);
109
0
        size_t num_rows = _key_column->size();
110
0
        auto& offsets = dst.get_offsets();
111
0
        auto& dst_key_column = assert_cast<ColumnNullable&>(dst.get_keys());
112
0
        dst_key_column.insert_range_from(*_key_column, 0, num_rows);
113
0
        dst.get_values().insert_range_from(*_value_column, 0, num_rows);
114
0
        if (offsets.empty()) {
115
0
            offsets.push_back(num_rows);
116
0
        } else {
117
0
            offsets.push_back(offsets.back() + num_rows);
118
0
        }
119
0
    }
120
121
0
    void write(BufferWritable& buf) const {
122
0
        auto serialized_bytes =
123
0
                _key_type->get_uncompressed_serialized_bytes(*_key_column, _be_version);
124
125
0
        std::string serialized_buffer;
126
0
        serialized_buffer.resize(serialized_bytes);
127
128
0
        auto* buf_ptr = _key_type->serialize(*_key_column, serialized_buffer.data(), _be_version);
129
0
        int64_t written_bytes = buf_ptr - serialized_buffer.data();
130
0
        DCHECK_LE(written_bytes, serialized_bytes);
131
132
0
        serialized_buffer.resize(serialized_bytes);
133
0
        buf.write_binary(serialized_buffer);
134
135
0
        serialized_bytes =
136
0
                _value_type->get_uncompressed_serialized_bytes(*_value_column, _be_version);
137
138
0
        serialized_buffer.resize(serialized_bytes);
139
140
0
        buf_ptr = _value_type->serialize(*_value_column, serialized_buffer.data(), _be_version);
141
0
        written_bytes = buf_ptr - serialized_buffer.data();
142
0
        DCHECK_LE(written_bytes, serialized_bytes);
143
144
0
        serialized_buffer.resize(written_bytes);
145
0
        buf.write_binary(serialized_buffer);
146
0
    }
147
148
0
    void read(BufferReadable& buf) {
149
0
        std::string deserialized_buffer;
150
151
0
        buf.read_binary(deserialized_buffer);
152
153
0
        const auto* ptr =
154
0
                _key_type->deserialize(deserialized_buffer.data(), &_key_column, _be_version);
155
0
        auto read_bytes = ptr - deserialized_buffer.data();
156
0
        DCHECK_EQ(read_bytes, deserialized_buffer.size());
157
158
0
        buf.read_binary(deserialized_buffer);
159
160
0
        ptr = _value_type->deserialize(deserialized_buffer.data(), &_value_column, _be_version);
161
0
        read_bytes = ptr - deserialized_buffer.data();
162
0
        DCHECK_EQ(read_bytes, deserialized_buffer.size());
163
0
    }
164
165
private:
166
    Map _map;
167
    Arena _arena;
168
    IColumn::MutablePtr _key_column;
169
    IColumn::MutablePtr _value_column;
170
    DataTypePtr _key_type;
171
    DataTypePtr _value_type;
172
173
    int _be_version;
174
};
175
176
template <typename Data>
177
class AggregateFunctionMapAggV2 final
178
        : public IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>,
179
          MultiExpression,
180
          NotNullableAggregateFunction {
181
public:
182
    AggregateFunctionMapAggV2() = default;
183
    AggregateFunctionMapAggV2(const DataTypes& argument_types_)
184
0
            : IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>(argument_types_) {
185
0
    }
186
187
    using IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>::version;
188
189
0
    std::string get_name() const override { return "map_agg_v2"; }
190
191
0
    DataTypePtr get_return_type() const override {
192
        /// keys and values column of `ColumnMap` are always nullable.
193
0
        return std::make_shared<DataTypeMap>(make_nullable(argument_types[0]),
194
0
                                             make_nullable(argument_types[1]));
195
0
    }
196
197
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
198
0
             Arena&) const override {
199
0
        Field key, value;
200
0
        columns[0]->get(row_num, key);
201
0
        columns[1]->get(row_num, value);
202
0
        this->data(place).add_single(key, value);
203
0
    }
204
205
0
    void create(AggregateDataPtr __restrict place) const override {
206
0
        new (place) Data(argument_types, version);
207
0
    }
208
209
0
    void reset(AggregateDataPtr place) const override { this->data(place).reset(); }
210
211
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
212
0
               Arena&) const override {
213
0
        this->data(place).merge(this->data(rhs));
214
0
    }
215
216
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
217
0
        this->data(place).write(buf);
218
0
    }
219
220
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
221
0
                     Arena&) const override {
222
0
        this->data(place).read(buf);
223
0
    }
224
225
    void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,
226
0
                                           const size_t num_rows, Arena&) const override {
227
0
        auto& col = assert_cast<ColumnMap&>(*dst);
228
0
        for (size_t i = 0; i != num_rows; ++i) {
229
0
            Field key, value;
230
0
            columns[0]->get(i, key);
231
0
            columns[1]->get(i, value);
232
0
            col.insert(Field::create_field<TYPE_MAP>(
233
0
                    Map {Field::create_field<TYPE_ARRAY>(Array {key}),
234
0
                         Field::create_field<TYPE_ARRAY>(Array {value})}));
235
0
        }
236
0
    }
237
238
    void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset,
239
0
                             MutableColumnPtr& dst, const size_t num_rows) const override {
240
0
        for (size_t i = 0; i != num_rows; ++i) {
241
0
            Data& data_ = this->data(places[i] + offset);
242
0
            data_.insert_result_into(*dst);
243
0
        }
244
0
    }
245
246
    void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
247
                                                 const IColumn& column, size_t begin, size_t end,
248
0
                                                 Arena&) const override {
249
0
        DCHECK(end <= column.size() && begin <= end)
250
0
                << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
251
0
        const auto& col = assert_cast<const ColumnMap&>(column);
252
0
        for (size_t i = begin; i <= end; ++i) {
253
0
            auto map = col[i].get<TYPE_MAP>();
254
0
            this->data(place).add(map[0], map[1]);
255
0
        }
256
0
    }
257
258
    void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset,
259
                                   AggregateDataPtr rhs, const IColumn* column, Arena&,
260
0
                                   const size_t num_rows) const override {
261
0
        const auto& col = assert_cast<const ColumnMap&>(*column);
262
0
        for (size_t i = 0; i != num_rows; ++i) {
263
0
            auto map = col[i].get<TYPE_MAP>();
264
0
            this->data(places[i] + offset).add(map[0], map[1]);
265
0
        }
266
0
    }
267
268
    void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset,
269
                                            AggregateDataPtr rhs, const IColumn* column, Arena&,
270
0
                                            const size_t num_rows) const override {
271
0
        const auto& col = assert_cast<const ColumnMap&>(*column);
272
0
        for (size_t i = 0; i != num_rows; ++i) {
273
0
            if (places[i]) {
274
0
                auto map = col[i].get<TYPE_MAP>();
275
0
                this->data(places[i] + offset).add(map[0], map[1]);
276
0
            }
277
0
        }
278
0
    }
279
280
    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
281
0
                                         IColumn& to) const override {
282
0
        this->data(place).insert_result_into(to);
283
0
    }
284
285
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
286
0
        this->data(place).insert_result_into(to);
287
0
    }
288
289
0
    [[nodiscard]] MutableColumnPtr create_serialize_column() const override {
290
0
        return get_return_type()->create_column();
291
0
    }
292
293
0
    [[nodiscard]] DataTypePtr get_serialized_type() const override { return get_return_type(); }
294
295
protected:
296
    using IAggregateFunction::argument_types;
297
};
298
299
} // namespace doris
300
301
#include "common/compile_check_end.h"