be/src/exprs/function/ai/embed.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <glog/logging.h> |
21 | | #include <rapidjson/document.h> |
22 | | |
23 | | #include <string_view> |
24 | | |
25 | | #include "core/data_type/data_type_nullable.h" |
26 | | #include "core/data_type/primitive_type.h" |
27 | | #include "exprs/function/ai/ai_functions.h" |
28 | | #include "util/jsonb_utils.h" |
29 | | #include "util/s3_uri.h" |
30 | | #include "util/s3_util.h" |
31 | | |
32 | | namespace doris { |
33 | | class FunctionEmbed : public AIFunction<FunctionEmbed> { |
34 | | public: |
35 | | static constexpr auto name = "embed"; |
36 | | |
37 | | static constexpr size_t number_of_arguments = 2; |
38 | | |
39 | | static constexpr auto system_prompt = ""; |
40 | | |
41 | 1 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
42 | 1 | return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>())); |
43 | 1 | } |
44 | | |
45 | | Status execute_with_adapter(FunctionContext* context, Block& block, |
46 | | const ColumnNumbers& arguments, uint32_t result, |
47 | | size_t input_rows_count, const TAIResource& config, |
48 | 13 | std::shared_ptr<AIAdapter>& adapter) const { |
49 | 13 | if (arguments.size() != 2) { |
50 | 1 | return Status::InvalidArgument("Function EMBED expects 2 arguments, but got {}", |
51 | 1 | arguments.size()); |
52 | 1 | } |
53 | | |
54 | 12 | PrimitiveType input_type = |
55 | 12 | remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type(); |
56 | 12 | if (input_type == PrimitiveType::TYPE_JSONB) { |
57 | 8 | return _execute_multimodal_embed(context, block, arguments, result, input_rows_count, |
58 | 8 | config, adapter); |
59 | 8 | } |
60 | 4 | if (input_type == PrimitiveType::TYPE_STRING || input_type == PrimitiveType::TYPE_VARCHAR || |
61 | 4 | input_type == PrimitiveType::TYPE_CHAR) { |
62 | 3 | return _execute_text_embed(context, block, arguments, result, input_rows_count, config, |
63 | 3 | adapter); |
64 | 3 | } |
65 | 1 | return Status::InvalidArgument( |
66 | 1 | "Function EMBED expects the second argument to be STRING or JSON, but got type {}", |
67 | 1 | block.get_by_position(arguments[1]).type->get_name()); |
68 | 4 | } |
69 | | |
70 | 12 | static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); } |
71 | | |
72 | | private: |
73 | 11 | static int32_t _get_embed_max_batch_size(FunctionContext* context) { |
74 | 11 | QueryContext* query_ctx = context->state()->get_query_ctx(); |
75 | 11 | DORIS_CHECK(query_ctx != nullptr); |
76 | | |
77 | 11 | return query_ctx->query_options().embed_max_batch_size; |
78 | 11 | } |
79 | | |
80 | | Status _execute_text_embed(FunctionContext* context, Block& block, |
81 | | const ColumnNumbers& arguments, uint32_t result, |
82 | | size_t input_rows_count, const TAIResource& config, |
83 | 3 | std::shared_ptr<AIAdapter>& adapter) const { |
84 | 3 | auto col_result = ColumnArray::create( |
85 | 3 | ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create())); |
86 | 3 | std::vector<std::string> batch_prompts; |
87 | 3 | size_t current_batch_size = 0; |
88 | 3 | const int32_t max_batch_size = _get_embed_max_batch_size(context); |
89 | 3 | const size_t max_context_window_size = |
90 | 3 | static_cast<size_t>(get_ai_context_window_size(context)); |
91 | | |
92 | 9 | for (size_t i = 0; i < input_rows_count; ++i) { |
93 | 6 | std::string prompt; |
94 | 6 | RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt)); |
95 | | |
96 | 6 | const size_t prompt_size = prompt.size(); |
97 | | |
98 | 6 | if (prompt_size > max_context_window_size) { |
99 | | // flush history batch |
100 | 0 | RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config, |
101 | 0 | adapter, context)); |
102 | 0 | current_batch_size = 0; |
103 | |
|
104 | 0 | batch_prompts.emplace_back(std::move(prompt)); |
105 | 0 | RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config, |
106 | 0 | adapter, context)); |
107 | 0 | continue; |
108 | 0 | } |
109 | | |
110 | 6 | if (!batch_prompts.empty() && |
111 | 6 | (current_batch_size + prompt_size > max_context_window_size || |
112 | 3 | batch_prompts.size() >= static_cast<size_t>(max_batch_size))) { |
113 | 1 | RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config, |
114 | 1 | adapter, context)); |
115 | 1 | current_batch_size = 0; |
116 | 1 | } |
117 | | |
118 | 6 | batch_prompts.emplace_back(std::move(prompt)); |
119 | 6 | current_batch_size += prompt_size; |
120 | 6 | } |
121 | | |
122 | 3 | RETURN_IF_ERROR( |
123 | 3 | _flush_text_embedding_batch(batch_prompts, *col_result, config, adapter, context)); |
124 | | |
125 | 3 | block.replace_by_position(result, std::move(col_result)); |
126 | 3 | return Status::OK(); |
127 | 3 | } |
128 | | |
129 | | Status _execute_multimodal_embed(FunctionContext* context, Block& block, |
130 | | const ColumnNumbers& arguments, uint32_t result, |
131 | | size_t input_rows_count, const TAIResource& config, |
132 | 8 | std::shared_ptr<AIAdapter>& adapter) const { |
133 | 8 | auto col_result = ColumnArray::create( |
134 | 8 | ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create())); |
135 | 8 | std::vector<MultimodalType> batch_media_types; |
136 | 8 | std::vector<std::string> batch_media_content_types; |
137 | 8 | std::vector<std::string> batch_media_urls; |
138 | | |
139 | 8 | int64_t ttl_seconds = 3600; |
140 | 8 | QueryContext* query_ctx = context->state()->get_query_ctx(); |
141 | 8 | if (query_ctx && query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) { |
142 | 8 | ttl_seconds = query_ctx->query_options().file_presigned_url_ttl_seconds; |
143 | 8 | if (ttl_seconds <= 0) { |
144 | 1 | ttl_seconds = 3600; |
145 | 1 | } |
146 | 8 | } |
147 | | |
148 | 8 | const int32_t max_batch_size = _get_embed_max_batch_size(context); |
149 | | |
150 | 8 | const ColumnWithTypeAndName& file_column = block.get_by_position(arguments[1]); |
151 | 18 | for (size_t i = 0; i < input_rows_count; ++i) { |
152 | 14 | rapidjson::Document file_input; |
153 | 14 | RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input)); |
154 | | |
155 | 14 | std::string content_type; |
156 | 14 | MultimodalType media_type; |
157 | 14 | RETURN_IF_ERROR(_infer_media_type(file_input, content_type, media_type)); |
158 | | |
159 | 12 | std::string media_url; |
160 | 12 | RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, media_url)); |
161 | | |
162 | 10 | if (!batch_media_urls.empty() && |
163 | 10 | batch_media_urls.size() >= static_cast<size_t>(max_batch_size)) { |
164 | 1 | RETURN_IF_ERROR(_flush_multimodal_embedding_batch( |
165 | 1 | batch_media_types, batch_media_content_types, batch_media_urls, *col_result, |
166 | 1 | config, adapter, context)); |
167 | 1 | } |
168 | | |
169 | 10 | batch_media_types.emplace_back(media_type); |
170 | 10 | batch_media_content_types.emplace_back(std::move(content_type)); |
171 | 10 | batch_media_urls.emplace_back(std::move(media_url)); |
172 | 10 | } |
173 | | |
174 | 4 | RETURN_IF_ERROR(_flush_multimodal_embedding_batch( |
175 | 4 | batch_media_types, batch_media_content_types, batch_media_urls, *col_result, config, |
176 | 4 | adapter, context)); |
177 | | |
178 | 4 | block.replace_by_position(result, std::move(col_result)); |
179 | 4 | return Status::OK(); |
180 | 4 | } |
181 | | |
182 | | // EMBED-private helper. |
183 | | // Sends one embedding request with a prebuilt request body and validates returned row count. |
184 | | Status _execute_prebuilt_embedding_request(const std::string& request_body, |
185 | | std::vector<std::vector<float>>& results, |
186 | | size_t expected_size, const TAIResource& config, |
187 | | std::shared_ptr<AIAdapter>& adapter, |
188 | 9 | FunctionContext* context) const { |
189 | 9 | std::string response; |
190 | 9 | #ifdef BE_TEST |
191 | 9 | if (config.provider_type == "MOCK") { |
192 | 9 | results.clear(); |
193 | 9 | results.reserve(expected_size); |
194 | 25 | for (size_t i = 0; i < expected_size; ++i) { |
195 | 16 | results.emplace_back(std::initializer_list<float> {0, 1, 2, 3, 4}); |
196 | 16 | } |
197 | 9 | return Status::OK(); |
198 | 9 | } |
199 | 0 | #endif |
200 | | |
201 | 0 | RETURN_IF_ERROR( |
202 | 0 | this->send_request_to_llm(request_body, response, config, adapter, context)); |
203 | | |
204 | 0 | RETURN_IF_ERROR(adapter->parse_embedding_response(response, results)); |
205 | 0 | if (results.empty()) { |
206 | 0 | return Status::InternalError("AI returned empty result"); |
207 | 0 | } |
208 | 0 | if (results.size() != expected_size) [[unlikely]] { |
209 | 0 | return Status::InternalError( |
210 | 0 | "AI embedding returned {} results, but {} inputs were sent", results.size(), |
211 | 0 | expected_size); |
212 | 0 | } |
213 | 0 | return Status::OK(); |
214 | 0 | } |
215 | | |
216 | | // EMBED-private helper. |
217 | | // Flushes one accumulated text embedding batch into the output array column. |
218 | | Status _flush_text_embedding_batch(std::vector<std::string>& batch_prompts, |
219 | | ColumnArray& col_result, const TAIResource& config, |
220 | | std::shared_ptr<AIAdapter>& adapter, |
221 | 4 | FunctionContext* context) const { |
222 | 4 | if (batch_prompts.empty()) { |
223 | 0 | return Status::OK(); |
224 | 0 | } |
225 | | |
226 | 4 | std::string request_body; |
227 | 4 | RETURN_IF_ERROR(adapter->build_embedding_request(batch_prompts, request_body)); |
228 | 4 | std::vector<std::vector<float>> batch_results; |
229 | 4 | RETURN_IF_ERROR(_execute_prebuilt_embedding_request( |
230 | 4 | request_body, batch_results, batch_prompts.size(), config, adapter, context)); |
231 | 6 | for (const auto& batch_result : batch_results) { |
232 | 6 | _insert_embedding_result(col_result, batch_result); |
233 | 6 | } |
234 | 4 | batch_prompts.clear(); |
235 | 4 | return Status::OK(); |
236 | 4 | } |
237 | | |
238 | | // EMBED-private helper. |
239 | | // Flushes one accumulated multimodal embedding batch into the output array column. |
240 | | Status _flush_multimodal_embedding_batch(std::vector<MultimodalType>& batch_media_types, |
241 | | std::vector<std::string>& batch_media_content_types, |
242 | | std::vector<std::string>& batch_media_urls, |
243 | | ColumnArray& col_result, const TAIResource& config, |
244 | | std::shared_ptr<AIAdapter>& adapter, |
245 | 5 | FunctionContext* context) const { |
246 | 5 | if (batch_media_urls.empty()) { |
247 | 0 | return Status::OK(); |
248 | 0 | } |
249 | | |
250 | 5 | std::string request_body; |
251 | 5 | RETURN_IF_ERROR(adapter->build_multimodal_embedding_request( |
252 | 5 | batch_media_types, batch_media_urls, batch_media_content_types, request_body)); |
253 | | |
254 | 5 | std::vector<std::vector<float>> batch_results; |
255 | 5 | RETURN_IF_ERROR(_execute_prebuilt_embedding_request( |
256 | 5 | request_body, batch_results, batch_media_urls.size(), config, adapter, context)); |
257 | 10 | for (const auto& batch_result : batch_results) { |
258 | 10 | _insert_embedding_result(col_result, batch_result); |
259 | 10 | } |
260 | 5 | batch_media_types.clear(); |
261 | 5 | batch_media_content_types.clear(); |
262 | 5 | batch_media_urls.clear(); |
263 | 5 | return Status::OK(); |
264 | 5 | } |
265 | | |
266 | | static void _insert_embedding_result(ColumnArray& col_array, |
267 | 16 | const std::vector<float>& float_result) { |
268 | 16 | auto& offsets = col_array.get_offsets(); |
269 | 16 | auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data()); |
270 | 16 | auto& nested_col = |
271 | 16 | assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr())); |
272 | 16 | nested_col.reserve(nested_col.size() + float_result.size()); |
273 | | |
274 | 16 | size_t current_offset = nested_col.size(); |
275 | 16 | nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()), |
276 | 16 | float_result.size()); |
277 | 16 | offsets.push_back(current_offset + float_result.size()); |
278 | 16 | auto& null_map = nested_nullable_col.get_null_map_column(); |
279 | 16 | null_map.insert_many_vals(0, float_result.size()); |
280 | 16 | } |
281 | | |
282 | 45 | static bool _starts_with_ignore_case(std::string_view s, std::string_view prefix) { |
283 | 45 | if (s.size() < prefix.size()) { |
284 | 0 | return false; |
285 | 0 | } |
286 | 204 | return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, char b) { |
287 | 204 | return std::tolower(static_cast<unsigned char>(a)) == |
288 | 204 | std::tolower(static_cast<unsigned char>(b)); |
289 | 204 | }); |
290 | 45 | } |
291 | | |
292 | | static Status _infer_media_type(const rapidjson::Value& file_input, std::string& content_type, |
293 | 14 | MultimodalType& media_type) { |
294 | 14 | RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", content_type)); |
295 | | |
296 | 13 | if (_starts_with_ignore_case(content_type, "image/")) { |
297 | 8 | media_type = MultimodalType::IMAGE; |
298 | 8 | return Status::OK(); |
299 | 8 | } else if (_starts_with_ignore_case(content_type, "video/")) { |
300 | 2 | media_type = MultimodalType::VIDEO; |
301 | 2 | return Status::OK(); |
302 | 3 | } else if (_starts_with_ignore_case(content_type, "audio/")) { |
303 | 2 | media_type = MultimodalType::AUDIO; |
304 | 2 | return Status::OK(); |
305 | 2 | } |
306 | | |
307 | 1 | return Status::InvalidArgument("Unsupported content_type for EMBED: {}", content_type); |
308 | 13 | } |
309 | | |
310 | | // Parse the FILE-like JSONB argument into a JSON object for downstream field reads. |
311 | | static Status _parse_file_input(const ColumnWithTypeAndName& file_column, size_t row_num, |
312 | 14 | rapidjson::Document& file_input) { |
313 | 14 | std::string file_json = |
314 | 14 | JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data, |
315 | 14 | file_column.column->get_data_at(row_num).size); |
316 | 14 | file_input.Parse(file_json.c_str()); |
317 | 14 | DORIS_CHECK(!file_input.HasParseError() && file_input.IsObject()); |
318 | 14 | return Status::OK(); |
319 | 14 | } |
320 | | |
321 | | // TODO(lzq): After support FILE type, We should use the interface provided by FILE to get the fields |
322 | | // replacing this function |
323 | | static Status _get_required_string_field(const rapidjson::Value& obj, const char* field_name, |
324 | 31 | std::string& value) { |
325 | 31 | auto iter = obj.FindMember(field_name); |
326 | 31 | if (iter == obj.MemberEnd() || !iter->value.IsString()) { |
327 | 3 | return Status::InvalidArgument( |
328 | 3 | "EMBED file json field '{}' is required and must be a string", field_name); |
329 | 3 | } |
330 | 28 | value = iter->value.GetString(); |
331 | 28 | if (value.empty()) { |
332 | 0 | return Status::InvalidArgument("EMBED file json field '{}' can not be empty", |
333 | 0 | field_name); |
334 | 0 | } |
335 | 28 | return Status::OK(); |
336 | 28 | } |
337 | | |
338 | | static Status init_s3_client_conf_from_json(const rapidjson::Value& file_input, |
339 | 3 | S3ClientConf& s3_client_conf) { |
340 | 3 | std::string endpoint; |
341 | 3 | RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint", endpoint)); |
342 | 2 | std::string region; |
343 | 2 | RETURN_IF_ERROR(_get_required_string_field(file_input, "region", region)); |
344 | | |
345 | 4 | auto get_optional_string_field = [&](const char* field_name, std::string& value) { |
346 | 4 | auto iter = file_input.FindMember(field_name); |
347 | 4 | if (iter == file_input.MemberEnd() || iter->value.IsNull()) { |
348 | 0 | return; |
349 | 0 | } |
350 | 4 | DORIS_CHECK(iter->value.IsString()); |
351 | 4 | value = iter->value.GetString(); |
352 | 4 | }; |
353 | | |
354 | 1 | get_optional_string_field("ak", s3_client_conf.ak); |
355 | 1 | get_optional_string_field("sk", s3_client_conf.sk); |
356 | 1 | get_optional_string_field("role_arn", s3_client_conf.role_arn); |
357 | 1 | get_optional_string_field("external_id", s3_client_conf.external_id); |
358 | 1 | s3_client_conf.endpoint = endpoint; |
359 | 1 | s3_client_conf.region = region; |
360 | | |
361 | 1 | return Status::OK(); |
362 | 2 | } |
363 | | |
364 | | Status _resolve_media_url(const rapidjson::Value& file_input, int64_t ttl_seconds, |
365 | 12 | std::string& media_url) const { |
366 | 12 | std::string uri; |
367 | 12 | RETURN_IF_ERROR(_get_required_string_field(file_input, "uri", uri)); |
368 | | |
369 | | // If it's a direct http/https URL, use it as-is |
370 | 12 | if (_starts_with_ignore_case(uri, "http://") || _starts_with_ignore_case(uri, "https://")) { |
371 | 9 | media_url = uri; |
372 | 9 | return Status::OK(); |
373 | 9 | } |
374 | | |
375 | 3 | S3ClientConf s3_client_conf; |
376 | 3 | RETURN_IF_ERROR(init_s3_client_conf_from_json(file_input, s3_client_conf)); |
377 | 1 | auto s3_client = S3ClientFactory::instance().create(s3_client_conf); |
378 | 1 | if (s3_client == nullptr) { |
379 | 0 | return Status::InternalError("Failed to create S3 client for EMBED file input"); |
380 | 0 | } |
381 | | |
382 | 1 | S3URI s3_uri(uri); |
383 | 1 | RETURN_IF_ERROR(s3_uri.parse()); |
384 | 1 | std::string bucket = s3_uri.get_bucket(); |
385 | 1 | std::string key = s3_uri.get_key(); |
386 | 1 | DORIS_CHECK(!bucket.empty() && !key.empty()); |
387 | 1 | media_url = s3_client->generate_presigned_url({.bucket = bucket, .key = key}, ttl_seconds, |
388 | 1 | s3_client_conf); |
389 | 1 | return Status::OK(); |
390 | 1 | } |
391 | | }; |
392 | | |
393 | | }; // namespace doris |