Coverage Report

Created: 2026-04-22 18:57

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_adapter.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/PaloInternalService_types.h>
21
#include <rapidjson/rapidjson.h>
22
23
#include <algorithm>
24
#include <cctype>
25
#include <memory>
26
#include <string>
27
#include <unordered_map>
28
#include <vector>
29
30
#include "common/status.h"
31
#include "core/string_buffer.hpp"
32
#include "rapidjson/document.h"
33
#include "rapidjson/stringbuffer.h"
34
#include "rapidjson/writer.h"
35
#include "service/http/http_client.h"
36
#include "service/http/http_headers.h"
37
#include "util/security.h"
38
39
namespace doris {
40
41
struct AIResource {
42
19
    AIResource() = default;
43
    AIResource(const TAIResource& tai)
44
12
            : endpoint(tai.endpoint),
45
12
              provider_type(tai.provider_type),
46
12
              model_name(tai.model_name),
47
12
              api_key(tai.api_key),
48
12
              temperature(tai.temperature),
49
12
              max_tokens(tai.max_tokens),
50
12
              max_retries(tai.max_retries),
51
12
              retry_delay_second(tai.retry_delay_second),
52
12
              anthropic_version(tai.anthropic_version),
53
12
              dimensions(tai.dimensions) {}
54
55
    std::string endpoint;
56
    std::string provider_type;
57
    std::string model_name;
58
    std::string api_key;
59
    double temperature;
60
    int64_t max_tokens;
61
    int32_t max_retries;
62
    int32_t retry_delay_second;
63
    std::string anthropic_version;
64
    int32_t dimensions;
65
66
1
    void serialize(BufferWritable& buf) const {
67
1
        buf.write_binary(endpoint);
68
1
        buf.write_binary(provider_type);
69
1
        buf.write_binary(model_name);
70
1
        buf.write_binary(api_key);
71
1
        buf.write_binary(temperature);
72
1
        buf.write_binary(max_tokens);
73
1
        buf.write_binary(max_retries);
74
1
        buf.write_binary(retry_delay_second);
75
1
        buf.write_binary(anthropic_version);
76
1
        buf.write_binary(dimensions);
77
1
    }
78
79
1
    void deserialize(BufferReadable& buf) {
80
1
        buf.read_binary(endpoint);
81
1
        buf.read_binary(provider_type);
82
1
        buf.read_binary(model_name);
83
1
        buf.read_binary(api_key);
84
1
        buf.read_binary(temperature);
85
1
        buf.read_binary(max_tokens);
86
1
        buf.read_binary(max_retries);
87
1
        buf.read_binary(retry_delay_second);
88
1
        buf.read_binary(anthropic_version);
89
1
        buf.read_binary(dimensions);
90
1
    }
91
};
92
93
enum class MultimodalType { IMAGE, VIDEO, AUDIO };
94
95
3
inline const char* multimodal_type_to_string(MultimodalType type) {
96
3
    switch (type) {
97
1
    case MultimodalType::IMAGE:
98
1
        return "image";
99
1
    case MultimodalType::VIDEO:
100
1
        return "video";
101
1
    case MultimodalType::AUDIO:
102
1
        return "audio";
103
3
    }
104
0
    return "unknown";
105
3
}
106
107
class AIAdapter {
108
public:
109
167
    virtual ~AIAdapter() = default;
110
111
    // Set authentication headers for the HTTP client
112
    virtual Status set_authentication(HttpClient* client) const = 0;
113
114
123
    virtual void init(const TAIResource& config) { _config = config; }
115
13
    virtual void init(const AIResource& config) {
116
13
        _config.endpoint = config.endpoint;
117
13
        _config.provider_type = config.provider_type;
118
13
        _config.model_name = config.model_name;
119
13
        _config.api_key = config.api_key;
120
13
        _config.temperature = config.temperature;
121
13
        _config.max_tokens = config.max_tokens;
122
13
        _config.max_retries = config.max_retries;
123
13
        _config.retry_delay_second = config.retry_delay_second;
124
13
        _config.anthropic_version = config.anthropic_version;
125
13
    }
126
127
    // Build request payload based on input text strings
128
    virtual Status build_request_payload(const std::vector<std::string>& inputs,
129
                                         const char* const system_prompt,
130
1
                                         std::string& request_body) const {
131
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
132
1
    }
133
134
    // Parse response from AI service and extract generated text results
135
    virtual Status parse_response(const std::string& response_body,
136
1
                                  std::vector<std::string>& results) const {
137
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
138
1
    }
139
140
    virtual Status build_embedding_request(const std::vector<std::string>& inputs,
141
0
                                           std::string& request_body) const {
142
0
        return embed_not_supported_status();
143
0
    }
144
145
    virtual Status build_multimodal_embedding_request(
146
            const std::vector<MultimodalType>& /*media_types*/,
147
            const std::vector<std::string>& /*media_urls*/,
148
            const std::vector<std::string>& /*media_content_types*/,
149
0
            std::string& /*request_body*/) const {
150
0
        return Status::NotSupported("{} does not support multimodal Embed feature.",
151
0
                                    _config.provider_type);
152
0
    }
153
154
    virtual Status parse_embedding_response(const std::string& response_body,
155
0
                                            std::vector<std::vector<float>>& results) const {
156
0
        return embed_not_supported_status();
157
0
    }
158
159
protected:
160
    TAIResource _config;
161
162
4
    Status embed_not_supported_status() const {
163
4
        return Status::NotSupported(
164
4
                "{} does not support the Embed feature. Currently supported providers are "
165
4
                "OpenAI, Gemini, Voyage, Jina, Qwen, and Minimax.",
166
4
                _config.provider_type);
167
4
    }
168
169
    // Appends one provider-parsed text result to `results`.
170
    // The adapter has already parsed the provider's outer response envelope before calling here.
171
    // Example:
172
    // provider response -> choices[0].message.content = "[\"1\",\"0\",\"1\"]"
173
    // this helper       -> appends "1", "0", "1" into `results`
174
    static Status append_parsed_text_result(std::string_view text,
175
87
                                            std::vector<std::string>& results) {
176
87
        size_t begin = 0;
177
87
        size_t end = text.size();
178
117
        while (begin < end && std::isspace(static_cast<unsigned char>(text[begin]))) {
179
30
            ++begin;
180
30
        }
181
111
        while (begin < end && std::isspace(static_cast<unsigned char>(text[end - 1]))) {
182
24
            --end;
183
24
        }
184
185
87
        if (begin < end && text[begin] == '[' && text[end - 1] == ']') {
186
66
            rapidjson::Document doc;
187
66
            doc.Parse(text.data() + begin, end - begin);
188
66
            if (!doc.HasParseError() && doc.IsArray()) {
189
139
                for (rapidjson::SizeType i = 0; i < doc.Size(); ++i) {
190
76
                    if (!doc[i].IsString()) {
191
1
                        return Status::InternalError(
192
1
                                "Invalid batch result format, array element {} is not a string", i);
193
1
                    }
194
75
                    results.emplace_back(doc[i].GetString(), doc[i].GetStringLength());
195
75
                }
196
63
                return Status::OK();
197
64
            }
198
66
        }
199
200
23
        results.emplace_back(text.data(), text.size());
201
23
        return Status::OK();
202
87
    }
203
204
    // return true if the model support dimension parameter
205
1
    virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
206
207
    // Different providers may have different dimension parameter names.
208
0
    virtual std::string get_dimension_param_name() const { return "dimensions"; }
209
210
    virtual void add_dimension_params(rapidjson::Value& doc,
211
20
                                      rapidjson::Document::AllocatorType& allocator) const {
212
20
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
213
13
            std::string param_name = get_dimension_param_name();
214
13
            rapidjson::Value name(param_name.c_str(), allocator);
215
13
            doc.AddMember(name, _config.dimensions, allocator);
216
13
        }
217
20
    }
218
219
    // Validates common multimodal embedding request invariants shared by providers.
220
    Status validate_multimodal_embedding_inputs(
221
            std::string_view provider_name, const std::vector<MultimodalType>& media_types,
222
            const std::vector<std::string>& media_urls,
223
16
            std::initializer_list<MultimodalType> supported_types) const {
224
16
        if (media_urls.empty()) {
225
1
            return Status::InvalidArgument("{} multimodal embed inputs can not be empty",
226
1
                                           provider_name);
227
1
        }
228
15
        if (media_types.size() != media_urls.size()) {
229
1
            return Status::InvalidArgument(
230
1
                    "{} multimodal embed input size mismatch, media_types={}, media_urls={}",
231
1
                    provider_name, media_types.size(), media_urls.size());
232
1
        }
233
19
        for (MultimodalType media_type : media_types) {
234
19
            bool supported = false;
235
31
            for (MultimodalType supported_type : supported_types) {
236
31
                if (media_type == supported_type) {
237
18
                    supported = true;
238
18
                    break;
239
18
                }
240
31
            }
241
19
            if (!supported) [[unlikely]] {
242
1
                return Status::InvalidArgument(
243
1
                        "{} only supports {} multimodal embed, got {}", provider_name,
244
1
                        supported_multimodal_types_to_string(supported_types),
245
1
                        multimodal_type_to_string(media_type));
246
1
            }
247
19
        }
248
13
        return Status::OK();
249
14
    }
250
251
    static std::string supported_multimodal_types_to_string(
252
1
            std::initializer_list<MultimodalType> supported_types) {
253
1
        std::string result;
254
2
        for (MultimodalType type : supported_types) {
255
2
            if (!result.empty()) {
256
1
                result += "/";
257
1
            }
258
2
            result += multimodal_type_to_string(type);
259
2
        }
260
1
        return result;
261
1
    }
262
};
263
264
// Most LLM-providers' Embedding formats are based on VoyageAI.
265
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
266
class VoyageAIAdapter : public AIAdapter {
267
public:
268
2
    Status set_authentication(HttpClient* client) const override {
269
2
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
270
2
        client->set_content_type("application/json");
271
272
2
        return Status::OK();
273
2
    }
274
275
    Status build_embedding_request(const std::vector<std::string>& inputs,
276
8
                                   std::string& request_body) const override {
277
8
        rapidjson::Document doc;
278
8
        doc.SetObject();
279
8
        auto& allocator = doc.GetAllocator();
280
281
        /*{
282
            "model": "xxx",
283
            "input": [
284
              "xxx",
285
              "xxx",
286
              ...
287
            ],
288
            "output_dimensions": 512
289
        }*/
290
8
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
291
8
        add_dimension_params(doc, allocator);
292
293
8
        rapidjson::Value input(rapidjson::kArrayType);
294
8
        for (const auto& msg : inputs) {
295
8
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
296
8
        }
297
8
        doc.AddMember("input", input, allocator);
298
299
8
        rapidjson::StringBuffer buffer;
300
8
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
301
8
        doc.Accept(writer);
302
8
        request_body = buffer.GetString();
303
304
8
        return Status::OK();
305
8
    }
306
307
    Status build_multimodal_embedding_request(
308
            const std::vector<MultimodalType>& media_types,
309
            const std::vector<std::string>& media_urls,
310
            const std::vector<std::string>& /*media_content_types*/,
311
2
            std::string& request_body) const override {
312
2
        RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
313
2
                "VoyageAI", media_types, media_urls,
314
2
                {MultimodalType::IMAGE, MultimodalType::VIDEO}));
315
2
        if (_config.dimensions != -1) {
316
2
            LOG(WARNING) << "VoyageAI multimodal embedding currently ignores dimensions parameter, "
317
2
                         << "model=" << _config.model_name << ", dimensions=" << _config.dimensions;
318
2
        }
319
320
2
        rapidjson::Document doc;
321
2
        doc.SetObject();
322
2
        auto& allocator = doc.GetAllocator();
323
324
        /*{
325
            "inputs": [
326
              {
327
                "content": [
328
                  {"type": "image_url", "image_url": "<url>"}
329
                ]
330
              },
331
              {
332
                "content": [
333
                  {"type": "video_url", "video_url": "<url>"}
334
                ]
335
              }
336
            ],
337
            "model": "voyage-multimodal-3.5"
338
        }*/
339
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
340
341
2
        rapidjson::Value request_inputs(rapidjson::kArrayType);
342
5
        for (size_t i = 0; i < media_urls.size(); ++i) {
343
3
            rapidjson::Value input(rapidjson::kObjectType);
344
3
            rapidjson::Value content(rapidjson::kArrayType);
345
3
            rapidjson::Value media_item(rapidjson::kObjectType);
346
3
            if (media_types[i] == MultimodalType::IMAGE) {
347
1
                media_item.AddMember("type", "image_url", allocator);
348
1
                media_item.AddMember("image_url",
349
1
                                     rapidjson::Value(media_urls[i].c_str(), allocator), allocator);
350
2
            } else {
351
2
                media_item.AddMember("type", "video_url", allocator);
352
2
                media_item.AddMember("video_url",
353
2
                                     rapidjson::Value(media_urls[i].c_str(), allocator), allocator);
354
2
            }
355
3
            content.PushBack(media_item, allocator);
356
3
            input.AddMember("content", content, allocator);
357
3
            request_inputs.PushBack(input, allocator);
358
3
        }
359
360
2
        doc.AddMember("inputs", request_inputs, allocator);
361
362
2
        rapidjson::StringBuffer buffer;
363
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
364
2
        doc.Accept(writer);
365
2
        request_body = buffer.GetString();
366
2
        return Status::OK();
367
2
    }
368
369
    Status parse_embedding_response(const std::string& response_body,
370
5
                                    std::vector<std::vector<float>>& results) const override {
371
5
        rapidjson::Document doc;
372
5
        doc.Parse(response_body.c_str());
373
374
5
        if (doc.HasParseError() || !doc.IsObject()) {
375
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
376
1
                                         response_body);
377
1
        }
378
4
        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
379
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
380
1
                                         response_body);
381
1
        }
382
383
        /*{
384
            "data":[
385
              {
386
                "object": "embedding",
387
                "embedding": [...], <- only need this
388
                "index": 0
389
              },
390
              {
391
                "object": "embedding",
392
                "embedding": [...],
393
                "index": 1
394
              }, ...
395
            ],
396
            "model"....
397
        }*/
398
3
        const auto& data = doc["data"];
399
3
        results.reserve(data.Size());
400
7
        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
401
5
            if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
402
1
                return Status::InternalError("Invalid {} response format: {}",
403
1
                                             _config.provider_type, response_body);
404
1
            }
405
406
4
            std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
407
4
                           std::back_inserter(results.emplace_back()),
408
10
                           [](const auto& val) { return val.GetFloat(); });
409
4
        }
410
411
2
        return Status::OK();
412
3
    }
413
414
protected:
415
4
    bool supports_dimension_param(const std::string& model_name) const override {
416
4
        static const std::unordered_set<std::string> no_dimension_models = {
417
4
                "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
418
4
                "voyage-multimodal-3"};
419
4
        return !no_dimension_models.contains(model_name);
420
4
    }
421
422
1
    std::string get_dimension_param_name() const override { return "output_dimension"; }
423
};
424
425
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
426
class LocalAdapter : public AIAdapter {
427
public:
428
    // Local deployments typically don't need authentication
429
2
    Status set_authentication(HttpClient* client) const override {
430
2
        client->set_content_type("application/json");
431
2
        return Status::OK();
432
2
    }
433
434
    Status build_request_payload(const std::vector<std::string>& inputs,
435
                                 const char* const system_prompt,
436
3
                                 std::string& request_body) const override {
437
3
        rapidjson::Document doc;
438
3
        doc.SetObject();
439
3
        auto& allocator = doc.GetAllocator();
440
441
3
        std::string end_point = _config.endpoint;
442
3
        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
443
2
            RETURN_IF_ERROR(
444
2
                    build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
445
2
        } else {
446
1
            RETURN_IF_ERROR(
447
1
                    build_default_request(doc, allocator, inputs, system_prompt, request_body));
448
1
        }
449
450
3
        rapidjson::StringBuffer buffer;
451
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
452
3
        doc.Accept(writer);
453
3
        request_body = buffer.GetString();
454
455
3
        return Status::OK();
456
3
    }
457
458
    Status parse_response(const std::string& response_body,
459
7
                          std::vector<std::string>& results) const override {
460
7
        rapidjson::Document doc;
461
7
        doc.Parse(response_body.c_str());
462
463
7
        if (doc.HasParseError() || !doc.IsObject()) {
464
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
465
1
                                         response_body);
466
1
        }
467
468
        // Handle various response formats from local LLMs
469
        // Format 1: OpenAI-compatible format with choices/message/content
470
6
        if (doc.HasMember("choices") && doc["choices"].IsArray()) {
471
1
            const auto& choices = doc["choices"];
472
1
            results.reserve(choices.Size());
473
474
2
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
475
1
                if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
476
1
                    choices[i]["message"]["content"].IsString()) {
477
1
                    RETURN_IF_ERROR(append_parsed_text_result(
478
1
                            choices[i]["message"]["content"].GetString(), results));
479
1
                } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
480
                    // Some local LLMs use a simpler format
481
0
                    RETURN_IF_ERROR(
482
0
                            append_parsed_text_result(choices[i]["text"].GetString(), results));
483
0
                }
484
1
            }
485
5
        } else if (doc.HasMember("text") && doc["text"].IsString()) {
486
            // Format 2: Simple response with just "text" or "content" field
487
1
            RETURN_IF_ERROR(append_parsed_text_result(doc["text"].GetString(), results));
488
4
        } else if (doc.HasMember("content") && doc["content"].IsString()) {
489
1
            RETURN_IF_ERROR(append_parsed_text_result(doc["content"].GetString(), results));
490
3
        } else if (doc.HasMember("response") && doc["response"].IsString()) {
491
            // Format 3: Response field (Ollama `generate` format)
492
1
            RETURN_IF_ERROR(append_parsed_text_result(doc["response"].GetString(), results));
493
2
        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
494
2
                   doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
495
            // Format 4: message/content field (Ollama `chat` format)
496
1
            RETURN_IF_ERROR(
497
1
                    append_parsed_text_result(doc["message"]["content"].GetString(), results));
498
1
        } else {
499
1
            return Status::NotSupported("Unsupported response format from local AI.");
500
1
        }
501
5
        return Status::OK();
502
6
    }
503
504
    Status build_embedding_request(const std::vector<std::string>& inputs,
505
1
                                   std::string& request_body) const override {
506
1
        rapidjson::Document doc;
507
1
        doc.SetObject();
508
1
        auto& allocator = doc.GetAllocator();
509
510
1
        if (!_config.model_name.empty()) {
511
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
512
1
                          allocator);
513
1
        }
514
515
1
        add_dimension_params(doc, allocator);
516
517
1
        rapidjson::Value input(rapidjson::kArrayType);
518
1
        for (const auto& msg : inputs) {
519
1
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
520
1
        }
521
1
        doc.AddMember("input", input, allocator);
522
523
1
        rapidjson::StringBuffer buffer;
524
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
525
1
        doc.Accept(writer);
526
1
        request_body = buffer.GetString();
527
528
1
        return Status::OK();
529
1
    }
530
531
    Status build_multimodal_embedding_request(
532
            const std::vector<MultimodalType>& /*media_types*/,
533
            const std::vector<std::string>& /*media_urls*/,
534
            const std::vector<std::string>& /*media_content_types*/,
535
0
            std::string& /*request_body*/) const override {
536
0
        return Status::NotSupported("{} does not support multimodal Embed feature.",
537
0
                                    _config.provider_type);
538
0
    }
539
540
    Status parse_embedding_response(const std::string& response_body,
541
3
                                    std::vector<std::vector<float>>& results) const override {
542
3
        rapidjson::Document doc;
543
3
        doc.Parse(response_body.c_str());
544
545
3
        if (doc.HasParseError() || !doc.IsObject()) {
546
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
547
0
                                         response_body);
548
0
        }
549
550
        // parse different response format
551
3
        rapidjson::Value embedding;
552
3
        if (doc.HasMember("data") && doc["data"].IsArray()) {
553
            // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
554
1
            const auto& data = doc["data"];
555
1
            results.reserve(data.Size());
556
3
            for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
557
2
                if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
558
0
                    return Status::InternalError("Invalid {} response format",
559
0
                                                 _config.provider_type);
560
0
                }
561
562
2
                std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
563
2
                               std::back_inserter(results.emplace_back()),
564
5
                               [](const auto& val) { return val.GetFloat(); });
565
2
            }
566
2
        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
567
            // "embeddings":[[0.1, 0.2, ...]]
568
1
            results.reserve(1);
569
2
            for (int i = 0; i < doc["embeddings"].Size(); i++) {
570
1
                embedding = doc["embeddings"][i];
571
1
                std::transform(embedding.Begin(), embedding.End(),
572
1
                               std::back_inserter(results.emplace_back()),
573
2
                               [](const auto& val) { return val.GetFloat(); });
574
1
            }
575
1
        } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
576
            // "embedding":[0.1, 0.2, ...]
577
1
            results.reserve(1);
578
1
            embedding = doc["embedding"];
579
1
            std::transform(embedding.Begin(), embedding.End(),
580
1
                           std::back_inserter(results.emplace_back()),
581
3
                           [](const auto& val) { return val.GetFloat(); });
582
1
        } else {
583
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
584
0
                                         response_body);
585
0
        }
586
587
3
        return Status::OK();
588
3
    }
589
590
private:
591
    Status build_ollama_request(rapidjson::Document& doc,
592
                                rapidjson::Document::AllocatorType& allocator,
593
                                const std::vector<std::string>& inputs,
594
2
                                const char* const system_prompt, std::string& request_body) const {
595
        /*
596
        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
597
        {
598
            "model": <model_name>,
599
            "stream": false,
600
            "think": false,
601
            "options": {
602
                "temperature": <temperature>,
603
                "max_token": <max_token>
604
            },
605
            "messages": [
606
                {"role": "system", "content": <system_prompt>},
607
                {"role": "user", "content": <user_prompt>}
608
            ]
609
        }
610
        
611
        for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
612
        {
613
            "model": <model_name>,
614
            "stream": false,
615
            "think": false
616
            "options": {
617
                "temperature": <temperature>,
618
                "max_token": <max_token>
619
            },
620
            "system": <system_prompt>,
621
            "prompt": <user_prompt>
622
        }
623
        */
624
625
        // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
626
        // The rest remains identical.
627
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
628
2
        doc.AddMember("stream", false, allocator);
629
2
        doc.AddMember("think", false, allocator);
630
631
        // option section
632
2
        rapidjson::Value options(rapidjson::kObjectType);
633
2
        if (_config.temperature != -1) {
634
2
            options.AddMember("temperature", _config.temperature, allocator);
635
2
        }
636
2
        if (_config.max_tokens != -1) {
637
2
            options.AddMember("max_token", _config.max_tokens, allocator);
638
2
        }
639
2
        doc.AddMember("options", options, allocator);
640
641
        // prompt section
642
2
        if (_config.endpoint.ends_with("chat")) {
643
1
            rapidjson::Value messages(rapidjson::kArrayType);
644
1
            if (system_prompt && *system_prompt) {
645
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
646
1
                sys_msg.AddMember("role", "system", allocator);
647
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
648
1
                messages.PushBack(sys_msg, allocator);
649
1
            }
650
1
            for (const auto& input : inputs) {
651
1
                rapidjson::Value message(rapidjson::kObjectType);
652
1
                message.AddMember("role", "user", allocator);
653
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
654
1
                messages.PushBack(message, allocator);
655
1
            }
656
1
            doc.AddMember("messages", messages, allocator);
657
1
        } else {
658
1
            if (system_prompt && *system_prompt) {
659
1
                doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
660
1
            }
661
1
            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
662
1
        }
663
664
2
        return Status::OK();
665
2
    }
666
667
    Status build_default_request(rapidjson::Document& doc,
668
                                 rapidjson::Document::AllocatorType& allocator,
669
                                 const std::vector<std::string>& inputs,
670
1
                                 const char* const system_prompt, std::string& request_body) const {
671
        /*
672
        Default format(OpenAI-compatible):
673
        {
674
            "model": <model_name>,
675
            "temperature": <temperature>,
676
            "max_tokens": <max_tokens>,
677
            "messages": [
678
                {"role": "system", "content": <system_prompt>},
679
                {"role": "user", "content": <user_prompt>}
680
            ]
681
        }
682
        */
683
684
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
685
686
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
687
1
        if (_config.temperature != -1) {
688
1
            doc.AddMember("temperature", _config.temperature, allocator);
689
1
        }
690
1
        if (_config.max_tokens != -1) {
691
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
692
1
        }
693
694
1
        rapidjson::Value messages(rapidjson::kArrayType);
695
1
        if (system_prompt && *system_prompt) {
696
1
            rapidjson::Value sys_msg(rapidjson::kObjectType);
697
1
            sys_msg.AddMember("role", "system", allocator);
698
1
            sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
699
1
            messages.PushBack(sys_msg, allocator);
700
1
        }
701
1
        for (const auto& input : inputs) {
702
1
            rapidjson::Value message(rapidjson::kObjectType);
703
1
            message.AddMember("role", "user", allocator);
704
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
705
1
            messages.PushBack(message, allocator);
706
1
        }
707
1
        doc.AddMember("messages", messages, allocator);
708
1
        return Status::OK();
709
1
    }
710
};
711
712
// The OpenAI API format can be reused with some compatible AIs.
713
class OpenAIAdapter : public VoyageAIAdapter {
714
public:
715
13
    Status set_authentication(HttpClient* client) const override {
716
13
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
717
13
        client->set_content_type("application/json");
718
719
13
        return Status::OK();
720
13
    }
721
722
    Status build_request_payload(const std::vector<std::string>& inputs,
723
                                 const char* const system_prompt,
724
4
                                 std::string& request_body) const override {
725
4
        rapidjson::Document doc;
726
4
        doc.SetObject();
727
4
        auto& allocator = doc.GetAllocator();
728
729
4
        if (_config.endpoint.ends_with("responses")) {
730
            /*{
731
              "model": "gpt-4.1-mini",
732
              "input": [
733
                {"role": "system", "content": "system_prompt here"},
734
                {"role": "user", "content": "xxx"}
735
              ],
736
              "temperature": 0.7,
737
              "max_output_tokens": 150
738
            }*/
739
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
740
1
                          allocator);
741
742
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
743
1
            if (_config.temperature != -1) {
744
1
                doc.AddMember("temperature", _config.temperature, allocator);
745
1
            }
746
1
            if (_config.max_tokens != -1) {
747
1
                doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
748
1
            }
749
750
            // input
751
1
            rapidjson::Value input(rapidjson::kArrayType);
752
1
            if (system_prompt && *system_prompt) {
753
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
754
1
                sys_msg.AddMember("role", "system", allocator);
755
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
756
1
                input.PushBack(sys_msg, allocator);
757
1
            }
758
1
            for (const auto& msg : inputs) {
759
1
                rapidjson::Value message(rapidjson::kObjectType);
760
1
                message.AddMember("role", "user", allocator);
761
1
                message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
762
1
                input.PushBack(message, allocator);
763
1
            }
764
1
            doc.AddMember("input", input, allocator);
765
3
        } else {
766
            /*{
767
              "model": "gpt-4",
768
              "messages": [
769
                {"role": "system", "content": "system_prompt here"},
770
                {"role": "user", "content": "xxx"}
771
              ],
772
              "temperature": x,
773
              "max_tokens": x,
774
            }*/
775
3
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
776
3
                          allocator);
777
778
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
779
3
            if (_config.temperature != -1) {
780
3
                doc.AddMember("temperature", _config.temperature, allocator);
781
3
            }
782
3
            if (_config.max_tokens != -1) {
783
3
                doc.AddMember("max_tokens", _config.max_tokens, allocator);
784
3
            }
785
786
3
            rapidjson::Value messages(rapidjson::kArrayType);
787
3
            if (system_prompt && *system_prompt) {
788
3
                rapidjson::Value sys_msg(rapidjson::kObjectType);
789
3
                sys_msg.AddMember("role", "system", allocator);
790
3
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
791
3
                messages.PushBack(sys_msg, allocator);
792
3
            }
793
3
            for (const auto& input : inputs) {
794
3
                rapidjson::Value message(rapidjson::kObjectType);
795
3
                message.AddMember("role", "user", allocator);
796
3
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
797
3
                messages.PushBack(message, allocator);
798
3
            }
799
3
            doc.AddMember("messages", messages, allocator);
800
3
        }
801
802
4
        rapidjson::StringBuffer buffer;
803
4
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
804
4
        doc.Accept(writer);
805
4
        request_body = buffer.GetString();
806
807
4
        return Status::OK();
808
4
    }
809
810
    Status parse_response(const std::string& response_body,
811
10
                          std::vector<std::string>& results) const override {
812
10
        rapidjson::Document doc;
813
10
        doc.Parse(response_body.c_str());
814
815
10
        if (doc.HasParseError() || !doc.IsObject()) {
816
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
817
1
                                         response_body);
818
1
        }
819
820
9
        if (doc.HasMember("output") && doc["output"].IsArray()) {
821
            /// for responses endpoint
822
            /*{
823
              "output": [
824
                {
825
                  "id": "msg_123",
826
                  "type": "message",
827
                  "role": "assistant",
828
                  "content": [
829
                    {
830
                      "type": "text",
831
                      "text": "result text here"   <- result
832
                    }
833
                  ]
834
                }
835
              ]
836
            }*/
837
1
            const auto& output = doc["output"];
838
1
            results.reserve(output.Size());
839
840
2
            for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
841
1
                if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
842
1
                    output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
843
1
                    !output[i]["content"][0]["text"].IsString()) {
844
0
                    return Status::InternalError("Invalid output format in {} response: {}",
845
0
                                                 _config.provider_type, response_body);
846
0
                }
847
848
1
                RETURN_IF_ERROR(append_parsed_text_result(
849
1
                        output[i]["content"][0]["text"].GetString(), results));
850
1
            }
851
8
        } else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
852
            /// for completions endpoint
853
            /*{
854
              "object": "chat.completion",
855
              "model": "gpt-4",
856
              "choices": [
857
                {
858
                  ...
859
                  "message": {
860
                    "role": "assistant",
861
                    "content": "xxx"      <- result
862
                  },
863
                  ...
864
                }
865
              ],
866
              ...
867
            }*/
868
7
            const auto& choices = doc["choices"];
869
7
            results.reserve(choices.Size());
870
871
12
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
872
7
                if (!choices[i].HasMember("message") ||
873
7
                    !choices[i]["message"].HasMember("content") ||
874
7
                    !choices[i]["message"]["content"].IsString()) {
875
2
                    return Status::InternalError("Invalid choice format in {} response: {}",
876
2
                                                 _config.provider_type, response_body);
877
2
                }
878
879
5
                RETURN_IF_ERROR(append_parsed_text_result(
880
5
                        choices[i]["message"]["content"].GetString(), results));
881
5
            }
882
7
        } else {
883
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
884
1
                                         response_body);
885
1
        }
886
887
6
        return Status::OK();
888
9
    }
889
890
    Status build_multimodal_embedding_request(
891
            const std::vector<MultimodalType>& /*media_types*/,
892
            const std::vector<std::string>& /*media_urls*/,
893
            const std::vector<std::string>& /*media_content_types*/,
894
1
            std::string& /*request_body*/) const override {
895
1
        return Status::NotSupported("{} does not support multimodal Embed feature.",
896
1
                                    _config.provider_type);
897
1
    }
898
899
protected:
900
2
    bool supports_dimension_param(const std::string& model_name) const override {
901
2
        return !(model_name == "text-embedding-ada-002");
902
2
    }
903
904
2
    std::string get_dimension_param_name() const override { return "dimensions"; }
905
};
906
907
class DeepSeekAdapter : public OpenAIAdapter {
908
public:
909
    Status build_embedding_request(const std::vector<std::string>& inputs,
910
1
                                   std::string& request_body) const override {
911
1
        return embed_not_supported_status();
912
1
    }
913
914
    Status parse_embedding_response(const std::string& response_body,
915
1
                                    std::vector<std::vector<float>>& results) const override {
916
1
        return embed_not_supported_status();
917
1
    }
918
};
919
920
class MoonShotAdapter : public OpenAIAdapter {
921
public:
922
    Status build_embedding_request(const std::vector<std::string>& inputs,
923
1
                                   std::string& request_body) const override {
924
1
        return embed_not_supported_status();
925
1
    }
926
927
    Status parse_embedding_response(const std::string& response_body,
928
1
                                    std::vector<std::vector<float>>& results) const override {
929
1
        return embed_not_supported_status();
930
1
    }
931
};
932
933
class MinimaxAdapter : public OpenAIAdapter {
934
public:
935
    Status build_embedding_request(const std::vector<std::string>& inputs,
936
1
                                   std::string& request_body) const override {
937
1
        rapidjson::Document doc;
938
1
        doc.SetObject();
939
1
        auto& allocator = doc.GetAllocator();
940
941
        /*{
942
          "text": ["xxx", "xxx", ...],
943
          "model": "embo-1",
944
          "type": "db"
945
        }*/
946
1
        rapidjson::Value texts(rapidjson::kArrayType);
947
1
        for (const auto& input : inputs) {
948
1
            texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
949
1
        }
950
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
951
1
        doc.AddMember("texts", texts, allocator);
952
1
        doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
953
954
1
        rapidjson::StringBuffer buffer;
955
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
956
1
        doc.Accept(writer);
957
1
        request_body = buffer.GetString();
958
959
1
        return Status::OK();
960
1
    }
961
};
962
963
class ZhipuAdapter : public OpenAIAdapter {
964
protected:
965
2
    bool supports_dimension_param(const std::string& model_name) const override {
966
2
        return !(model_name == "embedding-2");
967
2
    }
968
};
969
970
class QwenAdapter : public OpenAIAdapter {
971
public:
972
    Status build_multimodal_embedding_request(
973
            const std::vector<MultimodalType>& media_types,
974
            const std::vector<std::string>& media_urls,
975
            const std::vector<std::string>& /*media_content_types*/,
976
4
            std::string& request_body) const override {
977
4
        RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
978
4
                "QWEN", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO}));
979
980
3
        rapidjson::Document doc;
981
3
        doc.SetObject();
982
3
        auto& allocator = doc.GetAllocator();
983
984
        /*{
985
            "model": "tongyi-embedding-vision-plus",
986
            "input": {
987
              "contents": [
988
                {"image": "<url>"},
989
                {"video": "<url>"}
990
              ]
991
            }
992
            "parameters": {
993
              "dimension": 512
994
            }
995
        }*/
996
3
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
997
3
        rapidjson::Value input(rapidjson::kObjectType);
998
3
        rapidjson::Value contents(rapidjson::kArrayType);
999
1000
7
        for (size_t i = 0; i < media_urls.size(); ++i) {
1001
4
            rapidjson::Value media_item(rapidjson::kObjectType);
1002
4
            if (media_types[i] == MultimodalType::IMAGE) {
1003
2
                media_item.AddMember("image", rapidjson::Value(media_urls[i].c_str(), allocator),
1004
2
                                     allocator);
1005
2
            } else {
1006
2
                media_item.AddMember("video", rapidjson::Value(media_urls[i].c_str(), allocator),
1007
2
                                     allocator);
1008
2
            }
1009
4
            contents.PushBack(media_item, allocator);
1010
4
        }
1011
1012
3
        input.AddMember("contents", contents, allocator);
1013
3
        doc.AddMember("input", input, allocator);
1014
3
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
1015
3
            rapidjson::Value parameters(rapidjson::kObjectType);
1016
3
            std::string param_name = get_dimension_param_name();
1017
3
            rapidjson::Value dimension_name(param_name.c_str(), allocator);
1018
3
            parameters.AddMember(dimension_name, _config.dimensions, allocator);
1019
3
            doc.AddMember("parameters", parameters, allocator);
1020
3
        }
1021
1022
3
        rapidjson::StringBuffer buffer;
1023
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1024
3
        doc.Accept(writer);
1025
3
        request_body = buffer.GetString();
1026
3
        return Status::OK();
1027
4
    }
1028
1029
    Status parse_embedding_response(const std::string& response_body,
1030
0
                                    std::vector<std::vector<float>>& results) const override {
1031
0
        rapidjson::Document doc;
1032
0
        doc.Parse(response_body.c_str());
1033
1034
0
        if (doc.HasParseError() || !doc.IsObject()) [[unlikely]] {
1035
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1036
0
                                         response_body);
1037
0
        }
1038
        // Qwen multimodal embedding usually returns:
1039
        // {
1040
        //   "output": {
1041
        //     "embeddings": [
1042
        //       {"index":0, "embedding":[...], "type":"image|video|text"},
1043
        //       ...
1044
        //     ]
1045
        //   }
1046
        // }
1047
        //
1048
        // In text-only or compatibility endpoints, Qwen may also return OpenAI-style
1049
        // "data":[{"embedding":[...]}]. For compatibility we first parse native
1050
        // output.embeddings and then fallback to OpenAIAdapter parser.
1051
0
        if (doc.HasMember("output") && doc["output"].IsObject() &&
1052
0
            doc["output"].HasMember("embeddings") && doc["output"]["embeddings"].IsArray()) {
1053
0
            const auto& embeddings = doc["output"]["embeddings"];
1054
0
            results.reserve(embeddings.Size());
1055
0
            for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) {
1056
0
                if (!embeddings[i].HasMember("embedding") ||
1057
0
                    !embeddings[i]["embedding"].IsArray()) {
1058
0
                    return Status::InternalError("Invalid {} response format: {}",
1059
0
                                                 _config.provider_type, response_body);
1060
0
                }
1061
0
                std::transform(embeddings[i]["embedding"].Begin(), embeddings[i]["embedding"].End(),
1062
0
                               std::back_inserter(results.emplace_back()),
1063
0
                               [](const auto& val) { return val.GetFloat(); });
1064
0
            }
1065
0
            return Status::OK();
1066
0
        }
1067
0
        return OpenAIAdapter::parse_embedding_response(response_body, results);
1068
0
    }
1069
1070
protected:
1071
5
    bool supports_dimension_param(const std::string& model_name) const override {
1072
5
        static const std::unordered_set<std::string> no_dimension_models = {
1073
5
                "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
1074
5
        return !no_dimension_models.contains(model_name);
1075
5
    }
1076
1077
4
    std::string get_dimension_param_name() const override { return "dimension"; }
1078
};
1079
1080
class JinaAdapter : public VoyageAIAdapter {
1081
public:
1082
    Status build_multimodal_embedding_request(
1083
            const std::vector<MultimodalType>& media_types,
1084
            const std::vector<std::string>& media_urls,
1085
            const std::vector<std::string>& /*media_content_types*/,
1086
2
            std::string& request_body) const override {
1087
2
        RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
1088
2
                "JINA", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO}));
1089
1090
2
        rapidjson::Document doc;
1091
2
        doc.SetObject();
1092
2
        auto& allocator = doc.GetAllocator();
1093
1094
        /*{
1095
            "model": "jina-embeddings-v4",
1096
            "task": "text-matching",
1097
            "input": [
1098
              {"image": "<url>"},
1099
              {"video": "<url>"}
1100
            ]
1101
        }*/
1102
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1103
2
        doc.AddMember("task", "text-matching", allocator);
1104
1105
2
        rapidjson::Value input(rapidjson::kArrayType);
1106
5
        for (size_t i = 0; i < media_urls.size(); ++i) {
1107
3
            rapidjson::Value media_item(rapidjson::kObjectType);
1108
3
            if (media_types[i] == MultimodalType::IMAGE) {
1109
2
                media_item.AddMember("image", rapidjson::Value(media_urls[i].c_str(), allocator),
1110
2
                                     allocator);
1111
2
            } else {
1112
1
                media_item.AddMember("video", rapidjson::Value(media_urls[i].c_str(), allocator),
1113
1
                                     allocator);
1114
1
            }
1115
3
            input.PushBack(media_item, allocator);
1116
3
        }
1117
2
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
1118
2
            doc.AddMember("dimensions", _config.dimensions, allocator);
1119
2
        }
1120
2
        doc.AddMember("input", input, allocator);
1121
1122
2
        rapidjson::StringBuffer buffer;
1123
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1124
2
        doc.Accept(writer);
1125
2
        request_body = buffer.GetString();
1126
2
        return Status::OK();
1127
2
    }
1128
};
1129
1130
class BaichuanAdapter : public OpenAIAdapter {
1131
protected:
1132
0
    bool supports_dimension_param(const std::string& model_name) const override { return false; }
1133
};
1134
1135
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
1136
class GeminiAdapter : public AIAdapter {
1137
public:
1138
2
    Status set_authentication(HttpClient* client) const override {
1139
2
        client->set_header("x-goog-api-key", _config.api_key);
1140
2
        client->set_content_type("application/json");
1141
2
        return Status::OK();
1142
2
    }
1143
1144
    Status build_request_payload(const std::vector<std::string>& inputs,
1145
                                 const char* const system_prompt,
1146
1
                                 std::string& request_body) const override {
1147
1
        rapidjson::Document doc;
1148
1
        doc.SetObject();
1149
1
        auto& allocator = doc.GetAllocator();
1150
1151
        /*{
1152
          "systemInstruction": {
1153
              "parts": [
1154
                {
1155
                  "text": "system_prompt here"
1156
                }
1157
              ]
1158
            }
1159
          ],
1160
          "contents": [
1161
            {
1162
              "parts": [
1163
                {
1164
                  "text": "xxx"
1165
                }
1166
              ]
1167
            }
1168
          ],
1169
          "generationConfig": {
1170
          "temperature": 0.7,
1171
          "maxOutputTokens": 1024
1172
          }
1173
1174
        }*/
1175
1
        if (system_prompt && *system_prompt) {
1176
1
            rapidjson::Value system_instruction(rapidjson::kObjectType);
1177
1
            rapidjson::Value parts(rapidjson::kArrayType);
1178
1179
1
            rapidjson::Value part(rapidjson::kObjectType);
1180
1
            part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
1181
1
            parts.PushBack(part, allocator);
1182
            // system_instruction.PushBack(content, allocator);
1183
1
            system_instruction.AddMember("parts", parts, allocator);
1184
1
            doc.AddMember("systemInstruction", system_instruction, allocator);
1185
1
        }
1186
1187
1
        rapidjson::Value contents(rapidjson::kArrayType);
1188
1
        for (const auto& input : inputs) {
1189
1
            rapidjson::Value content(rapidjson::kObjectType);
1190
1
            rapidjson::Value parts(rapidjson::kArrayType);
1191
1192
1
            rapidjson::Value part(rapidjson::kObjectType);
1193
1
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
1194
1195
1
            parts.PushBack(part, allocator);
1196
1
            content.AddMember("parts", parts, allocator);
1197
1
            contents.PushBack(content, allocator);
1198
1
        }
1199
1
        doc.AddMember("contents", contents, allocator);
1200
1201
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1202
1
        rapidjson::Value generationConfig(rapidjson::kObjectType);
1203
1
        if (_config.temperature != -1) {
1204
1
            generationConfig.AddMember("temperature", _config.temperature, allocator);
1205
1
        }
1206
1
        if (_config.max_tokens != -1) {
1207
1
            generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
1208
1
        }
1209
1
        doc.AddMember("generationConfig", generationConfig, allocator);
1210
1211
1
        rapidjson::StringBuffer buffer;
1212
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1213
1
        doc.Accept(writer);
1214
1
        request_body = buffer.GetString();
1215
1216
1
        return Status::OK();
1217
1
    }
1218
1219
    Status parse_response(const std::string& response_body,
1220
3
                          std::vector<std::string>& results) const override {
1221
3
        rapidjson::Document doc;
1222
3
        doc.Parse(response_body.c_str());
1223
1224
3
        if (doc.HasParseError() || !doc.IsObject()) {
1225
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1226
1
                                         response_body);
1227
1
        }
1228
2
        if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
1229
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1230
1
                                         response_body);
1231
1
        }
1232
1233
        /*{
1234
          "candidates":[
1235
            {
1236
              "content": {
1237
                "parts": [
1238
                  {
1239
                    "text": "xxx"
1240
                  }
1241
                ]
1242
              }
1243
            }
1244
          ]
1245
        }*/
1246
1
        const auto& candidates = doc["candidates"];
1247
1
        results.reserve(candidates.Size());
1248
1249
2
        for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
1250
1
            if (!candidates[i].HasMember("content") ||
1251
1
                !candidates[i]["content"].HasMember("parts") ||
1252
1
                !candidates[i]["content"]["parts"].IsArray() ||
1253
1
                candidates[i]["content"]["parts"].Empty() ||
1254
1
                !candidates[i]["content"]["parts"][0].HasMember("text") ||
1255
1
                !candidates[i]["content"]["parts"][0]["text"].IsString()) {
1256
0
                return Status::InternalError("Invalid candidate format in {} response",
1257
0
                                             _config.provider_type);
1258
0
            }
1259
1260
1
            RETURN_IF_ERROR(append_parsed_text_result(
1261
1
                    candidates[i]["content"]["parts"][0]["text"].GetString(), results));
1262
1
        }
1263
1
        return Status::OK();
1264
1
    }
1265
1266
    Status build_embedding_request(const std::vector<std::string>& inputs,
1267
2
                                   std::string& request_body) const override {
1268
2
        rapidjson::Document doc;
1269
2
        doc.SetObject();
1270
2
        auto& allocator = doc.GetAllocator();
1271
1272
        /*{
1273
          "requests": [
1274
            {
1275
              "model": "models/gemini-embedding-001",
1276
              "content": {
1277
                "parts": [
1278
                  {
1279
                    "text": "xxx"
1280
                  }
1281
                ]
1282
              },
1283
              "outputDimensionality": 1024
1284
            },
1285
            {
1286
              "model": "models/gemini-embedding-001",
1287
              "content": {
1288
                "parts": [
1289
                  {
1290
                    "text": "yyy"
1291
                  }
1292
                ]
1293
              },
1294
              "outputDimensionality": 1024
1295
            }
1296
          ]
1297
        }*/
1298
1299
        // gemini requires the model format as `models/{model}`
1300
2
        std::string model_name = _config.model_name;
1301
2
        if (!model_name.starts_with("models/")) {
1302
2
            model_name = "models/" + model_name;
1303
2
        }
1304
1305
2
        rapidjson::Value requests(rapidjson::kArrayType);
1306
4
        for (const auto& input : inputs) {
1307
4
            rapidjson::Value request(rapidjson::kObjectType);
1308
4
            request.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
1309
4
            add_dimension_params(request, allocator);
1310
1311
4
            rapidjson::Value content(rapidjson::kObjectType);
1312
4
            rapidjson::Value parts(rapidjson::kArrayType);
1313
4
            rapidjson::Value part(rapidjson::kObjectType);
1314
4
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
1315
4
            parts.PushBack(part, allocator);
1316
4
            content.AddMember("parts", parts, allocator);
1317
4
            request.AddMember("content", content, allocator);
1318
4
            requests.PushBack(request, allocator);
1319
4
        }
1320
2
        doc.AddMember("requests", requests, allocator);
1321
1322
2
        rapidjson::StringBuffer buffer;
1323
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1324
2
        doc.Accept(writer);
1325
2
        request_body = buffer.GetString();
1326
1327
2
        return Status::OK();
1328
2
    }
1329
1330
    Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types,
1331
                                              const std::vector<std::string>& media_urls,
1332
                                              const std::vector<std::string>& media_content_types,
1333
8
                                              std::string& request_body) const override {
1334
8
        RETURN_IF_ERROR(validate_multimodal_embedding_inputs(
1335
8
                "Gemini", media_types, media_urls,
1336
8
                {MultimodalType::IMAGE, MultimodalType::AUDIO, MultimodalType::VIDEO}));
1337
6
        if (media_content_types.size() != media_urls.size()) {
1338
1
            return Status::InvalidArgument(
1339
1
                    "Gemini multimodal embed input size mismatch, media_content_types={}, "
1340
1
                    "media_urls={}",
1341
1
                    media_content_types.size(), media_urls.size());
1342
1
        }
1343
1344
5
        rapidjson::Document doc;
1345
5
        doc.SetObject();
1346
5
        auto& allocator = doc.GetAllocator();
1347
1348
        /*{
1349
          "requests": [
1350
            {
1351
              "model": "models/gemini-embedding-2-preview",
1352
              "content": {
1353
                "parts": [
1354
                  {"file_data": {"mime_type": "<original content_type>", "file_uri": "<url>"}}
1355
                ]
1356
              },
1357
              "outputDimensionality": 768
1358
            },
1359
            {
1360
              "model": "models/gemini-embedding-2-preview",
1361
              "content": {
1362
                "parts": [
1363
                  {"file_data": {"mime_type": "<original content_type>", "file_uri": "<url>"}}
1364
                ]
1365
              },
1366
              "outputDimensionality": 768
1367
            }
1368
          ]
1369
        }*/
1370
5
        std::string model_name = _config.model_name;
1371
5
        if (!model_name.starts_with("models/")) {
1372
5
            model_name = "models/" + model_name;
1373
5
        }
1374
1375
5
        rapidjson::Value requests(rapidjson::kArrayType);
1376
12
        for (size_t i = 0; i < media_urls.size(); ++i) {
1377
7
            rapidjson::Value request(rapidjson::kObjectType);
1378
7
            request.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
1379
7
            add_dimension_params(request, allocator);
1380
1381
7
            rapidjson::Value content(rapidjson::kObjectType);
1382
7
            rapidjson::Value parts(rapidjson::kArrayType);
1383
7
            rapidjson::Value part(rapidjson::kObjectType);
1384
7
            rapidjson::Value file_data(rapidjson::kObjectType);
1385
7
            file_data.AddMember("mime_type",
1386
7
                                rapidjson::Value(media_content_types[i].c_str(), allocator),
1387
7
                                allocator);
1388
7
            file_data.AddMember("file_uri", rapidjson::Value(media_urls[i].c_str(), allocator),
1389
7
                                allocator);
1390
7
            part.AddMember("file_data", file_data, allocator);
1391
7
            parts.PushBack(part, allocator);
1392
7
            content.AddMember("parts", parts, allocator);
1393
7
            request.AddMember("content", content, allocator);
1394
7
            requests.PushBack(request, allocator);
1395
7
        }
1396
5
        doc.AddMember("requests", requests, allocator);
1397
1398
5
        rapidjson::StringBuffer buffer;
1399
5
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1400
5
        doc.Accept(writer);
1401
5
        request_body = buffer.GetString();
1402
5
        return Status::OK();
1403
6
    }
1404
1405
    Status parse_embedding_response(const std::string& response_body,
1406
3
                                    std::vector<std::vector<float>>& results) const override {
1407
3
        rapidjson::Document doc;
1408
3
        doc.Parse(response_body.c_str());
1409
1410
3
        if (doc.HasParseError() || !doc.IsObject()) {
1411
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1412
0
                                         response_body);
1413
0
        }
1414
3
        if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
1415
            /*{
1416
              "embeddings": [
1417
                {"values": [0.1, 0.2, 0.3]},
1418
                {"values": [0.4, 0.5, 0.6]}
1419
              ]
1420
            }*/
1421
2
            const auto& embeddings = doc["embeddings"];
1422
2
            results.reserve(embeddings.Size());
1423
6
            for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) {
1424
4
                if (!embeddings[i].HasMember("values") || !embeddings[i]["values"].IsArray()) {
1425
0
                    return Status::InternalError("Invalid {} response format: {}",
1426
0
                                                 _config.provider_type, response_body);
1427
0
                }
1428
4
                std::transform(embeddings[i]["values"].Begin(), embeddings[i]["values"].End(),
1429
4
                               std::back_inserter(results.emplace_back()),
1430
10
                               [](const auto& val) { return val.GetFloat(); });
1431
4
            }
1432
2
            return Status::OK();
1433
2
        }
1434
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
1435
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1436
0
                                         response_body);
1437
0
        }
1438
1439
        /*{
1440
          "embedding":{
1441
            "values": [0.1, 0.2, 0.3]
1442
          }
1443
        }*/
1444
1
        const auto& embedding = doc["embedding"];
1445
1
        if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
1446
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1447
0
                                         response_body);
1448
0
        }
1449
1
        std::transform(embedding["values"].Begin(), embedding["values"].End(),
1450
1
                       std::back_inserter(results.emplace_back()),
1451
3
                       [](const auto& val) { return val.GetFloat(); });
1452
1453
1
        return Status::OK();
1454
1
    }
1455
1456
protected:
1457
11
    bool supports_dimension_param(const std::string& model_name) const override {
1458
11
        static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
1459
11
                                                                            "embedding-001"};
1460
11
        return !no_dimension_models.contains(model_name);
1461
11
    }
1462
1463
9
    std::string get_dimension_param_name() const override { return "outputDimensionality"; }
1464
};
1465
1466
class AnthropicAdapter : public VoyageAIAdapter {
1467
public:
1468
1
    Status set_authentication(HttpClient* client) const override {
1469
1
        client->set_header("x-api-key", _config.api_key);
1470
1
        client->set_header("anthropic-version", _config.anthropic_version);
1471
1
        client->set_content_type("application/json");
1472
1473
1
        return Status::OK();
1474
1
    }
1475
1476
    Status build_request_payload(const std::vector<std::string>& inputs,
1477
                                 const char* const system_prompt,
1478
1
                                 std::string& request_body) const override {
1479
1
        rapidjson::Document doc;
1480
1
        doc.SetObject();
1481
1
        auto& allocator = doc.GetAllocator();
1482
1483
        /*
1484
            "model": "claude-opus-4-1-20250805",
1485
            "max_tokens": 1024,
1486
            "system": "system_prompt here",
1487
            "messages": [
1488
              {"role": "user", "content": "xxx"}
1489
            ],
1490
            "temperature": 0.7
1491
        */
1492
1493
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1494
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1495
1
        if (_config.temperature != -1) {
1496
1
            doc.AddMember("temperature", _config.temperature, allocator);
1497
1
        }
1498
1
        if (_config.max_tokens != -1) {
1499
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
1500
1
        } else {
1501
            // Keep the default value, Anthropic requires this parameter
1502
0
            doc.AddMember("max_tokens", 2048, allocator);
1503
0
        }
1504
1
        if (system_prompt && *system_prompt) {
1505
1
            doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
1506
1
        }
1507
1508
1
        rapidjson::Value messages(rapidjson::kArrayType);
1509
1
        for (const auto& input : inputs) {
1510
1
            rapidjson::Value message(rapidjson::kObjectType);
1511
1
            message.AddMember("role", "user", allocator);
1512
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
1513
1
            messages.PushBack(message, allocator);
1514
1
        }
1515
1
        doc.AddMember("messages", messages, allocator);
1516
1517
1
        rapidjson::StringBuffer buffer;
1518
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1519
1
        doc.Accept(writer);
1520
1
        request_body = buffer.GetString();
1521
1522
1
        return Status::OK();
1523
1
    }
1524
1525
    Status parse_response(const std::string& response_body,
1526
3
                          std::vector<std::string>& results) const override {
1527
3
        rapidjson::Document doc;
1528
3
        doc.Parse(response_body.c_str());
1529
3
        if (doc.HasParseError() || !doc.IsObject()) {
1530
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1531
1
                                         response_body);
1532
1
        }
1533
2
        if (!doc.HasMember("content") || !doc["content"].IsArray()) {
1534
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1535
1
                                         response_body);
1536
1
        }
1537
1538
        /*{
1539
            "content": [
1540
              {
1541
                "text": "xxx",
1542
                "type": "text"
1543
              }
1544
            ]
1545
        }*/
1546
1
        const auto& content = doc["content"];
1547
1
        results.reserve(1);
1548
1549
1
        std::string result;
1550
2
        for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
1551
1
            if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
1552
1
                !content[i].HasMember("text") || !content[i]["text"].IsString()) {
1553
0
                continue;
1554
0
            }
1555
1556
1
            if (std::string(content[i]["type"].GetString()) == "text") {
1557
1
                if (!result.empty()) {
1558
0
                    result += "\n";
1559
0
                }
1560
1
                result += content[i]["text"].GetString();
1561
1
            }
1562
1
        }
1563
1564
1
        return append_parsed_text_result(result, results);
1565
2
    }
1566
};
1567
1568
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
1569
class MockAdapter : public AIAdapter {
1570
public:
1571
0
    Status set_authentication(HttpClient* client) const override { return Status::OK(); }
1572
1573
    Status build_request_payload(const std::vector<std::string>& inputs,
1574
                                 const char* const system_prompt,
1575
2
                                 std::string& request_body) const override {
1576
2
        return Status::OK();
1577
2
    }
1578
1579
    Status parse_response(const std::string& response_body,
1580
74
                          std::vector<std::string>& results) const override {
1581
74
        return append_parsed_text_result(response_body, results);
1582
74
    }
1583
1584
    Status build_embedding_request(const std::vector<std::string>& inputs,
1585
2
                                   std::string& request_body) const override {
1586
2
        return Status::OK();
1587
2
    }
1588
1589
    Status build_multimodal_embedding_request(
1590
            const std::vector<MultimodalType>& /*media_types*/,
1591
            const std::vector<std::string>& /*media_urls*/,
1592
            const std::vector<std::string>& /*media_content_types*/,
1593
2
            std::string& /*request_body*/) const override {
1594
2
        return Status::OK();
1595
2
    }
1596
1597
    Status parse_embedding_response(const std::string& response_body,
1598
0
                                    std::vector<std::vector<float>>& results) const override {
1599
0
        rapidjson::Document doc;
1600
0
        doc.SetObject();
1601
0
        doc.Parse(response_body.c_str());
1602
0
        if (doc.HasParseError() || !doc.IsObject()) {
1603
0
            return Status::InternalError("Failed to parse embedding response");
1604
0
        }
1605
0
        if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
1606
0
            return Status::InternalError("Invalid embedding response format");
1607
0
        }
1608
1609
0
        results.reserve(1);
1610
0
        std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
1611
0
                       std::back_inserter(results.emplace_back()),
1612
0
                       [](const auto& val) { return val.GetFloat(); });
1613
0
        return Status::OK();
1614
0
    }
1615
};
1616
1617
class AIAdapterFactory {
1618
public:
1619
106
    static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
1620
106
        static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
1621
106
                adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
1622
106
                            {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
1623
106
                            {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
1624
106
                            {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
1625
106
                            {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
1626
106
                            {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
1627
106
                            {"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
1628
106
                            {"JINA", []() { return std::make_shared<JinaAdapter>(); }},
1629
106
                            {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
1630
106
                            {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
1631
106
                            {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
1632
106
                            {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
1633
106
                            {"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
1634
1635
106
        auto it = adapters.find(provider_type);
1636
106
        return (it != adapters.end()) ? it->second() : nullptr;
1637
106
    }
1638
};
1639
1640
} // namespace doris