be/src/exprs/function/ai/embed.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <glog/logging.h> |
21 | | #include <rapidjson/document.h> |
22 | | |
23 | | #include <string_view> |
24 | | |
25 | | #include "core/data_type/data_type_nullable.h" |
26 | | #include "core/data_type/primitive_type.h" |
27 | | #include "exprs/function/ai/ai_functions.h" |
28 | | #include "util/jsonb_utils.h" |
29 | | #include "util/s3_uri.h" |
30 | | #include "util/s3_util.h" |
31 | | |
32 | | namespace doris { |
33 | | class FunctionEmbed : public AIFunction<FunctionEmbed> { |
34 | | public: |
35 | | static constexpr auto name = "embed"; |
36 | | |
37 | | static constexpr size_t number_of_arguments = 2; |
38 | | |
39 | | static constexpr auto system_prompt = ""; |
40 | | |
41 | 1 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
42 | 1 | return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())); |
43 | 1 | } |
44 | | |
45 | | Status execute_with_adapter(FunctionContext* context, Block& block, |
46 | | const ColumnNumbers& arguments, uint32_t result, |
47 | | size_t input_rows_count, const TAIResource& config, |
48 | 10 | std::shared_ptr<AIAdapter>& adapter) const { |
49 | 10 | if (arguments.size() != 2) { |
50 | 1 | return Status::InvalidArgument("Function EMBED expects 2 arguments, but got {}", |
51 | 1 | arguments.size()); |
52 | 1 | } |
53 | | |
54 | 9 | PrimitiveType input_type = |
55 | 9 | remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type(); |
56 | 9 | if (input_type == PrimitiveType::TYPE_JSONB) { |
57 | 6 | return _execute_multimodal_embed(context, block, arguments, result, input_rows_count, |
58 | 6 | config, adapter); |
59 | 6 | } |
60 | 3 | if (input_type == PrimitiveType::TYPE_STRING || input_type == PrimitiveType::TYPE_VARCHAR || |
61 | 3 | input_type == PrimitiveType::TYPE_CHAR) { |
62 | 2 | return _execute_text_embed(context, block, arguments, result, input_rows_count, config, |
63 | 2 | adapter); |
64 | 2 | } |
65 | 1 | return Status::InvalidArgument( |
66 | 1 | "Function EMBED expects the second argument to be STRING or JSON, but got type {}", |
67 | 1 | block.get_by_position(arguments[1]).type->get_name()); |
68 | 3 | } |
69 | | |
70 | 19 | static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); } |
71 | | |
72 | | private: |
73 | | Status _execute_text_embed(FunctionContext* context, Block& block, |
74 | | const ColumnNumbers& arguments, uint32_t result, |
75 | | size_t input_rows_count, const TAIResource& config, |
76 | 2 | std::shared_ptr<AIAdapter>& adapter) const { |
77 | 2 | auto col_result = ColumnArray::create( |
78 | 2 | ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create())); |
79 | | |
80 | 5 | for (size_t i = 0; i < input_rows_count; ++i) { |
81 | 3 | std::string prompt; |
82 | 3 | RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt)); |
83 | | |
84 | 3 | std::vector<float> float_result; |
85 | 3 | RETURN_IF_ERROR(execute_single_request(prompt, float_result, config, adapter, context)); |
86 | 3 | _insert_embedding_result(*col_result, float_result); |
87 | 3 | } |
88 | | |
89 | 2 | block.replace_by_position(result, std::move(col_result)); |
90 | 2 | return Status::OK(); |
91 | 2 | } |
92 | | |
93 | | Status _execute_multimodal_embed(FunctionContext* context, Block& block, |
94 | | const ColumnNumbers& arguments, uint32_t result, |
95 | | size_t input_rows_count, const TAIResource& config, |
96 | 6 | std::shared_ptr<AIAdapter>& adapter) const { |
97 | 6 | auto col_result = ColumnArray::create( |
98 | 6 | ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create())); |
99 | | |
100 | 6 | int64_t ttl_seconds = 3600; |
101 | 6 | QueryContext* query_ctx = context->state()->get_query_ctx(); |
102 | 6 | if (query_ctx && query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) { |
103 | 6 | ttl_seconds = query_ctx->query_options().file_presigned_url_ttl_seconds; |
104 | 6 | if (ttl_seconds <= 0) { |
105 | 1 | ttl_seconds = 3600; |
106 | 1 | } |
107 | 6 | } |
108 | | |
109 | 6 | const ColumnWithTypeAndName& file_column = block.get_by_position(arguments[1]); |
110 | 10 | for (size_t i = 0; i < input_rows_count; ++i) { |
111 | 8 | rapidjson::Document file_input; |
112 | 8 | RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input)); |
113 | | |
114 | 8 | MultimodalType media_type; |
115 | 8 | RETURN_IF_ERROR(_infer_media_type(file_input, media_type)); |
116 | | |
117 | 6 | std::string media_url; |
118 | 6 | RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, media_url)); |
119 | | |
120 | 4 | std::string request_body; |
121 | 4 | RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type, media_url, |
122 | 4 | request_body)); |
123 | | |
124 | 4 | std::vector<float> float_result; |
125 | 4 | RETURN_IF_ERROR(execute_embedding_request(request_body, float_result, config, adapter, |
126 | 4 | context)); |
127 | 4 | _insert_embedding_result(*col_result, float_result); |
128 | 4 | } |
129 | | |
130 | 2 | block.replace_by_position(result, std::move(col_result)); |
131 | 2 | return Status::OK(); |
132 | 6 | } |
133 | | |
134 | | static void _insert_embedding_result(ColumnArray& col_array, |
135 | 7 | const std::vector<float>& float_result) { |
136 | 7 | auto& offsets = col_array.get_offsets(); |
137 | 7 | auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data()); |
138 | 7 | auto& nested_col = |
139 | 7 | assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr())); |
140 | 7 | nested_col.reserve(nested_col.size() + float_result.size()); |
141 | | |
142 | 7 | size_t current_offset = nested_col.size(); |
143 | 7 | nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()), |
144 | 7 | float_result.size()); |
145 | 7 | offsets.push_back(current_offset + float_result.size()); |
146 | 7 | auto& null_map = nested_nullable_col.get_null_map_column(); |
147 | 7 | null_map.insert_many_vals(0, float_result.size()); |
148 | 7 | } |
149 | | |
150 | 24 | static bool _starts_with_ignore_case(std::string_view s, std::string_view prefix) { |
151 | 24 | if (s.size() < prefix.size()) { |
152 | 0 | return false; |
153 | 0 | } |
154 | 87 | return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, char b) { |
155 | 87 | return std::tolower(static_cast<unsigned char>(a)) == |
156 | 87 | std::tolower(static_cast<unsigned char>(b)); |
157 | 87 | }); |
158 | 24 | } |
159 | | |
160 | | static Status _infer_media_type(const rapidjson::Value& file_input, |
161 | 8 | MultimodalType& media_type) { |
162 | 8 | std::string content_type; |
163 | 8 | RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", content_type)); |
164 | | |
165 | 7 | if (_starts_with_ignore_case(content_type, "image/")) { |
166 | 4 | media_type = MultimodalType::IMAGE; |
167 | 4 | return Status::OK(); |
168 | 4 | } else if (_starts_with_ignore_case(content_type, "video/")) { |
169 | 1 | media_type = MultimodalType::VIDEO; |
170 | 1 | return Status::OK(); |
171 | 2 | } else if (_starts_with_ignore_case(content_type, "audio/")) { |
172 | 1 | media_type = MultimodalType::AUDIO; |
173 | 1 | return Status::OK(); |
174 | 1 | } |
175 | | |
176 | 1 | return Status::InvalidArgument("Unsupported content_type for EMBED: {}", content_type); |
177 | 7 | } |
178 | | |
179 | | // Parse the FILE-like JSONB argument into a JSON object for downstream field reads. |
180 | | static Status _parse_file_input(const ColumnWithTypeAndName& file_column, size_t row_num, |
181 | 8 | rapidjson::Document& file_input) { |
182 | 8 | std::string file_json = |
183 | 8 | JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data, |
184 | 8 | file_column.column->get_data_at(row_num).size); |
185 | 8 | file_input.Parse(file_json.c_str()); |
186 | 8 | DORIS_CHECK(!file_input.HasParseError() && file_input.IsObject()); |
187 | 8 | return Status::OK(); |
188 | 8 | } |
189 | | |
190 | | // TODO(lzq): After support FILE type, We should use the interface provided by FILE to get the fields |
191 | | // replacing this function |
192 | | static Status _get_required_string_field(const rapidjson::Value& obj, const char* field_name, |
193 | 19 | std::string& value) { |
194 | 19 | auto iter = obj.FindMember(field_name); |
195 | 19 | if (iter == obj.MemberEnd() || !iter->value.IsString()) { |
196 | 3 | return Status::InvalidArgument( |
197 | 3 | "EMBED file json field '{}' is required and must be a string", field_name); |
198 | 3 | } |
199 | 16 | value = iter->value.GetString(); |
200 | 16 | if (value.empty()) { |
201 | 0 | return Status::InvalidArgument("EMBED file json field '{}' can not be empty", |
202 | 0 | field_name); |
203 | 0 | } |
204 | 16 | return Status::OK(); |
205 | 16 | } |
206 | | |
207 | | static Status init_s3_client_conf_from_json(const rapidjson::Value& file_input, |
208 | 3 | S3ClientConf& s3_client_conf) { |
209 | 3 | std::string endpoint; |
210 | 3 | RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint", endpoint)); |
211 | 2 | std::string region; |
212 | 2 | RETURN_IF_ERROR(_get_required_string_field(file_input, "region", region)); |
213 | | |
214 | 4 | auto get_optional_string_field = [&](const char* field_name, std::string& value) { |
215 | 4 | auto iter = file_input.FindMember(field_name); |
216 | 4 | if (iter == file_input.MemberEnd() || iter->value.IsNull()) { |
217 | 0 | return; |
218 | 0 | } |
219 | 4 | DORIS_CHECK(iter->value.IsString()); |
220 | 4 | value = iter->value.GetString(); |
221 | 4 | }; |
222 | | |
223 | 1 | get_optional_string_field("ak", s3_client_conf.ak); |
224 | 1 | get_optional_string_field("sk", s3_client_conf.sk); |
225 | 1 | get_optional_string_field("role_arn", s3_client_conf.role_arn); |
226 | 1 | get_optional_string_field("external_id", s3_client_conf.external_id); |
227 | 1 | s3_client_conf.endpoint = endpoint; |
228 | 1 | s3_client_conf.region = region; |
229 | | |
230 | 1 | return Status::OK(); |
231 | 2 | } |
232 | | |
233 | | Status _resolve_media_url(const rapidjson::Value& file_input, int64_t ttl_seconds, |
234 | 6 | std::string& media_url) const { |
235 | 6 | std::string uri; |
236 | 6 | RETURN_IF_ERROR(_get_required_string_field(file_input, "uri", uri)); |
237 | | |
238 | | // If it's a direct http/https URL, use it as-is |
239 | 6 | if (_starts_with_ignore_case(uri, "http://") || _starts_with_ignore_case(uri, "https://")) { |
240 | 3 | media_url = uri; |
241 | 3 | return Status::OK(); |
242 | 3 | } |
243 | | |
244 | 3 | S3ClientConf s3_client_conf; |
245 | 3 | RETURN_IF_ERROR(init_s3_client_conf_from_json(file_input, s3_client_conf)); |
246 | 1 | auto s3_client = S3ClientFactory::instance().create(s3_client_conf); |
247 | 1 | if (s3_client == nullptr) { |
248 | 0 | return Status::InternalError("Failed to create S3 client for EMBED file input"); |
249 | 0 | } |
250 | | |
251 | 1 | S3URI s3_uri(uri); |
252 | 1 | RETURN_IF_ERROR(s3_uri.parse()); |
253 | 1 | std::string bucket = s3_uri.get_bucket(); |
254 | 1 | std::string key = s3_uri.get_key(); |
255 | 1 | DORIS_CHECK(!bucket.empty() && !key.empty()); |
256 | 1 | media_url = s3_client->generate_presigned_url({.bucket = bucket, .key = key}, ttl_seconds, |
257 | 1 | s3_client_conf); |
258 | 1 | return Status::OK(); |
259 | 1 | } |
260 | | }; |
261 | | |
262 | | }; // namespace doris |