Coverage Report

Created: 2026-04-16 20:39

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_rpc.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/Exprs_types.h>
21
#include <gen_cpp/function_service.pb.h>
22
23
#include <cstdint>
24
#include <memory>
25
26
#include "common/status.h"
27
#include "core/block/block.h"
28
#include "core/block/column_numbers.h"
29
#include "core/column/column_string.h"
30
#include "core/column/column_vector.h"
31
#include "core/data_type/data_type_string.h"
32
#include "core/field.h"
33
#include "core/string_ref.h"
34
#include "core/types.h"
35
#include "exprs/aggregate/aggregate_function.h"
36
#include "exprs/function/function_rpc.h"
37
#include "json2pb/json_to_pb.h"
38
#include "json2pb/pb_to_json.h"
39
#include "runtime/exec_env.h"
40
#include "runtime/user_function_cache.h"
41
#include "util/brpc_client_cache.h"
42
#include "util/jni-util.h"
43
namespace doris {
44
#include "common/compile_check_avoid_begin.h"
45
// The rpc function has now been deprecated to avoid compilation checks.
46
47
0
#define error_default_str "#$@"
48
49
constexpr int64_t max_buffered_rows = 4096;
50
51
struct AggregateRpcUdafData {
52
private:
53
    std::string _update_fn;
54
    std::string _merge_fn;
55
    std::string _server_addr;
56
    std::string _finalize_fn;
57
    bool _saved_last_result;
58
    std::shared_ptr<PFunctionService_Stub> _client;
59
    PFunctionCallResponse _res;
60
    std::vector<PFunctionCallRequest> _buffer_request;
61
    bool _error;
62
63
public:
64
0
    AggregateRpcUdafData() = default;
65
0
    AggregateRpcUdafData(int64_t num_args) { set_last_result(false); }
66
67
0
    bool has_last_result() { return _saved_last_result == true; }
68
69
0
    void set_last_result(bool flag) { _saved_last_result = flag; }
70
71
0
    ~AggregateRpcUdafData() {}
72
73
0
    void set_error(bool flag) { _error = flag; }
74
75
0
    bool has_error() { return _error == true; }
76
77
0
    Status merge(AggregateRpcUdafData& rhs) {
78
0
        static_cast<void>(send_buffer_to_rpc_server());
79
0
        if (has_last_result()) {
80
0
            PFunctionCallRequest request;
81
0
            PFunctionCallResponse response;
82
0
            brpc::Controller cntl;
83
0
            PFunctionCallResponse current_res = rhs.get_result();
84
0
            request.set_function_name(_merge_fn);
85
            //last result
86
0
            PValues* arg = request.add_args();
87
0
            arg->CopyFrom(_res.result(0));
88
0
            arg = request.add_args();
89
            //current result
90
0
            arg->CopyFrom(current_res.result(0));
91
            //send to rpc server  that impl the merge op, the will save the result
92
0
            RETURN_IF_ERROR(send_rpc_request(cntl, request, response));
93
0
            _res = response;
94
0
        } else {
95
0
            _res = rhs.get_result();
96
0
            set_last_result(true);
97
0
        }
98
0
        return Status::OK();
99
0
    }
100
101
0
    Status init(const TFunction& fn) {
102
0
        _update_fn = fn.aggregate_fn.update_fn_symbol;
103
0
        _merge_fn = fn.aggregate_fn.merge_fn_symbol;
104
0
        _server_addr = fn.hdfs_location;
105
0
        _finalize_fn = fn.aggregate_fn.finalize_fn_symbol;
106
0
        set_error(false);
107
0
        _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server_addr);
108
0
        if (_client == nullptr) {
109
0
            std::string err_msg = "init rpc error, addr:" + _server_addr;
110
0
            LOG(ERROR) << err_msg;
111
0
            set_error(true);
112
0
            return Status::InternalError(err_msg);
113
0
        }
114
0
        return Status::OK();
115
0
    }
116
117
    Status send_rpc_request(brpc::Controller& cntl, PFunctionCallRequest& request,
118
0
                            PFunctionCallResponse& response) {
119
0
        _client->fn_call(&cntl, &request, &response, nullptr);
120
0
        if (cntl.Failed()) {
121
0
            set_error(true);
122
0
            std::stringstream err_msg;
123
0
            err_msg << " call rpc function failed";
124
0
            err_msg << " _server_addr:" << _server_addr;
125
0
            err_msg << " function_name:" << request.function_name();
126
0
            err_msg << " err:" << cntl.ErrorText();
127
0
            LOG(ERROR) << err_msg.str();
128
0
            return Status::InternalError(err_msg.str());
129
0
        }
130
0
        if (!response.has_status() || response.result_size() == 0) {
131
0
            set_error(true);
132
0
            std::stringstream err_msg;
133
0
            err_msg << " call rpc function failed, status or result is not set";
134
0
            err_msg << " _server_addr:" << _server_addr;
135
0
            err_msg << " function_name:" << request.function_name();
136
0
            LOG(ERROR) << err_msg.str();
137
0
            return Status::InternalError(err_msg.str());
138
0
        }
139
0
        if (response.status().status_code() != 0) {
140
0
            set_error(true);
141
0
            std::stringstream err_msg;
142
0
            err_msg << " call rpc function failed";
143
0
            err_msg << " _server_addr:" << _server_addr;
144
0
            err_msg << " function_name:" << request.function_name();
145
0
            err_msg << " err:" << response.status().DebugString();
146
0
            LOG(ERROR) << err_msg.str();
147
0
            return Status::InternalError(err_msg.str());
148
0
        }
149
0
        return Status::OK();
150
0
    }
151
152
    Status gen_request_data(PFunctionCallRequest& request, const IColumn** columns, int start,
153
0
                            int end, const DataTypes& argument_types) {
154
0
        for (int i = 0; i < argument_types.size(); i++) {
155
0
            PValues* arg = request.add_args();
156
0
            auto data_type = argument_types[i];
157
0
            if (auto st = data_type->get_serde()->write_column_to_pb(*columns[i], *arg, start, end);
158
0
                !st.ok()) {
159
0
                return st;
160
0
            }
161
0
        }
162
0
        return Status::OK();
163
0
    }
164
165
#define ADD_VALUE(TYPEVALUE)                                           \
166
0
    if (_buffer_request[j].args(i).TYPEVALUE##_##size() > 0) {         \
167
0
        arg->add_##TYPEVALUE(_buffer_request[j].args(i).TYPEVALUE(0)); \
168
0
    }
169
170
0
    PFunctionCallRequest merge_buffer_request(PFunctionCallRequest& request) {
171
0
        int args_size = _buffer_request[0].args_size();
172
0
        request.set_function_name(_update_fn);
173
0
        for (int i = 0; i < args_size; i++) {
174
0
            PValues* arg = request.add_args();
175
0
            arg->mutable_type()->CopyFrom(_buffer_request[0].args(i).type());
176
0
            for (int j = 0; j < _buffer_request.size(); j++) {
177
0
                ADD_VALUE(double_value);
178
0
                ADD_VALUE(float_value);
179
0
                ADD_VALUE(int32_value);
180
0
                ADD_VALUE(int64_value);
181
0
                ADD_VALUE(uint32_value);
182
0
                ADD_VALUE(uint64_value);
183
0
                ADD_VALUE(bool_value);
184
0
                ADD_VALUE(string_value);
185
0
                ADD_VALUE(bytes_value);
186
0
            }
187
0
        }
188
0
        return request;
189
0
    }
190
191
    // called in group agg op
192
    Status buffer_add(const IColumn** columns, int start, int end,
193
0
                      const DataTypes& argument_types) {
194
0
        PFunctionCallRequest request;
195
0
        static_cast<void>(gen_request_data(request, columns, start, end, argument_types));
196
0
        _buffer_request.push_back(request);
197
0
        if (_buffer_request.size() >= max_buffered_rows) {
198
0
            static_cast<void>(send_buffer_to_rpc_server());
199
0
        }
200
0
        return Status::OK();
201
0
    }
202
203
    //clear buffer request
204
0
    Status send_buffer_to_rpc_server() {
205
0
        if (_buffer_request.size() > 0) {
206
0
            PFunctionCallRequest request;
207
0
            PFunctionCallResponse response;
208
0
            brpc::Controller cntl;
209
0
            merge_buffer_request(request);
210
0
            if (has_last_result()) {
211
0
                request.mutable_context()
212
0
                        ->mutable_function_context()
213
0
                        ->mutable_args_data()
214
0
                        ->CopyFrom(_res.result());
215
0
            }
216
0
            RETURN_IF_ERROR(send_rpc_request(cntl, request, response));
217
0
            _res = response;
218
0
            set_last_result(true);
219
0
            _buffer_request.clear();
220
0
        }
221
0
        return Status::OK();
222
0
    }
223
224
0
    Status add(const IColumn** columns, int start, int end, const DataTypes& argument_types) {
225
0
        PFunctionCallRequest request;
226
0
        PFunctionCallResponse response;
227
0
        brpc::Controller cntl;
228
0
        request.set_function_name(_update_fn);
229
0
        static_cast<void>(gen_request_data(request, columns, start, end, argument_types));
230
0
        if (has_last_result()) {
231
0
            request.mutable_context()->mutable_function_context()->mutable_args_data()->CopyFrom(
232
0
                    _res.result());
233
0
        }
234
0
        RETURN_IF_ERROR(send_rpc_request(cntl, request, response));
235
0
        _res = response;
236
0
        set_last_result(true);
237
0
        return Status::OK();
238
0
    }
239
240
0
    void serialize(BufferWritable& buf) {
241
0
        static_cast<void>(send_buffer_to_rpc_server());
242
0
        std::string serialize_data = error_default_str;
243
0
        if (!has_error()) {
244
0
            serialize_data = _res.SerializeAsString();
245
0
        } else {
246
0
            LOG(ERROR) << "serialize empty buf";
247
0
        }
248
0
        buf.write_binary(serialize_data);
249
0
    }
250
251
0
    void deserialize(BufferReadable& buf) {
252
0
        static_cast<void>(send_buffer_to_rpc_server());
253
0
        std::string serialize_data;
254
0
        buf.read_binary(serialize_data);
255
0
        if (error_default_str != serialize_data) {
256
0
            _res.ParseFromString(serialize_data);
257
0
            set_last_result(true);
258
0
        } else {
259
0
            LOG(ERROR) << "deserialize empty buf";
260
0
            set_error(true);
261
0
        }
262
0
    }
263
264
#define GETDATA(LOCATTYPE, TYPEVALUE)                                                          \
265
0
    if (response.result_size() > 0 && response.result(0).TYPEVALUE##_##value_size() > 0) {     \
266
0
        LOCATTYPE ret = response.result(0).TYPEVALUE##_##value(0);                             \
267
0
        to.insert_data((char*)&ret, 0);                                                        \
268
0
    } else {                                                                                   \
269
0
        LOG(ERROR) << "_server_addr:" << _server_addr << ",_finalize_fn:" << _finalize_fn      \
270
0
                   << ",msg: failed to get final result cause return type need " << #TYPEVALUE \
271
0
                   << "but result is empty";                                                   \
272
0
        to.insert_default();                                                                   \
273
0
    }
274
275
    //if any unexpected error happen will return NULL
276
0
    Status get(IColumn& to, const DataTypePtr& return_type) {
277
0
        if (has_error()) {
278
0
            to.insert_default();
279
0
            return Status::OK();
280
0
        }
281
0
        static_cast<void>(send_buffer_to_rpc_server());
282
0
        PFunctionCallRequest request;
283
0
        PFunctionCallResponse response;
284
0
        brpc::Controller cntl;
285
0
        request.set_function_name(_finalize_fn);
286
0
        request.mutable_context()->mutable_function_context()->mutable_args_data()->CopyFrom(
287
0
                _res.result());
288
0
        static_cast<void>(send_rpc_request(cntl, request, response));
289
0
        if (has_error()) {
290
0
            to.insert_default();
291
0
            return Status::OK();
292
0
        }
293
0
        DataTypePtr result_type = return_type;
294
0
        if (return_type->is_nullable()) {
295
0
            result_type =
296
0
                    reinterpret_cast<const DataTypeNullable*>(return_type.get())->get_nested_type();
297
0
        }
298
0
        switch (result_type->get_primitive_type()) {
299
0
        case TYPE_FLOAT:
300
0
            GETDATA(float, float);
301
0
            break;
302
0
        case TYPE_DOUBLE:
303
0
            GETDATA(double, double);
304
0
            break;
305
0
        case TYPE_INT:
306
0
            GETDATA(int32_t, int32);
307
0
            break;
308
0
        case TYPE_BIGINT:
309
0
            GETDATA(int64_t, int64);
310
0
            break;
311
0
        case TYPE_BOOLEAN:
312
0
            GETDATA(uint8_t, bool);
313
0
            break;
314
0
        case TYPE_STRING:
315
0
        case TYPE_CHAR:
316
0
        case TYPE_VARCHAR: {
317
0
            if (response.result_size() > 0 && response.result(0).string_value_size() > 0) {
318
0
                std::string ret = response.result(0).string_value(0);
319
0
                to.insert_data(ret.c_str(), ret.size());
320
0
            } else {
321
0
                LOG(ERROR) << "_server_addr:" << _server_addr << ",_finalize_fn:" << _finalize_fn
322
0
                           << ",msg: failed to get final result cause return type need string but "
323
0
                              "result is empty";
324
0
                to.insert_default();
325
0
            }
326
0
            break;
327
0
        }
328
0
        default:
329
0
            LOG(ERROR) << "failed to get result cause unkown return type";
330
0
            to.insert_default();
331
0
        }
332
0
        return Status::OK();
333
0
    }
334
335
0
    PFunctionCallResponse get_result() {
336
0
        static_cast<void>(send_buffer_to_rpc_server());
337
0
        return _res;
338
0
    }
339
};
340
341
class AggregateRpcUdaf final
342
        : public IAggregateFunctionDataHelper<AggregateRpcUdafData, AggregateRpcUdaf> {
343
public:
344
    AggregateRpcUdaf(const TFunction& fn, const DataTypes& argument_types_,
345
                     const DataTypePtr& return_type)
346
0
            : IAggregateFunctionDataHelper(argument_types_), _fn(fn), _return_type(return_type) {}
347
0
    ~AggregateRpcUdaf() = default;
348
349
    static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types_,
350
0
                                       const DataTypePtr& return_type) {
351
0
        return std::make_shared<AggregateRpcUdaf>(fn, argument_types_, return_type);
352
0
    }
353
354
0
    void create(AggregateDataPtr __restrict place) const override {
355
0
        new (place) Data(argument_types.size());
356
0
        Status status = Status::OK();
357
0
        SAFE_CREATE(RETURN_IF_STATUS_ERROR(status, data(place).init(_fn)),
358
0
                    this->data(place).~Data());
359
0
    }
360
361
0
    String get_name() const override { return _fn.name.function_name; }
362
363
0
    DataTypePtr get_return_type() const override { return _return_type; }
364
365
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
366
0
             Arena&) const override {
367
0
        static_cast<void>(
368
0
                this->data(place).buffer_add(columns, row_num, row_num + 1, argument_types));
369
0
    }
370
371
    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
372
0
                                Arena&) const override {
373
0
        static_cast<void>(this->data(place).add(columns, 0, batch_size, argument_types));
374
0
    }
375
376
0
    void reset(AggregateDataPtr place) const override {}
377
378
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
379
0
               Arena&) const override {
380
        // place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
381
        // todo: rethink the merge method to determine whether const_cast is necessary.
382
0
        static_cast<void>(this->data(place).merge(this->data(const_cast<AggregateDataPtr>(rhs))));
383
0
    }
384
385
0
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
386
        // place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
387
        // todo: rethink the serialize method to determine whether const_cast is necessary.
388
0
        this->data(const_cast<AggregateDataPtr&>(place)).serialize(buf);
389
0
    }
390
391
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
392
0
                     Arena&) const override {
393
0
        this->data(place).deserialize(buf);
394
0
    }
395
396
0
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
397
        // place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
398
        // todo: rethink the get method to determine whether const_cast is necessary.
399
0
        static_cast<void>(this->data(const_cast<AggregateDataPtr>(place)).get(to, _return_type));
400
0
    }
401
402
private:
403
    TFunction _fn;
404
    DataTypePtr _return_type;
405
};
406
407
#include "common/compile_check_avoid_end.h"
408
} // namespace doris