Coverage Report

Created: 2026-04-14 17:06

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_java_udaf.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <jni.h>
21
#include <unistd.h>
22
23
#include <cstdint>
24
#include <memory>
25
26
#include "absl/strings/substitute.h"
27
#include "common/cast_set.h"
28
#include "common/compiler_util.h"
29
#include "common/exception.h"
30
#include "common/logging.h"
31
#include "common/status.h"
32
#include "core/column/column_array.h"
33
#include "core/column/column_map.h"
34
#include "core/column/column_string.h"
35
#include "core/field.h"
36
#include "core/string_ref.h"
37
#include "core/types.h"
38
#include "exprs/aggregate/aggregate_function.h"
39
#include "format/jni/jni_data_bridge.h"
40
#include "runtime/user_function_cache.h"
41
#include "util/io_helper.h"
42
#include "util/jni-util.h"
43
44
namespace doris {
45
46
const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor";
47
const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V";
48
const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V";
49
const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V";
50
const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZIIJILjava/util/Map;)V";
51
const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
52
const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
53
const char* UDAF_EXECUTOR_GET_SIGNATURE = "(JLjava/util/Map;)J";
54
const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
55
// Calling Java method about those signature means: "(argument-types)return-type"
56
// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
57
58
struct AggregateJavaUdafData {
59
public:
60
0
    AggregateJavaUdafData() = default;
61
0
    AggregateJavaUdafData(int64_t num_args) { cast_set(argument_size, num_args); }
62
63
0
    ~AggregateJavaUdafData() = default;
64
65
0
    Status close_and_delete_object() {
66
0
        JNIEnv* env = nullptr;
67
68
0
        RETURN_IF_ERROR(Jni::Env::Get(&env));
69
70
0
        auto st = executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_close_id)
71
0
                          .call();
72
0
        if (!st.ok()) {
73
0
            LOG(WARNING) << "Failed to close JAVA UDAF: " << st.to_string();
74
0
            return st;
75
0
        }
76
0
        return Status::OK();
77
0
    }
78
79
0
    Status init_udaf(const TFunction& fn, const std::string& local_location) {
80
0
        JNIEnv* env = nullptr;
81
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf init_udaf function");
82
0
        RETURN_IF_ERROR(Jni::Util::find_class(env, UDAF_EXECUTOR_CLASS, &executor_cl));
83
0
        RETURN_NOT_OK_STATUS_WITH_WARN(register_func_id(env),
84
0
                                       "Java-Udaf register_func_id function");
85
86
0
        TJavaUdfExecutorCtorParams ctor_params;
87
0
        ctor_params.__set_fn(fn);
88
0
        if (!fn.hdfs_location.empty() && !fn.checksum.empty()) {
89
0
            ctor_params.__set_location(local_location);
90
0
        }
91
92
0
        Jni::LocalArray ctor_params_bytes;
93
0
        RETURN_IF_ERROR(Jni::Util::SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes));
94
0
        RETURN_IF_ERROR(executor_cl.new_object(env, executor_ctor_id)
95
0
                                .with_arg(ctor_params_bytes)
96
0
                                .call(&executor_obj));
97
0
        return Status::OK();
98
0
    }
99
100
    Status add(int64_t places_address, bool is_single_place, const IColumn** columns,
101
               int64_t row_num_start, int64_t row_num_end, const DataTypes& argument_types,
102
0
               int64_t place_offset) {
103
0
        JNIEnv* env = nullptr;
104
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf add function");
105
106
0
        Block input_block;
107
0
        for (size_t i = 0; i < argument_size; ++i) {
108
0
            input_block.insert(ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i],
109
0
                                                     std::to_string(i)));
110
0
        }
111
0
        std::unique_ptr<long[]> input_table;
112
0
        RETURN_IF_ERROR(JniDataBridge::to_java_table(&input_block, input_table));
113
0
        auto input_table_schema = JniDataBridge::parse_table_schema(&input_block);
114
0
        std::map<String, String> input_params = {
115
0
                {"meta_address", std::to_string((long)input_table.get())},
116
0
                {"required_fields", input_table_schema.first},
117
0
                {"columns_types", input_table_schema.second}};
118
119
0
        Jni::LocalObject input_map;
120
0
        RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, input_params, &input_map));
121
        // invoke add batch
122
        // Keep consistent with the function signature of executor_add_batch_id.
123
124
0
        return executor_obj.call_void_method(env, executor_add_batch_id)
125
0
                .with_arg((jboolean)is_single_place)
126
0
                .with_arg(cast_set<jint>(row_num_start))
127
0
                .with_arg(cast_set<jint>(row_num_end))
128
0
                .with_arg(cast_set<jlong>(places_address))
129
0
                .with_arg(cast_set<jint>(place_offset))
130
0
                .with_arg(input_map)
131
0
                .call();
132
0
    }
133
134
0
    Status merge(const AggregateJavaUdafData& rhs, int64_t place) {
135
0
        JNIEnv* env = nullptr;
136
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf merge function");
137
0
        serialize_data = rhs.serialize_data;
138
0
        Jni::LocalArray byte_arr;
139
0
        RETURN_IF_ERROR(Jni::Util::WriteBufferToByteArray(env, (jbyte*)serialize_data.data(),
140
0
                                                          cast_set<jsize>(serialize_data.length()),
141
0
                                                          &byte_arr));
142
143
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_merge_id)
144
0
                .with_arg((jlong)place)
145
0
                .with_arg(byte_arr)
146
0
                .call();
147
0
    }
148
149
0
    Status write(BufferWritable& buf, int64_t place) {
150
0
        JNIEnv* env = nullptr;
151
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf write function");
152
        // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to
153
        // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again.
154
0
        Jni::LocalArray arr;
155
0
        RETURN_IF_ERROR(
156
0
                executor_obj.call_nonvirtual_object_method(env, executor_cl, executor_serialize_id)
157
0
                        .with_arg((jlong)place)
158
0
                        .call(&arr));
159
160
0
        jsize len = 0;
161
0
        RETURN_IF_ERROR(arr.get_length(env, &len));
162
0
        serialize_data.resize(len);
163
0
        RETURN_IF_ERROR(arr.get_byte_elements(env, 0, len,
164
0
                                              reinterpret_cast<jbyte*>(serialize_data.data())));
165
0
        buf.write_binary(serialize_data);
166
0
        return Status::OK();
167
0
    }
168
169
0
    Status reset(int64_t place) {
170
0
        JNIEnv* env = nullptr;
171
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf reset function");
172
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_reset_id)
173
0
                .with_arg(cast_set<jlong>(place))
174
0
                .call();
175
0
    }
176
177
0
    void read(BufferReadable& buf) { buf.read_binary(serialize_data); }
178
179
0
    Status destroy() {
180
0
        JNIEnv* env = nullptr;
181
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf destroy function");
182
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_destroy_id)
183
0
                .call();
184
0
    }
185
186
0
    Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const {
187
0
        JNIEnv* env = nullptr;
188
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf get value function");
189
190
0
        Block output_block;
191
0
        output_block.insert(ColumnWithTypeAndName(to.get_ptr(), result_type, "_result_"));
192
0
        auto output_table_schema = JniDataBridge::parse_table_schema(&output_block);
193
0
        std::string output_nullable = result_type->is_nullable() ? "true" : "false";
194
0
        std::map<String, String> output_params = {{"is_nullable", output_nullable},
195
0
                                                  {"required_fields", output_table_schema.first},
196
0
                                                  {"columns_types", output_table_schema.second}};
197
198
0
        Jni::LocalObject output_map;
199
0
        RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, output_params, &output_map));
200
0
        long output_address;
201
202
0
        RETURN_IF_ERROR(executor_obj.call_long_method(env, executor_get_value_id)
203
0
                                .with_arg(cast_set<jlong>(place))
204
0
                                .with_arg(output_map)
205
0
                                .call(&output_address));
206
207
0
        return JniDataBridge::fill_block(&output_block, {0}, output_address);
208
0
    }
209
210
private:
211
0
    Status register_func_id(JNIEnv* env) {
212
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "<init>", UDAF_EXECUTOR_CTOR_SIGNATURE,
213
0
                                               &executor_ctor_id));
214
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "reset", UDAF_EXECUTOR_RESET_SIGNATURE,
215
0
                                               &executor_reset_id));
216
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "close", UDAF_EXECUTOR_CLOSE_SIGNATURE,
217
0
                                               &executor_close_id));
218
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "merge", UDAF_EXECUTOR_MERGE_SIGNATURE,
219
0
                                               &executor_merge_id));
220
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE,
221
0
                                               &executor_serialize_id));
222
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "getValue", UDAF_EXECUTOR_GET_SIGNATURE,
223
0
                                               &executor_get_value_id));
224
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE,
225
0
                                               &executor_destroy_id));
226
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "addBatch", UDAF_EXECUTOR_ADD_SIGNATURE,
227
0
                                               &executor_add_batch_id));
228
229
0
        return Status::OK();
230
0
    }
231
232
private:
233
    // TODO: too many variables are hold, it's causing a lot of memory waste
234
    // it's time to refactor it.
235
    Jni::GlobalClass executor_cl;
236
    Jni::GlobalObject executor_obj;
237
238
    Jni::MethodId executor_ctor_id;
239
    Jni::MethodId executor_add_batch_id;
240
    Jni::MethodId executor_merge_id;
241
    Jni::MethodId executor_serialize_id;
242
    Jni::MethodId executor_get_value_id;
243
    Jni::MethodId executor_reset_id;
244
    Jni::MethodId executor_close_id;
245
    Jni::MethodId executor_destroy_id;
246
    int argument_size = 0;
247
    std::string serialize_data;
248
};
249
250
class AggregateJavaUdaf final
251
        : public IAggregateFunctionDataHelper<AggregateJavaUdafData, AggregateJavaUdaf>,
252
          VarargsExpression,
253
          NullableAggregateFunction {
254
public:
255
    ENABLE_FACTORY_CREATOR(AggregateJavaUdaf);
256
    AggregateJavaUdaf(const TFunction& fn, const DataTypes& argument_types_,
257
                      const DataTypePtr& return_type)
258
0
            : IAggregateFunctionDataHelper(argument_types_),
259
0
              _fn(fn),
260
0
              _return_type(return_type),
261
0
              _first_created(true),
262
0
              _exec_place(nullptr) {}
263
0
    ~AggregateJavaUdaf() override = default;
264
265
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
266
0
                                       const DataTypePtr& return_type) {
267
0
        return std::make_shared<AggregateJavaUdaf>(fn, argument_types_, return_type);
268
0
    }
269
    //Note: The condition is added because maybe the BE can't find java-udaf impl jar
270
    //So need to check as soon as possible, before call Data function
271
0
    Status check_udaf(const TFunction& fn) {
272
0
        auto function_cache = UserFunctionCache::instance();
273
        // get jar path if both file path location and checksum are null
274
0
        if (!fn.hdfs_location.empty() && !fn.checksum.empty()) {
275
0
            return function_cache->get_jarpath(fn.id, fn.hdfs_location, fn.checksum,
276
0
                                               &_local_location);
277
0
        } else {
278
0
            return Status::OK();
279
0
        }
280
0
    }
281
282
0
    void create(AggregateDataPtr __restrict place) const override {
283
0
        new (place) Data(argument_types.size());
284
0
        if (_first_created) {
285
0
            Status status = this->data(place).init_udaf(_fn, _local_location);
286
0
            _first_created = false;
287
0
            _exec_place = place;
288
0
            if (UNLIKELY(!status.ok())) {
289
0
                static_cast<void>(this->data(place).destroy());
290
0
                this->data(place).~Data();
291
0
                throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string());
292
0
            }
293
0
        }
294
0
    }
295
296
    // To avoid multiple times JNI call, Here will destroy all data at once
297
0
    void destroy(AggregateDataPtr __restrict place) const noexcept override {
298
0
        if (place == _exec_place) {
299
0
            Status status = Status::OK();
300
0
            status = this->data(_exec_place).destroy();
301
0
            status = this->data(_exec_place).close_and_delete_object();
302
0
            _first_created = true;
303
0
            if (UNLIKELY(!status.ok())) {
304
0
                LOG(WARNING) << "Failed to destroy function: " << status.to_string();
305
0
            }
306
0
        }
307
0
        this->data(place).~Data();
308
0
    }
309
310
0
    String get_name() const override { return _fn.name.function_name; }
311
312
0
    DataTypePtr get_return_type() const override { return _return_type; }
313
314
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
315
0
             Arena&) const override {
316
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
317
0
        Status st = this->data(_exec_place)
318
0
                            .add(places_address, true, columns, row_num, row_num + 1,
319
0
                                 argument_types, 0);
320
0
        if (UNLIKELY(!st.ok())) {
321
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
322
0
        }
323
0
    }
324
325
    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
326
0
                   const IColumn** columns, Arena&, bool /*agg_many*/) const override {
327
0
        int64_t places_address = reinterpret_cast<int64_t>(places);
328
0
        Status st = this->data(_exec_place)
329
0
                            .add(places_address, false, columns, 0, batch_size, argument_types,
330
0
                                 place_offset);
331
0
        if (UNLIKELY(!st.ok())) {
332
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
333
0
        }
334
0
    }
335
336
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
337
0
                                Arena&) const override {
338
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
339
0
        Status st = this->data(_exec_place)
340
0
                            .add(places_address, true, columns, 0, batch_size, argument_types, 0);
341
0
        if (UNLIKELY(!st.ok())) {
342
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
343
0
        }
344
0
    }
345
346
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
347
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
348
                                Arena&, UInt8* current_window_empty,
349
0
                                UInt8* current_window_has_inited) const override {
350
0
        frame_start = std::max<int64_t>(frame_start, partition_start);
351
0
        frame_end = std::min<int64_t>(frame_end, partition_end);
352
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
353
0
        Status st = this->data(_exec_place)
354
0
                            .add(places_address, true, columns, frame_start, frame_end,
355
0
                                 argument_types, 0);
356
0
        if (frame_start >= frame_end) {
357
0
            if (!*current_window_has_inited) {
358
0
                *current_window_empty = true;
359
0
            }
360
0
        } else {
361
0
            *current_window_empty = false;
362
0
            *current_window_has_inited = true;
363
0
        }
364
0
        if (UNLIKELY(!st.ok())) {
365
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
366
0
        }
367
0
    }
368
369
0
    void reset(AggregateDataPtr place) const override {
370
0
        Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place));
371
0
        if (UNLIKELY(!st.ok())) {
372
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
373
0
        }
374
0
    }
375
376
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
377
0
               Arena&) const override {
378
0
        Status st =
379
0
                this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place));
380
0
        if (UNLIKELY(!st.ok())) {
381
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
382
0
        }
383
0
    }
384
385
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
386
0
        Status st = this->data(_exec_place).write(buf, reinterpret_cast<int64_t>(place));
387
0
        if (UNLIKELY(!st.ok())) {
388
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
389
0
        }
390
0
    }
391
392
    // during merge-finalized phase, for deserialize and merge firstly,
393
    // will call create --- deserialize --- merge --- destory for each rows ,
394
    // so need doing new (place), to create Data and read to buf, then call merge ,
395
    // and during destory about deserialize, because haven't done init_udaf,
396
    // so it's can't call ~Data, only to change _destory_deserialize flag.
397
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
398
0
                     Arena&) const override {
399
0
        this->data(place).read(buf);
400
0
    }
401
402
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
403
0
        Status st = this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place));
404
0
        if (UNLIKELY(!st.ok())) {
405
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
406
0
        }
407
0
    }
408
409
private:
410
    TFunction _fn;
411
    DataTypePtr _return_type;
412
    mutable bool _first_created;
413
    mutable AggregateDataPtr _exec_place;
414
    std::string _local_location;
415
};
416
417
} // namespace doris