be/src/exec/common/string_searcher.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/Common/StringSearcher.h |
19 | | // and modified by Doris |
20 | | |
21 | | #pragma once |
22 | | |
23 | | #include <stdint.h> |
24 | | #include <string.h> |
25 | | |
26 | | #include <algorithm> |
27 | | #include <limits> |
28 | | #include <vector> |
29 | | |
30 | | #include "core/string_ref.h" |
31 | | #include "exec/common/string_utils/string_utils.h" |
32 | | #include "util/sse_util.hpp" |
33 | | |
34 | | namespace doris { |
35 | | #include "common/compile_check_begin.h" |
36 | | // namespace ErrorCodes |
37 | | // { |
38 | | // extern const int BAD_ARGUMENTS; |
39 | | // } |
40 | | |
41 | | /** Variants for searching a substring in a string. |
42 | | */ |
43 | | |
44 | | class StringSearcherBase { |
45 | | public: |
46 | | bool force_fallback = false; |
47 | | #if defined(__SSE2__) || defined(__aarch64__) |
48 | | protected: |
49 | | static constexpr auto n = sizeof(__m128i); |
50 | | const long page_size = sysconf(_SC_PAGESIZE); //::getPageSize(); |
51 | | |
52 | 25.3M | bool page_safe(const void* const ptr) const { |
53 | 25.3M | return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n; |
54 | 25.3M | } |
55 | | #endif |
56 | | }; |
57 | | |
58 | | /// Performs case-sensitive and case-insensitive search of UTF-8 strings |
59 | | template <bool CaseSensitive, bool ASCII> |
60 | | class StringSearcher; |
61 | | |
62 | | /// Case-sensitive searcher (both ASCII and UTF-8) |
63 | | template <bool ASCII> |
64 | | class StringSearcher<true, ASCII> : public StringSearcherBase { |
65 | | private: |
66 | | /// string to be searched for |
67 | | const uint8_t* const needle; |
68 | | const uint8_t* const needle_end; |
69 | | /// first character in `needle` |
70 | | uint8_t first {}; |
71 | | |
72 | | #if defined(__SSE4_1__) || defined(__aarch64__) |
73 | | uint8_t second {}; |
74 | | /// vector filled `first` or `second` for determining leftmost position of the first and second symbols |
75 | | __m128i first_pattern; |
76 | | __m128i second_pattern; |
77 | | /// vector of first 16 characters of `needle` |
78 | | __m128i cache = _mm_setzero_si128(); |
79 | | int cachemask {}; |
80 | | #endif |
81 | | |
82 | | public: |
83 | | template <typename CharT> |
84 | | // requires (sizeof(CharT) == 1) |
85 | | StringSearcher(const CharT* needle_, const size_t needle_size) |
86 | 6.04k | : needle {reinterpret_cast<const uint8_t*>(needle_)}, |
87 | 6.04k | needle_end {needle + needle_size} { |
88 | 6.04k | if (0 == needle_size) return; |
89 | | |
90 | 5.50k | first = *needle; |
91 | | |
92 | 5.50k | #if defined(__SSE4_1__) || defined(__aarch64__) |
93 | 5.50k | first_pattern = _mm_set1_epi8(first); |
94 | 5.50k | if (needle + 1 < needle_end) { |
95 | 3.68k | second = *(needle + 1); |
96 | 3.68k | second_pattern = _mm_set1_epi8(second); |
97 | 3.68k | } |
98 | 5.50k | const auto* needle_pos = needle; |
99 | | |
100 | | //for (const auto i : collections::range(0, n)) |
101 | 93.5k | for (size_t i = 0; i < n; i++) { |
102 | 88.0k | cache = _mm_srli_si128(cache, 1); |
103 | | |
104 | 88.0k | if (needle_pos != needle_end) { |
105 | 24.3k | cache = _mm_insert_epi8(cache, *needle_pos, n - 1); |
106 | 24.3k | cachemask |= 1 << i; |
107 | 24.3k | ++needle_pos; |
108 | 24.3k | } |
109 | 88.0k | } |
110 | 5.50k | #endif |
111 | 5.50k | } |
112 | | |
113 | | template <typename CharT> |
114 | | // requires (sizeof(CharT) == 1) |
115 | 275k | const CharT* search(const CharT* haystack, size_t haystack_size) const { |
116 | | // cast to unsigned int8 to be consitent with needle type |
117 | | // ensure unsigned type compare |
118 | 275k | return reinterpret_cast<const CharT*>( |
119 | 275k | _search(reinterpret_cast<const uint8_t*>(haystack), haystack_size)); |
120 | 275k | } |
121 | | |
122 | | template <typename CharT> |
123 | | // requires (sizeof(CharT) == 1) |
124 | 168 | const CharT* search(const CharT* haystack, const CharT* haystack_end) const { |
125 | | // cast to unsigned int8 to be consitent with needle type |
126 | | // ensure unsigned type compare |
127 | 168 | return reinterpret_cast<const CharT*>( |
128 | 168 | _search(reinterpret_cast<const uint8_t*>(haystack), |
129 | 168 | reinterpret_cast<const uint8_t*>(haystack_end))); |
130 | 168 | } |
131 | | |
132 | | template <typename CharT> |
133 | | // requires (sizeof(CharT) == 1) |
134 | | ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end, CharT* pos) const { |
135 | | // cast to unsigned int8 to be consitent with needle type |
136 | | // ensure unsigned type compare |
137 | | return _compare(reinterpret_cast<const uint8_t*>(haystack), |
138 | | reinterpret_cast<const uint8_t*>(haystack_end), |
139 | | reinterpret_cast<const uint8_t*>(pos)); |
140 | | } |
141 | | |
142 | | private: |
143 | | ALWAYS_INLINE bool _compare(uint8_t* /*haystack*/, uint8_t* /*haystack_end*/, |
144 | | uint8_t* pos) const { |
145 | | #if defined(__SSE4_1__) || defined(__aarch64__) |
146 | | if (needle_end - needle > n && page_safe(pos)) { |
147 | | const auto v_haystack = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pos)); |
148 | | const auto v_against_cache = _mm_cmpeq_epi8(v_haystack, cache); |
149 | | const auto mask = _mm_movemask_epi8(v_against_cache); |
150 | | |
151 | | if (0xffff == cachemask) { |
152 | | if (mask == cachemask) { |
153 | | pos += n; |
154 | | const auto* needle_pos = needle + n; |
155 | | |
156 | | while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos; |
157 | | |
158 | | if (needle_pos == needle_end) return true; |
159 | | } |
160 | | } else if ((mask & cachemask) == cachemask) |
161 | | return true; |
162 | | |
163 | | return false; |
164 | | } |
165 | | #endif |
166 | | |
167 | | if (*pos == first) { |
168 | | ++pos; |
169 | | const auto* needle_pos = needle + 1; |
170 | | |
171 | | while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos; |
172 | | |
173 | | if (needle_pos == needle_end) return true; |
174 | | } |
175 | | |
176 | | return false; |
177 | | } |
178 | | |
179 | 275k | const uint8_t* _search(const uint8_t* haystack, const uint8_t* haystack_end) const { |
180 | 275k | if (needle == needle_end) return haystack; |
181 | | |
182 | 275k | const auto needle_size = needle_end - needle; |
183 | 275k | #if defined(__SSE4_1__) || defined(__aarch64__) |
184 | | /// Here is the quick path when needle_size is 1. |
185 | 275k | if (needle_size == 1) { |
186 | 73.8k | while (haystack < haystack_end) { |
187 | 72.5k | if (haystack + n <= haystack_end && page_safe(haystack)) { |
188 | 60.9k | const auto v_haystack = |
189 | 60.9k | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack)); |
190 | 60.9k | const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, first_pattern); |
191 | 60.9k | const auto mask = _mm_movemask_epi8(v_against_pattern); |
192 | 60.9k | if (mask == 0) { |
193 | 12.2k | haystack += n; |
194 | 12.2k | continue; |
195 | 12.2k | } |
196 | | |
197 | 48.6k | const auto offset = __builtin_ctz(mask); |
198 | 48.6k | haystack += offset; |
199 | | |
200 | 48.6k | return haystack; |
201 | 60.9k | } |
202 | | |
203 | 11.6k | if (haystack == haystack_end) { |
204 | 0 | return haystack_end; |
205 | 0 | } |
206 | | |
207 | 11.6k | if (*haystack == first) { |
208 | 1.75k | return haystack; |
209 | 1.75k | } |
210 | 9.89k | ++haystack; |
211 | 9.89k | } |
212 | 1.28k | return haystack_end; |
213 | 51.6k | } |
214 | 224k | #endif |
215 | | |
216 | 22.7M | while (haystack < haystack_end && haystack_end - haystack >= needle_size) { |
217 | 22.7M | #if defined(__SSE4_1__) || defined(__aarch64__) |
218 | 22.7M | if ((haystack + 1 + n) <= haystack_end && page_safe(haystack)) { |
219 | | /// find first and second characters |
220 | 21.8M | const auto v_haystack_block_first = |
221 | 21.8M | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack)); |
222 | 21.8M | const auto v_haystack_block_second = |
223 | 21.8M | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack + 1)); |
224 | | |
225 | 21.8M | const auto v_against_pattern_first = |
226 | 21.8M | _mm_cmpeq_epi8(v_haystack_block_first, first_pattern); |
227 | 21.8M | const auto v_against_pattern_second = |
228 | 21.8M | _mm_cmpeq_epi8(v_haystack_block_second, second_pattern); |
229 | | |
230 | 21.8M | const auto mask = _mm_movemask_epi8( |
231 | 21.8M | _mm_and_si128(v_against_pattern_first, v_against_pattern_second)); |
232 | | /// first and second characters not present in 16 octets starting at `haystack` |
233 | 21.8M | if (mask == 0) { |
234 | 18.9M | haystack += n; |
235 | 18.9M | continue; |
236 | 18.9M | } |
237 | | |
238 | 2.83M | const auto offset = __builtin_ctz(mask); |
239 | 2.83M | haystack += offset; |
240 | | |
241 | 2.84M | if (haystack + n <= haystack_end && page_safe(haystack)) { |
242 | | /// check for first 16 octets |
243 | 2.84M | const auto v_haystack_offset = |
244 | 2.84M | _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack)); |
245 | 2.84M | const auto v_against_cache = _mm_cmpeq_epi8(v_haystack_offset, cache); |
246 | 2.84M | const auto mask_offset = _mm_movemask_epi8(v_against_cache); |
247 | | |
248 | 2.84M | if (0xffff == cachemask) { |
249 | 16.1k | if (mask_offset == cachemask) { |
250 | 24 | const auto* haystack_pos = haystack + n; |
251 | 24 | const auto* needle_pos = needle + n; |
252 | | |
253 | 170 | while (haystack_pos < haystack_end && needle_pos < needle_end && |
254 | 170 | *haystack_pos == *needle_pos) |
255 | 146 | ++haystack_pos, ++needle_pos; |
256 | | |
257 | 24 | if (needle_pos == needle_end) return haystack; |
258 | 24 | } |
259 | 2.82M | } else if ((mask_offset & cachemask) == cachemask) |
260 | 162k | return haystack; |
261 | | |
262 | 2.67M | ++haystack; |
263 | 2.67M | continue; |
264 | 2.84M | } |
265 | 2.83M | } |
266 | 912k | #endif |
267 | | |
268 | 912k | if (haystack == haystack_end) return haystack_end; |
269 | | |
270 | 912k | if (*haystack == first) { |
271 | 26.0k | const auto* haystack_pos = haystack + 1; |
272 | 26.0k | const auto* needle_pos = needle + 1; |
273 | | |
274 | 49.4k | while (haystack_pos < haystack_end && needle_pos < needle_end && |
275 | 49.4k | *haystack_pos == *needle_pos) |
276 | 23.4k | ++haystack_pos, ++needle_pos; |
277 | | |
278 | 26.0k | if (needle_pos == needle_end) return haystack; |
279 | 26.0k | } |
280 | | |
281 | 909k | ++haystack; |
282 | 909k | } |
283 | | |
284 | 59.3k | return haystack_end; |
285 | 224k | } |
286 | | |
287 | 275k | const uint8_t* _search(const uint8_t* haystack, const size_t haystack_size) const { |
288 | 275k | return _search(haystack, haystack + haystack_size); |
289 | 275k | } |
290 | | }; |
291 | | |
292 | | // Searches for needle surrounded by token-separators. |
293 | | // Separators are anything inside ASCII (0-128) and not alphanum. |
294 | | // Any value outside of basic ASCII (>=128) is considered a non-separator symbol, hence UTF-8 strings |
295 | | // should work just fine. But any Unicode whitespace is not considered a token separtor. |
296 | | template <typename StringSearcher> |
297 | | class TokenSearcher : public StringSearcherBase { |
298 | | StringSearcher searcher; |
299 | | size_t needle_size; |
300 | | |
301 | | public: |
302 | | template <typename CharT> |
303 | | // requires (sizeof(CharT) == 1) |
304 | | TokenSearcher(const CharT* needle_, const size_t needle_size_) |
305 | | : searcher {needle_, needle_size_}, needle_size(needle_size_) { |
306 | | if (std::any_of(needle_, needle_ + needle_size_, isTokenSeparator)) { |
307 | | //throw Exception{"Needle must not contain whitespace or separator characters", ErrorCodes::BAD_ARGUMENTS}; |
308 | | } |
309 | | } |
310 | | |
311 | | template <typename CharT> |
312 | | // requires (sizeof(CharT) == 1) |
313 | | ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end, |
314 | | const CharT* pos) const { |
315 | | // use searcher only if pos is in the beginning of token and pos + searcher.needle_size is end of token. |
316 | | if (isToken(haystack, haystack_end, pos)) |
317 | | return searcher.compare(haystack, haystack_end, pos); |
318 | | |
319 | | return false; |
320 | | } |
321 | | |
322 | | template <typename CharT> |
323 | | // requires (sizeof(CharT) == 1) |
324 | | const CharT* search(const CharT* haystack, const CharT* const haystack_end) const { |
325 | | // use searcher.search(), then verify that returned value is a token |
326 | | // if it is not, skip it and re-run |
327 | | |
328 | | const auto* pos = haystack; |
329 | | while (pos < haystack_end) { |
330 | | pos = searcher.search(pos, haystack_end); |
331 | | if (pos == haystack_end || isToken(haystack, haystack_end, pos)) return pos; |
332 | | |
333 | | // assuming that heendle does not contain any token separators. |
334 | | pos += needle_size; |
335 | | } |
336 | | return haystack_end; |
337 | | } |
338 | | |
339 | | template <typename CharT> |
340 | | // requires (sizeof(CharT) == 1) |
341 | | const CharT* search(const CharT* haystack, const size_t haystack_size) const { |
342 | | return search(haystack, haystack + haystack_size); |
343 | | } |
344 | | |
345 | | template <typename CharT> |
346 | | // requires (sizeof(CharT) == 1) |
347 | | ALWAYS_INLINE bool isToken(const CharT* haystack, const CharT* const haystack_end, |
348 | | const CharT* p) const { |
349 | | return (p == haystack || isTokenSeparator(*(p - 1))) && |
350 | | (p + needle_size >= haystack_end || isTokenSeparator(*(p + needle_size))); |
351 | | } |
352 | | |
353 | | ALWAYS_INLINE static bool isTokenSeparator(const uint8_t c) { |
354 | | return !(is_alpha_numeric_ascii(c) || !is_ascii(c)); |
355 | | } |
356 | | }; |
357 | | |
358 | | using ASCIICaseSensitiveStringSearcher = StringSearcher<true, true>; |
359 | | // using ASCIICaseInsensitiveStringSearcher = StringSearcher<false, true>; |
360 | | using UTF8CaseSensitiveStringSearcher = StringSearcher<true, false>; |
361 | | // using UTF8CaseInsensitiveStringSearcher = StringSearcher<false, false>; |
362 | | using ASCIICaseSensitiveTokenSearcher = TokenSearcher<ASCIICaseSensitiveStringSearcher>; |
363 | | // using ASCIICaseInsensitiveTokenSearcher = TokenSearcher<ASCIICaseInsensitiveStringSearcher>; |
364 | | |
365 | | /** Uses functions from libc. |
366 | | * It makes sense to use only with short haystacks when cheap initialization is required. |
367 | | * There is no option for case-insensitive search for UTF-8 strings. |
368 | | * It is required that strings are zero-terminated. |
369 | | */ |
370 | | |
371 | | struct LibCASCIICaseSensitiveStringSearcher : public StringSearcherBase { |
372 | | const char* const needle; |
373 | | |
374 | | template <typename CharT> |
375 | | // requires (sizeof(CharT) == 1) |
376 | | LibCASCIICaseSensitiveStringSearcher(const CharT* const needle_, const size_t /* needle_size */) |
377 | | : needle(reinterpret_cast<const char*>(needle_)) {} |
378 | | |
379 | | template <typename CharT> |
380 | | // requires (sizeof(CharT) == 1) |
381 | | const CharT* search(const CharT* haystack, const CharT* const haystack_end) const { |
382 | | const auto* res = strstr(reinterpret_cast<const char*>(haystack), |
383 | | reinterpret_cast<const char*>(needle)); |
384 | | if (!res) return haystack_end; |
385 | | return reinterpret_cast<const CharT*>(res); |
386 | | } |
387 | | |
388 | | template <typename CharT> |
389 | | // requires (sizeof(CharT) == 1) |
390 | | const CharT* search(const CharT* haystack, const size_t haystack_size) const { |
391 | | return search(haystack, haystack + haystack_size); |
392 | | } |
393 | | }; |
394 | | |
395 | | struct LibCASCIICaseInsensitiveStringSearcher : public StringSearcherBase { |
396 | | const char* const needle; |
397 | | |
398 | | template <typename CharT> |
399 | | // requires (sizeof(CharT) == 1) |
400 | | LibCASCIICaseInsensitiveStringSearcher(const CharT* const needle_, |
401 | | const size_t /* needle_size */) |
402 | | : needle(reinterpret_cast<const char*>(needle_)) {} |
403 | | |
404 | | template <typename CharT> |
405 | | // requires (sizeof(CharT) == 1) |
406 | | const CharT* search(const CharT* haystack, const CharT* const haystack_end) const { |
407 | | const auto* res = strcasestr(reinterpret_cast<const char*>(haystack), |
408 | | reinterpret_cast<const char*>(needle)); |
409 | | if (!res) return haystack_end; |
410 | | return reinterpret_cast<const CharT*>(res); |
411 | | } |
412 | | |
413 | | template <typename CharT> |
414 | | // requires (sizeof(CharT) == 1) |
415 | | const CharT* search(const CharT* haystack, const size_t haystack_size) const { |
416 | | return search(haystack, haystack + haystack_size); |
417 | | } |
418 | | }; |
419 | | #include "common/compile_check_end.h" |
420 | | } // namespace doris |