Coverage Report

Created: 2026-04-15 18:59

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/embed.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <glog/logging.h>
21
#include <rapidjson/document.h>
22
23
#include <string_view>
24
25
#include "core/data_type/data_type_nullable.h"
26
#include "core/data_type/primitive_type.h"
27
#include "exprs/function/ai/ai_functions.h"
28
#include "util/jsonb_utils.h"
29
#include "util/s3_uri.h"
30
#include "util/s3_util.h"
31
32
namespace doris {
33
class FunctionEmbed : public AIFunction<FunctionEmbed> {
34
public:
35
    static constexpr auto name = "embed";
36
37
    static constexpr size_t number_of_arguments = 2;
38
39
    static constexpr auto system_prompt = "";
40
41
2
    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
42
2
        return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
43
2
    }
44
45
    Status execute_with_adapter(FunctionContext* context, Block& block,
46
                                const ColumnNumbers& arguments, uint32_t result,
47
                                size_t input_rows_count, const TAIResource& config,
48
20
                                std::shared_ptr<AIAdapter>& adapter) const {
49
20
        if (arguments.size() != 2) {
50
2
            return Status::InvalidArgument("Function EMBED expects 2 arguments, but got {}",
51
2
                                           arguments.size());
52
2
        }
53
54
18
        PrimitiveType input_type =
55
18
                remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
56
18
        if (input_type == PrimitiveType::TYPE_JSONB) {
57
12
            return _execute_multimodal_embed(context, block, arguments, result, input_rows_count,
58
12
                                             config, adapter);
59
12
        }
60
6
        if (input_type == PrimitiveType::TYPE_STRING || input_type == PrimitiveType::TYPE_VARCHAR ||
61
6
            input_type == PrimitiveType::TYPE_CHAR) {
62
4
            return _execute_text_embed(context, block, arguments, result, input_rows_count, config,
63
4
                                       adapter);
64
4
        }
65
2
        return Status::InvalidArgument(
66
2
                "Function EMBED expects the second argument to be STRING or JSON, but got type {}",
67
2
                block.get_by_position(arguments[1]).type->get_name());
68
6
    }
69
70
24
    static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
71
72
private:
73
    Status _execute_text_embed(FunctionContext* context, Block& block,
74
                               const ColumnNumbers& arguments, uint32_t result,
75
                               size_t input_rows_count, const TAIResource& config,
76
4
                               std::shared_ptr<AIAdapter>& adapter) const {
77
4
        auto col_result = ColumnArray::create(
78
4
                ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
79
80
10
        for (size_t i = 0; i < input_rows_count; ++i) {
81
6
            std::string prompt;
82
6
            RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
83
84
6
            std::vector<float> float_result;
85
6
            RETURN_IF_ERROR(execute_single_request(prompt, float_result, config, adapter, context));
86
6
            _insert_embedding_result(*col_result, float_result);
87
6
        }
88
89
4
        block.replace_by_position(result, std::move(col_result));
90
4
        return Status::OK();
91
4
    }
92
93
    Status _execute_multimodal_embed(FunctionContext* context, Block& block,
94
                                     const ColumnNumbers& arguments, uint32_t result,
95
                                     size_t input_rows_count, const TAIResource& config,
96
12
                                     std::shared_ptr<AIAdapter>& adapter) const {
97
12
        auto col_result = ColumnArray::create(
98
12
                ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
99
100
12
        int64_t ttl_seconds = 3600;
101
12
        QueryContext* query_ctx = context->state()->get_query_ctx();
102
12
        if (query_ctx && query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
103
12
            ttl_seconds = query_ctx->query_options().file_presigned_url_ttl_seconds;
104
12
            if (ttl_seconds <= 0) {
105
2
                ttl_seconds = 3600;
106
2
            }
107
12
        }
108
109
12
        const ColumnWithTypeAndName& file_column = block.get_by_position(arguments[1]);
110
20
        for (size_t i = 0; i < input_rows_count; ++i) {
111
16
            rapidjson::Document file_input;
112
16
            RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
113
114
16
            MultimodalType media_type;
115
16
            RETURN_IF_ERROR(_infer_media_type(file_input, media_type));
116
117
12
            std::string media_url;
118
12
            RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, media_url));
119
120
8
            std::string request_body;
121
8
            RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type, media_url,
122
8
                                                                        request_body));
123
124
8
            std::vector<float> float_result;
125
8
            RETURN_IF_ERROR(execute_embedding_request(request_body, float_result, config, adapter,
126
8
                                                      context));
127
8
            _insert_embedding_result(*col_result, float_result);
128
8
        }
129
130
4
        block.replace_by_position(result, std::move(col_result));
131
4
        return Status::OK();
132
12
    }
133
134
    static void _insert_embedding_result(ColumnArray& col_array,
135
14
                                         const std::vector<float>& float_result) {
136
14
        auto& offsets = col_array.get_offsets();
137
14
        auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data());
138
14
        auto& nested_col =
139
14
                assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
140
14
        nested_col.reserve(nested_col.size() + float_result.size());
141
142
14
        size_t current_offset = nested_col.size();
143
14
        nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()),
144
14
                                        float_result.size());
145
14
        offsets.push_back(current_offset + float_result.size());
146
14
        auto& null_map = nested_nullable_col.get_null_map_column();
147
14
        null_map.insert_many_vals(0, float_result.size());
148
14
    }
149
150
48
    static bool _starts_with_ignore_case(std::string_view s, std::string_view prefix) {
151
48
        if (s.size() < prefix.size()) {
152
0
            return false;
153
0
        }
154
174
        return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, char b) {
155
174
            return std::tolower(static_cast<unsigned char>(a)) ==
156
174
                   std::tolower(static_cast<unsigned char>(b));
157
174
        });
158
48
    }
159
160
    static Status _infer_media_type(const rapidjson::Value& file_input,
161
16
                                    MultimodalType& media_type) {
162
16
        std::string content_type;
163
16
        RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", content_type));
164
165
14
        if (_starts_with_ignore_case(content_type, "image/")) {
166
8
            media_type = MultimodalType::IMAGE;
167
8
            return Status::OK();
168
8
        } else if (_starts_with_ignore_case(content_type, "video/")) {
169
2
            media_type = MultimodalType::VIDEO;
170
2
            return Status::OK();
171
4
        } else if (_starts_with_ignore_case(content_type, "audio/")) {
172
2
            media_type = MultimodalType::AUDIO;
173
2
            return Status::OK();
174
2
        }
175
176
2
        return Status::InvalidArgument("Unsupported content_type for EMBED: {}", content_type);
177
14
    }
178
179
    // Parse the FILE-like JSONB argument into a JSON object for downstream field reads.
180
    static Status _parse_file_input(const ColumnWithTypeAndName& file_column, size_t row_num,
181
16
                                    rapidjson::Document& file_input) {
182
16
        std::string file_json =
183
16
                JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
184
16
                                                  file_column.column->get_data_at(row_num).size);
185
16
        file_input.Parse(file_json.c_str());
186
16
        DORIS_CHECK(!file_input.HasParseError() && file_input.IsObject());
187
16
        return Status::OK();
188
16
    }
189
190
    // TODO(lzq): After support FILE type, We should use the interface provided by FILE to get the fields
191
    // replacing this function
192
    static Status _get_required_string_field(const rapidjson::Value& obj, const char* field_name,
193
38
                                             std::string& value) {
194
38
        auto iter = obj.FindMember(field_name);
195
38
        if (iter == obj.MemberEnd() || !iter->value.IsString()) {
196
6
            return Status::InvalidArgument(
197
6
                    "EMBED file json field '{}' is required and must be a string", field_name);
198
6
        }
199
32
        value = iter->value.GetString();
200
32
        if (value.empty()) {
201
0
            return Status::InvalidArgument("EMBED file json field '{}' can not be empty",
202
0
                                           field_name);
203
0
        }
204
32
        return Status::OK();
205
32
    }
206
207
    static Status init_s3_client_conf_from_json(const rapidjson::Value& file_input,
208
6
                                                S3ClientConf& s3_client_conf) {
209
6
        std::string endpoint;
210
6
        RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint", endpoint));
211
4
        std::string region;
212
4
        RETURN_IF_ERROR(_get_required_string_field(file_input, "region", region));
213
214
8
        auto get_optional_string_field = [&](const char* field_name, std::string& value) {
215
8
            auto iter = file_input.FindMember(field_name);
216
8
            if (iter == file_input.MemberEnd() || iter->value.IsNull()) {
217
0
                return;
218
0
            }
219
8
            DORIS_CHECK(iter->value.IsString());
220
8
            value = iter->value.GetString();
221
8
        };
222
223
2
        get_optional_string_field("ak", s3_client_conf.ak);
224
2
        get_optional_string_field("sk", s3_client_conf.sk);
225
2
        get_optional_string_field("role_arn", s3_client_conf.role_arn);
226
2
        get_optional_string_field("external_id", s3_client_conf.external_id);
227
2
        s3_client_conf.endpoint = endpoint;
228
2
        s3_client_conf.region = region;
229
230
2
        return Status::OK();
231
4
    }
232
233
    Status _resolve_media_url(const rapidjson::Value& file_input, int64_t ttl_seconds,
234
12
                              std::string& media_url) const {
235
12
        std::string uri;
236
12
        RETURN_IF_ERROR(_get_required_string_field(file_input, "uri", uri));
237
238
        // If it's a direct http/https URL, use it as-is
239
12
        if (_starts_with_ignore_case(uri, "http://") || _starts_with_ignore_case(uri, "https://")) {
240
6
            media_url = uri;
241
6
            return Status::OK();
242
6
        }
243
244
6
        S3ClientConf s3_client_conf;
245
6
        RETURN_IF_ERROR(init_s3_client_conf_from_json(file_input, s3_client_conf));
246
2
        auto s3_client = S3ClientFactory::instance().create(s3_client_conf);
247
2
        if (s3_client == nullptr) {
248
0
            return Status::InternalError("Failed to create S3 client for EMBED file input");
249
0
        }
250
251
2
        S3URI s3_uri(uri);
252
2
        RETURN_IF_ERROR(s3_uri.parse());
253
2
        std::string bucket = s3_uri.get_bucket();
254
2
        std::string key = s3_uri.get_key();
255
2
        DORIS_CHECK(!bucket.empty() && !key.empty());
256
2
        media_url = s3_client->generate_presigned_url({.bucket = bucket, .key = key}, ttl_seconds,
257
2
                                                      s3_client_conf);
258
2
        return Status::OK();
259
2
    }
260
};
261
262
}; // namespace doris