be/src/exprs/aggregate/aggregate_function_java_udaf.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <jni.h> |
21 | | #include <unistd.h> |
22 | | |
23 | | #include <cstdint> |
24 | | #include <memory> |
25 | | |
26 | | #include "absl/strings/substitute.h" |
27 | | #include "common/cast_set.h" |
28 | | #include "common/compiler_util.h" |
29 | | #include "common/exception.h" |
30 | | #include "common/logging.h" |
31 | | #include "common/status.h" |
32 | | #include "core/column/column_array.h" |
33 | | #include "core/column/column_map.h" |
34 | | #include "core/column/column_string.h" |
35 | | #include "core/field.h" |
36 | | #include "core/string_ref.h" |
37 | | #include "core/types.h" |
38 | | #include "exec/connector/jni_connector.h" |
39 | | #include "exprs/aggregate/aggregate_function.h" |
40 | | #include "runtime/user_function_cache.h" |
41 | | #include "util/io_helper.h" |
42 | | #include "util/jni-util.h" |
43 | | |
44 | | namespace doris { |
45 | | #include "common/compile_check_begin.h" |
46 | | |
47 | | const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor"; |
48 | | const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V"; |
49 | | const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V"; |
50 | | const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V"; |
51 | | const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZIIJILjava/util/Map;)V"; |
52 | | const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B"; |
53 | | const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V"; |
54 | | const char* UDAF_EXECUTOR_GET_SIGNATURE = "(JLjava/util/Map;)J"; |
55 | | const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V"; |
56 | | // Calling Java method about those signature means: "(argument-types)return-type" |
57 | | // https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html |
58 | | |
59 | | struct AggregateJavaUdafData { |
60 | | public: |
61 | 0 | AggregateJavaUdafData() = default; |
62 | 0 | AggregateJavaUdafData(int64_t num_args) { cast_set(argument_size, num_args); } |
63 | | |
64 | 0 | ~AggregateJavaUdafData() = default; |
65 | | |
66 | 0 | Status close_and_delete_object() { |
67 | 0 | JNIEnv* env = nullptr; |
68 | |
|
69 | 0 | RETURN_IF_ERROR(Jni::Env::Get(&env)); |
70 | | |
71 | 0 | auto st = executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_close_id) |
72 | 0 | .call(); |
73 | 0 | if (!st.ok()) { |
74 | 0 | LOG(WARNING) << "Failed to close JAVA UDAF: " << st.to_string(); |
75 | 0 | return st; |
76 | 0 | } |
77 | 0 | return Status::OK(); |
78 | 0 | } |
79 | | |
80 | 0 | Status init_udaf(const TFunction& fn, const std::string& local_location) { |
81 | 0 | JNIEnv* env = nullptr; |
82 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf init_udaf function"); |
83 | 0 | RETURN_IF_ERROR(Jni::Util::find_class(env, UDAF_EXECUTOR_CLASS, &executor_cl)); |
84 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(register_func_id(env), |
85 | 0 | "Java-Udaf register_func_id function"); |
86 | |
|
87 | 0 | TJavaUdfExecutorCtorParams ctor_params; |
88 | 0 | ctor_params.__set_fn(fn); |
89 | 0 | if (!fn.hdfs_location.empty() && !fn.checksum.empty()) { |
90 | 0 | ctor_params.__set_location(local_location); |
91 | 0 | } |
92 | |
|
93 | 0 | Jni::LocalArray ctor_params_bytes; |
94 | 0 | RETURN_IF_ERROR(Jni::Util::SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes)); |
95 | 0 | RETURN_IF_ERROR(executor_cl.new_object(env, executor_ctor_id) |
96 | 0 | .with_arg(ctor_params_bytes) |
97 | 0 | .call(&executor_obj)); |
98 | 0 | return Status::OK(); |
99 | 0 | } |
100 | | |
101 | | Status add(int64_t places_address, bool is_single_place, const IColumn** columns, |
102 | | int64_t row_num_start, int64_t row_num_end, const DataTypes& argument_types, |
103 | 0 | int64_t place_offset) { |
104 | 0 | JNIEnv* env = nullptr; |
105 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf add function"); |
106 | |
|
107 | 0 | Block input_block; |
108 | 0 | for (size_t i = 0; i < argument_size; ++i) { |
109 | 0 | input_block.insert(ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i], |
110 | 0 | std::to_string(i))); |
111 | 0 | } |
112 | 0 | std::unique_ptr<long[]> input_table; |
113 | 0 | RETURN_IF_ERROR(JniConnector::to_java_table(&input_block, input_table)); |
114 | 0 | auto input_table_schema = JniConnector::parse_table_schema(&input_block); |
115 | 0 | std::map<String, String> input_params = { |
116 | 0 | {"meta_address", std::to_string((long)input_table.get())}, |
117 | 0 | {"required_fields", input_table_schema.first}, |
118 | 0 | {"columns_types", input_table_schema.second}}; |
119 | |
|
120 | 0 | Jni::LocalObject input_map; |
121 | 0 | RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, input_params, &input_map)); |
122 | | // invoke add batch |
123 | | // Keep consistent with the function signature of executor_add_batch_id. |
124 | | |
125 | 0 | return executor_obj.call_void_method(env, executor_add_batch_id) |
126 | 0 | .with_arg((jboolean)is_single_place) |
127 | 0 | .with_arg(cast_set<jint>(row_num_start)) |
128 | 0 | .with_arg(cast_set<jint>(row_num_end)) |
129 | 0 | .with_arg(cast_set<jlong>(places_address)) |
130 | 0 | .with_arg(cast_set<jint>(place_offset)) |
131 | 0 | .with_arg(input_map) |
132 | 0 | .call(); |
133 | 0 | } |
134 | | |
135 | 0 | Status merge(const AggregateJavaUdafData& rhs, int64_t place) { |
136 | 0 | JNIEnv* env = nullptr; |
137 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf merge function"); |
138 | 0 | serialize_data = rhs.serialize_data; |
139 | 0 | Jni::LocalArray byte_arr; |
140 | 0 | RETURN_IF_ERROR(Jni::Util::WriteBufferToByteArray(env, (jbyte*)serialize_data.data(), |
141 | 0 | cast_set<jsize>(serialize_data.length()), |
142 | 0 | &byte_arr)); |
143 | | |
144 | 0 | return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_merge_id) |
145 | 0 | .with_arg((jlong)place) |
146 | 0 | .with_arg(byte_arr) |
147 | 0 | .call(); |
148 | 0 | } |
149 | | |
150 | 0 | Status write(BufferWritable& buf, int64_t place) { |
151 | 0 | JNIEnv* env = nullptr; |
152 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf write function"); |
153 | | // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to |
154 | | // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again. |
155 | 0 | Jni::LocalArray arr; |
156 | 0 | RETURN_IF_ERROR( |
157 | 0 | executor_obj.call_nonvirtual_object_method(env, executor_cl, executor_serialize_id) |
158 | 0 | .with_arg((jlong)place) |
159 | 0 | .call(&arr)); |
160 | | |
161 | 0 | jsize len = 0; |
162 | 0 | RETURN_IF_ERROR(arr.get_length(env, &len)); |
163 | 0 | serialize_data.resize(len); |
164 | 0 | RETURN_IF_ERROR(arr.get_byte_elements(env, 0, len, |
165 | 0 | reinterpret_cast<jbyte*>(serialize_data.data()))); |
166 | 0 | buf.write_binary(serialize_data); |
167 | 0 | return Status::OK(); |
168 | 0 | } |
169 | | |
170 | 0 | Status reset(int64_t place) { |
171 | 0 | JNIEnv* env = nullptr; |
172 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf reset function"); |
173 | 0 | return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_reset_id) |
174 | 0 | .with_arg(cast_set<jlong>(place)) |
175 | 0 | .call(); |
176 | 0 | } |
177 | | |
178 | 0 | void read(BufferReadable& buf) { buf.read_binary(serialize_data); } |
179 | | |
180 | 0 | Status destroy() { |
181 | 0 | JNIEnv* env = nullptr; |
182 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf destroy function"); |
183 | 0 | return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_destroy_id) |
184 | 0 | .call(); |
185 | 0 | } |
186 | | |
187 | 0 | Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const { |
188 | 0 | JNIEnv* env = nullptr; |
189 | 0 | RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf get value function"); |
190 | |
|
191 | 0 | Block output_block; |
192 | 0 | output_block.insert(ColumnWithTypeAndName(to.get_ptr(), result_type, "_result_")); |
193 | 0 | auto output_table_schema = JniConnector::parse_table_schema(&output_block); |
194 | 0 | std::string output_nullable = result_type->is_nullable() ? "true" : "false"; |
195 | 0 | std::map<String, String> output_params = {{"is_nullable", output_nullable}, |
196 | 0 | {"required_fields", output_table_schema.first}, |
197 | 0 | {"columns_types", output_table_schema.second}}; |
198 | |
|
199 | 0 | Jni::LocalObject output_map; |
200 | 0 | RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, output_params, &output_map)); |
201 | 0 | long output_address; |
202 | |
|
203 | 0 | RETURN_IF_ERROR(executor_obj.call_long_method(env, executor_get_value_id) |
204 | 0 | .with_arg(cast_set<jlong>(place)) |
205 | 0 | .with_arg(output_map) |
206 | 0 | .call(&output_address)); |
207 | | |
208 | 0 | return JniConnector::fill_block(&output_block, {0}, output_address); |
209 | 0 | } |
210 | | |
211 | | private: |
212 | 0 | Status register_func_id(JNIEnv* env) { |
213 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, |
214 | 0 | &executor_ctor_id)); |
215 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "reset", UDAF_EXECUTOR_RESET_SIGNATURE, |
216 | 0 | &executor_reset_id)); |
217 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "close", UDAF_EXECUTOR_CLOSE_SIGNATURE, |
218 | 0 | &executor_close_id)); |
219 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "merge", UDAF_EXECUTOR_MERGE_SIGNATURE, |
220 | 0 | &executor_merge_id)); |
221 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, |
222 | 0 | &executor_serialize_id)); |
223 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "getValue", UDAF_EXECUTOR_GET_SIGNATURE, |
224 | 0 | &executor_get_value_id)); |
225 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, |
226 | 0 | &executor_destroy_id)); |
227 | 0 | RETURN_IF_ERROR(executor_cl.get_method(env, "addBatch", UDAF_EXECUTOR_ADD_SIGNATURE, |
228 | 0 | &executor_add_batch_id)); |
229 | | |
230 | 0 | return Status::OK(); |
231 | 0 | } |
232 | | |
233 | | private: |
234 | | // TODO: too many variables are hold, it's causing a lot of memory waste |
235 | | // it's time to refactor it. |
236 | | Jni::GlobalClass executor_cl; |
237 | | Jni::GlobalObject executor_obj; |
238 | | |
239 | | Jni::MethodId executor_ctor_id; |
240 | | Jni::MethodId executor_add_batch_id; |
241 | | Jni::MethodId executor_merge_id; |
242 | | Jni::MethodId executor_serialize_id; |
243 | | Jni::MethodId executor_get_value_id; |
244 | | Jni::MethodId executor_reset_id; |
245 | | Jni::MethodId executor_close_id; |
246 | | Jni::MethodId executor_destroy_id; |
247 | | int argument_size = 0; |
248 | | std::string serialize_data; |
249 | | }; |
250 | | |
251 | | class AggregateJavaUdaf final |
252 | | : public IAggregateFunctionDataHelper<AggregateJavaUdafData, AggregateJavaUdaf>, |
253 | | VarargsExpression, |
254 | | NullableAggregateFunction { |
255 | | public: |
256 | | ENABLE_FACTORY_CREATOR(AggregateJavaUdaf); |
257 | | AggregateJavaUdaf(const TFunction& fn, const DataTypes& argument_types_, |
258 | | const DataTypePtr& return_type) |
259 | 0 | : IAggregateFunctionDataHelper(argument_types_), |
260 | 0 | _fn(fn), |
261 | 0 | _return_type(return_type), |
262 | 0 | _first_created(true), |
263 | 0 | _exec_place(nullptr) {} |
264 | 0 | ~AggregateJavaUdaf() override = default; |
265 | | |
266 | | static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_, |
267 | 0 | const DataTypePtr& return_type) { |
268 | 0 | return std::make_shared<AggregateJavaUdaf>(fn, argument_types_, return_type); |
269 | 0 | } |
270 | | //Note: The condition is added because maybe the BE can't find java-udaf impl jar |
271 | | //So need to check as soon as possible, before call Data function |
272 | 0 | Status check_udaf(const TFunction& fn) { |
273 | 0 | auto function_cache = UserFunctionCache::instance(); |
274 | | // get jar path if both file path location and checksum are null |
275 | 0 | if (!fn.hdfs_location.empty() && !fn.checksum.empty()) { |
276 | 0 | return function_cache->get_jarpath(fn.id, fn.hdfs_location, fn.checksum, |
277 | 0 | &_local_location); |
278 | 0 | } else { |
279 | 0 | return Status::OK(); |
280 | 0 | } |
281 | 0 | } |
282 | | |
283 | 0 | void create(AggregateDataPtr __restrict place) const override { |
284 | 0 | new (place) Data(argument_types.size()); |
285 | 0 | if (_first_created) { |
286 | 0 | Status status = this->data(place).init_udaf(_fn, _local_location); |
287 | 0 | _first_created = false; |
288 | 0 | _exec_place = place; |
289 | 0 | if (UNLIKELY(!status.ok())) { |
290 | 0 | static_cast<void>(this->data(place).destroy()); |
291 | 0 | this->data(place).~Data(); |
292 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string()); |
293 | 0 | } |
294 | 0 | } |
295 | 0 | } |
296 | | |
297 | | // To avoid multiple times JNI call, Here will destroy all data at once |
298 | 0 | void destroy(AggregateDataPtr __restrict place) const noexcept override { |
299 | 0 | if (place == _exec_place) { |
300 | 0 | Status status = Status::OK(); |
301 | 0 | status = this->data(_exec_place).destroy(); |
302 | 0 | status = this->data(_exec_place).close_and_delete_object(); |
303 | 0 | _first_created = true; |
304 | 0 | if (UNLIKELY(!status.ok())) { |
305 | 0 | LOG(WARNING) << "Failed to destroy function: " << status.to_string(); |
306 | 0 | } |
307 | 0 | } |
308 | 0 | this->data(place).~Data(); |
309 | 0 | } |
310 | | |
311 | 0 | String get_name() const override { return _fn.name.function_name; } |
312 | | |
313 | 0 | DataTypePtr get_return_type() const override { return _return_type; } |
314 | | |
315 | | void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
316 | 0 | Arena&) const override { |
317 | 0 | int64_t places_address = reinterpret_cast<int64_t>(place); |
318 | 0 | Status st = this->data(_exec_place) |
319 | 0 | .add(places_address, true, columns, row_num, row_num + 1, |
320 | 0 | argument_types, 0); |
321 | 0 | if (UNLIKELY(!st.ok())) { |
322 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
323 | 0 | } |
324 | 0 | } |
325 | | |
326 | | void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, |
327 | 0 | const IColumn** columns, Arena&, bool /*agg_many*/) const override { |
328 | 0 | int64_t places_address = reinterpret_cast<int64_t>(places); |
329 | 0 | Status st = this->data(_exec_place) |
330 | 0 | .add(places_address, false, columns, 0, batch_size, argument_types, |
331 | 0 | place_offset); |
332 | 0 | if (UNLIKELY(!st.ok())) { |
333 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
334 | 0 | } |
335 | 0 | } |
336 | | |
337 | | void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, |
338 | 0 | Arena&) const override { |
339 | 0 | int64_t places_address = reinterpret_cast<int64_t>(place); |
340 | 0 | Status st = this->data(_exec_place) |
341 | 0 | .add(places_address, true, columns, 0, batch_size, argument_types, 0); |
342 | 0 | if (UNLIKELY(!st.ok())) { |
343 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
344 | 0 | } |
345 | 0 | } |
346 | | |
347 | | void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, |
348 | | int64_t frame_end, AggregateDataPtr place, const IColumn** columns, |
349 | | Arena&, UInt8* current_window_empty, |
350 | 0 | UInt8* current_window_has_inited) const override { |
351 | 0 | frame_start = std::max<int64_t>(frame_start, partition_start); |
352 | 0 | frame_end = std::min<int64_t>(frame_end, partition_end); |
353 | 0 | int64_t places_address = reinterpret_cast<int64_t>(place); |
354 | 0 | Status st = this->data(_exec_place) |
355 | 0 | .add(places_address, true, columns, frame_start, frame_end, |
356 | 0 | argument_types, 0); |
357 | 0 | if (frame_start >= frame_end) { |
358 | 0 | if (!*current_window_has_inited) { |
359 | 0 | *current_window_empty = true; |
360 | 0 | } |
361 | 0 | } else { |
362 | 0 | *current_window_empty = false; |
363 | 0 | *current_window_has_inited = true; |
364 | 0 | } |
365 | 0 | if (UNLIKELY(!st.ok())) { |
366 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
367 | 0 | } |
368 | 0 | } |
369 | | |
370 | 0 | void reset(AggregateDataPtr place) const override { |
371 | 0 | Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place)); |
372 | 0 | if (UNLIKELY(!st.ok())) { |
373 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
374 | 0 | } |
375 | 0 | } |
376 | | |
377 | | void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
378 | 0 | Arena&) const override { |
379 | 0 | Status st = |
380 | 0 | this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place)); |
381 | 0 | if (UNLIKELY(!st.ok())) { |
382 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
383 | 0 | } |
384 | 0 | } |
385 | | |
386 | 0 | void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
387 | 0 | Status st = this->data(_exec_place).write(buf, reinterpret_cast<int64_t>(place)); |
388 | 0 | if (UNLIKELY(!st.ok())) { |
389 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
390 | 0 | } |
391 | 0 | } |
392 | | |
393 | | // during merge-finalized phase, for deserialize and merge firstly, |
394 | | // will call create --- deserialize --- merge --- destory for each rows , |
395 | | // so need doing new (place), to create Data and read to buf, then call merge , |
396 | | // and during destory about deserialize, because haven't done init_udaf, |
397 | | // so it's can't call ~Data, only to change _destory_deserialize flag. |
398 | | void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
399 | 0 | Arena&) const override { |
400 | 0 | this->data(place).read(buf); |
401 | 0 | } |
402 | | |
403 | 0 | void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
404 | 0 | Status st = this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place)); |
405 | 0 | if (UNLIKELY(!st.ok())) { |
406 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); |
407 | 0 | } |
408 | 0 | } |
409 | | |
410 | | private: |
411 | | TFunction _fn; |
412 | | DataTypePtr _return_type; |
413 | | mutable bool _first_created; |
414 | | mutable AggregateDataPtr _exec_place; |
415 | | std::string _local_location; |
416 | | }; |
417 | | |
418 | | } // namespace doris |
419 | | |
420 | | #include "common/compile_check_end.h" |