be/src/exprs/function/ai/ai_adapter.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <gen_cpp/PaloInternalService_types.h> |
21 | | #include <rapidjson/rapidjson.h> |
22 | | |
23 | | #include <algorithm> |
24 | | #include <cctype> |
25 | | #include <memory> |
26 | | #include <string> |
27 | | #include <unordered_map> |
28 | | #include <vector> |
29 | | |
30 | | #include "common/status.h" |
31 | | #include "core/string_buffer.hpp" |
32 | | #include "rapidjson/document.h" |
33 | | #include "rapidjson/stringbuffer.h" |
34 | | #include "rapidjson/writer.h" |
35 | | #include "service/http/http_client.h" |
36 | | #include "service/http/http_headers.h" |
37 | | #include "util/security.h" |
38 | | |
39 | | namespace doris { |
40 | | |
41 | | struct AIResource { |
42 | 19 | AIResource() = default; |
43 | | AIResource(const TAIResource& tai) |
44 | 12 | : endpoint(tai.endpoint), |
45 | 12 | provider_type(tai.provider_type), |
46 | 12 | model_name(tai.model_name), |
47 | 12 | api_key(tai.api_key), |
48 | 12 | temperature(tai.temperature), |
49 | 12 | max_tokens(tai.max_tokens), |
50 | 12 | max_retries(tai.max_retries), |
51 | 12 | retry_delay_second(tai.retry_delay_second), |
52 | 12 | anthropic_version(tai.anthropic_version), |
53 | 12 | dimensions(tai.dimensions) {} |
54 | | |
55 | | std::string endpoint; |
56 | | std::string provider_type; |
57 | | std::string model_name; |
58 | | std::string api_key; |
59 | | double temperature; |
60 | | int64_t max_tokens; |
61 | | int32_t max_retries; |
62 | | int32_t retry_delay_second; |
63 | | std::string anthropic_version; |
64 | | int32_t dimensions; |
65 | | |
66 | 1 | void serialize(BufferWritable& buf) const { |
67 | 1 | buf.write_binary(endpoint); |
68 | 1 | buf.write_binary(provider_type); |
69 | 1 | buf.write_binary(model_name); |
70 | 1 | buf.write_binary(api_key); |
71 | 1 | buf.write_binary(temperature); |
72 | 1 | buf.write_binary(max_tokens); |
73 | 1 | buf.write_binary(max_retries); |
74 | 1 | buf.write_binary(retry_delay_second); |
75 | 1 | buf.write_binary(anthropic_version); |
76 | 1 | buf.write_binary(dimensions); |
77 | 1 | } |
78 | | |
79 | 1 | void deserialize(BufferReadable& buf) { |
80 | 1 | buf.read_binary(endpoint); |
81 | 1 | buf.read_binary(provider_type); |
82 | 1 | buf.read_binary(model_name); |
83 | 1 | buf.read_binary(api_key); |
84 | 1 | buf.read_binary(temperature); |
85 | 1 | buf.read_binary(max_tokens); |
86 | 1 | buf.read_binary(max_retries); |
87 | 1 | buf.read_binary(retry_delay_second); |
88 | 1 | buf.read_binary(anthropic_version); |
89 | 1 | buf.read_binary(dimensions); |
90 | 1 | } |
91 | | }; |
92 | | |
93 | | enum class MultimodalType { IMAGE, VIDEO, AUDIO }; |
94 | | |
95 | 3 | inline const char* multimodal_type_to_string(MultimodalType type) { |
96 | 3 | switch (type) { |
97 | 1 | case MultimodalType::IMAGE: |
98 | 1 | return "image"; |
99 | 1 | case MultimodalType::VIDEO: |
100 | 1 | return "video"; |
101 | 1 | case MultimodalType::AUDIO: |
102 | 1 | return "audio"; |
103 | 3 | } |
104 | 0 | return "unknown"; |
105 | 3 | } |
106 | | |
107 | | class AIAdapter { |
108 | | public: |
109 | 167 | virtual ~AIAdapter() = default; |
110 | | |
111 | | // Set authentication headers for the HTTP client |
112 | | virtual Status set_authentication(HttpClient* client) const = 0; |
113 | | |
114 | 123 | virtual void init(const TAIResource& config) { _config = config; } |
115 | 13 | virtual void init(const AIResource& config) { |
116 | 13 | _config.endpoint = config.endpoint; |
117 | 13 | _config.provider_type = config.provider_type; |
118 | 13 | _config.model_name = config.model_name; |
119 | 13 | _config.api_key = config.api_key; |
120 | 13 | _config.temperature = config.temperature; |
121 | 13 | _config.max_tokens = config.max_tokens; |
122 | 13 | _config.max_retries = config.max_retries; |
123 | 13 | _config.retry_delay_second = config.retry_delay_second; |
124 | 13 | _config.anthropic_version = config.anthropic_version; |
125 | 13 | } |
126 | | |
127 | | // Build request payload based on input text strings |
128 | | virtual Status build_request_payload(const std::vector<std::string>& inputs, |
129 | | const char* const system_prompt, |
130 | 1 | std::string& request_body) const { |
131 | 1 | return Status::NotSupported("{} don't support text generation", _config.provider_type); |
132 | 1 | } |
133 | | |
134 | | // Parse response from AI service and extract generated text results |
135 | | virtual Status parse_response(const std::string& response_body, |
136 | 1 | std::vector<std::string>& results) const { |
137 | 1 | return Status::NotSupported("{} don't support text generation", _config.provider_type); |
138 | 1 | } |
139 | | |
140 | | virtual Status build_embedding_request(const std::vector<std::string>& inputs, |
141 | 0 | std::string& request_body) const { |
142 | 0 | return embed_not_supported_status(); |
143 | 0 | } |
144 | | |
145 | | virtual Status build_multimodal_embedding_request( |
146 | | const std::vector<MultimodalType>& /*media_types*/, |
147 | | const std::vector<std::string>& /*media_urls*/, |
148 | | const std::vector<std::string>& /*media_content_types*/, |
149 | 0 | std::string& /*request_body*/) const { |
150 | 0 | return Status::NotSupported("{} does not support multimodal Embed feature.", |
151 | 0 | _config.provider_type); |
152 | 0 | } |
153 | | |
154 | | virtual Status parse_embedding_response(const std::string& response_body, |
155 | 0 | std::vector<std::vector<float>>& results) const { |
156 | 0 | return embed_not_supported_status(); |
157 | 0 | } |
158 | | |
159 | | protected: |
160 | | TAIResource _config; |
161 | | |
162 | 4 | Status embed_not_supported_status() const { |
163 | 4 | return Status::NotSupported( |
164 | 4 | "{} does not support the Embed feature. Currently supported providers are " |
165 | 4 | "OpenAI, Gemini, Voyage, Jina, Qwen, and Minimax.", |
166 | 4 | _config.provider_type); |
167 | 4 | } |
168 | | |
169 | | // Appends one provider-parsed text result to `results`. |
170 | | // The adapter has already parsed the provider's outer response envelope before calling here. |
171 | | // Example: |
172 | | // provider response -> choices[0].message.content = "[\"1\",\"0\",\"1\"]" |
173 | | // this helper -> appends "1", "0", "1" into `results` |
174 | | static Status append_parsed_text_result(std::string_view text, |
175 | 87 | std::vector<std::string>& results) { |
176 | 87 | size_t begin = 0; |
177 | 87 | size_t end = text.size(); |
178 | 117 | while (begin < end && std::isspace(static_cast<unsigned char>(text[begin]))) { |
179 | 30 | ++begin; |
180 | 30 | } |
181 | 111 | while (begin < end && std::isspace(static_cast<unsigned char>(text[end - 1]))) { |
182 | 24 | --end; |
183 | 24 | } |
184 | | |
185 | 87 | if (begin < end && text[begin] == '[' && text[end - 1] == ']') { |
186 | 66 | rapidjson::Document doc; |
187 | 66 | doc.Parse(text.data() + begin, end - begin); |
188 | 66 | if (!doc.HasParseError() && doc.IsArray()) { |
189 | 139 | for (rapidjson::SizeType i = 0; i < doc.Size(); ++i) { |
190 | 76 | if (!doc[i].IsString()) { |
191 | 1 | return Status::InternalError( |
192 | 1 | "Invalid batch result format, array element {} is not a string", i); |
193 | 1 | } |
194 | 75 | results.emplace_back(doc[i].GetString(), doc[i].GetStringLength()); |
195 | 75 | } |
196 | 63 | return Status::OK(); |
197 | 64 | } |
198 | 66 | } |
199 | | |
200 | 23 | results.emplace_back(text.data(), text.size()); |
201 | 23 | return Status::OK(); |
202 | 87 | } |
203 | | |
204 | | // return true if the model support dimension parameter |
205 | 1 | virtual bool supports_dimension_param(const std::string& model_name) const { return false; } |
206 | | |
207 | | // Different providers may have different dimension parameter names. |
208 | 0 | virtual std::string get_dimension_param_name() const { return "dimensions"; } |
209 | | |
210 | | virtual void add_dimension_params(rapidjson::Value& doc, |
211 | 20 | rapidjson::Document::AllocatorType& allocator) const { |
212 | 20 | if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) { |
213 | 13 | std::string param_name = get_dimension_param_name(); |
214 | 13 | rapidjson::Value name(param_name.c_str(), allocator); |
215 | 13 | doc.AddMember(name, _config.dimensions, allocator); |
216 | 13 | } |
217 | 20 | } |
218 | | |
219 | | // Validates common multimodal embedding request invariants shared by providers. |
220 | | Status validate_multimodal_embedding_inputs( |
221 | | std::string_view provider_name, const std::vector<MultimodalType>& media_types, |
222 | | const std::vector<std::string>& media_urls, |
223 | 16 | std::initializer_list<MultimodalType> supported_types) const { |
224 | 16 | if (media_urls.empty()) { |
225 | 1 | return Status::InvalidArgument("{} multimodal embed inputs can not be empty", |
226 | 1 | provider_name); |
227 | 1 | } |
228 | 15 | if (media_types.size() != media_urls.size()) { |
229 | 1 | return Status::InvalidArgument( |
230 | 1 | "{} multimodal embed input size mismatch, media_types={}, media_urls={}", |
231 | 1 | provider_name, media_types.size(), media_urls.size()); |
232 | 1 | } |
233 | 19 | for (MultimodalType media_type : media_types) { |
234 | 19 | bool supported = false; |
235 | 31 | for (MultimodalType supported_type : supported_types) { |
236 | 31 | if (media_type == supported_type) { |
237 | 18 | supported = true; |
238 | 18 | break; |
239 | 18 | } |
240 | 31 | } |
241 | 19 | if (!supported) [[unlikely]] { |
242 | 1 | return Status::InvalidArgument( |
243 | 1 | "{} only supports {} multimodal embed, got {}", provider_name, |
244 | 1 | supported_multimodal_types_to_string(supported_types), |
245 | 1 | multimodal_type_to_string(media_type)); |
246 | 1 | } |
247 | 19 | } |
248 | 13 | return Status::OK(); |
249 | 14 | } |
250 | | |
251 | | static std::string supported_multimodal_types_to_string( |
252 | 1 | std::initializer_list<MultimodalType> supported_types) { |
253 | 1 | std::string result; |
254 | 2 | for (MultimodalType type : supported_types) { |
255 | 2 | if (!result.empty()) { |
256 | 1 | result += "/"; |
257 | 1 | } |
258 | 2 | result += multimodal_type_to_string(type); |
259 | 2 | } |
260 | 1 | return result; |
261 | 1 | } |
262 | | }; |
263 | | |
264 | | // Most LLM-providers' Embedding formats are based on VoyageAI. |
265 | | // The following adapters inherit from VoyageAIAdapter to directly reuse its embedding logic. |
266 | | class VoyageAIAdapter : public AIAdapter { |
267 | | public: |
268 | 2 | Status set_authentication(HttpClient* client) const override { |
269 | 2 | client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key); |
270 | 2 | client->set_content_type("application/json"); |
271 | | |
272 | 2 | return Status::OK(); |
273 | 2 | } |
274 | | |
275 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
276 | 8 | std::string& request_body) const override { |
277 | 8 | rapidjson::Document doc; |
278 | 8 | doc.SetObject(); |
279 | 8 | auto& allocator = doc.GetAllocator(); |
280 | | |
281 | | /*{ |
282 | | "model": "xxx", |
283 | | "input": [ |
284 | | "xxx", |
285 | | "xxx", |
286 | | ... |
287 | | ], |
288 | | "output_dimensions": 512 |
289 | | }*/ |
290 | 8 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
291 | 8 | add_dimension_params(doc, allocator); |
292 | | |
293 | 8 | rapidjson::Value input(rapidjson::kArrayType); |
294 | 8 | for (const auto& msg : inputs) { |
295 | 8 | input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator); |
296 | 8 | } |
297 | 8 | doc.AddMember("input", input, allocator); |
298 | | |
299 | 8 | rapidjson::StringBuffer buffer; |
300 | 8 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
301 | 8 | doc.Accept(writer); |
302 | 8 | request_body = buffer.GetString(); |
303 | | |
304 | 8 | return Status::OK(); |
305 | 8 | } |
306 | | |
307 | | Status build_multimodal_embedding_request( |
308 | | const std::vector<MultimodalType>& media_types, |
309 | | const std::vector<std::string>& media_urls, |
310 | | const std::vector<std::string>& /*media_content_types*/, |
311 | 2 | std::string& request_body) const override { |
312 | 2 | RETURN_IF_ERROR(validate_multimodal_embedding_inputs( |
313 | 2 | "VoyageAI", media_types, media_urls, |
314 | 2 | {MultimodalType::IMAGE, MultimodalType::VIDEO})); |
315 | 2 | if (_config.dimensions != -1) { |
316 | 2 | LOG(WARNING) << "VoyageAI multimodal embedding currently ignores dimensions parameter, " |
317 | 2 | << "model=" << _config.model_name << ", dimensions=" << _config.dimensions; |
318 | 2 | } |
319 | | |
320 | 2 | rapidjson::Document doc; |
321 | 2 | doc.SetObject(); |
322 | 2 | auto& allocator = doc.GetAllocator(); |
323 | | |
324 | | /*{ |
325 | | "inputs": [ |
326 | | { |
327 | | "content": [ |
328 | | {"type": "image_url", "image_url": "<url>"} |
329 | | ] |
330 | | }, |
331 | | { |
332 | | "content": [ |
333 | | {"type": "video_url", "video_url": "<url>"} |
334 | | ] |
335 | | } |
336 | | ], |
337 | | "model": "voyage-multimodal-3.5" |
338 | | }*/ |
339 | 2 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
340 | | |
341 | 2 | rapidjson::Value request_inputs(rapidjson::kArrayType); |
342 | 5 | for (size_t i = 0; i < media_urls.size(); ++i) { |
343 | 3 | rapidjson::Value input(rapidjson::kObjectType); |
344 | 3 | rapidjson::Value content(rapidjson::kArrayType); |
345 | 3 | rapidjson::Value media_item(rapidjson::kObjectType); |
346 | 3 | if (media_types[i] == MultimodalType::IMAGE) { |
347 | 1 | media_item.AddMember("type", "image_url", allocator); |
348 | 1 | media_item.AddMember("image_url", |
349 | 1 | rapidjson::Value(media_urls[i].c_str(), allocator), allocator); |
350 | 2 | } else { |
351 | 2 | media_item.AddMember("type", "video_url", allocator); |
352 | 2 | media_item.AddMember("video_url", |
353 | 2 | rapidjson::Value(media_urls[i].c_str(), allocator), allocator); |
354 | 2 | } |
355 | 3 | content.PushBack(media_item, allocator); |
356 | 3 | input.AddMember("content", content, allocator); |
357 | 3 | request_inputs.PushBack(input, allocator); |
358 | 3 | } |
359 | | |
360 | 2 | doc.AddMember("inputs", request_inputs, allocator); |
361 | | |
362 | 2 | rapidjson::StringBuffer buffer; |
363 | 2 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
364 | 2 | doc.Accept(writer); |
365 | 2 | request_body = buffer.GetString(); |
366 | 2 | return Status::OK(); |
367 | 2 | } |
368 | | |
369 | | Status parse_embedding_response(const std::string& response_body, |
370 | 5 | std::vector<std::vector<float>>& results) const override { |
371 | 5 | rapidjson::Document doc; |
372 | 5 | doc.Parse(response_body.c_str()); |
373 | | |
374 | 5 | if (doc.HasParseError() || !doc.IsObject()) { |
375 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
376 | 1 | response_body); |
377 | 1 | } |
378 | 4 | if (!doc.HasMember("data") || !doc["data"].IsArray()) { |
379 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
380 | 1 | response_body); |
381 | 1 | } |
382 | | |
383 | | /*{ |
384 | | "data":[ |
385 | | { |
386 | | "object": "embedding", |
387 | | "embedding": [...], <- only need this |
388 | | "index": 0 |
389 | | }, |
390 | | { |
391 | | "object": "embedding", |
392 | | "embedding": [...], |
393 | | "index": 1 |
394 | | }, ... |
395 | | ], |
396 | | "model".... |
397 | | }*/ |
398 | 3 | const auto& data = doc["data"]; |
399 | 3 | results.reserve(data.Size()); |
400 | 7 | for (rapidjson::SizeType i = 0; i < data.Size(); i++) { |
401 | 5 | if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) { |
402 | 1 | return Status::InternalError("Invalid {} response format: {}", |
403 | 1 | _config.provider_type, response_body); |
404 | 1 | } |
405 | | |
406 | 4 | std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(), |
407 | 4 | std::back_inserter(results.emplace_back()), |
408 | 10 | [](const auto& val) { return val.GetFloat(); }); |
409 | 4 | } |
410 | | |
411 | 2 | return Status::OK(); |
412 | 3 | } |
413 | | |
414 | | protected: |
415 | 4 | bool supports_dimension_param(const std::string& model_name) const override { |
416 | 4 | static const std::unordered_set<std::string> no_dimension_models = { |
417 | 4 | "voyage-law-2", "voyage-2", "voyage-code-2", "voyage-finance-2", |
418 | 4 | "voyage-multimodal-3"}; |
419 | 4 | return !no_dimension_models.contains(model_name); |
420 | 4 | } |
421 | | |
422 | 1 | std::string get_dimension_param_name() const override { return "output_dimension"; } |
423 | | }; |
424 | | |
425 | | // Local AI adapter for locally hosted models (Ollama, LLaMA, etc.) |
426 | | class LocalAdapter : public AIAdapter { |
427 | | public: |
428 | | // Local deployments typically don't need authentication |
429 | 2 | Status set_authentication(HttpClient* client) const override { |
430 | 2 | client->set_content_type("application/json"); |
431 | 2 | return Status::OK(); |
432 | 2 | } |
433 | | |
434 | | Status build_request_payload(const std::vector<std::string>& inputs, |
435 | | const char* const system_prompt, |
436 | 3 | std::string& request_body) const override { |
437 | 3 | rapidjson::Document doc; |
438 | 3 | doc.SetObject(); |
439 | 3 | auto& allocator = doc.GetAllocator(); |
440 | | |
441 | 3 | std::string end_point = _config.endpoint; |
442 | 3 | if (end_point.ends_with("chat") || end_point.ends_with("generate")) { |
443 | 2 | RETURN_IF_ERROR( |
444 | 2 | build_ollama_request(doc, allocator, inputs, system_prompt, request_body)); |
445 | 2 | } else { |
446 | 1 | RETURN_IF_ERROR( |
447 | 1 | build_default_request(doc, allocator, inputs, system_prompt, request_body)); |
448 | 1 | } |
449 | | |
450 | 3 | rapidjson::StringBuffer buffer; |
451 | 3 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
452 | 3 | doc.Accept(writer); |
453 | 3 | request_body = buffer.GetString(); |
454 | | |
455 | 3 | return Status::OK(); |
456 | 3 | } |
457 | | |
458 | | Status parse_response(const std::string& response_body, |
459 | 7 | std::vector<std::string>& results) const override { |
460 | 7 | rapidjson::Document doc; |
461 | 7 | doc.Parse(response_body.c_str()); |
462 | | |
463 | 7 | if (doc.HasParseError() || !doc.IsObject()) { |
464 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
465 | 1 | response_body); |
466 | 1 | } |
467 | | |
468 | | // Handle various response formats from local LLMs |
469 | | // Format 1: OpenAI-compatible format with choices/message/content |
470 | 6 | if (doc.HasMember("choices") && doc["choices"].IsArray()) { |
471 | 1 | const auto& choices = doc["choices"]; |
472 | 1 | results.reserve(choices.Size()); |
473 | | |
474 | 2 | for (rapidjson::SizeType i = 0; i < choices.Size(); i++) { |
475 | 1 | if (choices[i].HasMember("message") && choices[i]["message"].HasMember("content") && |
476 | 1 | choices[i]["message"]["content"].IsString()) { |
477 | 1 | RETURN_IF_ERROR(append_parsed_text_result( |
478 | 1 | choices[i]["message"]["content"].GetString(), results)); |
479 | 1 | } else if (choices[i].HasMember("text") && choices[i]["text"].IsString()) { |
480 | | // Some local LLMs use a simpler format |
481 | 0 | RETURN_IF_ERROR( |
482 | 0 | append_parsed_text_result(choices[i]["text"].GetString(), results)); |
483 | 0 | } |
484 | 1 | } |
485 | 5 | } else if (doc.HasMember("text") && doc["text"].IsString()) { |
486 | | // Format 2: Simple response with just "text" or "content" field |
487 | 1 | RETURN_IF_ERROR(append_parsed_text_result(doc["text"].GetString(), results)); |
488 | 4 | } else if (doc.HasMember("content") && doc["content"].IsString()) { |
489 | 1 | RETURN_IF_ERROR(append_parsed_text_result(doc["content"].GetString(), results)); |
490 | 3 | } else if (doc.HasMember("response") && doc["response"].IsString()) { |
491 | | // Format 3: Response field (Ollama `generate` format) |
492 | 1 | RETURN_IF_ERROR(append_parsed_text_result(doc["response"].GetString(), results)); |
493 | 2 | } else if (doc.HasMember("message") && doc["message"].IsObject() && |
494 | 2 | doc["message"].HasMember("content") && doc["message"]["content"].IsString()) { |
495 | | // Format 4: message/content field (Ollama `chat` format) |
496 | 1 | RETURN_IF_ERROR( |
497 | 1 | append_parsed_text_result(doc["message"]["content"].GetString(), results)); |
498 | 1 | } else { |
499 | 1 | return Status::NotSupported("Unsupported response format from local AI."); |
500 | 1 | } |
501 | 5 | return Status::OK(); |
502 | 6 | } |
503 | | |
504 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
505 | 1 | std::string& request_body) const override { |
506 | 1 | rapidjson::Document doc; |
507 | 1 | doc.SetObject(); |
508 | 1 | auto& allocator = doc.GetAllocator(); |
509 | | |
510 | 1 | if (!_config.model_name.empty()) { |
511 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), |
512 | 1 | allocator); |
513 | 1 | } |
514 | | |
515 | 1 | add_dimension_params(doc, allocator); |
516 | | |
517 | 1 | rapidjson::Value input(rapidjson::kArrayType); |
518 | 1 | for (const auto& msg : inputs) { |
519 | 1 | input.PushBack(rapidjson::Value(msg.c_str(), allocator), allocator); |
520 | 1 | } |
521 | 1 | doc.AddMember("input", input, allocator); |
522 | | |
523 | 1 | rapidjson::StringBuffer buffer; |
524 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
525 | 1 | doc.Accept(writer); |
526 | 1 | request_body = buffer.GetString(); |
527 | | |
528 | 1 | return Status::OK(); |
529 | 1 | } |
530 | | |
531 | | Status build_multimodal_embedding_request( |
532 | | const std::vector<MultimodalType>& /*media_types*/, |
533 | | const std::vector<std::string>& /*media_urls*/, |
534 | | const std::vector<std::string>& /*media_content_types*/, |
535 | 0 | std::string& /*request_body*/) const override { |
536 | 0 | return Status::NotSupported("{} does not support multimodal Embed feature.", |
537 | 0 | _config.provider_type); |
538 | 0 | } |
539 | | |
540 | | Status parse_embedding_response(const std::string& response_body, |
541 | 3 | std::vector<std::vector<float>>& results) const override { |
542 | 3 | rapidjson::Document doc; |
543 | 3 | doc.Parse(response_body.c_str()); |
544 | | |
545 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
546 | 0 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
547 | 0 | response_body); |
548 | 0 | } |
549 | | |
550 | | // parse different response format |
551 | 3 | rapidjson::Value embedding; |
552 | 3 | if (doc.HasMember("data") && doc["data"].IsArray()) { |
553 | | // "data":["object":"embedding", "embedding":[0.1, 0.2...], "index":0] |
554 | 1 | const auto& data = doc["data"]; |
555 | 1 | results.reserve(data.Size()); |
556 | 3 | for (rapidjson::SizeType i = 0; i < data.Size(); i++) { |
557 | 2 | if (!data[i].HasMember("embedding") || !data[i]["embedding"].IsArray()) { |
558 | 0 | return Status::InternalError("Invalid {} response format", |
559 | 0 | _config.provider_type); |
560 | 0 | } |
561 | | |
562 | 2 | std::transform(data[i]["embedding"].Begin(), data[i]["embedding"].End(), |
563 | 2 | std::back_inserter(results.emplace_back()), |
564 | 5 | [](const auto& val) { return val.GetFloat(); }); |
565 | 2 | } |
566 | 2 | } else if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) { |
567 | | // "embeddings":[[0.1, 0.2, ...]] |
568 | 1 | results.reserve(1); |
569 | 2 | for (int i = 0; i < doc["embeddings"].Size(); i++) { |
570 | 1 | embedding = doc["embeddings"][i]; |
571 | 1 | std::transform(embedding.Begin(), embedding.End(), |
572 | 1 | std::back_inserter(results.emplace_back()), |
573 | 2 | [](const auto& val) { return val.GetFloat(); }); |
574 | 1 | } |
575 | 1 | } else if (doc.HasMember("embedding") && doc["embedding"].IsArray()) { |
576 | | // "embedding":[0.1, 0.2, ...] |
577 | 1 | results.reserve(1); |
578 | 1 | embedding = doc["embedding"]; |
579 | 1 | std::transform(embedding.Begin(), embedding.End(), |
580 | 1 | std::back_inserter(results.emplace_back()), |
581 | 3 | [](const auto& val) { return val.GetFloat(); }); |
582 | 1 | } else { |
583 | 0 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
584 | 0 | response_body); |
585 | 0 | } |
586 | | |
587 | 3 | return Status::OK(); |
588 | 3 | } |
589 | | |
590 | | private: |
591 | | Status build_ollama_request(rapidjson::Document& doc, |
592 | | rapidjson::Document::AllocatorType& allocator, |
593 | | const std::vector<std::string>& inputs, |
594 | 2 | const char* const system_prompt, std::string& request_body) const { |
595 | | /* |
596 | | for endpoints end_with `/chat` like 'http://localhost:11434/api/chat': |
597 | | { |
598 | | "model": <model_name>, |
599 | | "stream": false, |
600 | | "think": false, |
601 | | "options": { |
602 | | "temperature": <temperature>, |
603 | | "max_token": <max_token> |
604 | | }, |
605 | | "messages": [ |
606 | | {"role": "system", "content": <system_prompt>}, |
607 | | {"role": "user", "content": <user_prompt>} |
608 | | ] |
609 | | } |
610 | | |
611 | | for endpoints end_with `/generate` like 'http://localhost:11434/api/generate': |
612 | | { |
613 | | "model": <model_name>, |
614 | | "stream": false, |
615 | | "think": false |
616 | | "options": { |
617 | | "temperature": <temperature>, |
618 | | "max_token": <max_token> |
619 | | }, |
620 | | "system": <system_prompt>, |
621 | | "prompt": <user_prompt> |
622 | | } |
623 | | */ |
624 | | |
625 | | // For Ollama, only the prompt section ("system" + "prompt" or "role" + "content") is affected by the endpoint; |
626 | | // The rest remains identical. |
627 | 2 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
628 | 2 | doc.AddMember("stream", false, allocator); |
629 | 2 | doc.AddMember("think", false, allocator); |
630 | | |
631 | | // option section |
632 | 2 | rapidjson::Value options(rapidjson::kObjectType); |
633 | 2 | if (_config.temperature != -1) { |
634 | 2 | options.AddMember("temperature", _config.temperature, allocator); |
635 | 2 | } |
636 | 2 | if (_config.max_tokens != -1) { |
637 | 2 | options.AddMember("max_token", _config.max_tokens, allocator); |
638 | 2 | } |
639 | 2 | doc.AddMember("options", options, allocator); |
640 | | |
641 | | // prompt section |
642 | 2 | if (_config.endpoint.ends_with("chat")) { |
643 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
644 | 1 | if (system_prompt && *system_prompt) { |
645 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
646 | 1 | sys_msg.AddMember("role", "system", allocator); |
647 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
648 | 1 | messages.PushBack(sys_msg, allocator); |
649 | 1 | } |
650 | 1 | for (const auto& input : inputs) { |
651 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
652 | 1 | message.AddMember("role", "user", allocator); |
653 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
654 | 1 | messages.PushBack(message, allocator); |
655 | 1 | } |
656 | 1 | doc.AddMember("messages", messages, allocator); |
657 | 1 | } else { |
658 | 1 | if (system_prompt && *system_prompt) { |
659 | 1 | doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator); |
660 | 1 | } |
661 | 1 | doc.AddMember("prompt", rapidjson::Value(inputs[0].c_str(), allocator), allocator); |
662 | 1 | } |
663 | | |
664 | 2 | return Status::OK(); |
665 | 2 | } |
666 | | |
667 | | Status build_default_request(rapidjson::Document& doc, |
668 | | rapidjson::Document::AllocatorType& allocator, |
669 | | const std::vector<std::string>& inputs, |
670 | 1 | const char* const system_prompt, std::string& request_body) const { |
671 | | /* |
672 | | Default format(OpenAI-compatible): |
673 | | { |
674 | | "model": <model_name>, |
675 | | "temperature": <temperature>, |
676 | | "max_tokens": <max_tokens>, |
677 | | "messages": [ |
678 | | {"role": "system", "content": <system_prompt>}, |
679 | | {"role": "user", "content": <user_prompt>} |
680 | | ] |
681 | | } |
682 | | */ |
683 | | |
684 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
685 | | |
686 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
687 | 1 | if (_config.temperature != -1) { |
688 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
689 | 1 | } |
690 | 1 | if (_config.max_tokens != -1) { |
691 | 1 | doc.AddMember("max_tokens", _config.max_tokens, allocator); |
692 | 1 | } |
693 | | |
694 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
695 | 1 | if (system_prompt && *system_prompt) { |
696 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
697 | 1 | sys_msg.AddMember("role", "system", allocator); |
698 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
699 | 1 | messages.PushBack(sys_msg, allocator); |
700 | 1 | } |
701 | 1 | for (const auto& input : inputs) { |
702 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
703 | 1 | message.AddMember("role", "user", allocator); |
704 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
705 | 1 | messages.PushBack(message, allocator); |
706 | 1 | } |
707 | 1 | doc.AddMember("messages", messages, allocator); |
708 | 1 | return Status::OK(); |
709 | 1 | } |
710 | | }; |
711 | | |
712 | | // The OpenAI API format can be reused with some compatible AIs. |
713 | | class OpenAIAdapter : public VoyageAIAdapter { |
714 | | public: |
715 | 13 | Status set_authentication(HttpClient* client) const override { |
716 | 13 | client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + _config.api_key); |
717 | 13 | client->set_content_type("application/json"); |
718 | | |
719 | 13 | return Status::OK(); |
720 | 13 | } |
721 | | |
722 | | Status build_request_payload(const std::vector<std::string>& inputs, |
723 | | const char* const system_prompt, |
724 | 4 | std::string& request_body) const override { |
725 | 4 | rapidjson::Document doc; |
726 | 4 | doc.SetObject(); |
727 | 4 | auto& allocator = doc.GetAllocator(); |
728 | | |
729 | 4 | if (_config.endpoint.ends_with("responses")) { |
730 | | /*{ |
731 | | "model": "gpt-4.1-mini", |
732 | | "input": [ |
733 | | {"role": "system", "content": "system_prompt here"}, |
734 | | {"role": "user", "content": "xxx"} |
735 | | ], |
736 | | "temperature": 0.7, |
737 | | "max_output_tokens": 150 |
738 | | }*/ |
739 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), |
740 | 1 | allocator); |
741 | | |
742 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
743 | 1 | if (_config.temperature != -1) { |
744 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
745 | 1 | } |
746 | 1 | if (_config.max_tokens != -1) { |
747 | 1 | doc.AddMember("max_output_tokens", _config.max_tokens, allocator); |
748 | 1 | } |
749 | | |
750 | | // input |
751 | 1 | rapidjson::Value input(rapidjson::kArrayType); |
752 | 1 | if (system_prompt && *system_prompt) { |
753 | 1 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
754 | 1 | sys_msg.AddMember("role", "system", allocator); |
755 | 1 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
756 | 1 | input.PushBack(sys_msg, allocator); |
757 | 1 | } |
758 | 1 | for (const auto& msg : inputs) { |
759 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
760 | 1 | message.AddMember("role", "user", allocator); |
761 | 1 | message.AddMember("content", rapidjson::Value(msg.c_str(), allocator), allocator); |
762 | 1 | input.PushBack(message, allocator); |
763 | 1 | } |
764 | 1 | doc.AddMember("input", input, allocator); |
765 | 3 | } else { |
766 | | /*{ |
767 | | "model": "gpt-4", |
768 | | "messages": [ |
769 | | {"role": "system", "content": "system_prompt here"}, |
770 | | {"role": "user", "content": "xxx"} |
771 | | ], |
772 | | "temperature": x, |
773 | | "max_tokens": x, |
774 | | }*/ |
775 | 3 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), |
776 | 3 | allocator); |
777 | | |
778 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
779 | 3 | if (_config.temperature != -1) { |
780 | 3 | doc.AddMember("temperature", _config.temperature, allocator); |
781 | 3 | } |
782 | 3 | if (_config.max_tokens != -1) { |
783 | 3 | doc.AddMember("max_tokens", _config.max_tokens, allocator); |
784 | 3 | } |
785 | | |
786 | 3 | rapidjson::Value messages(rapidjson::kArrayType); |
787 | 3 | if (system_prompt && *system_prompt) { |
788 | 3 | rapidjson::Value sys_msg(rapidjson::kObjectType); |
789 | 3 | sys_msg.AddMember("role", "system", allocator); |
790 | 3 | sys_msg.AddMember("content", rapidjson::Value(system_prompt, allocator), allocator); |
791 | 3 | messages.PushBack(sys_msg, allocator); |
792 | 3 | } |
793 | 3 | for (const auto& input : inputs) { |
794 | 3 | rapidjson::Value message(rapidjson::kObjectType); |
795 | 3 | message.AddMember("role", "user", allocator); |
796 | 3 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
797 | 3 | messages.PushBack(message, allocator); |
798 | 3 | } |
799 | 3 | doc.AddMember("messages", messages, allocator); |
800 | 3 | } |
801 | | |
802 | 4 | rapidjson::StringBuffer buffer; |
803 | 4 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
804 | 4 | doc.Accept(writer); |
805 | 4 | request_body = buffer.GetString(); |
806 | | |
807 | 4 | return Status::OK(); |
808 | 4 | } |
809 | | |
810 | | Status parse_response(const std::string& response_body, |
811 | 10 | std::vector<std::string>& results) const override { |
812 | 10 | rapidjson::Document doc; |
813 | 10 | doc.Parse(response_body.c_str()); |
814 | | |
815 | 10 | if (doc.HasParseError() || !doc.IsObject()) { |
816 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
817 | 1 | response_body); |
818 | 1 | } |
819 | | |
820 | 9 | if (doc.HasMember("output") && doc["output"].IsArray()) { |
821 | | /// for responses endpoint |
822 | | /*{ |
823 | | "output": [ |
824 | | { |
825 | | "id": "msg_123", |
826 | | "type": "message", |
827 | | "role": "assistant", |
828 | | "content": [ |
829 | | { |
830 | | "type": "text", |
831 | | "text": "result text here" <- result |
832 | | } |
833 | | ] |
834 | | } |
835 | | ] |
836 | | }*/ |
837 | 1 | const auto& output = doc["output"]; |
838 | 1 | results.reserve(output.Size()); |
839 | | |
840 | 2 | for (rapidjson::SizeType i = 0; i < output.Size(); i++) { |
841 | 1 | if (!output[i].HasMember("content") || !output[i]["content"].IsArray() || |
842 | 1 | output[i]["content"].Empty() || !output[i]["content"][0].HasMember("text") || |
843 | 1 | !output[i]["content"][0]["text"].IsString()) { |
844 | 0 | return Status::InternalError("Invalid output format in {} response: {}", |
845 | 0 | _config.provider_type, response_body); |
846 | 0 | } |
847 | | |
848 | 1 | RETURN_IF_ERROR(append_parsed_text_result( |
849 | 1 | output[i]["content"][0]["text"].GetString(), results)); |
850 | 1 | } |
851 | 8 | } else if (doc.HasMember("choices") && doc["choices"].IsArray()) { |
852 | | /// for completions endpoint |
853 | | /*{ |
854 | | "object": "chat.completion", |
855 | | "model": "gpt-4", |
856 | | "choices": [ |
857 | | { |
858 | | ... |
859 | | "message": { |
860 | | "role": "assistant", |
861 | | "content": "xxx" <- result |
862 | | }, |
863 | | ... |
864 | | } |
865 | | ], |
866 | | ... |
867 | | }*/ |
868 | 7 | const auto& choices = doc["choices"]; |
869 | 7 | results.reserve(choices.Size()); |
870 | | |
871 | 12 | for (rapidjson::SizeType i = 0; i < choices.Size(); i++) { |
872 | 7 | if (!choices[i].HasMember("message") || |
873 | 7 | !choices[i]["message"].HasMember("content") || |
874 | 7 | !choices[i]["message"]["content"].IsString()) { |
875 | 2 | return Status::InternalError("Invalid choice format in {} response: {}", |
876 | 2 | _config.provider_type, response_body); |
877 | 2 | } |
878 | | |
879 | 5 | RETURN_IF_ERROR(append_parsed_text_result( |
880 | 5 | choices[i]["message"]["content"].GetString(), results)); |
881 | 5 | } |
882 | 7 | } else { |
883 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
884 | 1 | response_body); |
885 | 1 | } |
886 | | |
887 | 6 | return Status::OK(); |
888 | 9 | } |
889 | | |
890 | | Status build_multimodal_embedding_request( |
891 | | const std::vector<MultimodalType>& /*media_types*/, |
892 | | const std::vector<std::string>& /*media_urls*/, |
893 | | const std::vector<std::string>& /*media_content_types*/, |
894 | 1 | std::string& /*request_body*/) const override { |
895 | 1 | return Status::NotSupported("{} does not support multimodal Embed feature.", |
896 | 1 | _config.provider_type); |
897 | 1 | } |
898 | | |
899 | | protected: |
900 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
901 | 2 | return !(model_name == "text-embedding-ada-002"); |
902 | 2 | } |
903 | | |
904 | 2 | std::string get_dimension_param_name() const override { return "dimensions"; } |
905 | | }; |
906 | | |
907 | | class DeepSeekAdapter : public OpenAIAdapter { |
908 | | public: |
909 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
910 | 1 | std::string& request_body) const override { |
911 | 1 | return embed_not_supported_status(); |
912 | 1 | } |
913 | | |
914 | | Status parse_embedding_response(const std::string& response_body, |
915 | 1 | std::vector<std::vector<float>>& results) const override { |
916 | 1 | return embed_not_supported_status(); |
917 | 1 | } |
918 | | }; |
919 | | |
920 | | class MoonShotAdapter : public OpenAIAdapter { |
921 | | public: |
922 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
923 | 1 | std::string& request_body) const override { |
924 | 1 | return embed_not_supported_status(); |
925 | 1 | } |
926 | | |
927 | | Status parse_embedding_response(const std::string& response_body, |
928 | 1 | std::vector<std::vector<float>>& results) const override { |
929 | 1 | return embed_not_supported_status(); |
930 | 1 | } |
931 | | }; |
932 | | |
933 | | class MinimaxAdapter : public OpenAIAdapter { |
934 | | public: |
935 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
936 | 1 | std::string& request_body) const override { |
937 | 1 | rapidjson::Document doc; |
938 | 1 | doc.SetObject(); |
939 | 1 | auto& allocator = doc.GetAllocator(); |
940 | | |
941 | | /*{ |
942 | | "text": ["xxx", "xxx", ...], |
943 | | "model": "embo-1", |
944 | | "type": "db" |
945 | | }*/ |
946 | 1 | rapidjson::Value texts(rapidjson::kArrayType); |
947 | 1 | for (const auto& input : inputs) { |
948 | 1 | texts.PushBack(rapidjson::Value(input.c_str(), allocator), allocator); |
949 | 1 | } |
950 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
951 | 1 | doc.AddMember("texts", texts, allocator); |
952 | 1 | doc.AddMember("type", rapidjson::Value("db", allocator), allocator); |
953 | | |
954 | 1 | rapidjson::StringBuffer buffer; |
955 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
956 | 1 | doc.Accept(writer); |
957 | 1 | request_body = buffer.GetString(); |
958 | | |
959 | 1 | return Status::OK(); |
960 | 1 | } |
961 | | }; |
962 | | |
963 | | class ZhipuAdapter : public OpenAIAdapter { |
964 | | protected: |
965 | 2 | bool supports_dimension_param(const std::string& model_name) const override { |
966 | 2 | return !(model_name == "embedding-2"); |
967 | 2 | } |
968 | | }; |
969 | | |
970 | | class QwenAdapter : public OpenAIAdapter { |
971 | | public: |
972 | | Status build_multimodal_embedding_request( |
973 | | const std::vector<MultimodalType>& media_types, |
974 | | const std::vector<std::string>& media_urls, |
975 | | const std::vector<std::string>& /*media_content_types*/, |
976 | 4 | std::string& request_body) const override { |
977 | 4 | RETURN_IF_ERROR(validate_multimodal_embedding_inputs( |
978 | 4 | "QWEN", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO})); |
979 | | |
980 | 3 | rapidjson::Document doc; |
981 | 3 | doc.SetObject(); |
982 | 3 | auto& allocator = doc.GetAllocator(); |
983 | | |
984 | | /*{ |
985 | | "model": "tongyi-embedding-vision-plus", |
986 | | "input": { |
987 | | "contents": [ |
988 | | {"image": "<url>"}, |
989 | | {"video": "<url>"} |
990 | | ] |
991 | | } |
992 | | "parameters": { |
993 | | "dimension": 512 |
994 | | } |
995 | | }*/ |
996 | 3 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
997 | 3 | rapidjson::Value input(rapidjson::kObjectType); |
998 | 3 | rapidjson::Value contents(rapidjson::kArrayType); |
999 | | |
1000 | 7 | for (size_t i = 0; i < media_urls.size(); ++i) { |
1001 | 4 | rapidjson::Value media_item(rapidjson::kObjectType); |
1002 | 4 | if (media_types[i] == MultimodalType::IMAGE) { |
1003 | 2 | media_item.AddMember("image", rapidjson::Value(media_urls[i].c_str(), allocator), |
1004 | 2 | allocator); |
1005 | 2 | } else { |
1006 | 2 | media_item.AddMember("video", rapidjson::Value(media_urls[i].c_str(), allocator), |
1007 | 2 | allocator); |
1008 | 2 | } |
1009 | 4 | contents.PushBack(media_item, allocator); |
1010 | 4 | } |
1011 | | |
1012 | 3 | input.AddMember("contents", contents, allocator); |
1013 | 3 | doc.AddMember("input", input, allocator); |
1014 | 3 | if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) { |
1015 | 3 | rapidjson::Value parameters(rapidjson::kObjectType); |
1016 | 3 | std::string param_name = get_dimension_param_name(); |
1017 | 3 | rapidjson::Value dimension_name(param_name.c_str(), allocator); |
1018 | 3 | parameters.AddMember(dimension_name, _config.dimensions, allocator); |
1019 | 3 | doc.AddMember("parameters", parameters, allocator); |
1020 | 3 | } |
1021 | | |
1022 | 3 | rapidjson::StringBuffer buffer; |
1023 | 3 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1024 | 3 | doc.Accept(writer); |
1025 | 3 | request_body = buffer.GetString(); |
1026 | 3 | return Status::OK(); |
1027 | 4 | } |
1028 | | |
1029 | | Status parse_embedding_response(const std::string& response_body, |
1030 | 0 | std::vector<std::vector<float>>& results) const override { |
1031 | 0 | rapidjson::Document doc; |
1032 | 0 | doc.Parse(response_body.c_str()); |
1033 | |
|
1034 | 0 | if (doc.HasParseError() || !doc.IsObject()) [[unlikely]] { |
1035 | 0 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
1036 | 0 | response_body); |
1037 | 0 | } |
1038 | | // Qwen multimodal embedding usually returns: |
1039 | | // { |
1040 | | // "output": { |
1041 | | // "embeddings": [ |
1042 | | // {"index":0, "embedding":[...], "type":"image|video|text"}, |
1043 | | // ... |
1044 | | // ] |
1045 | | // } |
1046 | | // } |
1047 | | // |
1048 | | // In text-only or compatibility endpoints, Qwen may also return OpenAI-style |
1049 | | // "data":[{"embedding":[...]}]. For compatibility we first parse native |
1050 | | // output.embeddings and then fallback to OpenAIAdapter parser. |
1051 | 0 | if (doc.HasMember("output") && doc["output"].IsObject() && |
1052 | 0 | doc["output"].HasMember("embeddings") && doc["output"]["embeddings"].IsArray()) { |
1053 | 0 | const auto& embeddings = doc["output"]["embeddings"]; |
1054 | 0 | results.reserve(embeddings.Size()); |
1055 | 0 | for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) { |
1056 | 0 | if (!embeddings[i].HasMember("embedding") || |
1057 | 0 | !embeddings[i]["embedding"].IsArray()) { |
1058 | 0 | return Status::InternalError("Invalid {} response format: {}", |
1059 | 0 | _config.provider_type, response_body); |
1060 | 0 | } |
1061 | 0 | std::transform(embeddings[i]["embedding"].Begin(), embeddings[i]["embedding"].End(), |
1062 | 0 | std::back_inserter(results.emplace_back()), |
1063 | 0 | [](const auto& val) { return val.GetFloat(); }); |
1064 | 0 | } |
1065 | 0 | return Status::OK(); |
1066 | 0 | } |
1067 | 0 | return OpenAIAdapter::parse_embedding_response(response_body, results); |
1068 | 0 | } |
1069 | | |
1070 | | protected: |
1071 | 5 | bool supports_dimension_param(const std::string& model_name) const override { |
1072 | 5 | static const std::unordered_set<std::string> no_dimension_models = { |
1073 | 5 | "text-embedding-v1", "text-embedding-v2", "text2vec", "m3e-base", "m3e-small"}; |
1074 | 5 | return !no_dimension_models.contains(model_name); |
1075 | 5 | } |
1076 | | |
1077 | 4 | std::string get_dimension_param_name() const override { return "dimension"; } |
1078 | | }; |
1079 | | |
1080 | | class JinaAdapter : public VoyageAIAdapter { |
1081 | | public: |
1082 | | Status build_multimodal_embedding_request( |
1083 | | const std::vector<MultimodalType>& media_types, |
1084 | | const std::vector<std::string>& media_urls, |
1085 | | const std::vector<std::string>& /*media_content_types*/, |
1086 | 2 | std::string& request_body) const override { |
1087 | 2 | RETURN_IF_ERROR(validate_multimodal_embedding_inputs( |
1088 | 2 | "JINA", media_types, media_urls, {MultimodalType::IMAGE, MultimodalType::VIDEO})); |
1089 | | |
1090 | 2 | rapidjson::Document doc; |
1091 | 2 | doc.SetObject(); |
1092 | 2 | auto& allocator = doc.GetAllocator(); |
1093 | | |
1094 | | /*{ |
1095 | | "model": "jina-embeddings-v4", |
1096 | | "task": "text-matching", |
1097 | | "input": [ |
1098 | | {"image": "<url>"}, |
1099 | | {"video": "<url>"} |
1100 | | ] |
1101 | | }*/ |
1102 | 2 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
1103 | 2 | doc.AddMember("task", "text-matching", allocator); |
1104 | | |
1105 | 2 | rapidjson::Value input(rapidjson::kArrayType); |
1106 | 5 | for (size_t i = 0; i < media_urls.size(); ++i) { |
1107 | 3 | rapidjson::Value media_item(rapidjson::kObjectType); |
1108 | 3 | if (media_types[i] == MultimodalType::IMAGE) { |
1109 | 2 | media_item.AddMember("image", rapidjson::Value(media_urls[i].c_str(), allocator), |
1110 | 2 | allocator); |
1111 | 2 | } else { |
1112 | 1 | media_item.AddMember("video", rapidjson::Value(media_urls[i].c_str(), allocator), |
1113 | 1 | allocator); |
1114 | 1 | } |
1115 | 3 | input.PushBack(media_item, allocator); |
1116 | 3 | } |
1117 | 2 | if (_config.dimensions != -1 && supports_dimension_param(_config.model_name)) { |
1118 | 2 | doc.AddMember("dimensions", _config.dimensions, allocator); |
1119 | 2 | } |
1120 | 2 | doc.AddMember("input", input, allocator); |
1121 | | |
1122 | 2 | rapidjson::StringBuffer buffer; |
1123 | 2 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1124 | 2 | doc.Accept(writer); |
1125 | 2 | request_body = buffer.GetString(); |
1126 | 2 | return Status::OK(); |
1127 | 2 | } |
1128 | | }; |
1129 | | |
1130 | | class BaichuanAdapter : public OpenAIAdapter { |
1131 | | protected: |
1132 | 0 | bool supports_dimension_param(const std::string& model_name) const override { return false; } |
1133 | | }; |
1134 | | |
1135 | | // Gemini's embedding format is different from VoyageAI, so it requires a separate adapter |
1136 | | class GeminiAdapter : public AIAdapter { |
1137 | | public: |
1138 | 2 | Status set_authentication(HttpClient* client) const override { |
1139 | 2 | client->set_header("x-goog-api-key", _config.api_key); |
1140 | 2 | client->set_content_type("application/json"); |
1141 | 2 | return Status::OK(); |
1142 | 2 | } |
1143 | | |
1144 | | Status build_request_payload(const std::vector<std::string>& inputs, |
1145 | | const char* const system_prompt, |
1146 | 1 | std::string& request_body) const override { |
1147 | 1 | rapidjson::Document doc; |
1148 | 1 | doc.SetObject(); |
1149 | 1 | auto& allocator = doc.GetAllocator(); |
1150 | | |
1151 | | /*{ |
1152 | | "systemInstruction": { |
1153 | | "parts": [ |
1154 | | { |
1155 | | "text": "system_prompt here" |
1156 | | } |
1157 | | ] |
1158 | | } |
1159 | | ], |
1160 | | "contents": [ |
1161 | | { |
1162 | | "parts": [ |
1163 | | { |
1164 | | "text": "xxx" |
1165 | | } |
1166 | | ] |
1167 | | } |
1168 | | ], |
1169 | | "generationConfig": { |
1170 | | "temperature": 0.7, |
1171 | | "maxOutputTokens": 1024 |
1172 | | } |
1173 | | |
1174 | | }*/ |
1175 | 1 | if (system_prompt && *system_prompt) { |
1176 | 1 | rapidjson::Value system_instruction(rapidjson::kObjectType); |
1177 | 1 | rapidjson::Value parts(rapidjson::kArrayType); |
1178 | | |
1179 | 1 | rapidjson::Value part(rapidjson::kObjectType); |
1180 | 1 | part.AddMember("text", rapidjson::Value(system_prompt, allocator), allocator); |
1181 | 1 | parts.PushBack(part, allocator); |
1182 | | // system_instruction.PushBack(content, allocator); |
1183 | 1 | system_instruction.AddMember("parts", parts, allocator); |
1184 | 1 | doc.AddMember("systemInstruction", system_instruction, allocator); |
1185 | 1 | } |
1186 | | |
1187 | 1 | rapidjson::Value contents(rapidjson::kArrayType); |
1188 | 1 | for (const auto& input : inputs) { |
1189 | 1 | rapidjson::Value content(rapidjson::kObjectType); |
1190 | 1 | rapidjson::Value parts(rapidjson::kArrayType); |
1191 | | |
1192 | 1 | rapidjson::Value part(rapidjson::kObjectType); |
1193 | 1 | part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator); |
1194 | | |
1195 | 1 | parts.PushBack(part, allocator); |
1196 | 1 | content.AddMember("parts", parts, allocator); |
1197 | 1 | contents.PushBack(content, allocator); |
1198 | 1 | } |
1199 | 1 | doc.AddMember("contents", contents, allocator); |
1200 | | |
1201 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
1202 | 1 | rapidjson::Value generationConfig(rapidjson::kObjectType); |
1203 | 1 | if (_config.temperature != -1) { |
1204 | 1 | generationConfig.AddMember("temperature", _config.temperature, allocator); |
1205 | 1 | } |
1206 | 1 | if (_config.max_tokens != -1) { |
1207 | 1 | generationConfig.AddMember("maxOutputTokens", _config.max_tokens, allocator); |
1208 | 1 | } |
1209 | 1 | doc.AddMember("generationConfig", generationConfig, allocator); |
1210 | | |
1211 | 1 | rapidjson::StringBuffer buffer; |
1212 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1213 | 1 | doc.Accept(writer); |
1214 | 1 | request_body = buffer.GetString(); |
1215 | | |
1216 | 1 | return Status::OK(); |
1217 | 1 | } |
1218 | | |
1219 | | Status parse_response(const std::string& response_body, |
1220 | 3 | std::vector<std::string>& results) const override { |
1221 | 3 | rapidjson::Document doc; |
1222 | 3 | doc.Parse(response_body.c_str()); |
1223 | | |
1224 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
1225 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
1226 | 1 | response_body); |
1227 | 1 | } |
1228 | 2 | if (!doc.HasMember("candidates") || !doc["candidates"].IsArray()) { |
1229 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
1230 | 1 | response_body); |
1231 | 1 | } |
1232 | | |
1233 | | /*{ |
1234 | | "candidates":[ |
1235 | | { |
1236 | | "content": { |
1237 | | "parts": [ |
1238 | | { |
1239 | | "text": "xxx" |
1240 | | } |
1241 | | ] |
1242 | | } |
1243 | | } |
1244 | | ] |
1245 | | }*/ |
1246 | 1 | const auto& candidates = doc["candidates"]; |
1247 | 1 | results.reserve(candidates.Size()); |
1248 | | |
1249 | 2 | for (rapidjson::SizeType i = 0; i < candidates.Size(); i++) { |
1250 | 1 | if (!candidates[i].HasMember("content") || |
1251 | 1 | !candidates[i]["content"].HasMember("parts") || |
1252 | 1 | !candidates[i]["content"]["parts"].IsArray() || |
1253 | 1 | candidates[i]["content"]["parts"].Empty() || |
1254 | 1 | !candidates[i]["content"]["parts"][0].HasMember("text") || |
1255 | 1 | !candidates[i]["content"]["parts"][0]["text"].IsString()) { |
1256 | 0 | return Status::InternalError("Invalid candidate format in {} response", |
1257 | 0 | _config.provider_type); |
1258 | 0 | } |
1259 | | |
1260 | 1 | RETURN_IF_ERROR(append_parsed_text_result( |
1261 | 1 | candidates[i]["content"]["parts"][0]["text"].GetString(), results)); |
1262 | 1 | } |
1263 | 1 | return Status::OK(); |
1264 | 1 | } |
1265 | | |
1266 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
1267 | 2 | std::string& request_body) const override { |
1268 | 2 | rapidjson::Document doc; |
1269 | 2 | doc.SetObject(); |
1270 | 2 | auto& allocator = doc.GetAllocator(); |
1271 | | |
1272 | | /*{ |
1273 | | "requests": [ |
1274 | | { |
1275 | | "model": "models/gemini-embedding-001", |
1276 | | "content": { |
1277 | | "parts": [ |
1278 | | { |
1279 | | "text": "xxx" |
1280 | | } |
1281 | | ] |
1282 | | }, |
1283 | | "outputDimensionality": 1024 |
1284 | | }, |
1285 | | { |
1286 | | "model": "models/gemini-embedding-001", |
1287 | | "content": { |
1288 | | "parts": [ |
1289 | | { |
1290 | | "text": "yyy" |
1291 | | } |
1292 | | ] |
1293 | | }, |
1294 | | "outputDimensionality": 1024 |
1295 | | } |
1296 | | ] |
1297 | | }*/ |
1298 | | |
1299 | | // gemini requires the model format as `models/{model}` |
1300 | 2 | std::string model_name = _config.model_name; |
1301 | 2 | if (!model_name.starts_with("models/")) { |
1302 | 2 | model_name = "models/" + model_name; |
1303 | 2 | } |
1304 | | |
1305 | 2 | rapidjson::Value requests(rapidjson::kArrayType); |
1306 | 4 | for (const auto& input : inputs) { |
1307 | 4 | rapidjson::Value request(rapidjson::kObjectType); |
1308 | 4 | request.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator); |
1309 | 4 | add_dimension_params(request, allocator); |
1310 | | |
1311 | 4 | rapidjson::Value content(rapidjson::kObjectType); |
1312 | 4 | rapidjson::Value parts(rapidjson::kArrayType); |
1313 | 4 | rapidjson::Value part(rapidjson::kObjectType); |
1314 | 4 | part.AddMember("text", rapidjson::Value(input.c_str(), allocator), allocator); |
1315 | 4 | parts.PushBack(part, allocator); |
1316 | 4 | content.AddMember("parts", parts, allocator); |
1317 | 4 | request.AddMember("content", content, allocator); |
1318 | 4 | requests.PushBack(request, allocator); |
1319 | 4 | } |
1320 | 2 | doc.AddMember("requests", requests, allocator); |
1321 | | |
1322 | 2 | rapidjson::StringBuffer buffer; |
1323 | 2 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1324 | 2 | doc.Accept(writer); |
1325 | 2 | request_body = buffer.GetString(); |
1326 | | |
1327 | 2 | return Status::OK(); |
1328 | 2 | } |
1329 | | |
1330 | | Status build_multimodal_embedding_request(const std::vector<MultimodalType>& media_types, |
1331 | | const std::vector<std::string>& media_urls, |
1332 | | const std::vector<std::string>& media_content_types, |
1333 | 8 | std::string& request_body) const override { |
1334 | 8 | RETURN_IF_ERROR(validate_multimodal_embedding_inputs( |
1335 | 8 | "Gemini", media_types, media_urls, |
1336 | 8 | {MultimodalType::IMAGE, MultimodalType::AUDIO, MultimodalType::VIDEO})); |
1337 | 6 | if (media_content_types.size() != media_urls.size()) { |
1338 | 1 | return Status::InvalidArgument( |
1339 | 1 | "Gemini multimodal embed input size mismatch, media_content_types={}, " |
1340 | 1 | "media_urls={}", |
1341 | 1 | media_content_types.size(), media_urls.size()); |
1342 | 1 | } |
1343 | | |
1344 | 5 | rapidjson::Document doc; |
1345 | 5 | doc.SetObject(); |
1346 | 5 | auto& allocator = doc.GetAllocator(); |
1347 | | |
1348 | | /*{ |
1349 | | "requests": [ |
1350 | | { |
1351 | | "model": "models/gemini-embedding-2-preview", |
1352 | | "content": { |
1353 | | "parts": [ |
1354 | | {"file_data": {"mime_type": "<original content_type>", "file_uri": "<url>"}} |
1355 | | ] |
1356 | | }, |
1357 | | "outputDimensionality": 768 |
1358 | | }, |
1359 | | { |
1360 | | "model": "models/gemini-embedding-2-preview", |
1361 | | "content": { |
1362 | | "parts": [ |
1363 | | {"file_data": {"mime_type": "<original content_type>", "file_uri": "<url>"}} |
1364 | | ] |
1365 | | }, |
1366 | | "outputDimensionality": 768 |
1367 | | } |
1368 | | ] |
1369 | | }*/ |
1370 | 5 | std::string model_name = _config.model_name; |
1371 | 5 | if (!model_name.starts_with("models/")) { |
1372 | 5 | model_name = "models/" + model_name; |
1373 | 5 | } |
1374 | | |
1375 | 5 | rapidjson::Value requests(rapidjson::kArrayType); |
1376 | 12 | for (size_t i = 0; i < media_urls.size(); ++i) { |
1377 | 7 | rapidjson::Value request(rapidjson::kObjectType); |
1378 | 7 | request.AddMember("model", rapidjson::Value(model_name.c_str(), allocator), allocator); |
1379 | 7 | add_dimension_params(request, allocator); |
1380 | | |
1381 | 7 | rapidjson::Value content(rapidjson::kObjectType); |
1382 | 7 | rapidjson::Value parts(rapidjson::kArrayType); |
1383 | 7 | rapidjson::Value part(rapidjson::kObjectType); |
1384 | 7 | rapidjson::Value file_data(rapidjson::kObjectType); |
1385 | 7 | file_data.AddMember("mime_type", |
1386 | 7 | rapidjson::Value(media_content_types[i].c_str(), allocator), |
1387 | 7 | allocator); |
1388 | 7 | file_data.AddMember("file_uri", rapidjson::Value(media_urls[i].c_str(), allocator), |
1389 | 7 | allocator); |
1390 | 7 | part.AddMember("file_data", file_data, allocator); |
1391 | 7 | parts.PushBack(part, allocator); |
1392 | 7 | content.AddMember("parts", parts, allocator); |
1393 | 7 | request.AddMember("content", content, allocator); |
1394 | 7 | requests.PushBack(request, allocator); |
1395 | 7 | } |
1396 | 5 | doc.AddMember("requests", requests, allocator); |
1397 | | |
1398 | 5 | rapidjson::StringBuffer buffer; |
1399 | 5 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1400 | 5 | doc.Accept(writer); |
1401 | 5 | request_body = buffer.GetString(); |
1402 | 5 | return Status::OK(); |
1403 | 6 | } |
1404 | | |
1405 | | Status parse_embedding_response(const std::string& response_body, |
1406 | 3 | std::vector<std::vector<float>>& results) const override { |
1407 | 3 | rapidjson::Document doc; |
1408 | 3 | doc.Parse(response_body.c_str()); |
1409 | | |
1410 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
1411 | 0 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
1412 | 0 | response_body); |
1413 | 0 | } |
1414 | 3 | if (doc.HasMember("embeddings") && doc["embeddings"].IsArray()) { |
1415 | | /*{ |
1416 | | "embeddings": [ |
1417 | | {"values": [0.1, 0.2, 0.3]}, |
1418 | | {"values": [0.4, 0.5, 0.6]} |
1419 | | ] |
1420 | | }*/ |
1421 | 2 | const auto& embeddings = doc["embeddings"]; |
1422 | 2 | results.reserve(embeddings.Size()); |
1423 | 6 | for (rapidjson::SizeType i = 0; i < embeddings.Size(); i++) { |
1424 | 4 | if (!embeddings[i].HasMember("values") || !embeddings[i]["values"].IsArray()) { |
1425 | 0 | return Status::InternalError("Invalid {} response format: {}", |
1426 | 0 | _config.provider_type, response_body); |
1427 | 0 | } |
1428 | 4 | std::transform(embeddings[i]["values"].Begin(), embeddings[i]["values"].End(), |
1429 | 4 | std::back_inserter(results.emplace_back()), |
1430 | 10 | [](const auto& val) { return val.GetFloat(); }); |
1431 | 4 | } |
1432 | 2 | return Status::OK(); |
1433 | 2 | } |
1434 | 1 | if (!doc.HasMember("embedding") || !doc["embedding"].IsObject()) { |
1435 | 0 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
1436 | 0 | response_body); |
1437 | 0 | } |
1438 | | |
1439 | | /*{ |
1440 | | "embedding":{ |
1441 | | "values": [0.1, 0.2, 0.3] |
1442 | | } |
1443 | | }*/ |
1444 | 1 | const auto& embedding = doc["embedding"]; |
1445 | 1 | if (!embedding.HasMember("values") || !embedding["values"].IsArray()) { |
1446 | 0 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
1447 | 0 | response_body); |
1448 | 0 | } |
1449 | 1 | std::transform(embedding["values"].Begin(), embedding["values"].End(), |
1450 | 1 | std::back_inserter(results.emplace_back()), |
1451 | 3 | [](const auto& val) { return val.GetFloat(); }); |
1452 | | |
1453 | 1 | return Status::OK(); |
1454 | 1 | } |
1455 | | |
1456 | | protected: |
1457 | 11 | bool supports_dimension_param(const std::string& model_name) const override { |
1458 | 11 | static const std::unordered_set<std::string> no_dimension_models = {"models/embedding-001", |
1459 | 11 | "embedding-001"}; |
1460 | 11 | return !no_dimension_models.contains(model_name); |
1461 | 11 | } |
1462 | | |
1463 | 9 | std::string get_dimension_param_name() const override { return "outputDimensionality"; } |
1464 | | }; |
1465 | | |
1466 | | class AnthropicAdapter : public VoyageAIAdapter { |
1467 | | public: |
1468 | 1 | Status set_authentication(HttpClient* client) const override { |
1469 | 1 | client->set_header("x-api-key", _config.api_key); |
1470 | 1 | client->set_header("anthropic-version", _config.anthropic_version); |
1471 | 1 | client->set_content_type("application/json"); |
1472 | | |
1473 | 1 | return Status::OK(); |
1474 | 1 | } |
1475 | | |
1476 | | Status build_request_payload(const std::vector<std::string>& inputs, |
1477 | | const char* const system_prompt, |
1478 | 1 | std::string& request_body) const override { |
1479 | 1 | rapidjson::Document doc; |
1480 | 1 | doc.SetObject(); |
1481 | 1 | auto& allocator = doc.GetAllocator(); |
1482 | | |
1483 | | /* |
1484 | | "model": "claude-opus-4-1-20250805", |
1485 | | "max_tokens": 1024, |
1486 | | "system": "system_prompt here", |
1487 | | "messages": [ |
1488 | | {"role": "user", "content": "xxx"} |
1489 | | ], |
1490 | | "temperature": 0.7 |
1491 | | */ |
1492 | | |
1493 | | // If 'temperature' and 'max_tokens' are set, add them to the request body. |
1494 | 1 | doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), allocator), allocator); |
1495 | 1 | if (_config.temperature != -1) { |
1496 | 1 | doc.AddMember("temperature", _config.temperature, allocator); |
1497 | 1 | } |
1498 | 1 | if (_config.max_tokens != -1) { |
1499 | 1 | doc.AddMember("max_tokens", _config.max_tokens, allocator); |
1500 | 1 | } else { |
1501 | | // Keep the default value, Anthropic requires this parameter |
1502 | 0 | doc.AddMember("max_tokens", 2048, allocator); |
1503 | 0 | } |
1504 | 1 | if (system_prompt && *system_prompt) { |
1505 | 1 | doc.AddMember("system", rapidjson::Value(system_prompt, allocator), allocator); |
1506 | 1 | } |
1507 | | |
1508 | 1 | rapidjson::Value messages(rapidjson::kArrayType); |
1509 | 1 | for (const auto& input : inputs) { |
1510 | 1 | rapidjson::Value message(rapidjson::kObjectType); |
1511 | 1 | message.AddMember("role", "user", allocator); |
1512 | 1 | message.AddMember("content", rapidjson::Value(input.c_str(), allocator), allocator); |
1513 | 1 | messages.PushBack(message, allocator); |
1514 | 1 | } |
1515 | 1 | doc.AddMember("messages", messages, allocator); |
1516 | | |
1517 | 1 | rapidjson::StringBuffer buffer; |
1518 | 1 | rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); |
1519 | 1 | doc.Accept(writer); |
1520 | 1 | request_body = buffer.GetString(); |
1521 | | |
1522 | 1 | return Status::OK(); |
1523 | 1 | } |
1524 | | |
1525 | | Status parse_response(const std::string& response_body, |
1526 | 3 | std::vector<std::string>& results) const override { |
1527 | 3 | rapidjson::Document doc; |
1528 | 3 | doc.Parse(response_body.c_str()); |
1529 | 3 | if (doc.HasParseError() || !doc.IsObject()) { |
1530 | 1 | return Status::InternalError("Failed to parse {} response: {}", _config.provider_type, |
1531 | 1 | response_body); |
1532 | 1 | } |
1533 | 2 | if (!doc.HasMember("content") || !doc["content"].IsArray()) { |
1534 | 1 | return Status::InternalError("Invalid {} response format: {}", _config.provider_type, |
1535 | 1 | response_body); |
1536 | 1 | } |
1537 | | |
1538 | | /*{ |
1539 | | "content": [ |
1540 | | { |
1541 | | "text": "xxx", |
1542 | | "type": "text" |
1543 | | } |
1544 | | ] |
1545 | | }*/ |
1546 | 1 | const auto& content = doc["content"]; |
1547 | 1 | results.reserve(1); |
1548 | | |
1549 | 1 | std::string result; |
1550 | 2 | for (rapidjson::SizeType i = 0; i < content.Size(); i++) { |
1551 | 1 | if (!content[i].HasMember("type") || !content[i]["type"].IsString() || |
1552 | 1 | !content[i].HasMember("text") || !content[i]["text"].IsString()) { |
1553 | 0 | continue; |
1554 | 0 | } |
1555 | | |
1556 | 1 | if (std::string(content[i]["type"].GetString()) == "text") { |
1557 | 1 | if (!result.empty()) { |
1558 | 0 | result += "\n"; |
1559 | 0 | } |
1560 | 1 | result += content[i]["text"].GetString(); |
1561 | 1 | } |
1562 | 1 | } |
1563 | | |
1564 | 1 | return append_parsed_text_result(result, results); |
1565 | 2 | } |
1566 | | }; |
1567 | | |
1568 | | // Mock adapter used only for UT to bypass real HTTP calls and return deterministic data. |
1569 | | class MockAdapter : public AIAdapter { |
1570 | | public: |
1571 | 0 | Status set_authentication(HttpClient* client) const override { return Status::OK(); } |
1572 | | |
1573 | | Status build_request_payload(const std::vector<std::string>& inputs, |
1574 | | const char* const system_prompt, |
1575 | 2 | std::string& request_body) const override { |
1576 | 2 | return Status::OK(); |
1577 | 2 | } |
1578 | | |
1579 | | Status parse_response(const std::string& response_body, |
1580 | 74 | std::vector<std::string>& results) const override { |
1581 | 74 | return append_parsed_text_result(response_body, results); |
1582 | 74 | } |
1583 | | |
1584 | | Status build_embedding_request(const std::vector<std::string>& inputs, |
1585 | 2 | std::string& request_body) const override { |
1586 | 2 | return Status::OK(); |
1587 | 2 | } |
1588 | | |
1589 | | Status build_multimodal_embedding_request( |
1590 | | const std::vector<MultimodalType>& /*media_types*/, |
1591 | | const std::vector<std::string>& /*media_urls*/, |
1592 | | const std::vector<std::string>& /*media_content_types*/, |
1593 | 2 | std::string& /*request_body*/) const override { |
1594 | 2 | return Status::OK(); |
1595 | 2 | } |
1596 | | |
1597 | | Status parse_embedding_response(const std::string& response_body, |
1598 | 0 | std::vector<std::vector<float>>& results) const override { |
1599 | 0 | rapidjson::Document doc; |
1600 | 0 | doc.SetObject(); |
1601 | 0 | doc.Parse(response_body.c_str()); |
1602 | 0 | if (doc.HasParseError() || !doc.IsObject()) { |
1603 | 0 | return Status::InternalError("Failed to parse embedding response"); |
1604 | 0 | } |
1605 | 0 | if (!doc.HasMember("embedding") || !doc["embedding"].IsArray()) { |
1606 | 0 | return Status::InternalError("Invalid embedding response format"); |
1607 | 0 | } |
1608 | | |
1609 | 0 | results.reserve(1); |
1610 | 0 | std::transform(doc["embedding"].Begin(), doc["embedding"].End(), |
1611 | 0 | std::back_inserter(results.emplace_back()), |
1612 | 0 | [](const auto& val) { return val.GetFloat(); }); |
1613 | 0 | return Status::OK(); |
1614 | 0 | } |
1615 | | }; |
1616 | | |
1617 | | class AIAdapterFactory { |
1618 | | public: |
1619 | 106 | static std::shared_ptr<AIAdapter> create_adapter(const std::string& provider_type) { |
1620 | 106 | static const std::unordered_map<std::string, std::function<std::shared_ptr<AIAdapter>()>> |
1621 | 106 | adapters = {{"LOCAL", []() { return std::make_shared<LocalAdapter>(); }}, |
1622 | 106 | {"OPENAI", []() { return std::make_shared<OpenAIAdapter>(); }}, |
1623 | 106 | {"MOONSHOT", []() { return std::make_shared<MoonShotAdapter>(); }}, |
1624 | 106 | {"DEEPSEEK", []() { return std::make_shared<DeepSeekAdapter>(); }}, |
1625 | 106 | {"MINIMAX", []() { return std::make_shared<MinimaxAdapter>(); }}, |
1626 | 106 | {"ZHIPU", []() { return std::make_shared<ZhipuAdapter>(); }}, |
1627 | 106 | {"QWEN", []() { return std::make_shared<QwenAdapter>(); }}, |
1628 | 106 | {"JINA", []() { return std::make_shared<JinaAdapter>(); }}, |
1629 | 106 | {"BAICHUAN", []() { return std::make_shared<BaichuanAdapter>(); }}, |
1630 | 106 | {"ANTHROPIC", []() { return std::make_shared<AnthropicAdapter>(); }}, |
1631 | 106 | {"GEMINI", []() { return std::make_shared<GeminiAdapter>(); }}, |
1632 | 106 | {"VOYAGEAI", []() { return std::make_shared<VoyageAIAdapter>(); }}, |
1633 | 106 | {"MOCK", []() { return std::make_shared<MockAdapter>(); }}}; |
1634 | | |
1635 | 106 | auto it = adapters.find(provider_type); |
1636 | 106 | return (it != adapters.end()) ? it->second() : nullptr; |
1637 | 106 | } |
1638 | | }; |
1639 | | |
1640 | | } // namespace doris |