Coverage Report

Created: 2026-03-12 17:42

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/function/ai/ai_functions.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "core/column/column_array.h"
19
#include "exprs/function/ai/ai_classify.h"
20
#include "exprs/function/ai/ai_extract.h"
21
#include "exprs/function/ai/ai_filter.h"
22
#include "exprs/function/ai/ai_fix_grammar.h"
23
#include "exprs/function/ai/ai_generate.h"
24
#include "exprs/function/ai/ai_mask.h"
25
#include "exprs/function/ai/ai_sentiment.h"
26
#include "exprs/function/ai/ai_similarity.h"
27
#include "exprs/function/ai/ai_summarize.h"
28
#include "exprs/function/ai/ai_translate.h"
29
#include "exprs/function/ai/embed.h"
30
#include "exprs/function/simple_function_factory.h"
31
32
namespace doris {
33
Status FunctionAIClassify::build_prompt(const Block& block, const ColumnNumbers& arguments,
34
1
                                        size_t row_num, std::string& prompt) const {
35
    // Get the text column
36
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
37
1
    StringRef text = text_column.column->get_data_at(row_num);
38
1
    std::string text_str = std::string(text.data, text.size);
39
40
    // Get the labels array column
41
1
    const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]);
42
1
    const auto& [array_column, array_row_num] =
43
1
            check_column_const_set_readability(*labels_column.column, row_num);
44
1
    const auto* col_array = check_and_get_column<ColumnArray>(*array_column);
45
1
    if (col_array == nullptr) {
46
0
        return Status::InternalError(
47
0
                "labels argument for {} must be Array(String) or Array(Varchar)", name);
48
0
    }
49
50
1
    std::vector<std::string> label_values;
51
1
    const auto& data = col_array->get_data();
52
1
    const auto& offsets = col_array->get_offsets();
53
1
    size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0;
54
1
    size_t end = offsets[array_row_num];
55
4
    for (size_t i = start; i < end; ++i) {
56
3
        Field field;
57
3
        data.get(i, field);
58
3
        label_values.emplace_back(field.template get<TYPE_STRING>());
59
3
    }
60
61
1
    std::string labels_str = "[";
62
4
    for (size_t i = 0; i < label_values.size(); ++i) {
63
3
        if (i > 0) {
64
2
            labels_str += ", ";
65
2
        }
66
3
        labels_str += "\"" + label_values[i] + "\"";
67
3
    }
68
1
    labels_str += "]";
69
70
1
    prompt = "Labels: " + labels_str + "\nText: " + text_str;
71
72
1
    return Status::OK();
73
1
}
74
75
Status FunctionAIExtract::build_prompt(const Block& block, const ColumnNumbers& arguments,
76
1
                                       size_t row_num, std::string& prompt) const {
77
    // Get the text column
78
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
79
1
    StringRef text = text_column.column->get_data_at(row_num);
80
1
    std::string text_str = std::string(text.data, text.size);
81
82
    // Get the labels array column
83
1
    const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]);
84
1
    const auto& [array_column, array_row_num] =
85
1
            check_column_const_set_readability(*labels_column.column, row_num);
86
1
    const auto* col_array = check_and_get_column<ColumnArray>(*array_column);
87
1
    if (col_array == nullptr) {
88
0
        return Status::InternalError(
89
0
                "labels argument for {} must be Array(String) or Array(Varchar)", name);
90
0
    }
91
92
1
    std::vector<std::string> label_values;
93
1
    const auto& offsets = col_array->get_offsets();
94
1
    const auto& data = col_array->get_data();
95
1
    size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0;
96
1
    size_t end = offsets[array_row_num];
97
3
    for (size_t i = start; i < end; ++i) {
98
2
        Field field;
99
2
        data.get(i, field);
100
2
        label_values.emplace_back(field.template get<TYPE_STRING>());
101
2
    }
102
103
1
    std::string labels_str = "[";
104
3
    for (size_t i = 0; i < label_values.size(); ++i) {
105
2
        if (i > 0) {
106
1
            labels_str += ", ";
107
1
        }
108
2
        labels_str += "\"" + label_values[i] + "\"";
109
2
    }
110
1
    labels_str += "]";
111
112
1
    prompt = "Labels: " + labels_str + "\nText: " + text_str;
113
114
1
    return Status::OK();
115
1
}
116
117
Status FunctionAIGenerate::build_prompt(const Block& block, const ColumnNumbers& arguments,
118
1
                                        size_t row_num, std::string& prompt) const {
119
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
120
1
    StringRef text_ref = text_column.column->get_data_at(row_num);
121
1
    prompt = std::string(text_ref.data, text_ref.size);
122
123
1
    return Status::OK();
124
1
}
125
126
Status FunctionAIMask::build_prompt(const Block& block, const ColumnNumbers& arguments,
127
1
                                    size_t row_num, std::string& prompt) const {
128
    // Get the text column
129
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
130
1
    StringRef text = text_column.column->get_data_at(row_num);
131
1
    std::string text_str = std::string(text.data, text.size);
132
133
    // Get the labels array column
134
1
    const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]);
135
1
    const auto& [array_column, array_row_num] =
136
1
            check_column_const_set_readability(*labels_column.column, row_num);
137
1
    const auto* col_array = check_and_get_column<ColumnArray>(*array_column);
138
1
    if (col_array == nullptr) {
139
0
        return Status::InternalError(
140
0
                "labels argument for {} must be Array(String) or Array(Varchar)", name);
141
0
    }
142
143
1
    std::vector<std::string> label_values;
144
1
    const auto& offsets = col_array->get_offsets();
145
1
    const auto& data = col_array->get_data();
146
1
    size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0;
147
1
    size_t end = offsets[array_row_num];
148
3
    for (size_t i = start; i < end; ++i) {
149
2
        Field field;
150
2
        data.get(i, field);
151
2
        label_values.emplace_back(field.template get<TYPE_STRING>());
152
2
    }
153
154
1
    std::string labels_str = "[";
155
3
    for (size_t i = 0; i < label_values.size(); ++i) {
156
2
        if (i > 0) {
157
1
            labels_str += ", ";
158
1
        }
159
2
        labels_str += "\"" + label_values[i] + "\"";
160
2
    }
161
1
    labels_str += "]";
162
163
1
    prompt = "Labels: " + labels_str + "\nText: " + text_str;
164
165
1
    return Status::OK();
166
1
}
167
168
Status FunctionAISimilarity::build_prompt(const Block& block, const ColumnNumbers& arguments,
169
15
                                          size_t row_num, std::string& prompt) const {
170
    // text1
171
15
    const ColumnWithTypeAndName& text_column_1 = block.get_by_position(arguments[1]);
172
15
    StringRef text_1 = text_column_1.column.get()->get_data_at(row_num);
173
15
    std::string text_str_1 = std::string(text_1.data, text_1.size);
174
175
    // text2
176
15
    const ColumnWithTypeAndName& text_column_2 = block.get_by_position(arguments[2]);
177
15
    StringRef text_2 = text_column_2.column.get()->get_data_at(row_num);
178
15
    std::string text_str_2 = std::string(text_2.data, text_2.size);
179
180
15
    prompt = "Text 1: " + text_str_1 + "\nText 2: " + text_str_2;
181
182
15
    return Status::OK();
183
15
}
184
185
Status FunctionAITranslate::build_prompt(const Block& block, const ColumnNumbers& arguments,
186
1
                                         size_t row_num, std::string& prompt) const {
187
    // text
188
1
    const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]);
189
1
    StringRef text = text_column.column.get()->get_data_at(row_num);
190
1
    std::string text_str = std::string(text.data, text.size);
191
192
    // target language
193
1
    const ColumnWithTypeAndName& lang_column = block.get_by_position(arguments[2]);
194
1
    StringRef lang = lang_column.column.get()->get_data_at(row_num);
195
1
    std::string target_lang = std::string(lang.data, lang.size);
196
197
1
    prompt = "Translate the following text to " + target_lang + ".\nText: " + text_str;
198
199
1
    return Status::OK();
200
1
}
201
202
8
void register_function_ai(SimpleFunctionFactory& factory) {
203
8
    factory.register_function<FunctionEmbed>();
204
8
    factory.register_function<FunctionAIClassify>();
205
8
    factory.register_function<FunctionAIExtract>();
206
8
    factory.register_function<FunctionAIFilter>();
207
8
    factory.register_function<FunctionAIFixGrammar>();
208
8
    factory.register_function<FunctionAIGenerate>();
209
8
    factory.register_function<FunctionAIMask>();
210
8
    factory.register_function<FunctionAISentiment>();
211
8
    factory.register_function<FunctionAISimilarity>();
212
8
    factory.register_function<FunctionAISummarize>();
213
8
    factory.register_function<FunctionAITranslate>();
214
8
}
215
216
} // namespace doris