be/src/exprs/function/ai/ai_functions.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "core/column/column_array.h" |
19 | | #include "exprs/function/ai/ai_classify.h" |
20 | | #include "exprs/function/ai/ai_extract.h" |
21 | | #include "exprs/function/ai/ai_filter.h" |
22 | | #include "exprs/function/ai/ai_fix_grammar.h" |
23 | | #include "exprs/function/ai/ai_generate.h" |
24 | | #include "exprs/function/ai/ai_mask.h" |
25 | | #include "exprs/function/ai/ai_sentiment.h" |
26 | | #include "exprs/function/ai/ai_similarity.h" |
27 | | #include "exprs/function/ai/ai_summarize.h" |
28 | | #include "exprs/function/ai/ai_translate.h" |
29 | | #include "exprs/function/ai/embed.h" |
30 | | #include "exprs/function/simple_function_factory.h" |
31 | | |
32 | | namespace doris { |
33 | | Status FunctionAIClassify::build_prompt(const Block& block, const ColumnNumbers& arguments, |
34 | 1 | size_t row_num, std::string& prompt) const { |
35 | | // Get the text column |
36 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
37 | 1 | StringRef text = text_column.column->get_data_at(row_num); |
38 | 1 | std::string text_str = std::string(text.data, text.size); |
39 | | |
40 | | // Get the labels array column |
41 | 1 | const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
42 | 1 | const auto& [array_column, array_row_num] = |
43 | 1 | check_column_const_set_readability(*labels_column.column, row_num); |
44 | 1 | const auto* col_array = check_and_get_column<ColumnArray>(*array_column); |
45 | 1 | if (col_array == nullptr) { |
46 | 0 | return Status::InternalError( |
47 | 0 | "labels argument for {} must be Array(String) or Array(Varchar)", name); |
48 | 0 | } |
49 | | |
50 | 1 | std::vector<std::string> label_values; |
51 | 1 | const auto& data = col_array->get_data(); |
52 | 1 | const auto& offsets = col_array->get_offsets(); |
53 | 1 | size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
54 | 1 | size_t end = offsets[array_row_num]; |
55 | 4 | for (size_t i = start; i < end; ++i) { |
56 | 3 | Field field; |
57 | 3 | data.get(i, field); |
58 | 3 | label_values.emplace_back(field.template get<TYPE_STRING>()); |
59 | 3 | } |
60 | | |
61 | 1 | std::string labels_str = "["; |
62 | 4 | for (size_t i = 0; i < label_values.size(); ++i) { |
63 | 3 | if (i > 0) { |
64 | 2 | labels_str += ", "; |
65 | 2 | } |
66 | 3 | labels_str += "\"" + label_values[i] + "\""; |
67 | 3 | } |
68 | 1 | labels_str += "]"; |
69 | | |
70 | 1 | prompt = "Labels: " + labels_str + "\nText: " + text_str; |
71 | | |
72 | 1 | return Status::OK(); |
73 | 1 | } |
74 | | |
75 | | Status FunctionAIExtract::build_prompt(const Block& block, const ColumnNumbers& arguments, |
76 | 1 | size_t row_num, std::string& prompt) const { |
77 | | // Get the text column |
78 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
79 | 1 | StringRef text = text_column.column->get_data_at(row_num); |
80 | 1 | std::string text_str = std::string(text.data, text.size); |
81 | | |
82 | | // Get the labels array column |
83 | 1 | const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
84 | 1 | const auto& [array_column, array_row_num] = |
85 | 1 | check_column_const_set_readability(*labels_column.column, row_num); |
86 | 1 | const auto* col_array = check_and_get_column<ColumnArray>(*array_column); |
87 | 1 | if (col_array == nullptr) { |
88 | 0 | return Status::InternalError( |
89 | 0 | "labels argument for {} must be Array(String) or Array(Varchar)", name); |
90 | 0 | } |
91 | | |
92 | 1 | std::vector<std::string> label_values; |
93 | 1 | const auto& offsets = col_array->get_offsets(); |
94 | 1 | const auto& data = col_array->get_data(); |
95 | 1 | size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
96 | 1 | size_t end = offsets[array_row_num]; |
97 | 3 | for (size_t i = start; i < end; ++i) { |
98 | 2 | Field field; |
99 | 2 | data.get(i, field); |
100 | 2 | label_values.emplace_back(field.template get<TYPE_STRING>()); |
101 | 2 | } |
102 | | |
103 | 1 | std::string labels_str = "["; |
104 | 3 | for (size_t i = 0; i < label_values.size(); ++i) { |
105 | 2 | if (i > 0) { |
106 | 1 | labels_str += ", "; |
107 | 1 | } |
108 | 2 | labels_str += "\"" + label_values[i] + "\""; |
109 | 2 | } |
110 | 1 | labels_str += "]"; |
111 | | |
112 | 1 | prompt = "Labels: " + labels_str + "\nText: " + text_str; |
113 | | |
114 | 1 | return Status::OK(); |
115 | 1 | } |
116 | | |
117 | | Status FunctionAIGenerate::build_prompt(const Block& block, const ColumnNumbers& arguments, |
118 | 1 | size_t row_num, std::string& prompt) const { |
119 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
120 | 1 | StringRef text_ref = text_column.column->get_data_at(row_num); |
121 | 1 | prompt = std::string(text_ref.data, text_ref.size); |
122 | | |
123 | 1 | return Status::OK(); |
124 | 1 | } |
125 | | |
126 | | Status FunctionAIMask::build_prompt(const Block& block, const ColumnNumbers& arguments, |
127 | 1 | size_t row_num, std::string& prompt) const { |
128 | | // Get the text column |
129 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
130 | 1 | StringRef text = text_column.column->get_data_at(row_num); |
131 | 1 | std::string text_str = std::string(text.data, text.size); |
132 | | |
133 | | // Get the labels array column |
134 | 1 | const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
135 | 1 | const auto& [array_column, array_row_num] = |
136 | 1 | check_column_const_set_readability(*labels_column.column, row_num); |
137 | 1 | const auto* col_array = check_and_get_column<ColumnArray>(*array_column); |
138 | 1 | if (col_array == nullptr) { |
139 | 0 | return Status::InternalError( |
140 | 0 | "labels argument for {} must be Array(String) or Array(Varchar)", name); |
141 | 0 | } |
142 | | |
143 | 1 | std::vector<std::string> label_values; |
144 | 1 | const auto& offsets = col_array->get_offsets(); |
145 | 1 | const auto& data = col_array->get_data(); |
146 | 1 | size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
147 | 1 | size_t end = offsets[array_row_num]; |
148 | 3 | for (size_t i = start; i < end; ++i) { |
149 | 2 | Field field; |
150 | 2 | data.get(i, field); |
151 | 2 | label_values.emplace_back(field.template get<TYPE_STRING>()); |
152 | 2 | } |
153 | | |
154 | 1 | std::string labels_str = "["; |
155 | 3 | for (size_t i = 0; i < label_values.size(); ++i) { |
156 | 2 | if (i > 0) { |
157 | 1 | labels_str += ", "; |
158 | 1 | } |
159 | 2 | labels_str += "\"" + label_values[i] + "\""; |
160 | 2 | } |
161 | 1 | labels_str += "]"; |
162 | | |
163 | 1 | prompt = "Labels: " + labels_str + "\nText: " + text_str; |
164 | | |
165 | 1 | return Status::OK(); |
166 | 1 | } |
167 | | |
168 | | Status FunctionAISimilarity::build_prompt(const Block& block, const ColumnNumbers& arguments, |
169 | 15 | size_t row_num, std::string& prompt) const { |
170 | | // text1 |
171 | 15 | const ColumnWithTypeAndName& text_column_1 = block.get_by_position(arguments[1]); |
172 | 15 | StringRef text_1 = text_column_1.column.get()->get_data_at(row_num); |
173 | 15 | std::string text_str_1 = std::string(text_1.data, text_1.size); |
174 | | |
175 | | // text2 |
176 | 15 | const ColumnWithTypeAndName& text_column_2 = block.get_by_position(arguments[2]); |
177 | 15 | StringRef text_2 = text_column_2.column.get()->get_data_at(row_num); |
178 | 15 | std::string text_str_2 = std::string(text_2.data, text_2.size); |
179 | | |
180 | 15 | prompt = "Text 1: " + text_str_1 + "\nText 2: " + text_str_2; |
181 | | |
182 | 15 | return Status::OK(); |
183 | 15 | } |
184 | | |
185 | | Status FunctionAITranslate::build_prompt(const Block& block, const ColumnNumbers& arguments, |
186 | 1 | size_t row_num, std::string& prompt) const { |
187 | | // text |
188 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
189 | 1 | StringRef text = text_column.column.get()->get_data_at(row_num); |
190 | 1 | std::string text_str = std::string(text.data, text.size); |
191 | | |
192 | | // target language |
193 | 1 | const ColumnWithTypeAndName& lang_column = block.get_by_position(arguments[2]); |
194 | 1 | StringRef lang = lang_column.column.get()->get_data_at(row_num); |
195 | 1 | std::string target_lang = std::string(lang.data, lang.size); |
196 | | |
197 | 1 | prompt = "Translate the following text to " + target_lang + ".\nText: " + text_str; |
198 | | |
199 | 1 | return Status::OK(); |
200 | 1 | } |
201 | | |
202 | 8 | void register_function_ai(SimpleFunctionFactory& factory) { |
203 | 8 | factory.register_function<FunctionEmbed>(); |
204 | 8 | factory.register_function<FunctionAIClassify>(); |
205 | 8 | factory.register_function<FunctionAIExtract>(); |
206 | 8 | factory.register_function<FunctionAIFilter>(); |
207 | 8 | factory.register_function<FunctionAIFixGrammar>(); |
208 | 8 | factory.register_function<FunctionAIGenerate>(); |
209 | 8 | factory.register_function<FunctionAIMask>(); |
210 | 8 | factory.register_function<FunctionAISentiment>(); |
211 | 8 | factory.register_function<FunctionAISimilarity>(); |
212 | 8 | factory.register_function<FunctionAISummarize>(); |
213 | 8 | factory.register_function<FunctionAITranslate>(); |
214 | 8 | } |
215 | | |
216 | | } // namespace doris |