be/src/format/arrow/arrow_row_batch.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "format/arrow/arrow_row_batch.h" |
19 | | |
20 | | #include <arrow/buffer.h> |
21 | | #include <arrow/io/memory.h> |
22 | | #include <arrow/ipc/writer.h> |
23 | | #include <arrow/record_batch.h> |
24 | | #include <arrow/result.h> |
25 | | #include <arrow/status.h> |
26 | | #include <arrow/type.h> |
27 | | #include <arrow/type_fwd.h> |
28 | | #include <arrow/util/key_value_metadata.h> |
29 | | #include <glog/logging.h> |
30 | | #include <stdint.h> |
31 | | |
32 | | #include <algorithm> |
33 | | #include <cstdlib> |
34 | | #include <memory> |
35 | | #include <utility> |
36 | | #include <vector> |
37 | | |
38 | | #include "core/block/block.h" |
39 | | #include "core/data_type/data_type_agg_state.h" |
40 | | #include "core/data_type/data_type_array.h" |
41 | | #include "core/data_type/data_type_map.h" |
42 | | #include "core/data_type/data_type_struct.h" |
43 | | #include "core/data_type/define_primitive_type.h" |
44 | | #include "exprs/vexpr.h" |
45 | | #include "exprs/vexpr_context.h" |
46 | | #include "format/arrow/arrow_block_convertor.h" |
47 | | #include "runtime/descriptors.h" |
48 | | |
49 | | namespace doris { |
50 | | |
51 | | Status convert_to_arrow_type(const DataTypePtr& origin_type, |
52 | | std::shared_ptr<arrow::DataType>* result, |
53 | 565 | const std::string& timezone) { |
54 | 565 | auto type = get_serialized_type(origin_type); |
55 | 565 | switch (type->get_primitive_type()) { |
56 | 0 | case TYPE_NULL: |
57 | 0 | *result = arrow::null(); |
58 | 0 | break; |
59 | 9 | case TYPE_TINYINT: |
60 | 9 | *result = arrow::int8(); |
61 | 9 | break; |
62 | 9 | case TYPE_SMALLINT: |
63 | 9 | *result = arrow::int16(); |
64 | 9 | break; |
65 | 37 | case TYPE_INT: |
66 | 37 | *result = arrow::int32(); |
67 | 37 | break; |
68 | 9 | case TYPE_BIGINT: |
69 | 9 | *result = arrow::int64(); |
70 | 9 | break; |
71 | 9 | case TYPE_FLOAT: |
72 | 9 | *result = arrow::float32(); |
73 | 9 | break; |
74 | 22 | case TYPE_DOUBLE: |
75 | 22 | *result = arrow::float64(); |
76 | 22 | break; |
77 | 0 | case TYPE_TIMEV2: |
78 | 0 | *result = arrow::float64(); |
79 | 0 | break; |
80 | 25 | case TYPE_IPV4: |
81 | | // ipv4 is uint32, but parquet not uint32, it's will be convert to int64 |
82 | | // so use int32 directly |
83 | 25 | *result = arrow::int32(); |
84 | 25 | break; |
85 | 20 | case TYPE_IPV6: |
86 | 20 | *result = arrow::utf8(); |
87 | 20 | break; |
88 | 19 | case TYPE_LARGEINT: |
89 | 19 | case TYPE_VARCHAR: |
90 | 19 | case TYPE_CHAR: |
91 | 36 | case TYPE_DATE: |
92 | 49 | case TYPE_DATETIME: |
93 | 97 | case TYPE_STRING: |
94 | 97 | case TYPE_JSONB: |
95 | 97 | *result = arrow::utf8(); |
96 | 97 | break; |
97 | 13 | case TYPE_DATEV2: |
98 | 13 | *result = std::make_shared<arrow::Date32Type>(); |
99 | 13 | break; |
100 | | // TODO: maybe need to distinguish TYPE_DATETIME and TYPE_TIMESTAMPTZ |
101 | 0 | case TYPE_TIMESTAMPTZ: |
102 | 18 | case TYPE_DATETIMEV2: |
103 | 18 | if (type->get_scale() > 3) { |
104 | 0 | *result = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::MICRO, timezone); |
105 | 18 | } else if (type->get_scale() > 0) { |
106 | 4 | *result = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::MILLI, timezone); |
107 | 14 | } else { |
108 | 14 | *result = std::make_shared<arrow::TimestampType>(arrow::TimeUnit::SECOND, timezone); |
109 | 14 | } |
110 | 18 | break; |
111 | 5 | case TYPE_DECIMALV2: |
112 | 27 | case TYPE_DECIMAL32: |
113 | 43 | case TYPE_DECIMAL64: |
114 | 55 | case TYPE_DECIMAL128I: |
115 | 55 | *result = std::make_shared<arrow::Decimal128Type>(type->get_precision(), type->get_scale()); |
116 | 55 | break; |
117 | 16 | case TYPE_DECIMAL256: |
118 | 16 | *result = std::make_shared<arrow::Decimal256Type>(type->get_precision(), type->get_scale()); |
119 | 16 | break; |
120 | 9 | case TYPE_BOOLEAN: |
121 | 9 | *result = arrow::boolean(); |
122 | 9 | break; |
123 | 100 | case TYPE_ARRAY: { |
124 | 100 | const auto* type_arr = assert_cast<const DataTypeArray*>(remove_nullable(type).get()); |
125 | 100 | std::shared_ptr<arrow::DataType> item_type; |
126 | 100 | RETURN_IF_ERROR(convert_to_arrow_type(type_arr->get_nested_type(), &item_type, timezone)); |
127 | 100 | *result = std::make_shared<arrow::ListType>(item_type); |
128 | 100 | break; |
129 | 100 | } |
130 | 62 | case TYPE_MAP: { |
131 | 62 | const auto* type_map = assert_cast<const DataTypeMap*>(remove_nullable(type).get()); |
132 | 62 | std::shared_ptr<arrow::DataType> key_type; |
133 | 62 | std::shared_ptr<arrow::DataType> val_type; |
134 | 62 | RETURN_IF_ERROR(convert_to_arrow_type(type_map->get_key_type(), &key_type, timezone)); |
135 | 62 | RETURN_IF_ERROR(convert_to_arrow_type(type_map->get_value_type(), &val_type, timezone)); |
136 | 62 | *result = std::make_shared<arrow::MapType>(key_type, val_type); |
137 | 62 | break; |
138 | 62 | } |
139 | 53 | case TYPE_STRUCT: { |
140 | 53 | const auto* type_struct = assert_cast<const DataTypeStruct*>(remove_nullable(type).get()); |
141 | 53 | std::vector<std::shared_ptr<arrow::Field>> fields; |
142 | 175 | for (size_t i = 0; i < type_struct->get_elements().size(); i++) { |
143 | 122 | std::shared_ptr<arrow::DataType> field_type; |
144 | 122 | RETURN_IF_ERROR( |
145 | 122 | convert_to_arrow_type(type_struct->get_element(i), &field_type, timezone)); |
146 | 122 | fields.push_back( |
147 | 122 | std::make_shared<arrow::Field>(type_struct->get_element_name(i), field_type, |
148 | 122 | type_struct->get_element(i)->is_nullable())); |
149 | 122 | } |
150 | 53 | *result = std::make_shared<arrow::StructType>(fields); |
151 | 53 | break; |
152 | 53 | } |
153 | 0 | case TYPE_VARIANT: { |
154 | 0 | *result = arrow::utf8(); |
155 | 0 | break; |
156 | 53 | } |
157 | 0 | case TYPE_QUANTILE_STATE: |
158 | 1 | case TYPE_BITMAP: |
159 | 2 | case TYPE_HLL: { |
160 | 2 | *result = arrow::binary(); |
161 | 2 | break; |
162 | 1 | } |
163 | 0 | case TYPE_VARBINARY: { |
164 | 0 | *result = arrow::binary(); |
165 | 0 | break; |
166 | 1 | } |
167 | 0 | default: |
168 | 0 | return Status::InvalidArgument("Unknown primitive type({}) convert to Arrow type", |
169 | 0 | type->get_name()); |
170 | 565 | } |
171 | 565 | return Status::OK(); |
172 | 565 | } |
173 | | |
174 | | // Helper function to create an Arrow Field with type metadata if applicable, such as IP types |
175 | | static std::shared_ptr<arrow::Field> create_arrow_field_with_metadata( |
176 | | const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type, |
177 | 204 | bool is_nullable, PrimitiveType primitive_type) { |
178 | 204 | if (primitive_type == PrimitiveType::TYPE_IPV4) { |
179 | 4 | auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"IPV4"}); |
180 | 4 | return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata); |
181 | 200 | } else if (primitive_type == PrimitiveType::TYPE_IPV6) { |
182 | 4 | auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"IPV6"}); |
183 | 4 | return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata); |
184 | 196 | } else { |
185 | 196 | return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable); |
186 | 196 | } |
187 | 204 | } |
188 | | |
189 | | Status get_arrow_schema_from_block(const Block& block, std::shared_ptr<arrow::Schema>* result, |
190 | 30 | const std::string& timezone) { |
191 | 30 | std::vector<std::shared_ptr<arrow::Field>> fields; |
192 | 204 | for (const auto& type_and_name : block) { |
193 | 204 | std::shared_ptr<arrow::DataType> arrow_type; |
194 | 204 | RETURN_IF_ERROR(convert_to_arrow_type(type_and_name.type, &arrow_type, timezone)); |
195 | 204 | auto field = create_arrow_field_with_metadata(type_and_name.name, arrow_type, |
196 | 204 | type_and_name.type->is_nullable(), |
197 | 204 | type_and_name.type->get_primitive_type()); |
198 | 204 | fields.push_back(field); |
199 | 204 | } |
200 | 30 | *result = arrow::schema(std::move(fields)); |
201 | 30 | return Status::OK(); |
202 | 30 | } |
203 | | |
204 | | Status get_arrow_schema_from_expr_ctxs(const VExprContextSPtrs& output_vexpr_ctxs, |
205 | | std::shared_ptr<arrow::Schema>* result, |
206 | 0 | const std::string& timezone) { |
207 | 0 | std::vector<std::shared_ptr<arrow::Field>> fields; |
208 | 0 | for (int i = 0; i < output_vexpr_ctxs.size(); i++) { |
209 | 0 | std::shared_ptr<arrow::DataType> arrow_type; |
210 | 0 | auto root_expr = output_vexpr_ctxs.at(i)->root(); |
211 | 0 | RETURN_IF_ERROR(convert_to_arrow_type(root_expr->data_type(), &arrow_type, timezone)); |
212 | 0 | auto field_name = root_expr->is_slot_ref() && !root_expr->expr_label().empty() |
213 | 0 | ? root_expr->expr_label() |
214 | 0 | : fmt::format("{}_{}", root_expr->data_type()->get_name(), i); |
215 | 0 | auto field = |
216 | 0 | create_arrow_field_with_metadata(field_name, arrow_type, root_expr->is_nullable(), |
217 | 0 | root_expr->data_type()->get_primitive_type()); |
218 | 0 | fields.push_back(field); |
219 | 0 | } |
220 | 0 | *result = arrow::schema(std::move(fields)); |
221 | 0 | return Status::OK(); |
222 | 0 | } |
223 | | |
224 | 0 | Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::string* result) { |
225 | | // create sink memory buffer outputstream with the computed capacity |
226 | 0 | int64_t capacity; |
227 | 0 | arrow::Status a_st = arrow::ipc::GetRecordBatchSize(record_batch, &capacity); |
228 | 0 | if (!a_st.ok()) { |
229 | 0 | return Status::InternalError("GetRecordBatchSize failure, reason: {}", a_st.ToString()); |
230 | 0 | } |
231 | 0 | auto sink_res = arrow::io::BufferOutputStream::Create(capacity, arrow::default_memory_pool()); |
232 | 0 | if (!sink_res.ok()) { |
233 | 0 | return Status::InternalError("create BufferOutputStream failure, reason: {}", |
234 | 0 | sink_res.status().ToString()); |
235 | 0 | } |
236 | 0 | std::shared_ptr<arrow::io::BufferOutputStream> sink = sink_res.ValueOrDie(); |
237 | | // create RecordBatch Writer |
238 | 0 | auto res = arrow::ipc::MakeStreamWriter(sink.get(), record_batch.schema()); |
239 | 0 | if (!res.ok()) { |
240 | 0 | return Status::InternalError("open RecordBatchStreamWriter failure, reason: {}", |
241 | 0 | res.status().ToString()); |
242 | 0 | } |
243 | | // write RecordBatch to memory buffer outputstream |
244 | 0 | std::shared_ptr<arrow::ipc::RecordBatchWriter> record_batch_writer = res.ValueOrDie(); |
245 | 0 | a_st = record_batch_writer->WriteRecordBatch(record_batch); |
246 | 0 | if (!a_st.ok()) { |
247 | 0 | return Status::InternalError("write record batch failure, reason: {}", a_st.ToString()); |
248 | 0 | } |
249 | 0 | a_st = record_batch_writer->Close(); |
250 | 0 | if (!a_st.ok()) { |
251 | 0 | return Status::InternalError("Close failed, reason: {}", a_st.ToString()); |
252 | 0 | } |
253 | 0 | auto finish_res = sink->Finish(); |
254 | 0 | if (!finish_res.ok()) { |
255 | 0 | return Status::InternalError("allocate result buffer failure, reason: {}", |
256 | 0 | finish_res.status().ToString()); |
257 | 0 | } |
258 | 0 | *result = finish_res.ValueOrDie()->ToString(); |
259 | | // close the sink |
260 | 0 | a_st = sink->Close(); |
261 | 0 | if (!a_st.ok()) { |
262 | 0 | return Status::InternalError("Close failed, reason: {}", a_st.ToString()); |
263 | 0 | } |
264 | 0 | return Status::OK(); |
265 | 0 | } |
266 | | |
267 | 0 | Status serialize_arrow_schema(std::shared_ptr<arrow::Schema>* schema, std::string* result) { |
268 | 0 | auto make_empty_result = arrow::RecordBatch::MakeEmpty(*schema); |
269 | 0 | if (!make_empty_result.ok()) { |
270 | 0 | return Status::InternalError("serialize_arrow_schema failed, reason: {}", |
271 | 0 | make_empty_result.status().ToString()); |
272 | 0 | } |
273 | 0 | auto batch = make_empty_result.ValueOrDie(); |
274 | 0 | return serialize_record_batch(*batch, result); |
275 | 0 | } |
276 | | |
277 | | } // namespace doris |