Coverage Report

Created: 2026-04-11 00:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exec/common/string_searcher.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Common/StringSearcher.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <stdint.h>
24
#include <string.h>
25
26
#include <algorithm>
27
#include <limits>
28
#include <vector>
29
30
#include "core/string_ref.h"
31
#include "exec/common/string_utils/string_utils.h"
32
#include "util/sse_util.hpp"
33
34
namespace doris {
35
// namespace ErrorCodes
36
// {
37
//     extern const int BAD_ARGUMENTS;
38
// }
39
40
/** Variants for searching a substring in a string.
41
  */
42
43
class StringSearcherBase {
44
public:
45
    bool force_fallback = false;
46
#if defined(__SSE2__) || defined(__aarch64__)
47
protected:
48
    static constexpr auto n = sizeof(__m128i);
49
    const long page_size = sysconf(_SC_PAGESIZE); //::getPageSize();
50
51
2.72k
    bool page_safe(const void* const ptr) const {
52
2.72k
        return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
53
2.72k
    }
54
#endif
55
};
56
57
/// Performs case-sensitive and case-insensitive search of UTF-8 strings
58
template <bool CaseSensitive, bool ASCII>
59
class StringSearcher;
60
61
/// Case-sensitive searcher (both ASCII and UTF-8)
62
template <bool ASCII>
63
class StringSearcher<true, ASCII> : public StringSearcherBase {
64
private:
65
    /// string to be searched for
66
    const uint8_t* const needle;
67
    const uint8_t* const needle_end;
68
    /// first character in `needle`
69
    uint8_t first {};
70
71
#if defined(__SSE4_1__) || defined(__aarch64__)
72
    uint8_t second {};
73
    /// vector filled `first` or `second` for determining leftmost position of the first and second symbols
74
    __m128i first_pattern;
75
    __m128i second_pattern;
76
    /// vector of first 16 characters of `needle`
77
    __m128i cache = _mm_setzero_si128();
78
    int cachemask {};
79
#endif
80
81
public:
82
    template <typename CharT>
83
    // requires (sizeof(CharT) == 1)
84
    StringSearcher(const CharT* needle_, const size_t needle_size)
85
1.27k
            : needle {reinterpret_cast<const uint8_t*>(needle_)},
86
1.27k
              needle_end {needle + needle_size} {
87
1.27k
        if (0 == needle_size) return;
88
89
1.05k
        first = *needle;
90
91
1.05k
#if defined(__SSE4_1__) || defined(__aarch64__)
92
1.05k
        first_pattern = _mm_set1_epi8(first);
93
1.05k
        if (needle + 1 < needle_end) {
94
781
            second = *(needle + 1);
95
781
            second_pattern = _mm_set1_epi8(second);
96
781
        }
97
1.05k
        const auto* needle_pos = needle;
98
99
        //for (const auto i : collections::range(0, n))
100
17.9k
        for (size_t i = 0; i < n; i++) {
101
16.8k
            cache = _mm_srli_si128(cache, 1);
102
103
16.8k
            if (needle_pos != needle_end) {
104
6.43k
                cache = _mm_insert_epi8(cache, *needle_pos, n - 1);
105
6.43k
                cachemask |= 1 << i;
106
6.43k
                ++needle_pos;
107
6.43k
            }
108
16.8k
        }
109
1.05k
#endif
110
1.05k
    }
111
112
    template <typename CharT>
113
    // requires (sizeof(CharT) == 1)
114
888
    const CharT* search(const CharT* haystack, size_t haystack_size) const {
115
        // cast to unsigned int8 to be consitent with needle type
116
        // ensure unsigned type compare
117
888
        return reinterpret_cast<const CharT*>(
118
888
                _search(reinterpret_cast<const uint8_t*>(haystack), haystack_size));
119
888
    }
120
121
    template <typename CharT>
122
    // requires (sizeof(CharT) == 1)
123
0
    const CharT* search(const CharT* haystack, const CharT* haystack_end) const {
124
        // cast to unsigned int8 to be consitent with needle type
125
        // ensure unsigned type compare
126
0
        return reinterpret_cast<const CharT*>(
127
0
                _search(reinterpret_cast<const uint8_t*>(haystack),
128
0
                        reinterpret_cast<const uint8_t*>(haystack_end)));
129
0
    }
130
131
    template <typename CharT>
132
    // requires (sizeof(CharT) == 1)
133
    ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end, CharT* pos) const {
134
        // cast to unsigned int8 to be consitent with needle type
135
        // ensure unsigned type compare
136
        return _compare(reinterpret_cast<const uint8_t*>(haystack),
137
                        reinterpret_cast<const uint8_t*>(haystack_end),
138
                        reinterpret_cast<const uint8_t*>(pos));
139
    }
140
141
private:
142
    ALWAYS_INLINE bool _compare(uint8_t* /*haystack*/, uint8_t* /*haystack_end*/,
143
                                uint8_t* pos) const {
144
#if defined(__SSE4_1__) || defined(__aarch64__)
145
        if (needle_end - needle > n && page_safe(pos)) {
146
            const auto v_haystack = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pos));
147
            const auto v_against_cache = _mm_cmpeq_epi8(v_haystack, cache);
148
            const auto mask = _mm_movemask_epi8(v_against_cache);
149
150
            if (0xffff == cachemask) {
151
                if (mask == cachemask) {
152
                    pos += n;
153
                    const auto* needle_pos = needle + n;
154
155
                    while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos;
156
157
                    if (needle_pos == needle_end) return true;
158
                }
159
            } else if ((mask & cachemask) == cachemask)
160
                return true;
161
162
            return false;
163
        }
164
#endif
165
166
        if (*pos == first) {
167
            ++pos;
168
            const auto* needle_pos = needle + 1;
169
170
            while (needle_pos < needle_end && *pos == *needle_pos) ++pos, ++needle_pos;
171
172
            if (needle_pos == needle_end) return true;
173
        }
174
175
        return false;
176
    }
177
178
888
    const uint8_t* _search(const uint8_t* haystack, const uint8_t* haystack_end) const {
179
888
        if (needle == needle_end) return haystack;
180
181
888
        const auto needle_size = needle_end - needle;
182
888
#if defined(__SSE4_1__) || defined(__aarch64__)
183
        /// Here is the quick path when needle_size is 1.
184
888
        if (needle_size == 1) {
185
3.94k
            while (haystack < haystack_end) {
186
3.80k
                if (haystack + n <= haystack_end && page_safe(haystack)) {
187
2.40k
                    const auto v_haystack =
188
2.40k
                            _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack));
189
2.40k
                    const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, first_pattern);
190
2.40k
                    const auto mask = _mm_movemask_epi8(v_against_pattern);
191
2.40k
                    if (mask == 0) {
192
2.30k
                        haystack += n;
193
2.30k
                        continue;
194
2.30k
                    }
195
196
98
                    const auto offset = __builtin_ctz(mask);
197
98
                    haystack += offset;
198
199
98
                    return haystack;
200
2.40k
                }
201
202
1.40k
                if (haystack == haystack_end) {
203
0
                    return haystack_end;
204
0
                }
205
206
1.40k
                if (*haystack == first) {
207
123
                    return haystack;
208
123
                }
209
1.28k
                ++haystack;
210
1.28k
            }
211
142
            return haystack_end;
212
363
        }
213
525
#endif
214
215
1.29k
        while (haystack < haystack_end && haystack_end - haystack >= needle_size) {
216
1.04k
#if defined(__SSE4_1__) || defined(__aarch64__)
217
1.04k
            if ((haystack + 1 + n) <= haystack_end && page_safe(haystack)) {
218
                /// find first and second characters
219
194
                const auto v_haystack_block_first =
220
194
                        _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack));
221
194
                const auto v_haystack_block_second =
222
194
                        _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack + 1));
223
224
194
                const auto v_against_pattern_first =
225
194
                        _mm_cmpeq_epi8(v_haystack_block_first, first_pattern);
226
194
                const auto v_against_pattern_second =
227
194
                        _mm_cmpeq_epi8(v_haystack_block_second, second_pattern);
228
229
194
                const auto mask = _mm_movemask_epi8(
230
194
                        _mm_and_si128(v_against_pattern_first, v_against_pattern_second));
231
                /// first and second characters not present in 16 octets starting at `haystack`
232
194
                if (mask == 0) {
233
55
                    haystack += n;
234
55
                    continue;
235
55
                }
236
237
139
                const auto offset = __builtin_ctz(mask);
238
139
                haystack += offset;
239
240
139
                if (haystack + n <= haystack_end && page_safe(haystack)) {
241
                    /// check for first 16 octets
242
124
                    const auto v_haystack_offset =
243
124
                            _mm_loadu_si128(reinterpret_cast<const __m128i*>(haystack));
244
124
                    const auto v_against_cache = _mm_cmpeq_epi8(v_haystack_offset, cache);
245
124
                    const auto mask_offset = _mm_movemask_epi8(v_against_cache);
246
247
124
                    if (0xffff == cachemask) {
248
16
                        if (mask_offset == cachemask) {
249
16
                            const auto* haystack_pos = haystack + n;
250
16
                            const auto* needle_pos = needle + n;
251
252
136
                            while (haystack_pos < haystack_end && needle_pos < needle_end &&
253
136
                                   *haystack_pos == *needle_pos)
254
120
                                ++haystack_pos, ++needle_pos;
255
256
16
                            if (needle_pos == needle_end) return haystack;
257
16
                        }
258
108
                    } else if ((mask_offset & cachemask) == cachemask)
259
108
                        return haystack;
260
261
0
                    ++haystack;
262
0
                    continue;
263
124
                }
264
139
            }
265
866
#endif
266
267
866
            if (haystack == haystack_end) return haystack_end;
268
269
866
            if (*haystack == first) {
270
186
                const auto* haystack_pos = haystack + 1;
271
186
                const auto* needle_pos = needle + 1;
272
273
735
                while (haystack_pos < haystack_end && needle_pos < needle_end &&
274
735
                       *haystack_pos == *needle_pos)
275
549
                    ++haystack_pos, ++needle_pos;
276
277
186
                if (needle_pos == needle_end) return haystack;
278
186
            }
279
280
717
            ++haystack;
281
717
        }
282
283
252
        return haystack_end;
284
525
    }
285
286
888
    const uint8_t* _search(const uint8_t* haystack, const size_t haystack_size) const {
287
888
        return _search(haystack, haystack + haystack_size);
288
888
    }
289
};
290
291
// Searches for needle surrounded by token-separators.
292
// Separators are anything inside ASCII (0-128) and not alphanum.
293
// Any value outside of basic ASCII (>=128) is considered a non-separator symbol, hence UTF-8 strings
294
// should work just fine. But any Unicode whitespace is not considered a token separtor.
295
template <typename StringSearcher>
296
class TokenSearcher : public StringSearcherBase {
297
    StringSearcher searcher;
298
    size_t needle_size;
299
300
public:
301
    template <typename CharT>
302
    // requires (sizeof(CharT) == 1)
303
    TokenSearcher(const CharT* needle_, const size_t needle_size_)
304
            : searcher {needle_, needle_size_}, needle_size(needle_size_) {
305
        if (std::any_of(needle_, needle_ + needle_size_, isTokenSeparator)) {
306
            //throw Exception{"Needle must not contain whitespace or separator characters", ErrorCodes::BAD_ARGUMENTS};
307
        }
308
    }
309
310
    template <typename CharT>
311
    // requires (sizeof(CharT) == 1)
312
    ALWAYS_INLINE bool compare(const CharT* haystack, const CharT* haystack_end,
313
                               const CharT* pos) const {
314
        // use searcher only if pos is in the beginning of token and pos + searcher.needle_size is end of token.
315
        if (isToken(haystack, haystack_end, pos))
316
            return searcher.compare(haystack, haystack_end, pos);
317
318
        return false;
319
    }
320
321
    template <typename CharT>
322
    // requires (sizeof(CharT) == 1)
323
    const CharT* search(const CharT* haystack, const CharT* const haystack_end) const {
324
        // use searcher.search(), then verify that returned value is a token
325
        // if it is not, skip it and re-run
326
327
        const auto* pos = haystack;
328
        while (pos < haystack_end) {
329
            pos = searcher.search(pos, haystack_end);
330
            if (pos == haystack_end || isToken(haystack, haystack_end, pos)) return pos;
331
332
            // assuming that heendle does not contain any token separators.
333
            pos += needle_size;
334
        }
335
        return haystack_end;
336
    }
337
338
    template <typename CharT>
339
    // requires (sizeof(CharT) == 1)
340
    const CharT* search(const CharT* haystack, const size_t haystack_size) const {
341
        return search(haystack, haystack + haystack_size);
342
    }
343
344
    template <typename CharT>
345
    // requires (sizeof(CharT) == 1)
346
    ALWAYS_INLINE bool isToken(const CharT* haystack, const CharT* const haystack_end,
347
                               const CharT* p) const {
348
        return (p == haystack || isTokenSeparator(*(p - 1))) &&
349
               (p + needle_size >= haystack_end || isTokenSeparator(*(p + needle_size)));
350
    }
351
352
    ALWAYS_INLINE static bool isTokenSeparator(const uint8_t c) {
353
        return !(is_alpha_numeric_ascii(c) || !is_ascii(c));
354
    }
355
};
356
357
using ASCIICaseSensitiveStringSearcher = StringSearcher<true, true>;
358
// using ASCIICaseInsensitiveStringSearcher = StringSearcher<false, true>;
359
using UTF8CaseSensitiveStringSearcher = StringSearcher<true, false>;
360
// using UTF8CaseInsensitiveStringSearcher = StringSearcher<false, false>;
361
using ASCIICaseSensitiveTokenSearcher = TokenSearcher<ASCIICaseSensitiveStringSearcher>;
362
// using ASCIICaseInsensitiveTokenSearcher = TokenSearcher<ASCIICaseInsensitiveStringSearcher>;
363
364
/** Uses functions from libc.
365
  * It makes sense to use only with short haystacks when cheap initialization is required.
366
  * There is no option for case-insensitive search for UTF-8 strings.
367
  * It is required that strings are zero-terminated.
368
  */
369
370
struct LibCASCIICaseSensitiveStringSearcher : public StringSearcherBase {
371
    const char* const needle;
372
373
    template <typename CharT>
374
    // requires (sizeof(CharT) == 1)
375
    LibCASCIICaseSensitiveStringSearcher(const CharT* const needle_, const size_t /* needle_size */)
376
            : needle(reinterpret_cast<const char*>(needle_)) {}
377
378
    template <typename CharT>
379
    // requires (sizeof(CharT) == 1)
380
    const CharT* search(const CharT* haystack, const CharT* const haystack_end) const {
381
        const auto* res = strstr(reinterpret_cast<const char*>(haystack),
382
                                 reinterpret_cast<const char*>(needle));
383
        if (!res) return haystack_end;
384
        return reinterpret_cast<const CharT*>(res);
385
    }
386
387
    template <typename CharT>
388
    // requires (sizeof(CharT) == 1)
389
    const CharT* search(const CharT* haystack, const size_t haystack_size) const {
390
        return search(haystack, haystack + haystack_size);
391
    }
392
};
393
394
struct LibCASCIICaseInsensitiveStringSearcher : public StringSearcherBase {
395
    const char* const needle;
396
397
    template <typename CharT>
398
    // requires (sizeof(CharT) == 1)
399
    LibCASCIICaseInsensitiveStringSearcher(const CharT* const needle_,
400
                                           const size_t /* needle_size */)
401
            : needle(reinterpret_cast<const char*>(needle_)) {}
402
403
    template <typename CharT>
404
    // requires (sizeof(CharT) == 1)
405
    const CharT* search(const CharT* haystack, const CharT* const haystack_end) const {
406
        const auto* res = strcasestr(reinterpret_cast<const char*>(haystack),
407
                                     reinterpret_cast<const char*>(needle));
408
        if (!res) return haystack_end;
409
        return reinterpret_cast<const CharT*>(res);
410
    }
411
412
    template <typename CharT>
413
    // requires (sizeof(CharT) == 1)
414
    const CharT* search(const CharT* haystack, const size_t haystack_size) const {
415
        return search(haystack, haystack + haystack_size);
416
    }
417
};
418
} // namespace doris