Coverage Report

Created: 2026-05-18 05:00

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_functions.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "core/column/column_array.h"
19
#include "exprs/function/ai/ai_classify.h"
20
#include "exprs/function/ai/ai_extract.h"
21
#include "exprs/function/ai/ai_filter.h"
22
#include "exprs/function/ai/ai_fix_grammar.h"
23
#include "exprs/function/ai/ai_generate.h"
24
#include "exprs/function/ai/ai_mask.h"
25
#include "exprs/function/ai/ai_sentiment.h"
26
#include "exprs/function/ai/ai_similarity.h"
27
#include "exprs/function/ai/ai_summarize.h"
28
#include "exprs/function/ai/ai_translate.h"
29
#include "exprs/function/ai/embed.h"
30
#include "exprs/function/simple_function_factory.h"
31
32
namespace doris {
33
Status FunctionAIClassify::build_prompt(const Block& block, const ColumnNumbers& arguments,
34
1
                                        size_t row_num, std::string& prompt) const {
35
    // Get the text column
36
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
37
1
    StringRef text = text_column.column->get_data_at(row_num);
38
1
    std::string text_str = std::string(text.data, text.size);
39
40
    // Get the labels array column
41
1
    const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]);
42
1
    const auto& [array_column, array_row_num] =
43
1
            check_column_const_set_readability(*labels_column.column, row_num);
44
1
    const ColumnArray* col_array_ptr = nullptr;
45
1
    if (const auto col_array = check_and_get_column<ColumnArray>(*array_column)) {
46
1
        col_array_ptr = col_array.get();
47
1
    } else {
48
0
        return Status::InternalError(
49
0
                "labels argument for {} must be Array(String) or Array(Varchar)", name);
50
0
    }
51
52
1
    std::vector<std::string> label_values;
53
1
    const auto& data = col_array_ptr->get_data();
54
1
    const auto& offsets = col_array_ptr->get_offsets();
55
1
    size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0;
56
1
    size_t end = offsets[array_row_num];
57
4
    for (size_t i = start; i < end; ++i) {
58
3
        Field field;
59
3
        data.get(i, field);
60
3
        label_values.emplace_back(field.template get<TYPE_STRING>());
61
3
    }
62
63
1
    std::string labels_str = "[";
64
4
    for (size_t i = 0; i < label_values.size(); ++i) {
65
3
        if (i > 0) {
66
2
            labels_str += ", ";
67
2
        }
68
3
        labels_str += "\"" + label_values[i] + "\"";
69
3
    }
70
1
    labels_str += "]";
71
72
1
    prompt = "Labels: " + labels_str + "\nText: " + text_str;
73
74
1
    return Status::OK();
75
1
}
76
77
Status FunctionAIExtract::build_prompt(const Block& block, const ColumnNumbers& arguments,
78
1
                                       size_t row_num, std::string& prompt) const {
79
    // Get the text column
80
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
81
1
    StringRef text = text_column.column->get_data_at(row_num);
82
1
    std::string text_str = std::string(text.data, text.size);
83
84
    // Get the labels array column
85
1
    const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]);
86
1
    const auto& [array_column, array_row_num] =
87
1
            check_column_const_set_readability(*labels_column.column, row_num);
88
1
    const ColumnArray* col_array_ptr = nullptr;
89
1
    if (const auto col_array = check_and_get_column<ColumnArray>(*array_column)) {
90
1
        col_array_ptr = col_array.get();
91
1
    } else {
92
0
        return Status::InternalError(
93
0
                "labels argument for {} must be Array(String) or Array(Varchar)", name);
94
0
    }
95
96
1
    std::vector<std::string> label_values;
97
1
    const auto& offsets = col_array_ptr->get_offsets();
98
1
    const auto& data = col_array_ptr->get_data();
99
1
    size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0;
100
1
    size_t end = offsets[array_row_num];
101
3
    for (size_t i = start; i < end; ++i) {
102
2
        Field field;
103
2
        data.get(i, field);
104
2
        label_values.emplace_back(field.template get<TYPE_STRING>());
105
2
    }
106
107
1
    std::string labels_str = "[";
108
3
    for (size_t i = 0; i < label_values.size(); ++i) {
109
2
        if (i > 0) {
110
1
            labels_str += ", ";
111
1
        }
112
2
        labels_str += "\"" + label_values[i] + "\"";
113
2
    }
114
1
    labels_str += "]";
115
116
1
    prompt = "Labels: " + labels_str + "\nText: " + text_str;
117
118
1
    return Status::OK();
119
1
}
120
121
Status FunctionAIGenerate::build_prompt(const Block& block, const ColumnNumbers& arguments,
122
1
                                        size_t row_num, std::string& prompt) const {
123
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
124
1
    StringRef text_ref = text_column.column->get_data_at(row_num);
125
1
    prompt = std::string(text_ref.data, text_ref.size);
126
127
1
    return Status::OK();
128
1
}
129
130
Status FunctionAIMask::build_prompt(const Block& block, const ColumnNumbers& arguments,
131
1
                                    size_t row_num, std::string& prompt) const {
132
    // Get the text column
133
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
134
1
    StringRef text = text_column.column->get_data_at(row_num);
135
1
    std::string text_str = std::string(text.data, text.size);
136
137
    // Get the labels array column
138
1
    const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]);
139
1
    const auto& [array_column, array_row_num] =
140
1
            check_column_const_set_readability(*labels_column.column, row_num);
141
1
    const ColumnArray* col_array_ptr = nullptr;
142
1
    if (const auto col_array = check_and_get_column<ColumnArray>(*array_column)) {
143
1
        col_array_ptr = col_array.get();
144
1
    } else {
145
0
        return Status::InternalError(
146
0
                "labels argument for {} must be Array(String) or Array(Varchar)", name);
147
0
    }
148
149
1
    std::vector<std::string> label_values;
150
1
    const auto& offsets = col_array_ptr->get_offsets();
151
1
    const auto& data = col_array_ptr->get_data();
152
1
    size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0;
153
1
    size_t end = offsets[array_row_num];
154
3
    for (size_t i = start; i < end; ++i) {
155
2
        Field field;
156
2
        data.get(i, field);
157
2
        label_values.emplace_back(field.template get<TYPE_STRING>());
158
2
    }
159
160
1
    std::string labels_str = "[";
161
3
    for (size_t i = 0; i < label_values.size(); ++i) {
162
2
        if (i > 0) {
163
1
            labels_str += ", ";
164
1
        }
165
2
        labels_str += "\"" + label_values[i] + "\"";
166
2
    }
167
1
    labels_str += "]";
168
169
1
    prompt = "Labels: " + labels_str + "\nText: " + text_str;
170
171
1
    return Status::OK();
172
1
}
173
174
Status FunctionAISimilarity::build_prompt(const Block& block, const ColumnNumbers& arguments,
175
24
                                          size_t row_num, std::string& prompt) const {
176
    // text1
177
24
    const ColumnWithTypeAndName& text_column_1 = block.get_by_position(arguments[1]);
178
24
    StringRef text_1 = text_column_1.column.get()->get_data_at(row_num);
179
24
    std::string text_str_1 = std::string(text_1.data, text_1.size);
180
181
    // text2
182
24
    const ColumnWithTypeAndName& text_column_2 = block.get_by_position(arguments[2]);
183
24
    StringRef text_2 = text_column_2.column.get()->get_data_at(row_num);
184
24
    std::string text_str_2 = std::string(text_2.data, text_2.size);
185
186
24
    prompt = "Text 1: " + text_str_1 + "\nText 2: " + text_str_2;
187
188
24
    return Status::OK();
189
24
}
190
191
Status FunctionAITranslate::build_prompt(const Block& block, const ColumnNumbers& arguments,
192
1
                                         size_t row_num, std::string& prompt) const {
193
    // text
194
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
195
1
    StringRef text = text_column.column.get()->get_data_at(row_num);
196
1
    std::string text_str = std::string(text.data, text.size);
197
198
    // target language
199
1
    const ColumnWithTypeAndName& lang_column = block.get_by_position(arguments[2]);
200
1
    StringRef lang = lang_column.column.get()->get_data_at(row_num);
201
1
    std::string target_lang = std::string(lang.data, lang.size);
202
203
1
    prompt = "Translate the following text to " + target_lang + ".\nText: " + text_str;
204
205
1
    return Status::OK();
206
1
}
207
208
8
void register_function_ai(SimpleFunctionFactory& factory) {
209
8
    factory.register_function<FunctionEmbed>();
210
8
    factory.register_function<FunctionAIClassify>();
211
8
    factory.register_function<FunctionAIExtract>();
212
8
    factory.register_function<FunctionAIFilter>();
213
8
    factory.register_function<FunctionAIFixGrammar>();
214
8
    factory.register_function<FunctionAIGenerate>();
215
8
    factory.register_function<FunctionAIMask>();
216
8
    factory.register_function<FunctionAISentiment>();
217
8
    factory.register_function<FunctionAISimilarity>();
218
8
    factory.register_function<FunctionAISummarize>();
219
8
    factory.register_function<FunctionAITranslate>();
220
8
}
221
222
} // namespace doris