be/src/exprs/function/ai/ai_functions.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "core/column/column_array.h" |
19 | | #include "exprs/function/ai/ai_classify.h" |
20 | | #include "exprs/function/ai/ai_extract.h" |
21 | | #include "exprs/function/ai/ai_filter.h" |
22 | | #include "exprs/function/ai/ai_fix_grammar.h" |
23 | | #include "exprs/function/ai/ai_generate.h" |
24 | | #include "exprs/function/ai/ai_mask.h" |
25 | | #include "exprs/function/ai/ai_sentiment.h" |
26 | | #include "exprs/function/ai/ai_similarity.h" |
27 | | #include "exprs/function/ai/ai_summarize.h" |
28 | | #include "exprs/function/ai/ai_translate.h" |
29 | | #include "exprs/function/ai/embed.h" |
30 | | #include "exprs/function/simple_function_factory.h" |
31 | | |
32 | | namespace doris { |
33 | | Status FunctionAIClassify::build_prompt(const Block& block, const ColumnNumbers& arguments, |
34 | 1 | size_t row_num, std::string& prompt) const { |
35 | | // Get the text column |
36 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
37 | 1 | StringRef text = text_column.column->get_data_at(row_num); |
38 | 1 | std::string text_str = std::string(text.data, text.size); |
39 | | |
40 | | // Get the labels array column |
41 | 1 | const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
42 | 1 | const auto& [array_column, array_row_num] = |
43 | 1 | check_column_const_set_readability(*labels_column.column, row_num); |
44 | 1 | const ColumnArray* col_array_ptr = nullptr; |
45 | 1 | if (const auto col_array = check_and_get_column<ColumnArray>(*array_column)) { |
46 | 1 | col_array_ptr = col_array.get(); |
47 | 1 | } else { |
48 | 0 | return Status::InternalError( |
49 | 0 | "labels argument for {} must be Array(String) or Array(Varchar)", name); |
50 | 0 | } |
51 | | |
52 | 1 | std::vector<std::string> label_values; |
53 | 1 | const auto& data = col_array_ptr->get_data(); |
54 | 1 | const auto& offsets = col_array_ptr->get_offsets(); |
55 | 1 | size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
56 | 1 | size_t end = offsets[array_row_num]; |
57 | 4 | for (size_t i = start; i < end; ++i) { |
58 | 3 | Field field; |
59 | 3 | data.get(i, field); |
60 | 3 | label_values.emplace_back(field.template get<TYPE_STRING>()); |
61 | 3 | } |
62 | | |
63 | 1 | std::string labels_str = "["; |
64 | 4 | for (size_t i = 0; i < label_values.size(); ++i) { |
65 | 3 | if (i > 0) { |
66 | 2 | labels_str += ", "; |
67 | 2 | } |
68 | 3 | labels_str += "\"" + label_values[i] + "\""; |
69 | 3 | } |
70 | 1 | labels_str += "]"; |
71 | | |
72 | 1 | prompt = "Labels: " + labels_str + "\nText: " + text_str; |
73 | | |
74 | 1 | return Status::OK(); |
75 | 1 | } |
76 | | |
77 | | Status FunctionAIExtract::build_prompt(const Block& block, const ColumnNumbers& arguments, |
78 | 1 | size_t row_num, std::string& prompt) const { |
79 | | // Get the text column |
80 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
81 | 1 | StringRef text = text_column.column->get_data_at(row_num); |
82 | 1 | std::string text_str = std::string(text.data, text.size); |
83 | | |
84 | | // Get the labels array column |
85 | 1 | const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
86 | 1 | const auto& [array_column, array_row_num] = |
87 | 1 | check_column_const_set_readability(*labels_column.column, row_num); |
88 | 1 | const ColumnArray* col_array_ptr = nullptr; |
89 | 1 | if (const auto col_array = check_and_get_column<ColumnArray>(*array_column)) { |
90 | 1 | col_array_ptr = col_array.get(); |
91 | 1 | } else { |
92 | 0 | return Status::InternalError( |
93 | 0 | "labels argument for {} must be Array(String) or Array(Varchar)", name); |
94 | 0 | } |
95 | | |
96 | 1 | std::vector<std::string> label_values; |
97 | 1 | const auto& offsets = col_array_ptr->get_offsets(); |
98 | 1 | const auto& data = col_array_ptr->get_data(); |
99 | 1 | size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
100 | 1 | size_t end = offsets[array_row_num]; |
101 | 3 | for (size_t i = start; i < end; ++i) { |
102 | 2 | Field field; |
103 | 2 | data.get(i, field); |
104 | 2 | label_values.emplace_back(field.template get<TYPE_STRING>()); |
105 | 2 | } |
106 | | |
107 | 1 | std::string labels_str = "["; |
108 | 3 | for (size_t i = 0; i < label_values.size(); ++i) { |
109 | 2 | if (i > 0) { |
110 | 1 | labels_str += ", "; |
111 | 1 | } |
112 | 2 | labels_str += "\"" + label_values[i] + "\""; |
113 | 2 | } |
114 | 1 | labels_str += "]"; |
115 | | |
116 | 1 | prompt = "Labels: " + labels_str + "\nText: " + text_str; |
117 | | |
118 | 1 | return Status::OK(); |
119 | 1 | } |
120 | | |
121 | | Status FunctionAIGenerate::build_prompt(const Block& block, const ColumnNumbers& arguments, |
122 | 1 | size_t row_num, std::string& prompt) const { |
123 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
124 | 1 | StringRef text_ref = text_column.column->get_data_at(row_num); |
125 | 1 | prompt = std::string(text_ref.data, text_ref.size); |
126 | | |
127 | 1 | return Status::OK(); |
128 | 1 | } |
129 | | |
130 | | Status FunctionAIMask::build_prompt(const Block& block, const ColumnNumbers& arguments, |
131 | 1 | size_t row_num, std::string& prompt) const { |
132 | | // Get the text column |
133 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
134 | 1 | StringRef text = text_column.column->get_data_at(row_num); |
135 | 1 | std::string text_str = std::string(text.data, text.size); |
136 | | |
137 | | // Get the labels array column |
138 | 1 | const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
139 | 1 | const auto& [array_column, array_row_num] = |
140 | 1 | check_column_const_set_readability(*labels_column.column, row_num); |
141 | 1 | const ColumnArray* col_array_ptr = nullptr; |
142 | 1 | if (const auto col_array = check_and_get_column<ColumnArray>(*array_column)) { |
143 | 1 | col_array_ptr = col_array.get(); |
144 | 1 | } else { |
145 | 0 | return Status::InternalError( |
146 | 0 | "labels argument for {} must be Array(String) or Array(Varchar)", name); |
147 | 0 | } |
148 | | |
149 | 1 | std::vector<std::string> label_values; |
150 | 1 | const auto& offsets = col_array_ptr->get_offsets(); |
151 | 1 | const auto& data = col_array_ptr->get_data(); |
152 | 1 | size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
153 | 1 | size_t end = offsets[array_row_num]; |
154 | 3 | for (size_t i = start; i < end; ++i) { |
155 | 2 | Field field; |
156 | 2 | data.get(i, field); |
157 | 2 | label_values.emplace_back(field.template get<TYPE_STRING>()); |
158 | 2 | } |
159 | | |
160 | 1 | std::string labels_str = "["; |
161 | 3 | for (size_t i = 0; i < label_values.size(); ++i) { |
162 | 2 | if (i > 0) { |
163 | 1 | labels_str += ", "; |
164 | 1 | } |
165 | 2 | labels_str += "\"" + label_values[i] + "\""; |
166 | 2 | } |
167 | 1 | labels_str += "]"; |
168 | | |
169 | 1 | prompt = "Labels: " + labels_str + "\nText: " + text_str; |
170 | | |
171 | 1 | return Status::OK(); |
172 | 1 | } |
173 | | |
174 | | Status FunctionAISimilarity::build_prompt(const Block& block, const ColumnNumbers& arguments, |
175 | 24 | size_t row_num, std::string& prompt) const { |
176 | | // text1 |
177 | 24 | const ColumnWithTypeAndName& text_column_1 = block.get_by_position(arguments[1]); |
178 | 24 | StringRef text_1 = text_column_1.column.get()->get_data_at(row_num); |
179 | 24 | std::string text_str_1 = std::string(text_1.data, text_1.size); |
180 | | |
181 | | // text2 |
182 | 24 | const ColumnWithTypeAndName& text_column_2 = block.get_by_position(arguments[2]); |
183 | 24 | StringRef text_2 = text_column_2.column.get()->get_data_at(row_num); |
184 | 24 | std::string text_str_2 = std::string(text_2.data, text_2.size); |
185 | | |
186 | 24 | prompt = "Text 1: " + text_str_1 + "\nText 2: " + text_str_2; |
187 | | |
188 | 24 | return Status::OK(); |
189 | 24 | } |
190 | | |
191 | | Status FunctionAITranslate::build_prompt(const Block& block, const ColumnNumbers& arguments, |
192 | 1 | size_t row_num, std::string& prompt) const { |
193 | | // text |
194 | 1 | const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
195 | 1 | StringRef text = text_column.column.get()->get_data_at(row_num); |
196 | 1 | std::string text_str = std::string(text.data, text.size); |
197 | | |
198 | | // target language |
199 | 1 | const ColumnWithTypeAndName& lang_column = block.get_by_position(arguments[2]); |
200 | 1 | StringRef lang = lang_column.column.get()->get_data_at(row_num); |
201 | 1 | std::string target_lang = std::string(lang.data, lang.size); |
202 | | |
203 | 1 | prompt = "Translate the following text to " + target_lang + ".\nText: " + text_str; |
204 | | |
205 | 1 | return Status::OK(); |
206 | 1 | } |
207 | | |
208 | 8 | void register_function_ai(SimpleFunctionFactory& factory) { |
209 | 8 | factory.register_function<FunctionEmbed>(); |
210 | 8 | factory.register_function<FunctionAIClassify>(); |
211 | 8 | factory.register_function<FunctionAIExtract>(); |
212 | 8 | factory.register_function<FunctionAIFilter>(); |
213 | 8 | factory.register_function<FunctionAIFixGrammar>(); |
214 | 8 | factory.register_function<FunctionAIGenerate>(); |
215 | 8 | factory.register_function<FunctionAIMask>(); |
216 | 8 | factory.register_function<FunctionAISentiment>(); |
217 | 8 | factory.register_function<FunctionAISimilarity>(); |
218 | 8 | factory.register_function<FunctionAISummarize>(); |
219 | 8 | factory.register_function<FunctionAITranslate>(); |
220 | 8 | } |
221 | | |
222 | | } // namespace doris |