Coverage Report

Created: 2026-03-19 18:23

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_adapter.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <gen_cpp/PaloInternalService_types.h>
21
#include <rapidjson/rapidjson.h>
22
23
#include <algorithm>
24
#include <memory>
25
#include <string>
26
#include <unordered_map>
27
#include <vector>
28
29
#include "common/status.h"
30
#include "core/string_buffer.hpp"
31
#include "rapidjson/document.h"
32
#include "rapidjson/stringbuffer.h"
33
#include "rapidjson/writer.h"
34
#include "service/http/http_client.h"
35
#include "service/http/http_headers.h"
36
37
namespace doris {
38
#include "common/compile_check_begin.h"
39
40
struct AIResource {
41
32
    AIResource() = default;
42
    AIResource(const TAIResource& tai)
43
22
            : endpoint(tai.endpoint),
44
22
              provider_type(tai.provider_type),
45
22
              model_name(tai.model_name),
46
22
              api_key(tai.api_key),
47
22
              temperature(tai.temperature),
48
22
              max_tokens(tai.max_tokens),
49
22
              max_retries(tai.max_retries),
50
22
              retry_delay_second(tai.retry_delay_second),
51
22
              anthropic_version(tai.anthropic_version),
52
22
              dimensions(tai.dimensions) {}
53
54
    std::string endpoint;
55
    std::string provider_type;
56
    std::string model_name;
57
    std::string api_key;
58
    double temperature;
59
    int64_t max_tokens;
60
    int32_t max_retries;
61
    int32_t retry_delay_second;
62
    std::string anthropic_version;
63
    int32_t dimensions;
64
65
2
    void serialize(BufferWritable& buf) const {
66
2
        buf.write_binary(endpoint);
67
2
        buf.write_binary(provider_type);
68
2
        buf.write_binary(model_name);
69
2
        buf.write_binary(api_key);
70
2
        buf.write_binary(temperature);
71
2
        buf.write_binary(max_tokens);
72
2
        buf.write_binary(max_retries);
73
2
        buf.write_binary(retry_delay_second);
74
2
        buf.write_binary(anthropic_version);
75
2
        buf.write_binary(dimensions);
76
2
    }
77
78
2
    void deserialize(BufferReadable& buf) {
79
2
        buf.read_binary(endpoint);
80
2
        buf.read_binary(provider_type);
81
2
        buf.read_binary(model_name);
82
2
        buf.read_binary(api_key);
83
2
        buf.read_binary(temperature);
84
2
        buf.read_binary(max_tokens);
85
2
        buf.read_binary(max_retries);
86
2
        buf.read_binary(retry_delay_second);
87
2
        buf.read_binary(anthropic_version);
88
2
        buf.read_binary(dimensions);
89
2
    }
90
};
91
92
class AIAdapter {
93
public:
94
222
    virtual ~AIAdapter() = default;
95
96
    // Set authentication headers for the HTTP client
97
    virtual Status set_authentication(HttpClient* client) const = 0;
98
99
142
    virtual void init(const TAIResource& config) { _config = config; }
100
24
    virtual void init(const AIResource& config) {
101
24
        _config.endpoint = config.endpoint;
102
24
        _config.provider_type = config.provider_type;
103
24
        _config.model_name = config.model_name;
104
24
        _config.api_key = config.api_key;
105
24
        _config.temperature = config.temperature;
106
24
        _config.max_tokens = config.max_tokens;
107
24
        _config.max_retries = config.max_retries;
108
24
        _config.retry_delay_second = config.retry_delay_second;
109
24
        _config.anthropic_version = config.anthropic_version;
110
24
    }
111
112
    // Build request payload based on input text strings
113
    virtual Status build_request_payload(const std::vector<std::string>& inputs,
114
                                         const char* const system_prompt,
115
2
                                         std::string& request_body) const {
116
2
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
117
2
    }
118
119
    // Parse response from AI service and extract generated text results
120
    virtual Status parse_response(const std::string& response_body,
121
2
                                  std::vector<std::string>& results) const {
122
2
        return Status::NotSupported("{} don't support text generation", _config.provider_type);
123
2
    }
124
125
    virtual Status build_embedding_request(const std::vector<std::string>& inputs,
126
0
                                           std::string& request_body) const {
127
0
        return Status::NotSupported("{} does not support the Embed feature.",
128
0
                                    _config.provider_type);
129
0
    }
130
131
    virtual Status parse_embedding_response(const std::string& response_body,
132
0
                                            std::vector<std::vector<float>>& results) const {
133
0
        return Status::NotSupported("{} does not support the Embed feature.",
134
0
                                    _config.provider_type);
135
0
    }
136
137
protected:
138
    TAIResource _config;
139
140
    // return true if the model support dimension parameter
141
2
    virtual bool supports_dimension_param(const std::string& model_name) const { return false; }
142
143
    // Different providers may have different dimension parameter names.
144
0
    virtual std::string get_dimension_param_name() const { return "dimensions"; }
145
146
    virtual void add_dimension_params(rapidjson::Value& doc,
147
22
                                      rapidjson::Document::AllocatorType& allocator) const {
148
22
        if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) {
149
10
            std::string param_name = get_dimension_param_name();
150
10
            rapidjson::Value name(param_name.c_str(), allocator);
151
10
            doc.AddMember(name, _config.dimensions, allocator);
152
10
        }
153
22
    }
154
};
155
156
// Most LLM-providers' Embedding formats are based on VoyageAI.
157
// The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic.
158
class VoyageAIAdapter : public AIAdapter {
159
public:
160
4
    Status set_authentication(HttpClient* client) const override {
161
4
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
162
4
        client->set_content_type("application/json");
163
164
4
        return Status::OK();
165
4
    }
166
167
    Status build_embedding_request(const std::vector<std::string>& inputs,
168
16
                                   std::string& request_body) const override {
169
16
        rapidjson::Document doc;
170
16
        doc.SetObject();
171
16
        auto& allocator = doc.GetAllocator();
172
173
        /*{
174
            "model": "xxx",
175
            "input": [
176
              "xxx",
177
              "xxx",
178
              ...
179
            ],
180
            "output_dimensions": 512
181
        }*/
182
16
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
183
16
        add_dimension_params(doc, allocator);
184
185
16
        rapidjson::Value input(rapidjson::kArrayType);
186
16
        for (const auto& msg : inputs) {
187
16
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
188
16
        }
189
16
        doc.AddMember("input", input, allocator);
190
191
16
        rapidjson::StringBuffer buffer;
192
16
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
193
16
        doc.Accept(writer);
194
16
        request_body = buffer.GetString();
195
196
16
        return Status::OK();
197
16
    }
198
199
    Status parse_embedding_response(const std::string& response_body,
200
10
                                    std::vector<std::vector<float>>& results) const override {
201
10
        rapidjson::Document doc;
202
10
        doc.Parse(response_body.c_str());
203
204
10
        if (doc.HasParseError() || !doc.IsObject()) {
205
2
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
206
2
                                         response_body);
207
2
        }
208
8
        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
209
2
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
210
2
                                         response_body);
211
2
        }
212
213
        /*{
214
            "data":[
215
              {
216
                "object": "embedding",
217
                "embedding": [...], <- only need this
218
                "index": 0
219
              },
220
              {
221
                "object": "embedding",
222
                "embedding": [...],
223
                "index": 1
224
              }, ...
225
            ],
226
            "model"....
227
        }*/
228
6
        const auto& data = doc["data"];
229
6
        results.reserve(data.Size());
230
14
        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
231
10
            if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
232
2
                return Status::InternalError("Invalid {} response format: {}",
233
2
                                             _config.provider_type, response_body);
234
2
            }
235
236
8
            std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
237
8
                           std::back_inserter(results.emplace_back()),
238
20
                           [](const auto& val) { return val.GetFloat(); });
239
8
        }
240
241
4
        return Status::OK();
242
6
    }
243
244
protected:
245
4
    bool supports_dimension_param(const std::string& model_name) const override {
246
4
        static const std::unordered_set<std::string> no_dimension_models = {
247
4
                "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2",
248
4
                "voyage-multimodal-3"};
249
4
        return !no_dimension_models.contains(model_name);
250
4
    }
251
252
2
    std::string get_dimension_param_name() const override { return "output_dimension"; }
253
};
254
255
// Local AI adapter for locally hosted models (Ollama, LLaMA, etc.)
256
class LocalAdapter : public AIAdapter {
257
public:
258
    // Local deployments typically don't need authentication
259
4
    Status set_authentication(HttpClient* client) const override {
260
4
        client->set_content_type("application/json");
261
4
        return Status::OK();
262
4
    }
263
264
    Status build_request_payload(const std::vector<std::string>& inputs,
265
                                 const char* const system_prompt,
266
6
                                 std::string& request_body) const override {
267
6
        rapidjson::Document doc;
268
6
        doc.SetObject();
269
6
        auto& allocator = doc.GetAllocator();
270
271
6
        std::string end_point = _config.endpoint;
272
6
        if (end_point.ends_with("chat") || end_point.ends_with("generate")) {
273
4
            RETURN_IF_ERROR(
274
4
                    build_ollama_request(doc, allocator, inputs, system_prompt, request_body));
275
4
        } else {
276
2
            RETURN_IF_ERROR(
277
2
                    build_default_request(doc, allocator, inputs, system_prompt, request_body));
278
2
        }
279
280
6
        rapidjson::StringBuffer buffer;
281
6
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
282
6
        doc.Accept(writer);
283
6
        request_body = buffer.GetString();
284
285
6
        return Status::OK();
286
6
    }
287
288
    Status parse_response(const std::string& response_body,
289
14
                          std::vector<std::string>& results) const override {
290
14
        rapidjson::Document doc;
291
14
        doc.Parse(response_body.c_str());
292
293
14
        if (doc.HasParseError() || !doc.IsObject()) {
294
2
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
295
2
                                         response_body);
296
2
        }
297
298
        // Handle various response formats from local LLMs
299
        // Format 1: OpenAI-compatible format with choices/message/content
300
12
        if (doc.HasMember("choices") && doc["choices"].IsArray()) {
301
2
            const auto& choices = doc["choices"];
302
2
            results.reserve(choices.Size());
303
304
4
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
305
2
                if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") &&
306
2
                    choices[i]["message"]["content"].IsString()) {
307
2
                    results.emplace_back(choices[i]["message"]["content"].GetString());
308
2
                } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) {
309
                    // Some local LLMs use a simpler format
310
0
                    results.emplace_back(choices[i]["text"].GetString());
311
0
                }
312
2
            }
313
10
        } else if (doc.HasMember("text") && doc["text"].IsString()) {
314
            // Format 2: Simple response with just "text" or "content" field
315
2
            results.emplace_back(doc["text"].GetString());
316
8
        } else if (doc.HasMember("content") && doc["content"].IsString()) {
317
2
            results.emplace_back(doc["content"].GetString());
318
6
        } else if (doc.HasMember("response") && doc["response"].IsString()) {
319
            // Format 3: Response field (Ollama `generate` format)
320
2
            results.emplace_back(doc["response"].GetString());
321
4
        } else if (doc.HasMember("message") && doc["message"].IsObject() &&
322
4
                   doc["message"].HasMember("content") && doc["message"]["content"].IsString()) {
323
            // Format 4: message/content field (Ollama `chat` format)
324
2
            results.emplace_back(doc["message"]["content"].GetString());
325
2
        } else {
326
2
            return Status::NotSupported("Unsupported response format from local AI.");
327
2
        }
328
329
10
        return Status::OK();
330
12
    }
331
332
    Status build_embedding_request(const std::vector<std::string>& inputs,
333
2
                                   std::string& request_body) const override {
334
2
        rapidjson::Document doc;
335
2
        doc.SetObject();
336
2
        auto& allocator = doc.GetAllocator();
337
338
2
        if (!_config.model_name.empty()) {
339
2
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
340
2
                          allocator);
341
2
        }
342
343
2
        add_dimension_params(doc, allocator);
344
345
2
        rapidjson::Value input(rapidjson::kArrayType);
346
2
        for (const auto& msg : inputs) {
347
2
            input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator);
348
2
        }
349
2
        doc.AddMember("input", input, allocator);
350
351
2
        rapidjson::StringBuffer buffer;
352
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
353
2
        doc.Accept(writer);
354
2
        request_body = buffer.GetString();
355
356
2
        return Status::OK();
357
2
    }
358
359
    Status parse_embedding_response(const std::string& response_body,
360
6
                                    std::vector<std::vector<float>>& results) const override {
361
6
        rapidjson::Document doc;
362
6
        doc.Parse(response_body.c_str());
363
364
6
        if (doc.HasParseError() || !doc.IsObject()) {
365
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
366
0
                                         response_body);
367
0
        }
368
369
        // parse different response format
370
6
        rapidjson::Value embedding;
371
6
        if (doc.HasMember("data") && doc["data"].IsArray()) {
372
            // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0]
373
2
            const auto& data = doc["data"];
374
2
            results.reserve(data.Size());
375
6
            for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
376
4
                if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) {
377
0
                    return Status::InternalError("Invalid {} response format",
378
0
                                                 _config.provider_type);
379
0
                }
380
381
4
                std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(),
382
4
                               std::back_inserter(results.emplace_back()),
383
10
                               [](const auto& val) { return val.GetFloat(); });
384
4
            }
385
4
        } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) {
386
            // "embeddings":[[0.1, 0.2, ...]]
387
2
            results.reserve(1);
388
4
            for (int i = 0; i < doc["embeddings"].Size(); i++) {
389
2
                embedding = doc["embeddings"][i];
390
2
                std::transform(embedding.Begin(), embedding.End(),
391
2
                               std::back_inserter(results.emplace_back()),
392
4
                               [](const auto& val) { return val.GetFloat(); });
393
2
            }
394
2
        } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) {
395
            // "embedding":[0.1, 0.2, ...]
396
2
            results.reserve(1);
397
2
            embedding = doc["embedding"];
398
2
            std::transform(embedding.Begin(), embedding.End(),
399
2
                           std::back_inserter(results.emplace_back()),
400
6
                           [](const auto& val) { return val.GetFloat(); });
401
2
        } else {
402
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
403
0
                                         response_body);
404
0
        }
405
406
6
        return Status::OK();
407
6
    }
408
409
private:
410
    Status build_ollama_request(rapidjson::Document& doc,
411
                                rapidjson::Document::AllocatorType& allocator,
412
                                const std::vector<std::string>& inputs,
413
4
                                const char* const system_prompt, std::string& request_body) const {
414
        /*
415
        for endpoints end_with `/chat` like 'http://localhost:11434/api/chat':
416
        {
417
            "model": <model_name>,
418
            "stream": false,
419
            "think": false,
420
            "options": {
421
                "temperature": <temperature>,
422
                "max_token": <max_token>
423
            },
424
            "messages": [
425
                {"role": "system", "content": <system_prompt>},
426
                {"role": "user", "content": <user_prompt>}
427
            ]
428
        }
429
        
430
        for endpoints end_with `/generate` like 'http://localhost:11434/api/generate':
431
        {
432
            "model": <model_name>,
433
            "stream": false,
434
            "think": false
435
            "options": {
436
                "temperature": <temperature>,
437
                "max_token": <max_token>
438
            },
439
            "system": <system_prompt>,
440
            "prompt": <user_prompt>
441
        }
442
        */
443
444
        // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint;
445
        // The rest remains identical.
446
4
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
447
4
        doc.AddMember("stream", false, allocator);
448
4
        doc.AddMember("think", false, allocator);
449
450
        // option section
451
4
        rapidjson::Value options(rapidjson::kObjectType);
452
4
        if (_config.temperature != -1) {
453
4
            options.AddMember("temperature", _config.temperature, allocator);
454
4
        }
455
4
        if (_config.max_tokens != -1) {
456
4
            options.AddMember("max_token", _config.max_tokens, allocator);
457
4
        }
458
4
        doc.AddMember("options", options, allocator);
459
460
        // prompt section
461
4
        if (_config.endpoint.ends_with("chat")) {
462
2
            rapidjson::Value messages(rapidjson::kArrayType);
463
2
            if (system_prompt && *system_prompt) {
464
2
                rapidjson::Value sys_msg(rapidjson::kObjectType);
465
2
                sys_msg.AddMember("role", "system", allocator);
466
2
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
467
2
                messages.PushBack(sys_msg, allocator);
468
2
            }
469
2
            for (const auto& input : inputs) {
470
2
                rapidjson::Value message(rapidjson::kObjectType);
471
2
                message.AddMember("role", "user", allocator);
472
2
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
473
2
                messages.PushBack(message, allocator);
474
2
            }
475
2
            doc.AddMember("messages", messages, allocator);
476
2
        } else {
477
2
            if (system_prompt && *system_prompt) {
478
2
                doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
479
2
            }
480
2
            doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator);
481
2
        }
482
483
4
        return Status::OK();
484
4
    }
485
486
    Status build_default_request(rapidjson::Document& doc,
487
                                 rapidjson::Document::AllocatorType& allocator,
488
                                 const std::vector<std::string>& inputs,
489
2
                                 const char* const system_prompt, std::string& request_body) const {
490
        /*
491
        Default format(OpenAI-compatible):
492
        {
493
            "model": <model_name>,
494
            "temperature": <temperature>,
495
            "max_tokens": <max_tokens>,
496
            "messages": [
497
                {"role": "system", "content": <system_prompt>},
498
                {"role": "user", "content": <user_prompt>}
499
            ]
500
        }
501
        */
502
503
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
504
505
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
506
2
        if (_config.temperature != -1) {
507
2
            doc.AddMember("temperature", _config.temperature, allocator);
508
2
        }
509
2
        if (_config.max_tokens != -1) {
510
2
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
511
2
        }
512
513
2
        rapidjson::Value messages(rapidjson::kArrayType);
514
2
        if (system_prompt && *system_prompt) {
515
2
            rapidjson::Value sys_msg(rapidjson::kObjectType);
516
2
            sys_msg.AddMember("role", "system", allocator);
517
2
            sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
518
2
            messages.PushBack(sys_msg, allocator);
519
2
        }
520
2
        for (const auto& input : inputs) {
521
2
            rapidjson::Value message(rapidjson::kObjectType);
522
2
            message.AddMember("role", "user", allocator);
523
2
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
524
2
            messages.PushBack(message, allocator);
525
2
        }
526
2
        doc.AddMember("messages", messages, allocator);
527
2
        return Status::OK();
528
2
    }
529
};
530
531
// The OpenAI API format can be reused with some compatible AIs.
532
class OpenAIAdapter : public VoyageAIAdapter {
533
public:
534
16
    Status set_authentication(HttpClient* client) const override {
535
16
        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key);
536
16
        client->set_content_type("application/json");
537
538
16
        return Status::OK();
539
16
    }
540
541
    Status build_request_payload(const std::vector<std::string>& inputs,
542
                                 const char* const system_prompt,
543
4
                                 std::string& request_body) const override {
544
4
        rapidjson::Document doc;
545
4
        doc.SetObject();
546
4
        auto& allocator = doc.GetAllocator();
547
548
4
        if (_config.endpoint.ends_with("responses")) {
549
            /*{
550
              "model": "gpt-4.1-mini",
551
              "input": [
552
                {"role": "system", "content": "system_prompt here"},
553
                {"role": "user", "content": "xxx"}
554
              ],
555
              "temperature": 0.7,
556
              "max_output_tokens": 150
557
            }*/
558
2
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
559
2
                          allocator);
560
561
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
562
2
            if (_config.temperature != -1) {
563
2
                doc.AddMember("temperature", _config.temperature, allocator);
564
2
            }
565
2
            if (_config.max_tokens != -1) {
566
2
                doc.AddMember("max_output_tokens", _config.max_tokens, allocator);
567
2
            }
568
569
            // input
570
2
            rapidjson::Value input(rapidjson::kArrayType);
571
2
            if (system_prompt && *system_prompt) {
572
2
                rapidjson::Value sys_msg(rapidjson::kObjectType);
573
2
                sys_msg.AddMember("role", "system", allocator);
574
2
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
575
2
                input.PushBack(sys_msg, allocator);
576
2
            }
577
2
            for (const auto& msg : inputs) {
578
2
                rapidjson::Value message(rapidjson::kObjectType);
579
2
                message.AddMember("role", "user", allocator);
580
2
                message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator);
581
2
                input.PushBack(message, allocator);
582
2
            }
583
2
            doc.AddMember("input", input, allocator);
584
2
        } else {
585
            /*{
586
              "model": "gpt-4",
587
              "messages": [
588
                {"role": "system", "content": "system_prompt here"},
589
                {"role": "user", "content": "xxx"}
590
              ],
591
              "temperature": x,
592
              "max_tokens": x,
593
            }*/
594
2
            doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator),
595
2
                          allocator);
596
597
            // If 'temperature' and 'max_tokens' are set, add them to the request body.
598
2
            if (_config.temperature != -1) {
599
2
                doc.AddMember("temperature", _config.temperature, allocator);
600
2
            }
601
2
            if (_config.max_tokens != -1) {
602
2
                doc.AddMember("max_tokens", _config.max_tokens, allocator);
603
2
            }
604
605
2
            rapidjson::Value messages(rapidjson::kArrayType);
606
2
            if (system_prompt && *system_prompt) {
607
2
                rapidjson::Value sys_msg(rapidjson::kObjectType);
608
2
                sys_msg.AddMember("role", "system", allocator);
609
2
                sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator);
610
2
                messages.PushBack(sys_msg, allocator);
611
2
            }
612
2
            for (const auto& input : inputs) {
613
2
                rapidjson::Value message(rapidjson::kObjectType);
614
2
                message.AddMember("role", "user", allocator);
615
2
                message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
616
2
                messages.PushBack(message, allocator);
617
2
            }
618
2
            doc.AddMember("messages", messages, allocator);
619
2
        }
620
621
4
        rapidjson::StringBuffer buffer;
622
4
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
623
4
        doc.Accept(writer);
624
4
        request_body = buffer.GetString();
625
626
4
        return Status::OK();
627
4
    }
628
629
    Status parse_response(const std::string& response_body,
630
12
                          std::vector<std::string>& results) const override {
631
12
        rapidjson::Document doc;
632
12
        doc.Parse(response_body.c_str());
633
634
12
        if (doc.HasParseError() || !doc.IsObject()) {
635
2
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
636
2
                                         response_body);
637
2
        }
638
639
10
        if (doc.HasMember("output") && doc["output"].IsArray()) {
640
            /// for responses endpoint
641
            /*{
642
              "output": [
643
                {
644
                  "id": "msg_123",
645
                  "type": "message",
646
                  "role": "assistant",
647
                  "content": [
648
                    {
649
                      "type": "text",
650
                      "text": "result text here"   <- result
651
                    }
652
                  ]
653
                }
654
              ]
655
            }*/
656
2
            const auto& output = doc["output"];
657
2
            results.reserve(output.Size());
658
659
4
            for (rapidjson::SizeType i = 0; i < output.Size(); i++) {
660
2
                if (!output[i].HasMember("content") || !output[i]["content"].IsArray() ||
661
2
                    output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") ||
662
2
                    !output[i]["content"][0]["text"].IsString()) {
663
0
                    return Status::InternalError("Invalid output format in {} response: {}",
664
0
                                                 _config.provider_type, response_body);
665
0
                }
666
667
2
                results.emplace_back(output[i]["content"][0]["text"].GetString());
668
2
            }
669
8
        } else if (doc.HasMember("choices") && doc["choices"].IsArray()) {
670
            /// for completions endpoint
671
            /*{
672
              "object": "chat.completion",
673
              "model": "gpt-4",
674
              "choices": [
675
                {
676
                  ...
677
                  "message": {
678
                    "role": "assistant",
679
                    "content": "xxx"      <- result
680
                  },
681
                  ...
682
                }
683
              ],
684
              ...
685
            }*/
686
6
            const auto& choices = doc["choices"];
687
6
            results.reserve(choices.Size());
688
689
8
            for (rapidjson::SizeType i = 0; i < choices.Size(); i++) {
690
6
                if (!choices[i].HasMember("message") ||
691
6
                    !choices[i]["message"].HasMember("content") ||
692
6
                    !choices[i]["message"]["content"].IsString()) {
693
4
                    return Status::InternalError("Invalid choice format in {} response: {}",
694
4
                                                 _config.provider_type, response_body);
695
4
                }
696
697
2
                results.emplace_back(choices[i]["message"]["content"].GetString());
698
2
            }
699
6
        } else {
700
2
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
701
2
                                         response_body);
702
2
        }
703
704
4
        return Status::OK();
705
10
    }
706
707
protected:
708
4
    bool supports_dimension_param(const std::string& model_name) const override {
709
4
        return !(model_name == "text-embedding-ada-002");
710
4
    }
711
712
4
    std::string get_dimension_param_name() const override { return "dimensions"; }
713
};
714
715
class DeepSeekAdapter : public OpenAIAdapter {
716
public:
717
    Status build_embedding_request(const std::vector<std::string>& inputs,
718
2
                                   std::string& request_body) const override {
719
2
        return Status::NotSupported("{} does not support the Embed feature.",
720
2
                                    _config.provider_type);
721
2
    }
722
723
    Status parse_embedding_response(const std::string& response_body,
724
2
                                    std::vector<std::vector<float>>& results) const override {
725
2
        return Status::NotSupported("{} does not support the Embed feature.",
726
2
                                    _config.provider_type);
727
2
    }
728
};
729
730
class MoonShotAdapter : public OpenAIAdapter {
731
public:
732
    Status build_embedding_request(const std::vector<std::string>& inputs,
733
2
                                   std::string& request_body) const override {
734
2
        return Status::NotSupported("{} does not support the Embed feature.",
735
2
                                    _config.provider_type);
736
2
    }
737
738
    Status parse_embedding_response(const std::string& response_body,
739
2
                                    std::vector<std::vector<float>>& results) const override {
740
2
        return Status::NotSupported("{} does not support the Embed feature.",
741
2
                                    _config.provider_type);
742
2
    }
743
};
744
745
class MinimaxAdapter : public OpenAIAdapter {
746
public:
747
    Status build_embedding_request(const std::vector<std::string>& inputs,
748
2
                                   std::string& request_body) const override {
749
2
        rapidjson::Document doc;
750
2
        doc.SetObject();
751
2
        auto& allocator = doc.GetAllocator();
752
753
        /*{
754
          "text": ["xxx", "xxx", ...],
755
          "model": "embo-1",
756
          "type": "db"
757
        }*/
758
2
        rapidjson::Value texts(rapidjson::kArrayType);
759
2
        for (const auto& input : inputs) {
760
2
            texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator);
761
2
        }
762
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
763
2
        doc.AddMember("texts", texts, allocator);
764
2
        doc.AddMember("type", rapidjson::Value("db", allocator), allocator);
765
766
2
        rapidjson::StringBuffer buffer;
767
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
768
2
        doc.Accept(writer);
769
2
        request_body = buffer.GetString();
770
771
2
        return Status::OK();
772
2
    }
773
};
774
775
class ZhipuAdapter : public OpenAIAdapter {
776
protected:
777
4
    bool supports_dimension_param(const std::string& model_name) const override {
778
4
        return !(model_name == "embedding-2");
779
4
    }
780
};
781
782
class QwenAdapter : public OpenAIAdapter {
783
protected:
784
4
    bool supports_dimension_param(const std::string& model_name) const override {
785
4
        static const std::unordered_set<std::string> no_dimension_models = {
786
4
                "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"};
787
4
        return !no_dimension_models.contains(model_name);
788
4
    }
789
790
2
    std::string get_dimension_param_name() const override { return "dimension"; }
791
};
792
793
class BaichuanAdapter : public OpenAIAdapter {
794
protected:
795
0
    bool supports_dimension_param(const std::string& model_name) const override { return false; }
796
};
797
798
// Gemini's embedding format is different from VoyageAI, so it requires a separate adapter
799
class GeminiAdapter : public AIAdapter {
800
public:
801
4
    Status set_authentication(HttpClient* client) const override {
802
4
        client->set_header("x-goog-api-key", _config.api_key);
803
4
        client->set_content_type("application/json");
804
4
        return Status::OK();
805
4
    }
806
807
    Status build_request_payload(const std::vector<std::string>& inputs,
808
                                 const char* const system_prompt,
809
2
                                 std::string& request_body) const override {
810
2
        rapidjson::Document doc;
811
2
        doc.SetObject();
812
2
        auto& allocator = doc.GetAllocator();
813
814
        /*{
815
          "systemInstruction": {
816
              "parts": [
817
                {
818
                  "text": "system_prompt here"
819
                }
820
              ]
821
            }
822
          ],
823
          "contents": [
824
            {
825
              "parts": [
826
                {
827
                  "text": "xxx"
828
                }
829
              ]
830
            }
831
          ],
832
          "generationConfig": {
833
          "temperature": 0.7,
834
          "maxOutputTokens": 1024
835
          }
836
837
        }*/
838
2
        if (system_prompt && *system_prompt) {
839
2
            rapidjson::Value system_instruction(rapidjson::kObjectType);
840
2
            rapidjson::Value parts(rapidjson::kArrayType);
841
842
2
            rapidjson::Value part(rapidjson::kObjectType);
843
2
            part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator);
844
2
            parts.PushBack(part, allocator);
845
            // system_instruction.PushBack(content, allocator);
846
2
            system_instruction.AddMember("parts", parts, allocator);
847
2
            doc.AddMember("systemInstruction", system_instruction, allocator);
848
2
        }
849
850
2
        rapidjson::Value contents(rapidjson::kArrayType);
851
2
        for (const auto& input : inputs) {
852
2
            rapidjson::Value content(rapidjson::kObjectType);
853
2
            rapidjson::Value parts(rapidjson::kArrayType);
854
855
2
            rapidjson::Value part(rapidjson::kObjectType);
856
2
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
857
858
2
            parts.PushBack(part, allocator);
859
2
            content.AddMember("parts", parts, allocator);
860
2
            contents.PushBack(content, allocator);
861
2
        }
862
2
        doc.AddMember("contents", contents, allocator);
863
864
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
865
2
        rapidjson::Value generationConfig(rapidjson::kObjectType);
866
2
        if (_config.temperature != -1) {
867
2
            generationConfig.AddMember("temperature", _config.temperature, allocator);
868
2
        }
869
2
        if (_config.max_tokens != -1) {
870
2
            generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator);
871
2
        }
872
2
        doc.AddMember("generationConfig", generationConfig, allocator);
873
874
2
        rapidjson::StringBuffer buffer;
875
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
876
2
        doc.Accept(writer);
877
2
        request_body = buffer.GetString();
878
879
2
        return Status::OK();
880
2
    }
881
882
    Status parse_response(const std::string& response_body,
883
6
                          std::vector<std::string>& results) const override {
884
6
        rapidjson::Document doc;
885
6
        doc.Parse(response_body.c_str());
886
887
6
        if (doc.HasParseError() || !doc.IsObject()) {
888
2
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
889
2
                                         response_body);
890
2
        }
891
4
        if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) {
892
2
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
893
2
                                         response_body);
894
2
        }
895
896
        /*{
897
          "candidates":[
898
            {
899
              "content": {
900
                "parts": [
901
                  {
902
                    "text": "xxx"
903
                  }
904
                ]
905
              }
906
            }
907
          ]
908
        }*/
909
2
        const auto& candidates = doc["candidates"];
910
2
        results.reserve(candidates.Size());
911
912
4
        for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) {
913
2
            if (!candidates[i].HasMember("content") ||
914
2
                !candidates[i]["content"].HasMember("parts") ||
915
2
                !candidates[i]["content"]["parts"].IsArray() ||
916
2
                candidates[i]["content"]["parts"].Empty() ||
917
2
                !candidates[i]["content"]["parts"][0].HasMember("text") ||
918
2
                !candidates[i]["content"]["parts"][0]["text"].IsString()) {
919
0
                return Status::InternalError("Invalid candidate format in {} response",
920
0
                                             _config.provider_type);
921
0
            }
922
923
2
            results.emplace_back(candidates[i]["content"]["parts"][0]["text"].GetString());
924
2
        }
925
926
2
        return Status::OK();
927
2
    }
928
929
    Status build_embedding_request(const std::vector<std::string>& inputs,
930
4
                                   std::string& request_body) const override {
931
4
        rapidjson::Document doc;
932
4
        doc.SetObject();
933
4
        auto& allocator = doc.GetAllocator();
934
935
        /*{
936
          "model": "models/gemini-embedding-001",
937
          "content": {
938
              "parts": [
939
                {
940
                  "text": "xxx"
941
                }
942
              ]
943
            }
944
          "outputDimensionality": 1024
945
        }*/
946
947
        // gemini requires the model format as `models/{model}`
948
4
        std::string model_name = _config.model_name;
949
4
        if (!model_name.starts_with("models/")) {
950
4
            model_name = "models/" + model_name;
951
4
        }
952
4
        doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator);
953
4
        add_dimension_params(doc, allocator);
954
955
4
        rapidjson::Value content(rapidjson::kObjectType);
956
4
        for (const auto& input : inputs) {
957
4
            rapidjson::Value parts(rapidjson::kArrayType);
958
4
            rapidjson::Value part(rapidjson::kObjectType);
959
4
            part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator);
960
4
            parts.PushBack(part, allocator);
961
4
            content.AddMember("parts", parts, allocator);
962
4
        }
963
4
        doc.AddMember("content", content, allocator);
964
965
4
        rapidjson::StringBuffer buffer;
966
4
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
967
4
        doc.Accept(writer);
968
4
        request_body = buffer.GetString();
969
970
4
        return Status::OK();
971
4
    }
972
973
    Status parse_embedding_response(const std::string& response_body,
974
2
                                    std::vector<std::vector<float>>& results) const override {
975
2
        rapidjson::Document doc;
976
2
        doc.Parse(response_body.c_str());
977
978
2
        if (doc.HasParseError() || !doc.IsObject()) {
979
0
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
980
0
                                         response_body);
981
0
        }
982
2
        if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) {
983
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
984
0
                                         response_body);
985
0
        }
986
987
        /*{
988
          "embedding":{
989
            "values": [0.1, 0.2, 0.3]
990
          }
991
        }*/
992
2
        const auto& embedding = doc["embedding"];
993
2
        if (!embedding.HasMember("values") || !embedding["values"].IsArray()) {
994
0
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
995
0
                                         response_body);
996
0
        }
997
2
        std::transform(embedding["values"].Begin(), embedding["values"].End(),
998
2
                       std::back_inserter(results.emplace_back()),
999
6
                       [](const auto& val) { return val.GetFloat(); });
1000
1001
2
        return Status::OK();
1002
2
    }
1003
1004
protected:
1005
4
    bool supports_dimension_param(const std::string& model_name) const override {
1006
4
        static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001",
1007
4
                                                                            "embedding-001"};
1008
4
        return !no_dimension_models.contains(model_name);
1009
4
    }
1010
1011
2
    std::string get_dimension_param_name() const override { return "outputDimensionality"; }
1012
};
1013
1014
class AnthropicAdapter : public VoyageAIAdapter {
1015
public:
1016
2
    Status set_authentication(HttpClient* client) const override {
1017
2
        client->set_header("x-api-key", _config.api_key);
1018
2
        client->set_header("anthropic-version", _config.anthropic_version);
1019
2
        client->set_content_type("application/json");
1020
1021
2
        return Status::OK();
1022
2
    }
1023
1024
    Status build_request_payload(const std::vector<std::string>& inputs,
1025
                                 const char* const system_prompt,
1026
2
                                 std::string& request_body) const override {
1027
2
        rapidjson::Document doc;
1028
2
        doc.SetObject();
1029
2
        auto& allocator = doc.GetAllocator();
1030
1031
        /*
1032
            "model": "claude-opus-4-1-20250805",
1033
            "max_tokens": 1024,
1034
            "system": "system_prompt here",
1035
            "messages": [
1036
              {"role": "user", "content": "xxx"}
1037
            ],
1038
            "temperature": 0.7
1039
        */
1040
1041
        // If 'temperature' and 'max_tokens' are set, add them to the request body.
1042
2
        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator);
1043
2
        if (_config.temperature != -1) {
1044
2
            doc.AddMember("temperature", _config.temperature, allocator);
1045
2
        }
1046
2
        if (_config.max_tokens != -1) {
1047
2
            doc.AddMember("max_tokens", _config.max_tokens, allocator);
1048
2
        } else {
1049
            // Keep the default value, Anthropic requires this parameter
1050
0
            doc.AddMember("max_tokens", 2048, allocator);
1051
0
        }
1052
2
        if (system_prompt && *system_prompt) {
1053
2
            doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator);
1054
2
        }
1055
1056
2
        rapidjson::Value messages(rapidjson::kArrayType);
1057
2
        for (const auto& input : inputs) {
1058
2
            rapidjson::Value message(rapidjson::kObjectType);
1059
2
            message.AddMember("role", "user", allocator);
1060
2
            message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator);
1061
2
            messages.PushBack(message, allocator);
1062
2
        }
1063
2
        doc.AddMember("messages", messages, allocator);
1064
1065
2
        rapidjson::StringBuffer buffer;
1066
2
        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
1067
2
        doc.Accept(writer);
1068
2
        request_body = buffer.GetString();
1069
1070
2
        return Status::OK();
1071
2
    }
1072
1073
    Status parse_response(const std::string& response_body,
1074
6
                          std::vector<std::string>& results) const override {
1075
6
        rapidjson::Document doc;
1076
6
        doc.Parse(response_body.c_str());
1077
6
        if (doc.HasParseError() || !doc.IsObject()) {
1078
2
            return Status::InternalError("Failed to parse {} response: {}", _config.provider_type,
1079
2
                                         response_body);
1080
2
        }
1081
4
        if (!doc.HasMember("content") || !doc["content"].IsArray()) {
1082
2
            return Status::InternalError("Invalid {} response format: {}", _config.provider_type,
1083
2
                                         response_body);
1084
2
        }
1085
1086
        /*{
1087
            "content": [
1088
              {
1089
                "text": "xxx",
1090
                "type": "text"
1091
              }
1092
            ]
1093
        }*/
1094
2
        const auto& content = doc["content"];
1095
2
        results.reserve(1);
1096
1097
2
        std::string result;
1098
4
        for (rapidjson::SizeType i = 0; i < content.Size(); i++) {
1099
2
            if (!content[i].HasMember("type") || !content[i]["type"].IsString() ||
1100
2
                !content[i].HasMember("text") || !content[i]["text"].IsString()) {
1101
0
                continue;
1102
0
            }
1103
1104
2
            if (std::string(content[i]["type"].GetString()) == "text") {
1105
2
                if (!result.empty()) {
1106
0
                    result += "\n";
1107
0
                }
1108
2
                result += content[i]["text"].GetString();
1109
2
            }
1110
2
        }
1111
1112
2
        results.emplace_back(std::move(result));
1113
2
        return Status::OK();
1114
4
    }
1115
};
1116
1117
// Mock adapter used only for UT to bypass real HTTP calls and return deterministic data.
1118
class MockAdapter : public AIAdapter {
1119
public:
1120
0
    Status set_authentication(HttpClient* client) const override { return Status::OK(); }
1121
1122
    Status build_request_payload(const std::vector<std::string>& inputs,
1123
                                 const char* const system_prompt,
1124
98
                                 std::string& request_body) const override {
1125
98
        return Status::OK();
1126
98
    }
1127
1128
    Status parse_response(const std::string& response_body,
1129
98
                          std::vector<std::string>& results) const override {
1130
98
        results.emplace_back(response_body);
1131
98
        return Status::OK();
1132
98
    }
1133
1134
    Status build_embedding_request(const std::vector<std::string>& inputs,
1135
2
                                   std::string& request_body) const override {
1136
2
        return Status::OK();
1137
2
    }
1138
1139
    Status parse_embedding_response(const std::string& response_body,
1140
2
                                    std::vector<std::vector<float>>& results) const override {
1141
2
        rapidjson::Document doc;
1142
2
        doc.SetObject();
1143
2
        doc.Parse(response_body.c_str());
1144
2
        if (doc.HasParseError() || !doc.IsObject()) {
1145
0
            return Status::InternalError("Failed to parse embedding response");
1146
0
        }
1147
2
        if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) {
1148
0
            return Status::InternalError("Invalid embedding response format");
1149
0
        }
1150
1151
2
        results.reserve(1);
1152
2
        std::transform(doc["embedding"].Begin(), doc["embedding"].End(),
1153
2
                       std::back_inserter(results.emplace_back()),
1154
10
                       [](const auto& val) { return val.GetFloat(); });
1155
2
        return Status::OK();
1156
2
    }
1157
};
1158
1159
class AIAdapterFactory {
1160
public:
1161
148
    static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) {
1162
148
        static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>>
1163
148
                adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }},
1164
148
                            {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }},
1165
148
                            {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }},
1166
148
                            {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }},
1167
148
                            {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }},
1168
148
                            {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }},
1169
148
                            {"QWEN", []() { return std::make_shared<QwenAdapter>(); }},
1170
148
                            {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }},
1171
148
                            {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }},
1172
148
                            {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }},
1173
148
                            {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }},
1174
148
                            {"MOCK", []() { return std::make_shared<MockAdapter>(); }}};
1175
1176
148
        auto it = adapters.find(provider_type);
1177
148
        return (it != adapters.end()) ? it->second() : nullptr;
1178
148
    }
1179
};
1180
1181
#include "common/compile_check_end.h"
1182
} // namespace doris