be/src/exprs/function/ai/ai_adapter.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <gen_cpp/PaloInternalService_types.h> |
21 | | #include <rapidjson/rapidjson.h> |
22 | | |
23 | | #include <algorithm> |
24 | | #include <memory> |
25 | | #include <string> |
26 | | #include <unordered_map> |
27 | | #include <vector> |
28 | | |
29 | | #include "common/status.h" |
30 | | #include "core/string_buffer.hpp" |
31 | | #include "rapidjson/document.h" |
32 | | #include "rapidjson/stringbuffer.h" |
33 | | #include "rapidjson/writer.h" |
34 | | #include "service/http/http_client.h" |
35 | | #include "service/http/http_headers.h" |
36 | | |
37 | | namespace doris { |
38 | | #include "common/compile_check_begin.h" |
39 | | |
40 | | struct AIResource { |
41 | 16 | AIResource() = default; |
42 | | AIResource(const TAIResource& tai) |
43 | 11 | : endpoint(tai.endpoint), |
44 | 11 | provider_type(tai.provider_type), |
45 | 11 | model_name(tai.model_name), |
46 | 11 | api_key(tai.api_key), |
47 | 11 | temperature(tai.temperature), |
48 | 11 | max_tokens(tai.max_tokens), |
49 | 11 | max_retries(tai.max_retries), |
50 | 11 | retry_delay_second(tai.retry_delay_second), |
51 | 11 | anthropic_version(tai.anthropic_version), |
52 | 11 | dimensions(tai.dimensions) {} |
53 | | |
54 | | std::string endpoint; |
55 | | std::string provider_type; |
56 | | std::string model_name; |
57 | | std::string api_key; |
58 | | double temperature; |
59 | | int64_t max_tokens; |
60 | | int32_t max_retries; |
61 | | int32_t retry_delay_second; |
62 | | std::string anthropic_version; |
63 | | int32_t dimensions; |
64 | | |
65 | 1 | void serialize(BufferWritable& buf) const { |
66 | 1 | buf.write_binary(endpoint); |
67 | 1 | buf.write_binary(provider_type); |
68 | 1 | buf.write_binary(model_name); |
69 | 1 | buf.write_binary(api_key); |
70 | 1 | buf.write_binary(temperature); |
71 | 1 | buf.write_binary(max_tokens); |
72 | 1 | buf.write_binary(max_retries); |
73 | 1 | buf.write_binary(retry_delay_second); |
74 | 1 | buf.write_binary(anthropic_version); |
75 | 1 | buf.write_binary(dimensions); |
76 | 1 | } |
77 | | |
78 | 1 | void deserialize(BufferReadable& buf) { |
79 | 1 | buf.read_binary(endpoint); |
80 | 1 | buf.read_binary(provider_type); |
81 | 1 | buf.read_binary(model_name); |
82 | 1 | buf.read_binary(api_key); |
83 | 1 | buf.read_binary(temperature); |
84 | 1 | buf.read_binary(max_tokens); |
85 | 1 | buf.read_binary(max_retries); |
86 | 1 | buf.read_binary(retry_delay_second); |
87 | 1 | buf.read_binary(anthropic_version); |
88 | 1 | buf.read_binary(dimensions); |
89 | 1 | } |
90 | | }; |
91 | | |
92 | | class AIAdapter { |
93 | | public: |
94 | 111 | virtual ~AIAdapter() = default; |
95 | | |
96 | | // Set authentication headers for the HTTP client |
97 | | virtual Status set_authentication(HttpClient* client) const = 0; |
98 | | |
99 | 71 | virtual void init(const TAIResource& config) { _config = config; } |
100 | 12 | virtual void init(const AIResource& config) { |
101 | 12 | _config.endpoint = config.endpoint; |
102 | 12 | _config.provider_type = config.provider_type; |
103 | 12 | _config.model_name = config.model_name; |
104 | 12 | _config.api_key = config.api_key; |
105 | 12 | _config.temperature = config.temperature; |
106 | 12 | _config.max_tokens = config.max_tokens; |
107 | 12 | _config.max_retries = config.max_retries; |
108 | 12 | _config.retry_delay_second = config.retry_delay_second; |
109 | 12 | _config.anthropic_version = config.anthropic_version; |
110 | 12 | } |
111 | | |
112 | | // Build request payload based on input text strings |
113 | | virtual Status build_request_payload(const std::vector<std::string>& inputs, |
114 | | const char* const system_prompt, |
115 | 1 | std::string& request_body) const { |
116 | 1 | return Status::NotSupported("{} don't support text generation", _config.provider_type); |
117 | 1 | } |
118 | | |
119 | | // Parse response from AI service and extract generated text results |
120 | | virtual Status parse_response(const std::string& response_body, |
121 | 1 | std::vector<std::string>& results) const { |
122 | 1 | return Status::NotSupported("{} don't support text generation", _config.provider_type); |
123 | 1 | } |
124 | | |
125 | | virtual Status build_embedding_request(const std::vector<std::string>& inputs, |
126 | 0 | std::string& request_body) const { |
127 | 0 | return Status::NotSupported("{} does not support the Embed feature.", |
128 | 0 | _config.provider_type); |
129 | 0 | } |
130 | | |
131 | | virtual Status parse_embedding_response(const std::string& response_body, |
132 | 0 | std::vector<std::vector<float>>& results) const { |
133 | 0 | return Status::NotSupported("{} does not support the Embed feature.", |
134 | 0 | _config.provider_type); |
135 | 0 | } |
136 | | |
137 | | protected: |
138 | | TAIResource _config; |
139 | | |
140 | | // return true if the model support dimension parameter |
141 | 1 | virtual bool supports_dimension_param(const std::string& model_name) const { return false; } |
142 | | |
143 | | // Different providers may have different dimension parameter names. |
144 | 0 | virtual std::string get_dimension_param_name() const { return "dimensions"; } |
145 | | |
146 | | virtual void add_dimension_params(rapidjson::Value& doc, |
147 | 11 | rapidjson::Document::AllocatorType& allocator) const { |
148 | 11 | if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) { |
149 | 5 | std::string param_name = get_dimension_param_name(); |
150 | 5 | rapidjson::Value name(param_name.c_str(), allocator); |
151 | 5 | doc.AddMember(name, _config.dimensions, allocator); |
152 | 5 | } |
153 | 11 | } |
154 | | }; |
155 | | |
156 | | // Most LLM-providers' Embedding formats are based on VoyageAI. |
157 | | // The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic. |
158 | | class VoyageAIAdapter : public AIAdapter { |
159 | | public: |
160 | 2 | Status set_authentication(HttpClient* client) const override { |
161 | 2 | client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key); |
162 | 2 | client->set_content_type("application/json"); |
163 | | |
164 | 2 | return Status::OK(); |
165 | 2 | } |
166 | | |
167 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
168 | 8 | std::string& request_body) const override { |
169 | 8 | rapidjson::Document doc; |
170 | 8 | doc.SetObject(); |
171 | 8 | auto& allocator = doc.GetAllocator(); |
172 | | |
173 | | /*{ |
174 | | "model": "xxx", |
175 | | "input": [ |
176 | | "xxx", |
177 | | "xxx", |
178 | | ... |
179 | | ], |
180 | | "output_dimensions": 512 |
181 | | }*/ |
182 | 8 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
183 | 8 | add_dimension_params(doc, allocator); |
184 | | |
185 | 8 | rapidjson::Value input(rapidjson::kArrayType); |
186 | 8 | for (const auto& msg : inputs) { |
187 | 8 | input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator); |
188 | 8 | } |
189 | 8 | doc.AddMember("input", input, allocator); |
190 | | |
191 | 8 | rapidjson::StringBuffer buffer; |
192 | 8 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
193 | 8 | doc.Accept(writer); |
194 | 8 | request_body = buffer.GetString(); |
195 | | |
196 | 8 | return Status::OK(); |
197 | 8 | } |
198 | | |
199 | | Status parse_embedding_response(const std::string& response_body, |
200 | 5 | std::vector<std::vector<float>>& results) const override { |
201 | 5 | rapidjson::Document doc; |
202 | 5 | doc.Parse(response_body.c_str()); |
203 | | |
204 | 5 | if (doc.HasParseError() || !doc.IsObject()) { |
205 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
206 | 1 | response_body); |
207 | 1 | } |
208 | 4 | if (!doc.HasMember("data") || !doc["data"].IsArray()) { |
209 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
210 | 1 | response_body); |
211 | 1 | } |
212 | | |
213 | | /*{ |
214 | | "data":[ |
215 | | { |
216 | | "object": "embedding", |
217 | | "embedding": [...], <- only need this |
218 | | "index": 0 |
219 | | }, |
220 | | { |
221 | | "object": "embedding", |
222 | | "embedding": [...], |
223 | | "index": 1 |
224 | | }, ... |
225 | | ], |
226 | | "model".... |
227 | | }*/ |
228 | 3 | const auto& data = doc["data"]; |
229 | 3 | results.reserve(data.Size()); |
230 | 7 | for (rapidjson::SizeType i = 0; i < data.Size(); i++) { |
231 | 5 | if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) { |
232 | 1 | return Status::InternalError("Invalid {} response format: {}", |
233 | 1 | _config.provider_type, response_body); |
234 | 1 | } |
235 | | |
236 | 4 | std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(), |
237 | 4 | std::back_inserter(results.emplace_back()), |
238 | 10 | [](const auto& val) { return val.GetFloat(); }); |
239 | 4 | } |
240 | | |
241 | 2 | return Status::OK(); |
242 | 3 | } |
243 | | |
244 | | protected: |
245 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
246 | 2 | static const std::unordered_set<std::string> no_dimension_models = { |
247 | 2 | "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2", |
248 | 2 | "voyage-multimodal-3"}; |
249 | 2 | return !no_dimension_models.contains(model_name); |
250 | 2 | } |
251 | | |
252 | 1 | std::string get_dimension_param_name() const override { return "output_dimension"; } |
253 | | }; |
254 | | |
255 | | // Local AI adapter for locally hosted models (Ollama, LLaMA, etc.) |
256 | | class LocalAdapter : public AIAdapter { |
257 | | public: |
258 | | // Local deployments typically don't need authentication |
259 | 2 | Status set_authentication(HttpClient* client) const override { |
260 | 2 | client->set_content_type("application/json"); |
261 | 2 | return Status::OK(); |
262 | 2 | } |
263 | | |
264 | | Status build_request_payload(const std::vector<std::string>& inputs, |
265 | | const char* const system_prompt, |
266 | 3 | std::string& request_body) const override { |
267 | 3 | rapidjson::Document doc; |
268 | 3 | doc.SetObject(); |
269 | 3 | auto& allocator = doc.GetAllocator(); |
270 | | |
271 | 3 | std::string end_point = _config.endpoint; |
272 | 3 | if (end_point.ends_with("chat") || end_point.ends_with("generate")) { |
273 | 2 | RETURN_IF_ERROR( |
274 | 2 | build_ollama_request(doc, allocator, inputs, system_prompt, request_body)); |
275 | 2 | } else { |
276 | 1 | RETURN_IF_ERROR( |
277 | 1 | build_default_request(doc, allocator, inputs, system_prompt, request_body)); |
278 | 1 | } |
279 | | |
280 | 3 | rapidjson::StringBuffer buffer; |
281 | 3 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
282 | 3 | doc.Accept(writer); |
283 | 3 | request_body = buffer.GetString(); |
284 | | |
285 | 3 | return Status::OK(); |
286 | 3 | } |
287 | | |
288 | | Status parse_response(const std::string& response_body, |
289 | 7 | std::vector<std::string>& results) const override { |
290 | 7 | rapidjson::Document doc; |
291 | 7 | doc.Parse(response_body.c_str()); |
292 | | |
293 | 7 | if (doc.HasParseError() || !doc.IsObject()) { |
294 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
295 | 1 | response_body); |
296 | 1 | } |
297 | | |
298 | | // Handle various response formats from local LLMs |
299 | | // Format 1: OpenAI-compatible format with choices/message/content |
300 | 6 | if (doc.HasMember("choices") && doc["choices"].IsArray()) { |
301 | 1 | const auto& choices = doc["choices"]; |
302 | 1 | results.reserve(choices.Size()); |
303 | | |
304 | 2 | for (rapidjson::SizeType i = 0; i < choices.Size(); i++) { |
305 | 1 | if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") && |
306 | 1 | choices[i]["message"]["content"].IsString()) { |
307 | 1 | results.emplace_back(choices[i]["message"]["content"].GetString()); |
308 | 1 | } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) { |
309 | | // Some local LLMs use a simpler format |
310 | 0 | results.emplace_back(choices[i]["text"].GetString()); |
311 | 0 | } |
312 | 1 | } |
313 | 5 | } else if (doc.HasMember("text") && doc["text"].IsString()) { |
314 | | // Format 2: Simple response with just "text" or "content" field |
315 | 1 | results.emplace_back(doc["text"].GetString()); |
316 | 4 | } else if (doc.HasMember("content") && doc["content"].IsString()) { |
317 | 1 | results.emplace_back(doc["content"].GetString()); |
318 | 3 | } else if (doc.HasMember("response") && doc["response"].IsString()) { |
319 | | // Format 3: Response field (Ollama `generate` format) |
320 | 1 | results.emplace_back(doc["response"].GetString()); |
321 | 2 | } else if (doc.HasMember("message") && doc["message"].IsObject() && |
322 | 2 | doc["message"].HasMember("content") && doc["message"]["content"].IsString()) { |
323 | | // Format 4: message/content field (Ollama `chat` format) |
324 | 1 | results.emplace_back(doc["message"]["content"].GetString()); |
325 | 1 | } else { |
326 | 1 | return Status::NotSupported("Unsupported response format from local AI."); |
327 | 1 | } |
328 | | |
329 | 5 | return Status::OK(); |
330 | 6 | } |
331 | | |
332 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
333 | 1 | std::string& request_body) const override { |
334 | 1 | rapidjson::Document doc; |
335 | 1 | doc.SetObject(); |
336 | 1 | auto& allocator = doc.GetAllocator(); |
337 | | |
338 | 1 | if (!_config.model_name.empty()) { |
339 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), |
340 | 1 | allocator); |
341 | 1 | } |
342 | | |
343 | 1 | add_dimension_params(doc, allocator); |
344 | | |
345 | 1 | rapidjson::Value input(rapidjson::kArrayType); |
346 | 1 | for (const auto& msg : inputs) { |
347 | 1 | input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator); |
348 | 1 | } |
349 | 1 | doc.AddMember("input", input, allocator); |
350 | | |
351 | 1 | rapidjson::StringBuffer buffer; |
352 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
353 | 1 | doc.Accept(writer); |
354 | 1 | request_body = buffer.GetString(); |
355 | | |
356 | 1 | return Status::OK(); |
357 | 1 | } |
358 | | |
359 | | Status parse_embedding_response(const std::string& response_body, |
360 | 3 | std::vector<std::vector<float>>& results) const override { |
361 | 3 | rapidjson::Document doc; |
362 | 3 | doc.Parse(response_body.c_str()); |
363 | | |
364 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
365 | 0 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
366 | 0 | response_body); |
367 | 0 | } |
368 | | |
369 | | // parse different response format |
370 | 3 | rapidjson::Value embedding; |
371 | 3 | if (doc.HasMember("data") && doc["data"].IsArray()) { |
372 | | // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0] |
373 | 1 | const auto& data = doc["data"]; |
374 | 1 | results.reserve(data.Size()); |
375 | 3 | for (rapidjson::SizeType i = 0; i < data.Size(); i++) { |
376 | 2 | if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) { |
377 | 0 | return Status::InternalError("Invalid {} response format", |
378 | 0 | _config.provider_type); |
379 | 0 | } |
380 | | |
381 | 2 | std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(), |
382 | 2 | std::back_inserter(results.emplace_back()), |
383 | 5 | [](const auto& val) { return val.GetFloat(); }); |
384 | 2 | } |
385 | 2 | } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) { |
386 | | // "embeddings":[[0.1, 0.2, ...]] |
387 | 1 | results.reserve(1); |
388 | 2 | for (int i = 0; i < doc["embeddings"].Size(); i++) { |
389 | 1 | embedding = doc["embeddings"][i]; |
390 | 1 | std::transform(embedding.Begin(), embedding.End(), |
391 | 1 | std::back_inserter(results.emplace_back()), |
392 | 2 | [](const auto& val) { return val.GetFloat(); }); |
393 | 1 | } |
394 | 1 | } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) { |
395 | | // "embedding":[0.1, 0.2, ...] |
396 | 1 | results.reserve(1); |
397 | 1 | embedding = doc["embedding"]; |
398 | 1 | std::transform(embedding.Begin(), embedding.End(), |
399 | 1 | std::back_inserter(results.emplace_back()), |
400 | 3 | [](const auto& val) { return val.GetFloat(); }); |
401 | 1 | } else { |
402 | 0 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
403 | 0 | response_body); |
404 | 0 | } |
405 | | |
406 | 3 | return Status::OK(); |
407 | 3 | } |
408 | | |
409 | | private: |
410 | | Status build_ollama_request(rapidjson::Document& doc, |
411 | | rapidjson::Document::AllocatorType& allocator, |
412 | | const std::vector<std::string>& inputs, |
413 | 2 | const char* const system_prompt, std::string& request_body) const { |
414 | | /* |
415 | | for endpoints end_with `/chat` like 'http://localhost:11434/api/chat': |
416 | | { |
417 | | "model": <model_name>, |
418 | | "stream": false, |
419 | | "think": false, |
420 | | "options": { |
421 | | "temperature": <temperature>, |
422 | | "max_token": <max_token> |
423 | | }, |
424 | | "messages": [ |
425 | | {"role": "system", "content": <system_prompt>}, |
426 | | {"role": "user", "content": <user_prompt>} |
427 | | ] |
428 | | } |
429 | | |
430 | | for endpoints end_with `/generate` like 'http://localhost:11434/api/generate': |
431 | | { |
432 | | "model": <model_name>, |
433 | | "stream": false, |
434 | | "think": false |
435 | | "options": { |
436 | | "temperature": <temperature>, |
437 | | "max_token": <max_token> |
438 | | }, |
439 | | "system": <system_prompt>, |
440 | | "prompt": <user_prompt> |
441 | | } |
442 | | */ |
443 | | |
444 | | // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint; |
445 | | // The rest remains identical. |
446 | 2 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
447 | 2 | doc.AddMember("stream", false, allocator); |
448 | 2 | doc.AddMember("think", false, allocator); |
449 | | |
450 | | // option section |
451 | 2 | rapidjson::Value options(rapidjson::kObjectType); |
452 | 2 | if (_config.temperature != -1) { |
453 | 2 | options.AddMember("temperature", _config.temperature, allocator); |
454 | 2 | } |
455 | 2 | if (_config.max_tokens != -1) { |
456 | 2 | options.AddMember("max_token", _config.max_tokens, allocator); |
457 | 2 | } |
458 | 2 | doc.AddMember("options", options, allocator); |
459 | | |
460 | | // prompt section |
461 | 2 | if (_config.endpoint.ends_with("chat")) { |
462 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
463 | 1 | if (system_prompt && *system_prompt) { |
464 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
465 | 1 | sys_msg.AddMember("role", "system", allocator); |
466 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
467 | 1 | messages.PushBack(sys_msg, allocator); |
468 | 1 | } |
469 | 1 | for (const auto& input : inputs) { |
470 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
471 | 1 | message.AddMember("role", "user", allocator); |
472 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
473 | 1 | messages.PushBack(message, allocator); |
474 | 1 | } |
475 | 1 | doc.AddMember("messages", messages, allocator); |
476 | 1 | } else { |
477 | 1 | if (system_prompt && *system_prompt) { |
478 | 1 | doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator); |
479 | 1 | } |
480 | 1 | doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator); |
481 | 1 | } |
482 | | |
483 | 2 | return Status::OK(); |
484 | 2 | } |
485 | | |
486 | | Status build_default_request(rapidjson::Document& doc, |
487 | | rapidjson::Document::AllocatorType& allocator, |
488 | | const std::vector<std::string>& inputs, |
489 | 1 | const char* const system_prompt, std::string& request_body) const { |
490 | | /* |
491 | | Default format(OpenAI-compatible): |
492 | | { |
493 | | "model": <model_name>, |
494 | | "temperature": <temperature>, |
495 | | "max_tokens": <max_tokens>, |
496 | | "messages": [ |
497 | | {"role": "system", "content": <system_prompt>}, |
498 | | {"role": "user", "content": <user_prompt>} |
499 | | ] |
500 | | } |
501 | | */ |
502 | | |
503 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
504 | | |
505 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
506 | 1 | if (_config.temperature != -1) { |
507 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
508 | 1 | } |
509 | 1 | if (_config.max_tokens != -1) { |
510 | 1 | doc.AddMember("max_tokens", _config.max_tokens, allocator); |
511 | 1 | } |
512 | | |
513 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
514 | 1 | if (system_prompt && *system_prompt) { |
515 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
516 | 1 | sys_msg.AddMember("role", "system", allocator); |
517 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
518 | 1 | messages.PushBack(sys_msg, allocator); |
519 | 1 | } |
520 | 1 | for (const auto& input : inputs) { |
521 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
522 | 1 | message.AddMember("role", "user", allocator); |
523 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
524 | 1 | messages.PushBack(message, allocator); |
525 | 1 | } |
526 | 1 | doc.AddMember("messages", messages, allocator); |
527 | 1 | return Status::OK(); |
528 | 1 | } |
529 | | }; |
530 | | |
531 | | // The OpenAI API format can be reused with some compatible AIs. |
532 | | class OpenAIAdapter : public VoyageAIAdapter { |
533 | | public: |
534 | 8 | Status set_authentication(HttpClient* client) const override { |
535 | 8 | client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key); |
536 | 8 | client->set_content_type("application/json"); |
537 | | |
538 | 8 | return Status::OK(); |
539 | 8 | } |
540 | | |
541 | | Status build_request_payload(const std::vector<std::string>& inputs, |
542 | | const char* const system_prompt, |
543 | 2 | std::string& request_body) const override { |
544 | 2 | rapidjson::Document doc; |
545 | 2 | doc.SetObject(); |
546 | 2 | auto& allocator = doc.GetAllocator(); |
547 | | |
548 | 2 | if (_config.endpoint.ends_with("responses")) { |
549 | | /*{ |
550 | | "model": "gpt-4.1-mini", |
551 | | "input": [ |
552 | | {"role": "system", "content": "system_prompt here"}, |
553 | | {"role": "user", "content": "xxx"} |
554 | | ], |
555 | | "temperature": 0.7, |
556 | | "max_output_tokens": 150 |
557 | | }*/ |
558 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), |
559 | 1 | allocator); |
560 | | |
561 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
562 | 1 | if (_config.temperature != -1) { |
563 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
564 | 1 | } |
565 | 1 | if (_config.max_tokens != -1) { |
566 | 1 | doc.AddMember("max_output_tokens", _config.max_tokens, allocator); |
567 | 1 | } |
568 | | |
569 | | // input |
570 | 1 | rapidjson::Value input(rapidjson::kArrayType); |
571 | 1 | if (system_prompt && *system_prompt) { |
572 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
573 | 1 | sys_msg.AddMember("role", "system", allocator); |
574 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
575 | 1 | input.PushBack(sys_msg, allocator); |
576 | 1 | } |
577 | 1 | for (const auto& msg : inputs) { |
578 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
579 | 1 | message.AddMember("role", "user", allocator); |
580 | 1 | message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator); |
581 | 1 | input.PushBack(message, allocator); |
582 | 1 | } |
583 | 1 | doc.AddMember("input", input, allocator); |
584 | 1 | } else { |
585 | | /*{ |
586 | | "model": "gpt-4", |
587 | | "messages": [ |
588 | | {"role": "system", "content": "system_prompt here"}, |
589 | | {"role": "user", "content": "xxx"} |
590 | | ], |
591 | | "temperature": x, |
592 | | "max_tokens": x, |
593 | | }*/ |
594 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), |
595 | 1 | allocator); |
596 | | |
597 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
598 | 1 | if (_config.temperature != -1) { |
599 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
600 | 1 | } |
601 | 1 | if (_config.max_tokens != -1) { |
602 | 1 | doc.AddMember("max_tokens", _config.max_tokens, allocator); |
603 | 1 | } |
604 | | |
605 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
606 | 1 | if (system_prompt && *system_prompt) { |
607 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
608 | 1 | sys_msg.AddMember("role", "system", allocator); |
609 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
610 | 1 | messages.PushBack(sys_msg, allocator); |
611 | 1 | } |
612 | 1 | for (const auto& input : inputs) { |
613 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
614 | 1 | message.AddMember("role", "user", allocator); |
615 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
616 | 1 | messages.PushBack(message, allocator); |
617 | 1 | } |
618 | 1 | doc.AddMember("messages", messages, allocator); |
619 | 1 | } |
620 | | |
621 | 2 | rapidjson::StringBuffer buffer; |
622 | 2 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
623 | 2 | doc.Accept(writer); |
624 | 2 | request_body = buffer.GetString(); |
625 | | |
626 | 2 | return Status::OK(); |
627 | 2 | } |
628 | | |
629 | | Status parse_response(const std::string& response_body, |
630 | 6 | std::vector<std::string>& results) const override { |
631 | 6 | rapidjson::Document doc; |
632 | 6 | doc.Parse(response_body.c_str()); |
633 | | |
634 | 6 | if (doc.HasParseError() || !doc.IsObject()) { |
635 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
636 | 1 | response_body); |
637 | 1 | } |
638 | | |
639 | 5 | if (doc.HasMember("output") && doc["output"].IsArray()) { |
640 | | /// for responses endpoint |
641 | | /*{ |
642 | | "output": [ |
643 | | { |
644 | | "id": "msg_123", |
645 | | "type": "message", |
646 | | "role": "assistant", |
647 | | "content": [ |
648 | | { |
649 | | "type": "text", |
650 | | "text": "result text here" <- result |
651 | | } |
652 | | ] |
653 | | } |
654 | | ] |
655 | | }*/ |
656 | 1 | const auto& output = doc["output"]; |
657 | 1 | results.reserve(output.Size()); |
658 | | |
659 | 2 | for (rapidjson::SizeType i = 0; i < output.Size(); i++) { |
660 | 1 | if (!output[i].HasMember("content") || !output[i]["content"].IsArray() || |
661 | 1 | output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") || |
662 | 1 | !output[i]["content"][0]["text"].IsString()) { |
663 | 0 | return Status::InternalError("Invalid output format in {} response: {}", |
664 | 0 | _config.provider_type, response_body); |
665 | 0 | } |
666 | | |
667 | 1 | results.emplace_back(output[i]["content"][0]["text"].GetString()); |
668 | 1 | } |
669 | 4 | } else if (doc.HasMember("choices") && doc["choices"].IsArray()) { |
670 | | /// for completions endpoint |
671 | | /*{ |
672 | | "object": "chat.completion", |
673 | | "model": "gpt-4", |
674 | | "choices": [ |
675 | | { |
676 | | ... |
677 | | "message": { |
678 | | "role": "assistant", |
679 | | "content": "xxx" <- result |
680 | | }, |
681 | | ... |
682 | | } |
683 | | ], |
684 | | ... |
685 | | }*/ |
686 | 3 | const auto& choices = doc["choices"]; |
687 | 3 | results.reserve(choices.Size()); |
688 | | |
689 | 4 | for (rapidjson::SizeType i = 0; i < choices.Size(); i++) { |
690 | 3 | if (!choices[i].HasMember("message") || |
691 | 3 | !choices[i]["message"].HasMember("content") || |
692 | 3 | !choices[i]["message"]["content"].IsString()) { |
693 | 2 | return Status::InternalError("Invalid choice format in {} response: {}", |
694 | 2 | _config.provider_type, response_body); |
695 | 2 | } |
696 | | |
697 | 1 | results.emplace_back(choices[i]["message"]["content"].GetString()); |
698 | 1 | } |
699 | 3 | } else { |
700 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
701 | 1 | response_body); |
702 | 1 | } |
703 | | |
704 | 2 | return Status::OK(); |
705 | 5 | } |
706 | | |
707 | | protected: |
708 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
709 | 2 | return !(model_name == "text-embedding-ada-002"); |
710 | 2 | } |
711 | | |
712 | 2 | std::string get_dimension_param_name() const override { return "dimensions"; } |
713 | | }; |
714 | | |
715 | | class DeepSeekAdapter : public OpenAIAdapter { |
716 | | public: |
717 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
718 | 1 | std::string& request_body) const override { |
719 | 1 | return Status::NotSupported("{} does not support the Embed feature.", |
720 | 1 | _config.provider_type); |
721 | 1 | } |
722 | | |
723 | | Status parse_embedding_response(const std::string& response_body, |
724 | 1 | std::vector<std::vector<float>>& results) const override { |
725 | 1 | return Status::NotSupported("{} does not support the Embed feature.", |
726 | 1 | _config.provider_type); |
727 | 1 | } |
728 | | }; |
729 | | |
730 | | class MoonShotAdapter : public OpenAIAdapter { |
731 | | public: |
732 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
733 | 1 | std::string& request_body) const override { |
734 | 1 | return Status::NotSupported("{} does not support the Embed feature.", |
735 | 1 | _config.provider_type); |
736 | 1 | } |
737 | | |
738 | | Status parse_embedding_response(const std::string& response_body, |
739 | 1 | std::vector<std::vector<float>>& results) const override { |
740 | 1 | return Status::NotSupported("{} does not support the Embed feature.", |
741 | 1 | _config.provider_type); |
742 | 1 | } |
743 | | }; |
744 | | |
745 | | class MinimaxAdapter : public OpenAIAdapter { |
746 | | public: |
747 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
748 | 1 | std::string& request_body) const override { |
749 | 1 | rapidjson::Document doc; |
750 | 1 | doc.SetObject(); |
751 | 1 | auto& allocator = doc.GetAllocator(); |
752 | | |
753 | | /*{ |
754 | | "text": ["xxx", "xxx", ...], |
755 | | "model": "embo-1", |
756 | | "type": "db" |
757 | | }*/ |
758 | 1 | rapidjson::Value texts(rapidjson::kArrayType); |
759 | 1 | for (const auto& input : inputs) { |
760 | 1 | texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator); |
761 | 1 | } |
762 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
763 | 1 | doc.AddMember("texts", texts, allocator); |
764 | 1 | doc.AddMember("type", rapidjson::Value("db", allocator), allocator); |
765 | | |
766 | 1 | rapidjson::StringBuffer buffer; |
767 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
768 | 1 | doc.Accept(writer); |
769 | 1 | request_body = buffer.GetString(); |
770 | | |
771 | 1 | return Status::OK(); |
772 | 1 | } |
773 | | }; |
774 | | |
775 | | class ZhipuAdapter : public OpenAIAdapter { |
776 | | protected: |
777 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
778 | 2 | return !(model_name == "embedding-2"); |
779 | 2 | } |
780 | | }; |
781 | | |
782 | | class QwenAdapter : public OpenAIAdapter { |
783 | | protected: |
784 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
785 | 2 | static const std::unordered_set<std::string> no_dimension_models = { |
786 | 2 | "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"}; |
787 | 2 | return !no_dimension_models.contains(model_name); |
788 | 2 | } |
789 | | |
790 | 1 | std::string get_dimension_param_name() const override { return "dimension"; } |
791 | | }; |
792 | | |
793 | | class BaichuanAdapter : public OpenAIAdapter { |
794 | | protected: |
795 | 0 | bool supports_dimension_param(const std::string& model_name) const override { return false; } |
796 | | }; |
797 | | |
798 | | // Gemini's embedding format is different from VoyageAI, so it requires a separate adapter |
799 | | class GeminiAdapter : public AIAdapter { |
800 | | public: |
801 | 2 | Status set_authentication(HttpClient* client) const override { |
802 | 2 | client->set_header("x-goog-api-key", _config.api_key); |
803 | 2 | client->set_content_type("application/json"); |
804 | 2 | return Status::OK(); |
805 | 2 | } |
806 | | |
807 | | Status build_request_payload(const std::vector<std::string>& inputs, |
808 | | const char* const system_prompt, |
809 | 1 | std::string& request_body) const override { |
810 | 1 | rapidjson::Document doc; |
811 | 1 | doc.SetObject(); |
812 | 1 | auto& allocator = doc.GetAllocator(); |
813 | | |
814 | | /*{ |
815 | | "systemInstruction": { |
816 | | "parts": [ |
817 | | { |
818 | | "text": "system_prompt here" |
819 | | } |
820 | | ] |
821 | | } |
822 | | ], |
823 | | "contents": [ |
824 | | { |
825 | | "parts": [ |
826 | | { |
827 | | "text": "xxx" |
828 | | } |
829 | | ] |
830 | | } |
831 | | ], |
832 | | "generationConfig": { |
833 | | "temperature": 0.7, |
834 | | "maxOutputTokens": 1024 |
835 | | } |
836 | | |
837 | | }*/ |
838 | 1 | if (system_prompt && *system_prompt) { |
839 | 1 | rapidjson::Value system_instruction(rapidjson::kObjectType); |
840 | 1 | rapidjson::Value parts(rapidjson::kArrayType); |
841 | | |
842 | 1 | rapidjson::Value part(rapidjson::kObjectType); |
843 | 1 | part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator); |
844 | 1 | parts.PushBack(part, allocator); |
845 | | // system_instruction.PushBack(content, allocator); |
846 | 1 | system_instruction.AddMember("parts", parts, allocator); |
847 | 1 | doc.AddMember("systemInstruction", system_instruction, allocator); |
848 | 1 | } |
849 | | |
850 | 1 | rapidjson::Value contents(rapidjson::kArrayType); |
851 | 1 | for (const auto& input : inputs) { |
852 | 1 | rapidjson::Value content(rapidjson::kObjectType); |
853 | 1 | rapidjson::Value parts(rapidjson::kArrayType); |
854 | | |
855 | 1 | rapidjson::Value part(rapidjson::kObjectType); |
856 | 1 | part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator); |
857 | | |
858 | 1 | parts.PushBack(part, allocator); |
859 | 1 | content.AddMember("parts", parts, allocator); |
860 | 1 | contents.PushBack(content, allocator); |
861 | 1 | } |
862 | 1 | doc.AddMember("contents", contents, allocator); |
863 | | |
864 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
865 | 1 | rapidjson::Value generationConfig(rapidjson::kObjectType); |
866 | 1 | if (_config.temperature != -1) { |
867 | 1 | generationConfig.AddMember("temperature", _config.temperature, allocator); |
868 | 1 | } |
869 | 1 | if (_config.max_tokens != -1) { |
870 | 1 | generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator); |
871 | 1 | } |
872 | 1 | doc.AddMember("generationConfig", generationConfig, allocator); |
873 | | |
874 | 1 | rapidjson::StringBuffer buffer; |
875 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
876 | 1 | doc.Accept(writer); |
877 | 1 | request_body = buffer.GetString(); |
878 | | |
879 | 1 | return Status::OK(); |
880 | 1 | } |
881 | | |
882 | | Status parse_response(const std::string& response_body, |
883 | 3 | std::vector<std::string>& results) const override { |
884 | 3 | rapidjson::Document doc; |
885 | 3 | doc.Parse(response_body.c_str()); |
886 | | |
887 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
888 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
889 | 1 | response_body); |
890 | 1 | } |
891 | 2 | if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) { |
892 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
893 | 1 | response_body); |
894 | 1 | } |
895 | | |
896 | | /*{ |
897 | | "candidates":[ |
898 | | { |
899 | | "content": { |
900 | | "parts": [ |
901 | | { |
902 | | "text": "xxx" |
903 | | } |
904 | | ] |
905 | | } |
906 | | } |
907 | | ] |
908 | | }*/ |
909 | 1 | const auto& candidates = doc["candidates"]; |
910 | 1 | results.reserve(candidates.Size()); |
911 | | |
912 | 2 | for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) { |
913 | 1 | if (!candidates[i].HasMember("content") || |
914 | 1 | !candidates[i]["content"].HasMember("parts") || |
915 | 1 | !candidates[i]["content"]["parts"].IsArray() || |
916 | 1 | candidates[i]["content"]["parts"].Empty() || |
917 | 1 | !candidates[i]["content"]["parts"][0].HasMember("text") || |
918 | 1 | !candidates[i]["content"]["parts"][0]["text"].IsString()) { |
919 | 0 | return Status::InternalError("Invalid candidate format in {} response", |
920 | 0 | _config.provider_type); |
921 | 0 | } |
922 | | |
923 | 1 | results.emplace_back(candidates[i]["content"]["parts"][0]["text"].GetString()); |
924 | 1 | } |
925 | | |
926 | 1 | return Status::OK(); |
927 | 1 | } |
928 | | |
929 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
930 | 2 | std::string& request_body) const override { |
931 | 2 | rapidjson::Document doc; |
932 | 2 | doc.SetObject(); |
933 | 2 | auto& allocator = doc.GetAllocator(); |
934 | | |
935 | | /*{ |
936 | | "model": "models/gemini-embedding-001", |
937 | | "content": { |
938 | | "parts": [ |
939 | | { |
940 | | "text": "xxx" |
941 | | } |
942 | | ] |
943 | | } |
944 | | "outputDimensionality": 1024 |
945 | | }*/ |
946 | | |
947 | | // gemini requires the model format as `models/{model}` |
948 | 2 | std::string model_name = _config.model_name; |
949 | 2 | if (!model_name.starts_with("models/")) { |
950 | 2 | model_name = "models/" + model_name; |
951 | 2 | } |
952 | 2 | doc.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator); |
953 | 2 | add_dimension_params(doc, allocator); |
954 | | |
955 | 2 | rapidjson::Value content(rapidjson::kObjectType); |
956 | 2 | for (const auto& input : inputs) { |
957 | 2 | rapidjson::Value parts(rapidjson::kArrayType); |
958 | 2 | rapidjson::Value part(rapidjson::kObjectType); |
959 | 2 | part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator); |
960 | 2 | parts.PushBack(part, allocator); |
961 | 2 | content.AddMember("parts", parts, allocator); |
962 | 2 | } |
963 | 2 | doc.AddMember("content", content, allocator); |
964 | | |
965 | 2 | rapidjson::StringBuffer buffer; |
966 | 2 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
967 | 2 | doc.Accept(writer); |
968 | 2 | request_body = buffer.GetString(); |
969 | | |
970 | 2 | return Status::OK(); |
971 | 2 | } |
972 | | |
973 | | Status parse_embedding_response(const std::string& response_body, |
974 | 1 | std::vector<std::vector<float>>& results) const override { |
975 | 1 | rapidjson::Document doc; |
976 | 1 | doc.Parse(response_body.c_str()); |
977 | | |
978 | 1 | if (doc.HasParseError() || !doc.IsObject()) { |
979 | 0 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
980 | 0 | response_body); |
981 | 0 | } |
982 | 1 | if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) { |
983 | 0 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
984 | 0 | response_body); |
985 | 0 | } |
986 | | |
987 | | /*{ |
988 | | "embedding":{ |
989 | | "values": [0.1, 0.2, 0.3] |
990 | | } |
991 | | }*/ |
992 | 1 | const auto& embedding = doc["embedding"]; |
993 | 1 | if (!embedding.HasMember("values") || !embedding["values"].IsArray()) { |
994 | 0 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
995 | 0 | response_body); |
996 | 0 | } |
997 | 1 | std::transform(embedding["values"].Begin(), embedding["values"].End(), |
998 | 1 | std::back_inserter(results.emplace_back()), |
999 | 3 | [](const auto& val) { return val.GetFloat(); }); |
1000 | | |
1001 | 1 | return Status::OK(); |
1002 | 1 | } |
1003 | | |
1004 | | protected: |
1005 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
1006 | 2 | static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001", |
1007 | 2 | "embedding-001"}; |
1008 | 2 | return !no_dimension_models.contains(model_name); |
1009 | 2 | } |
1010 | | |
1011 | 1 | std::string get_dimension_param_name() const override { return "outputDimensionality"; } |
1012 | | }; |
1013 | | |
1014 | | class AnthropicAdapter : public VoyageAIAdapter { |
1015 | | public: |
1016 | 1 | Status set_authentication(HttpClient* client) const override { |
1017 | 1 | client->set_header("x-api-key", _config.api_key); |
1018 | 1 | client->set_header("anthropic-version", _config.anthropic_version); |
1019 | 1 | client->set_content_type("application/json"); |
1020 | | |
1021 | 1 | return Status::OK(); |
1022 | 1 | } |
1023 | | |
1024 | | Status build_request_payload(const std::vector<std::string>& inputs, |
1025 | | const char* const system_prompt, |
1026 | 1 | std::string& request_body) const override { |
1027 | 1 | rapidjson::Document doc; |
1028 | 1 | doc.SetObject(); |
1029 | 1 | auto& allocator = doc.GetAllocator(); |
1030 | | |
1031 | | /* |
1032 | | "model": "claude-opus-4-1-20250805", |
1033 | | "max_tokens": 1024, |
1034 | | "system": "system_prompt here", |
1035 | | "messages": [ |
1036 | | {"role": "user", "content": "xxx"} |
1037 | | ], |
1038 | | "temperature": 0.7 |
1039 | | */ |
1040 | | |
1041 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
1042 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
1043 | 1 | if (_config.temperature != -1) { |
1044 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
1045 | 1 | } |
1046 | 1 | if (_config.max_tokens != -1) { |
1047 | 1 | doc.AddMember("max_tokens", _config.max_tokens, allocator); |
1048 | 1 | } else { |
1049 | | // Keep the default value, Anthropic requires this parameter |
1050 | 0 | doc.AddMember("max_tokens", 2048, allocator); |
1051 | 0 | } |
1052 | 1 | if (system_prompt && *system_prompt) { |
1053 | 1 | doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator); |
1054 | 1 | } |
1055 | | |
1056 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
1057 | 1 | for (const auto& input : inputs) { |
1058 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
1059 | 1 | message.AddMember("role", "user", allocator); |
1060 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
1061 | 1 | messages.PushBack(message, allocator); |
1062 | 1 | } |
1063 | 1 | doc.AddMember("messages", messages, allocator); |
1064 | | |
1065 | 1 | rapidjson::StringBuffer buffer; |
1066 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1067 | 1 | doc.Accept(writer); |
1068 | 1 | request_body = buffer.GetString(); |
1069 | | |
1070 | 1 | return Status::OK(); |
1071 | 1 | } |
1072 | | |
1073 | | Status parse_response(const std::string& response_body, |
1074 | 3 | std::vector<std::string>& results) const override { |
1075 | 3 | rapidjson::Document doc; |
1076 | 3 | doc.Parse(response_body.c_str()); |
1077 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
1078 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
1079 | 1 | response_body); |
1080 | 1 | } |
1081 | 2 | if (!doc.HasMember("content") || !doc["content"].IsArray()) { |
1082 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
1083 | 1 | response_body); |
1084 | 1 | } |
1085 | | |
1086 | | /*{ |
1087 | | "content": [ |
1088 | | { |
1089 | | "text": "xxx", |
1090 | | "type": "text" |
1091 | | } |
1092 | | ] |
1093 | | }*/ |
1094 | 1 | const auto& content = doc["content"]; |
1095 | 1 | results.reserve(1); |
1096 | | |
1097 | 1 | std::string result; |
1098 | 2 | for (rapidjson::SizeType i = 0; i < content.Size(); i++) { |
1099 | 1 | if (!content[i].HasMember("type") || !content[i]["type"].IsString() || |
1100 | 1 | !content[i].HasMember("text") || !content[i]["text"].IsString()) { |
1101 | 0 | continue; |
1102 | 0 | } |
1103 | | |
1104 | 1 | if (std::string(content[i]["type"].GetString()) == "text") { |
1105 | 1 | if (!result.empty()) { |
1106 | 0 | result += "\n"; |
1107 | 0 | } |
1108 | 1 | result += content[i]["text"].GetString(); |
1109 | 1 | } |
1110 | 1 | } |
1111 | | |
1112 | 1 | results.emplace_back(std::move(result)); |
1113 | 1 | return Status::OK(); |
1114 | 2 | } |
1115 | | }; |
1116 | | |
1117 | | // Mock adapter used only for UT to bypass real HTTP calls and return deterministic data. |
1118 | | class MockAdapter : public AIAdapter { |
1119 | | public: |
1120 | 0 | Status set_authentication(HttpClient* client) const override { return Status::OK(); } |
1121 | | |
1122 | | Status build_request_payload(const std::vector<std::string>& inputs, |
1123 | | const char* const system_prompt, |
1124 | 49 | std::string& request_body) const override { |
1125 | 49 | return Status::OK(); |
1126 | 49 | } |
1127 | | |
1128 | | Status parse_response(const std::string& response_body, |
1129 | 49 | std::vector<std::string>& results) const override { |
1130 | 49 | results.emplace_back(response_body); |
1131 | 49 | return Status::OK(); |
1132 | 49 | } |
1133 | | |
1134 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
1135 | 1 | std::string& request_body) const override { |
1136 | 1 | return Status::OK(); |
1137 | 1 | } |
1138 | | |
1139 | | Status parse_embedding_response(const std::string& response_body, |
1140 | 1 | std::vector<std::vector<float>>& results) const override { |
1141 | 1 | rapidjson::Document doc; |
1142 | 1 | doc.SetObject(); |
1143 | 1 | doc.Parse(response_body.c_str()); |
1144 | 1 | if (doc.HasParseError() || !doc.IsObject()) { |
1145 | 0 | return Status::InternalError("Failed to parse embedding response"); |
1146 | 0 | } |
1147 | 1 | if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) { |
1148 | 0 | return Status::InternalError("Invalid embedding response format"); |
1149 | 0 | } |
1150 | | |
1151 | 1 | results.reserve(1); |
1152 | 1 | std::transform(doc["embedding"].Begin(), doc["embedding"].End(), |
1153 | 1 | std::back_inserter(results.emplace_back()), |
1154 | 5 | [](const auto& val) { return val.GetFloat(); }); |
1155 | 1 | return Status::OK(); |
1156 | 1 | } |
1157 | | }; |
1158 | | |
1159 | | class AIAdapterFactory { |
1160 | | public: |
1161 | 74 | static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) { |
1162 | 74 | static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>> |
1163 | 74 | adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }}, |
1164 | 74 | {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }}, |
1165 | 74 | {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }}, |
1166 | 74 | {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }}, |
1167 | 74 | {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }}, |
1168 | 74 | {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }}, |
1169 | 74 | {"QWEN", []() { return std::make_shared<QwenAdapter>(); }}, |
1170 | 74 | {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }}, |
1171 | 74 | {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }}, |
1172 | 74 | {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }}, |
1173 | 74 | {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }}, |
1174 | 74 | {"MOCK", []() { return std::make_shared<MockAdapter>(); }}}; |
1175 | | |
1176 | 74 | auto it = adapters.find(provider_type); |
1177 | 74 | return (it != adapters.end()) ? it->second() : nullptr; |
1178 | 74 | } |
1179 | | }; |
1180 | | |
1181 | | #include "common/compile_check_end.h" |
1182 | | } // namespace doris |