be/src/exprs/aggregate/aggregate_function_map_v2.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <glog/logging.h> |
21 | | #include <parallel_hashmap/phmap.h> |
22 | | |
23 | | #include "core/assert_cast.h" |
24 | | #include "core/column/column_decimal.h" |
25 | | #include "core/column/column_map.h" |
26 | | #include "core/column/column_string.h" |
27 | | #include "core/data_type/data_type_map.h" |
28 | | #include "core/field.h" |
29 | | #include "exprs/aggregate/aggregate_function.h" |
30 | | #include "exprs/aggregate/aggregate_function_simple_factory.h" |
31 | | |
32 | | namespace doris { |
33 | | |
34 | | struct AggregateFunctionMapAggDataV2 { |
35 | | using Map = phmap::flat_hash_map<Field, int64_t>; |
36 | | |
37 | 0 | AggregateFunctionMapAggDataV2() { |
38 | 0 | throw Exception(Status::FatalError("__builtin_unreachable")); |
39 | 0 | } |
40 | | |
41 | | AggregateFunctionMapAggDataV2(const DataTypes& argument_types, const int be_exec_version) |
42 | 0 | : _be_version(be_exec_version) { |
43 | 0 | _key_type = make_nullable(argument_types[0]); |
44 | 0 | _value_type = make_nullable(argument_types[1]); |
45 | 0 | _key_column = _key_type->create_column(); |
46 | 0 | _value_column = _value_type->create_column(); |
47 | 0 | } |
48 | | |
49 | 0 | void reset() { |
50 | 0 | _map.clear(); |
51 | 0 | _key_column->clear(); |
52 | 0 | _value_column->clear(); |
53 | 0 | } |
54 | | |
55 | 0 | void add_single(const Field& key, const Field& value) { |
56 | 0 | if (UNLIKELY(_map.find(key) != _map.end())) { |
57 | 0 | return; |
58 | 0 | } |
59 | | |
60 | 0 | _map.emplace(key, _key_column->size()); |
61 | 0 | _key_column->insert(key); |
62 | 0 | _value_column->insert(value); |
63 | 0 | } |
64 | | |
65 | 0 | void add(const Field& key_, const Field& value) { |
66 | 0 | DCHECK(!key_.is_null()); |
67 | 0 | auto key_array = key_.get<TYPE_ARRAY>(); |
68 | 0 | auto value_array = value.get<TYPE_ARRAY>(); |
69 | |
|
70 | 0 | const auto count = key_array.size(); |
71 | 0 | DCHECK_EQ(count, value_array.size()); |
72 | |
|
73 | 0 | for (size_t i = 0; i != count; ++i) { |
74 | 0 | const auto& key = key_array[i]; |
75 | |
|
76 | 0 | if (UNLIKELY(_map.find(key) != _map.end())) { |
77 | 0 | continue; |
78 | 0 | } |
79 | | |
80 | 0 | _map.emplace(key, _key_column->size()); |
81 | 0 | _key_column->insert(key); |
82 | 0 | _value_column->insert(value_array[i]); |
83 | 0 | } |
84 | 0 | } |
85 | | |
86 | 0 | void merge(const AggregateFunctionMapAggDataV2& other) { |
87 | 0 | const size_t num_rows = other._key_column->size(); |
88 | 0 | if (num_rows == 0) { |
89 | 0 | return; |
90 | 0 | } |
91 | | |
92 | 0 | auto& other_key_column_nullable = assert_cast<ColumnNullable&>(*other._key_column); |
93 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
94 | 0 | const auto& key = other_key_column_nullable[i]; |
95 | 0 | if (_map.find(key) != _map.cend()) { |
96 | 0 | continue; |
97 | 0 | } |
98 | | |
99 | 0 | _map.emplace(key, _key_column->size()); |
100 | 0 | _key_column->insert(key); |
101 | |
|
102 | 0 | _value_column->insert((*other._value_column)[i]); |
103 | 0 | } |
104 | 0 | } |
105 | | |
106 | 0 | void insert_result_into(IColumn& to) const { |
107 | 0 | auto& dst = assert_cast<ColumnMap&>(to); |
108 | 0 | size_t num_rows = _key_column->size(); |
109 | 0 | auto& offsets = dst.get_offsets(); |
110 | 0 | auto& dst_key_column = assert_cast<ColumnNullable&>(dst.get_keys()); |
111 | 0 | dst_key_column.insert_range_from(*_key_column, 0, num_rows); |
112 | 0 | dst.get_values().insert_range_from(*_value_column, 0, num_rows); |
113 | 0 | if (offsets.empty()) { |
114 | 0 | offsets.push_back(num_rows); |
115 | 0 | } else { |
116 | 0 | offsets.push_back(offsets.back() + num_rows); |
117 | 0 | } |
118 | 0 | } |
119 | | |
120 | 0 | void write(BufferWritable& buf) const { |
121 | 0 | auto serialized_bytes = |
122 | 0 | _key_type->get_uncompressed_serialized_bytes(*_key_column, _be_version); |
123 | |
|
124 | 0 | std::string serialized_buffer; |
125 | 0 | serialized_buffer.resize(serialized_bytes); |
126 | |
|
127 | 0 | auto* buf_ptr = _key_type->serialize(*_key_column, serialized_buffer.data(), _be_version); |
128 | 0 | int64_t written_bytes = buf_ptr - serialized_buffer.data(); |
129 | 0 | DCHECK_LE(written_bytes, serialized_bytes); |
130 | |
|
131 | 0 | serialized_buffer.resize(serialized_bytes); |
132 | 0 | buf.write_binary(serialized_buffer); |
133 | |
|
134 | 0 | serialized_bytes = |
135 | 0 | _value_type->get_uncompressed_serialized_bytes(*_value_column, _be_version); |
136 | |
|
137 | 0 | serialized_buffer.resize(serialized_bytes); |
138 | |
|
139 | 0 | buf_ptr = _value_type->serialize(*_value_column, serialized_buffer.data(), _be_version); |
140 | 0 | written_bytes = buf_ptr - serialized_buffer.data(); |
141 | 0 | DCHECK_LE(written_bytes, serialized_bytes); |
142 | |
|
143 | 0 | serialized_buffer.resize(written_bytes); |
144 | 0 | buf.write_binary(serialized_buffer); |
145 | 0 | } |
146 | | |
147 | 0 | void read(BufferReadable& buf) { |
148 | 0 | std::string deserialized_buffer; |
149 | |
|
150 | 0 | buf.read_binary(deserialized_buffer); |
151 | |
|
152 | 0 | const auto* ptr = |
153 | 0 | _key_type->deserialize(deserialized_buffer.data(), &_key_column, _be_version); |
154 | 0 | auto read_bytes = ptr - deserialized_buffer.data(); |
155 | 0 | DCHECK_EQ(read_bytes, deserialized_buffer.size()); |
156 | |
|
157 | 0 | buf.read_binary(deserialized_buffer); |
158 | |
|
159 | 0 | ptr = _value_type->deserialize(deserialized_buffer.data(), &_value_column, _be_version); |
160 | 0 | read_bytes = ptr - deserialized_buffer.data(); |
161 | 0 | DCHECK_EQ(read_bytes, deserialized_buffer.size()); |
162 | 0 | } |
163 | | |
164 | | private: |
165 | | Map _map; |
166 | | Arena _arena; |
167 | | IColumn::MutablePtr _key_column; |
168 | | IColumn::MutablePtr _value_column; |
169 | | DataTypePtr _key_type; |
170 | | DataTypePtr _value_type; |
171 | | |
172 | | int _be_version; |
173 | | }; |
174 | | |
175 | | template <typename Data> |
176 | | class AggregateFunctionMapAggV2 final |
177 | | : public IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>, |
178 | | MultiExpression, |
179 | | NotNullableAggregateFunction { |
180 | | public: |
181 | | AggregateFunctionMapAggV2() = default; |
182 | | AggregateFunctionMapAggV2(const DataTypes& argument_types_) |
183 | 0 | : IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>(argument_types_) { |
184 | 0 | } |
185 | | |
186 | | using IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>::version; |
187 | | |
188 | 0 | std::string get_name() const override { return "map_agg_v2"; } |
189 | | |
190 | 0 | DataTypePtr get_return_type() const override { |
191 | | /// keys and values column of `ColumnMap` are always nullable. |
192 | 0 | return std::make_shared<DataTypeMap>(make_nullable(argument_types[0]), |
193 | 0 | make_nullable(argument_types[1])); |
194 | 0 | } |
195 | | |
196 | | void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
197 | 0 | Arena&) const override { |
198 | 0 | Field key, value; |
199 | 0 | columns[0]->get(row_num, key); |
200 | 0 | columns[1]->get(row_num, value); |
201 | 0 | this->data(place).add_single(key, value); |
202 | 0 | } |
203 | | |
204 | 0 | void create(AggregateDataPtr __restrict place) const override { |
205 | 0 | new (place) Data(argument_types, version); |
206 | 0 | } |
207 | | |
208 | 0 | void reset(AggregateDataPtr place) const override { this->data(place).reset(); } |
209 | | |
210 | | void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
211 | 0 | Arena&) const override { |
212 | 0 | this->data(place).merge(this->data(rhs)); |
213 | 0 | } |
214 | | |
215 | 0 | void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
216 | 0 | this->data(place).write(buf); |
217 | 0 | } |
218 | | |
219 | | void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
220 | 0 | Arena&) const override { |
221 | 0 | this->data(place).read(buf); |
222 | 0 | } |
223 | | |
224 | | void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, |
225 | 0 | const size_t num_rows, Arena&) const override { |
226 | 0 | auto& col = assert_cast<ColumnMap&>(*dst); |
227 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
228 | 0 | Field key, value; |
229 | 0 | columns[0]->get(i, key); |
230 | 0 | columns[1]->get(i, value); |
231 | 0 | col.insert(Field::create_field<TYPE_MAP>( |
232 | 0 | Map {Field::create_field<TYPE_ARRAY>(Array {key}), |
233 | 0 | Field::create_field<TYPE_ARRAY>(Array {value})})); |
234 | 0 | } |
235 | 0 | } |
236 | | |
237 | | void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset, |
238 | 0 | MutableColumnPtr& dst, const size_t num_rows) const override { |
239 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
240 | 0 | Data& data_ = this->data(places[i] + offset); |
241 | 0 | data_.insert_result_into(*dst); |
242 | 0 | } |
243 | 0 | } |
244 | | |
245 | | void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place, |
246 | | const IColumn& column, size_t begin, size_t end, |
247 | 0 | Arena&) const override { |
248 | 0 | DCHECK(end <= column.size() && begin <= end) |
249 | 0 | << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size(); |
250 | 0 | const auto& col = assert_cast<const ColumnMap&>(column); |
251 | 0 | for (size_t i = begin; i <= end; ++i) { |
252 | 0 | auto map = col[i].get<TYPE_MAP>(); |
253 | 0 | this->data(place).add(map[0], map[1]); |
254 | 0 | } |
255 | 0 | } |
256 | | |
257 | | void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset, |
258 | | AggregateDataPtr rhs, const IColumn* column, Arena&, |
259 | 0 | const size_t num_rows) const override { |
260 | 0 | const auto& col = assert_cast<const ColumnMap&>(*column); |
261 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
262 | 0 | auto map = col[i].get<TYPE_MAP>(); |
263 | 0 | this->data(places[i] + offset).add(map[0], map[1]); |
264 | 0 | } |
265 | 0 | } |
266 | | |
267 | | void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset, |
268 | | AggregateDataPtr rhs, const IColumn* column, Arena&, |
269 | 0 | const size_t num_rows) const override { |
270 | 0 | const auto& col = assert_cast<const ColumnMap&>(*column); |
271 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
272 | 0 | if (places[i]) { |
273 | 0 | auto map = col[i].get<TYPE_MAP>(); |
274 | 0 | this->data(places[i] + offset).add(map[0], map[1]); |
275 | 0 | } |
276 | 0 | } |
277 | 0 | } |
278 | | |
279 | | void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, |
280 | 0 | IColumn& to) const override { |
281 | 0 | this->data(place).insert_result_into(to); |
282 | 0 | } |
283 | | |
284 | 0 | void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
285 | 0 | this->data(place).insert_result_into(to); |
286 | 0 | } |
287 | | |
288 | 0 | [[nodiscard]] MutableColumnPtr create_serialize_column() const override { |
289 | 0 | return get_return_type()->create_column(); |
290 | 0 | } |
291 | | |
292 | 0 | [[nodiscard]] DataTypePtr get_serialized_type() const override { return get_return_type(); } |
293 | | |
294 | | protected: |
295 | | using IAggregateFunction::argument_types; |
296 | | }; |
297 | | |
298 | | } // namespace doris |