be/src/exprs/function/function_levenshtein.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include <algorithm> |
19 | | #include <cstring> |
20 | | #include <vector> |
21 | | |
22 | | #include "common/status.h" |
23 | | #include "core/data_type/data_type_number.h" |
24 | | #include "core/string_ref.h" |
25 | | #include "exprs/function/function_totype.h" |
26 | | #include "exprs/function/simple_function_factory.h" |
27 | | #include "util/simd/vstring_function.h" |
28 | | |
29 | | namespace doris { |
30 | | #include "common/compile_check_begin.h" |
31 | | |
32 | | struct NameLevenshtein { |
33 | | static constexpr auto name = "levenshtein"; |
34 | | }; |
35 | | |
36 | | template <typename LeftDataType, typename RightDataType> |
37 | | struct LevenshteinImpl { |
38 | | using ResultDataType = DataTypeInt32; |
39 | | using ResultPaddedPODArray = PaddedPODArray<Int32>; |
40 | | |
41 | | static Status vector_vector(const ColumnString::Chars& ldata, |
42 | | const ColumnString::Offsets& loffsets, |
43 | | const ColumnString::Chars& rdata, |
44 | 2 | const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { |
45 | 2 | DCHECK_EQ(loffsets.size(), roffsets.size()); |
46 | | |
47 | 2 | const size_t size = loffsets.size(); |
48 | 2 | res.resize(size); |
49 | 2 | std::vector<size_t> left_offsets; |
50 | 2 | std::vector<size_t> right_offsets; |
51 | 16 | for (size_t i = 0; i < size; ++i) { |
52 | 14 | res[i] = levenshtein_distance(string_ref_at(ldata, loffsets, i), |
53 | 14 | string_ref_at(rdata, roffsets, i), left_offsets, |
54 | 14 | right_offsets); |
55 | 14 | } |
56 | 2 | return Status::OK(); |
57 | 2 | } |
58 | | |
59 | | static Status vector_scalar(const ColumnString::Chars& ldata, |
60 | | const ColumnString::Offsets& loffsets, const StringRef& rdata, |
61 | 0 | ResultPaddedPODArray& res) { |
62 | 0 | const size_t size = loffsets.size(); |
63 | 0 | res.resize(size); |
64 | 0 | const bool right_ascii = simd::VStringFunctions::is_ascii(rdata); |
65 | 0 | std::vector<size_t> right_offsets; |
66 | 0 | utf8_char_offsets(rdata, right_offsets); |
67 | 0 | std::vector<size_t> left_offsets; |
68 | 0 | for (size_t i = 0; i < size; ++i) { |
69 | 0 | res[i] = levenshtein_distance_with_right_offsets(string_ref_at(ldata, loffsets, i), |
70 | 0 | left_offsets, rdata, right_offsets, |
71 | 0 | right_ascii); |
72 | 0 | } |
73 | 0 | return Status::OK(); |
74 | 0 | } |
75 | | |
76 | | static Status scalar_vector(const StringRef& ldata, const ColumnString::Chars& rdata, |
77 | 0 | const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { |
78 | 0 | const size_t size = roffsets.size(); |
79 | 0 | res.resize(size); |
80 | 0 | const bool left_ascii = simd::VStringFunctions::is_ascii(ldata); |
81 | 0 | std::vector<size_t> left_offsets; |
82 | 0 | utf8_char_offsets(ldata, left_offsets); |
83 | 0 | std::vector<size_t> right_offsets; |
84 | 0 | for (size_t i = 0; i < size; ++i) { |
85 | 0 | res[i] = levenshtein_distance_with_left_offsets(ldata, left_offsets, left_ascii, |
86 | 0 | string_ref_at(rdata, roffsets, i), |
87 | 0 | right_offsets); |
88 | 0 | } |
89 | 0 | return Status::OK(); |
90 | 0 | } |
91 | | |
92 | | private: |
93 | | static StringRef string_ref_at(const ColumnString::Chars& data, |
94 | 28 | const ColumnString::Offsets& offsets, size_t i) { |
95 | 28 | DCHECK_LT(i, offsets.size()); |
96 | 28 | const auto idx = static_cast<ssize_t>(i); |
97 | 28 | return StringRef(data.data() + offsets[idx - 1], offsets[idx] - offsets[idx - 1]) |
98 | 28 | .trim_tail_padding_zero(); |
99 | 28 | } |
100 | | |
101 | 8 | static void utf8_char_offsets(const StringRef& ref, std::vector<size_t>& offsets) { |
102 | 8 | offsets.clear(); |
103 | 8 | offsets.reserve(ref.size); |
104 | 8 | simd::VStringFunctions::get_char_len(ref.data, ref.size, offsets); |
105 | 8 | } |
106 | | |
107 | | static bool utf8_char_equal(const StringRef& left, size_t left_off, size_t left_next, |
108 | 24 | const StringRef& right, size_t right_off, size_t right_next) { |
109 | 24 | const size_t left_len = left_next - left_off; |
110 | 24 | const size_t right_len = right_next - right_off; |
111 | 24 | return left_len == right_len && |
112 | 24 | std::memcmp(left.data + left_off, right.data + right_off, left_len) == 0; |
113 | 24 | } |
114 | | |
115 | | static Int32 levenshtein_distance_utf8(const StringRef& left, |
116 | | const std::vector<size_t>& left_offsets, |
117 | | const StringRef& right, |
118 | 4 | const std::vector<size_t>& right_offsets) { |
119 | 4 | const StringRef* left_ref = &left; |
120 | 4 | const StringRef* right_ref = &right; |
121 | 4 | const std::vector<size_t>* left_offsets_ref = &left_offsets; |
122 | 4 | const std::vector<size_t>* right_offsets_ref = &right_offsets; |
123 | 4 | if (right_offsets_ref->size() > left_offsets_ref->size()) { |
124 | 0 | std::swap(left_offsets_ref, right_offsets_ref); |
125 | 0 | std::swap(left_ref, right_ref); |
126 | 0 | } |
127 | | |
128 | 4 | const size_t m = left_offsets_ref->size(); |
129 | 4 | const size_t n = right_offsets_ref->size(); |
130 | | |
131 | 4 | std::vector<Int32> prev(n + 1); |
132 | 4 | std::vector<Int32> curr(n + 1); |
133 | 16 | for (size_t j = 0; j <= n; ++j) { |
134 | 12 | prev[j] = static_cast<Int32>(j); |
135 | 12 | } |
136 | | |
137 | 16 | for (size_t i = 1; i <= m; ++i) { |
138 | 12 | curr[0] = static_cast<Int32>(i); |
139 | 12 | const size_t left_off = (*left_offsets_ref)[i - 1]; |
140 | 12 | const size_t left_next = i < m ? (*left_offsets_ref)[i] : left_ref->size; |
141 | | |
142 | 36 | for (size_t j = 1; j <= n; ++j) { |
143 | 24 | const size_t right_off = (*right_offsets_ref)[j - 1]; |
144 | 24 | const size_t right_next = j < n ? (*right_offsets_ref)[j] : right_ref->size; |
145 | | |
146 | 24 | const Int32 cost = utf8_char_equal(*left_ref, left_off, left_next, *right_ref, |
147 | 24 | right_off, right_next) |
148 | 24 | ? 0 |
149 | 24 | : 1; |
150 | | |
151 | 24 | const Int32 insert_cost = curr[j - 1] + 1; |
152 | 24 | const Int32 delete_cost = prev[j] + 1; |
153 | 24 | const Int32 replace_cost = prev[j - 1] + cost; |
154 | 24 | curr[j] = std::min(std::min(insert_cost, delete_cost), replace_cost); |
155 | 24 | } |
156 | 12 | std::swap(prev, curr); |
157 | 12 | } |
158 | | |
159 | 4 | return prev[n]; |
160 | 4 | } |
161 | | |
162 | 8 | static Int32 levenshtein_distance_ascii(const StringRef& left, const StringRef& right) { |
163 | 8 | const StringRef* left_ref = &left; |
164 | 8 | const StringRef* right_ref = &right; |
165 | 8 | size_t m = left.size; |
166 | 8 | size_t n = right.size; |
167 | | |
168 | 8 | if (n > m) { |
169 | 4 | std::swap(left_ref, right_ref); |
170 | 4 | std::swap(m, n); |
171 | 4 | } |
172 | | |
173 | 8 | std::vector<Int32> prev(n + 1); |
174 | 8 | std::vector<Int32> curr(n + 1); |
175 | 42 | for (size_t j = 0; j <= n; ++j) { |
176 | 34 | prev[j] = static_cast<Int32>(j); |
177 | 34 | } |
178 | | |
179 | 42 | for (size_t i = 1; i <= m; ++i) { |
180 | 34 | curr[0] = static_cast<Int32>(i); |
181 | 34 | const char left_char = left_ref->data[i - 1]; |
182 | | |
183 | 168 | for (size_t j = 1; j <= n; ++j) { |
184 | 134 | const Int32 cost = left_char == right_ref->data[j - 1] ? 0 : 1; |
185 | 134 | const Int32 insert_cost = curr[j - 1] + 1; |
186 | 134 | const Int32 delete_cost = prev[j] + 1; |
187 | 134 | const Int32 replace_cost = prev[j - 1] + cost; |
188 | 134 | curr[j] = std::min(std::min(insert_cost, delete_cost), replace_cost); |
189 | 134 | } |
190 | 34 | std::swap(prev, curr); |
191 | 34 | } |
192 | | |
193 | 8 | return prev[n]; |
194 | 8 | } |
195 | | |
196 | | static Int32 levenshtein_distance(const StringRef& left, const StringRef& right, |
197 | | std::vector<size_t>& left_offsets, |
198 | 14 | std::vector<size_t>& right_offsets) { |
199 | 14 | const bool left_ascii = simd::VStringFunctions::is_ascii(left); |
200 | 14 | const bool right_ascii = simd::VStringFunctions::is_ascii(right); |
201 | 14 | if (left_ascii && right_ascii) { |
202 | 8 | return levenshtein_distance_ascii(left, right); |
203 | 8 | } |
204 | | |
205 | 6 | if (left.size == 0) { |
206 | 2 | return static_cast<Int32>(simd::VStringFunctions::get_char_len(right.data, right.size)); |
207 | 2 | } |
208 | 4 | if (right.size == 0) { |
209 | 0 | return static_cast<Int32>(simd::VStringFunctions::get_char_len(left.data, left.size)); |
210 | 0 | } |
211 | | |
212 | 4 | utf8_char_offsets(left, left_offsets); |
213 | 4 | utf8_char_offsets(right, right_offsets); |
214 | 4 | return levenshtein_distance_utf8(left, left_offsets, right, right_offsets); |
215 | 4 | } |
216 | | |
217 | | static Int32 levenshtein_distance_with_right_offsets(const StringRef& left, |
218 | | std::vector<size_t>& left_offsets, |
219 | | const StringRef& right, |
220 | | const std::vector<size_t>& right_offsets, |
221 | 0 | bool right_ascii) { |
222 | 0 | const bool left_ascii = simd::VStringFunctions::is_ascii(left); |
223 | 0 | if (left_ascii && right_ascii) { |
224 | 0 | return levenshtein_distance_ascii(left, right); |
225 | 0 | } |
226 | | |
227 | 0 | if (left.size == 0) { |
228 | 0 | return static_cast<Int32>(right_offsets.size()); |
229 | 0 | } |
230 | 0 | if (right.size == 0) { |
231 | 0 | return left_ascii ? static_cast<Int32>(left.size) |
232 | 0 | : static_cast<Int32>( |
233 | 0 | simd::VStringFunctions::get_char_len(left.data, left.size)); |
234 | 0 | } |
235 | | |
236 | 0 | utf8_char_offsets(left, left_offsets); |
237 | 0 | return levenshtein_distance_utf8(left, left_offsets, right, right_offsets); |
238 | 0 | } |
239 | | |
240 | | static Int32 levenshtein_distance_with_left_offsets(const StringRef& left, |
241 | | const std::vector<size_t>& left_offsets, |
242 | | bool left_ascii, const StringRef& right, |
243 | 0 | std::vector<size_t>& right_offsets) { |
244 | 0 | const bool right_ascii = simd::VStringFunctions::is_ascii(right); |
245 | 0 | if (left_ascii && right_ascii) { |
246 | 0 | return levenshtein_distance_ascii(left, right); |
247 | 0 | } |
248 | | |
249 | 0 | if (left.size == 0) { |
250 | 0 | return static_cast<Int32>( |
251 | 0 | right_ascii ? right.size |
252 | 0 | : simd::VStringFunctions::get_char_len(right.data, right.size)); |
253 | 0 | } |
254 | 0 | if (right.size == 0) { |
255 | 0 | return static_cast<Int32>(left_offsets.size()); |
256 | 0 | } |
257 | | |
258 | 0 | utf8_char_offsets(right, right_offsets); |
259 | 0 | return levenshtein_distance_utf8(left, left_offsets, right, right_offsets); |
260 | 0 | } |
261 | | |
262 | | static Int32 levenshtein_distance(const StringRef& left, const StringRef& right) { |
263 | | std::vector<size_t> left_offsets; |
264 | | std::vector<size_t> right_offsets; |
265 | | return levenshtein_distance(left, right, left_offsets, right_offsets); |
266 | | } |
267 | | }; |
268 | | |
269 | | using FunctionLevenshtein = |
270 | | FunctionBinaryToType<DataTypeString, DataTypeString, LevenshteinImpl, NameLevenshtein>; |
271 | | |
272 | 8 | void register_function_levenshtein(SimpleFunctionFactory& factory) { |
273 | 8 | factory.register_function<FunctionLevenshtein>(); |
274 | 8 | } |
275 | | |
276 | | #include "common/compile_check_end.h" |
277 | | } // namespace doris |