be/src/udf/python/python_udf_meta.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "udf/python/python_udf_meta.h" |
19 | | |
20 | | #include <arrow/util/base64.h> |
21 | | #include <fmt/core.h> |
22 | | #include <rapidjson/stringbuffer.h> |
23 | | #include <rapidjson/writer.h> |
24 | | |
25 | | #include <sstream> |
26 | | |
27 | | #include "common/status.h" |
28 | | #include "format/arrow/arrow_utils.h" |
29 | | #include "util/string_util.h" |
30 | | |
31 | | namespace doris { |
32 | | |
33 | | Status PythonUDFMeta::convert_types_to_schema(const DataTypes& types, const std::string& timezone, |
34 | 12 | std::shared_ptr<arrow::Schema>* schema) { |
35 | 12 | assert(!types.empty()); |
36 | 12 | arrow::SchemaBuilder builder; |
37 | 27 | for (size_t i = 0; i < types.size(); ++i) { |
38 | 15 | std::shared_ptr<arrow::DataType> arrow_type; |
39 | 15 | RETURN_IF_ERROR(convert_to_arrow_type(types[i], &arrow_type, timezone)); |
40 | 15 | std::shared_ptr<arrow::Field> field = std::make_shared<arrow::Field>( |
41 | 15 | "arg" + std::to_string(i), arrow_type, types[i]->is_nullable()); |
42 | 15 | RETURN_DORIS_STATUS_IF_ERROR(builder.AddField(field)); |
43 | 15 | } |
44 | 12 | RETURN_DORIS_STATUS_IF_RESULT_ERROR(schema, builder.Finish()); |
45 | 12 | return Status::OK(); |
46 | 12 | } |
47 | | |
48 | | Status PythonUDFMeta::serialize_arrow_schema(const std::shared_ptr<arrow::Schema>& schema, |
49 | 11 | std::shared_ptr<arrow::Buffer>* out) { |
50 | 11 | RETURN_DORIS_STATUS_IF_RESULT_ERROR( |
51 | 11 | out, arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())); |
52 | 11 | return Status::OK(); |
53 | 11 | } |
54 | | |
55 | | /* |
56 | | json format: |
57 | | { |
58 | | "name": "xxx", |
59 | | "symbol": "xxx", |
60 | | "location": "xxx", |
61 | | "udf_load_type": 0 or 1, |
62 | | "client_type": 0 (UDF) or 1 (UDAF) or 2 (UDTF), |
63 | | "runtime_version": "x.xx.xx", |
64 | | "always_nullable": true, |
65 | | "inline_code": "base64_inline_code", |
66 | | "input_types": "base64_input_types", |
67 | | "return_type": "base64_return_type" |
68 | | } |
69 | | */ |
70 | 5 | Status PythonUDFMeta::serialize_to_json(std::string* json_str) const { |
71 | 5 | rapidjson::Document doc; |
72 | 5 | doc.SetObject(); |
73 | 5 | auto& allocator = doc.GetAllocator(); |
74 | 5 | doc.AddMember("name", rapidjson::Value().SetString(name.c_str(), allocator), allocator); |
75 | 5 | doc.AddMember("symbol", rapidjson::Value().SetString(symbol.c_str(), allocator), allocator); |
76 | 5 | doc.AddMember("location", rapidjson::Value().SetString(location.c_str(), allocator), allocator); |
77 | 5 | doc.AddMember("udf_load_type", rapidjson::Value().SetInt(static_cast<int>(type)), allocator); |
78 | 5 | doc.AddMember("client_type", rapidjson::Value().SetInt(static_cast<int>(client_type)), |
79 | 5 | allocator); |
80 | 5 | doc.AddMember("runtime_version", |
81 | 5 | rapidjson::Value().SetString(runtime_version.c_str(), allocator), allocator); |
82 | 5 | doc.AddMember("always_nullable", rapidjson::Value().SetBool(always_nullable), allocator); |
83 | | |
84 | 5 | { |
85 | | // Serialize base64 inline code to json |
86 | 5 | std::string base64_str = arrow::util::base64_encode(inline_code); |
87 | 5 | doc.AddMember("inline_code", rapidjson::Value().SetString(base64_str.c_str(), allocator), |
88 | 5 | allocator); |
89 | 5 | } |
90 | 5 | { |
91 | | // Serialize base64 input types to json |
92 | 5 | std::shared_ptr<arrow::Schema> input_schema; |
93 | 5 | RETURN_IF_ERROR(convert_types_to_schema(input_types, TimezoneUtils::default_time_zone, |
94 | 5 | &input_schema)); |
95 | 5 | std::shared_ptr<arrow::Buffer> input_schema_buffer; |
96 | 5 | RETURN_IF_ERROR(serialize_arrow_schema(input_schema, &input_schema_buffer)); |
97 | 5 | std::string base64_str = |
98 | 5 | arrow::util::base64_encode({input_schema_buffer->data_as<char>(), |
99 | 5 | static_cast<size_t>(input_schema_buffer->size())}); |
100 | 5 | doc.AddMember("input_types", rapidjson::Value().SetString(base64_str.c_str(), allocator), |
101 | 5 | allocator); |
102 | 5 | } |
103 | 0 | { |
104 | | // Serialize base64 return type to json |
105 | 5 | std::shared_ptr<arrow::Schema> return_schema; |
106 | 5 | RETURN_IF_ERROR(convert_types_to_schema({return_type}, TimezoneUtils::default_time_zone, |
107 | 5 | &return_schema)); |
108 | 5 | std::shared_ptr<arrow::Buffer> return_schema_buffer; |
109 | 5 | RETURN_IF_ERROR(serialize_arrow_schema(return_schema, &return_schema_buffer)); |
110 | 5 | std::string base64_str = |
111 | 5 | arrow::util::base64_encode({return_schema_buffer->data_as<char>(), |
112 | 5 | static_cast<size_t>(return_schema_buffer->size())}); |
113 | 5 | doc.AddMember("return_type", rapidjson::Value().SetString(base64_str.c_str(), allocator), |
114 | 5 | allocator); |
115 | 5 | } |
116 | | |
117 | | // Convert document to json string |
118 | 0 | rapidjson::StringBuffer buffer; |
119 | 5 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
120 | 5 | doc.Accept(writer); |
121 | 5 | *json_str = std::string(buffer.GetString(), buffer.GetSize()); |
122 | 5 | return Status::OK(); |
123 | 5 | } |
124 | | |
125 | 2 | std::string PythonUDFMeta::to_string() const { |
126 | 2 | std::stringstream input_types_ss; |
127 | 2 | input_types_ss << "<"; |
128 | 7 | for (size_t i = 0; i < input_types.size(); ++i) { |
129 | 5 | input_types_ss << input_types[i]->get_name(); |
130 | 5 | if (i != input_types.size() - 1) { |
131 | 3 | input_types_ss << ", "; |
132 | 3 | } |
133 | 5 | } |
134 | 2 | input_types_ss << ">"; |
135 | 2 | return fmt::format( |
136 | 2 | "[name: {}, symbol: {}, location: {}, runtime_version: {}, always_nullable: {}, " |
137 | 2 | "inline_code: {}][input_types: {}][return_type: {}]", |
138 | 2 | name, symbol, location, runtime_version, always_nullable, inline_code, |
139 | 2 | input_types_ss.str(), return_type->get_name()); |
140 | 2 | } |
141 | | |
142 | 11 | Status PythonUDFMeta::check() const { |
143 | 11 | if (trim(name).empty()) { |
144 | 2 | return Status::InvalidArgument("Python UDF name is empty"); |
145 | 2 | } |
146 | | |
147 | 9 | if (trim(symbol).empty()) { |
148 | 1 | return Status::InvalidArgument("Python UDF symbol is empty"); |
149 | 1 | } |
150 | | |
151 | 8 | if (trim(runtime_version).empty()) { |
152 | 1 | return Status::InvalidArgument("Python UDF runtime version is empty"); |
153 | 1 | } |
154 | | |
155 | 7 | if (input_types.empty()) { |
156 | 1 | return Status::InvalidArgument("Python UDF input types is empty"); |
157 | 1 | } |
158 | | |
159 | 6 | if (!return_type) { |
160 | 1 | return Status::InvalidArgument("Python UDF return type is empty"); |
161 | 1 | } |
162 | | |
163 | 5 | if (type == PythonUDFLoadType::UNKNOWN) { |
164 | 1 | return Status::InvalidArgument( |
165 | 1 | "Python UDF load type is invalid, please check inline code or file path"); |
166 | 1 | } |
167 | | |
168 | 4 | if (type == PythonUDFLoadType::MODULE) { |
169 | 3 | if (trim(location).empty()) { |
170 | 1 | return Status::InvalidArgument("Non-inline Python UDF location is empty"); |
171 | 1 | } |
172 | 2 | if (trim(checksum).empty()) { |
173 | 1 | return Status::InvalidArgument("Non-inline Python UDF checksum is empty"); |
174 | 1 | } |
175 | 2 | } |
176 | | |
177 | 2 | return Status::OK(); |
178 | 4 | } |
179 | | |
180 | | } // namespace doris |