be/src/exprs/aggregate/aggregate_function_map_v2.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <glog/logging.h> |
21 | | #include <parallel_hashmap/phmap.h> |
22 | | |
23 | | #include "core/assert_cast.h" |
24 | | #include "core/column/column_decimal.h" |
25 | | #include "core/column/column_map.h" |
26 | | #include "core/column/column_string.h" |
27 | | #include "core/data_type/data_type_map.h" |
28 | | #include "core/field.h" |
29 | | #include "exprs/aggregate/aggregate_function.h" |
30 | | #include "exprs/aggregate/aggregate_function_simple_factory.h" |
31 | | |
32 | | namespace doris { |
33 | | #include "common/compile_check_begin.h" |
34 | | |
35 | | struct AggregateFunctionMapAggDataV2 { |
36 | | using Map = phmap::flat_hash_map<Field, int64_t>; |
37 | | |
38 | 0 | AggregateFunctionMapAggDataV2() { |
39 | 0 | throw Exception(Status::FatalError("__builtin_unreachable")); |
40 | 0 | } |
41 | | |
42 | | AggregateFunctionMapAggDataV2(const DataTypes& argument_types, const int be_exec_version) |
43 | 0 | : _be_version(be_exec_version) { |
44 | 0 | _key_type = make_nullable(argument_types[0]); |
45 | 0 | _value_type = make_nullable(argument_types[1]); |
46 | 0 | _key_column = _key_type->create_column(); |
47 | 0 | _value_column = _value_type->create_column(); |
48 | 0 | } |
49 | | |
50 | 0 | void reset() { |
51 | 0 | _map.clear(); |
52 | 0 | _key_column->clear(); |
53 | 0 | _value_column->clear(); |
54 | 0 | } |
55 | | |
56 | 0 | void add_single(const Field& key, const Field& value) { |
57 | 0 | if (UNLIKELY(_map.find(key) != _map.end())) { |
58 | 0 | return; |
59 | 0 | } |
60 | | |
61 | 0 | _map.emplace(key, _key_column->size()); |
62 | 0 | _key_column->insert(key); |
63 | 0 | _value_column->insert(value); |
64 | 0 | } |
65 | | |
66 | 0 | void add(const Field& key_, const Field& value) { |
67 | 0 | DCHECK(!key_.is_null()); |
68 | 0 | auto key_array = key_.get<TYPE_ARRAY>(); |
69 | 0 | auto value_array = value.get<TYPE_ARRAY>(); |
70 | |
|
71 | 0 | const auto count = key_array.size(); |
72 | 0 | DCHECK_EQ(count, value_array.size()); |
73 | |
|
74 | 0 | for (size_t i = 0; i != count; ++i) { |
75 | 0 | const auto& key = key_array[i]; |
76 | |
|
77 | 0 | if (UNLIKELY(_map.find(key) != _map.end())) { |
78 | 0 | continue; |
79 | 0 | } |
80 | | |
81 | 0 | _map.emplace(key, _key_column->size()); |
82 | 0 | _key_column->insert(key); |
83 | 0 | _value_column->insert(value_array[i]); |
84 | 0 | } |
85 | 0 | } |
86 | | |
87 | 0 | void merge(const AggregateFunctionMapAggDataV2& other) { |
88 | 0 | const size_t num_rows = other._key_column->size(); |
89 | 0 | if (num_rows == 0) { |
90 | 0 | return; |
91 | 0 | } |
92 | | |
93 | 0 | auto& other_key_column_nullable = assert_cast<ColumnNullable&>(*other._key_column); |
94 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
95 | 0 | const auto& key = other_key_column_nullable[i]; |
96 | 0 | if (_map.find(key) != _map.cend()) { |
97 | 0 | continue; |
98 | 0 | } |
99 | | |
100 | 0 | _map.emplace(key, _key_column->size()); |
101 | 0 | _key_column->insert(key); |
102 | |
|
103 | 0 | _value_column->insert((*other._value_column)[i]); |
104 | 0 | } |
105 | 0 | } |
106 | | |
107 | 0 | void insert_result_into(IColumn& to) const { |
108 | 0 | auto& dst = assert_cast<ColumnMap&>(to); |
109 | 0 | size_t num_rows = _key_column->size(); |
110 | 0 | auto& offsets = dst.get_offsets(); |
111 | 0 | auto& dst_key_column = assert_cast<ColumnNullable&>(dst.get_keys()); |
112 | 0 | dst_key_column.insert_range_from(*_key_column, 0, num_rows); |
113 | 0 | dst.get_values().insert_range_from(*_value_column, 0, num_rows); |
114 | 0 | if (offsets.empty()) { |
115 | 0 | offsets.push_back(num_rows); |
116 | 0 | } else { |
117 | 0 | offsets.push_back(offsets.back() + num_rows); |
118 | 0 | } |
119 | 0 | } |
120 | | |
121 | 0 | void write(BufferWritable& buf) const { |
122 | 0 | auto serialized_bytes = |
123 | 0 | _key_type->get_uncompressed_serialized_bytes(*_key_column, _be_version); |
124 | |
|
125 | 0 | std::string serialized_buffer; |
126 | 0 | serialized_buffer.resize(serialized_bytes); |
127 | |
|
128 | 0 | auto* buf_ptr = _key_type->serialize(*_key_column, serialized_buffer.data(), _be_version); |
129 | 0 | int64_t written_bytes = buf_ptr - serialized_buffer.data(); |
130 | 0 | DCHECK_LE(written_bytes, serialized_bytes); |
131 | |
|
132 | 0 | serialized_buffer.resize(serialized_bytes); |
133 | 0 | buf.write_binary(serialized_buffer); |
134 | |
|
135 | 0 | serialized_bytes = |
136 | 0 | _value_type->get_uncompressed_serialized_bytes(*_value_column, _be_version); |
137 | |
|
138 | 0 | serialized_buffer.resize(serialized_bytes); |
139 | |
|
140 | 0 | buf_ptr = _value_type->serialize(*_value_column, serialized_buffer.data(), _be_version); |
141 | 0 | written_bytes = buf_ptr - serialized_buffer.data(); |
142 | 0 | DCHECK_LE(written_bytes, serialized_bytes); |
143 | |
|
144 | 0 | serialized_buffer.resize(written_bytes); |
145 | 0 | buf.write_binary(serialized_buffer); |
146 | 0 | } |
147 | | |
148 | 0 | void read(BufferReadable& buf) { |
149 | 0 | std::string deserialized_buffer; |
150 | |
|
151 | 0 | buf.read_binary(deserialized_buffer); |
152 | |
|
153 | 0 | const auto* ptr = |
154 | 0 | _key_type->deserialize(deserialized_buffer.data(), &_key_column, _be_version); |
155 | 0 | auto read_bytes = ptr - deserialized_buffer.data(); |
156 | 0 | DCHECK_EQ(read_bytes, deserialized_buffer.size()); |
157 | |
|
158 | 0 | buf.read_binary(deserialized_buffer); |
159 | |
|
160 | 0 | ptr = _value_type->deserialize(deserialized_buffer.data(), &_value_column, _be_version); |
161 | 0 | read_bytes = ptr - deserialized_buffer.data(); |
162 | 0 | DCHECK_EQ(read_bytes, deserialized_buffer.size()); |
163 | 0 | } |
164 | | |
165 | | private: |
166 | | Map _map; |
167 | | Arena _arena; |
168 | | IColumn::MutablePtr _key_column; |
169 | | IColumn::MutablePtr _value_column; |
170 | | DataTypePtr _key_type; |
171 | | DataTypePtr _value_type; |
172 | | |
173 | | int _be_version; |
174 | | }; |
175 | | |
176 | | template <typename Data> |
177 | | class AggregateFunctionMapAggV2 final |
178 | | : public IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>, |
179 | | MultiExpression, |
180 | | NotNullableAggregateFunction { |
181 | | public: |
182 | | AggregateFunctionMapAggV2() = default; |
183 | | AggregateFunctionMapAggV2(const DataTypes& argument_types_) |
184 | 0 | : IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>(argument_types_) { |
185 | 0 | } |
186 | | |
187 | | using IAggregateFunctionDataHelper<Data, AggregateFunctionMapAggV2<Data>>::version; |
188 | | |
189 | 0 | std::string get_name() const override { return "map_agg_v2"; } |
190 | | |
191 | 0 | DataTypePtr get_return_type() const override { |
192 | | /// keys and values column of `ColumnMap` are always nullable. |
193 | 0 | return std::make_shared<DataTypeMap>(make_nullable(argument_types[0]), |
194 | 0 | make_nullable(argument_types[1])); |
195 | 0 | } |
196 | | |
197 | | void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
198 | 0 | Arena&) const override { |
199 | 0 | Field key, value; |
200 | 0 | columns[0]->get(row_num, key); |
201 | 0 | columns[1]->get(row_num, value); |
202 | 0 | this->data(place).add_single(key, value); |
203 | 0 | } |
204 | | |
205 | 0 | void create(AggregateDataPtr __restrict place) const override { |
206 | 0 | new (place) Data(argument_types, version); |
207 | 0 | } |
208 | | |
209 | 0 | void reset(AggregateDataPtr place) const override { this->data(place).reset(); } |
210 | | |
211 | | void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
212 | 0 | Arena&) const override { |
213 | 0 | this->data(place).merge(this->data(rhs)); |
214 | 0 | } |
215 | | |
216 | 0 | void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
217 | 0 | this->data(place).write(buf); |
218 | 0 | } |
219 | | |
220 | | void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
221 | 0 | Arena&) const override { |
222 | 0 | this->data(place).read(buf); |
223 | 0 | } |
224 | | |
225 | | void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, |
226 | 0 | const size_t num_rows, Arena&) const override { |
227 | 0 | auto& col = assert_cast<ColumnMap&>(*dst); |
228 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
229 | 0 | Field key, value; |
230 | 0 | columns[0]->get(i, key); |
231 | 0 | columns[1]->get(i, value); |
232 | 0 | col.insert(Field::create_field<TYPE_MAP>( |
233 | 0 | Map {Field::create_field<TYPE_ARRAY>(Array {key}), |
234 | 0 | Field::create_field<TYPE_ARRAY>(Array {value})})); |
235 | 0 | } |
236 | 0 | } |
237 | | |
238 | | void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset, |
239 | 0 | MutableColumnPtr& dst, const size_t num_rows) const override { |
240 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
241 | 0 | Data& data_ = this->data(places[i] + offset); |
242 | 0 | data_.insert_result_into(*dst); |
243 | 0 | } |
244 | 0 | } |
245 | | |
246 | | void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place, |
247 | | const IColumn& column, size_t begin, size_t end, |
248 | 0 | Arena&) const override { |
249 | 0 | DCHECK(end <= column.size() && begin <= end) |
250 | 0 | << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size(); |
251 | 0 | const auto& col = assert_cast<const ColumnMap&>(column); |
252 | 0 | for (size_t i = begin; i <= end; ++i) { |
253 | 0 | auto map = col[i].get<TYPE_MAP>(); |
254 | 0 | this->data(place).add(map[0], map[1]); |
255 | 0 | } |
256 | 0 | } |
257 | | |
258 | | void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset, |
259 | | AggregateDataPtr rhs, const IColumn* column, Arena&, |
260 | 0 | const size_t num_rows) const override { |
261 | 0 | const auto& col = assert_cast<const ColumnMap&>(*column); |
262 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
263 | 0 | auto map = col[i].get<TYPE_MAP>(); |
264 | 0 | this->data(places[i] + offset).add(map[0], map[1]); |
265 | 0 | } |
266 | 0 | } |
267 | | |
268 | | void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset, |
269 | | AggregateDataPtr rhs, const IColumn* column, Arena&, |
270 | 0 | const size_t num_rows) const override { |
271 | 0 | const auto& col = assert_cast<const ColumnMap&>(*column); |
272 | 0 | for (size_t i = 0; i != num_rows; ++i) { |
273 | 0 | if (places[i]) { |
274 | 0 | auto map = col[i].get<TYPE_MAP>(); |
275 | 0 | this->data(places[i] + offset).add(map[0], map[1]); |
276 | 0 | } |
277 | 0 | } |
278 | 0 | } |
279 | | |
280 | | void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, |
281 | 0 | IColumn& to) const override { |
282 | 0 | this->data(place).insert_result_into(to); |
283 | 0 | } |
284 | | |
285 | 0 | void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
286 | 0 | this->data(place).insert_result_into(to); |
287 | 0 | } |
288 | | |
289 | 0 | [[nodiscard]] MutableColumnPtr create_serialize_column() const override { |
290 | 0 | return get_return_type()->create_column(); |
291 | 0 | } |
292 | | |
293 | 0 | [[nodiscard]] DataTypePtr get_serialized_type() const override { return get_return_type(); } |
294 | | |
295 | | protected: |
296 | | using IAggregateFunction::argument_types; |
297 | | }; |
298 | | |
299 | | } // namespace doris |
300 | | |
301 | | #include "common/compile_check_end.h" |