Coverage Report

Created: 2026-03-14 20:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/function_python_udf.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "exprs/function/function_python_udf.h"
19
20
#include <arrow/record_batch.h>
21
#include <arrow/type_fwd.h>
22
#include <fmt/core.h>
23
#include <glog/logging.h>
24
25
#include <cstddef>
26
#include <cstdint>
27
#include <ctime>
28
#include <memory>
29
30
#include "common/status.h"
31
#include "core/block/block.h"
32
#include "exec/connector/jni_connector.h"
33
#include "format/arrow/arrow_block_convertor.h"
34
#include "format/arrow/arrow_row_batch.h"
35
#include "runtime/user_function_cache.h"
36
#include "udf/python/python_server.h"
37
#include "udf/python/python_udf_client.h"
38
#include "udf/python/python_udf_meta.h"
39
#include "util/timezone_utils.h"
40
41
namespace doris {
42
43
PythonFunctionCall::PythonFunctionCall(const TFunction& fn, const DataTypes& argument_types,
44
                                       const DataTypePtr& return_type)
45
0
        : _fn(fn), _argument_types(argument_types), _return_type(return_type) {}
46
47
Status PythonFunctionCall::open(FunctionContext* context,
48
0
                                FunctionContext::FunctionStateScope scope) {
49
0
    if (scope == FunctionContext::FunctionStateScope::FRAGMENT_LOCAL) {
50
0
        LOG(INFO) << "Open python UDF fragment local";
51
0
        return Status::OK();
52
0
    }
53
54
0
    PythonVersion version;
55
0
    PythonUDFMeta func_meta;
56
0
    func_meta.id = _fn.id;
57
0
    func_meta.name = _fn.name.function_name;
58
0
    func_meta.symbol = _fn.scalar_fn.symbol;
59
0
    if (!_fn.function_code.empty()) {
60
0
        func_meta.type = PythonUDFLoadType::INLINE;
61
0
        func_meta.location = "inline";
62
0
        func_meta.inline_code = _fn.function_code;
63
0
    } else if (!_fn.hdfs_location.empty()) {
64
0
        func_meta.type = PythonUDFLoadType::MODULE;
65
0
        func_meta.location = _fn.hdfs_location;
66
0
        func_meta.checksum = _fn.checksum;
67
0
    } else {
68
0
        func_meta.type = PythonUDFLoadType::UNKNOWN;
69
0
        func_meta.location = "unknown";
70
0
    }
71
72
0
    func_meta.input_types = _argument_types;
73
0
    func_meta.return_type = _return_type;
74
0
    func_meta.client_type = PythonClientType::UDF;
75
76
0
    if (_fn.__isset.runtime_version && !_fn.runtime_version.empty()) {
77
0
        RETURN_IF_ERROR(
78
0
                PythonVersionManager::instance().get_version(_fn.runtime_version, &version));
79
0
    } else {
80
0
        return Status::InvalidArgument("Python UDF runtime version is not set");
81
0
    }
82
83
0
    func_meta.runtime_version = version.full_version;
84
0
    RETURN_IF_ERROR(func_meta.check());
85
0
    func_meta.always_nullable = _return_type->is_nullable();
86
0
    LOG(INFO) << fmt::format("runtime_version: {}, func_meta: {}", version.to_string(),
87
0
                             func_meta.to_string());
88
89
0
    if (func_meta.type == PythonUDFLoadType::MODULE) {
90
0
        RETURN_IF_ERROR(UserFunctionCache::instance()->get_pypath(
91
0
                func_meta.id, func_meta.location, func_meta.checksum, &func_meta.location));
92
0
    }
93
94
0
    PythonUDFClientPtr client = nullptr;
95
0
    RETURN_IF_ERROR(PythonServerManager::instance().get_client(func_meta, version, &client));
96
97
0
    if (!client) {
98
0
        return Status::InternalError("Python UDF client is null");
99
0
    }
100
101
0
    context->set_function_state(FunctionContext::THREAD_LOCAL, client);
102
0
    LOG(INFO) << fmt::format("Successfully get python UDF client, process: {}",
103
0
                             client->print_process());
104
0
    return Status::OK();
105
0
}
106
107
Status PythonFunctionCall::execute_impl(FunctionContext* context, Block& block,
108
                                        const ColumnNumbers& arguments, uint32_t result,
109
0
                                        size_t num_rows) const {
110
0
    auto client = reinterpret_cast<PythonUDFClient*>(
111
0
            context->get_function_state(FunctionContext::THREAD_LOCAL));
112
0
    if (!client) {
113
0
        LOG(WARNING) << "Python UDF client is null";
114
0
        return Status::InternalError("Python UDF client is null");
115
0
    }
116
117
0
    int64_t input_rows = block.rows();
118
0
    uint32_t input_columns = block.columns();
119
0
    DCHECK(input_columns > 0 && result < input_columns &&
120
0
           _argument_types.size() == arguments.size());
121
0
    Block input_block;
122
0
    Block output_block;
123
124
0
    if (!_return_type->equals(*block.get_by_position(result).type)) {
125
0
        return Status::InternalError(fmt::format("Python UDF output type {} not equal to {}",
126
0
                                                 block.get_by_position(result).type->get_name(),
127
0
                                                 _return_type->get_name()));
128
0
    }
129
130
0
    for (uint32_t i = 0; i < arguments.size(); ++i) {
131
0
        if (!_argument_types[i]->equals(*block.get_by_position(arguments[i]).type)) {
132
0
            return Status::InternalError(
133
0
                    fmt::format("Python UDF input type {} not equal to {}",
134
0
                                block.get_by_position(arguments[i]).type->get_name(),
135
0
                                _argument_types[i]->get_name()));
136
0
        }
137
0
        input_block.insert(block.get_by_position(arguments[i]));
138
0
    }
139
140
0
    std::shared_ptr<arrow::Schema> schema;
141
0
    RETURN_IF_ERROR(
142
0
            get_arrow_schema_from_block(input_block, &schema, TimezoneUtils::default_time_zone));
143
0
    std::shared_ptr<arrow::RecordBatch> input_batch;
144
0
    std::shared_ptr<arrow::RecordBatch> output_batch;
145
0
    cctz::time_zone _timezone_obj; // default UTC
146
0
    RETURN_IF_ERROR(convert_to_arrow_batch(input_block, schema, arrow::default_memory_pool(),
147
0
                                           &input_batch, _timezone_obj));
148
0
    RETURN_IF_ERROR(client->evaluate(*input_batch, &output_batch));
149
0
    int64_t output_rows = output_batch->num_rows();
150
151
0
    if (output_batch->num_columns() != 1) {
152
0
        return Status::InternalError(fmt::format("Python UDF output columns {} not equal to 1",
153
0
                                                 output_batch->num_columns()));
154
0
    }
155
156
0
    if (input_rows != output_rows) {
157
0
        return Status::InternalError(fmt::format(
158
0
                "Python UDF output rows {} not equal to input rows {}", output_rows, input_rows));
159
0
    }
160
161
0
    RETURN_IF_ERROR(
162
0
            convert_from_arrow_batch(output_batch, {_return_type}, &output_block, _timezone_obj));
163
0
    DCHECK_EQ(output_block.columns(), 1);
164
0
    block.replace_by_position(result, std::move(output_block.get_by_position(0).column));
165
0
    return Status::OK();
166
0
}
167
168
Status PythonFunctionCall::close(FunctionContext* context,
169
0
                                 FunctionContext::FunctionStateScope scope) {
170
0
    auto client = reinterpret_cast<PythonUDFClient*>(
171
0
            context->get_function_state(FunctionContext::THREAD_LOCAL));
172
0
    if (!client) {
173
0
        LOG(WARNING) << "Python UDF client is null";
174
0
        return Status::InternalError("Python UDF client is null");
175
0
    }
176
0
    RETURN_IF_ERROR(client->close());
177
0
    return Status::OK();
178
0
}
179
180
} // namespace doris