be/src/udf/python/python_udf_meta.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "udf/python/python_udf_meta.h" |
19 | | |
20 | | #include <arrow/util/base64.h> |
21 | | #include <fmt/core.h> |
22 | | #include <rapidjson/stringbuffer.h> |
23 | | #include <rapidjson/writer.h> |
24 | | |
25 | | #include <sstream> |
26 | | |
27 | | #include "common/status.h" |
28 | | #include "format/arrow/arrow_utils.h" |
29 | | #include "util/string_util.h" |
30 | | |
31 | | namespace doris { |
32 | | |
33 | | Status PythonUDFMeta::convert_types_to_schema(const DataTypes& types, const std::string& timezone, |
34 | 13.7k | std::shared_ptr<arrow::Schema>* schema) { |
35 | 13.7k | arrow::SchemaBuilder builder; |
36 | 31.2k | for (size_t i = 0; i < types.size(); ++i) { |
37 | 17.4k | std::shared_ptr<arrow::DataType> arrow_type; |
38 | 17.4k | RETURN_IF_ERROR(convert_to_arrow_type(types[i], &arrow_type, timezone)); |
39 | 17.4k | std::shared_ptr<arrow::Field> field = std::make_shared<arrow::Field>( |
40 | 17.4k | "arg" + std::to_string(i), arrow_type, types[i]->is_nullable()); |
41 | 17.4k | RETURN_DORIS_STATUS_IF_ERROR(builder.AddField(field)); |
42 | 17.4k | } |
43 | 13.7k | RETURN_DORIS_STATUS_IF_RESULT_ERROR(schema, builder.Finish()); |
44 | 13.7k | return Status::OK(); |
45 | 13.7k | } |
46 | | |
47 | | Status PythonUDFMeta::serialize_arrow_schema(const std::shared_ptr<arrow::Schema>& schema, |
48 | 13.7k | std::shared_ptr<arrow::Buffer>* out) { |
49 | 13.7k | RETURN_DORIS_STATUS_IF_RESULT_ERROR( |
50 | 13.7k | out, arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())); |
51 | 13.7k | return Status::OK(); |
52 | 13.7k | } |
53 | | |
54 | | /* |
55 | | json format: |
56 | | { |
57 | | "name": "xxx", |
58 | | "id": 123, |
59 | | "symbol": "xxx", |
60 | | "location": "xxx", |
61 | | "udf_load_type": 0 or 1, |
62 | | "client_type": 0 (UDF) or 1 (UDAF) or 2 (UDTF), |
63 | | "runtime_version": "x.xx.xx", |
64 | | "always_nullable": true, |
65 | | "inline_code": "base64_inline_code", |
66 | | "input_types": "base64_input_types", |
67 | | "return_type": "base64_return_type" |
68 | | } |
69 | | */ |
70 | 6.89k | Status PythonUDFMeta::serialize_to_json(std::string* json_str) const { |
71 | 6.89k | rapidjson::Document doc; |
72 | 6.89k | doc.SetObject(); |
73 | 6.89k | auto& allocator = doc.GetAllocator(); |
74 | 6.89k | doc.AddMember("name", rapidjson::Value().SetString(name.c_str(), allocator), allocator); |
75 | 6.89k | doc.AddMember("id", rapidjson::Value().SetInt64(id), allocator); |
76 | 6.89k | doc.AddMember("symbol", rapidjson::Value().SetString(symbol.c_str(), allocator), allocator); |
77 | 6.89k | doc.AddMember("location", rapidjson::Value().SetString(location.c_str(), allocator), allocator); |
78 | 6.89k | doc.AddMember("udf_load_type", rapidjson::Value().SetInt(static_cast<int>(type)), allocator); |
79 | 6.89k | doc.AddMember("client_type", rapidjson::Value().SetInt(static_cast<int>(client_type)), |
80 | 6.89k | allocator); |
81 | 6.89k | doc.AddMember("runtime_version", |
82 | 6.89k | rapidjson::Value().SetString(runtime_version.c_str(), allocator), allocator); |
83 | 6.89k | doc.AddMember("always_nullable", rapidjson::Value().SetBool(always_nullable), allocator); |
84 | | |
85 | 6.89k | { |
86 | | // Serialize base64 inline code to json |
87 | 6.89k | std::string base64_str = arrow::util::base64_encode(inline_code); |
88 | 6.89k | doc.AddMember("inline_code", rapidjson::Value().SetString(base64_str.c_str(), allocator), |
89 | 6.89k | allocator); |
90 | 6.89k | } |
91 | 6.89k | { |
92 | | // Serialize base64 input types to json |
93 | 6.89k | std::shared_ptr<arrow::Schema> input_schema; |
94 | 6.89k | RETURN_IF_ERROR(convert_types_to_schema(input_types, TimezoneUtils::default_time_zone, |
95 | 6.89k | &input_schema)); |
96 | 6.89k | std::shared_ptr<arrow::Buffer> input_schema_buffer; |
97 | 6.89k | RETURN_IF_ERROR(serialize_arrow_schema(input_schema, &input_schema_buffer)); |
98 | 6.89k | std::string base64_str = |
99 | 6.89k | arrow::util::base64_encode({input_schema_buffer->data_as<char>(), |
100 | 6.89k | static_cast<size_t>(input_schema_buffer->size())}); |
101 | 6.89k | doc.AddMember("input_types", rapidjson::Value().SetString(base64_str.c_str(), allocator), |
102 | 6.89k | allocator); |
103 | 6.89k | } |
104 | 0 | { |
105 | | // Serialize base64 return type to json |
106 | 6.89k | std::shared_ptr<arrow::Schema> return_schema; |
107 | 6.89k | RETURN_IF_ERROR(convert_types_to_schema({return_type}, TimezoneUtils::default_time_zone, |
108 | 6.89k | &return_schema)); |
109 | 6.89k | std::shared_ptr<arrow::Buffer> return_schema_buffer; |
110 | 6.89k | RETURN_IF_ERROR(serialize_arrow_schema(return_schema, &return_schema_buffer)); |
111 | 6.89k | std::string base64_str = |
112 | 6.89k | arrow::util::base64_encode({return_schema_buffer->data_as<char>(), |
113 | 6.89k | static_cast<size_t>(return_schema_buffer->size())}); |
114 | 6.89k | doc.AddMember("return_type", rapidjson::Value().SetString(base64_str.c_str(), allocator), |
115 | 6.89k | allocator); |
116 | 6.89k | } |
117 | | |
118 | | // Convert document to json string |
119 | 0 | rapidjson::StringBuffer buffer; |
120 | 6.89k | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
121 | 6.89k | doc.Accept(writer); |
122 | 6.89k | *json_str = std::string(buffer.GetString(), buffer.GetSize()); |
123 | 6.89k | return Status::OK(); |
124 | 6.89k | } |
125 | | |
126 | 2.00k | std::string PythonUDFMeta::to_string() const { |
127 | 2.00k | std::stringstream input_types_ss; |
128 | 2.00k | input_types_ss << "<"; |
129 | 5.23k | for (size_t i = 0; i < input_types.size(); ++i) { |
130 | 3.22k | input_types_ss << input_types[i]->get_name(); |
131 | 3.22k | if (i != input_types.size() - 1) { |
132 | 1.23k | input_types_ss << ", "; |
133 | 1.23k | } |
134 | 3.22k | } |
135 | 2.00k | input_types_ss << ">"; |
136 | 2.00k | return fmt::format( |
137 | 2.00k | "[name: {}, symbol: {}, location: {}, runtime_version: {}, always_nullable: {}, " |
138 | 2.00k | "inline_code: {}][input_types: {}][return_type: {}]", |
139 | 2.00k | name, symbol, location, runtime_version, always_nullable, inline_code, |
140 | 2.00k | input_types_ss.str(), return_type->get_name()); |
141 | 2.00k | } |
142 | | |
143 | 4.11k | Status PythonUDFMeta::check() const { |
144 | 4.11k | if (trim(name).empty()) { |
145 | 2 | return Status::InvalidArgument("Python UDF name is empty"); |
146 | 2 | } |
147 | | |
148 | 4.11k | if (trim(symbol).empty()) { |
149 | 1 | return Status::InvalidArgument("Python UDF symbol is empty"); |
150 | 1 | } |
151 | | |
152 | 4.11k | if (trim(runtime_version).empty()) { |
153 | 1 | return Status::InvalidArgument("Python UDF runtime version is empty"); |
154 | 1 | } |
155 | | |
156 | 4.11k | if (input_types.empty() && |
157 | 4.11k | (client_type == PythonClientType::UDAF || type == PythonUDFLoadType::UNKNOWN)) { |
158 | 1 | return Status::InvalidArgument("Python UDAF input types is empty"); |
159 | 1 | } |
160 | | |
161 | 4.10k | if (!return_type) { |
162 | 1 | return Status::InvalidArgument("Python UDF return type is empty"); |
163 | 1 | } |
164 | | |
165 | 4.10k | if (type == PythonUDFLoadType::UNKNOWN) { |
166 | 1 | return Status::InvalidArgument( |
167 | 1 | "Python UDF load type is invalid, please check inline code or file path"); |
168 | 1 | } |
169 | | |
170 | 4.10k | if (type == PythonUDFLoadType::MODULE) { |
171 | 1.85k | if (trim(location).empty()) { |
172 | 1 | return Status::InvalidArgument("Non-inline Python UDF location is empty"); |
173 | 1 | } |
174 | 1.85k | if (trim(checksum).empty()) { |
175 | 1 | return Status::InvalidArgument("Non-inline Python UDF checksum is empty"); |
176 | 1 | } |
177 | 1.85k | } |
178 | | |
179 | 4.10k | return Status::OK(); |
180 | 4.10k | } |
181 | | |
182 | | } // namespace doris |