Coverage Report

Created: 2026-04-20 14:38

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_java_udaf.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <jni.h>
21
#include <unistd.h>
22
23
#include <cstdint>
24
#include <memory>
25
26
#include "absl/strings/substitute.h"
27
#include "common/cast_set.h"
28
#include "common/compiler_util.h"
29
#include "common/exception.h"
30
#include "common/logging.h"
31
#include "common/status.h"
32
#include "core/column/column_array.h"
33
#include "core/column/column_map.h"
34
#include "core/column/column_string.h"
35
#include "core/field.h"
36
#include "core/string_ref.h"
37
#include "core/types.h"
38
#include "exprs/aggregate/aggregate_function.h"
39
#include "format/jni/jni_data_bridge.h"
40
#include "runtime/user_function_cache.h"
41
#include "util/jni-util.h"
42
43
namespace doris {
44
45
const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor";
46
const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V";
47
const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V";
48
const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V";
49
const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZIIJILjava/util/Map;)V";
50
const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
51
const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
52
const char* UDAF_EXECUTOR_GET_SIGNATURE = "(JLjava/util/Map;)J";
53
const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
54
// Calling Java method about those signature means: "(argument-types)return-type"
55
// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
56
57
struct AggregateJavaUdafData {
58
public:
59
0
    AggregateJavaUdafData() = default;
60
0
    AggregateJavaUdafData(int64_t num_args) { cast_set(argument_size, num_args); }
61
62
0
    ~AggregateJavaUdafData() = default;
63
64
0
    Status close_and_delete_object() {
65
0
        JNIEnv* env = nullptr;
66
67
0
        RETURN_IF_ERROR(Jni::Env::Get(&env));
68
69
0
        auto st = executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_close_id)
70
0
                          .call();
71
0
        if (!st.ok()) {
72
0
            LOG(WARNING) << "Failed to close JAVA UDAF: " << st.to_string();
73
0
            return st;
74
0
        }
75
0
        return Status::OK();
76
0
    }
77
78
0
    Status init_udaf(const TFunction& fn, const std::string& local_location) {
79
0
        JNIEnv* env = nullptr;
80
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf init_udaf function");
81
0
        RETURN_IF_ERROR(Jni::Util::find_class(env, UDAF_EXECUTOR_CLASS, &executor_cl));
82
0
        RETURN_NOT_OK_STATUS_WITH_WARN(register_func_id(env),
83
0
                                       "Java-Udaf register_func_id function");
84
85
0
        TJavaUdfExecutorCtorParams ctor_params;
86
0
        ctor_params.__set_fn(fn);
87
0
        if (!fn.hdfs_location.empty() && !fn.checksum.empty()) {
88
0
            ctor_params.__set_location(local_location);
89
0
        }
90
91
0
        Jni::LocalArray ctor_params_bytes;
92
0
        RETURN_IF_ERROR(Jni::Util::SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes));
93
0
        RETURN_IF_ERROR(executor_cl.new_object(env, executor_ctor_id)
94
0
                                .with_arg(ctor_params_bytes)
95
0
                                .call(&executor_obj));
96
0
        return Status::OK();
97
0
    }
98
99
    Status add(int64_t places_address, bool is_single_place, const IColumn** columns,
100
               int64_t row_num_start, int64_t row_num_end, const DataTypes& argument_types,
101
0
               int64_t place_offset) {
102
0
        JNIEnv* env = nullptr;
103
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf add function");
104
105
0
        Block input_block;
106
0
        for (size_t i = 0; i < argument_size; ++i) {
107
0
            input_block.insert(ColumnWithTypeAndName(columns[i]->get_ptr(), argument_types[i],
108
0
                                                     std::to_string(i)));
109
0
        }
110
0
        std::unique_ptr<long[]> input_table;
111
0
        RETURN_IF_ERROR(JniDataBridge::to_java_table(&input_block, input_table));
112
0
        auto input_table_schema = JniDataBridge::parse_table_schema(&input_block);
113
0
        std::map<String, String> input_params = {
114
0
                {"meta_address", std::to_string((long)input_table.get())},
115
0
                {"required_fields", input_table_schema.first},
116
0
                {"columns_types", input_table_schema.second}};
117
118
0
        Jni::LocalObject input_map;
119
0
        RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, input_params, &input_map));
120
        // invoke add batch
121
        // Keep consistent with the function signature of executor_add_batch_id.
122
123
0
        return executor_obj.call_void_method(env, executor_add_batch_id)
124
0
                .with_arg((jboolean)is_single_place)
125
0
                .with_arg(cast_set<jint>(row_num_start))
126
0
                .with_arg(cast_set<jint>(row_num_end))
127
0
                .with_arg(cast_set<jlong>(places_address))
128
0
                .with_arg(cast_set<jint>(place_offset))
129
0
                .with_arg(input_map)
130
0
                .call();
131
0
    }
132
133
0
    Status merge(const AggregateJavaUdafData& rhs, int64_t place) {
134
0
        JNIEnv* env = nullptr;
135
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf merge function");
136
0
        serialize_data = rhs.serialize_data;
137
0
        Jni::LocalArray byte_arr;
138
0
        RETURN_IF_ERROR(Jni::Util::WriteBufferToByteArray(env, (jbyte*)serialize_data.data(),
139
0
                                                          cast_set<jsize>(serialize_data.length()),
140
0
                                                          &byte_arr));
141
142
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_merge_id)
143
0
                .with_arg((jlong)place)
144
0
                .with_arg(byte_arr)
145
0
                .call();
146
0
    }
147
148
0
    Status write(BufferWritable& buf, int64_t place) {
149
0
        JNIEnv* env = nullptr;
150
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf write function");
151
        // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to
152
        // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again.
153
0
        Jni::LocalArray arr;
154
0
        RETURN_IF_ERROR(
155
0
                executor_obj.call_nonvirtual_object_method(env, executor_cl, executor_serialize_id)
156
0
                        .with_arg((jlong)place)
157
0
                        .call(&arr));
158
159
0
        jsize len = 0;
160
0
        RETURN_IF_ERROR(arr.get_length(env, &len));
161
0
        serialize_data.resize(len);
162
0
        RETURN_IF_ERROR(arr.get_byte_elements(env, 0, len,
163
0
                                              reinterpret_cast<jbyte*>(serialize_data.data())));
164
0
        buf.write_binary(serialize_data);
165
0
        return Status::OK();
166
0
    }
167
168
0
    Status reset(int64_t place) {
169
0
        JNIEnv* env = nullptr;
170
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf reset function");
171
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_reset_id)
172
0
                .with_arg(cast_set<jlong>(place))
173
0
                .call();
174
0
    }
175
176
0
    void read(BufferReadable& buf) { buf.read_binary(serialize_data); }
177
178
0
    Status destroy() {
179
0
        JNIEnv* env = nullptr;
180
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf destroy function");
181
0
        return executor_obj.call_nonvirtual_void_method(env, executor_cl, executor_destroy_id)
182
0
                .call();
183
0
    }
184
185
0
    Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const {
186
0
        JNIEnv* env = nullptr;
187
0
        RETURN_NOT_OK_STATUS_WITH_WARN(Jni::Env::Get(&env), "Java-Udaf get value function");
188
189
0
        Block output_block;
190
0
        output_block.insert(ColumnWithTypeAndName(to.get_ptr(), result_type, "_result_"));
191
0
        auto output_table_schema = JniDataBridge::parse_table_schema(&output_block);
192
0
        std::string output_nullable = result_type->is_nullable() ? "true" : "false";
193
0
        std::map<String, String> output_params = {{"is_nullable", output_nullable},
194
0
                                                  {"required_fields", output_table_schema.first},
195
0
                                                  {"columns_types", output_table_schema.second}};
196
197
0
        Jni::LocalObject output_map;
198
0
        RETURN_IF_ERROR(Jni::Util::convert_to_java_map(env, output_params, &output_map));
199
0
        long output_address;
200
201
0
        RETURN_IF_ERROR(executor_obj.call_long_method(env, executor_get_value_id)
202
0
                                .with_arg(cast_set<jlong>(place))
203
0
                                .with_arg(output_map)
204
0
                                .call(&output_address));
205
206
0
        return JniDataBridge::fill_block(&output_block, {0}, output_address);
207
0
    }
208
209
private:
210
0
    Status register_func_id(JNIEnv* env) {
211
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "<init>", UDAF_EXECUTOR_CTOR_SIGNATURE,
212
0
                                               &executor_ctor_id));
213
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "reset", UDAF_EXECUTOR_RESET_SIGNATURE,
214
0
                                               &executor_reset_id));
215
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "close", UDAF_EXECUTOR_CLOSE_SIGNATURE,
216
0
                                               &executor_close_id));
217
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "merge", UDAF_EXECUTOR_MERGE_SIGNATURE,
218
0
                                               &executor_merge_id));
219
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE,
220
0
                                               &executor_serialize_id));
221
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "getValue", UDAF_EXECUTOR_GET_SIGNATURE,
222
0
                                               &executor_get_value_id));
223
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE,
224
0
                                               &executor_destroy_id));
225
0
        RETURN_IF_ERROR(executor_cl.get_method(env, "addBatch", UDAF_EXECUTOR_ADD_SIGNATURE,
226
0
                                               &executor_add_batch_id));
227
228
0
        return Status::OK();
229
0
    }
230
231
private:
232
    // TODO: too many variables are hold, it's causing a lot of memory waste
233
    // it's time to refactor it.
234
    Jni::GlobalClass executor_cl;
235
    Jni::GlobalObject executor_obj;
236
237
    Jni::MethodId executor_ctor_id;
238
    Jni::MethodId executor_add_batch_id;
239
    Jni::MethodId executor_merge_id;
240
    Jni::MethodId executor_serialize_id;
241
    Jni::MethodId executor_get_value_id;
242
    Jni::MethodId executor_reset_id;
243
    Jni::MethodId executor_close_id;
244
    Jni::MethodId executor_destroy_id;
245
    int argument_size = 0;
246
    std::string serialize_data;
247
};
248
249
class AggregateJavaUdaf final
250
        : public IAggregateFunctionDataHelper<AggregateJavaUdafData, AggregateJavaUdaf>,
251
          VarargsExpression,
252
          NullableAggregateFunction {
253
public:
254
    ENABLE_FACTORY_CREATOR(AggregateJavaUdaf);
255
    AggregateJavaUdaf(const TFunction& fn, const DataTypes& argument_types_,
256
                      const DataTypePtr& return_type)
257
0
            : IAggregateFunctionDataHelper(argument_types_),
258
0
              _fn(fn),
259
0
              _return_type(return_type),
260
0
              _first_created(true),
261
0
              _exec_place(nullptr) {}
262
0
    ~AggregateJavaUdaf() override = default;
263
264
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
265
0
                                       const DataTypePtr& return_type) {
266
0
        return std::make_shared<AggregateJavaUdaf>(fn, argument_types_, return_type);
267
0
    }
268
    //Note: The condition is added because maybe the BE can't find java-udaf impl jar
269
    //So need to check as soon as possible, before call Data function
270
0
    Status check_udaf(const TFunction& fn) {
271
0
        auto function_cache = UserFunctionCache::instance();
272
        // get jar path if both file path location and checksum are null
273
0
        if (!fn.hdfs_location.empty() && !fn.checksum.empty()) {
274
0
            return function_cache->get_jarpath(fn.id, fn.hdfs_location, fn.checksum,
275
0
                                               &_local_location);
276
0
        } else {
277
0
            return Status::OK();
278
0
        }
279
0
    }
280
281
0
    void create(AggregateDataPtr __restrict place) const override {
282
0
        new (place) Data(argument_types.size());
283
0
        if (_first_created) {
284
0
            Status status = this->data(place).init_udaf(_fn, _local_location);
285
0
            _first_created = false;
286
0
            _exec_place = place;
287
0
            if (UNLIKELY(!status.ok())) {
288
0
                static_cast<void>(this->data(place).destroy());
289
0
                this->data(place).~Data();
290
0
                throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string());
291
0
            }
292
0
        }
293
0
    }
294
295
    // To avoid multiple times JNI call, Here will destroy all data at once
296
0
    void destroy(AggregateDataPtr __restrict place) const noexcept override {
297
0
        if (place == _exec_place) {
298
0
            Status status = Status::OK();
299
0
            status = this->data(_exec_place).destroy();
300
0
            status = this->data(_exec_place).close_and_delete_object();
301
0
            _first_created = true;
302
0
            if (UNLIKELY(!status.ok())) {
303
0
                LOG(WARNING) << "Failed to destroy function: " << status.to_string();
304
0
            }
305
0
        }
306
0
        this->data(place).~Data();
307
0
    }
308
309
0
    String get_name() const override { return _fn.name.function_name; }
310
311
0
    DataTypePtr get_return_type() const override { return _return_type; }
312
313
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
314
0
             Arena&) const override {
315
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
316
0
        Status st = this->data(_exec_place)
317
0
                            .add(places_address, true, columns, row_num, row_num + 1,
318
0
                                 argument_types, 0);
319
0
        if (UNLIKELY(!st.ok())) {
320
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
321
0
        }
322
0
    }
323
324
    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
325
0
                   const IColumn** columns, Arena&, bool /*agg_many*/) const override {
326
0
        int64_t places_address = reinterpret_cast<int64_t>(places);
327
0
        Status st = this->data(_exec_place)
328
0
                            .add(places_address, false, columns, 0, batch_size, argument_types,
329
0
                                 place_offset);
330
0
        if (UNLIKELY(!st.ok())) {
331
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
332
0
        }
333
0
    }
334
335
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
336
0
                                Arena&) const override {
337
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
338
0
        Status st = this->data(_exec_place)
339
0
                            .add(places_address, true, columns, 0, batch_size, argument_types, 0);
340
0
        if (UNLIKELY(!st.ok())) {
341
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
342
0
        }
343
0
    }
344
345
    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
346
                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
347
                                Arena&, UInt8* current_window_empty,
348
0
                                UInt8* current_window_has_inited) const override {
349
0
        frame_start = std::max<int64_t>(frame_start, partition_start);
350
0
        frame_end = std::min<int64_t>(frame_end, partition_end);
351
0
        int64_t places_address = reinterpret_cast<int64_t>(place);
352
0
        Status st = this->data(_exec_place)
353
0
                            .add(places_address, true, columns, frame_start, frame_end,
354
0
                                 argument_types, 0);
355
0
        if (frame_start >= frame_end) {
356
0
            if (!*current_window_has_inited) {
357
0
                *current_window_empty = true;
358
0
            }
359
0
        } else {
360
0
            *current_window_empty = false;
361
0
            *current_window_has_inited = true;
362
0
        }
363
0
        if (UNLIKELY(!st.ok())) {
364
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
365
0
        }
366
0
    }
367
368
0
    void reset(AggregateDataPtr place) const override {
369
0
        Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place));
370
0
        if (UNLIKELY(!st.ok())) {
371
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
372
0
        }
373
0
    }
374
375
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
376
0
               Arena&) const override {
377
0
        Status st =
378
0
                this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place));
379
0
        if (UNLIKELY(!st.ok())) {
380
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
381
0
        }
382
0
    }
383
384
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
385
0
        Status st = this->data(_exec_place).write(buf, reinterpret_cast<int64_t>(place));
386
0
        if (UNLIKELY(!st.ok())) {
387
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
388
0
        }
389
0
    }
390
391
    // during merge-finalized phase, for deserialize and merge firstly,
392
    // will call create --- deserialize --- merge --- destory for each rows ,
393
    // so need doing new (place), to create Data and read to buf, then call merge ,
394
    // and during destory about deserialize, because haven't done init_udaf,
395
    // so it's can't call ~Data, only to change _destory_deserialize flag.
396
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
397
0
                     Arena&) const override {
398
0
        this->data(place).read(buf);
399
0
    }
400
401
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
402
0
        Status st = this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place));
403
0
        if (UNLIKELY(!st.ok())) {
404
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
405
0
        }
406
0
    }
407
408
private:
409
    TFunction _fn;
410
    DataTypePtr _return_type;
411
    mutable bool _first_created;
412
    mutable AggregateDataPtr _exec_place;
413
    std::string _local_location;
414
};
415
416
} // namespace doris