be/src/exec/common/string_searcher.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/Common/StringSearcher.h |
19 | | // and modified by Doris |
20 | | |
21 | | #pragma once |
22 | | |
23 | | #include <stdint.h> |
24 | | #include <string.h> |
25 | | |
26 | | #include <algorithm> |
27 | | #include <limits> |
28 | | #include <vector> |
29 | | |
30 | | #include "core/string_ref.h" |
31 | | #include "exec/common/string_utils/string_utils.h" |
32 | | #include "util/sse_util.hpp" |
33 | | |
34 | | namespace doris { |
35 | | // namespace ErrorCodes |
36 | | // { |
37 | | // extern const int BAD_ARGUMENTS; |
38 | | // } |
39 | | |
40 | | /** Variants for searching a substring in a string. |
41 | | */ |
42 | | |
43 | | class StringSearcherBase { |
44 | | public: |
45 | | bool force_fallback = false; |
46 | | #if defined(__SSE2__) || defined(__aarch64__) |
47 | | protected: |
48 | | static constexpr auto n = sizeof(__m128i); |
49 | | const long page_size = sysconf(_SC_PAGESIZE); //::getPageSize(); |
50 | | |
51 | 2.72k | bool page_safe(const void* const ptr) const { |
52 | 2.72k | return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n; |
53 | 2.72k | } |
54 | | #endif |
55 | | }; |
56 | | |
57 | | /// Performs case-sensitive and case-insensitive search of UTF-8 strings |
58 | | template <bool CaseSensitive, bool ASCII> |
59 | | class StringSearcher; |
60 | | |
61 | | /// Case-sensitive searcher (both ASCII and UTF-8) |
62 | | template <bool ASCII> |
63 | | class StringSearcher<true, ASCII> : public StringSearcherBase { |
64 | | private: |
65 | | /// string to be searched for |
66 | | const uint8_t* const needle; |
67 | | const uint8_t* const needle_end; |
68 | | /// first character in `needle` |
69 | | uint8_t first {}; |
70 | | |
71 | | #if defined(__SSE4_1__) || defined(__aarch64__) |
72 | | uint8_t second {}; |
73 | | /// vector filled `first` or `second` for determining leftmost position of the first and second symbols |
74 | | __m128i first_pattern; |
75 | | __m128i second_pattern; |
76 | | /// vector of first 16 characters of `needle` |
77 | | __m128i cache = _mm_setzero_si128(); |
78 | | int cachemask {}; |
79 | | #endif |
80 | | |
81 | | public: |
82 | | template <typename CharT> |
83 | | // requires (sizeof(CharT) == 1) |
84 | | StringSearcher(const CharT* needle_, const size_t needle_size) |
85 | 1.27k | : needle {reinterpret_cast<const uint8_t*>(needle_)}, |
86 | 1.27k | needle_end {needle + needle_size} { |
87 | 1.27k | if (0 == needle_size) return; |
88 | | |
89 | 1.05k | first = *needle; |
90 | | |
91 | 1.05k | #if defined(__SSE4_1__) || defined(__aarch64__) |
92 | 1.05k | first_pattern = _mm_set1_epi8(first); |
93 | 1.05k | if (needle + 1 < needle_end) { |
94 | 781 | second = *(needle + 1); |
95 | 781 | second_pattern = _mm_set1_epi8(second); |
96 | 781 | } |
97 | 1.05k | const auto* needle_pos = needle; |
98 | | |
99 | | //for (const auto i : collections::range(0, n)) |
100 | 17.9k | for (size_t i = 0; i < n; i++) { |
101 | 16.8k | cache = _mm_srli_si128(cache, 1); |
102 | | |
103 | 16.8k | if (needle_pos != needle_end) { |
104 | 6.43k | cache = _mm_insert_epi8(cache, *needle_pos, n - 1); |
105 | 6.43k | cachemask |= 1 << i; |
106 | 6.43k | ++needle_pos; |
107 | 6.43k | } |
108 | 16.8k | } |
109 | 1.05k | #endif |
110 | 1.05k | } |
111 | | |
112 | | template <typename CharT> |
113 | | // requires (sizeof(CharT) == 1) |
114 | 888 | const CharT* search(const CharT* haystack, size_t haystack_size) const { |
115 | | // cast to unsigned int8 to be consitent with needle type |
116 | | // ensure unsigned type compare |
117 | 888 | return reinterpret_cast<const CharT*>( |
118 | 888 | _search(reinterpret_cast<const uint8_t*>(haystack), haystack_size)); |
119 | 888 | } |
120 | | |
121 | | template <typename CharT> |
122 | | // requires (sizeof(CharT) == 1) |
123 | 0 | const CharT* search(const CharT* haystack, const CharT* haystack_end) const { |
124 | | // cast to unsigned int8 to be consitent with needle type |
125 | | // ensure unsigned type compare |
126 | 0 | return reinterpret_cast<const CharT*>( |
127 | 0 | _search(reinterpret_cast<const uint8_t*>(haystack), |
128 | 0 | reinterpret_cast<const uint8_t*>(haystack_end))); |
129 | 0 | } |
130 | | |
131 | | template <typename CharT> |
132 | | // requires (sizeof(CharT) == 1) |
133 | | ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end, CharT* pos) const { |
134 | | // cast to unsigned int8 to be consitent with needle type |
135 | | // ensure unsigned type compare |
136 | | return _compare(reinterpret_cast<const uint8_t*>(haystack), |
137 | | reinterpret_cast<const uint8_t*>(haystack_end), |
138 | | reinterpret_cast<const uint8_t*>(pos)); |
139 | | } |
140 | | |
141 | | private: |
142 | | ALWAYS_INLINE bool _compare(uint8_t* /*haystack*/, uint8_t* /*haystack_end*/, |
143 | | uint8_t* pos) const { |
144 | | #if defined(__SSE4_1__) || defined(__aarch64__) |
145 | | if (needle_end - needle > n && page_safe(pos)) { |
146 | | const auto v_haystack = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pos)); |
147 | | const auto v_against_cache = _mm_cmpeq_epi8(v_haystack, cache); |
148 | | const auto mask = _mm_movemask_epi8(v_against_cache); |
149 | | |
150 | | if (0xffff == cachemask) { |
151 | | if (mask == cachemask) { |
152 | | pos += n; |
153 | | const auto* needle_pos = needle + n; |
154 | | |
155 | | while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos; |
156 | | |
157 | | if (needle_pos == needle_end) return true; |
158 | | } |
159 | | } else if ((mask & cachemask) == cachemask) |
160 | | return true; |
161 | | |
162 | | return false; |
163 | | } |
164 | | #endif |
165 | | |
166 | | if (*pos == first) { |
167 | | ++pos; |
168 | | const auto* needle_pos = needle + 1; |
169 | | |
170 | | while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos; |
171 | | |
172 | | if (needle_pos == needle_end) return true; |
173 | | } |
174 | | |
175 | | return false; |
176 | | } |
177 | | |
178 | 888 | const uint8_t* _search(const uint8_t* haystack, const uint8_t* haystack_end) const { |
179 | 888 | if (needle == needle_end) return haystack; |
180 | | |
181 | 888 | const auto needle_size = needle_end - needle; |
182 | 888 | #if defined(__SSE4_1__) || defined(__aarch64__) |
183 | | /// Here is the quick path when needle_size is 1. |
184 | 888 | if (needle_size == 1) { |
185 | 3.94k | while (haystack < haystack_end) { |
186 | 3.80k | if (haystack + n <= haystack_end && page_safe(haystack)) { |
187 | 2.40k | const auto v_haystack = |
188 | 2.40k | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack)); |
189 | 2.40k | const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, first_pattern); |
190 | 2.40k | const auto mask = _mm_movemask_epi8(v_against_pattern); |
191 | 2.40k | if (mask == 0) { |
192 | 2.30k | haystack += n; |
193 | 2.30k | continue; |
194 | 2.30k | } |
195 | | |
196 | 98 | const auto offset = __builtin_ctz(mask); |
197 | 98 | haystack += offset; |
198 | | |
199 | 98 | return haystack; |
200 | 2.40k | } |
201 | | |
202 | 1.40k | if (haystack == haystack_end) { |
203 | 0 | return haystack_end; |
204 | 0 | } |
205 | | |
206 | 1.40k | if (*haystack == first) { |
207 | 123 | return haystack; |
208 | 123 | } |
209 | 1.28k | ++haystack; |
210 | 1.28k | } |
211 | 142 | return haystack_end; |
212 | 363 | } |
213 | 525 | #endif |
214 | | |
215 | 1.29k | while (haystack < haystack_end && haystack_end - haystack >= needle_size) { |
216 | 1.04k | #if defined(__SSE4_1__) || defined(__aarch64__) |
217 | 1.04k | if ((haystack + 1 + n) <= haystack_end && page_safe(haystack)) { |
218 | | /// find first and second characters |
219 | 194 | const auto v_haystack_block_first = |
220 | 194 | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack)); |
221 | 194 | const auto v_haystack_block_second = |
222 | 194 | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack + 1)); |
223 | | |
224 | 194 | const auto v_against_pattern_first = |
225 | 194 | _mm_cmpeq_epi8(v_haystack_block_first, first_pattern); |
226 | 194 | const auto v_against_pattern_second = |
227 | 194 | _mm_cmpeq_epi8(v_haystack_block_second, second_pattern); |
228 | | |
229 | 194 | const auto mask = _mm_movemask_epi8( |
230 | 194 | _mm_and_si128(v_against_pattern_first, v_against_pattern_second)); |
231 | | /// first and second characters not present in 16 octets starting at `haystack` |
232 | 194 | if (mask == 0) { |
233 | 55 | haystack += n; |
234 | 55 | continue; |
235 | 55 | } |
236 | | |
237 | 139 | const auto offset = __builtin_ctz(mask); |
238 | 139 | haystack += offset; |
239 | | |
240 | 139 | if (haystack + n <= haystack_end && page_safe(haystack)) { |
241 | | /// check for first 16 octets |
242 | 124 | const auto v_haystack_offset = |
243 | 124 | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack)); |
244 | 124 | const auto v_against_cache = _mm_cmpeq_epi8(v_haystack_offset, cache); |
245 | 124 | const auto mask_offset = _mm_movemask_epi8(v_against_cache); |
246 | | |
247 | 124 | if (0xffff == cachemask) { |
248 | 16 | if (mask_offset == cachemask) { |
249 | 16 | const auto* haystack_pos = haystack + n; |
250 | 16 | const auto* needle_pos = needle + n; |
251 | | |
252 | 136 | while (haystack_pos < haystack_end && needle_pos < needle_end && |
253 | 136 | *haystack_pos == *needle_pos) |
254 | 120 | ++haystack_pos, ++needle_pos; |
255 | | |
256 | 16 | if (needle_pos == needle_end) return haystack; |
257 | 16 | } |
258 | 108 | } else if ((mask_offset & cachemask) == cachemask) |
259 | 108 | return haystack; |
260 | | |
261 | 0 | ++haystack; |
262 | 0 | continue; |
263 | 124 | } |
264 | 139 | } |
265 | 866 | #endif |
266 | | |
267 | 866 | if (haystack == haystack_end) return haystack_end; |
268 | | |
269 | 866 | if (*haystack == first) { |
270 | 186 | const auto* haystack_pos = haystack + 1; |
271 | 186 | const auto* needle_pos = needle + 1; |
272 | | |
273 | 735 | while (haystack_pos < haystack_end && needle_pos < needle_end && |
274 | 735 | *haystack_pos == *needle_pos) |
275 | 549 | ++haystack_pos, ++needle_pos; |
276 | | |
277 | 186 | if (needle_pos == needle_end) return haystack; |
278 | 186 | } |
279 | | |
280 | 717 | ++haystack; |
281 | 717 | } |
282 | | |
283 | 252 | return haystack_end; |
284 | 525 | } |
285 | | |
286 | 888 | const uint8_t* _search(const uint8_t* haystack, const size_t haystack_size) const { |
287 | 888 | return _search(haystack, haystack + haystack_size); |
288 | 888 | } |
289 | | }; |
290 | | |
291 | | // Searches for needle surrounded by token-separators. |
292 | | // Separators are anything inside ASCII (0-128) and not alphanum. |
293 | | // Any value outside of basic ASCII (>=128) is considered a non-separator symbol, hence UTF-8 strings |
294 | | // should work just fine. But any Unicode whitespace is not considered a token separtor. |
295 | | template <typename StringSearcher> |
296 | | class TokenSearcher : public StringSearcherBase { |
297 | | StringSearcher searcher; |
298 | | size_t needle_size; |
299 | | |
300 | | public: |
301 | | template <typename CharT> |
302 | | // requires (sizeof(CharT) == 1) |
303 | | TokenSearcher(const CharT* needle_, const size_t needle_size_) |
304 | | : searcher {needle_, needle_size_}, needle_size(needle_size_) { |
305 | | if (std::any_of(needle_, needle_ + needle_size_, isTokenSeparator)) { |
306 | | //throw Exception{"Needle must not contain whitespace or separator characters", ErrorCodes::BAD_ARGUMENTS}; |
307 | | } |
308 | | } |
309 | | |
310 | | template <typename CharT> |
311 | | // requires (sizeof(CharT) == 1) |
312 | | ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end, |
313 | | const CharT* pos) const { |
314 | | // use searcher only if pos is in the beginning of token and pos + searcher.needle_size is end of token. |
315 | | if (isToken(haystack, haystack_end, pos)) |
316 | | return searcher.compare(haystack, haystack_end, pos); |
317 | | |
318 | | return false; |
319 | | } |
320 | | |
321 | | template <typename CharT> |
322 | | // requires (sizeof(CharT) == 1) |
323 | | const CharT* search(const CharT* haystack, const CharT* const haystack_end) const { |
324 | | // use searcher.search(), then verify that returned value is a token |
325 | | // if it is not, skip it and re-run |
326 | | |
327 | | const auto* pos = haystack; |
328 | | while (pos < haystack_end) { |
329 | | pos = searcher.search(pos, haystack_end); |
330 | | if (pos == haystack_end || isToken(haystack, haystack_end, pos)) return pos; |
331 | | |
332 | | // assuming that heendle does not contain any token separators. |
333 | | pos += needle_size; |
334 | | } |
335 | | return haystack_end; |
336 | | } |
337 | | |
338 | | template <typename CharT> |
339 | | // requires (sizeof(CharT) == 1) |
340 | | const CharT* search(const CharT* haystack, const size_t haystack_size) const { |
341 | | return search(haystack, haystack + haystack_size); |
342 | | } |
343 | | |
344 | | template <typename CharT> |
345 | | // requires (sizeof(CharT) == 1) |
346 | | ALWAYS_INLINE bool isToken(const CharT* haystack, const CharT* const haystack_end, |
347 | | const CharT* p) const { |
348 | | return (p == haystack || isTokenSeparator(*(p - 1))) && |
349 | | (p + needle_size >= haystack_end || isTokenSeparator(*(p + needle_size))); |
350 | | } |
351 | | |
352 | | ALWAYS_INLINE static bool isTokenSeparator(const uint8_t c) { |
353 | | return !(is_alpha_numeric_ascii(c) || !is_ascii(c)); |
354 | | } |
355 | | }; |
356 | | |
357 | | using ASCIICaseSensitiveStringSearcher = StringSearcher<true, true>; |
358 | | // using ASCIICaseInsensitiveStringSearcher = StringSearcher<false, true>; |
359 | | using UTF8CaseSensitiveStringSearcher = StringSearcher<true, false>; |
360 | | // using UTF8CaseInsensitiveStringSearcher = StringSearcher<false, false>; |
361 | | using ASCIICaseSensitiveTokenSearcher = TokenSearcher<ASCIICaseSensitiveStringSearcher>; |
362 | | // using ASCIICaseInsensitiveTokenSearcher = TokenSearcher<ASCIICaseInsensitiveStringSearcher>; |
363 | | |
364 | | /** Uses functions from libc. |
365 | | * It makes sense to use only with short haystacks when cheap initialization is required. |
366 | | * There is no option for case-insensitive search for UTF-8 strings. |
367 | | * It is required that strings are zero-terminated. |
368 | | */ |
369 | | |
370 | | struct LibCASCIICaseSensitiveStringSearcher : public StringSearcherBase { |
371 | | const char* const needle; |
372 | | |
373 | | template <typename CharT> |
374 | | // requires (sizeof(CharT) == 1) |
375 | | LibCASCIICaseSensitiveStringSearcher(const CharT* const needle_, const size_t /* needle_size */) |
376 | | : needle(reinterpret_cast<const char*>(needle_)) {} |
377 | | |
378 | | template <typename CharT> |
379 | | // requires (sizeof(CharT) == 1) |
380 | | const CharT* search(const CharT* haystack, const CharT* const haystack_end) const { |
381 | | const auto* res = strstr(reinterpret_cast<const char*>(haystack), |
382 | | reinterpret_cast<const char*>(needle)); |
383 | | if (!res) return haystack_end; |
384 | | return reinterpret_cast<const CharT*>(res); |
385 | | } |
386 | | |
387 | | template <typename CharT> |
388 | | // requires (sizeof(CharT) == 1) |
389 | | const CharT* search(const CharT* haystack, const size_t haystack_size) const { |
390 | | return search(haystack, haystack + haystack_size); |
391 | | } |
392 | | }; |
393 | | |
394 | | struct LibCASCIICaseInsensitiveStringSearcher : public StringSearcherBase { |
395 | | const char* const needle; |
396 | | |
397 | | template <typename CharT> |
398 | | // requires (sizeof(CharT) == 1) |
399 | | LibCASCIICaseInsensitiveStringSearcher(const CharT* const needle_, |
400 | | const size_t /* needle_size */) |
401 | | : needle(reinterpret_cast<const char*>(needle_)) {} |
402 | | |
403 | | template <typename CharT> |
404 | | // requires (sizeof(CharT) == 1) |
405 | | const CharT* search(const CharT* haystack, const CharT* const haystack_end) const { |
406 | | const auto* res = strcasestr(reinterpret_cast<const char*>(haystack), |
407 | | reinterpret_cast<const char*>(needle)); |
408 | | if (!res) return haystack_end; |
409 | | return reinterpret_cast<const CharT*>(res); |
410 | | } |
411 | | |
412 | | template <typename CharT> |
413 | | // requires (sizeof(CharT) == 1) |
414 | | const CharT* search(const CharT* haystack, const size_t haystack_size) const { |
415 | | return search(haystack, haystack + haystack_size); |
416 | | } |
417 | | }; |
418 | | } // namespace doris |