Coverage Report

Created: 2026-03-16 13:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_adapter.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/PaloInternalService_types.h>
21
#include <rapidjson/rapidjson.h>
22
23
#include <algorithm>
24
#include <memory>
25
#include <string>
26
#include <unordered_map>
27
#include <vector>
28
29
#include "common/status.h"
30
#include "core/string_buffer.hpp"
31
#include "rapidjson/document.h"
32
#include "rapidjson/stringbuffer.h"
33
#include "rapidjson/writer.h"
34
#include "service/http/http_client.h"
35
#include "service/http/http_headers.h"
36
37
namespace doris {
38
#include "common/compile_check_begin.h"
39
40
struct AIResource {
41
16
    AIResource() = default;
42
    AIResource(const TAIResource& tai)
43
11
            : endpoint(tai.endpoint),
44
11
              provider_type(tai.provider_type),
45
11
              model_name(tai.model_name),
46
11
              api_key(tai.api_key),
47
11
              temperature(tai.temperature),
48
11
              max_tokens(tai.max_tokens),
49
11
              max_retries(tai.max_retries),
50
11
              retry_delay_second(tai.retry_delay_second),
51
11
              anthropic_version(tai.anthropic_version),
52
11
              dimensions(tai.dimensions) {}
53
54
    std::string endpoint;
55
    std::string provider_type;
56
    std::string model_name;
57
    std::string api_key;
58
    double temperature;
59
    int64_t max_tokens;
60
    int32_t max_retries;
61
    int32_t retry_delay_second;
62
    std::string anthropic_version;
63
    int32_t dimensions;
64
65
1
    void serialize(BufferWritable& buf) const {
66
1
        buf.write_binary(endpoint);
67
1
        buf.write_binary(provider_type);
68
1
        buf.write_binary(model_name);
69
1
        buf.write_binary(api_key);
70
1
        buf.write_binary(temperature);
71
1
        buf.write_binary(max_tokens);
72
1
        buf.write_binary(max_retries);
73
1
        buf.write_binary(retry_delay_second);
74
1
        buf.write_binary(anthropic_version);
75
1
        buf.write_binary(dimensions);
76
1
    }
77
78
1
    void deserialize(BufferReadable& buf) {
79
1
        buf.read_binary(endpoint);
80
1
        buf.read_binary(provider_type);
81
1
        buf.read_binary(model_name);
82
1
        buf.read_binary(api_key);
83
1
        buf.read_binary(temperature);
84
1
        buf.read_binary(max_tokens);
85
1
        buf.read_binary(max_retries);
86
1
        buf.read_binary(retry_delay_second);
87
1
        buf.read_binary(anthropic_version);
88
1
        buf.read_binary(dimensions);
89
1
    }
90
};
91
92
class AIAdapter {
93
public:
94
111
    virtual ~AIAdapter() = default;
95
96
    // Set authentication headers for the HTTP client
97
    virtual Status set_authentication(HttpClient* client) const = 0;
98
99
71
    virtual void init(const TAIResource& config) { _config = config; }
100
12
    virtual void init(const AIResource& config) {
101
12
        _config.endpoint = config.endpoint;
102
12
        _config.provider_type = config.provider_type;
103
12
        _config.model_name = config.model_name;
104
12
        _config.api_key = config.api_key;
105
12
        _config.temperature = config.temperature;
106
12
        _config.max_tokens = config.max_tokens;
107
12
        _config.max_retries = config.max_retries;
108
12
        _config.retry_delay_second = config.retry_delay_second;
109
12
        _config.anthropic_version = config.anthropic_version;
110
12
    }
111
112
    // Build request payload based on input text strings
113
    virtual Status build_request_payload(const std::vector<std::string>& inputs,
114
                                         const char* const system_prompt,
115
1
                                         std::string& request_body) const {
116
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
117
1
    }
118
119
    // Parse response from AI service and extract generated text results
120
    virtual Status parse_response(const std::string& response_body,
121
1
                                  std::vector<std::string>& results) const {
122
1
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
123
1
    }
124
125
    virtual Status build_embedding_request(const std::vector<std::string>& inputs,
126
0
                                           std::string& request_body) const {
127
0
        return Status::NotSupported("{} does not support the Embed feature.",
128
0
                                    _config.provider_type);
129
0
    }
130
131
    virtual Status parse_embedding_response(const std::string& response_body,
132
0
                                            std::vector<std::vector<float>>& results) const {
133
0
        return Status::NotSupported("{} does not support the Embed feature.",
134
0
                                    _config.provider_type);
135
0
    }
136
137
protected:
138
    TAIResource _config;
139
140
    // return true if the model support dimension parameter
141
1
    virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
142
143
    // Different providers may have different dimension parameter names.
144
0
    virtual std::string get_dimension_param_name() const { return "dimensions"; }
145
146
    virtual void add_dimension_params(rapidjson::Value& doc,
147
11
                                      rapidjson::Document::AllocatorType& allocator) const {
148
11
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
149
5
            std::string param_name = get_dimension_param_name();
150
5
            rapidjson::Value name(param_name.c_str(), allocator);
151
5
            doc.AddMember(name, _config.dimensions, allocator);
152
5
        }
153
11
    }
154
};
155
156
// Most LLM-providers' Embedding formats are based on VoyageAI.
157
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
158
class VoyageAIAdapter : public AIAdapter {
159
public:
160
2
    Status set_authentication(HttpClient* client) const override {
161
2
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
162
2
        client->set_content_type("application/json");
163
164
2
        return Status::OK();
165
2
    }
166
167
    Status build_embedding_request(const std::vector<std::string>& inputs,
168
8
                                   std::string& request_body) const override {
169
8
        rapidjson::Document doc;
170
8
        doc.SetObject();
171
8
        auto& allocator = doc.GetAllocator();
172
173
        /*{
174
            "model": "xxx",
175
            "input": [
176
              "xxx",
177
              "xxx",
178
              ...
179
            ],
180
            "output_dimensions": 512
181
        }*/
182
8
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
183
8
        add_dimension_params(doc, allocator);
184
185
8
        rapidjson::Value input(rapidjson::kArrayType);
186
8
        for (const auto& msg : inputs) {
187
8
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
188
8
        }
189
8
        doc.AddMember("input", input, allocator);
190
191
8
        rapidjson::StringBuffer buffer;
192
8
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
193
8
        doc.Accept(writer);
194
8
        request_body = buffer.GetString();
195
196
8
        return Status::OK();
197
8
    }
198
199
    Status parse_embedding_response(const std::string& response_body,
200
5
                                    std::vector<std::vector<float>>& results) const override {
201
5
        rapidjson::Document doc;
202
5
        doc.Parse(response_body.c_str());
203
204
5
        if (doc.HasParseError() || !doc.IsObject()) {
205
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
206
1
                                         response_body);
207
1
        }
208
4
        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
209
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
210
1
                                         response_body);
211
1
        }
212
213
        /*{
214
            "data":[
215
              {
216
                "object": "embedding",
217
                "embedding": [...], <- only need this
218
                "index": 0
219
              },
220
              {
221
                "object": "embedding",
222
                "embedding": [...],
223
                "index": 1
224
              }, ...
225
            ],
226
            "model"....
227
        }*/
228
3
        const auto& data = doc["data"];
229
3
        results.reserve(data.Size());
230
7
        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
231
5
            if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
232
1
                return Status::InternalError("Invalid {} response format: {}",
233
1
                                             _config.provider_type, response_body);
234
1
            }
235
236
4
            std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
237
4
                           std::back_inserter(results.emplace_back()),
238
10
                           [](const auto& val) { return val.GetFloat(); });
239
4
        }
240
241
2
        return Status::OK();
242
3
    }
243
244
protected:
245
2
    bool supports_dimension_param(const std::string& model_name) const override {
246
2
        static const std::unordered_set<std::string> no_dimension_models = {
247
2
                "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
248
2
                "voyage-multimodal-3"};
249
2
        return !no_dimension_models.contains(model_name);
250
2
    }
251
252
1
    std::string get_dimension_param_name() const override { return "output_dimension"; }
253
};
254
255
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
256
class LocalAdapter : public AIAdapter {
257
public:
258
    // Local deployments typically don't need authentication
259
2
    Status set_authentication(HttpClient* client) const override {
260
2
        client->set_content_type("application/json");
261
2
        return Status::OK();
262
2
    }
263
264
    Status build_request_payload(const std::vector<std::string>& inputs,
265
                                 const char* const system_prompt,
266
3
                                 std::string& request_body) const override {
267
3
        rapidjson::Document doc;
268
3
        doc.SetObject();
269
3
        auto& allocator = doc.GetAllocator();
270
271
3
        std::string end_point = _config.endpoint;
272
3
        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
273
2
            RETURN_IF_ERROR(
274
2
                    build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
275
2
        } else {
276
1
            RETURN_IF_ERROR(
277
1
                    build_default_request(doc, allocator, inputs, system_prompt, request_body));
278
1
        }
279
280
3
        rapidjson::StringBuffer buffer;
281
3
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
282
3
        doc.Accept(writer);
283
3
        request_body = buffer.GetString();
284
285
3
        return Status::OK();
286
3
    }
287
288
    Status parse_response(const std::string& response_body,
289
7
                          std::vector<std::string>& results) const override {
290
7
        rapidjson::Document doc;
291
7
        doc.Parse(response_body.c_str());
292
293
7
        if (doc.HasParseError() || !doc.IsObject()) {
294
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
295
1
                                         response_body);
296
1
        }
297
298
        // Handle various response formats from local LLMs
299
        // Format 1: OpenAI-compatible format with choices/message/content
300
6
        if (doc.HasMember("choices") && doc["choices"].IsArray()) {
301
1
            const auto& choices = doc["choices"];
302
1
            results.reserve(choices.Size());
303
304
2
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
305
1
                if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
306
1
                    choices[i]["message"]["content"].IsString()) {
307
1
                    results.emplace_back(choices[i]["message"]["content"].GetString());
308
1
                } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
309
                    // Some local LLMs use a simpler format
310
0
                    results.emplace_back(choices[i]["text"].GetString());
311
0
                }
312
1
            }
313
5
        } else if (doc.HasMember("text") && doc["text"].IsString()) {
314
            // Format 2: Simple response with just "text" or "content" field
315
1
            results.emplace_back(doc["text"].GetString());
316
4
        } else if (doc.HasMember("content") && doc["content"].IsString()) {
317
1
            results.emplace_back(doc["content"].GetString());
318
3
        } else if (doc.HasMember("response") && doc["response"].IsString()) {
319
            // Format 3: Response field (Ollama `generate` format)
320
1
            results.emplace_back(doc["response"].GetString());
321
2
        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
322
2
                   doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
323
            // Format 4: message/content field (Ollama `chat` format)
324
1
            results.emplace_back(doc["message"]["content"].GetString());
325
1
        } else {
326
1
            return Status::NotSupported("Unsupported response format from local AI.");
327
1
        }
328
329
5
        return Status::OK();
330
6
    }
331
332
    Status build_embedding_request(const std::vector<std::string>& inputs,
333
1
                                   std::string& request_body) const override {
334
1
        rapidjson::Document doc;
335
1
        doc.SetObject();
336
1
        auto& allocator = doc.GetAllocator();
337
338
1
        if (!_config.model_name.empty()) {
339
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
340
1
                          allocator);
341
1
        }
342
343
1
        add_dimension_params(doc, allocator);
344
345
1
        rapidjson::Value input(rapidjson::kArrayType);
346
1
        for (const auto& msg : inputs) {
347
1
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
348
1
        }
349
1
        doc.AddMember("input", input, allocator);
350
351
1
        rapidjson::StringBuffer buffer;
352
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
353
1
        doc.Accept(writer);
354
1
        request_body = buffer.GetString();
355
356
1
        return Status::OK();
357
1
    }
358
359
    Status parse_embedding_response(const std::string& response_body,
360
3
                                    std::vector<std::vector<float>>& results) const override {
361
3
        rapidjson::Document doc;
362
3
        doc.Parse(response_body.c_str());
363
364
3
        if (doc.HasParseError() || !doc.IsObject()) {
365
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
366
0
                                         response_body);
367
0
        }
368
369
        // parse different response format
370
3
        rapidjson::Value embedding;
371
3
        if (doc.HasMember("data") && doc["data"].IsArray()) {
372
            // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
373
1
            const auto& data = doc["data"];
374
1
            results.reserve(data.Size());
375
3
            for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
376
2
                if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
377
0
                    return Status::InternalError("Invalid {} response format",
378
0
                                                 _config.provider_type);
379
0
                }
380
381
2
                std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
382
2
                               std::back_inserter(results.emplace_back()),
383
5
                               [](const auto& val) { return val.GetFloat(); });
384
2
            }
385
2
        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
386
            // "embeddings":[[0.1, 0.2, ...]]
387
1
            results.reserve(1);
388
2
            for (int i = 0; i < doc["embeddings"].Size(); i++) {
389
1
                embedding = doc["embeddings"][i];
390
1
                std::transform(embedding.Begin(), embedding.End(),
391
1
                               std::back_inserter(results.emplace_back()),
392
2
                               [](const auto& val) { return val.GetFloat(); });
393
1
            }
394
1
        } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
395
            // "embedding":[0.1, 0.2, ...]
396
1
            results.reserve(1);
397
1
            embedding = doc["embedding"];
398
1
            std::transform(embedding.Begin(), embedding.End(),
399
1
                           std::back_inserter(results.emplace_back()),
400
3
                           [](const auto& val) { return val.GetFloat(); });
401
1
        } else {
402
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
403
0
                                         response_body);
404
0
        }
405
406
3
        return Status::OK();
407
3
    }
408
409
private:
410
    Status build_ollama_request(rapidjson::Document& doc,
411
                                rapidjson::Document::AllocatorType& allocator,
412
                                const std::vector<std::string>& inputs,
413
2
                                const char* const system_prompt, std::string& request_body) const {
414
        /*
415
        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
416
        {
417
            "model": <model_name>,
418
            "stream": false,
419
            "think": false,
420
            "options": {
421
                "temperature": <temperature>,
422
                "max_token": <max_token>
423
            },
424
            "messages": [
425
                {"role": "system", "content": <system_prompt>},
426
                {"role": "user", "content": <user_prompt>}
427
            ]
428
        }
429
        
430
        for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
431
        {
432
            "model": <model_name>,
433
            "stream": false,
434
            "think": false
435
            "options": {
436
                "temperature": <temperature>,
437
                "max_token": <max_token>
438
            },
439
            "system": <system_prompt>,
440
            "prompt": <user_prompt>
441
        }
442
        */
443
444
        // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
445
        // The rest remains identical.
446
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
447
2
        doc.AddMember("stream", false, allocator);
448
2
        doc.AddMember("think", false, allocator);
449
450
        // option section
451
2
        rapidjson::Value options(rapidjson::kObjectType);
452
2
        if (_config.temperature != -1) {
453
2
            options.AddMember("temperature", _config.temperature, allocator);
454
2
        }
455
2
        if (_config.max_tokens != -1) {
456
2
            options.AddMember("max_token", _config.max_tokens, allocator);
457
2
        }
458
2
        doc.AddMember("options", options, allocator);
459
460
        // prompt section
461
2
        if (_config.endpoint.ends_with("chat")) {
462
1
            rapidjson::Value messages(rapidjson::kArrayType);
463
1
            if (system_prompt && *system_prompt) {
464
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
465
1
                sys_msg.AddMember("role", "system", allocator);
466
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
467
1
                messages.PushBack(sys_msg, allocator);
468
1
            }
469
1
            for (const auto& input : inputs) {
470
1
                rapidjson::Value message(rapidjson::kObjectType);
471
1
                message.AddMember("role", "user", allocator);
472
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
473
1
                messages.PushBack(message, allocator);
474
1
            }
475
1
            doc.AddMember("messages", messages, allocator);
476
1
        } else {
477
1
            if (system_prompt && *system_prompt) {
478
1
                doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
479
1
            }
480
1
            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
481
1
        }
482
483
2
        return Status::OK();
484
2
    }
485
486
    Status build_default_request(rapidjson::Document& doc,
487
                                 rapidjson::Document::AllocatorType& allocator,
488
                                 const std::vector<std::string>& inputs,
489
1
                                 const char* const system_prompt, std::string& request_body) const {
490
        /*
491
        Default format(OpenAI-compatible):
492
        {
493
            "model": <model_name>,
494
            "temperature": <temperature>,
495
            "max_tokens": <max_tokens>,
496
            "messages": [
497
                {"role": "system", "content": <system_prompt>},
498
                {"role": "user", "content": <user_prompt>}
499
            ]
500
        }
501
        */
502
503
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
504
505
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
506
1
        if (_config.temperature != -1) {
507
1
            doc.AddMember("temperature", _config.temperature, allocator);
508
1
        }
509
1
        if (_config.max_tokens != -1) {
510
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
511
1
        }
512
513
1
        rapidjson::Value messages(rapidjson::kArrayType);
514
1
        if (system_prompt && *system_prompt) {
515
1
            rapidjson::Value sys_msg(rapidjson::kObjectType);
516
1
            sys_msg.AddMember("role", "system", allocator);
517
1
            sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
518
1
            messages.PushBack(sys_msg, allocator);
519
1
        }
520
1
        for (const auto& input : inputs) {
521
1
            rapidjson::Value message(rapidjson::kObjectType);
522
1
            message.AddMember("role", "user", allocator);
523
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
524
1
            messages.PushBack(message, allocator);
525
1
        }
526
1
        doc.AddMember("messages", messages, allocator);
527
1
        return Status::OK();
528
1
    }
529
};
530
531
// The OpenAI API format can be reused with some compatible AIs.
532
class OpenAIAdapter : public VoyageAIAdapter {
533
public:
534
8
    Status set_authentication(HttpClient* client) const override {
535
8
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
536
8
        client->set_content_type("application/json");
537
538
8
        return Status::OK();
539
8
    }
540
541
    Status build_request_payload(const std::vector<std::string>& inputs,
542
                                 const char* const system_prompt,
543
2
                                 std::string& request_body) const override {
544
2
        rapidjson::Document doc;
545
2
        doc.SetObject();
546
2
        auto& allocator = doc.GetAllocator();
547
548
2
        if (_config.endpoint.ends_with("responses")) {
549
            /*{
550
              "model": "gpt-4.1-mini",
551
              "input": [
552
                {"role": "system", "content": "system_prompt here"},
553
                {"role": "user", "content": "xxx"}
554
              ],
555
              "temperature": 0.7,
556
              "max_output_tokens": 150
557
            }*/
558
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
559
1
                          allocator);
560
561
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
562
1
            if (_config.temperature != -1) {
563
1
                doc.AddMember("temperature", _config.temperature, allocator);
564
1
            }
565
1
            if (_config.max_tokens != -1) {
566
1
                doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
567
1
            }
568
569
            // input
570
1
            rapidjson::Value input(rapidjson::kArrayType);
571
1
            if (system_prompt && *system_prompt) {
572
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
573
1
                sys_msg.AddMember("role", "system", allocator);
574
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
575
1
                input.PushBack(sys_msg, allocator);
576
1
            }
577
1
            for (const auto& msg : inputs) {
578
1
                rapidjson::Value message(rapidjson::kObjectType);
579
1
                message.AddMember("role", "user", allocator);
580
1
                message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
581
1
                input.PushBack(message, allocator);
582
1
            }
583
1
            doc.AddMember("input", input, allocator);
584
1
        } else {
585
            /*{
586
              "model": "gpt-4",
587
              "messages": [
588
                {"role": "system", "content": "system_prompt here"},
589
                {"role": "user", "content": "xxx"}
590
              ],
591
              "temperature": x,
592
              "max_tokens": x,
593
            }*/
594
1
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
595
1
                          allocator);
596
597
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
598
1
            if (_config.temperature != -1) {
599
1
                doc.AddMember("temperature", _config.temperature, allocator);
600
1
            }
601
1
            if (_config.max_tokens != -1) {
602
1
                doc.AddMember("max_tokens", _config.max_tokens, allocator);
603
1
            }
604
605
1
            rapidjson::Value messages(rapidjson::kArrayType);
606
1
            if (system_prompt && *system_prompt) {
607
1
                rapidjson::Value sys_msg(rapidjson::kObjectType);
608
1
                sys_msg.AddMember("role", "system", allocator);
609
1
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
610
1
                messages.PushBack(sys_msg, allocator);
611
1
            }
612
1
            for (const auto& input : inputs) {
613
1
                rapidjson::Value message(rapidjson::kObjectType);
614
1
                message.AddMember("role", "user", allocator);
615
1
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
616
1
                messages.PushBack(message, allocator);
617
1
            }
618
1
            doc.AddMember("messages", messages, allocator);
619
1
        }
620
621
2
        rapidjson::StringBuffer buffer;
622
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
623
2
        doc.Accept(writer);
624
2
        request_body = buffer.GetString();
625
626
2
        return Status::OK();
627
2
    }
628
629
    Status parse_response(const std::string& response_body,
630
6
                          std::vector<std::string>& results) const override {
631
6
        rapidjson::Document doc;
632
6
        doc.Parse(response_body.c_str());
633
634
6
        if (doc.HasParseError() || !doc.IsObject()) {
635
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
636
1
                                         response_body);
637
1
        }
638
639
5
        if (doc.HasMember("output") && doc["output"].IsArray()) {
640
            /// for responses endpoint
641
            /*{
642
              "output": [
643
                {
644
                  "id": "msg_123",
645
                  "type": "message",
646
                  "role": "assistant",
647
                  "content": [
648
                    {
649
                      "type": "text",
650
                      "text": "result text here"   <- result
651
                    }
652
                  ]
653
                }
654
              ]
655
            }*/
656
1
            const auto& output = doc["output"];
657
1
            results.reserve(output.Size());
658
659
2
            for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
660
1
                if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
661
1
                    output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
662
1
                    !output[i]["content"][0]["text"].IsString()) {
663
0
                    return Status::InternalError("Invalid output format in {} response: {}",
664
0
                                                 _config.provider_type, response_body);
665
0
                }
666
667
1
                results.emplace_back(output[i]["content"][0]["text"].GetString());
668
1
            }
669
4
        } else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
670
            /// for completions endpoint
671
            /*{
672
              "object": "chat.completion",
673
              "model": "gpt-4",
674
              "choices": [
675
                {
676
                  ...
677
                  "message": {
678
                    "role": "assistant",
679
                    "content": "xxx"      <- result
680
                  },
681
                  ...
682
                }
683
              ],
684
              ...
685
            }*/
686
3
            const auto& choices = doc["choices"];
687
3
            results.reserve(choices.Size());
688
689
4
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
690
3
                if (!choices[i].HasMember("message") ||
691
3
                    !choices[i]["message"].HasMember("content") ||
692
3
                    !choices[i]["message"]["content"].IsString()) {
693
2
                    return Status::InternalError("Invalid choice format in {} response: {}",
694
2
                                                 _config.provider_type, response_body);
695
2
                }
696
697
1
                results.emplace_back(choices[i]["message"]["content"].GetString());
698
1
            }
699
3
        } else {
700
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
701
1
                                         response_body);
702
1
        }
703
704
2
        return Status::OK();
705
5
    }
706
707
protected:
708
2
    bool supports_dimension_param(const std::string& model_name) const override {
709
2
        return !(model_name == "text-embedding-ada-002");
710
2
    }
711
712
2
    std::string get_dimension_param_name() const override { return "dimensions"; }
713
};
714
715
class DeepSeekAdapter : public OpenAIAdapter {
716
public:
717
    Status build_embedding_request(const std::vector<std::string>& inputs,
718
1
                                   std::string& request_body) const override {
719
1
        return Status::NotSupported("{} does not support the Embed feature.",
720
1
                                    _config.provider_type);
721
1
    }
722
723
    Status parse_embedding_response(const std::string& response_body,
724
1
                                    std::vector<std::vector<float>>& results) const override {
725
1
        return Status::NotSupported("{} does not support the Embed feature.",
726
1
                                    _config.provider_type);
727
1
    }
728
};
729
730
class MoonShotAdapter : public OpenAIAdapter {
731
public:
732
    Status build_embedding_request(const std::vector<std::string>& inputs,
733
1
                                   std::string& request_body) const override {
734
1
        return Status::NotSupported("{} does not support the Embed feature.",
735
1
                                    _config.provider_type);
736
1
    }
737
738
    Status parse_embedding_response(const std::string& response_body,
739
1
                                    std::vector<std::vector<float>>& results) const override {
740
1
        return Status::NotSupported("{} does not support the Embed feature.",
741
1
                                    _config.provider_type);
742
1
    }
743
};
744
745
class MinimaxAdapter : public OpenAIAdapter {
746
public:
747
    Status build_embedding_request(const std::vector<std::string>& inputs,
748
1
                                   std::string& request_body) const override {
749
1
        rapidjson::Document doc;
750
1
        doc.SetObject();
751
1
        auto& allocator = doc.GetAllocator();
752
753
        /*{
754
          "text": ["xxx", "xxx", ...],
755
          "model": "embo-1",
756
          "type": "db"
757
        }*/
758
1
        rapidjson::Value texts(rapidjson::kArrayType);
759
1
        for (const auto& input : inputs) {
760
1
            texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
761
1
        }
762
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
763
1
        doc.AddMember("texts", texts, allocator);
764
1
        doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
765
766
1
        rapidjson::StringBuffer buffer;
767
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
768
1
        doc.Accept(writer);
769
1
        request_body = buffer.GetString();
770
771
1
        return Status::OK();
772
1
    }
773
};
774
775
class ZhipuAdapter : public OpenAIAdapter {
776
protected:
777
2
    bool supports_dimension_param(const std::string& model_name) const override {
778
2
        return !(model_name == "embedding-2");
779
2
    }
780
};
781
782
class QwenAdapter : public OpenAIAdapter {
783
protected:
784
2
    bool supports_dimension_param(const std::string& model_name) const override {
785
2
        static const std::unordered_set<std::string> no_dimension_models = {
786
2
                "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
787
2
        return !no_dimension_models.contains(model_name);
788
2
    }
789
790
1
    std::string get_dimension_param_name() const override { return "dimension"; }
791
};
792
793
class BaichuanAdapter : public OpenAIAdapter {
794
protected:
795
0
    bool supports_dimension_param(const std::string& model_name) const override { return false; }
796
};
797
798
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
799
class GeminiAdapter : public AIAdapter {
800
public:
801
2
    Status set_authentication(HttpClient* client) const override {
802
2
        client->set_header("x-goog-api-key", _config.api_key);
803
2
        client->set_content_type("application/json");
804
2
        return Status::OK();
805
2
    }
806
807
    Status build_request_payload(const std::vector<std::string>& inputs,
808
                                 const char* const system_prompt,
809
1
                                 std::string& request_body) const override {
810
1
        rapidjson::Document doc;
811
1
        doc.SetObject();
812
1
        auto& allocator = doc.GetAllocator();
813
814
        /*{
815
          "systemInstruction": {
816
              "parts": [
817
                {
818
                  "text": "system_prompt here"
819
                }
820
              ]
821
            }
822
          ],
823
          "contents": [
824
            {
825
              "parts": [
826
                {
827
                  "text": "xxx"
828
                }
829
              ]
830
            }
831
          ],
832
          "generationConfig": {
833
          "temperature": 0.7,
834
          "maxOutputTokens": 1024
835
          }
836
837
        }*/
838
1
        if (system_prompt && *system_prompt) {
839
1
            rapidjson::Value system_instruction(rapidjson::kObjectType);
840
1
            rapidjson::Value parts(rapidjson::kArrayType);
841
842
1
            rapidjson::Value part(rapidjson::kObjectType);
843
1
            part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
844
1
            parts.PushBack(part, allocator);
845
            // system_instruction.PushBack(content, allocator);
846
1
            system_instruction.AddMember("parts", parts, allocator);
847
1
            doc.AddMember("systemInstruction", system_instruction, allocator);
848
1
        }
849
850
1
        rapidjson::Value contents(rapidjson::kArrayType);
851
1
        for (const auto& input : inputs) {
852
1
            rapidjson::Value content(rapidjson::kObjectType);
853
1
            rapidjson::Value parts(rapidjson::kArrayType);
854
855
1
            rapidjson::Value part(rapidjson::kObjectType);
856
1
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
857
858
1
            parts.PushBack(part, allocator);
859
1
            content.AddMember("parts", parts, allocator);
860
1
            contents.PushBack(content, allocator);
861
1
        }
862
1
        doc.AddMember("contents", contents, allocator);
863
864
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
865
1
        rapidjson::Value generationConfig(rapidjson::kObjectType);
866
1
        if (_config.temperature != -1) {
867
1
            generationConfig.AddMember("temperature", _config.temperature, allocator);
868
1
        }
869
1
        if (_config.max_tokens != -1) {
870
1
            generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
871
1
        }
872
1
        doc.AddMember("generationConfig", generationConfig, allocator);
873
874
1
        rapidjson::StringBuffer buffer;
875
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
876
1
        doc.Accept(writer);
877
1
        request_body = buffer.GetString();
878
879
1
        return Status::OK();
880
1
    }
881
882
    Status parse_response(const std::string& response_body,
883
3
                          std::vector<std::string>& results) const override {
884
3
        rapidjson::Document doc;
885
3
        doc.Parse(response_body.c_str());
886
887
3
        if (doc.HasParseError() || !doc.IsObject()) {
888
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
889
1
                                         response_body);
890
1
        }
891
2
        if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
892
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
893
1
                                         response_body);
894
1
        }
895
896
        /*{
897
          "candidates":[
898
            {
899
              "content": {
900
                "parts": [
901
                  {
902
                    "text": "xxx"
903
                  }
904
                ]
905
              }
906
            }
907
          ]
908
        }*/
909
1
        const auto& candidates = doc["candidates"];
910
1
        results.reserve(candidates.Size());
911
912
2
        for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
913
1
            if (!candidates[i].HasMember("content") ||
914
1
                !candidates[i]["content"].HasMember("parts") ||
915
1
                !candidates[i]["content"]["parts"].IsArray() ||
916
1
                candidates[i]["content"]["parts"].Empty() ||
917
1
                !candidates[i]["content"]["parts"][0].HasMember("text") ||
918
1
                !candidates[i]["content"]["parts"][0]["text"].IsString()) {
919
0
                return Status::InternalError("Invalid candidate format in {} response",
920
0
                                             _config.provider_type);
921
0
            }
922
923
1
            results.emplace_back(candidates[i]["content"]["parts"][0]["text"].GetString());
924
1
        }
925
926
1
        return Status::OK();
927
1
    }
928
929
    Status build_embedding_request(const std::vector<std::string>& inputs,
930
2
                                   std::string& request_body) const override {
931
2
        rapidjson::Document doc;
932
2
        doc.SetObject();
933
2
        auto& allocator = doc.GetAllocator();
934
935
        /*{
936
          "model": "models/gemini-embedding-001",
937
          "content": {
938
              "parts": [
939
                {
940
                  "text": "xxx"
941
                }
942
              ]
943
            }
944
          "outputDimensionality": 1024
945
        }*/
946
947
        // gemini requires the model format as `models/{model}`
948
2
        std::string model_name = _config.model_name;
949
2
        if (!model_name.starts_with("models/")) {
950
2
            model_name = "models/" + model_name;
951
2
        }
952
2
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
953
2
        add_dimension_params(doc, allocator);
954
955
2
        rapidjson::Value content(rapidjson::kObjectType);
956
2
        for (const auto& input : inputs) {
957
2
            rapidjson::Value parts(rapidjson::kArrayType);
958
2
            rapidjson::Value part(rapidjson::kObjectType);
959
2
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
960
2
            parts.PushBack(part, allocator);
961
2
            content.AddMember("parts", parts, allocator);
962
2
        }
963
2
        doc.AddMember("content", content, allocator);
964
965
2
        rapidjson::StringBuffer buffer;
966
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
967
2
        doc.Accept(writer);
968
2
        request_body = buffer.GetString();
969
970
2
        return Status::OK();
971
2
    }
972
973
    Status parse_embedding_response(const std::string& response_body,
974
1
                                    std::vector<std::vector<float>>& results) const override {
975
1
        rapidjson::Document doc;
976
1
        doc.Parse(response_body.c_str());
977
978
1
        if (doc.HasParseError() || !doc.IsObject()) {
979
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
980
0
                                         response_body);
981
0
        }
982
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
983
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
984
0
                                         response_body);
985
0
        }
986
987
        /*{
988
          "embedding":{
989
            "values": [0.1, 0.2, 0.3]
990
          }
991
        }*/
992
1
        const auto& embedding = doc["embedding"];
993
1
        if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
994
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
995
0
                                         response_body);
996
0
        }
997
1
        std::transform(embedding["values"].Begin(), embedding["values"].End(),
998
1
                       std::back_inserter(results.emplace_back()),
999
3
                       [](const auto& val) { return val.GetFloat(); });
1000
1001
1
        return Status::OK();
1002
1
    }
1003
1004
protected:
1005
2
    bool supports_dimension_param(const std::string& model_name) const override {
1006
2
        static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
1007
2
                                                                            "embedding-001"};
1008
2
        return !no_dimension_models.contains(model_name);
1009
2
    }
1010
1011
1
    std::string get_dimension_param_name() const override { return "outputDimensionality"; }
1012
};
1013
1014
class AnthropicAdapter : public VoyageAIAdapter {
1015
public:
1016
1
    Status set_authentication(HttpClient* client) const override {
1017
1
        client->set_header("x-api-key", _config.api_key);
1018
1
        client->set_header("anthropic-version", _config.anthropic_version);
1019
1
        client->set_content_type("application/json");
1020
1021
1
        return Status::OK();
1022
1
    }
1023
1024
    Status build_request_payload(const std::vector<std::string>& inputs,
1025
                                 const char* const system_prompt,
1026
1
                                 std::string& request_body) const override {
1027
1
        rapidjson::Document doc;
1028
1
        doc.SetObject();
1029
1
        auto& allocator = doc.GetAllocator();
1030
1031
        /*
1032
            "model": "claude-opus-4-1-20250805",
1033
            "max_tokens": 1024,
1034
            "system": "system_prompt here",
1035
            "messages": [
1036
              {"role": "user", "content": "xxx"}
1037
            ],
1038
            "temperature": 0.7
1039
        */
1040
1041
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1042
1
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1043
1
        if (_config.temperature != -1) {
1044
1
            doc.AddMember("temperature", _config.temperature, allocator);
1045
1
        }
1046
1
        if (_config.max_tokens != -1) {
1047
1
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
1048
1
        } else {
1049
            // Keep the default value, Anthropic requires this parameter
1050
0
            doc.AddMember("max_tokens", 2048, allocator);
1051
0
        }
1052
1
        if (system_prompt && *system_prompt) {
1053
1
            doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
1054
1
        }
1055
1056
1
        rapidjson::Value messages(rapidjson::kArrayType);
1057
1
        for (const auto& input : inputs) {
1058
1
            rapidjson::Value message(rapidjson::kObjectType);
1059
1
            message.AddMember("role", "user", allocator);
1060
1
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
1061
1
            messages.PushBack(message, allocator);
1062
1
        }
1063
1
        doc.AddMember("messages", messages, allocator);
1064
1065
1
        rapidjson::StringBuffer buffer;
1066
1
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1067
1
        doc.Accept(writer);
1068
1
        request_body = buffer.GetString();
1069
1070
1
        return Status::OK();
1071
1
    }
1072
1073
    Status parse_response(const std::string& response_body,
1074
3
                          std::vector<std::string>& results) const override {
1075
3
        rapidjson::Document doc;
1076
3
        doc.Parse(response_body.c_str());
1077
3
        if (doc.HasParseError() || !doc.IsObject()) {
1078
1
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1079
1
                                         response_body);
1080
1
        }
1081
2
        if (!doc.HasMember("content") || !doc["content"].IsArray()) {
1082
1
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1083
1
                                         response_body);
1084
1
        }
1085
1086
        /*{
1087
            "content": [
1088
              {
1089
                "text": "xxx",
1090
                "type": "text"
1091
              }
1092
            ]
1093
        }*/
1094
1
        const auto& content = doc["content"];
1095
1
        results.reserve(1);
1096
1097
1
        std::string result;
1098
2
        for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
1099
1
            if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
1100
1
                !content[i].HasMember("text") || !content[i]["text"].IsString()) {
1101
0
                continue;
1102
0
            }
1103
1104
1
            if (std::string(content[i]["type"].GetString()) == "text") {
1105
1
                if (!result.empty()) {
1106
0
                    result += "\n";
1107
0
                }
1108
1
                result += content[i]["text"].GetString();
1109
1
            }
1110
1
        }
1111
1112
1
        results.emplace_back(std::move(result));
1113
1
        return Status::OK();
1114
2
    }
1115
};
1116
1117
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
1118
class MockAdapter : public AIAdapter {
1119
public:
1120
0
    Status set_authentication(HttpClient* client) const override { return Status::OK(); }
1121
1122
    Status build_request_payload(const std::vector<std::string>& inputs,
1123
                                 const char* const system_prompt,
1124
49
                                 std::string& request_body) const override {
1125
49
        return Status::OK();
1126
49
    }
1127
1128
    Status parse_response(const std::string& response_body,
1129
49
                          std::vector<std::string>& results) const override {
1130
49
        results.emplace_back(response_body);
1131
49
        return Status::OK();
1132
49
    }
1133
1134
    Status build_embedding_request(const std::vector<std::string>& inputs,
1135
1
                                   std::string& request_body) const override {
1136
1
        return Status::OK();
1137
1
    }
1138
1139
    Status parse_embedding_response(const std::string& response_body,
1140
1
                                    std::vector<std::vector<float>>& results) const override {
1141
1
        rapidjson::Document doc;
1142
1
        doc.SetObject();
1143
1
        doc.Parse(response_body.c_str());
1144
1
        if (doc.HasParseError() || !doc.IsObject()) {
1145
0
            return Status::InternalError("Failed to parse embedding response");
1146
0
        }
1147
1
        if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
1148
0
            return Status::InternalError("Invalid embedding response format");
1149
0
        }
1150
1151
1
        results.reserve(1);
1152
1
        std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
1153
1
                       std::back_inserter(results.emplace_back()),
1154
5
                       [](const auto& val) { return val.GetFloat(); });
1155
1
        return Status::OK();
1156
1
    }
1157
};
1158
1159
class AIAdapterFactory {
1160
public:
1161
74
    static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
1162
74
        static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
1163
74
                adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
1164
74
                            {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
1165
74
                            {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
1166
74
                            {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
1167
74
                            {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
1168
74
                            {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
1169
74
                            {"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
1170
74
                            {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
1171
74
                            {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
1172
74
                            {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
1173
74
                            {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
1174
74
                            {"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
1175
1176
74
        auto it = adapters.find(provider_type);
1177
74
        return (it != adapters.end()) ? it->second() : nullptr;
1178
74
    }
1179
};
1180
1181
#include "common/compile_check_end.h"
1182
} // namespace doris