Coverage Report

Created: 2026-04-14 20:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_adapter.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/PaloInternalService_types.h>
21
#include <rapidjson/rapidjson.h>
22
23
#include <algorithm>
24
#include <memory>
25
#include <string>
26
#include <unordered_map>
27
#include <vector>
28
29
#include "common/status.h"
30
#include "core/string_buffer.hpp"
31
#include "rapidjson/document.h"
32
#include "rapidjson/stringbuffer.h"
33
#include "rapidjson/writer.h"
34
#include "service/http/http_client.h"
35
#include "service/http/http_headers.h"
36
37
namespace doris {
38
39
struct AIResource {
40
16
    AIResource() = default;
41
    AIResource(const TAIResource& tai)
42
11
            : endpoint(tai.endpoint),
43
11
              provider_type(tai.provider_type),
44
11
              model_name(tai.model_name),
45
11
              api_key(tai.api_key),
46
11
              temperature(tai.temperature),
47
11
              max_tokens(tai.max_tokens),
48
11
              max_retries(tai.max_retries),
49
11
              retry_delay_second(tai.retry_delay_second),
50
11
              anthropic_version(tai.anthropic_version),
51
11
              dimensions(tai.dimensions) {}
52
53
    std::string endpoint;
54
    std::string provider_type;
55
    std::string model_name;
56
    std::string api_key;
57
    double temperature;
58
    int64_t max_tokens;
59
    int32_t max_retries;
60
    int32_t retry_delay_second;
61
    std::string anthropic_version;
62
    int32_t dimensions;
63
64
1
    void serialize(BufferWritable& buf) const {
65
1
        buf.write_binary(endpoint);
66
1
        buf.write_binary(provider_type);
67
1
        buf.write_binary(model_name);
68
1
        buf.write_binary(api_key);
69
1
        buf.write_binary(temperature);
70
1
        buf.write_binary(max_tokens);
71
1
        buf.write_binary(max_retries);
72
1
        buf.write_binary(retry_delay_second);
73
1
        buf.write_binary(anthropic_version);
74
1
        buf.write_binary(dimensions);
75
1
    }
76
77
1
    void deserialize(BufferReadable& buf) {
78
1
        buf.read_binary(endpoint);
79
1
        buf.read_binary(provider_type);
80
1
        buf.read_binary(model_name);
81
1
        buf.read_binary(api_key);
82
1
        buf.read_binary(temperature);
83
1
        buf.read_binary(max_tokens);
84
1
        buf.read_binary(max_retries);
85
1
        buf.read_binary(retry_delay_second);
86
1
        buf.read_binary(anthropic_version);
87
1
        buf.read_binary(dimensions);
88
1
    }
89
};
90
91
class AIAdapter {
92
public:
93
111
    virtual ~AIAdapter() = default;
94
95
    // Set authentication headers for the HTTP client
96
    virtual Status set_authentication(HttpClient* client) const = 0;
97
98
71
    virtual void init(const TAIResource& config) { _config = config; }
99
12
    virtual void init(const AIResource& config) {
100
12
        _config.endpoint = config.endpoint;
101
12
        _config.provider_type = config.provider_type;
102
12
        _config.model_name = config.model_name;
103
12
        _config.api_key = config.api_key;
104
12
        _config.temperature = config.temperature;
105
12
        _config.max_tokens = config.max_tokens;
106
12
        _config.max_retries = config.max_retries;
107
12
        _config.retry_delay_second = config.retry_delay_second;
108
12
        _config.anthropic_version = config.anthropic_version;
109
12
    }
110
111
    // Build request payload based on input text strings
112
    virtual Status build_request_payload(const std::vector<std::string>& inputs,
113
                                         const char* const system_prompt,
114
1
                                         std::string& request_body) const {
115
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
116
1
    }
117
118
    // Parse response from AI service and extract generated text results
119
    virtual Status parse_response(const std::string& response_body,
120
1
                                  std::vector<std::string>& results) const {
121
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
122
1
    }
123
124
    virtual Status build_embedding_request(const std::vector<std::string>& inputs,
125
0
                                           std::string& request_body) const {
126
0
        return Status::NotSupported("{} does not support the Embed feature.",
127
0
                                    _config.provider_type);
128
0
    }
129
130
    virtual Status parse_embedding_response(const std::string& response_body,
131
0
                                            std::vector<std::vector<float>>& results) const {
132
0
        return Status::NotSupported("{} does not support the Embed feature.",
133
0
                                    _config.provider_type);
134
0
    }
135
136
protected:
137
    TAIResource _config;
138
139
    // return true if the model support dimension parameter
140
1
    virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
141
142
    // Different providers may have different dimension parameter names.
143
0
    virtual std::string get_dimension_param_name() const { return "dimensions"; }
144
145
    virtual void add_dimension_params(rapidjson::Value& doc,
146
11
                                      rapidjson::Document::AllocatorType& allocator) const {
147
11
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
148
5
            std::string param_name = get_dimension_param_name();
149
5
            rapidjson::Value name(param_name.c_str(), allocator);
150
5
            doc.AddMember(name, _config.dimensions, allocator);
151
5
        }
152
11
    }
153
};
154
155
// Most LLM-providers' Embedding formats are based on VoyageAI.
156
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
157
class VoyageAIAdapter : public AIAdapter {
158
public:
159
2
    Status set_authentication(HttpClient* client) const override {
160
2
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
161
2
        client->set_content_type("application/json");
162
163
2
        return Status::OK();
164
2
    }
165
166
    Status build_embedding_request(const std::vector<std::string>& inputs,
167
8
                                   std::string& request_body) const override {
168
8
        rapidjson::Document doc;
169
8
        doc.SetObject();
170
8
        auto& allocator = doc.GetAllocator();
171
172
        /*{
173
            "model": "xxx",
174
            "input": [
175
              "xxx",
176
              "xxx",
177
              ...
178
            ],
179
            "output_dimensions": 512
180
        }*/
181
8
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
182
8
        add_dimension_params(doc, allocator);
183
184
8
        rapidjson::Value input(rapidjson::kArrayType);
185
8
        for (const auto& msg : inputs) {
186
8
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
187
8
        }
188
8
        doc.AddMember("input", input, allocator);
189
190
8
        rapidjson::StringBuffer buffer;
191
8
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
192
8
        doc.Accept(writer);
193
8
        request_body = buffer.GetString();
194
195
8
        return Status::OK();
196
8
    }
197
198
    Status parse_embedding_response(const std::string& response_body,
199
5
                                    std::vector<std::vector<float>>& results) const override {
200
5
        rapidjson::Document doc;
201
5
        doc.Parse(response_body.c_str());
202
203
5
        if (doc.HasParseError() || !doc.IsObject()) {
204
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
205
1
                                         response_body);
206
1
        }
207
4
        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
208
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
209
1
                                         response_body);
210
1
        }
211
212
        /*{
213
            "data":[
214
              {
215
                "object": "embedding",
216
                "embedding": [...], <- only need this
217
                "index": 0
218
              },
219
              {
220
                "object": "embedding",
221
                "embedding": [...],
222
                "index": 1
223
              }, ...
224
            ],
225
            "model"....
226
        }*/
227
3
        const auto& data = doc["data"];
228
3
        results.reserve(data.Size());
229
7
        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
230
5
            if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
231
1
                return Status::InternalError("Invalid {} response format: {}",
232
1
                                             _config.provider_type, response_body);
233
1
            }
234
235
4
            std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
236
4
                           std::back_inserter(results.emplace_back()),
237
10
                           [](const auto& val) { return val.GetFloat(); });
238
4
        }
239
240
2
        return Status::OK();
241
3
    }
242
243
protected:
244
2
    bool supports_dimension_param(const std::string& model_name) const override {
245
2
        static const std::unordered_set<std::string> no_dimension_models = {
246
2
                "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
247
2
                "voyage-multimodal-3"};
248
2
        return !no_dimension_models.contains(model_name);
249
2
    }
250
251
1
    std::string get_dimension_param_name() const override { return "output_dimension"; }
252
};
253
254
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
255
class LocalAdapter : public AIAdapter {
256
public:
257
    // Local deployments typically don't need authentication
258
2
    Status set_authentication(HttpClient* client) const override {
259
2
        client->set_content_type("application/json");
260
2
        return Status::OK();
261
2
    }
262
263
    Status build_request_payload(const std::vector<std::string>& inputs,
264
                                 const char* const system_prompt,
265
3
                                 std::string& request_body) const override {
266
3
        rapidjson::Document doc;
267
3
        doc.SetObject();
268
3
        auto& allocator = doc.GetAllocator();
269
270
3
        std::string end_point = _config.endpoint;
271
3
        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
272
2
            RETURN_IF_ERROR(
273
2
                    build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
274
2
        } else {
275
1
            RETURN_IF_ERROR(
276
1
                    build_default_request(doc, allocator, inputs, system_prompt, request_body));
277
1
        }
278
279
3
        rapidjson::StringBuffer buffer;
280
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
281
3
        doc.Accept(writer);
282
3
        request_body = buffer.GetString();
283
284
3
        return Status::OK();
285
3
    }
286
287
    Status parse_response(const std::string& response_body,
288
7
                          std::vector<std::string>& results) const override {
289
7
        rapidjson::Document doc;
290
7
        doc.Parse(response_body.c_str());
291
292
7
        if (doc.HasParseError() || !doc.IsObject()) {
293
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
294
1
                                         response_body);
295
1
        }
296
297
        // Handle various response formats from local LLMs
298
        // Format 1: OpenAI-compatible format with choices/message/content
299
6
        if (doc.HasMember("choices") && doc["choices"].IsArray()) {
300
1
            const auto& choices = doc["choices"];
301
1
            results.reserve(choices.Size());
302
303
2
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
304
1
                if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
305
1
                    choices[i]["message"]["content"].IsString()) {
306
1
                    results.emplace_back(choices[i]["message"]["content"].GetString());
307
1
                } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
308
                    // Some local LLMs use a simpler format
309
0
                    results.emplace_back(choices[i]["text"].GetString());
310
0
                }
311
1
            }
312
5
        } else if (doc.HasMember("text") && doc["text"].IsString()) {
313
            // Format 2: Simple response with just "text" or "content" field
314
1
            results.emplace_back(doc["text"].GetString());
315
4
        } else if (doc.HasMember("content") && doc["content"].IsString()) {
316
1
            results.emplace_back(doc["content"].GetString());
317
3
        } else if (doc.HasMember("response") && doc["response"].IsString()) {
318
            // Format 3: Response field (Ollama `generate` format)
319
1
            results.emplace_back(doc["response"].GetString());
320
2
        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
321
2
                   doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
322
            // Format 4: message/content field (Ollama `chat` format)
323
1
            results.emplace_back(doc["message"]["content"].GetString());
324
1
        } else {
325
1
            return Status::NotSupported("Unsupported response format from local AI.");
326
1
        }
327
328
5
        return Status::OK();
329
6
    }
330
331
    Status build_embedding_request(const std::vector<std::string>& inputs,
332
1
                                   std::string& request_body) const override {
333
1
        rapidjson::Document doc;
334
1
        doc.SetObject();
335
1
        auto& allocator = doc.GetAllocator();
336
337
1
        if (!_config.model_name.empty()) {
338
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
339
1
                          allocator);
340
1
        }
341
342
1
        add_dimension_params(doc, allocator);
343
344
1
        rapidjson::Value input(rapidjson::kArrayType);
345
1
        for (const auto& msg : inputs) {
346
1
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
347
1
        }
348
1
        doc.AddMember("input", input, allocator);
349
350
1
        rapidjson::StringBuffer buffer;
351
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
352
1
        doc.Accept(writer);
353
1
        request_body = buffer.GetString();
354
355
1
        return Status::OK();
356
1
    }
357
358
    Status parse_embedding_response(const std::string& response_body,
359
3
                                    std::vector<std::vector<float>>& results) const override {
360
3
        rapidjson::Document doc;
361
3
        doc.Parse(response_body.c_str());
362
363
3
        if (doc.HasParseError() || !doc.IsObject()) {
364
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
365
0
                                         response_body);
366
0
        }
367
368
        // parse different response format
369
3
        rapidjson::Value embedding;
370
3
        if (doc.HasMember("data") && doc["data"].IsArray()) {
371
            // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
372
1
            const auto& data = doc["data"];
373
1
            results.reserve(data.Size());
374
3
            for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
375
2
                if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
376
0
                    return Status::InternalError("Invalid {} response format",
377
0
                                                 _config.provider_type);
378
0
                }
379
380
2
                std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
381
2
                               std::back_inserter(results.emplace_back()),
382
5
                               [](const auto& val) { return val.GetFloat(); });
383
2
            }
384
2
        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
385
            // "embeddings":[[0.1, 0.2, ...]]
386
1
            results.reserve(1);
387
2
            for (int i = 0; i < doc["embeddings"].Size(); i++) {
388
1
                embedding = doc["embeddings"][i];
389
1
                std::transform(embedding.Begin(), embedding.End(),
390
1
                               std::back_inserter(results.emplace_back()),
391
2
                               [](const auto& val) { return val.GetFloat(); });
392
1
            }
393
1
        } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
394
            // "embedding":[0.1, 0.2, ...]
395
1
            results.reserve(1);
396
1
            embedding = doc["embedding"];
397
1
            std::transform(embedding.Begin(), embedding.End(),
398
1
                           std::back_inserter(results.emplace_back()),
399
3
                           [](const auto& val) { return val.GetFloat(); });
400
1
        } else {
401
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
402
0
                                         response_body);
403
0
        }
404
405
3
        return Status::OK();
406
3
    }
407
408
private:
409
    Status build_ollama_request(rapidjson::Document& doc,
410
                                rapidjson::Document::AllocatorType& allocator,
411
                                const std::vector<std::string>& inputs,
412
2
                                const char* const system_prompt, std::string& request_body) const {
413
        /*
414
        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
415
        {
416
            "model": <model_name>,
417
            "stream": false,
418
            "think": false,
419
            "options": {
420
                "temperature": <temperature>,
421
                "max_token": <max_token>
422
            },
423
            "messages": [
424
                {"role": "system", "content": <system_prompt>},
425
                {"role": "user", "content": <user_prompt>}
426
            ]
427
        }
428
        
429
        for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
430
        {
431
            "model": <model_name>,
432
            "stream": false,
433
            "think": false
434
            "options": {
435
                "temperature": <temperature>,
436
                "max_token": <max_token>
437
            },
438
            "system": <system_prompt>,
439
            "prompt": <user_prompt>
440
        }
441
        */
442
443
        // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
444
        // The rest remains identical.
445
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
446
2
        doc.AddMember("stream", false, allocator);
447
2
        doc.AddMember("think", false, allocator);
448
449
        // option section
450
2
        rapidjson::Value options(rapidjson::kObjectType);
451
2
        if (_config.temperature != -1) {
452
2
            options.AddMember("temperature", _config.temperature, allocator);
453
2
        }
454
2
        if (_config.max_tokens != -1) {
455
2
            options.AddMember("max_token", _config.max_tokens, allocator);
456
2
        }
457
2
        doc.AddMember("options", options, allocator);
458
459
        // prompt section
460
2
        if (_config.endpoint.ends_with("chat")) {
461
1
            rapidjson::Value messages(rapidjson::kArrayType);
462
1
            if (system_prompt && *system_prompt) {
463
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
464
1
                sys_msg.AddMember("role", "system", allocator);
465
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
466
1
                messages.PushBack(sys_msg, allocator);
467
1
            }
468
1
            for (const auto& input : inputs) {
469
1
                rapidjson::Value message(rapidjson::kObjectType);
470
1
                message.AddMember("role", "user", allocator);
471
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
472
1
                messages.PushBack(message, allocator);
473
1
            }
474
1
            doc.AddMember("messages", messages, allocator);
475
1
        } else {
476
1
            if (system_prompt && *system_prompt) {
477
1
                doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
478
1
            }
479
1
            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
480
1
        }
481
482
2
        return Status::OK();
483
2
    }
484
485
    Status build_default_request(rapidjson::Document& doc,
486
                                 rapidjson::Document::AllocatorType& allocator,
487
                                 const std::vector<std::string>& inputs,
488
1
                                 const char* const system_prompt, std::string& request_body) const {
489
        /*
490
        Default format(OpenAI-compatible):
491
        {
492
            "model": <model_name>,
493
            "temperature": <temperature>,
494
            "max_tokens": <max_tokens>,
495
            "messages": [
496
                {"role": "system", "content": <system_prompt>},
497
                {"role": "user", "content": <user_prompt>}
498
            ]
499
        }
500
        */
501
502
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
503
504
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
505
1
        if (_config.temperature != -1) {
506
1
            doc.AddMember("temperature", _config.temperature, allocator);
507
1
        }
508
1
        if (_config.max_tokens != -1) {
509
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
510
1
        }
511
512
1
        rapidjson::Value messages(rapidjson::kArrayType);
513
1
        if (system_prompt && *system_prompt) {
514
1
            rapidjson::Value sys_msg(rapidjson::kObjectType);
515
1
            sys_msg.AddMember("role", "system", allocator);
516
1
            sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
517
1
            messages.PushBack(sys_msg, allocator);
518
1
        }
519
1
        for (const auto& input : inputs) {
520
1
            rapidjson::Value message(rapidjson::kObjectType);
521
1
            message.AddMember("role", "user", allocator);
522
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
523
1
            messages.PushBack(message, allocator);
524
1
        }
525
1
        doc.AddMember("messages", messages, allocator);
526
1
        return Status::OK();
527
1
    }
528
};
529
530
// The OpenAI API format can be reused with some compatible AIs.
531
class OpenAIAdapter : public VoyageAIAdapter {
532
public:
533
8
    Status set_authentication(HttpClient* client) const override {
534
8
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
535
8
        client->set_content_type("application/json");
536
537
8
        return Status::OK();
538
8
    }
539
540
    Status build_request_payload(const std::vector<std::string>& inputs,
541
                                 const char* const system_prompt,
542
2
                                 std::string& request_body) const override {
543
2
        rapidjson::Document doc;
544
2
        doc.SetObject();
545
2
        auto& allocator = doc.GetAllocator();
546
547
2
        if (_config.endpoint.ends_with("responses")) {
548
            /*{
549
              "model": "gpt-4.1-mini",
550
              "input": [
551
                {"role": "system", "content": "system_prompt here"},
552
                {"role": "user", "content": "xxx"}
553
              ],
554
              "temperature": 0.7,
555
              "max_output_tokens": 150
556
            }*/
557
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
558
1
                          allocator);
559
560
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
561
1
            if (_config.temperature != -1) {
562
1
                doc.AddMember("temperature", _config.temperature, allocator);
563
1
            }
564
1
            if (_config.max_tokens != -1) {
565
1
                doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
566
1
            }
567
568
            // input
569
1
            rapidjson::Value input(rapidjson::kArrayType);
570
1
            if (system_prompt && *system_prompt) {
571
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
572
1
                sys_msg.AddMember("role", "system", allocator);
573
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
574
1
                input.PushBack(sys_msg, allocator);
575
1
            }
576
1
            for (const auto& msg : inputs) {
577
1
                rapidjson::Value message(rapidjson::kObjectType);
578
1
                message.AddMember("role", "user", allocator);
579
1
                message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
580
1
                input.PushBack(message, allocator);
581
1
            }
582
1
            doc.AddMember("input", input, allocator);
583
1
        } else {
584
            /*{
585
              "model": "gpt-4",
586
              "messages": [
587
                {"role": "system", "content": "system_prompt here"},
588
                {"role": "user", "content": "xxx"}
589
              ],
590
              "temperature": x,
591
              "max_tokens": x,
592
            }*/
593
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
594
1
                          allocator);
595
596
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
597
1
            if (_config.temperature != -1) {
598
1
                doc.AddMember("temperature", _config.temperature, allocator);
599
1
            }
600
1
            if (_config.max_tokens != -1) {
601
1
                doc.AddMember("max_tokens", _config.max_tokens, allocator);
602
1
            }
603
604
1
            rapidjson::Value messages(rapidjson::kArrayType);
605
1
            if (system_prompt && *system_prompt) {
606
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
607
1
                sys_msg.AddMember("role", "system", allocator);
608
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
609
1
                messages.PushBack(sys_msg, allocator);
610
1
            }
611
1
            for (const auto& input : inputs) {
612
1
                rapidjson::Value message(rapidjson::kObjectType);
613
1
                message.AddMember("role", "user", allocator);
614
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
615
1
                messages.PushBack(message, allocator);
616
1
            }
617
1
            doc.AddMember("messages", messages, allocator);
618
1
        }
619
620
2
        rapidjson::StringBuffer buffer;
621
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
622
2
        doc.Accept(writer);
623
2
        request_body = buffer.GetString();
624
625
2
        return Status::OK();
626
2
    }
627
628
    Status parse_response(const std::string& response_body,
629
6
                          std::vector<std::string>& results) const override {
630
6
        rapidjson::Document doc;
631
6
        doc.Parse(response_body.c_str());
632
633
6
        if (doc.HasParseError() || !doc.IsObject()) {
634
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
635
1
                                         response_body);
636
1
        }
637
638
5
        if (doc.HasMember("output") && doc["output"].IsArray()) {
639
            /// for responses endpoint
640
            /*{
641
              "output": [
642
                {
643
                  "id": "msg_123",
644
                  "type": "message",
645
                  "role": "assistant",
646
                  "content": [
647
                    {
648
                      "type": "text",
649
                      "text": "result text here"   <- result
650
                    }
651
                  ]
652
                }
653
              ]
654
            }*/
655
1
            const auto& output = doc["output"];
656
1
            results.reserve(output.Size());
657
658
2
            for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
659
1
                if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
660
1
                    output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
661
1
                    !output[i]["content"][0]["text"].IsString()) {
662
0
                    return Status::InternalError("Invalid output format in {} response: {}",
663
0
                                                 _config.provider_type, response_body);
664
0
                }
665
666
1
                results.emplace_back(output[i]["content"][0]["text"].GetString());
667
1
            }
668
4
        } else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
669
            /// for completions endpoint
670
            /*{
671
              "object": "chat.completion",
672
              "model": "gpt-4",
673
              "choices": [
674
                {
675
                  ...
676
                  "message": {
677
                    "role": "assistant",
678
                    "content": "xxx"      <- result
679
                  },
680
                  ...
681
                }
682
              ],
683
              ...
684
            }*/
685
3
            const auto& choices = doc["choices"];
686
3
            results.reserve(choices.Size());
687
688
4
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
689
3
                if (!choices[i].HasMember("message") ||
690
3
                    !choices[i]["message"].HasMember("content") ||
691
3
                    !choices[i]["message"]["content"].IsString()) {
692
2
                    return Status::InternalError("Invalid choice format in {} response: {}",
693
2
                                                 _config.provider_type, response_body);
694
2
                }
695
696
1
                results.emplace_back(choices[i]["message"]["content"].GetString());
697
1
            }
698
3
        } else {
699
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
700
1
                                         response_body);
701
1
        }
702
703
2
        return Status::OK();
704
5
    }
705
706
protected:
707
2
    bool supports_dimension_param(const std::string& model_name) const override {
708
2
        return !(model_name == "text-embedding-ada-002");
709
2
    }
710
711
2
    std::string get_dimension_param_name() const override { return "dimensions"; }
712
};
713
714
class DeepSeekAdapter : public OpenAIAdapter {
715
public:
716
    Status build_embedding_request(const std::vector<std::string>& inputs,
717
1
                                   std::string& request_body) const override {
718
1
        return Status::NotSupported("{} does not support the Embed feature.",
719
1
                                    _config.provider_type);
720
1
    }
721
722
    Status parse_embedding_response(const std::string& response_body,
723
1
                                    std::vector<std::vector<float>>& results) const override {
724
1
        return Status::NotSupported("{} does not support the Embed feature.",
725
1
                                    _config.provider_type);
726
1
    }
727
};
728
729
class MoonShotAdapter : public OpenAIAdapter {
730
public:
731
    Status build_embedding_request(const std::vector<std::string>& inputs,
732
1
                                   std::string& request_body) const override {
733
1
        return Status::NotSupported("{} does not support the Embed feature.",
734
1
                                    _config.provider_type);
735
1
    }
736
737
    Status parse_embedding_response(const std::string& response_body,
738
1
                                    std::vector<std::vector<float>>& results) const override {
739
1
        return Status::NotSupported("{} does not support the Embed feature.",
740
1
                                    _config.provider_type);
741
1
    }
742
};
743
744
class MinimaxAdapter : public OpenAIAdapter {
745
public:
746
    Status build_embedding_request(const std::vector<std::string>& inputs,
747
1
                                   std::string& request_body) const override {
748
1
        rapidjson::Document doc;
749
1
        doc.SetObject();
750
1
        auto& allocator = doc.GetAllocator();
751
752
        /*{
753
          "text": ["xxx", "xxx", ...],
754
          "model": "embo-1",
755
          "type": "db"
756
        }*/
757
1
        rapidjson::Value texts(rapidjson::kArrayType);
758
1
        for (const auto& input : inputs) {
759
1
            texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
760
1
        }
761
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
762
1
        doc.AddMember("texts", texts, allocator);
763
1
        doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
764
765
1
        rapidjson::StringBuffer buffer;
766
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
767
1
        doc.Accept(writer);
768
1
        request_body = buffer.GetString();
769
770
1
        return Status::OK();
771
1
    }
772
};
773
774
class ZhipuAdapter : public OpenAIAdapter {
775
protected:
776
2
    bool supports_dimension_param(const std::string& model_name) const override {
777
2
        return !(model_name == "embedding-2");
778
2
    }
779
};
780
781
class QwenAdapter : public OpenAIAdapter {
782
protected:
783
2
    bool supports_dimension_param(const std::string& model_name) const override {
784
2
        static const std::unordered_set<std::string> no_dimension_models = {
785
2
                "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
786
2
        return !no_dimension_models.contains(model_name);
787
2
    }
788
789
1
    std::string get_dimension_param_name() const override { return "dimension"; }
790
};
791
792
class BaichuanAdapter : public OpenAIAdapter {
793
protected:
794
0
    bool supports_dimension_param(const std::string& model_name) const override { return false; }
795
};
796
797
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
798
class GeminiAdapter : public AIAdapter {
799
public:
800
2
    Status set_authentication(HttpClient* client) const override {
801
2
        client->set_header("x-goog-api-key", _config.api_key);
802
2
        client->set_content_type("application/json");
803
2
        return Status::OK();
804
2
    }
805
806
    Status build_request_payload(const std::vector<std::string>& inputs,
807
                                 const char* const system_prompt,
808
1
                                 std::string& request_body) const override {
809
1
        rapidjson::Document doc;
810
1
        doc.SetObject();
811
1
        auto& allocator = doc.GetAllocator();
812
813
        /*{
814
          "systemInstruction": {
815
              "parts": [
816
                {
817
                  "text": "system_prompt here"
818
                }
819
              ]
820
            }
821
          ],
822
          "contents": [
823
            {
824
              "parts": [
825
                {
826
                  "text": "xxx"
827
                }
828
              ]
829
            }
830
          ],
831
          "generationConfig": {
832
          "temperature": 0.7,
833
          "maxOutputTokens": 1024
834
          }
835
836
        }*/
837
1
        if (system_prompt && *system_prompt) {
838
1
            rapidjson::Value system_instruction(rapidjson::kObjectType);
839
1
            rapidjson::Value parts(rapidjson::kArrayType);
840
841
1
            rapidjson::Value part(rapidjson::kObjectType);
842
1
            part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
843
1
            parts.PushBack(part, allocator);
844
            // system_instruction.PushBack(content, allocator);
845
1
            system_instruction.AddMember("parts", parts, allocator);
846
1
            doc.AddMember("systemInstruction", system_instruction, allocator);
847
1
        }
848
849
1
        rapidjson::Value contents(rapidjson::kArrayType);
850
1
        for (const auto& input : inputs) {
851
1
            rapidjson::Value content(rapidjson::kObjectType);
852
1
            rapidjson::Value parts(rapidjson::kArrayType);
853
854
1
            rapidjson::Value part(rapidjson::kObjectType);
855
1
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
856
857
1
            parts.PushBack(part, allocator);
858
1
            content.AddMember("parts", parts, allocator);
859
1
            contents.PushBack(content, allocator);
860
1
        }
861
1
        doc.AddMember("contents", contents, allocator);
862
863
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
864
1
        rapidjson::Value generationConfig(rapidjson::kObjectType);
865
1
        if (_config.temperature != -1) {
866
1
            generationConfig.AddMember("temperature", _config.temperature, allocator);
867
1
        }
868
1
        if (_config.max_tokens != -1) {
869
1
            generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
870
1
        }
871
1
        doc.AddMember("generationConfig", generationConfig, allocator);
872
873
1
        rapidjson::StringBuffer buffer;
874
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
875
1
        doc.Accept(writer);
876
1
        request_body = buffer.GetString();
877
878
1
        return Status::OK();
879
1
    }
880
881
    Status parse_response(const std::string& response_body,
882
3
                          std::vector<std::string>& results) const override {
883
3
        rapidjson::Document doc;
884
3
        doc.Parse(response_body.c_str());
885
886
3
        if (doc.HasParseError() || !doc.IsObject()) {
887
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
888
1
                                         response_body);
889
1
        }
890
2
        if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
891
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
892
1
                                         response_body);
893
1
        }
894
895
        /*{
896
          "candidates":[
897
            {
898
              "content": {
899
                "parts": [
900
                  {
901
                    "text": "xxx"
902
                  }
903
                ]
904
              }
905
            }
906
          ]
907
        }*/
908
1
        const auto& candidates = doc["candidates"];
909
1
        results.reserve(candidates.Size());
910
911
2
        for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
912
1
            if (!candidates[i].HasMember("content") ||
913
1
                !candidates[i]["content"].HasMember("parts") ||
914
1
                !candidates[i]["content"]["parts"].IsArray() ||
915
1
                candidates[i]["content"]["parts"].Empty() ||
916
1
                !candidates[i]["content"]["parts"][0].HasMember("text") ||
917
1
                !candidates[i]["content"]["parts"][0]["text"].IsString()) {
918
0
                return Status::InternalError("Invalid candidate format in {} response",
919
0
                                             _config.provider_type);
920
0
            }
921
922
1
            results.emplace_back(candidates[i]["content"]["parts"][0]["text"].GetString());
923
1
        }
924
925
1
        return Status::OK();
926
1
    }
927
928
    Status build_embedding_request(const std::vector<std::string>& inputs,
929
2
                                   std::string& request_body) const override {
930
2
        rapidjson::Document doc;
931
2
        doc.SetObject();
932
2
        auto& allocator = doc.GetAllocator();
933
934
        /*{
935
          "model": "models/gemini-embedding-001",
936
          "content": {
937
              "parts": [
938
                {
939
                  "text": "xxx"
940
                }
941
              ]
942
            }
943
          "outputDimensionality": 1024
944
        }*/
945
946
        // gemini requires the model format as `models/{model}`
947
2
        std::string model_name = _config.model_name;
948
2
        if (!model_name.starts_with("models/")) {
949
2
            model_name = "models/" + model_name;
950
2
        }
951
2
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
952
2
        add_dimension_params(doc, allocator);
953
954
2
        rapidjson::Value content(rapidjson::kObjectType);
955
2
        for (const auto& input : inputs) {
956
2
            rapidjson::Value parts(rapidjson::kArrayType);
957
2
            rapidjson::Value part(rapidjson::kObjectType);
958
2
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
959
2
            parts.PushBack(part, allocator);
960
2
            content.AddMember("parts", parts, allocator);
961
2
        }
962
2
        doc.AddMember("content", content, allocator);
963
964
2
        rapidjson::StringBuffer buffer;
965
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
966
2
        doc.Accept(writer);
967
2
        request_body = buffer.GetString();
968
969
2
        return Status::OK();
970
2
    }
971
972
    Status parse_embedding_response(const std::string& response_body,
973
1
                                    std::vector<std::vector<float>>& results) const override {
974
1
        rapidjson::Document doc;
975
1
        doc.Parse(response_body.c_str());
976
977
1
        if (doc.HasParseError() || !doc.IsObject()) {
978
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
979
0
                                         response_body);
980
0
        }
981
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
982
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
983
0
                                         response_body);
984
0
        }
985
986
        /*{
987
          "embedding":{
988
            "values": [0.1, 0.2, 0.3]
989
          }
990
        }*/
991
1
        const auto& embedding = doc["embedding"];
992
1
        if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
993
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
994
0
                                         response_body);
995
0
        }
996
1
        std::transform(embedding["values"].Begin(), embedding["values"].End(),
997
1
                       std::back_inserter(results.emplace_back()),
998
3
                       [](const auto& val) { return val.GetFloat(); });
999
1000
1
        return Status::OK();
1001
1
    }
1002
1003
protected:
1004
2
    bool supports_dimension_param(const std::string& model_name) const override {
1005
2
        static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
1006
2
                                                                            "embedding-001"};
1007
2
        return !no_dimension_models.contains(model_name);
1008
2
    }
1009
1010
1
    std::string get_dimension_param_name() const override { return "outputDimensionality"; }
1011
};
1012
1013
class AnthropicAdapter : public VoyageAIAdapter {
1014
public:
1015
1
    Status set_authentication(HttpClient* client) const override {
1016
1
        client->set_header("x-api-key", _config.api_key);
1017
1
        client->set_header("anthropic-version", _config.anthropic_version);
1018
1
        client->set_content_type("application/json");
1019
1020
1
        return Status::OK();
1021
1
    }
1022
1023
    Status build_request_payload(const std::vector<std::string>& inputs,
1024
                                 const char* const system_prompt,
1025
1
                                 std::string& request_body) const override {
1026
1
        rapidjson::Document doc;
1027
1
        doc.SetObject();
1028
1
        auto& allocator = doc.GetAllocator();
1029
1030
        /*
1031
            "model": "claude-opus-4-1-20250805",
1032
            "max_tokens": 1024,
1033
            "system": "system_prompt here",
1034
            "messages": [
1035
              {"role": "user", "content": "xxx"}
1036
            ],
1037
            "temperature": 0.7
1038
        */
1039
1040
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1041
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1042
1
        if (_config.temperature != -1) {
1043
1
            doc.AddMember("temperature", _config.temperature, allocator);
1044
1
        }
1045
1
        if (_config.max_tokens != -1) {
1046
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
1047
1
        } else {
1048
            // Keep the default value, Anthropic requires this parameter
1049
0
            doc.AddMember("max_tokens", 2048, allocator);
1050
0
        }
1051
1
        if (system_prompt && *system_prompt) {
1052
1
            doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
1053
1
        }
1054
1055
1
        rapidjson::Value messages(rapidjson::kArrayType);
1056
1
        for (const auto& input : inputs) {
1057
1
            rapidjson::Value message(rapidjson::kObjectType);
1058
1
            message.AddMember("role", "user", allocator);
1059
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
1060
1
            messages.PushBack(message, allocator);
1061
1
        }
1062
1
        doc.AddMember("messages", messages, allocator);
1063
1064
1
        rapidjson::StringBuffer buffer;
1065
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1066
1
        doc.Accept(writer);
1067
1
        request_body = buffer.GetString();
1068
1069
1
        return Status::OK();
1070
1
    }
1071
1072
    Status parse_response(const std::string& response_body,
1073
3
                          std::vector<std::string>& results) const override {
1074
3
        rapidjson::Document doc;
1075
3
        doc.Parse(response_body.c_str());
1076
3
        if (doc.HasParseError() || !doc.IsObject()) {
1077
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1078
1
                                         response_body);
1079
1
        }
1080
2
        if (!doc.HasMember("content") || !doc["content"].IsArray()) {
1081
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1082
1
                                         response_body);
1083
1
        }
1084
1085
        /*{
1086
            "content": [
1087
              {
1088
                "text": "xxx",
1089
                "type": "text"
1090
              }
1091
            ]
1092
        }*/
1093
1
        const auto& content = doc["content"];
1094
1
        results.reserve(1);
1095
1096
1
        std::string result;
1097
2
        for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
1098
1
            if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
1099
1
                !content[i].HasMember("text") || !content[i]["text"].IsString()) {
1100
0
                continue;
1101
0
            }
1102
1103
1
            if (std::string(content[i]["type"].GetString()) == "text") {
1104
1
                if (!result.empty()) {
1105
0
                    result += "\n";
1106
0
                }
1107
1
                result += content[i]["text"].GetString();
1108
1
            }
1109
1
        }
1110
1111
1
        results.emplace_back(std::move(result));
1112
1
        return Status::OK();
1113
2
    }
1114
};
1115
1116
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
1117
class MockAdapter : public AIAdapter {
1118
public:
1119
0
    Status set_authentication(HttpClient* client) const override { return Status::OK(); }
1120
1121
    Status build_request_payload(const std::vector<std::string>& inputs,
1122
                                 const char* const system_prompt,
1123
49
                                 std::string& request_body) const override {
1124
49
        return Status::OK();
1125
49
    }
1126
1127
    Status parse_response(const std::string& response_body,
1128
49
                          std::vector<std::string>& results) const override {
1129
49
        results.emplace_back(response_body);
1130
49
        return Status::OK();
1131
49
    }
1132
1133
    Status build_embedding_request(const std::vector<std::string>& inputs,
1134
1
                                   std::string& request_body) const override {
1135
1
        return Status::OK();
1136
1
    }
1137
1138
    Status parse_embedding_response(const std::string& response_body,
1139
1
                                    std::vector<std::vector<float>>& results) const override {
1140
1
        rapidjson::Document doc;
1141
1
        doc.SetObject();
1142
1
        doc.Parse(response_body.c_str());
1143
1
        if (doc.HasParseError() || !doc.IsObject()) {
1144
0
            return Status::InternalError("Failed to parse embedding response");
1145
0
        }
1146
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
1147
0
            return Status::InternalError("Invalid embedding response format");
1148
0
        }
1149
1150
1
        results.reserve(1);
1151
1
        std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
1152
1
                       std::back_inserter(results.emplace_back()),
1153
5
                       [](const auto& val) { return val.GetFloat(); });
1154
1
        return Status::OK();
1155
1
    }
1156
};
1157
1158
class AIAdapterFactory {
1159
public:
1160
74
    static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
1161
74
        static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
1162
74
                adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
1163
74
                            {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
1164
74
                            {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
1165
74
                            {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
1166
74
                            {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
1167
74
                            {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
1168
74
                            {"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
1169
74
                            {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
1170
74
                            {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
1171
74
                            {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
1172
74
                            {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
1173
74
                            {"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
1174
1175
74
        auto it = adapters.find(provider_type);
1176
74
        return (it != adapters.end()) ? it->second() : nullptr;
1177
74
    }
1178
};
1179
1180
} // namespace doris