Coverage Report

Created: 2026-03-14 13:33

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_rpc.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/Exprs_types.h>
21
#include <gen_cpp/function_service.pb.h>
22
23
#include <cstdint>
24
#include <memory>
25
26
#include "common/status.h"
27
#include "core/block/block.h"
28
#include "core/block/column_numbers.h"
29
#include "core/column/column_string.h"
30
#include "core/column/column_vector.h"
31
#include "core/data_type/data_type_string.h"
32
#include "core/field.h"
33
#include "core/string_ref.h"
34
#include "core/types.h"
35
#include "exprs/aggregate/aggregate_function.h"
36
#include "exprs/function/function_rpc.h"
37
#include "json2pb/json_to_pb.h"
38
#include "json2pb/pb_to_json.h"
39
#include "runtime/exec_env.h"
40
#include "runtime/user_function_cache.h"
41
#include "util/brpc_client_cache.h"
42
#include "util/io_helper.h"
43
#include "util/jni-util.h"
44
namespace doris {
45
#include "common/compile_check_avoid_begin.h"
46
// The rpc function has now been deprecated to avoid compilation checks.
47
48
0
#define error_default_str "#$@"
49
50
constexpr int64_t max_buffered_rows = 4096;
51
52
struct AggregateRpcUdafData {
53
private:
54
    std::string _update_fn;
55
    std::string _merge_fn;
56
    std::string _server_addr;
57
    std::string _finalize_fn;
58
    bool _saved_last_result;
59
    std::shared_ptr<PFunctionService_Stub> _client;
60
    PFunctionCallResponse _res;
61
    std::vector<PFunctionCallRequest> _buffer_request;
62
    bool _error;
63
64
public:
65
0
    AggregateRpcUdafData() = default;
66
0
    AggregateRpcUdafData(int64_t num_args) { set_last_result(false); }
67
68
0
    bool has_last_result() { return _saved_last_result == true; }
69
70
0
    void set_last_result(bool flag) { _saved_last_result = flag; }
71
72
0
    ~AggregateRpcUdafData() {}
73
74
0
    void set_error(bool flag) { _error = flag; }
75
76
0
    bool has_error() { return _error == true; }
77
78
0
    Status merge(AggregateRpcUdafData& rhs) {
79
0
        static_cast<void>(send_buffer_to_rpc_server());
80
0
        if (has_last_result()) {
81
0
            PFunctionCallRequest request;
82
0
            PFunctionCallResponse response;
83
0
            brpc::Controller cntl;
84
0
            PFunctionCallResponse current_res = rhs.get_result();
85
0
            request.set_function_name(_merge_fn);
86
            //last result
87
0
            PValues* arg = request.add_args();
88
0
            arg->CopyFrom(_res.result(0));
89
0
            arg = request.add_args();
90
            //current result
91
0
            arg->CopyFrom(current_res.result(0));
92
            //send to rpc server  that impl the merge op, the will save the result
93
0
            RETURN_IF_ERROR(send_rpc_request(cntl, request, response));
94
0
            _res = response;
95
0
        } else {
96
0
            _res = rhs.get_result();
97
0
            set_last_result(true);
98
0
        }
99
0
        return Status::OK();
100
0
    }
101
102
0
    Status init(const TFunction& fn) {
103
0
        _update_fn = fn.aggregate_fn.update_fn_symbol;
104
0
        _merge_fn = fn.aggregate_fn.merge_fn_symbol;
105
0
        _server_addr = fn.hdfs_location;
106
0
        _finalize_fn = fn.aggregate_fn.finalize_fn_symbol;
107
0
        set_error(false);
108
0
        _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server_addr);
109
0
        if (_client == nullptr) {
110
0
            std::string err_msg = "init rpc error, addr:" + _server_addr;
111
0
            LOG(ERROR) << err_msg;
112
0
            set_error(true);
113
0
            return Status::InternalError(err_msg);
114
0
        }
115
0
        return Status::OK();
116
0
    }
117
118
    Status send_rpc_request(brpc::Controller& cntl, PFunctionCallRequest& request,
119
0
                            PFunctionCallResponse& response) {
120
0
        _client->fn_call(&cntl, &request, &response, nullptr);
121
0
        if (cntl.Failed()) {
122
0
            set_error(true);
123
0
            std::stringstream err_msg;
124
0
            err_msg << " call rpc function failed";
125
0
            err_msg << " _server_addr:" << _server_addr;
126
0
            err_msg << " function_name:" << request.function_name();
127
0
            err_msg << " err:" << cntl.ErrorText();
128
0
            LOG(ERROR) << err_msg.str();
129
0
            return Status::InternalError(err_msg.str());
130
0
        }
131
0
        if (!response.has_status() || response.result_size() == 0) {
132
0
            set_error(true);
133
0
            std::stringstream err_msg;
134
0
            err_msg << " call rpc function failed, status or result is not set";
135
0
            err_msg << " _server_addr:" << _server_addr;
136
0
            err_msg << " function_name:" << request.function_name();
137
0
            LOG(ERROR) << err_msg.str();
138
0
            return Status::InternalError(err_msg.str());
139
0
        }
140
0
        if (response.status().status_code() != 0) {
141
0
            set_error(true);
142
0
            std::stringstream err_msg;
143
0
            err_msg << " call rpc function failed";
144
0
            err_msg << " _server_addr:" << _server_addr;
145
0
            err_msg << " function_name:" << request.function_name();
146
0
            err_msg << " err:" << response.status().DebugString();
147
0
            LOG(ERROR) << err_msg.str();
148
0
            return Status::InternalError(err_msg.str());
149
0
        }
150
0
        return Status::OK();
151
0
    }
152
153
    Status gen_request_data(PFunctionCallRequest& request, const IColumn** columns, int start,
154
0
                            int end, const DataTypes& argument_types) {
155
0
        for (int i = 0; i < argument_types.size(); i++) {
156
0
            PValues* arg = request.add_args();
157
0
            auto data_type = argument_types[i];
158
0
            if (auto st = data_type->get_serde()->write_column_to_pb(*columns[i], *arg, start, end);
159
0
                !st.ok()) {
160
0
                return st;
161
0
            }
162
0
        }
163
0
        return Status::OK();
164
0
    }
165
166
#define ADD_VALUE(TYPEVALUE)                                           \
167
0
    if (_buffer_request[j].args(i).TYPEVALUE##_##size() > 0) {         \
168
0
        arg->add_##TYPEVALUE(_buffer_request[j].args(i).TYPEVALUE(0)); \
169
0
    }
170
171
0
    PFunctionCallRequest merge_buffer_request(PFunctionCallRequest& request) {
172
0
        int args_size = _buffer_request[0].args_size();
173
0
        request.set_function_name(_update_fn);
174
0
        for (int i = 0; i < args_size; i++) {
175
0
            PValues* arg = request.add_args();
176
0
            arg->mutable_type()->CopyFrom(_buffer_request[0].args(i).type());
177
0
            for (int j = 0; j < _buffer_request.size(); j++) {
178
0
                ADD_VALUE(double_value);
179
0
                ADD_VALUE(float_value);
180
0
                ADD_VALUE(int32_value);
181
0
                ADD_VALUE(int64_value);
182
0
                ADD_VALUE(uint32_value);
183
0
                ADD_VALUE(uint64_value);
184
0
                ADD_VALUE(bool_value);
185
0
                ADD_VALUE(string_value);
186
0
                ADD_VALUE(bytes_value);
187
0
            }
188
0
        }
189
0
        return request;
190
0
    }
191
192
    // called in group agg op
193
    Status buffer_add(const IColumn** columns, int start, int end,
194
0
                      const DataTypes& argument_types) {
195
0
        PFunctionCallRequest request;
196
0
        static_cast<void>(gen_request_data(request, columns, start, end, argument_types));
197
0
        _buffer_request.push_back(request);
198
0
        if (_buffer_request.size() >= max_buffered_rows) {
199
0
            static_cast<void>(send_buffer_to_rpc_server());
200
0
        }
201
0
        return Status::OK();
202
0
    }
203
204
    //clear buffer request
205
0
    Status send_buffer_to_rpc_server() {
206
0
        if (_buffer_request.size() > 0) {
207
0
            PFunctionCallRequest request;
208
0
            PFunctionCallResponse response;
209
0
            brpc::Controller cntl;
210
0
            merge_buffer_request(request);
211
0
            if (has_last_result()) {
212
0
                request.mutable_context()
213
0
                        ->mutable_function_context()
214
0
                        ->mutable_args_data()
215
0
                        ->CopyFrom(_res.result());
216
0
            }
217
0
            RETURN_IF_ERROR(send_rpc_request(cntl, request, response));
218
0
            _res = response;
219
0
            set_last_result(true);
220
0
            _buffer_request.clear();
221
0
        }
222
0
        return Status::OK();
223
0
    }
224
225
0
    Status add(const IColumn** columns, int start, int end, const DataTypes& argument_types) {
226
0
        PFunctionCallRequest request;
227
0
        PFunctionCallResponse response;
228
0
        brpc::Controller cntl;
229
0
        request.set_function_name(_update_fn);
230
0
        static_cast<void>(gen_request_data(request, columns, start, end, argument_types));
231
0
        if (has_last_result()) {
232
0
            request.mutable_context()->mutable_function_context()->mutable_args_data()->CopyFrom(
233
0
                    _res.result());
234
0
        }
235
0
        RETURN_IF_ERROR(send_rpc_request(cntl, request, response));
236
0
        _res = response;
237
0
        set_last_result(true);
238
0
        return Status::OK();
239
0
    }
240
241
0
    void serialize(BufferWritable& buf) {
242
0
        static_cast<void>(send_buffer_to_rpc_server());
243
0
        std::string serialize_data = error_default_str;
244
0
        if (!has_error()) {
245
0
            serialize_data = _res.SerializeAsString();
246
0
        } else {
247
0
            LOG(ERROR) << "serialize empty buf";
248
0
        }
249
0
        buf.write_binary(serialize_data);
250
0
    }
251
252
0
    void deserialize(BufferReadable& buf) {
253
0
        static_cast<void>(send_buffer_to_rpc_server());
254
0
        std::string serialize_data;
255
0
        buf.read_binary(serialize_data);
256
0
        if (error_default_str != serialize_data) {
257
0
            _res.ParseFromString(serialize_data);
258
0
            set_last_result(true);
259
0
        } else {
260
0
            LOG(ERROR) << "deserialize empty buf";
261
0
            set_error(true);
262
0
        }
263
0
    }
264
265
#define GETDATA(LOCATTYPE, TYPEVALUE)                                                          \
266
0
    if (response.result_size() > 0 && response.result(0).TYPEVALUE##_##value_size() > 0) {     \
267
0
        LOCATTYPE ret = response.result(0).TYPEVALUE##_##value(0);                             \
268
0
        to.insert_data((char*)&ret, 0);                                                        \
269
0
    } else {                                                                                   \
270
0
        LOG(ERROR) << "_server_addr:" << _server_addr << ",_finalize_fn:" << _finalize_fn      \
271
0
                   << ",msg: failed to get final result cause return type need " << #TYPEVALUE \
272
0
                   << "but result is empty";                                                   \
273
0
        to.insert_default();                                                                   \
274
0
    }
275
276
    //if any unexpected error happen will return NULL
277
0
    Status get(IColumn& to, const DataTypePtr& return_type) {
278
0
        if (has_error()) {
279
0
            to.insert_default();
280
0
            return Status::OK();
281
0
        }
282
0
        static_cast<void>(send_buffer_to_rpc_server());
283
0
        PFunctionCallRequest request;
284
0
        PFunctionCallResponse response;
285
0
        brpc::Controller cntl;
286
0
        request.set_function_name(_finalize_fn);
287
0
        request.mutable_context()->mutable_function_context()->mutable_args_data()->CopyFrom(
288
0
                _res.result());
289
0
        static_cast<void>(send_rpc_request(cntl, request, response));
290
0
        if (has_error()) {
291
0
            to.insert_default();
292
0
            return Status::OK();
293
0
        }
294
0
        DataTypePtr result_type = return_type;
295
0
        if (return_type->is_nullable()) {
296
0
            result_type =
297
0
                    reinterpret_cast<const DataTypeNullable*>(return_type.get())->get_nested_type();
298
0
        }
299
0
        switch (result_type->get_primitive_type()) {
300
0
        case TYPE_FLOAT:
301
0
            GETDATA(float, float);
302
0
            break;
303
0
        case TYPE_DOUBLE:
304
0
            GETDATA(double, double);
305
0
            break;
306
0
        case TYPE_INT:
307
0
            GETDATA(int32_t, int32);
308
0
            break;
309
0
        case TYPE_BIGINT:
310
0
            GETDATA(int64_t, int64);
311
0
            break;
312
0
        case TYPE_BOOLEAN:
313
0
            GETDATA(uint8_t, bool);
314
0
            break;
315
0
        case TYPE_STRING:
316
0
        case TYPE_CHAR:
317
0
        case TYPE_VARCHAR: {
318
0
            if (response.result_size() > 0 && response.result(0).string_value_size() > 0) {
319
0
                std::string ret = response.result(0).string_value(0);
320
0
                to.insert_data(ret.c_str(), ret.size());
321
0
            } else {
322
0
                LOG(ERROR) << "_server_addr:" << _server_addr << ",_finalize_fn:" << _finalize_fn
323
0
                           << ",msg: failed to get final result cause return type need string but "
324
0
                              "result is empty";
325
0
                to.insert_default();
326
0
            }
327
0
            break;
328
0
        }
329
0
        default:
330
0
            LOG(ERROR) << "failed to get result cause unkown return type";
331
0
            to.insert_default();
332
0
        }
333
0
        return Status::OK();
334
0
    }
335
336
0
    PFunctionCallResponse get_result() {
337
0
        static_cast<void>(send_buffer_to_rpc_server());
338
0
        return _res;
339
0
    }
340
};
341
342
class AggregateRpcUdaf final
343
        : public IAggregateFunctionDataHelper<AggregateRpcUdafData, AggregateRpcUdaf> {
344
public:
345
    AggregateRpcUdaf(const TFunction& fn, const DataTypes& argument_types_,
346
                     const DataTypePtr& return_type)
347
0
            : IAggregateFunctionDataHelper(argument_types_), _fn(fn), _return_type(return_type) {}
348
0
    ~AggregateRpcUdaf() = default;
349
350
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
351
0
                                       const DataTypePtr& return_type) {
352
0
        return std::make_shared<AggregateRpcUdaf>(fn, argument_types_, return_type);
353
0
    }
354
355
0
    void create(AggregateDataPtr __restrict place) const override {
356
0
        new (place) Data(argument_types.size());
357
0
        Status status = Status::OK();
358
0
        SAFE_CREATE(RETURN_IF_STATUS_ERROR(status, data(place).init(_fn)),
359
0
                    this->data(place).~Data());
360
0
    }
361
362
0
    String get_name() const override { return _fn.name.function_name; }
363
364
0
    DataTypePtr get_return_type() const override { return _return_type; }
365
366
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
367
0
             Arena&) const override {
368
0
        static_cast<void>(
369
0
                this->data(place).buffer_add(columns, row_num, row_num + 1, argument_types));
370
0
    }
371
372
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
373
0
                                Arena&) const override {
374
0
        static_cast<void>(this->data(place).add(columns, 0, batch_size, argument_types));
375
0
    }
376
377
0
    void reset(AggregateDataPtr place) const override {}
378
379
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
380
0
               Arena&) const override {
381
        // place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
382
        // todo: rethink the merge method to determine whether const_cast is necessary.
383
0
        static_cast<void>(this->data(place).merge(this->data(const_cast<AggregateDataPtr>(rhs))));
384
0
    }
385
386
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
387
        // place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
388
        // todo: rethink the serialize method to determine whether const_cast is necessary.
389
0
        this->data(const_cast<AggregateDataPtr&>(place)).serialize(buf);
390
0
    }
391
392
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
393
0
                     Arena&) const override {
394
0
        this->data(place).deserialize(buf);
395
0
    }
396
397
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
398
        // place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
399
        // todo: rethink the get method to determine whether const_cast is necessary.
400
0
        static_cast<void>(this->data(const_cast<AggregateDataPtr>(place)).get(to, _return_type));
401
0
    }
402
403
private:
404
    TFunction _fn;
405
    DataTypePtr _return_type;
406
};
407
408
#include "common/compile_check_avoid_end.h"
409
} // namespace doris