Coverage Report

Created: 2026-04-13 08:21

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/embed.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <glog/logging.h>
21
#include <rapidjson/document.h>
22
23
#include <string_view>
24
25
#include "core/data_type/data_type_nullable.h"
26
#include "core/data_type/primitive_type.h"
27
#include "exprs/function/ai/ai_functions.h"
28
#include "util/jsonb_utils.h"
29
#include "util/s3_uri.h"
30
#include "util/s3_util.h"
31
32
namespace doris {
33
class FunctionEmbed : public AIFunction<FunctionEmbed> {
34
public:
35
    static constexpr auto name = "embed";
36
37
    static constexpr size_t number_of_arguments = 2;
38
39
    static constexpr auto system_prompt = "";
40
41
1
    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
42
1
        return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
43
1
    }
44
45
    Status execute_with_adapter(FunctionContext* context, Block& block,
46
                                const ColumnNumbers& arguments, uint32_t result,
47
                                size_t input_rows_count, const TAIResource& config,
48
10
                                std::shared_ptr<AIAdapter>& adapter) const {
49
10
        if (arguments.size() != 2) {
50
1
            return Status::InvalidArgument("Function EMBED expects 2 arguments, but got {}",
51
1
                                           arguments.size());
52
1
        }
53
54
9
        PrimitiveType input_type =
55
9
                remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
56
9
        if (input_type == PrimitiveType::TYPE_JSONB) {
57
6
            return _execute_multimodal_embed(context, block, arguments, result, input_rows_count,
58
6
                                             config, adapter);
59
6
        }
60
3
        if (input_type == PrimitiveType::TYPE_STRING || input_type == PrimitiveType::TYPE_VARCHAR ||
61
3
            input_type == PrimitiveType::TYPE_CHAR) {
62
2
            return _execute_text_embed(context, block, arguments, result, input_rows_count, config,
63
2
                                       adapter);
64
2
        }
65
1
        return Status::InvalidArgument(
66
1
                "Function EMBED expects the second argument to be STRING or JSON, but got type {}",
67
1
                block.get_by_position(arguments[1]).type->get_name());
68
3
    }
69
70
19
    static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
71
72
private:
73
    Status _execute_text_embed(FunctionContext* context, Block& block,
74
                               const ColumnNumbers& arguments, uint32_t result,
75
                               size_t input_rows_count, const TAIResource& config,
76
2
                               std::shared_ptr<AIAdapter>& adapter) const {
77
2
        auto col_result = ColumnArray::create(
78
2
                ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
79
80
5
        for (size_t i = 0; i < input_rows_count; ++i) {
81
3
            std::string prompt;
82
3
            RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
83
84
3
            std::vector<float> float_result;
85
3
            RETURN_IF_ERROR(execute_single_request(prompt, float_result, config, adapter, context));
86
3
            _insert_embedding_result(*col_result, float_result);
87
3
        }
88
89
2
        block.replace_by_position(result, std::move(col_result));
90
2
        return Status::OK();
91
2
    }
92
93
    Status _execute_multimodal_embed(FunctionContext* context, Block& block,
94
                                     const ColumnNumbers& arguments, uint32_t result,
95
                                     size_t input_rows_count, const TAIResource& config,
96
6
                                     std::shared_ptr<AIAdapter>& adapter) const {
97
6
        auto col_result = ColumnArray::create(
98
6
                ColumnNullable::create(ColumnFloat32::create(), ColumnUInt8::create()));
99
100
6
        int64_t ttl_seconds = 3600;
101
6
        QueryContext* query_ctx = context->state()->get_query_ctx();
102
6
        if (query_ctx && query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
103
6
            ttl_seconds = query_ctx->query_options().file_presigned_url_ttl_seconds;
104
6
            if (ttl_seconds <= 0) {
105
1
                ttl_seconds = 3600;
106
1
            }
107
6
        }
108
109
6
        const ColumnWithTypeAndName& file_column = block.get_by_position(arguments[1]);
110
10
        for (size_t i = 0; i < input_rows_count; ++i) {
111
8
            rapidjson::Document file_input;
112
8
            RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
113
114
8
            MultimodalType media_type;
115
8
            RETURN_IF_ERROR(_infer_media_type(file_input, media_type));
116
117
6
            std::string media_url;
118
6
            RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, media_url));
119
120
4
            std::string request_body;
121
4
            RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type, media_url,
122
4
                                                                        request_body));
123
124
4
            std::vector<float> float_result;
125
4
            RETURN_IF_ERROR(execute_embedding_request(request_body, float_result, config, adapter,
126
4
                                                      context));
127
4
            _insert_embedding_result(*col_result, float_result);
128
4
        }
129
130
2
        block.replace_by_position(result, std::move(col_result));
131
2
        return Status::OK();
132
6
    }
133
134
    static void _insert_embedding_result(ColumnArray& col_array,
135
7
                                         const std::vector<float>& float_result) {
136
7
        auto& offsets = col_array.get_offsets();
137
7
        auto& nested_nullable_col = assert_cast<ColumnNullable&>(col_array.get_data());
138
7
        auto& nested_col =
139
7
                assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
140
7
        nested_col.reserve(nested_col.size() + float_result.size());
141
142
7
        size_t current_offset = nested_col.size();
143
7
        nested_col.insert_many_raw_data(reinterpret_cast<const char*>(float_result.data()),
144
7
                                        float_result.size());
145
7
        offsets.push_back(current_offset + float_result.size());
146
7
        auto& null_map = nested_nullable_col.get_null_map_column();
147
7
        null_map.insert_many_vals(0, float_result.size());
148
7
    }
149
150
24
    static bool _starts_with_ignore_case(std::string_view s, std::string_view prefix) {
151
24
        if (s.size() < prefix.size()) {
152
0
            return false;
153
0
        }
154
87
        return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, char b) {
155
87
            return std::tolower(static_cast<unsigned char>(a)) ==
156
87
                   std::tolower(static_cast<unsigned char>(b));
157
87
        });
158
24
    }
159
160
    static Status _infer_media_type(const rapidjson::Value& file_input,
161
8
                                    MultimodalType& media_type) {
162
8
        std::string content_type;
163
8
        RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", content_type));
164
165
7
        if (_starts_with_ignore_case(content_type, "image/")) {
166
4
            media_type = MultimodalType::IMAGE;
167
4
            return Status::OK();
168
4
        } else if (_starts_with_ignore_case(content_type, "video/")) {
169
1
            media_type = MultimodalType::VIDEO;
170
1
            return Status::OK();
171
2
        } else if (_starts_with_ignore_case(content_type, "audio/")) {
172
1
            media_type = MultimodalType::AUDIO;
173
1
            return Status::OK();
174
1
        }
175
176
1
        return Status::InvalidArgument("Unsupported content_type for EMBED: {}", content_type);
177
7
    }
178
179
    // Parse the FILE-like JSONB argument into a JSON object for downstream field reads.
180
    static Status _parse_file_input(const ColumnWithTypeAndName& file_column, size_t row_num,
181
8
                                    rapidjson::Document& file_input) {
182
8
        std::string file_json =
183
8
                JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
184
8
                                                  file_column.column->get_data_at(row_num).size);
185
8
        file_input.Parse(file_json.c_str());
186
8
        DORIS_CHECK(!file_input.HasParseError() && file_input.IsObject());
187
8
        return Status::OK();
188
8
    }
189
190
    // TODO(lzq): After support FILE type, We should use the interface provided by FILE to get the fields
191
    // replacing this function
192
    static Status _get_required_string_field(const rapidjson::Value& obj, const char* field_name,
193
19
                                             std::string& value) {
194
19
        auto iter = obj.FindMember(field_name);
195
19
        if (iter == obj.MemberEnd() || !iter->value.IsString()) {
196
3
            return Status::InvalidArgument(
197
3
                    "EMBED file json field '{}' is required and must be a string", field_name);
198
3
        }
199
16
        value = iter->value.GetString();
200
16
        if (value.empty()) {
201
0
            return Status::InvalidArgument("EMBED file json field '{}' can not be empty",
202
0
                                           field_name);
203
0
        }
204
16
        return Status::OK();
205
16
    }
206
207
    static Status init_s3_client_conf_from_json(const rapidjson::Value& file_input,
208
3
                                                S3ClientConf& s3_client_conf) {
209
3
        std::string endpoint;
210
3
        RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint", endpoint));
211
2
        std::string region;
212
2
        RETURN_IF_ERROR(_get_required_string_field(file_input, "region", region));
213
214
4
        auto get_optional_string_field = [&](const char* field_name, std::string& value) {
215
4
            auto iter = file_input.FindMember(field_name);
216
4
            if (iter == file_input.MemberEnd() || iter->value.IsNull()) {
217
0
                return;
218
0
            }
219
4
            DORIS_CHECK(iter->value.IsString());
220
4
            value = iter->value.GetString();
221
4
        };
222
223
1
        get_optional_string_field("ak", s3_client_conf.ak);
224
1
        get_optional_string_field("sk", s3_client_conf.sk);
225
1
        get_optional_string_field("role_arn", s3_client_conf.role_arn);
226
1
        get_optional_string_field("external_id", s3_client_conf.external_id);
227
1
        s3_client_conf.endpoint = endpoint;
228
1
        s3_client_conf.region = region;
229
230
1
        return Status::OK();
231
2
    }
232
233
    Status _resolve_media_url(const rapidjson::Value& file_input, int64_t ttl_seconds,
234
6
                              std::string& media_url) const {
235
6
        std::string uri;
236
6
        RETURN_IF_ERROR(_get_required_string_field(file_input, "uri", uri));
237
238
        // If it's a direct http/https URL, use it as-is
239
6
        if (_starts_with_ignore_case(uri, "http://") || _starts_with_ignore_case(uri, "https://")) {
240
3
            media_url = uri;
241
3
            return Status::OK();
242
3
        }
243
244
3
        S3ClientConf s3_client_conf;
245
3
        RETURN_IF_ERROR(init_s3_client_conf_from_json(file_input, s3_client_conf));
246
1
        auto s3_client = S3ClientFactory::instance().create(s3_client_conf);
247
1
        if (s3_client == nullptr) {
248
0
            return Status::InternalError("Failed to create S3 client for EMBED file input");
249
0
        }
250
251
1
        S3URI s3_uri(uri);
252
1
        RETURN_IF_ERROR(s3_uri.parse());
253
1
        std::string bucket = s3_uri.get_bucket();
254
1
        std::string key = s3_uri.get_key();
255
1
        DORIS_CHECK(!bucket.empty() && !key.empty());
256
1
        media_url = s3_client->generate_presigned_url({.bucket = bucket, .key = key}, ttl_seconds,
257
1
                                                      s3_client_conf);
258
1
        return Status::OK();
259
1
    }
260
};
261
262
}; // namespace doris