Coverage Report

Created: 2026-04-16 10:20

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_adapter.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/PaloInternalService_types.h>
21
#include <rapidjson/rapidjson.h>
22
23
#include <algorithm>
24
#include <memory>
25
#include <string>
26
#include <unordered_map>
27
#include <vector>
28
29
#include "common/status.h"
30
#include "core/string_buffer.hpp"
31
#include "rapidjson/document.h"
32
#include "rapidjson/stringbuffer.h"
33
#include "rapidjson/writer.h"
34
#include "service/http/http_client.h"
35
#include "service/http/http_headers.h"
36
37
namespace doris {
38
39
struct AIResource {
40
16
    AIResource() = default;
41
    AIResource(const TAIResource& tai)
42
11
            : endpoint(tai.endpoint),
43
11
              provider_type(tai.provider_type),
44
11
              model_name(tai.model_name),
45
11
              api_key(tai.api_key),
46
11
              temperature(tai.temperature),
47
11
              max_tokens(tai.max_tokens),
48
11
              max_retries(tai.max_retries),
49
11
              retry_delay_second(tai.retry_delay_second),
50
11
              anthropic_version(tai.anthropic_version),
51
11
              dimensions(tai.dimensions) {}
52
53
    std::string endpoint;
54
    std::string provider_type;
55
    std::string model_name;
56
    std::string api_key;
57
    double temperature;
58
    int64_t max_tokens;
59
    int32_t max_retries;
60
    int32_t retry_delay_second;
61
    std::string anthropic_version;
62
    int32_t dimensions;
63
64
1
    void serialize(BufferWritable& buf) const {
65
1
        buf.write_binary(endpoint);
66
1
        buf.write_binary(provider_type);
67
1
        buf.write_binary(model_name);
68
1
        buf.write_binary(api_key);
69
1
        buf.write_binary(temperature);
70
1
        buf.write_binary(max_tokens);
71
1
        buf.write_binary(max_retries);
72
1
        buf.write_binary(retry_delay_second);
73
1
        buf.write_binary(anthropic_version);
74
1
        buf.write_binary(dimensions);
75
1
    }
76
77
1
    void deserialize(BufferReadable& buf) {
78
1
        buf.read_binary(endpoint);
79
1
        buf.read_binary(provider_type);
80
1
        buf.read_binary(model_name);
81
1
        buf.read_binary(api_key);
82
1
        buf.read_binary(temperature);
83
1
        buf.read_binary(max_tokens);
84
1
        buf.read_binary(max_retries);
85
1
        buf.read_binary(retry_delay_second);
86
1
        buf.read_binary(anthropic_version);
87
1
        buf.read_binary(dimensions);
88
1
    }
89
};
90
91
enum class MultimodalType { IMAGE, VIDEO, AUDIO };
92
93
1
inline const char* multimodal_type_to_string(MultimodalType type) {
94
1
    switch (type) {
95
0
    case MultimodalType::IMAGE:
96
0
        return "image";
97
0
    case MultimodalType::VIDEO:
98
0
        return "video";
99
1
    case MultimodalType::AUDIO:
100
1
        return "audio";
101
1
    }
102
0
    return "unknown";
103
1
}
104
105
class AIAdapter {
106
public:
107
140
    virtual ~AIAdapter() = default;
108
109
    // Set authentication headers for the HTTP client
110
    virtual Status set_authentication(HttpClient* client) const = 0;
111
112
99
    virtual void init(const TAIResource& config) { _config = config; }
113
12
    virtual void init(const AIResource& config) {
114
12
        _config.endpoint = config.endpoint;
115
12
        _config.provider_type = config.provider_type;
116
12
        _config.model_name = config.model_name;
117
12
        _config.api_key = config.api_key;
118
12
        _config.temperature = config.temperature;
119
12
        _config.max_tokens = config.max_tokens;
120
12
        _config.max_retries = config.max_retries;
121
12
        _config.retry_delay_second = config.retry_delay_second;
122
12
        _config.anthropic_version = config.anthropic_version;
123
12
    }
124
125
    // Build request payload based on input text strings
126
    virtual Status build_request_payload(const std::vector<std::string>& inputs,
127
                                         const char* const system_prompt,
128
1
                                         std::string& request_body) const {
129
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
130
1
    }
131
132
    // Parse response from AI service and extract generated text results
133
    virtual Status parse_response(const std::string& response_body,
134
1
                                  std::vector<std::string>& results) const {
135
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
136
1
    }
137
138
    virtual Status build_embedding_request(const std::vector<std::string>& inputs,
139
0
                                           std::string& request_body) const {
140
0
        return Status::NotSupported("{} does not support the Embed feature.",
141
0
                                    _config.provider_type);
142
0
    }
143
144
    virtual Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
145
                                                      const std::string& /*media_url*/,
146
0
                                                      std::string& /*request_body*/) const {
147
0
        return Status::NotSupported("{} does not support multimodal Embed feature.",
148
0
                                    _config.provider_type);
149
0
    }
150
151
    virtual Status parse_embedding_response(const std::string& response_body,
152
0
                                            std::vector<std::vector<float>>& results) const {
153
0
        return Status::NotSupported("{} does not support the Embed feature.",
154
0
                                    _config.provider_type);
155
0
    }
156
157
protected:
158
    TAIResource _config;
159
160
    // return true if the model support dimension parameter
161
1
    virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
162
163
    // Different providers may have different dimension parameter names.
164
0
    virtual std::string get_dimension_param_name() const { return "dimensions"; }
165
166
    virtual void add_dimension_params(rapidjson::Value& doc,
167
14
                                      rapidjson::Document::AllocatorType& allocator) const {
168
14
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
169
8
            std::string param_name = get_dimension_param_name();
170
8
            rapidjson::Value name(param_name.c_str(), allocator);
171
8
            doc.AddMember(name, _config.dimensions, allocator);
172
8
        }
173
14
    }
174
};
175
176
// Most LLM-providers' Embedding formats are based on VoyageAI.
177
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
178
class VoyageAIAdapter : public AIAdapter {
179
public:
180
2
    Status set_authentication(HttpClient* client) const override {
181
2
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
182
2
        client->set_content_type("application/json");
183
184
2
        return Status::OK();
185
2
    }
186
187
    Status build_embedding_request(const std::vector<std::string>& inputs,
188
8
                                   std::string& request_body) const override {
189
8
        rapidjson::Document doc;
190
8
        doc.SetObject();
191
8
        auto& allocator = doc.GetAllocator();
192
193
        /*{
194
            "model": "xxx",
195
            "input": [
196
              "xxx",
197
              "xxx",
198
              ...
199
            ],
200
            "output_dimensions": 512
201
        }*/
202
8
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
203
8
        add_dimension_params(doc, allocator);
204
205
8
        rapidjson::Value input(rapidjson::kArrayType);
206
8
        for (const auto& msg : inputs) {
207
8
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
208
8
        }
209
8
        doc.AddMember("input", input, allocator);
210
211
8
        rapidjson::StringBuffer buffer;
212
8
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
213
8
        doc.Accept(writer);
214
8
        request_body = buffer.GetString();
215
216
8
        return Status::OK();
217
8
    }
218
219
    Status build_multimodal_embedding_request(MultimodalType media_type,
220
                                              const std::string& media_url,
221
1
                                              std::string& request_body) const override {
222
1
        if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO)
223
0
                [[unlikely]] {
224
0
            return Status::InvalidArgument(
225
0
                    "VoyageAI only supports image/video multimodal embed, got {}",
226
0
                    multimodal_type_to_string(media_type));
227
0
        }
228
1
        if (_config.dimensions != -1) {
229
1
            LOG(WARNING) << "VoyageAI multimodal embedding currently ignores dimensions parameter, "
230
1
                         << "model=" << _config.model_name << ", dimensions=" << _config.dimensions;
231
1
        }
232
233
1
        rapidjson::Document doc;
234
1
        doc.SetObject();
235
1
        auto& allocator = doc.GetAllocator();
236
237
        /*{
238
            "inputs": [
239
              {
240
                "content": [
241
                  {"type": "image_url", "image_url": "<url>"}
242
                ]
243
              }
244
            ],
245
            "model": "voyage-multimodal-3.5"
246
        }*/
247
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
248
249
1
        rapidjson::Value inputs(rapidjson::kArrayType);
250
1
        rapidjson::Value input(rapidjson::kObjectType);
251
1
        rapidjson::Value content(rapidjson::kArrayType);
252
253
1
        rapidjson::Value media_item(rapidjson::kObjectType);
254
1
        if (media_type == MultimodalType::IMAGE) {
255
0
            media_item.AddMember("type", "image_url", allocator);
256
0
            media_item.AddMember("image_url", rapidjson::Value(media_url.c_str(), allocator),
257
0
                                 allocator);
258
1
        } else {
259
1
            media_item.AddMember("type", "video_url", allocator);
260
1
            media_item.AddMember("video_url", rapidjson::Value(media_url.c_str(), allocator),
261
1
                                 allocator);
262
1
        }
263
1
        content.PushBack(media_item, allocator);
264
265
1
        input.AddMember("content", content, allocator);
266
1
        inputs.PushBack(input, allocator);
267
1
        doc.AddMember("inputs", inputs, allocator);
268
269
1
        rapidjson::StringBuffer buffer;
270
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
271
1
        doc.Accept(writer);
272
1
        request_body = buffer.GetString();
273
1
        return Status::OK();
274
1
    }
275
276
    Status parse_embedding_response(const std::string& response_body,
277
5
                                    std::vector<std::vector<float>>& results) const override {
278
5
        rapidjson::Document doc;
279
5
        doc.Parse(response_body.c_str());
280
281
5
        if (doc.HasParseError() || !doc.IsObject()) {
282
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
283
1
                                         response_body);
284
1
        }
285
4
        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
286
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
287
1
                                         response_body);
288
1
        }
289
290
        /*{
291
            "data":[
292
              {
293
                "object": "embedding",
294
                "embedding": [...], <- only need this
295
                "index": 0
296
              },
297
              {
298
                "object": "embedding",
299
                "embedding": [...],
300
                "index": 1
301
              }, ...
302
            ],
303
            "model"....
304
        }*/
305
3
        const auto& data = doc["data"];
306
3
        results.reserve(data.Size());
307
7
        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
308
5
            if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
309
1
                return Status::InternalError("Invalid {} response format: {}",
310
1
                                             _config.provider_type, response_body);
311
1
            }
312
313
4
            std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
314
4
                           std::back_inserter(results.emplace_back()),
315
10
                           [](const auto& val) { return val.GetFloat(); });
316
4
        }
317
318
2
        return Status::OK();
319
3
    }
320
321
protected:
322
3
    bool supports_dimension_param(const std::string& model_name) const override {
323
3
        static const std::unordered_set<std::string> no_dimension_models = {
324
3
                "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
325
3
                "voyage-multimodal-3"};
326
3
        return !no_dimension_models.contains(model_name);
327
3
    }
328
329
1
    std::string get_dimension_param_name() const override { return "output_dimension"; }
330
};
331
332
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
333
class LocalAdapter : public AIAdapter {
334
public:
335
    // Local deployments typically don't need authentication
336
2
    Status set_authentication(HttpClient* client) const override {
337
2
        client->set_content_type("application/json");
338
2
        return Status::OK();
339
2
    }
340
341
    Status build_request_payload(const std::vector<std::string>& inputs,
342
                                 const char* const system_prompt,
343
3
                                 std::string& request_body) const override {
344
3
        rapidjson::Document doc;
345
3
        doc.SetObject();
346
3
        auto& allocator = doc.GetAllocator();
347
348
3
        std::string end_point = _config.endpoint;
349
3
        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
350
2
            RETURN_IF_ERROR(
351
2
                    build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
352
2
        } else {
353
1
            RETURN_IF_ERROR(
354
1
                    build_default_request(doc, allocator, inputs, system_prompt, request_body));
355
1
        }
356
357
3
        rapidjson::StringBuffer buffer;
358
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
359
3
        doc.Accept(writer);
360
3
        request_body = buffer.GetString();
361
362
3
        return Status::OK();
363
3
    }
364
365
    Status parse_response(const std::string& response_body,
366
7
                          std::vector<std::string>& results) const override {
367
7
        rapidjson::Document doc;
368
7
        doc.Parse(response_body.c_str());
369
370
7
        if (doc.HasParseError() || !doc.IsObject()) {
371
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
372
1
                                         response_body);
373
1
        }
374
375
        // Handle various response formats from local LLMs
376
        // Format 1: OpenAI-compatible format with choices/message/content
377
6
        if (doc.HasMember("choices") && doc["choices"].IsArray()) {
378
1
            const auto& choices = doc["choices"];
379
1
            results.reserve(choices.Size());
380
381
2
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
382
1
                if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
383
1
                    choices[i]["message"]["content"].IsString()) {
384
1
                    results.emplace_back(choices[i]["message"]["content"].GetString());
385
1
                } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
386
                    // Some local LLMs use a simpler format
387
0
                    results.emplace_back(choices[i]["text"].GetString());
388
0
                }
389
1
            }
390
5
        } else if (doc.HasMember("text") && doc["text"].IsString()) {
391
            // Format 2: Simple response with just "text" or "content" field
392
1
            results.emplace_back(doc["text"].GetString());
393
4
        } else if (doc.HasMember("content") && doc["content"].IsString()) {
394
1
            results.emplace_back(doc["content"].GetString());
395
3
        } else if (doc.HasMember("response") && doc["response"].IsString()) {
396
            // Format 3: Response field (Ollama `generate` format)
397
1
            results.emplace_back(doc["response"].GetString());
398
2
        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
399
2
                   doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
400
            // Format 4: message/content field (Ollama `chat` format)
401
1
            results.emplace_back(doc["message"]["content"].GetString());
402
1
        } else {
403
1
            return Status::NotSupported("Unsupported response format from local AI.");
404
1
        }
405
406
5
        return Status::OK();
407
6
    }
408
409
    Status build_embedding_request(const std::vector<std::string>& inputs,
410
1
                                   std::string& request_body) const override {
411
1
        rapidjson::Document doc;
412
1
        doc.SetObject();
413
1
        auto& allocator = doc.GetAllocator();
414
415
1
        if (!_config.model_name.empty()) {
416
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
417
1
                          allocator);
418
1
        }
419
420
1
        add_dimension_params(doc, allocator);
421
422
1
        rapidjson::Value input(rapidjson::kArrayType);
423
1
        for (const auto& msg : inputs) {
424
1
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
425
1
        }
426
1
        doc.AddMember("input", input, allocator);
427
428
1
        rapidjson::StringBuffer buffer;
429
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
430
1
        doc.Accept(writer);
431
1
        request_body = buffer.GetString();
432
433
1
        return Status::OK();
434
1
    }
435
436
    Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
437
                                              const std::string& /*media_url*/,
438
0
                                              std::string& /*request_body*/) const override {
439
0
        return Status::NotSupported("{} does not support multimodal Embed feature.",
440
0
                                    _config.provider_type);
441
0
    }
442
443
    Status parse_embedding_response(const std::string& response_body,
444
3
                                    std::vector<std::vector<float>>& results) const override {
445
3
        rapidjson::Document doc;
446
3
        doc.Parse(response_body.c_str());
447
448
3
        if (doc.HasParseError() || !doc.IsObject()) {
449
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
450
0
                                         response_body);
451
0
        }
452
453
        // parse different response format
454
3
        rapidjson::Value embedding;
455
3
        if (doc.HasMember("data") && doc["data"].IsArray()) {
456
            // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
457
1
            const auto& data = doc["data"];
458
1
            results.reserve(data.Size());
459
3
            for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
460
2
                if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
461
0
                    return Status::InternalError("Invalid {} response format",
462
0
                                                 _config.provider_type);
463
0
                }
464
465
2
                std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
466
2
                               std::back_inserter(results.emplace_back()),
467
5
                               [](const auto& val) { return val.GetFloat(); });
468
2
            }
469
2
        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
470
            // "embeddings":[[0.1, 0.2, ...]]
471
1
            results.reserve(1);
472
2
            for (int i = 0; i < doc["embeddings"].Size(); i++) {
473
1
                embedding = doc["embeddings"][i];
474
1
                std::transform(embedding.Begin(), embedding.End(),
475
1
                               std::back_inserter(results.emplace_back()),
476
2
                               [](const auto& val) { return val.GetFloat(); });
477
1
            }
478
1
        } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
479
            // "embedding":[0.1, 0.2, ...]
480
1
            results.reserve(1);
481
1
            embedding = doc["embedding"];
482
1
            std::transform(embedding.Begin(), embedding.End(),
483
1
                           std::back_inserter(results.emplace_back()),
484
3
                           [](const auto& val) { return val.GetFloat(); });
485
1
        } else {
486
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
487
0
                                         response_body);
488
0
        }
489
490
3
        return Status::OK();
491
3
    }
492
493
private:
494
    Status build_ollama_request(rapidjson::Document& doc,
495
                                rapidjson::Document::AllocatorType& allocator,
496
                                const std::vector<std::string>& inputs,
497
2
                                const char* const system_prompt, std::string& request_body) const {
498
        /*
499
        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
500
        {
501
            "model": <model_name>,
502
            "stream": false,
503
            "think": false,
504
            "options": {
505
                "temperature": <temperature>,
506
                "max_token": <max_token>
507
            },
508
            "messages": [
509
                {"role": "system", "content": <system_prompt>},
510
                {"role": "user", "content": <user_prompt>}
511
            ]
512
        }
513
        
514
        for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
515
        {
516
            "model": <model_name>,
517
            "stream": false,
518
            "think": false
519
            "options": {
520
                "temperature": <temperature>,
521
                "max_token": <max_token>
522
            },
523
            "system": <system_prompt>,
524
            "prompt": <user_prompt>
525
        }
526
        */
527
528
        // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
529
        // The rest remains identical.
530
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
531
2
        doc.AddMember("stream", false, allocator);
532
2
        doc.AddMember("think", false, allocator);
533
534
        // option section
535
2
        rapidjson::Value options(rapidjson::kObjectType);
536
2
        if (_config.temperature != -1) {
537
2
            options.AddMember("temperature", _config.temperature, allocator);
538
2
        }
539
2
        if (_config.max_tokens != -1) {
540
2
            options.AddMember("max_token", _config.max_tokens, allocator);
541
2
        }
542
2
        doc.AddMember("options", options, allocator);
543
544
        // prompt section
545
2
        if (_config.endpoint.ends_with("chat")) {
546
1
            rapidjson::Value messages(rapidjson::kArrayType);
547
1
            if (system_prompt && *system_prompt) {
548
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
549
1
                sys_msg.AddMember("role", "system", allocator);
550
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
551
1
                messages.PushBack(sys_msg, allocator);
552
1
            }
553
1
            for (const auto& input : inputs) {
554
1
                rapidjson::Value message(rapidjson::kObjectType);
555
1
                message.AddMember("role", "user", allocator);
556
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
557
1
                messages.PushBack(message, allocator);
558
1
            }
559
1
            doc.AddMember("messages", messages, allocator);
560
1
        } else {
561
1
            if (system_prompt && *system_prompt) {
562
1
                doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
563
1
            }
564
1
            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
565
1
        }
566
567
2
        return Status::OK();
568
2
    }
569
570
    Status build_default_request(rapidjson::Document& doc,
571
                                 rapidjson::Document::AllocatorType& allocator,
572
                                 const std::vector<std::string>& inputs,
573
1
                                 const char* const system_prompt, std::string& request_body) const {
574
        /*
575
        Default format(OpenAI-compatible):
576
        {
577
            "model": <model_name>,
578
            "temperature": <temperature>,
579
            "max_tokens": <max_tokens>,
580
            "messages": [
581
                {"role": "system", "content": <system_prompt>},
582
                {"role": "user", "content": <user_prompt>}
583
            ]
584
        }
585
        */
586
587
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
588
589
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
590
1
        if (_config.temperature != -1) {
591
1
            doc.AddMember("temperature", _config.temperature, allocator);
592
1
        }
593
1
        if (_config.max_tokens != -1) {
594
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
595
1
        }
596
597
1
        rapidjson::Value messages(rapidjson::kArrayType);
598
1
        if (system_prompt && *system_prompt) {
599
1
            rapidjson::Value sys_msg(rapidjson::kObjectType);
600
1
            sys_msg.AddMember("role", "system", allocator);
601
1
            sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
602
1
            messages.PushBack(sys_msg, allocator);
603
1
        }
604
1
        for (const auto& input : inputs) {
605
1
            rapidjson::Value message(rapidjson::kObjectType);
606
1
            message.AddMember("role", "user", allocator);
607
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
608
1
            messages.PushBack(message, allocator);
609
1
        }
610
1
        doc.AddMember("messages", messages, allocator);
611
1
        return Status::OK();
612
1
    }
613
};
614
615
// The OpenAI API format can be reused with some compatible AIs.
616
class OpenAIAdapter : public VoyageAIAdapter {
617
public:
618
11
    Status set_authentication(HttpClient* client) const override {
619
11
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
620
11
        client->set_content_type("application/json");
621
622
11
        return Status::OK();
623
11
    }
624
625
    Status build_request_payload(const std::vector<std::string>& inputs,
626
                                 const char* const system_prompt,
627
2
                                 std::string& request_body) const override {
628
2
        rapidjson::Document doc;
629
2
        doc.SetObject();
630
2
        auto& allocator = doc.GetAllocator();
631
632
2
        if (_config.endpoint.ends_with("responses")) {
633
            /*{
634
              "model": "gpt-4.1-mini",
635
              "input": [
636
                {"role": "system", "content": "system_prompt here"},
637
                {"role": "user", "content": "xxx"}
638
              ],
639
              "temperature": 0.7,
640
              "max_output_tokens": 150
641
            }*/
642
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
643
1
                          allocator);
644
645
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
646
1
            if (_config.temperature != -1) {
647
1
                doc.AddMember("temperature", _config.temperature, allocator);
648
1
            }
649
1
            if (_config.max_tokens != -1) {
650
1
                doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
651
1
            }
652
653
            // input
654
1
            rapidjson::Value input(rapidjson::kArrayType);
655
1
            if (system_prompt && *system_prompt) {
656
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
657
1
                sys_msg.AddMember("role", "system", allocator);
658
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
659
1
                input.PushBack(sys_msg, allocator);
660
1
            }
661
1
            for (const auto& msg : inputs) {
662
1
                rapidjson::Value message(rapidjson::kObjectType);
663
1
                message.AddMember("role", "user", allocator);
664
1
                message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
665
1
                input.PushBack(message, allocator);
666
1
            }
667
1
            doc.AddMember("input", input, allocator);
668
1
        } else {
669
            /*{
670
              "model": "gpt-4",
671
              "messages": [
672
                {"role": "system", "content": "system_prompt here"},
673
                {"role": "user", "content": "xxx"}
674
              ],
675
              "temperature": x,
676
              "max_tokens": x,
677
            }*/
678
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
679
1
                          allocator);
680
681
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
682
1
            if (_config.temperature != -1) {
683
1
                doc.AddMember("temperature", _config.temperature, allocator);
684
1
            }
685
1
            if (_config.max_tokens != -1) {
686
1
                doc.AddMember("max_tokens", _config.max_tokens, allocator);
687
1
            }
688
689
1
            rapidjson::Value messages(rapidjson::kArrayType);
690
1
            if (system_prompt && *system_prompt) {
691
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
692
1
                sys_msg.AddMember("role", "system", allocator);
693
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
694
1
                messages.PushBack(sys_msg, allocator);
695
1
            }
696
1
            for (const auto& input : inputs) {
697
1
                rapidjson::Value message(rapidjson::kObjectType);
698
1
                message.AddMember("role", "user", allocator);
699
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
700
1
                messages.PushBack(message, allocator);
701
1
            }
702
1
            doc.AddMember("messages", messages, allocator);
703
1
        }
704
705
2
        rapidjson::StringBuffer buffer;
706
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
707
2
        doc.Accept(writer);
708
2
        request_body = buffer.GetString();
709
710
2
        return Status::OK();
711
2
    }
712
713
    Status parse_response(const std::string& response_body,
714
6
                          std::vector<std::string>& results) const override {
715
6
        rapidjson::Document doc;
716
6
        doc.Parse(response_body.c_str());
717
718
6
        if (doc.HasParseError() || !doc.IsObject()) {
719
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
720
1
                                         response_body);
721
1
        }
722
723
5
        if (doc.HasMember("output") && doc["output"].IsArray()) {
724
            /// for responses endpoint
725
            /*{
726
              "output": [
727
                {
728
                  "id": "msg_123",
729
                  "type": "message",
730
                  "role": "assistant",
731
                  "content": [
732
                    {
733
                      "type": "text",
734
                      "text": "result text here"   <- result
735
                    }
736
                  ]
737
                }
738
              ]
739
            }*/
740
1
            const auto& output = doc["output"];
741
1
            results.reserve(output.Size());
742
743
2
            for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
744
1
                if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
745
1
                    output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
746
1
                    !output[i]["content"][0]["text"].IsString()) {
747
0
                    return Status::InternalError("Invalid output format in {} response: {}",
748
0
                                                 _config.provider_type, response_body);
749
0
                }
750
751
1
                results.emplace_back(output[i]["content"][0]["text"].GetString());
752
1
            }
753
4
        } else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
754
            /// for completions endpoint
755
            /*{
756
              "object": "chat.completion",
757
              "model": "gpt-4",
758
              "choices": [
759
                {
760
                  ...
761
                  "message": {
762
                    "role": "assistant",
763
                    "content": "xxx"      <- result
764
                  },
765
                  ...
766
                }
767
              ],
768
              ...
769
            }*/
770
3
            const auto& choices = doc["choices"];
771
3
            results.reserve(choices.Size());
772
773
4
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
774
3
                if (!choices[i].HasMember("message") ||
775
3
                    !choices[i]["message"].HasMember("content") ||
776
3
                    !choices[i]["message"]["content"].IsString()) {
777
2
                    return Status::InternalError("Invalid choice format in {} response: {}",
778
2
                                                 _config.provider_type, response_body);
779
2
                }
780
781
1
                results.emplace_back(choices[i]["message"]["content"].GetString());
782
1
            }
783
3
        } else {
784
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
785
1
                                         response_body);
786
1
        }
787
788
2
        return Status::OK();
789
5
    }
790
791
    Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
792
                                              const std::string& /*media_url*/,
793
1
                                              std::string& /*request_body*/) const override {
794
1
        return Status::NotSupported("{} does not support multimodal Embed feature.",
795
1
                                    _config.provider_type);
796
1
    }
797
798
protected:
799
2
    bool supports_dimension_param(const std::string& model_name) const override {
800
2
        return !(model_name == "text-embedding-ada-002");
801
2
    }
802
803
2
    std::string get_dimension_param_name() const override { return "dimensions"; }
804
};
805
806
class DeepSeekAdapter : public OpenAIAdapter {
807
public:
808
    Status build_embedding_request(const std::vector<std::string>& inputs,
809
1
                                   std::string& request_body) const override {
810
1
        return Status::NotSupported("{} does not support the Embed feature.",
811
1
                                    _config.provider_type);
812
1
    }
813
814
    Status parse_embedding_response(const std::string& response_body,
815
1
                                    std::vector<std::vector<float>>& results) const override {
816
1
        return Status::NotSupported("{} does not support the Embed feature.",
817
1
                                    _config.provider_type);
818
1
    }
819
};
820
821
class MoonShotAdapter : public OpenAIAdapter {
822
public:
823
    Status build_embedding_request(const std::vector<std::string>& inputs,
824
1
                                   std::string& request_body) const override {
825
1
        return Status::NotSupported("{} does not support the Embed feature.",
826
1
                                    _config.provider_type);
827
1
    }
828
829
    Status parse_embedding_response(const std::string& response_body,
830
1
                                    std::vector<std::vector<float>>& results) const override {
831
1
        return Status::NotSupported("{} does not support the Embed feature.",
832
1
                                    _config.provider_type);
833
1
    }
834
};
835
836
class MinimaxAdapter : public OpenAIAdapter {
837
public:
838
    Status build_embedding_request(const std::vector<std::string>& inputs,
839
1
                                   std::string& request_body) const override {
840
1
        rapidjson::Document doc;
841
1
        doc.SetObject();
842
1
        auto& allocator = doc.GetAllocator();
843
844
        /*{
845
          "text": ["xxx", "xxx", ...],
846
          "model": "embo-1",
847
          "type": "db"
848
        }*/
849
1
        rapidjson::Value texts(rapidjson::kArrayType);
850
1
        for (const auto& input : inputs) {
851
1
            texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
852
1
        }
853
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
854
1
        doc.AddMember("texts", texts, allocator);
855
1
        doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
856
857
1
        rapidjson::StringBuffer buffer;
858
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
859
1
        doc.Accept(writer);
860
1
        request_body = buffer.GetString();
861
862
1
        return Status::OK();
863
1
    }
864
};
865
866
class ZhipuAdapter : public OpenAIAdapter {
867
protected:
868
2
    bool supports_dimension_param(const std::string& model_name) const override {
869
2
        return !(model_name == "embedding-2");
870
2
    }
871
};
872
873
class QwenAdapter : public OpenAIAdapter {
874
public:
875
    Status build_multimodal_embedding_request(MultimodalType media_type,
876
                                              const std::string& media_url,
877
3
                                              std::string& request_body) const override {
878
3
        if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO) {
879
1
            return Status::InvalidArgument(
880
1
                    "QWEN only supports image/video multimodal embed, got {}",
881
1
                    multimodal_type_to_string(media_type));
882
1
        }
883
884
2
        rapidjson::Document doc;
885
2
        doc.SetObject();
886
2
        auto& allocator = doc.GetAllocator();
887
888
        /*{
889
            "model": "tongyi-embedding-vision-plus",
890
            "input": {
891
              "contents": [
892
                {"image": "<url>"}
893
              ]
894
            }
895
            "parameters": {
896
              "dimension": 512
897
            }
898
        }*/
899
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
900
2
        rapidjson::Value input(rapidjson::kObjectType);
901
2
        rapidjson::Value contents(rapidjson::kArrayType);
902
903
2
        rapidjson::Value media_item(rapidjson::kObjectType);
904
2
        if (media_type == MultimodalType::IMAGE) {
905
1
            media_item.AddMember("image", rapidjson::Value(media_url.c_str(), allocator),
906
1
                                 allocator);
907
1
        } else {
908
1
            media_item.AddMember("video", rapidjson::Value(media_url.c_str(), allocator),
909
1
                                 allocator);
910
1
        }
911
2
        contents.PushBack(media_item, allocator);
912
913
2
        input.AddMember("contents", contents, allocator);
914
2
        doc.AddMember("input", input, allocator);
915
2
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
916
2
            rapidjson::Value parameters(rapidjson::kObjectType);
917
2
            std::string param_name = get_dimension_param_name();
918
2
            rapidjson::Value dimension_name(param_name.c_str(), allocator);
919
2
            parameters.AddMember(dimension_name, _config.dimensions, allocator);
920
2
            doc.AddMember("parameters", parameters, allocator);
921
2
        }
922
923
2
        rapidjson::StringBuffer buffer;
924
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
925
2
        doc.Accept(writer);
926
2
        request_body = buffer.GetString();
927
2
        return Status::OK();
928
3
    }
929
930
    Status parse_embedding_response(const std::string& response_body,
931
0
                                    std::vector<std::vector<float>>& results) const override {
932
0
        rapidjson::Document doc;
933
0
        doc.Parse(response_body.c_str());
934
935
0
        if (doc.HasParseError() || !doc.IsObject()) [[unlikely]] {
936
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
937
0
                                         response_body);
938
0
        }
939
        // Qwen multimodal embedding usually returns:
940
        // {
941
        //   "output": {
942
        //     "embeddings": [
943
        //       {"index":0, "embedding":[...], "type":"image|video|text"},
944
        //       ...
945
        //     ]
946
        //   }
947
        // }
948
        //
949
        // In text-only or compatibility endpoints, Qwen may also return OpenAI-style
950
        // "data":[{"embedding":[...]}]. For compatibility we first parse native
951
        // output.embeddings and then fallback to OpenAIAdapter parser.
952
0
        if (doc.HasMember("output") && doc["output"].IsObject() &&
953
0
            doc["output"].HasMember("embeddings") && doc["output"]["embeddings"].IsArray()) {
954
0
            const auto& embeddings = doc["output"]["embeddings"];
955
0
            results.reserve(embeddings.Size());
956
0
            for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) {
957
0
                if (!embeddings[i].HasMember("embedding") ||
958
0
                    !embeddings[i]["embedding"].IsArray()) {
959
0
                    return Status::InternalError("Invalid {} response format: {}",
960
0
                                                 _config.provider_type, response_body);
961
0
                }
962
0
                std::transform(embeddings[i]["embedding"].Begin(), embeddings[i]["embedding"].End(),
963
0
                               std::back_inserter(results.emplace_back()),
964
0
                               [](const auto& val) { return val.GetFloat(); });
965
0
            }
966
0
            return Status::OK();
967
0
        }
968
0
        return OpenAIAdapter::parse_embedding_response(response_body, results);
969
0
    }
970
971
protected:
972
4
    bool supports_dimension_param(const std::string& model_name) const override {
973
4
        static const std::unordered_set<std::string> no_dimension_models = {
974
4
                "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
975
4
        return !no_dimension_models.contains(model_name);
976
4
    }
977
978
3
    std::string get_dimension_param_name() const override { return "dimension"; }
979
};
980
981
class JinaAdapter : public VoyageAIAdapter {
982
public:
983
    Status build_multimodal_embedding_request(MultimodalType media_type,
984
                                              const std::string& media_url,
985
1
                                              std::string& request_body) const override {
986
1
        if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO)
987
0
                [[unlikely]] {
988
0
            return Status::InvalidArgument(
989
0
                    "JINA only supports image/video multimodal embed, got {}",
990
0
                    multimodal_type_to_string(media_type));
991
0
        }
992
993
1
        rapidjson::Document doc;
994
1
        doc.SetObject();
995
1
        auto& allocator = doc.GetAllocator();
996
997
        /*{
998
            "model": "jina-embeddings-v4",
999
            "task": "text-matching",
1000
            "input": [
1001
              {"image": "<url>"}
1002
            ]
1003
        }*/
1004
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1005
1
        doc.AddMember("task", "text-matching", allocator);
1006
1007
1
        rapidjson::Value input(rapidjson::kArrayType);
1008
1
        rapidjson::Value media_item(rapidjson::kObjectType);
1009
1
        if (media_type == MultimodalType::IMAGE) {
1010
1
            media_item.AddMember("image", rapidjson::Value(media_url.c_str(), allocator),
1011
1
                                 allocator);
1012
1
        } else {
1013
0
            media_item.AddMember("video", rapidjson::Value(media_url.c_str(), allocator),
1014
0
                                 allocator);
1015
0
        }
1016
1
        input.PushBack(media_item, allocator);
1017
1
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
1018
1
            doc.AddMember("dimensions", _config.dimensions, allocator);
1019
1
        }
1020
1
        doc.AddMember("input", input, allocator);
1021
1022
1
        rapidjson::StringBuffer buffer;
1023
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1024
1
        doc.Accept(writer);
1025
1
        request_body = buffer.GetString();
1026
1
        return Status::OK();
1027
1
    }
1028
};
1029
1030
class BaichuanAdapter : public OpenAIAdapter {
1031
protected:
1032
0
    bool supports_dimension_param(const std::string& model_name) const override { return false; }
1033
};
1034
1035
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
1036
class GeminiAdapter : public AIAdapter {
1037
public:
1038
2
    Status set_authentication(HttpClient* client) const override {
1039
2
        client->set_header("x-goog-api-key", _config.api_key);
1040
2
        client->set_content_type("application/json");
1041
2
        return Status::OK();
1042
2
    }
1043
1044
    Status build_request_payload(const std::vector<std::string>& inputs,
1045
                                 const char* const system_prompt,
1046
1
                                 std::string& request_body) const override {
1047
1
        rapidjson::Document doc;
1048
1
        doc.SetObject();
1049
1
        auto& allocator = doc.GetAllocator();
1050
1051
        /*{
1052
          "systemInstruction": {
1053
              "parts": [
1054
                {
1055
                  "text": "system_prompt here"
1056
                }
1057
              ]
1058
            }
1059
          ],
1060
          "contents": [
1061
            {
1062
              "parts": [
1063
                {
1064
                  "text": "xxx"
1065
                }
1066
              ]
1067
            }
1068
          ],
1069
          "generationConfig": {
1070
          "temperature": 0.7,
1071
          "maxOutputTokens": 1024
1072
          }
1073
1074
        }*/
1075
1
        if (system_prompt && *system_prompt) {
1076
1
            rapidjson::Value system_instruction(rapidjson::kObjectType);
1077
1
            rapidjson::Value parts(rapidjson::kArrayType);
1078
1079
1
            rapidjson::Value part(rapidjson::kObjectType);
1080
1
            part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
1081
1
            parts.PushBack(part, allocator);
1082
            // system_instruction.PushBack(content, allocator);
1083
1
            system_instruction.AddMember("parts", parts, allocator);
1084
1
            doc.AddMember("systemInstruction", system_instruction, allocator);
1085
1
        }
1086
1087
1
        rapidjson::Value contents(rapidjson::kArrayType);
1088
1
        for (const auto& input : inputs) {
1089
1
            rapidjson::Value content(rapidjson::kObjectType);
1090
1
            rapidjson::Value parts(rapidjson::kArrayType);
1091
1092
1
            rapidjson::Value part(rapidjson::kObjectType);
1093
1
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
1094
1095
1
            parts.PushBack(part, allocator);
1096
1
            content.AddMember("parts", parts, allocator);
1097
1
            contents.PushBack(content, allocator);
1098
1
        }
1099
1
        doc.AddMember("contents", contents, allocator);
1100
1101
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1102
1
        rapidjson::Value generationConfig(rapidjson::kObjectType);
1103
1
        if (_config.temperature != -1) {
1104
1
            generationConfig.AddMember("temperature", _config.temperature, allocator);
1105
1
        }
1106
1
        if (_config.max_tokens != -1) {
1107
1
            generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
1108
1
        }
1109
1
        doc.AddMember("generationConfig", generationConfig, allocator);
1110
1111
1
        rapidjson::StringBuffer buffer;
1112
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1113
1
        doc.Accept(writer);
1114
1
        request_body = buffer.GetString();
1115
1116
1
        return Status::OK();
1117
1
    }
1118
1119
    Status parse_response(const std::string& response_body,
1120
3
                          std::vector<std::string>& results) const override {
1121
3
        rapidjson::Document doc;
1122
3
        doc.Parse(response_body.c_str());
1123
1124
3
        if (doc.HasParseError() || !doc.IsObject()) {
1125
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1126
1
                                         response_body);
1127
1
        }
1128
2
        if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
1129
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1130
1
                                         response_body);
1131
1
        }
1132
1133
        /*{
1134
          "candidates":[
1135
            {
1136
              "content": {
1137
                "parts": [
1138
                  {
1139
                    "text": "xxx"
1140
                  }
1141
                ]
1142
              }
1143
            }
1144
          ]
1145
        }*/
1146
1
        const auto& candidates = doc["candidates"];
1147
1
        results.reserve(candidates.Size());
1148
1149
2
        for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
1150
1
            if (!candidates[i].HasMember("content") ||
1151
1
                !candidates[i]["content"].HasMember("parts") ||
1152
1
                !candidates[i]["content"]["parts"].IsArray() ||
1153
1
                candidates[i]["content"]["parts"].Empty() ||
1154
1
                !candidates[i]["content"]["parts"][0].HasMember("text") ||
1155
1
                !candidates[i]["content"]["parts"][0]["text"].IsString()) {
1156
0
                return Status::InternalError("Invalid candidate format in {} response",
1157
0
                                             _config.provider_type);
1158
0
            }
1159
1160
1
            results.emplace_back(candidates[i]["content"]["parts"][0]["text"].GetString());
1161
1
        }
1162
1163
1
        return Status::OK();
1164
1
    }
1165
1166
    Status build_embedding_request(const std::vector<std::string>& inputs,
1167
2
                                   std::string& request_body) const override {
1168
2
        rapidjson::Document doc;
1169
2
        doc.SetObject();
1170
2
        auto& allocator = doc.GetAllocator();
1171
1172
        /*{
1173
          "model": "models/gemini-embedding-001",
1174
          "content": {
1175
              "parts": [
1176
                {
1177
                  "text": "xxx"
1178
                }
1179
              ]
1180
            }
1181
          "outputDimensionality": 1024
1182
        }*/
1183
1184
        // gemini requires the model format as `models/{model}`
1185
2
        std::string model_name = _config.model_name;
1186
2
        if (!model_name.starts_with("models/")) {
1187
2
            model_name = "models/" + model_name;
1188
2
        }
1189
2
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
1190
2
        add_dimension_params(doc, allocator);
1191
1192
2
        rapidjson::Value content(rapidjson::kObjectType);
1193
2
        for (const auto& input : inputs) {
1194
2
            rapidjson::Value parts(rapidjson::kArrayType);
1195
2
            rapidjson::Value part(rapidjson::kObjectType);
1196
2
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
1197
2
            parts.PushBack(part, allocator);
1198
2
            content.AddMember("parts", parts, allocator);
1199
2
        }
1200
2
        doc.AddMember("content", content, allocator);
1201
1202
2
        rapidjson::StringBuffer buffer;
1203
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1204
2
        doc.Accept(writer);
1205
2
        request_body = buffer.GetString();
1206
1207
2
        return Status::OK();
1208
2
    }
1209
1210
    Status build_multimodal_embedding_request(MultimodalType media_type,
1211
                                              const std::string& media_url,
1212
3
                                              std::string& request_body) const override {
1213
3
        const char* mime_type = nullptr;
1214
3
        switch (media_type) {
1215
1
        case MultimodalType::IMAGE:
1216
1
            mime_type = "image/png";
1217
1
            break;
1218
1
        case MultimodalType::AUDIO:
1219
1
            mime_type = "audio/mpeg";
1220
1
            break;
1221
1
        case MultimodalType::VIDEO:
1222
1
            mime_type = "video/mp4";
1223
1
            break;
1224
3
        }
1225
1226
3
        rapidjson::Document doc;
1227
3
        doc.SetObject();
1228
3
        auto& allocator = doc.GetAllocator();
1229
1230
3
        std::string model_name = _config.model_name;
1231
3
        if (!model_name.starts_with("models/")) {
1232
3
            model_name = "models/" + model_name;
1233
3
        }
1234
3
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
1235
3
        add_dimension_params(doc, allocator);
1236
1237
3
        rapidjson::Value content(rapidjson::kObjectType);
1238
3
        rapidjson::Value parts(rapidjson::kArrayType);
1239
3
        rapidjson::Value part(rapidjson::kObjectType);
1240
3
        rapidjson::Value file_data(rapidjson::kObjectType);
1241
3
        file_data.AddMember("mime_type", rapidjson::Value(mime_type, allocator), allocator);
1242
3
        file_data.AddMember("file_uri", rapidjson::Value(media_url.c_str(), allocator), allocator);
1243
3
        part.AddMember("file_data", file_data, allocator);
1244
3
        parts.PushBack(part, allocator);
1245
3
        content.AddMember("parts", parts, allocator);
1246
3
        doc.AddMember("content", content, allocator);
1247
1248
3
        rapidjson::StringBuffer buffer;
1249
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1250
3
        doc.Accept(writer);
1251
3
        request_body = buffer.GetString();
1252
3
        return Status::OK();
1253
3
    }
1254
1255
    Status parse_embedding_response(const std::string& response_body,
1256
1
                                    std::vector<std::vector<float>>& results) const override {
1257
1
        rapidjson::Document doc;
1258
1
        doc.Parse(response_body.c_str());
1259
1260
1
        if (doc.HasParseError() || !doc.IsObject()) {
1261
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1262
0
                                         response_body);
1263
0
        }
1264
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
1265
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1266
0
                                         response_body);
1267
0
        }
1268
1269
        /*{
1270
          "embedding":{
1271
            "values": [0.1, 0.2, 0.3]
1272
          }
1273
        }*/
1274
1
        const auto& embedding = doc["embedding"];
1275
1
        if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
1276
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1277
0
                                         response_body);
1278
0
        }
1279
1
        std::transform(embedding["values"].Begin(), embedding["values"].End(),
1280
1
                       std::back_inserter(results.emplace_back()),
1281
3
                       [](const auto& val) { return val.GetFloat(); });
1282
1283
1
        return Status::OK();
1284
1
    }
1285
1286
protected:
1287
5
    bool supports_dimension_param(const std::string& model_name) const override {
1288
5
        static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
1289
5
                                                                            "embedding-001"};
1290
5
        return !no_dimension_models.contains(model_name);
1291
5
    }
1292
1293
4
    std::string get_dimension_param_name() const override { return "outputDimensionality"; }
1294
};
1295
1296
class AnthropicAdapter : public VoyageAIAdapter {
1297
public:
1298
1
    Status set_authentication(HttpClient* client) const override {
1299
1
        client->set_header("x-api-key", _config.api_key);
1300
1
        client->set_header("anthropic-version", _config.anthropic_version);
1301
1
        client->set_content_type("application/json");
1302
1303
1
        return Status::OK();
1304
1
    }
1305
1306
    Status build_request_payload(const std::vector<std::string>& inputs,
1307
                                 const char* const system_prompt,
1308
1
                                 std::string& request_body) const override {
1309
1
        rapidjson::Document doc;
1310
1
        doc.SetObject();
1311
1
        auto& allocator = doc.GetAllocator();
1312
1313
        /*
1314
            "model": "claude-opus-4-1-20250805",
1315
            "max_tokens": 1024,
1316
            "system": "system_prompt here",
1317
            "messages": [
1318
              {"role": "user", "content": "xxx"}
1319
            ],
1320
            "temperature": 0.7
1321
        */
1322
1323
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1324
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1325
1
        if (_config.temperature != -1) {
1326
1
            doc.AddMember("temperature", _config.temperature, allocator);
1327
1
        }
1328
1
        if (_config.max_tokens != -1) {
1329
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
1330
1
        } else {
1331
            // Keep the default value, Anthropic requires this parameter
1332
0
            doc.AddMember("max_tokens", 2048, allocator);
1333
0
        }
1334
1
        if (system_prompt && *system_prompt) {
1335
1
            doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
1336
1
        }
1337
1338
1
        rapidjson::Value messages(rapidjson::kArrayType);
1339
1
        for (const auto& input : inputs) {
1340
1
            rapidjson::Value message(rapidjson::kObjectType);
1341
1
            message.AddMember("role", "user", allocator);
1342
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
1343
1
            messages.PushBack(message, allocator);
1344
1
        }
1345
1
        doc.AddMember("messages", messages, allocator);
1346
1347
1
        rapidjson::StringBuffer buffer;
1348
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1349
1
        doc.Accept(writer);
1350
1
        request_body = buffer.GetString();
1351
1352
1
        return Status::OK();
1353
1
    }
1354
1355
    Status parse_response(const std::string& response_body,
1356
3
                          std::vector<std::string>& results) const override {
1357
3
        rapidjson::Document doc;
1358
3
        doc.Parse(response_body.c_str());
1359
3
        if (doc.HasParseError() || !doc.IsObject()) {
1360
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1361
1
                                         response_body);
1362
1
        }
1363
2
        if (!doc.HasMember("content") || !doc["content"].IsArray()) {
1364
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1365
1
                                         response_body);
1366
1
        }
1367
1368
        /*{
1369
            "content": [
1370
              {
1371
                "text": "xxx",
1372
                "type": "text"
1373
              }
1374
            ]
1375
        }*/
1376
1
        const auto& content = doc["content"];
1377
1
        results.reserve(1);
1378
1379
1
        std::string result;
1380
2
        for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
1381
1
            if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
1382
1
                !content[i].HasMember("text") || !content[i]["text"].IsString()) {
1383
0
                continue;
1384
0
            }
1385
1386
1
            if (std::string(content[i]["type"].GetString()) == "text") {
1387
1
                if (!result.empty()) {
1388
0
                    result += "\n";
1389
0
                }
1390
1
                result += content[i]["text"].GetString();
1391
1
            }
1392
1
        }
1393
1394
1
        results.emplace_back(std::move(result));
1395
1
        return Status::OK();
1396
2
    }
1397
};
1398
1399
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
1400
class MockAdapter : public AIAdapter {
1401
public:
1402
0
    Status set_authentication(HttpClient* client) const override { return Status::OK(); }
1403
1404
    Status build_request_payload(const std::vector<std::string>& inputs,
1405
                                 const char* const system_prompt,
1406
57
                                 std::string& request_body) const override {
1407
57
        return Status::OK();
1408
57
    }
1409
1410
    Status parse_response(const std::string& response_body,
1411
57
                          std::vector<std::string>& results) const override {
1412
57
        results.emplace_back(response_body);
1413
57
        return Status::OK();
1414
57
    }
1415
1416
    Status build_embedding_request(const std::vector<std::string>& inputs,
1417
3
                                   std::string& request_body) const override {
1418
3
        return Status::OK();
1419
3
    }
1420
1421
    Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
1422
                                              const std::string& /*media_url*/,
1423
4
                                              std::string& /*request_body*/) const override {
1424
4
        return Status::OK();
1425
4
    }
1426
1427
    Status parse_embedding_response(const std::string& response_body,
1428
8
                                    std::vector<std::vector<float>>& results) const override {
1429
8
        rapidjson::Document doc;
1430
8
        doc.SetObject();
1431
8
        doc.Parse(response_body.c_str());
1432
8
        if (doc.HasParseError() || !doc.IsObject()) {
1433
0
            return Status::InternalError("Failed to parse embedding response");
1434
0
        }
1435
8
        if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
1436
0
            return Status::InternalError("Invalid embedding response format");
1437
0
        }
1438
1439
8
        results.reserve(1);
1440
8
        std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
1441
8
                       std::back_inserter(results.emplace_back()),
1442
40
                       [](const auto& val) { return val.GetFloat(); });
1443
8
        return Status::OK();
1444
8
    }
1445
};
1446
1447
class AIAdapterFactory {
1448
public:
1449
91
    static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
1450
91
        static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
1451
91
                adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
1452
91
                            {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
1453
91
                            {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
1454
91
                            {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
1455
91
                            {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
1456
91
                            {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
1457
91
                            {"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
1458
91
                            {"JINA", []() { return std::make_shared<JinaAdapter>(); }},
1459
91
                            {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
1460
91
                            {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
1461
91
                            {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
1462
91
                            {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
1463
91
                            {"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
1464
1465
91
        auto it = adapters.find(provider_type);
1466
91
        return (it != adapters.end()) ? it->second() : nullptr;
1467
91
    }
1468
};
1469
1470
} // namespace doris