Coverage Report

Created: 2026-04-17 23:17

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/embed.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <glog/logging.h>
21
#include <rapidjson/document.h>
22
23
#include <string_view>
24
25
#include "core/data_type/data_type_nullable.h"
26
#include "core/data_type/primitive_type.h"
27
#include "exprs/function/ai/ai_functions.h"
28
#include "util/jsonb_utils.h"
29
#include "util/s3_uri.h"
30
#include "util/s3_util.h"
31
32
namespace doris {
33
class FunctionEmbed : public AIFunction<FunctionEmbed> {
34
public:
35
    static constexpr auto name = "embed";
36
37
    static constexpr size_t number_of_arguments = 2;
38
39
    static constexpr auto system_prompt = "";
40
41
1
    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
42
1
        return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
43
1
    }
44
45
    Status execute_with_adapter(FunctionContext* context, Block& block,
46
                                const ColumnNumbers& arguments, uint32_t result,
47
                                size_t input_rows_count, const TAIResource& config,
48
13
                                std::shared_ptr<AIAdapter>& adapter) const {
49
13
        if (arguments.size() != 2) {
50
1
            return Status::InvalidArgument("Function EMBED expects 2 arguments, but got {}",
51
1
                                           arguments.size());
52
1
        }
53
54
12
        PrimitiveType input_type =
55
12
                remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
56
12
        if (input_type == PrimitiveType::TYPE_JSONB) {
57
8
            return _execute_multimodal_embed(context, block, arguments, result, input_rows_count,
58
8
                                             config, adapter);
59
8
        }
60
4
        if (input_type == PrimitiveType::TYPE_STRING || input_type == PrimitiveType::TYPE_VARCHAR ||
61
4
            input_type == PrimitiveType::TYPE_CHAR) {
62
3
            return _execute_text_embed(context, block, arguments, result, input_rows_count, config,
63
3
                                       adapter);
64
3
        }
65
1
        return Status::InvalidArgument(
66
1
                "Function EMBED expects the second argument to be STRING or JSON, but got type {}",
67
1
                block.get_by_position(arguments[1]).type->get_name());
68
4
    }
69
70
12
    static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
71
72
private:
73
11
    static int32_t _get_embed_max_batch_size(FunctionContext* context) {
74
11
        QueryContext* query_ctx = context->state()->get_query_ctx();
75
11
        DORIS_CHECK(query_ctx != nullptr);
76
77
11
        return query_ctx->query_options().embed_max_batch_size;
78
11
    }
79
80
    Status _execute_text_embed(FunctionContext* context, Block& block,
81
                               const ColumnNumbers& arguments, uint32_t result,
82
                               size_t input_rows_count, const TAIResource& config,
83
3
                               std::shared_ptr<AIAdapter>& adapter) const {
84
3
        auto col_result = ColumnArray::create(
85
3
                ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
86
3
        std::vector<std::string> batch_prompts;
87
3
        size_t current_batch_size = 0;
88
3
        const int32_t max_batch_size = _get_embed_max_batch_size(context);
89
3
        const size_t max_context_window_size =
90
3
                static_cast<size_t>(get_ai_context_window_size(context));
91
92
9
        for (size_t i = 0; i < input_rows_count; ++i) {
93
6
            std::string prompt;
94
6
            RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
95
96
6
            const size_t prompt_size = prompt.size();
97
98
6
            if (prompt_size > max_context_window_size) {
99
                // flush history batch
100
0
                RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
101
0
                                                            adapter, context));
102
0
                current_batch_size = 0;
103
104
0
                batch_prompts.emplace_back(std::move(prompt));
105
0
                RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
106
0
                                                            adapter, context));
107
0
                continue;
108
0
            }
109
110
6
            if (!batch_prompts.empty() &&
111
6
                (current_batch_size + prompt_size > max_context_window_size ||
112
3
                 batch_prompts.size() >= static_cast<size_t>(max_batch_size))) {
113
1
                RETURN_IF_ERROR(_flush_text_embedding_batch(batch_prompts, *col_result, config,
114
1
                                                            adapter, context));
115
1
                current_batch_size = 0;
116
1
            }
117
118
6
            batch_prompts.emplace_back(std::move(prompt));
119
6
            current_batch_size += prompt_size;
120
6
        }
121
122
3
        RETURN_IF_ERROR(
123
3
                _flush_text_embedding_batch(batch_prompts, *col_result, config, adapter, context));
124
125
3
        block.replace_by_position(result, std::move(col_result));
126
3
        return Status::OK();
127
3
    }
128
129
    Status _execute_multimodal_embed(FunctionContext* context, Block& block,
130
                                     const ColumnNumbers& arguments, uint32_t result,
131
                                     size_t input_rows_count, const TAIResource& config,
132
8
                                     std::shared_ptr<AIAdapter>& adapter) const {
133
8
        auto col_result = ColumnArray::create(
134
8
                ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
135
8
        std::vector<MultimodalType> batch_media_types;
136
8
        std::vector<std::string> batch_media_content_types;
137
8
        std::vector<std::string> batch_media_urls;
138
139
8
        int64_t ttl_seconds = 3600;
140
8
        QueryContext* query_ctx = context->state()->get_query_ctx();
141
8
        if (query_ctx && query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
142
8
            ttl_seconds = query_ctx->query_options().file_presigned_url_ttl_seconds;
143
8
            if (ttl_seconds <= 0) {
144
1
                ttl_seconds = 3600;
145
1
            }
146
8
        }
147
148
8
        const int32_t max_batch_size = _get_embed_max_batch_size(context);
149
150
8
        const ColumnWithTypeAndName& file_column = block.get_by_position(arguments[1]);
151
18
        for (size_t i = 0; i < input_rows_count; ++i) {
152
14
            rapidjson::Document file_input;
153
14
            RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
154
155
14
            std::string content_type;
156
14
            MultimodalType media_type;
157
14
            RETURN_IF_ERROR(_infer_media_type(file_input, content_type, media_type));
158
159
12
            std::string media_url;
160
12
            RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, media_url));
161
162
10
            if (!batch_media_urls.empty() &&
163
10
                batch_media_urls.size() >= static_cast<size_t>(max_batch_size)) {
164
1
                RETURN_IF_ERROR(_flush_multimodal_embedding_batch(
165
1
                        batch_media_types, batch_media_content_types, batch_media_urls, *col_result,
166
1
                        config, adapter, context));
167
1
            }
168
169
10
            batch_media_types.emplace_back(media_type);
170
10
            batch_media_content_types.emplace_back(std::move(content_type));
171
10
            batch_media_urls.emplace_back(std::move(media_url));
172
10
        }
173
174
4
        RETURN_IF_ERROR(_flush_multimodal_embedding_batch(
175
4
                batch_media_types, batch_media_content_types, batch_media_urls, *col_result, config,
176
4
                adapter, context));
177
178
4
        block.replace_by_position(result, std::move(col_result));
179
4
        return Status::OK();
180
4
    }
181
182
    // EMBED-private helper.
183
    // Sends one embedding request with a prebuilt request body and validates returned row count.
184
    Status _execute_prebuilt_embedding_request(const std::string& request_body,
185
                                               std::vector<std::vector<float>>& results,
186
                                               size_t expected_size, const TAIResource& config,
187
                                               std::shared_ptr<AIAdapter>& adapter,
188
9
                                               FunctionContext* context) const {
189
9
        std::string response;
190
9
#ifdef BE_TEST
191
9
        if (config.provider_type == "MOCK") {
192
9
            results.clear();
193
9
            results.reserve(expected_size);
194
25
            for (size_t i = 0; i < expected_size; ++i) {
195
16
                results.emplace_back(std::initializer_list<float> {0, 1, 2, 3, 4});
196
16
            }
197
9
            return Status::OK();
198
9
        }
199
0
#endif
200
201
0
        RETURN_IF_ERROR(
202
0
                this->send_request_to_llm(request_body, response, config, adapter, context));
203
204
0
        RETURN_IF_ERROR(adapter->parse_embedding_response(response, results));
205
0
        if (results.empty()) {
206
0
            return Status::InternalError("AI returned empty result");
207
0
        }
208
0
        if (results.size() != expected_size) [[unlikely]] {
209
0
            return Status::InternalError(
210
0
                    "AI embedding returned {} results, but {} inputs were sent", results.size(),
211
0
                    expected_size);
212
0
        }
213
0
        return Status::OK();
214
0
    }
215
216
    // EMBED-private helper.
217
    // Flushes one accumulated text embedding batch into the output array column.
218
    Status _flush_text_embedding_batch(std::vector<std::string>& batch_prompts,
219
                                       ColumnArray& col_result, const TAIResource& config,
220
                                       std::shared_ptr<AIAdapter>& adapter,
221
4
                                       FunctionContext* context) const {
222
4
        if (batch_prompts.empty()) {
223
0
            return Status::OK();
224
0
        }
225
226
4
        std::string request_body;
227
4
        RETURN_IF_ERROR(adapter->build_embedding_request(batch_prompts, request_body));
228
4
        std::vector<std::vector<float>> batch_results;
229
4
        RETURN_IF_ERROR(_execute_prebuilt_embedding_request(
230
4
                request_body, batch_results, batch_prompts.size(), config, adapter, context));
231
6
        for (const auto& batch_result : batch_results) {
232
6
            _insert_embedding_result(col_result, batch_result);
233
6
        }
234
4
        batch_prompts.clear();
235
4
        return Status::OK();
236
4
    }
237
238
    // EMBED-private helper.
239
    // Flushes one accumulated multimodal embedding batch into the output array column.
240
    Status _flush_multimodal_embedding_batch(std::vector<MultimodalType>& batch_media_types,
241
                                             std::vector<std::string>& batch_media_content_types,
242
                                             std::vector<std::string>& batch_media_urls,
243
                                             ColumnArray& col_result, const TAIResource& config,
244
                                             std::shared_ptr<AIAdapter>& adapter,
245
5
                                             FunctionContext* context) const {
246
5
        if (batch_media_urls.empty()) {
247
0
            return Status::OK();
248
0
        }
249
250
5
        std::string request_body;
251
5
        RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(
252
5
                batch_media_types, batch_media_urls, batch_media_content_types, request_body));
253
254
5
        std::vector<std::vector<float>> batch_results;
255
5
        RETURN_IF_ERROR(_execute_prebuilt_embedding_request(
256
5
                request_body, batch_results, batch_media_urls.size(), config, adapter, context));
257
10
        for (const auto& batch_result : batch_results) {
258
10
            _insert_embedding_result(col_result, batch_result);
259
10
        }
260
5
        batch_media_types.clear();
261
5
        batch_media_content_types.clear();
262
5
        batch_media_urls.clear();
263
5
        return Status::OK();
264
5
    }
265
266
    static void _insert_embedding_result(ColumnArray& col_array,
267
16
                                         const std::vector<float>& float_result) {
268
16
        auto& offsets = col_array.get_offsets();
269
16
        auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data());
270
16
        auto& nested_col =
271
16
                assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
272
16
        nested_col.reserve(nested_col.size() + float_result.size());
273
274
16
        size_t current_offset = nested_col.size();
275
16
        nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()),
276
16
                                        float_result.size());
277
16
        offsets.push_back(current_offset + float_result.size());
278
16
        auto& null_map = nested_nullable_col.get_null_map_column();
279
16
        null_map.insert_many_vals(0, float_result.size());
280
16
    }
281
282
45
    static bool _starts_with_ignore_case(std::string_view s, std::string_view prefix) {
283
45
        if (s.size() < prefix.size()) {
284
0
            return false;
285
0
        }
286
204
        return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, char b) {
287
204
            return std::tolower(static_cast<unsigned char>(a)) ==
288
204
                   std::tolower(static_cast<unsigned char>(b));
289
204
        });
290
45
    }
291
292
    static Status _infer_media_type(const rapidjson::Value& file_input, std::string& content_type,
293
14
                                    MultimodalType& media_type) {
294
14
        RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", content_type));
295
296
13
        if (_starts_with_ignore_case(content_type, "image/")) {
297
8
            media_type = MultimodalType::IMAGE;
298
8
            return Status::OK();
299
8
        } else if (_starts_with_ignore_case(content_type, "video/")) {
300
2
            media_type = MultimodalType::VIDEO;
301
2
            return Status::OK();
302
3
        } else if (_starts_with_ignore_case(content_type, "audio/")) {
303
2
            media_type = MultimodalType::AUDIO;
304
2
            return Status::OK();
305
2
        }
306
307
1
        return Status::InvalidArgument("Unsupported content_type for EMBED: {}", content_type);
308
13
    }
309
310
    // Parse the FILE-like JSONB argument into a JSON object for downstream field reads.
311
    static Status _parse_file_input(const ColumnWithTypeAndName& file_column, size_t row_num,
312
14
                                    rapidjson::Document& file_input) {
313
14
        std::string file_json =
314
14
                JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
315
14
                                                  file_column.column->get_data_at(row_num).size);
316
14
        file_input.Parse(file_json.c_str());
317
14
        DORIS_CHECK(!file_input.HasParseError() && file_input.IsObject());
318
14
        return Status::OK();
319
14
    }
320
321
    // TODO(lzq): After support FILE type, We should use the interface provided by FILE to get the fields
322
    // replacing this function
323
    static Status _get_required_string_field(const rapidjson::Value& obj, const char* field_name,
324
31
                                             std::string& value) {
325
31
        auto iter = obj.FindMember(field_name);
326
31
        if (iter == obj.MemberEnd() || !iter->value.IsString()) {
327
3
            return Status::InvalidArgument(
328
3
                    "EMBED file json field '{}' is required and must be a string", field_name);
329
3
        }
330
28
        value = iter->value.GetString();
331
28
        if (value.empty()) {
332
0
            return Status::InvalidArgument("EMBED file json field '{}' can not be empty",
333
0
                                           field_name);
334
0
        }
335
28
        return Status::OK();
336
28
    }
337
338
    static Status init_s3_client_conf_from_json(const rapidjson::Value& file_input,
339
3
                                                S3ClientConf& s3_client_conf) {
340
3
        std::string endpoint;
341
3
        RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint", endpoint));
342
2
        std::string region;
343
2
        RETURN_IF_ERROR(_get_required_string_field(file_input, "region", region));
344
345
4
        auto get_optional_string_field = [&](const char* field_name, std::string& value) {
346
4
            auto iter = file_input.FindMember(field_name);
347
4
            if (iter == file_input.MemberEnd() || iter->value.IsNull()) {
348
0
                return;
349
0
            }
350
4
            DORIS_CHECK(iter->value.IsString());
351
4
            value = iter->value.GetString();
352
4
        };
353
354
1
        get_optional_string_field("ak", s3_client_conf.ak);
355
1
        get_optional_string_field("sk", s3_client_conf.sk);
356
1
        get_optional_string_field("role_arn", s3_client_conf.role_arn);
357
1
        get_optional_string_field("external_id", s3_client_conf.external_id);
358
1
        s3_client_conf.endpoint = endpoint;
359
1
        s3_client_conf.region = region;
360
361
1
        return Status::OK();
362
2
    }
363
364
    Status _resolve_media_url(const rapidjson::Value& file_input, int64_t ttl_seconds,
365
12
                              std::string& media_url) const {
366
12
        std::string uri;
367
12
        RETURN_IF_ERROR(_get_required_string_field(file_input, "uri", uri));
368
369
        // If it's a direct http/https URL, use it as-is
370
12
        if (_starts_with_ignore_case(uri, "http://") || _starts_with_ignore_case(uri, "https://")) {
371
9
            media_url = uri;
372
9
            return Status::OK();
373
9
        }
374
375
3
        S3ClientConf s3_client_conf;
376
3
        RETURN_IF_ERROR(init_s3_client_conf_from_json(file_input, s3_client_conf));
377
1
        auto s3_client = S3ClientFactory::instance().create(s3_client_conf);
378
1
        if (s3_client == nullptr) {
379
0
            return Status::InternalError("Failed to create S3 client for EMBED file input");
380
0
        }
381
382
1
        S3URI s3_uri(uri);
383
1
        RETURN_IF_ERROR(s3_uri.parse());
384
1
        std::string bucket = s3_uri.get_bucket();
385
1
        std::string key = s3_uri.get_key();
386
1
        DORIS_CHECK(!bucket.empty() && !key.empty());
387
1
        media_url = s3_client->generate_presigned_url({.bucket = bucket, .key = key}, ttl_seconds,
388
1
                                                      s3_client_conf);
389
1
        return Status::OK();
390
1
    }
391
};
392
393
}; // namespace doris