Coverage Report

Created: 2026-04-09 15:45

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_adapter.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/PaloInternalService_types.h>
21
#include <rapidjson/rapidjson.h>
22
23
#include <algorithm>
24
#include <memory>
25
#include <string>
26
#include <unordered_map>
27
#include <vector>
28
29
#include "common/status.h"
30
#include "core/string_buffer.hpp"
31
#include "rapidjson/document.h"
32
#include "rapidjson/stringbuffer.h"
33
#include "rapidjson/writer.h"
34
#include "service/http/http_client.h"
35
#include "service/http/http_headers.h"
36
37
namespace doris {
38
#include "common/compile_check_begin.h"
39
40
struct AIResource {
41
16
    AIResource() = default;
42
    AIResource(const TAIResource& tai)
43
11
            : endpoint(tai.endpoint),
44
11
              provider_type(tai.provider_type),
45
11
              model_name(tai.model_name),
46
11
              api_key(tai.api_key),
47
11
              temperature(tai.temperature),
48
11
              max_tokens(tai.max_tokens),
49
11
              max_retries(tai.max_retries),
50
11
              retry_delay_second(tai.retry_delay_second),
51
11
              anthropic_version(tai.anthropic_version),
52
11
              dimensions(tai.dimensions) {}
53
54
    std::string endpoint;
55
    std::string provider_type;
56
    std::string model_name;
57
    std::string api_key;
58
    double temperature;
59
    int64_t max_tokens;
60
    int32_t max_retries;
61
    int32_t retry_delay_second;
62
    std::string anthropic_version;
63
    int32_t dimensions;
64
65
1
    void serialize(BufferWritable& buf) const {
66
1
        buf.write_binary(endpoint);
67
1
        buf.write_binary(provider_type);
68
1
        buf.write_binary(model_name);
69
1
        buf.write_binary(api_key);
70
1
        buf.write_binary(temperature);
71
1
        buf.write_binary(max_tokens);
72
1
        buf.write_binary(max_retries);
73
1
        buf.write_binary(retry_delay_second);
74
1
        buf.write_binary(anthropic_version);
75
1
        buf.write_binary(dimensions);
76
1
    }
77
78
1
    void deserialize(BufferReadable& buf) {
79
1
        buf.read_binary(endpoint);
80
1
        buf.read_binary(provider_type);
81
1
        buf.read_binary(model_name);
82
1
        buf.read_binary(api_key);
83
1
        buf.read_binary(temperature);
84
1
        buf.read_binary(max_tokens);
85
1
        buf.read_binary(max_retries);
86
1
        buf.read_binary(retry_delay_second);
87
1
        buf.read_binary(anthropic_version);
88
1
        buf.read_binary(dimensions);
89
1
    }
90
};
91
92
enum class MultimodalType { IMAGE, VIDEO, AUDIO };
93
94
1
inline const char* multimodal_type_to_string(MultimodalType type) {
95
1
    switch (type) {
96
0
    case MultimodalType::IMAGE:
97
0
        return "image";
98
0
    case MultimodalType::VIDEO:
99
0
        return "video";
100
1
    case MultimodalType::AUDIO:
101
1
        return "audio";
102
1
    }
103
0
    return "unknown";
104
1
}
105
106
class AIAdapter {
107
public:
108
119
    virtual ~AIAdapter() = default;
109
110
    // Set authentication headers for the HTTP client
111
    virtual Status set_authentication(HttpClient* client) const = 0;
112
113
78
    virtual void init(const TAIResource& config) { _config = config; }
114
12
    virtual void init(const AIResource& config) {
115
12
        _config.endpoint = config.endpoint;
116
12
        _config.provider_type = config.provider_type;
117
12
        _config.model_name = config.model_name;
118
12
        _config.api_key = config.api_key;
119
12
        _config.temperature = config.temperature;
120
12
        _config.max_tokens = config.max_tokens;
121
12
        _config.max_retries = config.max_retries;
122
12
        _config.retry_delay_second = config.retry_delay_second;
123
12
        _config.anthropic_version = config.anthropic_version;
124
12
    }
125
126
    // Build request payload based on input text strings
127
    virtual Status build_request_payload(const std::vector<std::string>& inputs,
128
                                         const char* const system_prompt,
129
1
                                         std::string& request_body) const {
130
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
131
1
    }
132
133
    // Parse response from AI service and extract generated text results
134
    virtual Status parse_response(const std::string& response_body,
135
1
                                  std::vector<std::string>& results) const {
136
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
137
1
    }
138
139
    virtual Status build_embedding_request(const std::vector<std::string>& inputs,
140
0
                                           std::string& request_body) const {
141
0
        return Status::NotSupported("{} does not support the Embed feature.",
142
0
                                    _config.provider_type);
143
0
    }
144
145
    virtual Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
146
                                                      const std::string& /*media_url*/,
147
0
                                                      std::string& /*request_body*/) const {
148
0
        return Status::NotSupported("{} does not support multimodal Embed feature.",
149
0
                                    _config.provider_type);
150
0
    }
151
152
    virtual Status parse_embedding_response(const std::string& response_body,
153
0
                                            std::vector<std::vector<float>>& results) const {
154
0
        return Status::NotSupported("{} does not support the Embed feature.",
155
0
                                    _config.provider_type);
156
0
    }
157
158
protected:
159
    TAIResource _config;
160
161
    // return true if the model support dimension parameter
162
1
    virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
163
164
    // Different providers may have different dimension parameter names.
165
0
    virtual std::string get_dimension_param_name() const { return "dimensions"; }
166
167
    virtual void add_dimension_params(rapidjson::Value& doc,
168
14
                                      rapidjson::Document::AllocatorType& allocator) const {
169
14
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
170
8
            std::string param_name = get_dimension_param_name();
171
8
            rapidjson::Value name(param_name.c_str(), allocator);
172
8
            doc.AddMember(name, _config.dimensions, allocator);
173
8
        }
174
14
    }
175
};
176
177
// Most LLM-providers' Embedding formats are based on VoyageAI.
178
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
179
class VoyageAIAdapter : public AIAdapter {
180
public:
181
2
    Status set_authentication(HttpClient* client) const override {
182
2
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
183
2
        client->set_content_type("application/json");
184
185
2
        return Status::OK();
186
2
    }
187
188
    Status build_embedding_request(const std::vector<std::string>& inputs,
189
8
                                   std::string& request_body) const override {
190
8
        rapidjson::Document doc;
191
8
        doc.SetObject();
192
8
        auto& allocator = doc.GetAllocator();
193
194
        /*{
195
            "model": "xxx",
196
            "input": [
197
              "xxx",
198
              "xxx",
199
              ...
200
            ],
201
            "output_dimensions": 512
202
        }*/
203
8
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
204
8
        add_dimension_params(doc, allocator);
205
206
8
        rapidjson::Value input(rapidjson::kArrayType);
207
8
        for (const auto& msg : inputs) {
208
8
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
209
8
        }
210
8
        doc.AddMember("input", input, allocator);
211
212
8
        rapidjson::StringBuffer buffer;
213
8
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
214
8
        doc.Accept(writer);
215
8
        request_body = buffer.GetString();
216
217
8
        return Status::OK();
218
8
    }
219
220
    Status build_multimodal_embedding_request(MultimodalType media_type,
221
                                              const std::string& media_url,
222
1
                                              std::string& request_body) const override {
223
1
        if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO)
224
0
                [[unlikely]] {
225
0
            return Status::InvalidArgument(
226
0
                    "VoyageAI only supports image/video multimodal embed, got {}",
227
0
                    multimodal_type_to_string(media_type));
228
0
        }
229
1
        if (_config.dimensions != -1) {
230
1
            LOG(WARNING) << "VoyageAI multimodal embedding currently ignores dimensions parameter, "
231
1
                         << "model=" << _config.model_name << ", dimensions=" << _config.dimensions;
232
1
        }
233
234
1
        rapidjson::Document doc;
235
1
        doc.SetObject();
236
1
        auto& allocator = doc.GetAllocator();
237
238
        /*{
239
            "inputs": [
240
              {
241
                "content": [
242
                  {"type": "image_url", "image_url": "<url>"}
243
                ]
244
              }
245
            ],
246
            "model": "voyage-multimodal-3.5"
247
        }*/
248
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
249
250
1
        rapidjson::Value inputs(rapidjson::kArrayType);
251
1
        rapidjson::Value input(rapidjson::kObjectType);
252
1
        rapidjson::Value content(rapidjson::kArrayType);
253
254
1
        rapidjson::Value media_item(rapidjson::kObjectType);
255
1
        if (media_type == MultimodalType::IMAGE) {
256
0
            media_item.AddMember("type", "image_url", allocator);
257
0
            media_item.AddMember("image_url", rapidjson::Value(media_url.c_str(), allocator),
258
0
                                 allocator);
259
1
        } else {
260
1
            media_item.AddMember("type", "video_url", allocator);
261
1
            media_item.AddMember("video_url", rapidjson::Value(media_url.c_str(), allocator),
262
1
                                 allocator);
263
1
        }
264
1
        content.PushBack(media_item, allocator);
265
266
1
        input.AddMember("content", content, allocator);
267
1
        inputs.PushBack(input, allocator);
268
1
        doc.AddMember("inputs", inputs, allocator);
269
270
1
        rapidjson::StringBuffer buffer;
271
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
272
1
        doc.Accept(writer);
273
1
        request_body = buffer.GetString();
274
1
        return Status::OK();
275
1
    }
276
277
    Status parse_embedding_response(const std::string& response_body,
278
5
                                    std::vector<std::vector<float>>& results) const override {
279
5
        rapidjson::Document doc;
280
5
        doc.Parse(response_body.c_str());
281
282
5
        if (doc.HasParseError() || !doc.IsObject()) {
283
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
284
1
                                         response_body);
285
1
        }
286
4
        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
287
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
288
1
                                         response_body);
289
1
        }
290
291
        /*{
292
            "data":[
293
              {
294
                "object": "embedding",
295
                "embedding": [...], <- only need this
296
                "index": 0
297
              },
298
              {
299
                "object": "embedding",
300
                "embedding": [...],
301
                "index": 1
302
              }, ...
303
            ],
304
            "model"....
305
        }*/
306
3
        const auto& data = doc["data"];
307
3
        results.reserve(data.Size());
308
7
        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
309
5
            if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
310
1
                return Status::InternalError("Invalid {} response format: {}",
311
1
                                             _config.provider_type, response_body);
312
1
            }
313
314
4
            std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
315
4
                           std::back_inserter(results.emplace_back()),
316
10
                           [](const auto& val) { return val.GetFloat(); });
317
4
        }
318
319
2
        return Status::OK();
320
3
    }
321
322
protected:
323
3
    bool supports_dimension_param(const std::string& model_name) const override {
324
3
        static const std::unordered_set<std::string> no_dimension_models = {
325
3
                "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
326
3
                "voyage-multimodal-3"};
327
3
        return !no_dimension_models.contains(model_name);
328
3
    }
329
330
1
    std::string get_dimension_param_name() const override { return "output_dimension"; }
331
};
332
333
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
334
class LocalAdapter : public AIAdapter {
335
public:
336
    // Local deployments typically don't need authentication
337
2
    Status set_authentication(HttpClient* client) const override {
338
2
        client->set_content_type("application/json");
339
2
        return Status::OK();
340
2
    }
341
342
    Status build_request_payload(const std::vector<std::string>& inputs,
343
                                 const char* const system_prompt,
344
3
                                 std::string& request_body) const override {
345
3
        rapidjson::Document doc;
346
3
        doc.SetObject();
347
3
        auto& allocator = doc.GetAllocator();
348
349
3
        std::string end_point = _config.endpoint;
350
3
        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
351
2
            RETURN_IF_ERROR(
352
2
                    build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
353
2
        } else {
354
1
            RETURN_IF_ERROR(
355
1
                    build_default_request(doc, allocator, inputs, system_prompt, request_body));
356
1
        }
357
358
3
        rapidjson::StringBuffer buffer;
359
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
360
3
        doc.Accept(writer);
361
3
        request_body = buffer.GetString();
362
363
3
        return Status::OK();
364
3
    }
365
366
    Status parse_response(const std::string& response_body,
367
7
                          std::vector<std::string>& results) const override {
368
7
        rapidjson::Document doc;
369
7
        doc.Parse(response_body.c_str());
370
371
7
        if (doc.HasParseError() || !doc.IsObject()) {
372
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
373
1
                                         response_body);
374
1
        }
375
376
        // Handle various response formats from local LLMs
377
        // Format 1: OpenAI-compatible format with choices/message/content
378
6
        if (doc.HasMember("choices") && doc["choices"].IsArray()) {
379
1
            const auto& choices = doc["choices"];
380
1
            results.reserve(choices.Size());
381
382
2
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
383
1
                if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
384
1
                    choices[i]["message"]["content"].IsString()) {
385
1
                    results.emplace_back(choices[i]["message"]["content"].GetString());
386
1
                } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
387
                    // Some local LLMs use a simpler format
388
0
                    results.emplace_back(choices[i]["text"].GetString());
389
0
                }
390
1
            }
391
5
        } else if (doc.HasMember("text") && doc["text"].IsString()) {
392
            // Format 2: Simple response with just "text" or "content" field
393
1
            results.emplace_back(doc["text"].GetString());
394
4
        } else if (doc.HasMember("content") && doc["content"].IsString()) {
395
1
            results.emplace_back(doc["content"].GetString());
396
3
        } else if (doc.HasMember("response") && doc["response"].IsString()) {
397
            // Format 3: Response field (Ollama `generate` format)
398
1
            results.emplace_back(doc["response"].GetString());
399
2
        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
400
2
                   doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
401
            // Format 4: message/content field (Ollama `chat` format)
402
1
            results.emplace_back(doc["message"]["content"].GetString());
403
1
        } else {
404
1
            return Status::NotSupported("Unsupported response format from local AI.");
405
1
        }
406
407
5
        return Status::OK();
408
6
    }
409
410
    Status build_embedding_request(const std::vector<std::string>& inputs,
411
1
                                   std::string& request_body) const override {
412
1
        rapidjson::Document doc;
413
1
        doc.SetObject();
414
1
        auto& allocator = doc.GetAllocator();
415
416
1
        if (!_config.model_name.empty()) {
417
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
418
1
                          allocator);
419
1
        }
420
421
1
        add_dimension_params(doc, allocator);
422
423
1
        rapidjson::Value input(rapidjson::kArrayType);
424
1
        for (const auto& msg : inputs) {
425
1
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
426
1
        }
427
1
        doc.AddMember("input", input, allocator);
428
429
1
        rapidjson::StringBuffer buffer;
430
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
431
1
        doc.Accept(writer);
432
1
        request_body = buffer.GetString();
433
434
1
        return Status::OK();
435
1
    }
436
437
    Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
438
                                              const std::string& /*media_url*/,
439
0
                                              std::string& /*request_body*/) const override {
440
0
        return Status::NotSupported("{} does not support multimodal Embed feature.",
441
0
                                    _config.provider_type);
442
0
    }
443
444
    Status parse_embedding_response(const std::string& response_body,
445
3
                                    std::vector<std::vector<float>>& results) const override {
446
3
        rapidjson::Document doc;
447
3
        doc.Parse(response_body.c_str());
448
449
3
        if (doc.HasParseError() || !doc.IsObject()) {
450
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
451
0
                                         response_body);
452
0
        }
453
454
        // parse different response format
455
3
        rapidjson::Value embedding;
456
3
        if (doc.HasMember("data") && doc["data"].IsArray()) {
457
            // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
458
1
            const auto& data = doc["data"];
459
1
            results.reserve(data.Size());
460
3
            for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
461
2
                if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
462
0
                    return Status::InternalError("Invalid {} response format",
463
0
                                                 _config.provider_type);
464
0
                }
465
466
2
                std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
467
2
                               std::back_inserter(results.emplace_back()),
468
5
                               [](const auto& val) { return val.GetFloat(); });
469
2
            }
470
2
        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
471
            // "embeddings":[[0.1, 0.2, ...]]
472
1
            results.reserve(1);
473
2
            for (int i = 0; i < doc["embeddings"].Size(); i++) {
474
1
                embedding = doc["embeddings"][i];
475
1
                std::transform(embedding.Begin(), embedding.End(),
476
1
                               std::back_inserter(results.emplace_back()),
477
2
                               [](const auto& val) { return val.GetFloat(); });
478
1
            }
479
1
        } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
480
            // "embedding":[0.1, 0.2, ...]
481
1
            results.reserve(1);
482
1
            embedding = doc["embedding"];
483
1
            std::transform(embedding.Begin(), embedding.End(),
484
1
                           std::back_inserter(results.emplace_back()),
485
3
                           [](const auto& val) { return val.GetFloat(); });
486
1
        } else {
487
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
488
0
                                         response_body);
489
0
        }
490
491
3
        return Status::OK();
492
3
    }
493
494
private:
495
    Status build_ollama_request(rapidjson::Document& doc,
496
                                rapidjson::Document::AllocatorType& allocator,
497
                                const std::vector<std::string>& inputs,
498
2
                                const char* const system_prompt, std::string& request_body) const {
499
        /*
500
        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
501
        {
502
            "model": <model_name>,
503
            "stream": false,
504
            "think": false,
505
            "options": {
506
                "temperature": <temperature>,
507
                "max_token": <max_token>
508
            },
509
            "messages": [
510
                {"role": "system", "content": <system_prompt>},
511
                {"role": "user", "content": <user_prompt>}
512
            ]
513
        }
514
        
515
        for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
516
        {
517
            "model": <model_name>,
518
            "stream": false,
519
            "think": false
520
            "options": {
521
                "temperature": <temperature>,
522
                "max_token": <max_token>
523
            },
524
            "system": <system_prompt>,
525
            "prompt": <user_prompt>
526
        }
527
        */
528
529
        // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
530
        // The rest remains identical.
531
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
532
2
        doc.AddMember("stream", false, allocator);
533
2
        doc.AddMember("think", false, allocator);
534
535
        // option section
536
2
        rapidjson::Value options(rapidjson::kObjectType);
537
2
        if (_config.temperature != -1) {
538
2
            options.AddMember("temperature", _config.temperature, allocator);
539
2
        }
540
2
        if (_config.max_tokens != -1) {
541
2
            options.AddMember("max_token", _config.max_tokens, allocator);
542
2
        }
543
2
        doc.AddMember("options", options, allocator);
544
545
        // prompt section
546
2
        if (_config.endpoint.ends_with("chat")) {
547
1
            rapidjson::Value messages(rapidjson::kArrayType);
548
1
            if (system_prompt && *system_prompt) {
549
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
550
1
                sys_msg.AddMember("role", "system", allocator);
551
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
552
1
                messages.PushBack(sys_msg, allocator);
553
1
            }
554
1
            for (const auto& input : inputs) {
555
1
                rapidjson::Value message(rapidjson::kObjectType);
556
1
                message.AddMember("role", "user", allocator);
557
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
558
1
                messages.PushBack(message, allocator);
559
1
            }
560
1
            doc.AddMember("messages", messages, allocator);
561
1
        } else {
562
1
            if (system_prompt && *system_prompt) {
563
1
                doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
564
1
            }
565
1
            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
566
1
        }
567
568
2
        return Status::OK();
569
2
    }
570
571
    Status build_default_request(rapidjson::Document& doc,
572
                                 rapidjson::Document::AllocatorType& allocator,
573
                                 const std::vector<std::string>& inputs,
574
1
                                 const char* const system_prompt, std::string& request_body) const {
575
        /*
576
        Default format(OpenAI-compatible):
577
        {
578
            "model": <model_name>,
579
            "temperature": <temperature>,
580
            "max_tokens": <max_tokens>,
581
            "messages": [
582
                {"role": "system", "content": <system_prompt>},
583
                {"role": "user", "content": <user_prompt>}
584
            ]
585
        }
586
        */
587
588
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
589
590
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
591
1
        if (_config.temperature != -1) {
592
1
            doc.AddMember("temperature", _config.temperature, allocator);
593
1
        }
594
1
        if (_config.max_tokens != -1) {
595
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
596
1
        }
597
598
1
        rapidjson::Value messages(rapidjson::kArrayType);
599
1
        if (system_prompt && *system_prompt) {
600
1
            rapidjson::Value sys_msg(rapidjson::kObjectType);
601
1
            sys_msg.AddMember("role", "system", allocator);
602
1
            sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
603
1
            messages.PushBack(sys_msg, allocator);
604
1
        }
605
1
        for (const auto& input : inputs) {
606
1
            rapidjson::Value message(rapidjson::kObjectType);
607
1
            message.AddMember("role", "user", allocator);
608
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
609
1
            messages.PushBack(message, allocator);
610
1
        }
611
1
        doc.AddMember("messages", messages, allocator);
612
1
        return Status::OK();
613
1
    }
614
};
615
616
// The OpenAI API format can be reused with some compatible AIs.
617
class OpenAIAdapter : public VoyageAIAdapter {
618
public:
619
8
    Status set_authentication(HttpClient* client) const override {
620
8
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
621
8
        client->set_content_type("application/json");
622
623
8
        return Status::OK();
624
8
    }
625
626
    Status build_request_payload(const std::vector<std::string>& inputs,
627
                                 const char* const system_prompt,
628
2
                                 std::string& request_body) const override {
629
2
        rapidjson::Document doc;
630
2
        doc.SetObject();
631
2
        auto& allocator = doc.GetAllocator();
632
633
2
        if (_config.endpoint.ends_with("responses")) {
634
            /*{
635
              "model": "gpt-4.1-mini",
636
              "input": [
637
                {"role": "system", "content": "system_prompt here"},
638
                {"role": "user", "content": "xxx"}
639
              ],
640
              "temperature": 0.7,
641
              "max_output_tokens": 150
642
            }*/
643
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
644
1
                          allocator);
645
646
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
647
1
            if (_config.temperature != -1) {
648
1
                doc.AddMember("temperature", _config.temperature, allocator);
649
1
            }
650
1
            if (_config.max_tokens != -1) {
651
1
                doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
652
1
            }
653
654
            // input
655
1
            rapidjson::Value input(rapidjson::kArrayType);
656
1
            if (system_prompt && *system_prompt) {
657
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
658
1
                sys_msg.AddMember("role", "system", allocator);
659
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
660
1
                input.PushBack(sys_msg, allocator);
661
1
            }
662
1
            for (const auto& msg : inputs) {
663
1
                rapidjson::Value message(rapidjson::kObjectType);
664
1
                message.AddMember("role", "user", allocator);
665
1
                message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
666
1
                input.PushBack(message, allocator);
667
1
            }
668
1
            doc.AddMember("input", input, allocator);
669
1
        } else {
670
            /*{
671
              "model": "gpt-4",
672
              "messages": [
673
                {"role": "system", "content": "system_prompt here"},
674
                {"role": "user", "content": "xxx"}
675
              ],
676
              "temperature": x,
677
              "max_tokens": x,
678
            }*/
679
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
680
1
                          allocator);
681
682
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
683
1
            if (_config.temperature != -1) {
684
1
                doc.AddMember("temperature", _config.temperature, allocator);
685
1
            }
686
1
            if (_config.max_tokens != -1) {
687
1
                doc.AddMember("max_tokens", _config.max_tokens, allocator);
688
1
            }
689
690
1
            rapidjson::Value messages(rapidjson::kArrayType);
691
1
            if (system_prompt && *system_prompt) {
692
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
693
1
                sys_msg.AddMember("role", "system", allocator);
694
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
695
1
                messages.PushBack(sys_msg, allocator);
696
1
            }
697
1
            for (const auto& input : inputs) {
698
1
                rapidjson::Value message(rapidjson::kObjectType);
699
1
                message.AddMember("role", "user", allocator);
700
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
701
1
                messages.PushBack(message, allocator);
702
1
            }
703
1
            doc.AddMember("messages", messages, allocator);
704
1
        }
705
706
2
        rapidjson::StringBuffer buffer;
707
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
708
2
        doc.Accept(writer);
709
2
        request_body = buffer.GetString();
710
711
2
        return Status::OK();
712
2
    }
713
714
    Status parse_response(const std::string& response_body,
715
6
                          std::vector<std::string>& results) const override {
716
6
        rapidjson::Document doc;
717
6
        doc.Parse(response_body.c_str());
718
719
6
        if (doc.HasParseError() || !doc.IsObject()) {
720
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
721
1
                                         response_body);
722
1
        }
723
724
5
        if (doc.HasMember("output") && doc["output"].IsArray()) {
725
            /// for responses endpoint
726
            /*{
727
              "output": [
728
                {
729
                  "id": "msg_123",
730
                  "type": "message",
731
                  "role": "assistant",
732
                  "content": [
733
                    {
734
                      "type": "text",
735
                      "text": "result text here"   <- result
736
                    }
737
                  ]
738
                }
739
              ]
740
            }*/
741
1
            const auto& output = doc["output"];
742
1
            results.reserve(output.Size());
743
744
2
            for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
745
1
                if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
746
1
                    output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
747
1
                    !output[i]["content"][0]["text"].IsString()) {
748
0
                    return Status::InternalError("Invalid output format in {} response: {}",
749
0
                                                 _config.provider_type, response_body);
750
0
                }
751
752
1
                results.emplace_back(output[i]["content"][0]["text"].GetString());
753
1
            }
754
4
        } else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
755
            /// for completions endpoint
756
            /*{
757
              "object": "chat.completion",
758
              "model": "gpt-4",
759
              "choices": [
760
                {
761
                  ...
762
                  "message": {
763
                    "role": "assistant",
764
                    "content": "xxx"      <- result
765
                  },
766
                  ...
767
                }
768
              ],
769
              ...
770
            }*/
771
3
            const auto& choices = doc["choices"];
772
3
            results.reserve(choices.Size());
773
774
4
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
775
3
                if (!choices[i].HasMember("message") ||
776
3
                    !choices[i]["message"].HasMember("content") ||
777
3
                    !choices[i]["message"]["content"].IsString()) {
778
2
                    return Status::InternalError("Invalid choice format in {} response: {}",
779
2
                                                 _config.provider_type, response_body);
780
2
                }
781
782
1
                results.emplace_back(choices[i]["message"]["content"].GetString());
783
1
            }
784
3
        } else {
785
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
786
1
                                         response_body);
787
1
        }
788
789
2
        return Status::OK();
790
5
    }
791
792
    Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
793
                                              const std::string& /*media_url*/,
794
1
                                              std::string& /*request_body*/) const override {
795
1
        return Status::NotSupported("{} does not support multimodal Embed feature.",
796
1
                                    _config.provider_type);
797
1
    }
798
799
protected:
800
2
    bool supports_dimension_param(const std::string& model_name) const override {
801
2
        return !(model_name == "text-embedding-ada-002");
802
2
    }
803
804
2
    std::string get_dimension_param_name() const override { return "dimensions"; }
805
};
806
807
class DeepSeekAdapter : public OpenAIAdapter {
808
public:
809
    Status build_embedding_request(const std::vector<std::string>& inputs,
810
1
                                   std::string& request_body) const override {
811
1
        return Status::NotSupported("{} does not support the Embed feature.",
812
1
                                    _config.provider_type);
813
1
    }
814
815
    Status parse_embedding_response(const std::string& response_body,
816
1
                                    std::vector<std::vector<float>>& results) const override {
817
1
        return Status::NotSupported("{} does not support the Embed feature.",
818
1
                                    _config.provider_type);
819
1
    }
820
};
821
822
class MoonShotAdapter : public OpenAIAdapter {
823
public:
824
    Status build_embedding_request(const std::vector<std::string>& inputs,
825
1
                                   std::string& request_body) const override {
826
1
        return Status::NotSupported("{} does not support the Embed feature.",
827
1
                                    _config.provider_type);
828
1
    }
829
830
    Status parse_embedding_response(const std::string& response_body,
831
1
                                    std::vector<std::vector<float>>& results) const override {
832
1
        return Status::NotSupported("{} does not support the Embed feature.",
833
1
                                    _config.provider_type);
834
1
    }
835
};
836
837
class MinimaxAdapter : public OpenAIAdapter {
838
public:
839
    Status build_embedding_request(const std::vector<std::string>& inputs,
840
1
                                   std::string& request_body) const override {
841
1
        rapidjson::Document doc;
842
1
        doc.SetObject();
843
1
        auto& allocator = doc.GetAllocator();
844
845
        /*{
846
          "text": ["xxx", "xxx", ...],
847
          "model": "embo-1",
848
          "type": "db"
849
        }*/
850
1
        rapidjson::Value texts(rapidjson::kArrayType);
851
1
        for (const auto& input : inputs) {
852
1
            texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
853
1
        }
854
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
855
1
        doc.AddMember("texts", texts, allocator);
856
1
        doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
857
858
1
        rapidjson::StringBuffer buffer;
859
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
860
1
        doc.Accept(writer);
861
1
        request_body = buffer.GetString();
862
863
1
        return Status::OK();
864
1
    }
865
};
866
867
class ZhipuAdapter : public OpenAIAdapter {
868
protected:
869
2
    bool supports_dimension_param(const std::string& model_name) const override {
870
2
        return !(model_name == "embedding-2");
871
2
    }
872
};
873
874
class QwenAdapter : public OpenAIAdapter {
875
public:
876
    Status build_multimodal_embedding_request(MultimodalType media_type,
877
                                              const std::string& media_url,
878
3
                                              std::string& request_body) const override {
879
3
        if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO) {
880
1
            return Status::InvalidArgument(
881
1
                    "QWEN only supports image/video multimodal embed, got {}",
882
1
                    multimodal_type_to_string(media_type));
883
1
        }
884
885
2
        rapidjson::Document doc;
886
2
        doc.SetObject();
887
2
        auto& allocator = doc.GetAllocator();
888
889
        /*{
890
            "model": "tongyi-embedding-vision-plus",
891
            "input": {
892
              "contents": [
893
                {"image": "<url>"}
894
              ]
895
            }
896
            "parameters": {
897
              "dimension": 512
898
            }
899
        }*/
900
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
901
2
        rapidjson::Value input(rapidjson::kObjectType);
902
2
        rapidjson::Value contents(rapidjson::kArrayType);
903
904
2
        rapidjson::Value media_item(rapidjson::kObjectType);
905
2
        if (media_type == MultimodalType::IMAGE) {
906
1
            media_item.AddMember("image", rapidjson::Value(media_url.c_str(), allocator),
907
1
                                 allocator);
908
1
        } else {
909
1
            media_item.AddMember("video", rapidjson::Value(media_url.c_str(), allocator),
910
1
                                 allocator);
911
1
        }
912
2
        contents.PushBack(media_item, allocator);
913
914
2
        input.AddMember("contents", contents, allocator);
915
2
        doc.AddMember("input", input, allocator);
916
2
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
917
2
            rapidjson::Value parameters(rapidjson::kObjectType);
918
2
            std::string param_name = get_dimension_param_name();
919
2
            rapidjson::Value dimension_name(param_name.c_str(), allocator);
920
2
            parameters.AddMember(dimension_name, _config.dimensions, allocator);
921
2
            doc.AddMember("parameters", parameters, allocator);
922
2
        }
923
924
2
        rapidjson::StringBuffer buffer;
925
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
926
2
        doc.Accept(writer);
927
2
        request_body = buffer.GetString();
928
2
        return Status::OK();
929
3
    }
930
931
    Status parse_embedding_response(const std::string& response_body,
932
0
                                    std::vector<std::vector<float>>& results) const override {
933
0
        rapidjson::Document doc;
934
0
        doc.Parse(response_body.c_str());
935
936
0
        if (doc.HasParseError() || !doc.IsObject()) [[unlikely]] {
937
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
938
0
                                         response_body);
939
0
        }
940
        // Qwen multimodal embedding usually returns:
941
        // {
942
        //   "output": {
943
        //     "embeddings": [
944
        //       {"index":0, "embedding":[...], "type":"image|video|text"},
945
        //       ...
946
        //     ]
947
        //   }
948
        // }
949
        //
950
        // In text-only or compatibility endpoints, Qwen may also return OpenAI-style
951
        // "data":[{"embedding":[...]}]. For compatibility we first parse native
952
        // output.embeddings and then fallback to OpenAIAdapter parser.
953
0
        if (doc.HasMember("output") && doc["output"].IsObject() &&
954
0
            doc["output"].HasMember("embeddings") && doc["output"]["embeddings"].IsArray()) {
955
0
            const auto& embeddings = doc["output"]["embeddings"];
956
0
            results.reserve(embeddings.Size());
957
0
            for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) {
958
0
                if (!embeddings[i].HasMember("embedding") ||
959
0
                    !embeddings[i]["embedding"].IsArray()) {
960
0
                    return Status::InternalError("Invalid {} response format: {}",
961
0
                                                 _config.provider_type, response_body);
962
0
                }
963
0
                std::transform(embeddings[i]["embedding"].Begin(), embeddings[i]["embedding"].End(),
964
0
                               std::back_inserter(results.emplace_back()),
965
0
                               [](const auto& val) { return val.GetFloat(); });
966
0
            }
967
0
            return Status::OK();
968
0
        }
969
0
        return OpenAIAdapter::parse_embedding_response(response_body, results);
970
0
    }
971
972
protected:
973
4
    bool supports_dimension_param(const std::string& model_name) const override {
974
4
        static const std::unordered_set<std::string> no_dimension_models = {
975
4
                "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
976
4
        return !no_dimension_models.contains(model_name);
977
4
    }
978
979
3
    std::string get_dimension_param_name() const override { return "dimension"; }
980
};
981
982
class JinaAdapter : public VoyageAIAdapter {
983
public:
984
    Status build_multimodal_embedding_request(MultimodalType media_type,
985
                                              const std::string& media_url,
986
1
                                              std::string& request_body) const override {
987
1
        if (media_type != MultimodalType::IMAGE && media_type != MultimodalType::VIDEO)
988
0
                [[unlikely]] {
989
0
            return Status::InvalidArgument(
990
0
                    "JINA only supports image/video multimodal embed, got {}",
991
0
                    multimodal_type_to_string(media_type));
992
0
        }
993
994
1
        rapidjson::Document doc;
995
1
        doc.SetObject();
996
1
        auto& allocator = doc.GetAllocator();
997
998
        /*{
999
            "model": "jina-embeddings-v4",
1000
            "task": "text-matching",
1001
            "input": [
1002
              {"image": "<url>"}
1003
            ]
1004
        }*/
1005
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1006
1
        doc.AddMember("task", "text-matching", allocator);
1007
1008
1
        rapidjson::Value input(rapidjson::kArrayType);
1009
1
        rapidjson::Value media_item(rapidjson::kObjectType);
1010
1
        if (media_type == MultimodalType::IMAGE) {
1011
1
            media_item.AddMember("image", rapidjson::Value(media_url.c_str(), allocator),
1012
1
                                 allocator);
1013
1
        } else {
1014
0
            media_item.AddMember("video", rapidjson::Value(media_url.c_str(), allocator),
1015
0
                                 allocator);
1016
0
        }
1017
1
        input.PushBack(media_item, allocator);
1018
1
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
1019
1
            doc.AddMember("dimensions", _config.dimensions, allocator);
1020
1
        }
1021
1
        doc.AddMember("input", input, allocator);
1022
1023
1
        rapidjson::StringBuffer buffer;
1024
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1025
1
        doc.Accept(writer);
1026
1
        request_body = buffer.GetString();
1027
1
        return Status::OK();
1028
1
    }
1029
};
1030
1031
class BaichuanAdapter : public OpenAIAdapter {
1032
protected:
1033
0
    bool supports_dimension_param(const std::string& model_name) const override { return false; }
1034
};
1035
1036
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
1037
class GeminiAdapter : public AIAdapter {
1038
public:
1039
2
    Status set_authentication(HttpClient* client) const override {
1040
2
        client->set_header("x-goog-api-key", _config.api_key);
1041
2
        client->set_content_type("application/json");
1042
2
        return Status::OK();
1043
2
    }
1044
1045
    Status build_request_payload(const std::vector<std::string>& inputs,
1046
                                 const char* const system_prompt,
1047
1
                                 std::string& request_body) const override {
1048
1
        rapidjson::Document doc;
1049
1
        doc.SetObject();
1050
1
        auto& allocator = doc.GetAllocator();
1051
1052
        /*{
1053
          "systemInstruction": {
1054
              "parts": [
1055
                {
1056
                  "text": "system_prompt here"
1057
                }
1058
              ]
1059
            }
1060
          ],
1061
          "contents": [
1062
            {
1063
              "parts": [
1064
                {
1065
                  "text": "xxx"
1066
                }
1067
              ]
1068
            }
1069
          ],
1070
          "generationConfig": {
1071
          "temperature": 0.7,
1072
          "maxOutputTokens": 1024
1073
          }
1074
1075
        }*/
1076
1
        if (system_prompt && *system_prompt) {
1077
1
            rapidjson::Value system_instruction(rapidjson::kObjectType);
1078
1
            rapidjson::Value parts(rapidjson::kArrayType);
1079
1080
1
            rapidjson::Value part(rapidjson::kObjectType);
1081
1
            part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
1082
1
            parts.PushBack(part, allocator);
1083
            // system_instruction.PushBack(content, allocator);
1084
1
            system_instruction.AddMember("parts", parts, allocator);
1085
1
            doc.AddMember("systemInstruction", system_instruction, allocator);
1086
1
        }
1087
1088
1
        rapidjson::Value contents(rapidjson::kArrayType);
1089
1
        for (const auto& input : inputs) {
1090
1
            rapidjson::Value content(rapidjson::kObjectType);
1091
1
            rapidjson::Value parts(rapidjson::kArrayType);
1092
1093
1
            rapidjson::Value part(rapidjson::kObjectType);
1094
1
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
1095
1096
1
            parts.PushBack(part, allocator);
1097
1
            content.AddMember("parts", parts, allocator);
1098
1
            contents.PushBack(content, allocator);
1099
1
        }
1100
1
        doc.AddMember("contents", contents, allocator);
1101
1102
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1103
1
        rapidjson::Value generationConfig(rapidjson::kObjectType);
1104
1
        if (_config.temperature != -1) {
1105
1
            generationConfig.AddMember("temperature", _config.temperature, allocator);
1106
1
        }
1107
1
        if (_config.max_tokens != -1) {
1108
1
            generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
1109
1
        }
1110
1
        doc.AddMember("generationConfig", generationConfig, allocator);
1111
1112
1
        rapidjson::StringBuffer buffer;
1113
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1114
1
        doc.Accept(writer);
1115
1
        request_body = buffer.GetString();
1116
1117
1
        return Status::OK();
1118
1
    }
1119
1120
    Status parse_response(const std::string& response_body,
1121
3
                          std::vector<std::string>& results) const override {
1122
3
        rapidjson::Document doc;
1123
3
        doc.Parse(response_body.c_str());
1124
1125
3
        if (doc.HasParseError() || !doc.IsObject()) {
1126
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1127
1
                                         response_body);
1128
1
        }
1129
2
        if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
1130
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1131
1
                                         response_body);
1132
1
        }
1133
1134
        /*{
1135
          "candidates":[
1136
            {
1137
              "content": {
1138
                "parts": [
1139
                  {
1140
                    "text": "xxx"
1141
                  }
1142
                ]
1143
              }
1144
            }
1145
          ]
1146
        }*/
1147
1
        const auto& candidates = doc["candidates"];
1148
1
        results.reserve(candidates.Size());
1149
1150
2
        for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
1151
1
            if (!candidates[i].HasMember("content") ||
1152
1
                !candidates[i]["content"].HasMember("parts") ||
1153
1
                !candidates[i]["content"]["parts"].IsArray() ||
1154
1
                candidates[i]["content"]["parts"].Empty() ||
1155
1
                !candidates[i]["content"]["parts"][0].HasMember("text") ||
1156
1
                !candidates[i]["content"]["parts"][0]["text"].IsString()) {
1157
0
                return Status::InternalError("Invalid candidate format in {} response",
1158
0
                                             _config.provider_type);
1159
0
            }
1160
1161
1
            results.emplace_back(candidates[i]["content"]["parts"][0]["text"].GetString());
1162
1
        }
1163
1164
1
        return Status::OK();
1165
1
    }
1166
1167
    Status build_embedding_request(const std::vector<std::string>& inputs,
1168
2
                                   std::string& request_body) const override {
1169
2
        rapidjson::Document doc;
1170
2
        doc.SetObject();
1171
2
        auto& allocator = doc.GetAllocator();
1172
1173
        /*{
1174
          "model": "models/gemini-embedding-001",
1175
          "content": {
1176
              "parts": [
1177
                {
1178
                  "text": "xxx"
1179
                }
1180
              ]
1181
            }
1182
          "outputDimensionality": 1024
1183
        }*/
1184
1185
        // gemini requires the model format as `models/{model}`
1186
2
        std::string model_name = _config.model_name;
1187
2
        if (!model_name.starts_with("models/")) {
1188
2
            model_name = "models/" + model_name;
1189
2
        }
1190
2
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
1191
2
        add_dimension_params(doc, allocator);
1192
1193
2
        rapidjson::Value content(rapidjson::kObjectType);
1194
2
        for (const auto& input : inputs) {
1195
2
            rapidjson::Value parts(rapidjson::kArrayType);
1196
2
            rapidjson::Value part(rapidjson::kObjectType);
1197
2
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
1198
2
            parts.PushBack(part, allocator);
1199
2
            content.AddMember("parts", parts, allocator);
1200
2
        }
1201
2
        doc.AddMember("content", content, allocator);
1202
1203
2
        rapidjson::StringBuffer buffer;
1204
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1205
2
        doc.Accept(writer);
1206
2
        request_body = buffer.GetString();
1207
1208
2
        return Status::OK();
1209
2
    }
1210
1211
    Status build_multimodal_embedding_request(MultimodalType media_type,
1212
                                              const std::string& media_url,
1213
3
                                              std::string& request_body) const override {
1214
3
        const char* mime_type = nullptr;
1215
3
        switch (media_type) {
1216
1
        case MultimodalType::IMAGE:
1217
1
            mime_type = "image/png";
1218
1
            break;
1219
1
        case MultimodalType::AUDIO:
1220
1
            mime_type = "audio/mpeg";
1221
1
            break;
1222
1
        case MultimodalType::VIDEO:
1223
1
            mime_type = "video/mp4";
1224
1
            break;
1225
3
        }
1226
1227
3
        rapidjson::Document doc;
1228
3
        doc.SetObject();
1229
3
        auto& allocator = doc.GetAllocator();
1230
1231
3
        std::string model_name = _config.model_name;
1232
3
        if (!model_name.starts_with("models/")) {
1233
3
            model_name = "models/" + model_name;
1234
3
        }
1235
3
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
1236
3
        add_dimension_params(doc, allocator);
1237
1238
3
        rapidjson::Value content(rapidjson::kObjectType);
1239
3
        rapidjson::Value parts(rapidjson::kArrayType);
1240
3
        rapidjson::Value part(rapidjson::kObjectType);
1241
3
        rapidjson::Value file_data(rapidjson::kObjectType);
1242
3
        file_data.AddMember("mime_type", rapidjson::Value(mime_type, allocator), allocator);
1243
3
        file_data.AddMember("file_uri", rapidjson::Value(media_url.c_str(), allocator), allocator);
1244
3
        part.AddMember("file_data", file_data, allocator);
1245
3
        parts.PushBack(part, allocator);
1246
3
        content.AddMember("parts", parts, allocator);
1247
3
        doc.AddMember("content", content, allocator);
1248
1249
3
        rapidjson::StringBuffer buffer;
1250
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1251
3
        doc.Accept(writer);
1252
3
        request_body = buffer.GetString();
1253
3
        return Status::OK();
1254
3
    }
1255
1256
    Status parse_embedding_response(const std::string& response_body,
1257
1
                                    std::vector<std::vector<float>>& results) const override {
1258
1
        rapidjson::Document doc;
1259
1
        doc.Parse(response_body.c_str());
1260
1261
1
        if (doc.HasParseError() || !doc.IsObject()) {
1262
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1263
0
                                         response_body);
1264
0
        }
1265
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
1266
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1267
0
                                         response_body);
1268
0
        }
1269
1270
        /*{
1271
          "embedding":{
1272
            "values": [0.1, 0.2, 0.3]
1273
          }
1274
        }*/
1275
1
        const auto& embedding = doc["embedding"];
1276
1
        if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
1277
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1278
0
                                         response_body);
1279
0
        }
1280
1
        std::transform(embedding["values"].Begin(), embedding["values"].End(),
1281
1
                       std::back_inserter(results.emplace_back()),
1282
3
                       [](const auto& val) { return val.GetFloat(); });
1283
1284
1
        return Status::OK();
1285
1
    }
1286
1287
protected:
1288
5
    bool supports_dimension_param(const std::string& model_name) const override {
1289
5
        static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
1290
5
                                                                            "embedding-001"};
1291
5
        return !no_dimension_models.contains(model_name);
1292
5
    }
1293
1294
4
    std::string get_dimension_param_name() const override { return "outputDimensionality"; }
1295
};
1296
1297
class AnthropicAdapter : public VoyageAIAdapter {
1298
public:
1299
1
    Status set_authentication(HttpClient* client) const override {
1300
1
        client->set_header("x-api-key", _config.api_key);
1301
1
        client->set_header("anthropic-version", _config.anthropic_version);
1302
1
        client->set_content_type("application/json");
1303
1304
1
        return Status::OK();
1305
1
    }
1306
1307
    Status build_request_payload(const std::vector<std::string>& inputs,
1308
                                 const char* const system_prompt,
1309
1
                                 std::string& request_body) const override {
1310
1
        rapidjson::Document doc;
1311
1
        doc.SetObject();
1312
1
        auto& allocator = doc.GetAllocator();
1313
1314
        /*
1315
            "model": "claude-opus-4-1-20250805",
1316
            "max_tokens": 1024,
1317
            "system": "system_prompt here",
1318
            "messages": [
1319
              {"role": "user", "content": "xxx"}
1320
            ],
1321
            "temperature": 0.7
1322
        */
1323
1324
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1325
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1326
1
        if (_config.temperature != -1) {
1327
1
            doc.AddMember("temperature", _config.temperature, allocator);
1328
1
        }
1329
1
        if (_config.max_tokens != -1) {
1330
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
1331
1
        } else {
1332
            // Keep the default value, Anthropic requires this parameter
1333
0
            doc.AddMember("max_tokens", 2048, allocator);
1334
0
        }
1335
1
        if (system_prompt && *system_prompt) {
1336
1
            doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
1337
1
        }
1338
1339
1
        rapidjson::Value messages(rapidjson::kArrayType);
1340
1
        for (const auto& input : inputs) {
1341
1
            rapidjson::Value message(rapidjson::kObjectType);
1342
1
            message.AddMember("role", "user", allocator);
1343
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
1344
1
            messages.PushBack(message, allocator);
1345
1
        }
1346
1
        doc.AddMember("messages", messages, allocator);
1347
1348
1
        rapidjson::StringBuffer buffer;
1349
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1350
1
        doc.Accept(writer);
1351
1
        request_body = buffer.GetString();
1352
1353
1
        return Status::OK();
1354
1
    }
1355
1356
    Status parse_response(const std::string& response_body,
1357
3
                          std::vector<std::string>& results) const override {
1358
3
        rapidjson::Document doc;
1359
3
        doc.Parse(response_body.c_str());
1360
3
        if (doc.HasParseError() || !doc.IsObject()) {
1361
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1362
1
                                         response_body);
1363
1
        }
1364
2
        if (!doc.HasMember("content") || !doc["content"].IsArray()) {
1365
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1366
1
                                         response_body);
1367
1
        }
1368
1369
        /*{
1370
            "content": [
1371
              {
1372
                "text": "xxx",
1373
                "type": "text"
1374
              }
1375
            ]
1376
        }*/
1377
1
        const auto& content = doc["content"];
1378
1
        results.reserve(1);
1379
1380
1
        std::string result;
1381
2
        for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
1382
1
            if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
1383
1
                !content[i].HasMember("text") || !content[i]["text"].IsString()) {
1384
0
                continue;
1385
0
            }
1386
1387
1
            if (std::string(content[i]["type"].GetString()) == "text") {
1388
1
                if (!result.empty()) {
1389
0
                    result += "\n";
1390
0
                }
1391
1
                result += content[i]["text"].GetString();
1392
1
            }
1393
1
        }
1394
1395
1
        results.emplace_back(std::move(result));
1396
1
        return Status::OK();
1397
2
    }
1398
};
1399
1400
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
1401
class MockAdapter : public AIAdapter {
1402
public:
1403
0
    Status set_authentication(HttpClient* client) const override { return Status::OK(); }
1404
1405
    Status build_request_payload(const std::vector<std::string>& inputs,
1406
                                 const char* const system_prompt,
1407
49
                                 std::string& request_body) const override {
1408
49
        return Status::OK();
1409
49
    }
1410
1411
    Status parse_response(const std::string& response_body,
1412
49
                          std::vector<std::string>& results) const override {
1413
49
        results.emplace_back(response_body);
1414
49
        return Status::OK();
1415
49
    }
1416
1417
    Status build_embedding_request(const std::vector<std::string>& inputs,
1418
1
                                   std::string& request_body) const override {
1419
1
        return Status::OK();
1420
1
    }
1421
1422
    Status build_multimodal_embedding_request(MultimodalType /*media_type*/,
1423
                                              const std::string& /*media_url*/,
1424
0
                                              std::string& /*request_body*/) const override {
1425
0
        return Status::OK();
1426
0
    }
1427
1428
    Status parse_embedding_response(const std::string& response_body,
1429
1
                                    std::vector<std::vector<float>>& results) const override {
1430
1
        rapidjson::Document doc;
1431
1
        doc.SetObject();
1432
1
        doc.Parse(response_body.c_str());
1433
1
        if (doc.HasParseError() || !doc.IsObject()) {
1434
0
            return Status::InternalError("Failed to parse embedding response");
1435
0
        }
1436
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
1437
0
            return Status::InternalError("Invalid embedding response format");
1438
0
        }
1439
1440
1
        results.reserve(1);
1441
1
        std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
1442
1
                       std::back_inserter(results.emplace_back()),
1443
5
                       [](const auto& val) { return val.GetFloat(); });
1444
1
        return Status::OK();
1445
1
    }
1446
};
1447
1448
class AIAdapterFactory {
1449
public:
1450
75
    static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
1451
75
        static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
1452
75
                adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
1453
75
                            {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
1454
75
                            {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
1455
75
                            {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
1456
75
                            {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
1457
75
                            {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
1458
75
                            {"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
1459
75
                            {"JINA", []() { return std::make_shared<JinaAdapter>(); }},
1460
75
                            {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
1461
75
                            {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
1462
75
                            {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
1463
75
                            {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
1464
75
                            {"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
1465
1466
75
        auto it = adapters.find(provider_type);
1467
75
        return (it != adapters.end()) ? it->second() : nullptr;
1468
75
    }
1469
};
1470
1471
#include "common/compile_check_end.h"
1472
} // namespace doris