Coverage Report

Created: 2026-03-15 08:11

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_java_udaf.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <jni.h>
21
#include <unistd.h>
22
23
#include <cstdint>
24
#include <memory>
25
26
#include "absl/strings/substitute.h"
27
#include "common/cast_set.h"
28
#include "common/compiler_util.h"
29
#include "common/exception.h"
30
#include "common/logging.h"
31
#include "common/status.h"
32
#include "core/column/column_array.h"
33
#include "core/column/column_map.h"
34
#include "core/column/column_string.h"
35
#include "core/field.h"
36
#include "core/string_ref.h"
37
#include "core/types.h"
38
#include "exec/connector/jni_connector.h"
39
#include "exprs/aggregate/aggregate_function.h"
40
#include "runtime/user_function_cache.h"
41
#include "util/io_helper.h"
42
#include "util/jni-util.h"
43
44
namespace doris {
45
#include "common/compile_check_begin.h"
46
47
const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor";
48
const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V";
49
const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V";
50
const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V";
51
const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZIIJILjava/util/Map;)V";
52
const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
53
const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
54
const char* UDAF_EXECUTOR_GET_SIGNATURE = "(JLjava/util/Map;)J";
55
const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
56
// Calling Java method about those signature means: "(argument-types)return-type"
57
// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
58
59
struct AggregateJavaUdafData {
60
public:
61
0
    AggregateJavaUdafData() = default;
62
0
    AggregateJavaUdafData(int64_t num_args) { cast_set(argument_size, num_args); }
63
64
0
    ~AggregateJavaUdafData() = default;
65
66
0
    Status close_and_delete_object() {
67
0
        JNIEnv* env = nullptr;
68
69
0
        RETURN_IF_ERROR(Jni::Env::Get(&env));
70
71
0
        auto st = executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_close_id)
72
0
                          .call();
73
0
        if (!st.ok()) {
74
0
            LOG(WARNING) << "Failed to close JAVA UDAF: " << st.to_string();
75
0
            return st;
76
0
        }
77
0
        return Status::OK();
78
0
    }
79
80
0
    Status init_udaf(const TFunction& fn, const std::string& local_location) {
81
0
        JNIEnv* env = nullptr;
82
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf init_udaf function");
83
0
        RETURN_IF_ERROR(Jni::Util::find_class(env, UDAF_EXECUTOR_CLASS, &executor_cl));
84
0
        RETURN_NOT_OK_STATUS_WITH_WARN(register_func_id(env),
85
0
                                       "Java-Udaf register_func_id function");
86
87
0
        TJavaUdfExecutorCtorParams ctor_params;
88
0
        ctor_params.__set_fn(fn);
89
0
        if (!fn.hdfs_location.empty() && !fn.checksum.empty()) {
90
0
            ctor_params.__set_location(local_location);
91
0
        }
92
93
0
        Jni::LocalArray ctor_params_bytes;
94
0
        RETURN_IF_ERROR(Jni::Util::SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes));
95
0
        RETURN_IF_ERROR(executor_cl.new_object(env, executor_ctor_id)
96
0
                                .with_arg(ctor_params_bytes)
97
0
                                .call(&executor_obj));
98
0
        return Status::OK();
99
0
    }
100
101
    Status add(int64_t places_address, bool is_single_place, const IColumn** columns,
102
               int64_t row_num_start, int64_t row_num_end, const DataTypes& argument_types,
103
0
               int64_t place_offset) {
104
0
        JNIEnv* env = nullptr;
105
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf add function");
106
107
0
        Block input_block;
108
0
        for (size_t i = 0; i < argument_size; ++i) {
109
0
            input_block.insert(ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i],
110
0
                                                     std::to_string(i)));
111
0
        }
112
0
        std::unique_ptr<long[]> input_table;
113
0
        RETURN_IF_ERROR(JniConnector::to_java_table(&input_block, input_table));
114
0
        auto input_table_schema = JniConnector::parse_table_schema(&input_block);
115
0
        std::map<String, String> input_params = {
116
0
                {"meta_address", std::to_string((long)input_table.get())},
117
0
                {"required_fields", input_table_schema.first},
118
0
                {"columns_types", input_table_schema.second}};
119
120
0
        Jni::LocalObject input_map;
121
0
        RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, input_params, &input_map));
122
        // invoke add batch
123
        // Keep consistent with the function signature of executor_add_batch_id.
124
125
0
        return executor_obj.call_void_method(env, executor_add_batch_id)
126
0
                .with_arg((jboolean)is_single_place)
127
0
                .with_arg(cast_set<jint>(row_num_start))
128
0
                .with_arg(cast_set<jint>(row_num_end))
129
0
                .with_arg(cast_set<jlong>(places_address))
130
0
                .with_arg(cast_set<jint>(place_offset))
131
0
                .with_arg(input_map)
132
0
                .call();
133
0
    }
134
135
0
    Status merge(const AggregateJavaUdafData& rhs, int64_t place) {
136
0
        JNIEnv* env = nullptr;
137
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf merge function");
138
0
        serialize_data = rhs.serialize_data;
139
0
        Jni::LocalArray byte_arr;
140
0
        RETURN_IF_ERROR(Jni::Util::WriteBufferToByteArray(env, (jbyte*)serialize_data.data(),
141
0
                                                          cast_set<jsize>(serialize_data.length()),
142
0
                                                          &byte_arr));
143
144
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_merge_id)
145
0
                .with_arg((jlong)place)
146
0
                .with_arg(byte_arr)
147
0
                .call();
148
0
    }
149
150
0
    Status write(BufferWritable& buf, int64_t place) {
151
0
        JNIEnv* env = nullptr;
152
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf write function");
153
        // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to
154
        // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again.
155
0
        Jni::LocalArray arr;
156
0
        RETURN_IF_ERROR(
157
0
                executor_obj.call_nonvirtual_object_method(env, executor_cl, executor_serialize_id)
158
0
                        .with_arg((jlong)place)
159
0
                        .call(&arr));
160
161
0
        jsize len = 0;
162
0
        RETURN_IF_ERROR(arr.get_length(env, &len));
163
0
        serialize_data.resize(len);
164
0
        RETURN_IF_ERROR(arr.get_byte_elements(env, 0, len,
165
0
                                              reinterpret_cast<jbyte*>(serialize_data.data())));
166
0
        buf.write_binary(serialize_data);
167
0
        return Status::OK();
168
0
    }
169
170
0
    Status reset(int64_t place) {
171
0
        JNIEnv* env = nullptr;
172
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf reset function");
173
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_reset_id)
174
0
                .with_arg(cast_set<jlong>(place))
175
0
                .call();
176
0
    }
177
178
0
    void read(BufferReadable& buf) { buf.read_binary(serialize_data); }
179
180
0
    Status destroy() {
181
0
        JNIEnv* env = nullptr;
182
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf destroy function");
183
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_destroy_id)
184
0
                .call();
185
0
    }
186
187
0
    Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const {
188
0
        JNIEnv* env = nullptr;
189
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf get value function");
190
191
0
        Block output_block;
192
0
        output_block.insert(ColumnWithTypeAndName(to.get_ptr(), result_type, "_result_"));
193
0
        auto output_table_schema = JniConnector::parse_table_schema(&output_block);
194
0
        std::string output_nullable = result_type->is_nullable() ? "true" : "false";
195
0
        std::map<String, String> output_params = {{"is_nullable", output_nullable},
196
0
                                                  {"required_fields", output_table_schema.first},
197
0
                                                  {"columns_types", output_table_schema.second}};
198
199
0
        Jni::LocalObject output_map;
200
0
        RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, output_params, &output_map));
201
0
        long output_address;
202
203
0
        RETURN_IF_ERROR(executor_obj.call_long_method(env, executor_get_value_id)
204
0
                                .with_arg(cast_set<jlong>(place))
205
0
                                .with_arg(output_map)
206
0
                                .call(&output_address));
207
208
0
        return JniConnector::fill_block(&output_block, {0}, output_address);
209
0
    }
210
211
private:
212
0
    Status register_func_id(JNIEnv* env) {
213
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "<init>", UDAF_EXECUTOR_CTOR_SIGNATURE,
214
0
                                               &executor_ctor_id));
215
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "reset", UDAF_EXECUTOR_RESET_SIGNATURE,
216
0
                                               &executor_reset_id));
217
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "close", UDAF_EXECUTOR_CLOSE_SIGNATURE,
218
0
                                               &executor_close_id));
219
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "merge", UDAF_EXECUTOR_MERGE_SIGNATURE,
220
0
                                               &executor_merge_id));
221
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE,
222
0
                                               &executor_serialize_id));
223
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "getValue", UDAF_EXECUTOR_GET_SIGNATURE,
224
0
                                               &executor_get_value_id));
225
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE,
226
0
                                               &executor_destroy_id));
227
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "addBatch", UDAF_EXECUTOR_ADD_SIGNATURE,
228
0
                                               &executor_add_batch_id));
229
230
0
        return Status::OK();
231
0
    }
232
233
private:
234
    // TODO: too many variables are hold, it's causing a lot of memory waste
235
    // it's time to refactor it.
236
    Jni::GlobalClass executor_cl;
237
    Jni::GlobalObject executor_obj;
238
239
    Jni::MethodId executor_ctor_id;
240
    Jni::MethodId executor_add_batch_id;
241
    Jni::MethodId executor_merge_id;
242
    Jni::MethodId executor_serialize_id;
243
    Jni::MethodId executor_get_value_id;
244
    Jni::MethodId executor_reset_id;
245
    Jni::MethodId executor_close_id;
246
    Jni::MethodId executor_destroy_id;
247
    int argument_size = 0;
248
    std::string serialize_data;
249
};
250
251
class AggregateJavaUdaf final
252
        : public IAggregateFunctionDataHelper<AggregateJavaUdafData, AggregateJavaUdaf>,
253
          VarargsExpression,
254
          NullableAggregateFunction {
255
public:
256
    ENABLE_FACTORY_CREATOR(AggregateJavaUdaf);
257
    AggregateJavaUdaf(const TFunction& fn, const DataTypes& argument_types_,
258
                      const DataTypePtr& return_type)
259
0
            : IAggregateFunctionDataHelper(argument_types_),
260
0
              _fn(fn),
261
0
              _return_type(return_type),
262
0
              _first_created(true),
263
0
              _exec_place(nullptr) {}
264
0
    ~AggregateJavaUdaf() override = default;
265
266
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
267
0
                                       const DataTypePtr& return_type) {
268
0
        return std::make_shared<AggregateJavaUdaf>(fn, argument_types_, return_type);
269
0
    }
270
    //Note: The condition is added because maybe the BE can't find java-udaf impl jar
271
    //So need to check as soon as possible, before call Data function
272
0
    Status check_udaf(const TFunction& fn) {
273
0
        auto function_cache = UserFunctionCache::instance();
274
        // get jar path if both file path location and checksum are null
275
0
        if (!fn.hdfs_location.empty() && !fn.checksum.empty()) {
276
0
            return function_cache->get_jarpath(fn.id, fn.hdfs_location, fn.checksum,
277
0
                                               &_local_location);
278
0
        } else {
279
0
            return Status::OK();
280
0
        }
281
0
    }
282
283
0
    void create(AggregateDataPtr __restrict place) const override {
284
0
        new (place) Data(argument_types.size());
285
0
        if (_first_created) {
286
0
            Status status = this->data(place).init_udaf(_fn, _local_location);
287
0
            _first_created = false;
288
0
            _exec_place = place;
289
0
            if (UNLIKELY(!status.ok())) {
290
0
                static_cast<void>(this->data(place).destroy());
291
0
                this->data(place).~Data();
292
0
                throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string());
293
0
            }
294
0
        }
295
0
    }
296
297
    // To avoid multiple times JNI call, Here will destroy all data at once
298
0
    void destroy(AggregateDataPtr __restrict place) const noexcept override {
299
0
        if (place == _exec_place) {
300
0
            Status status = Status::OK();
301
0
            status = this->data(_exec_place).destroy();
302
0
            status = this->data(_exec_place).close_and_delete_object();
303
0
            _first_created = true;
304
0
            if (UNLIKELY(!status.ok())) {
305
0
                LOG(WARNING) << "Failed to destroy function: " << status.to_string();
306
0
            }
307
0
        }
308
0
        this->data(place).~Data();
309
0
    }
310
311
0
    String get_name() const override { return _fn.name.function_name; }
312
313
0
    DataTypePtr get_return_type() const override { return _return_type; }
314
315
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
316
0
             Arena&) const override {
317
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
318
0
        Status st = this->data(_exec_place)
319
0
                            .add(places_address, true, columns, row_num, row_num + 1,
320
0
                                 argument_types, 0);
321
0
        if (UNLIKELY(!st.ok())) {
322
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
323
0
        }
324
0
    }
325
326
    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
327
0
                   const IColumn** columns, Arena&, bool /*agg_many*/) const override {
328
0
        int64_t places_address = reinterpret_cast<int64_t>(places);
329
0
        Status st = this->data(_exec_place)
330
0
                            .add(places_address, false, columns, 0, batch_size, argument_types,
331
0
                                 place_offset);
332
0
        if (UNLIKELY(!st.ok())) {
333
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
334
0
        }
335
0
    }
336
337
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
338
0
                                Arena&) const override {
339
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
340
0
        Status st = this->data(_exec_place)
341
0
                            .add(places_address, true, columns, 0, batch_size, argument_types, 0);
342
0
        if (UNLIKELY(!st.ok())) {
343
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
344
0
        }
345
0
    }
346
347
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
348
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
349
                                Arena&, UInt8* current_window_empty,
350
0
                                UInt8* current_window_has_inited) const override {
351
0
        frame_start = std::max<int64_t>(frame_start, partition_start);
352
0
        frame_end = std::min<int64_t>(frame_end, partition_end);
353
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
354
0
        Status st = this->data(_exec_place)
355
0
                            .add(places_address, true, columns, frame_start, frame_end,
356
0
                                 argument_types, 0);
357
0
        if (frame_start >= frame_end) {
358
0
            if (!*current_window_has_inited) {
359
0
                *current_window_empty = true;
360
0
            }
361
0
        } else {
362
0
            *current_window_empty = false;
363
0
            *current_window_has_inited = true;
364
0
        }
365
0
        if (UNLIKELY(!st.ok())) {
366
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
367
0
        }
368
0
    }
369
370
0
    void reset(AggregateDataPtr place) const override {
371
0
        Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place));
372
0
        if (UNLIKELY(!st.ok())) {
373
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
374
0
        }
375
0
    }
376
377
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
378
0
               Arena&) const override {
379
0
        Status st =
380
0
                this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place));
381
0
        if (UNLIKELY(!st.ok())) {
382
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
383
0
        }
384
0
    }
385
386
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
387
0
        Status st = this->data(_exec_place).write(buf, reinterpret_cast<int64_t>(place));
388
0
        if (UNLIKELY(!st.ok())) {
389
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
390
0
        }
391
0
    }
392
393
    // during merge-finalized phase, for deserialize and merge firstly,
394
    // will call create --- deserialize --- merge --- destory for each rows ,
395
    // so need doing new (place), to create Data and read to buf, then call merge ,
396
    // and during destory about deserialize, because haven't done init_udaf,
397
    // so it's can't call ~Data, only to change _destory_deserialize flag.
398
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
399
0
                     Arena&) const override {
400
0
        this->data(place).read(buf);
401
0
    }
402
403
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
404
0
        Status st = this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place));
405
0
        if (UNLIKELY(!st.ok())) {
406
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
407
0
        }
408
0
    }
409
410
private:
411
    TFunction _fn;
412
    DataTypePtr _return_type;
413
    mutable bool _first_created;
414
    mutable AggregateDataPtr _exec_place;
415
    std::string _local_location;
416
};
417
418
} // namespace doris
419
420
#include "common/compile_check_end.h"