Coverage Report

Created: 2026-03-13 09:58

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exec/common/string_searcher.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Common/StringSearcher.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <stdint.h>
24
#include <string.h>
25
26
#include <algorithm>
27
#include <limits>
28
#include <vector>
29
30
#include "core/string_ref.h"
31
#include "exec/common/string_utils/string_utils.h"
32
#include "util/sse_util.hpp"
33
34
namespace doris {
35
#include "common/compile_check_begin.h"
36
// namespace ErrorCodes
37
// {
38
//     extern const int BAD_ARGUMENTS;
39
// }
40
41
/** Variants for searching a substring in a string.
42
  */
43
44
class StringSearcherBase {
45
public:
46
    bool force_fallback = false;
47
#if defined(__SSE2__) || defined(__aarch64__)
48
protected:
49
    static constexpr auto n = sizeof(__m128i);
50
    const long page_size = sysconf(_SC_PAGESIZE); //::getPageSize();
51
52
25.3M
    bool page_safe(const void* const ptr) const {
53
25.3M
        return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
54
25.3M
    }
55
#endif
56
};
57
58
/// Performs case-sensitive and case-insensitive search of UTF-8 strings
59
template <bool CaseSensitive, bool ASCII>
60
class StringSearcher;
61
62
/// Case-sensitive searcher (both ASCII and UTF-8)
63
template <bool ASCII>
64
class StringSearcher<true, ASCII> : public StringSearcherBase {
65
private:
66
    /// string to be searched for
67
    const uint8_t* const needle;
68
    const uint8_t* const needle_end;
69
    /// first character in `needle`
70
    uint8_t first {};
71
72
#if defined(__SSE4_1__) || defined(__aarch64__)
73
    uint8_t second {};
74
    /// vector filled `first` or `second` for determining leftmost position of the first and second symbols
75
    __m128i first_pattern;
76
    __m128i second_pattern;
77
    /// vector of first 16 characters of `needle`
78
    __m128i cache = _mm_setzero_si128();
79
    int cachemask {};
80
#endif
81
82
public:
83
    template <typename CharT>
84
    // requires (sizeof(CharT) == 1)
85
    StringSearcher(const CharT* needle_, const size_t needle_size)
86
6.04k
            : needle {reinterpret_cast<const uint8_t*>(needle_)},
87
6.04k
              needle_end {needle + needle_size} {
88
6.04k
        if (0 == needle_size) return;
89
90
5.50k
        first = *needle;
91
92
5.50k
#if defined(__SSE4_1__) || defined(__aarch64__)
93
5.50k
        first_pattern = _mm_set1_epi8(first);
94
5.50k
        if (needle + 1 < needle_end) {
95
3.68k
            second = *(needle + 1);
96
3.68k
            second_pattern = _mm_set1_epi8(second);
97
3.68k
        }
98
5.50k
        const auto* needle_pos = needle;
99
100
        //for (const auto i : collections::range(0, n))
101
93.5k
        for (size_t i = 0; i < n; i++) {
102
88.0k
            cache = _mm_srli_si128(cache, 1);
103
104
88.0k
            if (needle_pos != needle_end) {
105
24.3k
                cache = _mm_insert_epi8(cache, *needle_pos, n - 1);
106
24.3k
                cachemask |= 1 << i;
107
24.3k
                ++needle_pos;
108
24.3k
            }
109
88.0k
        }
110
5.50k
#endif
111
5.50k
    }
112
113
    template <typename CharT>
114
    // requires (sizeof(CharT) == 1)
115
275k
    const CharT* search(const CharT* haystack, size_t haystack_size) const {
116
        // cast to unsigned int8 to be consitent with needle type
117
        // ensure unsigned type compare
118
275k
        return reinterpret_cast<const CharT*>(
119
275k
                _search(reinterpret_cast<const uint8_t*>(haystack), haystack_size));
120
275k
    }
121
122
    template <typename CharT>
123
    // requires (sizeof(CharT) == 1)
124
168
    const CharT* search(const CharT* haystack, const CharT* haystack_end) const {
125
        // cast to unsigned int8 to be consitent with needle type
126
        // ensure unsigned type compare
127
168
        return reinterpret_cast<const CharT*>(
128
168
                _search(reinterpret_cast<const uint8_t*>(haystack),
129
168
                        reinterpret_cast<const uint8_t*>(haystack_end)));
130
168
    }
131
132
    template <typename CharT>
133
    // requires (sizeof(CharT) == 1)
134
    ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end, CharT* pos) const {
135
        // cast to unsigned int8 to be consitent with needle type
136
        // ensure unsigned type compare
137
        return _compare(reinterpret_cast<const uint8_t*>(haystack),
138
                        reinterpret_cast<const uint8_t*>(haystack_end),
139
                        reinterpret_cast<const uint8_t*>(pos));
140
    }
141
142
private:
143
    ALWAYS_INLINE bool _compare(uint8_t* /*haystack*/, uint8_t* /*haystack_end*/,
144
                                uint8_t* pos) const {
145
#if defined(__SSE4_1__) || defined(__aarch64__)
146
        if (needle_end - needle > n && page_safe(pos)) {
147
            const auto v_haystack = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pos));
148
            const auto v_against_cache = _mm_cmpeq_epi8(v_haystack, cache);
149
            const auto mask = _mm_movemask_epi8(v_against_cache);
150
151
            if (0xffff == cachemask) {
152
                if (mask == cachemask) {
153
                    pos += n;
154
                    const auto* needle_pos = needle + n;
155
156
                    while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos;
157
158
                    if (needle_pos == needle_end) return true;
159
                }
160
            } else if ((mask & cachemask) == cachemask)
161
                return true;
162
163
            return false;
164
        }
165
#endif
166
167
        if (*pos == first) {
168
            ++pos;
169
            const auto* needle_pos = needle + 1;
170
171
            while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos;
172
173
            if (needle_pos == needle_end) return true;
174
        }
175
176
        return false;
177
    }
178
179
275k
    const uint8_t* _search(const uint8_t* haystack, const uint8_t* haystack_end) const {
180
275k
        if (needle == needle_end) return haystack;
181
182
275k
        const auto needle_size = needle_end - needle;
183
275k
#if defined(__SSE4_1__) || defined(__aarch64__)
184
        /// Here is the quick path when needle_size is 1.
185
275k
        if (needle_size == 1) {
186
73.8k
            while (haystack < haystack_end) {
187
72.5k
                if (haystack + n <= haystack_end && page_safe(haystack)) {
188
60.9k
                    const auto v_haystack =
189
60.9k
                            _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack));
190
60.9k
                    const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, first_pattern);
191
60.9k
                    const auto mask = _mm_movemask_epi8(v_against_pattern);
192
60.9k
                    if (mask == 0) {
193
12.2k
                        haystack += n;
194
12.2k
                        continue;
195
12.2k
                    }
196
197
48.6k
                    const auto offset = __builtin_ctz(mask);
198
48.6k
                    haystack += offset;
199
200
48.6k
                    return haystack;
201
60.9k
                }
202
203
11.6k
                if (haystack == haystack_end) {
204
0
                    return haystack_end;
205
0
                }
206
207
11.6k
                if (*haystack == first) {
208
1.75k
                    return haystack;
209
1.75k
                }
210
9.89k
                ++haystack;
211
9.89k
            }
212
1.28k
            return haystack_end;
213
51.6k
        }
214
224k
#endif
215
216
22.7M
        while (haystack < haystack_end && haystack_end - haystack >= needle_size) {
217
22.7M
#if defined(__SSE4_1__) || defined(__aarch64__)
218
22.7M
            if ((haystack + 1 + n) <= haystack_end && page_safe(haystack)) {
219
                /// find first and second characters
220
21.8M
                const auto v_haystack_block_first =
221
21.8M
                        _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack));
222
21.8M
                const auto v_haystack_block_second =
223
21.8M
                        _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack + 1));
224
225
21.8M
                const auto v_against_pattern_first =
226
21.8M
                        _mm_cmpeq_epi8(v_haystack_block_first, first_pattern);
227
21.8M
                const auto v_against_pattern_second =
228
21.8M
                        _mm_cmpeq_epi8(v_haystack_block_second, second_pattern);
229
230
21.8M
                const auto mask = _mm_movemask_epi8(
231
21.8M
                        _mm_and_si128(v_against_pattern_first, v_against_pattern_second));
232
                /// first and second characters not present in 16 octets starting at `haystack`
233
21.8M
                if (mask == 0) {
234
18.9M
                    haystack += n;
235
18.9M
                    continue;
236
18.9M
                }
237
238
2.83M
                const auto offset = __builtin_ctz(mask);
239
2.83M
                haystack += offset;
240
241
2.84M
                if (haystack + n <= haystack_end && page_safe(haystack)) {
242
                    /// check for first 16 octets
243
2.84M
                    const auto v_haystack_offset =
244
2.84M
                            _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack));
245
2.84M
                    const auto v_against_cache = _mm_cmpeq_epi8(v_haystack_offset, cache);
246
2.84M
                    const auto mask_offset = _mm_movemask_epi8(v_against_cache);
247
248
2.84M
                    if (0xffff == cachemask) {
249
16.1k
                        if (mask_offset == cachemask) {
250
24
                            const auto* haystack_pos = haystack + n;
251
24
                            const auto* needle_pos = needle + n;
252
253
170
                            while (haystack_pos < haystack_end && needle_pos < needle_end &&
254
170
                                   *haystack_pos == *needle_pos)
255
146
                                ++haystack_pos, ++needle_pos;
256
257
24
                            if (needle_pos == needle_end) return haystack;
258
24
                        }
259
2.82M
                    } else if ((mask_offset & cachemask) == cachemask)
260
162k
                        return haystack;
261
262
2.67M
                    ++haystack;
263
2.67M
                    continue;
264
2.84M
                }
265
2.83M
            }
266
912k
#endif
267
268
912k
            if (haystack == haystack_end) return haystack_end;
269
270
912k
            if (*haystack == first) {
271
26.0k
                const auto* haystack_pos = haystack + 1;
272
26.0k
                const auto* needle_pos = needle + 1;
273
274
49.4k
                while (haystack_pos < haystack_end && needle_pos < needle_end &&
275
49.4k
                       *haystack_pos == *needle_pos)
276
23.4k
                    ++haystack_pos, ++needle_pos;
277
278
26.0k
                if (needle_pos == needle_end) return haystack;
279
26.0k
            }
280
281
909k
            ++haystack;
282
909k
        }
283
284
59.3k
        return haystack_end;
285
224k
    }
286
287
275k
    const uint8_t* _search(const uint8_t* haystack, const size_t haystack_size) const {
288
275k
        return _search(haystack, haystack + haystack_size);
289
275k
    }
290
};
291
292
// Searches for needle surrounded by token-separators.
293
// Separators are anything inside ASCII (0-128) and not alphanum.
294
// Any value outside of basic ASCII (>=128) is considered a non-separator symbol, hence UTF-8 strings
295
// should work just fine. But any Unicode whitespace is not considered a token separtor.
296
template <typename StringSearcher>
297
class TokenSearcher : public StringSearcherBase {
298
    StringSearcher searcher;
299
    size_t needle_size;
300
301
public:
302
    template <typename CharT>
303
    // requires (sizeof(CharT) == 1)
304
    TokenSearcher(const CharT* needle_, const size_t needle_size_)
305
            : searcher {needle_, needle_size_}, needle_size(needle_size_) {
306
        if (std::any_of(needle_, needle_ + needle_size_, isTokenSeparator)) {
307
            //throw Exception{"Needle must not contain whitespace or separator characters", ErrorCodes::BAD_ARGUMENTS};
308
        }
309
    }
310
311
    template <typename CharT>
312
    // requires (sizeof(CharT) == 1)
313
    ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end,
314
                               const CharT* pos) const {
315
        // use searcher only if pos is in the beginning of token and pos + searcher.needle_size is end of token.
316
        if (isToken(haystack, haystack_end, pos))
317
            return searcher.compare(haystack, haystack_end, pos);
318
319
        return false;
320
    }
321
322
    template <typename CharT>
323
    // requires (sizeof(CharT) == 1)
324
    const CharT* search(const CharT* haystack, const CharT* const haystack_end) const {
325
        // use searcher.search(), then verify that returned value is a token
326
        // if it is not, skip it and re-run
327
328
        const auto* pos = haystack;
329
        while (pos < haystack_end) {
330
            pos = searcher.search(pos, haystack_end);
331
            if (pos == haystack_end || isToken(haystack, haystack_end, pos)) return pos;
332
333
            // assuming that heendle does not contain any token separators.
334
            pos += needle_size;
335
        }
336
        return haystack_end;
337
    }
338
339
    template <typename CharT>
340
    // requires (sizeof(CharT) == 1)
341
    const CharT* search(const CharT* haystack, const size_t haystack_size) const {
342
        return search(haystack, haystack + haystack_size);
343
    }
344
345
    template <typename CharT>
346
    // requires (sizeof(CharT) == 1)
347
    ALWAYS_INLINE bool isToken(const CharT* haystack, const CharT* const haystack_end,
348
                               const CharT* p) const {
349
        return (p == haystack || isTokenSeparator(*(p - 1))) &&
350
               (p + needle_size >= haystack_end || isTokenSeparator(*(p + needle_size)));
351
    }
352
353
    ALWAYS_INLINE static bool isTokenSeparator(const uint8_t c) {
354
        return !(is_alpha_numeric_ascii(c) || !is_ascii(c));
355
    }
356
};
357
358
using ASCIICaseSensitiveStringSearcher = StringSearcher<true, true>;
359
// using ASCIICaseInsensitiveStringSearcher = StringSearcher<false, true>;
360
using UTF8CaseSensitiveStringSearcher = StringSearcher<true, false>;
361
// using UTF8CaseInsensitiveStringSearcher = StringSearcher<false, false>;
362
using ASCIICaseSensitiveTokenSearcher = TokenSearcher<ASCIICaseSensitiveStringSearcher>;
363
// using ASCIICaseInsensitiveTokenSearcher = TokenSearcher<ASCIICaseInsensitiveStringSearcher>;
364
365
/** Uses functions from libc.
366
  * It makes sense to use only with short haystacks when cheap initialization is required.
367
  * There is no option for case-insensitive search for UTF-8 strings.
368
  * It is required that strings are zero-terminated.
369
  */
370
371
struct LibCASCIICaseSensitiveStringSearcher : public StringSearcherBase {
372
    const char* const needle;
373
374
    template <typename CharT>
375
    // requires (sizeof(CharT) == 1)
376
    LibCASCIICaseSensitiveStringSearcher(const CharT* const needle_, const size_t /* needle_size */)
377
            : needle(reinterpret_cast<const char*>(needle_)) {}
378
379
    template <typename CharT>
380
    // requires (sizeof(CharT) == 1)
381
    const CharT* search(const CharT* haystack, const CharT* const haystack_end) const {
382
        const auto* res = strstr(reinterpret_cast<const char*>(haystack),
383
                                 reinterpret_cast<const char*>(needle));
384
        if (!res) return haystack_end;
385
        return reinterpret_cast<const CharT*>(res);
386
    }
387
388
    template <typename CharT>
389
    // requires (sizeof(CharT) == 1)
390
    const CharT* search(const CharT* haystack, const size_t haystack_size) const {
391
        return search(haystack, haystack + haystack_size);
392
    }
393
};
394
395
struct LibCASCIICaseInsensitiveStringSearcher : public StringSearcherBase {
396
    const char* const needle;
397
398
    template <typename CharT>
399
    // requires (sizeof(CharT) == 1)
400
    LibCASCIICaseInsensitiveStringSearcher(const CharT* const needle_,
401
                                           const size_t /* needle_size */)
402
            : needle(reinterpret_cast<const char*>(needle_)) {}
403
404
    template <typename CharT>
405
    // requires (sizeof(CharT) == 1)
406
    const CharT* search(const CharT* haystack, const CharT* const haystack_end) const {
407
        const auto* res = strcasestr(reinterpret_cast<const char*>(haystack),
408
                                     reinterpret_cast<const char*>(needle));
409
        if (!res) return haystack_end;
410
        return reinterpret_cast<const CharT*>(res);
411
    }
412
413
    template <typename CharT>
414
    // requires (sizeof(CharT) == 1)
415
    const CharT* search(const CharT* haystack, const size_t haystack_size) const {
416
        return search(haystack, haystack + haystack_size);
417
    }
418
};
419
#include "common/compile_check_end.h"
420
} // namespace doris