Coverage Report

Created: 2026-03-17 16:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/function_python_udf.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "exprs/function/function_python_udf.h"
19
20
#include <arrow/record_batch.h>
21
#include <arrow/type_fwd.h>
22
#include <fmt/core.h>
23
#include <glog/logging.h>
24
25
#include <cstddef>
26
#include <cstdint>
27
#include <ctime>
28
#include <memory>
29
30
#include "common/status.h"
31
#include "core/block/block.h"
32
#include "format/arrow/arrow_block_convertor.h"
33
#include "runtime/user_function_cache.h"
34
#include "udf/python/python_server.h"
35
#include "udf/python/python_udf_client.h"
36
#include "udf/python/python_udf_meta.h"
37
#include "util/timezone_utils.h"
38
39
namespace doris {
40
41
PythonFunctionCall::PythonFunctionCall(const TFunction& fn, const DataTypes& argument_types,
42
                                       const DataTypePtr& return_type)
43
344
        : _fn(fn), _argument_types(argument_types), _return_type(return_type) {}
44
45
Status PythonFunctionCall::open(FunctionContext* context,
46
1.57k
                                FunctionContext::FunctionStateScope scope) {
47
1.57k
    if (scope == FunctionContext::FunctionStateScope::FRAGMENT_LOCAL) {
48
344
        LOG(INFO) << "Open python UDF fragment local";
49
344
        return Status::OK();
50
344
    }
51
52
1.23k
    PythonVersion version;
53
1.23k
    PythonUDFMeta func_meta;
54
1.23k
    func_meta.id = _fn.id;
55
1.23k
    func_meta.name = _fn.name.function_name;
56
1.23k
    func_meta.symbol = _fn.scalar_fn.symbol;
57
1.23k
    if (!_fn.function_code.empty()) {
58
568
        func_meta.type = PythonUDFLoadType::INLINE;
59
568
        func_meta.location = "inline";
60
568
        func_meta.inline_code = _fn.function_code;
61
665
    } else if (!_fn.hdfs_location.empty()) {
62
665
        func_meta.type = PythonUDFLoadType::MODULE;
63
665
        func_meta.location = _fn.hdfs_location;
64
665
        func_meta.checksum = _fn.checksum;
65
18.4E
    } else {
66
18.4E
        func_meta.type = PythonUDFLoadType::UNKNOWN;
67
18.4E
        func_meta.location = "unknown";
68
18.4E
    }
69
70
1.23k
    func_meta.input_types = _argument_types;
71
1.23k
    func_meta.return_type = _return_type;
72
1.23k
    func_meta.client_type = PythonClientType::UDF;
73
74
1.23k
    if (_fn.__isset.runtime_version && !_fn.runtime_version.empty()) {
75
1.23k
        RETURN_IF_ERROR(
76
1.23k
                PythonVersionManager::instance().get_version(_fn.runtime_version, &version));
77
18.4E
    } else {
78
18.4E
        return Status::InvalidArgument("Python UDF runtime version is not set");
79
18.4E
    }
80
81
1.23k
    func_meta.runtime_version = version.full_version;
82
1.23k
    RETURN_IF_ERROR(func_meta.check());
83
1.23k
    func_meta.always_nullable = _return_type->is_nullable();
84
1.23k
    LOG(INFO) << fmt::format("runtime_version: {}, func_meta: {}", version.to_string(),
85
1.23k
                             func_meta.to_string());
86
87
1.23k
    if (func_meta.type == PythonUDFLoadType::MODULE) {
88
665
        RETURN_IF_ERROR(UserFunctionCache::instance()->get_pypath(
89
665
                func_meta.id, func_meta.location, func_meta.checksum, &func_meta.location));
90
665
    }
91
92
1.23k
    PythonUDFClientPtr client = nullptr;
93
1.23k
    RETURN_IF_ERROR(PythonServerManager::instance().get_client(func_meta, version, &client));
94
95
1.23k
    if (!client) {
96
0
        return Status::InternalError("Python UDF client is null");
97
0
    }
98
99
1.23k
    context->set_function_state(FunctionContext::THREAD_LOCAL, client);
100
1.23k
    LOG(INFO) << fmt::format("Successfully get python UDF client, process: {}",
101
1.23k
                             client->print_process());
102
1.23k
    return Status::OK();
103
1.23k
}
104
105
Status PythonFunctionCall::execute_impl(FunctionContext* context, Block& block,
106
                                        const ColumnNumbers& arguments, uint32_t result,
107
520
                                        size_t num_rows) const {
108
520
    auto client = reinterpret_cast<PythonUDFClient*>(
109
520
            context->get_function_state(FunctionContext::THREAD_LOCAL));
110
520
    if (!client) {
111
0
        LOG(WARNING) << "Python UDF client is null";
112
0
        return Status::InternalError("Python UDF client is null");
113
0
    }
114
115
520
    int64_t input_rows = block.rows();
116
520
    uint32_t input_columns = block.columns();
117
520
    DCHECK(input_columns > 0 && result < input_columns &&
118
520
           _argument_types.size() == arguments.size());
119
520
    Block input_block;
120
520
    Block output_block;
121
122
520
    if (!_return_type->equals(*block.get_by_position(result).type)) {
123
0
        return Status::InternalError(fmt::format("Python UDF output type {} not equal to {}",
124
0
                                                 block.get_by_position(result).type->get_name(),
125
0
                                                 _return_type->get_name()));
126
0
    }
127
128
1.42k
    for (uint32_t i = 0; i < arguments.size(); ++i) {
129
902
        if (!_argument_types[i]->equals(*block.get_by_position(arguments[i]).type)) {
130
0
            return Status::InternalError(
131
0
                    fmt::format("Python UDF input type {} not equal to {}",
132
0
                                block.get_by_position(arguments[i]).type->get_name(),
133
0
                                _argument_types[i]->get_name()));
134
0
        }
135
902
        input_block.insert(block.get_by_position(arguments[i]));
136
902
    }
137
138
520
    std::shared_ptr<arrow::Schema> schema;
139
520
    RETURN_IF_ERROR(
140
520
            get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
141
520
    std::shared_ptr<arrow::RecordBatch> input_batch;
142
520
    std::shared_ptr<arrow::RecordBatch> output_batch;
143
520
    cctz::time_zone _timezone_obj; // default UTC
144
520
    RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
145
520
                                           &input_batch, _timezone_obj));
146
520
    RETURN_IF_ERROR(client->evaluate(*input_batch, &output_batch));
147
518
    int64_t output_rows = output_batch->num_rows();
148
149
518
    if (output_batch->num_columns() != 1) {
150
0
        return Status::InternalError(fmt::format("Python UDF output columns {} not equal to 1",
151
0
                                                 output_batch->num_columns()));
152
0
    }
153
154
518
    if (input_rows != output_rows) {
155
0
        return Status::InternalError(fmt::format(
156
0
                "Python UDF output rows {} not equal to input rows {}", output_rows, input_rows));
157
0
    }
158
159
518
    RETURN_IF_ERROR(
160
518
            convert_from_arrow_batch(output_batch, {_return_type}, &output_block, _timezone_obj));
161
518
    DCHECK_EQ(output_block.columns(), 1);
162
518
    block.replace_by_position(result, std::move(output_block.get_by_position(0).column));
163
518
    return Status::OK();
164
518
}
165
166
Status PythonFunctionCall::close(FunctionContext* context,
167
1.57k
                                 FunctionContext::FunctionStateScope scope) {
168
1.57k
    auto client = reinterpret_cast<PythonUDFClient*>(
169
1.57k
            context->get_function_state(FunctionContext::THREAD_LOCAL));
170
1.57k
    if (!client) {
171
0
        LOG(WARNING) << "Python UDF client is null";
172
0
        return Status::InternalError("Python UDF client is null");
173
0
    }
174
1.57k
    RETURN_IF_ERROR(client->close());
175
1.57k
    return Status::OK();
176
1.57k
}
177
178
} // namespace doris