Coverage Report

Created: 2026-03-19 18:35

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/utils/simdlib_avx2.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#pragma once
9
10
#include <cstdint>
11
#include <string>
12
13
#include <immintrin.h>
14
15
#include <faiss/impl/platform_macros.h>
16
17
namespace faiss {
18
19
/** Simple wrapper around the AVX 256-bit registers
20
 *
21
 * The objective is to separate the different interpretations of the same
22
 * registers (as a vector of uint8, uint16 or uint32), to provide printing
23
 * functions, and to give more readable names to the AVX intrinsics. It does not
24
 * pretend to be exhausitve, functions are added as needed.
25
 */
26
27
/// 256-bit representation without interpretation as a vector
28
struct simd256bit {
29
    union {
30
        __m256i i;
31
        __m256 f;
32
    };
33
34
0
    simd256bit() {}
35
36
2
    explicit simd256bit(__m256i i) : i(i) {}
37
38
0
    explicit simd256bit(__m256 f) : f(f) {}
39
40
    explicit simd256bit(const void* x)
41
0
            : i(_mm256_load_si256((__m256i const*)x)) {}
42
43
0
    void clear() {
44
0
        i = _mm256_setzero_si256();
45
0
    }
46
47
0
    void storeu(void* ptr) const {
48
0
        _mm256_storeu_si256((__m256i*)ptr, i);
49
0
    }
50
51
0
    void loadu(const void* ptr) {
52
0
        i = _mm256_loadu_si256((__m256i*)ptr);
53
0
    }
54
55
0
    void store(void* ptr) const {
56
0
        _mm256_store_si256((__m256i*)ptr, i);
57
0
    }
58
59
0
    void bin(char bits[257]) const {
60
0
        char bytes[32];
61
0
        storeu((void*)bytes);
62
0
        for (int i = 0; i < 256; i++) {
63
0
            bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
64
0
        }
65
0
        bits[256] = 0;
66
0
    }
67
68
0
    std::string bin() const {
69
0
        char bits[257];
70
0
        bin(bits);
71
0
        return std::string(bits);
72
0
    }
73
74
    // Checks whether the other holds exactly the same bytes.
75
0
    bool is_same_as(simd256bit other) const {
76
0
        const __m256i pcmp = _mm256_cmpeq_epi32(i, other.i);
77
0
        unsigned bitmask = _mm256_movemask_epi8(pcmp);
78
0
        return (bitmask == 0xffffffffU);
79
0
    }
80
};
81
82
/// vector of 16 elements in uint16
83
struct simd16uint16 : simd256bit {
84
0
    simd16uint16() {}
85
86
0
    explicit simd16uint16(__m256i i) : simd256bit(i) {}
87
88
0
    explicit simd16uint16(int x) : simd256bit(_mm256_set1_epi16(x)) {}
89
90
0
    explicit simd16uint16(uint16_t x) : simd256bit(_mm256_set1_epi16(x)) {}
91
92
0
    explicit simd16uint16(simd256bit x) : simd256bit(x) {}
93
94
0
    explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
95
96
    explicit simd16uint16(
97
            uint16_t u0,
98
            uint16_t u1,
99
            uint16_t u2,
100
            uint16_t u3,
101
            uint16_t u4,
102
            uint16_t u5,
103
            uint16_t u6,
104
            uint16_t u7,
105
            uint16_t u8,
106
            uint16_t u9,
107
            uint16_t u10,
108
            uint16_t u11,
109
            uint16_t u12,
110
            uint16_t u13,
111
            uint16_t u14,
112
            uint16_t u15)
113
            : simd256bit(_mm256_setr_epi16(
114
                      u0,
115
                      u1,
116
                      u2,
117
                      u3,
118
                      u4,
119
                      u5,
120
                      u6,
121
                      u7,
122
                      u8,
123
                      u9,
124
                      u10,
125
                      u11,
126
                      u12,
127
                      u13,
128
                      u14,
129
0
                      u15)) {}
130
131
0
    std::string elements_to_string(const char* fmt) const {
132
0
        uint16_t bytes[16];
133
0
        storeu((void*)bytes);
134
0
        char res[1000];
135
0
        char* ptr = res;
136
0
        for (int i = 0; i < 16; i++) {
137
0
            ptr += sprintf(ptr, fmt, bytes[i]);
138
0
        }
139
0
        // strip last ,
140
0
        ptr[-1] = 0;
141
0
        return std::string(res);
142
0
    }
143
144
0
    std::string hex() const {
145
0
        return elements_to_string("%02x,");
146
0
    }
147
148
0
    std::string dec() const {
149
0
        return elements_to_string("%3d,");
150
0
    }
151
152
0
    void set1(uint16_t x) {
153
0
        i = _mm256_set1_epi16((short)x);
154
0
    }
155
156
0
    simd16uint16 operator*(const simd16uint16& other) const {
157
0
        return simd16uint16(_mm256_mullo_epi16(i, other.i));
158
0
    }
159
160
    // shift must be known at compile time
161
0
    simd16uint16 operator>>(const int shift) const {
162
0
        return simd16uint16(_mm256_srli_epi16(i, shift));
163
0
    }
164
165
    // shift must be known at compile time
166
0
    simd16uint16 operator<<(const int shift) const {
167
0
        return simd16uint16(_mm256_slli_epi16(i, shift));
168
0
    }
169
170
0
    simd16uint16 operator+=(simd16uint16 other) {
171
0
        i = _mm256_add_epi16(i, other.i);
172
0
        return *this;
173
0
    }
174
175
0
    simd16uint16 operator-=(simd16uint16 other) {
176
0
        i = _mm256_sub_epi16(i, other.i);
177
0
        return *this;
178
0
    }
179
180
0
    simd16uint16 operator+(simd16uint16 other) const {
181
0
        return simd16uint16(_mm256_add_epi16(i, other.i));
182
0
    }
183
184
0
    simd16uint16 operator-(simd16uint16 other) const {
185
0
        return simd16uint16(_mm256_sub_epi16(i, other.i));
186
0
    }
187
188
0
    simd16uint16 operator&(simd256bit other) const {
189
0
        return simd16uint16(_mm256_and_si256(i, other.i));
190
0
    }
191
192
0
    simd16uint16 operator|(simd256bit other) const {
193
0
        return simd16uint16(_mm256_or_si256(i, other.i));
194
0
    }
195
196
0
    simd16uint16 operator^(simd256bit other) const {
197
0
        return simd16uint16(_mm256_xor_si256(i, other.i));
198
0
    }
199
200
    // returns binary masks
201
0
    friend simd16uint16 operator==(const simd256bit lhs, const simd256bit rhs) {
202
0
        return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
203
0
    }
204
205
0
    simd16uint16 operator~() const {
206
0
        return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1)));
207
0
    }
208
209
    // get scalar at index 0
210
0
    uint16_t get_scalar_0() const {
211
0
        return _mm256_extract_epi16(i, 0);
212
0
    }
213
214
    // mask of elements where this >= thresh
215
    // 2 bit per component: 16 * 2 = 32 bit
216
0
    uint32_t ge_mask(simd16uint16 thresh) const {
217
0
        __m256i j = thresh.i;
218
0
        __m256i max = _mm256_max_epu16(i, j);
219
0
        __m256i ge = _mm256_cmpeq_epi16(i, max);
220
0
        return _mm256_movemask_epi8(ge);
221
0
    }
222
223
0
    uint32_t le_mask(simd16uint16 thresh) const {
224
0
        return thresh.ge_mask(*this);
225
0
    }
226
227
0
    uint32_t gt_mask(simd16uint16 thresh) const {
228
0
        return ~le_mask(thresh);
229
0
    }
230
231
0
    bool all_gt(simd16uint16 thresh) const {
232
0
        return le_mask(thresh) == 0;
233
0
    }
234
235
    // for debugging only
236
0
    uint16_t operator[](int i) const {
237
0
        ALIGNED(32) uint16_t tab[16];
238
0
        store(tab);
239
0
        return tab[i];
240
0
    }
241
242
0
    void accu_min(simd16uint16 incoming) {
243
0
        i = _mm256_min_epu16(i, incoming.i);
244
0
    }
245
246
0
    void accu_max(simd16uint16 incoming) {
247
0
        i = _mm256_max_epu16(i, incoming.i);
248
0
    }
249
};
250
251
// not really a std::min because it returns an elementwise min
252
0
inline simd16uint16 min(simd16uint16 a, simd16uint16 b) {
253
0
    return simd16uint16(_mm256_min_epu16(a.i, b.i));
254
0
}
255
256
0
inline simd16uint16 max(simd16uint16 a, simd16uint16 b) {
257
0
    return simd16uint16(_mm256_max_epu16(a.i, b.i));
258
0
}
259
260
// decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
261
// return (a0 + a1, b0 + b1)
262
// TODO find a better name
263
0
inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
264
0
    __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
265
0
    __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
266
267
0
    return simd16uint16(a1b0) + simd16uint16(a0b1);
268
0
}
269
270
// compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
271
// of d0 and d1 with thr
272
0
inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
273
0
    __m256i max0 = _mm256_max_epu16(d0.i, thr.i);
274
0
    __m256i ge0 = _mm256_cmpeq_epi16(d0.i, max0);
275
276
0
    __m256i max1 = _mm256_max_epu16(d1.i, thr.i);
277
0
    __m256i ge1 = _mm256_cmpeq_epi16(d1.i, max1);
278
279
0
    __m256i ge01 = _mm256_packs_epi16(ge0, ge1);
280
281
    // easier than manipulating bit fields afterwards
282
0
    ge01 = _mm256_permute4x64_epi64(ge01, 0 | (2 << 2) | (1 << 4) | (3 << 6));
283
0
    uint32_t ge = _mm256_movemask_epi8(ge01);
284
285
0
    return ge;
286
0
}
287
288
0
inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
289
0
    __m256i max0 = _mm256_min_epu16(d0.i, thr.i);
290
0
    __m256i ge0 = _mm256_cmpeq_epi16(d0.i, max0);
291
292
0
    __m256i max1 = _mm256_min_epu16(d1.i, thr.i);
293
0
    __m256i ge1 = _mm256_cmpeq_epi16(d1.i, max1);
294
295
0
    __m256i ge01 = _mm256_packs_epi16(ge0, ge1);
296
297
    // easier than manipulating bit fields afterwards
298
0
    ge01 = _mm256_permute4x64_epi64(ge01, 0 | (2 << 2) | (1 << 4) | (3 << 6));
299
0
    uint32_t ge = _mm256_movemask_epi8(ge01);
300
301
0
    return ge;
302
0
}
303
304
0
inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
305
0
    return simd16uint16(_mm256_hadd_epi16(a.i, b.i));
306
0
}
307
308
// Vectorized version of the following code:
309
//   for (size_t i = 0; i < n; i++) {
310
//      bool flag = (candidateValues[i] < currentValues[i]);
311
//      minValues[i] = flag ? candidateValues[i] : currentValues[i];
312
//      minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
313
//      maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
314
//      maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
315
//   }
316
// Max indices evaluation is inaccurate in case of equal values (the index of
317
// the last equal value is saved instead of the first one), but this behavior
318
// saves instructions.
319
//
320
// Works in i16 mode in order to save instructions. One may
321
// switch from i16 to u16.
322
inline void cmplt_min_max_fast(
323
        const simd16uint16 candidateValues,
324
        const simd16uint16 candidateIndices,
325
        const simd16uint16 currentValues,
326
        const simd16uint16 currentIndices,
327
        simd16uint16& minValues,
328
        simd16uint16& minIndices,
329
        simd16uint16& maxValues,
330
0
        simd16uint16& maxIndices) {
331
0
    // there's no lt instruction, so we'll need to emulate one
332
0
    __m256i comparison = _mm256_cmpgt_epi16(currentValues.i, candidateValues.i);
333
0
    comparison = _mm256_andnot_si256(comparison, _mm256_set1_epi16(-1));
334
0
335
0
    minValues.i = _mm256_min_epi16(candidateValues.i, currentValues.i);
336
0
    minIndices.i = _mm256_blendv_epi8(
337
0
            candidateIndices.i, currentIndices.i, comparison);
338
0
    maxValues.i = _mm256_max_epi16(candidateValues.i, currentValues.i);
339
0
    maxIndices.i = _mm256_blendv_epi8(
340
0
            currentIndices.i, candidateIndices.i, comparison);
341
0
}
342
343
// vector of 32 unsigned 8-bit integers
344
struct simd32uint8 : simd256bit {
345
0
    simd32uint8() {}
346
347
2
    explicit simd32uint8(__m256i i) : simd256bit(i) {}
348
349
0
    explicit simd32uint8(int x) : simd256bit(_mm256_set1_epi8(x)) {}
350
351
0
    explicit simd32uint8(uint8_t x) : simd256bit(_mm256_set1_epi8(x)) {}
352
353
    template <
354
            uint8_t _0,
355
            uint8_t _1,
356
            uint8_t _2,
357
            uint8_t _3,
358
            uint8_t _4,
359
            uint8_t _5,
360
            uint8_t _6,
361
            uint8_t _7,
362
            uint8_t _8,
363
            uint8_t _9,
364
            uint8_t _10,
365
            uint8_t _11,
366
            uint8_t _12,
367
            uint8_t _13,
368
            uint8_t _14,
369
            uint8_t _15,
370
            uint8_t _16,
371
            uint8_t _17,
372
            uint8_t _18,
373
            uint8_t _19,
374
            uint8_t _20,
375
            uint8_t _21,
376
            uint8_t _22,
377
            uint8_t _23,
378
            uint8_t _24,
379
            uint8_t _25,
380
            uint8_t _26,
381
            uint8_t _27,
382
            uint8_t _28,
383
            uint8_t _29,
384
            uint8_t _30,
385
            uint8_t _31>
386
2
    static simd32uint8 create() {
387
2
        return simd32uint8(_mm256_setr_epi8(
388
2
                (char)_0,
389
2
                (char)_1,
390
2
                (char)_2,
391
2
                (char)_3,
392
2
                (char)_4,
393
2
                (char)_5,
394
2
                (char)_6,
395
2
                (char)_7,
396
2
                (char)_8,
397
2
                (char)_9,
398
2
                (char)_10,
399
2
                (char)_11,
400
2
                (char)_12,
401
2
                (char)_13,
402
2
                (char)_14,
403
2
                (char)_15,
404
2
                (char)_16,
405
2
                (char)_17,
406
2
                (char)_18,
407
2
                (char)_19,
408
2
                (char)_20,
409
2
                (char)_21,
410
2
                (char)_22,
411
2
                (char)_23,
412
2
                (char)_24,
413
2
                (char)_25,
414
2
                (char)_26,
415
2
                (char)_27,
416
2
                (char)_28,
417
2
                (char)_29,
418
2
                (char)_30,
419
2
                (char)_31));
420
2
    }
_ZN5faiss11simd32uint86createILh1ELh16ELh0ELh0ELh4ELh64ELh0ELh0ELh0ELh0ELh1ELh16ELh0ELh0ELh4ELh64ELh1ELh16ELh0ELh0ELh4ELh64ELh0ELh0ELh0ELh0ELh1ELh16ELh0ELh0ELh4ELh64EEES0_v
Line
Count
Source
386
1
    static simd32uint8 create() {
387
1
        return simd32uint8(_mm256_setr_epi8(
388
1
                (char)_0,
389
1
                (char)_1,
390
1
                (char)_2,
391
1
                (char)_3,
392
1
                (char)_4,
393
1
                (char)_5,
394
1
                (char)_6,
395
1
                (char)_7,
396
1
                (char)_8,
397
1
                (char)_9,
398
1
                (char)_10,
399
1
                (char)_11,
400
1
                (char)_12,
401
1
                (char)_13,
402
1
                (char)_14,
403
1
                (char)_15,
404
1
                (char)_16,
405
1
                (char)_17,
406
1
                (char)_18,
407
1
                (char)_19,
408
1
                (char)_20,
409
1
                (char)_21,
410
1
                (char)_22,
411
1
                (char)_23,
412
1
                (char)_24,
413
1
                (char)_25,
414
1
                (char)_26,
415
1
                (char)_27,
416
1
                (char)_28,
417
1
                (char)_29,
418
1
                (char)_30,
419
1
                (char)_31));
420
1
    }
_ZN5faiss11simd32uint86createILh1ELh2ELh4ELh8ELh16ELh32ELh64ELh128ELh1ELh2ELh4ELh8ELh16ELh32ELh64ELh128ELh1ELh2ELh4ELh8ELh16ELh32ELh64ELh128ELh1ELh2ELh4ELh8ELh16ELh32ELh64ELh128EEES0_v
Line
Count
Source
386
1
    static simd32uint8 create() {
387
1
        return simd32uint8(_mm256_setr_epi8(
388
1
                (char)_0,
389
1
                (char)_1,
390
1
                (char)_2,
391
1
                (char)_3,
392
1
                (char)_4,
393
1
                (char)_5,
394
1
                (char)_6,
395
1
                (char)_7,
396
1
                (char)_8,
397
1
                (char)_9,
398
1
                (char)_10,
399
1
                (char)_11,
400
1
                (char)_12,
401
1
                (char)_13,
402
1
                (char)_14,
403
1
                (char)_15,
404
1
                (char)_16,
405
1
                (char)_17,
406
1
                (char)_18,
407
1
                (char)_19,
408
1
                (char)_20,
409
1
                (char)_21,
410
1
                (char)_22,
411
1
                (char)_23,
412
1
                (char)_24,
413
1
                (char)_25,
414
1
                (char)_26,
415
1
                (char)_27,
416
1
                (char)_28,
417
1
                (char)_29,
418
1
                (char)_30,
419
1
                (char)_31));
420
1
    }
421
422
0
    explicit simd32uint8(simd256bit x) : simd256bit(x) {}
423
424
0
    explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
425
426
0
    std::string elements_to_string(const char* fmt) const {
427
0
        uint8_t bytes[32];
428
0
        storeu((void*)bytes);
429
0
        char res[1000];
430
0
        char* ptr = res;
431
0
        for (int i = 0; i < 32; i++) {
432
0
            ptr += sprintf(ptr, fmt, bytes[i]);
433
0
        }
434
0
        // strip last ,
435
0
        ptr[-1] = 0;
436
0
        return std::string(res);
437
0
    }
438
439
0
    std::string hex() const {
440
0
        return elements_to_string("%02x,");
441
0
    }
442
443
0
    std::string dec() const {
444
0
        return elements_to_string("%3d,");
445
0
    }
446
447
0
    void set1(uint8_t x) {
448
0
        i = _mm256_set1_epi8((char)x);
449
0
    }
450
451
0
    simd32uint8 operator&(simd256bit other) const {
452
0
        return simd32uint8(_mm256_and_si256(i, other.i));
453
0
    }
454
455
0
    simd32uint8 operator+(simd32uint8 other) const {
456
0
        return simd32uint8(_mm256_add_epi8(i, other.i));
457
0
    }
458
459
0
    simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
460
0
        return simd32uint8(_mm256_shuffle_epi8(i, idx.i));
461
0
    }
462
463
    // extract + 0-extend lane
464
    // this operation is slow (3 cycles)
465
0
    simd16uint16 lane0_as_uint16() const {
466
0
        __m128i x = _mm256_extracti128_si256(i, 0);
467
0
        return simd16uint16(_mm256_cvtepu8_epi16(x));
468
0
    }
469
470
0
    simd16uint16 lane1_as_uint16() const {
471
0
        __m128i x = _mm256_extracti128_si256(i, 1);
472
0
        return simd16uint16(_mm256_cvtepu8_epi16(x));
473
0
    }
474
475
0
    simd32uint8 operator+=(simd32uint8 other) {
476
0
        i = _mm256_add_epi8(i, other.i);
477
0
        return *this;
478
0
    }
479
480
    // for debugging only
481
0
    uint8_t operator[](int i) const {
482
0
        ALIGNED(32) uint8_t tab[32];
483
0
        store(tab);
484
0
        return tab[i];
485
0
    }
486
};
487
488
// convert with saturation
489
// careful: this does not cross lanes, so the order is weird
490
0
inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
491
0
    return simd32uint8(_mm256_packs_epi16(a.i, b.i));
492
0
}
493
494
/// get most significant bit of each byte
495
0
inline uint32_t get_MSBs(simd32uint8 a) {
496
0
    return _mm256_movemask_epi8(a.i);
497
0
}
498
499
/// use MSB of each byte of mask to select a byte between a and b
500
0
inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
501
0
    return simd32uint8(_mm256_blendv_epi8(a.i, b.i, mask.i));
502
0
}
503
504
/// vector of 8 unsigned 32-bit integers
505
struct simd8uint32 : simd256bit {
506
0
    simd8uint32() {}
507
508
0
    explicit simd8uint32(__m256i i) : simd256bit(i) {}
509
510
0
    explicit simd8uint32(uint32_t x) : simd256bit(_mm256_set1_epi32(x)) {}
511
512
0
    explicit simd8uint32(simd256bit x) : simd256bit(x) {}
513
514
0
    explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
515
516
    explicit simd8uint32(
517
            uint32_t u0,
518
            uint32_t u1,
519
            uint32_t u2,
520
            uint32_t u3,
521
            uint32_t u4,
522
            uint32_t u5,
523
            uint32_t u6,
524
            uint32_t u7)
525
0
            : simd256bit(_mm256_setr_epi32(u0, u1, u2, u3, u4, u5, u6, u7)) {}
526
527
0
    simd8uint32 operator+(simd8uint32 other) const {
528
0
        return simd8uint32(_mm256_add_epi32(i, other.i));
529
0
    }
530
531
0
    simd8uint32 operator-(simd8uint32 other) const {
532
0
        return simd8uint32(_mm256_sub_epi32(i, other.i));
533
0
    }
534
535
0
    simd8uint32& operator+=(const simd8uint32& other) {
536
0
        i = _mm256_add_epi32(i, other.i);
537
0
        return *this;
538
0
    }
539
540
0
    bool operator==(simd8uint32 other) const {
541
0
        const __m256i pcmp = _mm256_cmpeq_epi32(i, other.i);
542
0
        unsigned bitmask = _mm256_movemask_epi8(pcmp);
543
0
        return (bitmask == 0xffffffffU);
544
0
    }
545
546
0
    bool operator!=(simd8uint32 other) const {
547
0
        return !(*this == other);
548
0
    }
549
550
0
    std::string elements_to_string(const char* fmt) const {
551
0
        uint32_t bytes[8];
552
0
        storeu((void*)bytes);
553
0
        char res[1000];
554
0
        char* ptr = res;
555
0
        for (int i = 0; i < 8; i++) {
556
0
            ptr += sprintf(ptr, fmt, bytes[i]);
557
0
        }
558
0
        // strip last ,
559
0
        ptr[-1] = 0;
560
0
        return std::string(res);
561
0
    }
562
563
0
    std::string hex() const {
564
0
        return elements_to_string("%08x,");
565
0
    }
566
567
0
    std::string dec() const {
568
0
        return elements_to_string("%10d,");
569
0
    }
570
571
0
    void set1(uint32_t x) {
572
0
        i = _mm256_set1_epi32((int)x);
573
0
    }
574
575
0
    simd8uint32 unzip() const {
576
0
        return simd8uint32(_mm256_permutevar8x32_epi32(
577
0
                i, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)));
578
0
    }
579
};
580
581
// Vectorized version of the following code:
582
//   for (size_t i = 0; i < n; i++) {
583
//      bool flag = (candidateValues[i] < currentValues[i]);
584
//      minValues[i] = flag ? candidateValues[i] : currentValues[i];
585
//      minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
586
//      maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
587
//      maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
588
//   }
589
// Max indices evaluation is inaccurate in case of equal values (the index of
590
// the last equal value is saved instead of the first one), but this behavior
591
// saves instructions.
592
inline void cmplt_min_max_fast(
593
        const simd8uint32 candidateValues,
594
        const simd8uint32 candidateIndices,
595
        const simd8uint32 currentValues,
596
        const simd8uint32 currentIndices,
597
        simd8uint32& minValues,
598
        simd8uint32& minIndices,
599
        simd8uint32& maxValues,
600
0
        simd8uint32& maxIndices) {
601
    // there's no lt instruction, so we'll need to emulate one
602
0
    __m256i comparison = _mm256_cmpgt_epi32(currentValues.i, candidateValues.i);
603
0
    comparison = _mm256_andnot_si256(comparison, _mm256_set1_epi32(-1));
604
605
0
    minValues.i = _mm256_min_epi32(candidateValues.i, currentValues.i);
606
0
    minIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
607
0
            _mm256_castsi256_ps(candidateIndices.i),
608
0
            _mm256_castsi256_ps(currentIndices.i),
609
0
            _mm256_castsi256_ps(comparison)));
610
0
    maxValues.i = _mm256_max_epi32(candidateValues.i, currentValues.i);
611
0
    maxIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
612
0
            _mm256_castsi256_ps(currentIndices.i),
613
0
            _mm256_castsi256_ps(candidateIndices.i),
614
0
            _mm256_castsi256_ps(comparison)));
615
0
}
616
617
struct simd8float32 : simd256bit {
618
0
    simd8float32() {}
619
620
0
    explicit simd8float32(simd256bit x) : simd256bit(x) {}
621
622
0
    explicit simd8float32(__m256 x) : simd256bit(x) {}
623
624
0
    explicit simd8float32(float x) : simd256bit(_mm256_set1_ps(x)) {}
625
626
0
    explicit simd8float32(const float* x) : simd256bit(_mm256_loadu_ps(x)) {}
627
628
    explicit simd8float32(
629
            float f0,
630
            float f1,
631
            float f2,
632
            float f3,
633
            float f4,
634
            float f5,
635
            float f6,
636
            float f7)
637
0
            : simd256bit(_mm256_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7)) {}
638
639
0
    simd8float32 operator*(simd8float32 other) const {
640
0
        return simd8float32(_mm256_mul_ps(f, other.f));
641
0
    }
642
643
0
    simd8float32 operator+(simd8float32 other) const {
644
0
        return simd8float32(_mm256_add_ps(f, other.f));
645
0
    }
646
647
0
    simd8float32 operator-(simd8float32 other) const {
648
0
        return simd8float32(_mm256_sub_ps(f, other.f));
649
0
    }
650
651
0
    simd8float32& operator+=(const simd8float32& other) {
652
0
        f = _mm256_add_ps(f, other.f);
653
0
        return *this;
654
0
    }
655
656
0
    bool operator==(simd8float32 other) const {
657
0
        const __m256i pcmp =
658
0
                _mm256_castps_si256(_mm256_cmp_ps(f, other.f, _CMP_EQ_OQ));
659
0
        unsigned bitmask = _mm256_movemask_epi8(pcmp);
660
0
        return (bitmask == 0xffffffffU);
661
0
    }
662
663
0
    bool operator!=(simd8float32 other) const {
664
0
        return !(*this == other);
665
0
    }
666
667
0
    std::string tostring() const {
668
0
        float tab[8];
669
0
        storeu((void*)tab);
670
0
        char res[1000];
671
0
        char* ptr = res;
672
0
        for (int i = 0; i < 8; i++) {
673
0
            ptr += sprintf(ptr, "%g,", tab[i]);
674
0
        }
675
0
        // strip last ,
676
0
        ptr[-1] = 0;
677
0
        return std::string(res);
678
0
    }
679
};
680
681
0
inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
682
0
    return simd8float32(_mm256_hadd_ps(a.f, b.f));
683
0
}
684
685
0
inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
686
0
    return simd8float32(_mm256_unpacklo_ps(a.f, b.f));
687
0
}
688
689
0
inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
690
0
    return simd8float32(_mm256_unpackhi_ps(a.f, b.f));
691
0
}
692
693
// compute a * b + c
694
0
inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
695
0
    return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f));
696
0
}
697
698
// The following primitive is a vectorized version of the following code
699
// snippet:
700
//   float lowestValue = HUGE_VAL;
701
//   uint lowestIndex = 0;
702
//   for (size_t i = 0; i < n; i++) {
703
//     if (values[i] < lowestValue) {
704
//       lowestValue = values[i];
705
//       lowestIndex = i;
706
//     }
707
//   }
708
// Vectorized version can be implemented via two operations: cmp and blend
709
// with something like this:
710
//   lowestValues = [HUGE_VAL; 8];
711
//   lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
712
//   for (size_t i = 0; i < n; i += 8) {
713
//     auto comparison = cmp(values + i, lowestValues);
714
//     lowestValues = blend(
715
//         comparison,
716
//         values + i,
717
//         lowestValues);
718
//     lowestIndices = blend(
719
//         comparison,
720
//         i + {0, 1, 2, 3, 4, 5, 6, 7},
721
//         lowestIndices);
722
//     lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
723
//   }
724
// The problem is that blend primitive needs very different instruction
725
// order for AVX and ARM.
726
// So, let's introduce a combination of these two in order to avoid
727
// confusion for ppl who write in low-level SIMD instructions. Additionally,
728
// these two ops (cmp and blend) are very often used together.
729
inline void cmplt_and_blend_inplace(
730
        const simd8float32 candidateValues,
731
        const simd8uint32 candidateIndices,
732
        simd8float32& lowestValues,
733
0
        simd8uint32& lowestIndices) {
734
0
    const __m256 comparison =
735
0
            _mm256_cmp_ps(lowestValues.f, candidateValues.f, _CMP_LE_OS);
736
0
    lowestValues.f = _mm256_min_ps(candidateValues.f, lowestValues.f);
737
0
    lowestIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
738
0
            _mm256_castsi256_ps(candidateIndices.i),
739
0
            _mm256_castsi256_ps(lowestIndices.i),
740
0
            comparison));
741
0
}
742
743
// Vectorized version of the following code:
744
//   for (size_t i = 0; i < n; i++) {
745
//      bool flag = (candidateValues[i] < currentValues[i]);
746
//      minValues[i] = flag ? candidateValues[i] : currentValues[i];
747
//      minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
748
//      maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
749
//      maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
750
//   }
751
// Max indices evaluation is inaccurate in case of equal values (the index of
752
// the last equal value is saved instead of the first one), but this behavior
753
// saves instructions.
754
inline void cmplt_min_max_fast(
755
        const simd8float32 candidateValues,
756
        const simd8uint32 candidateIndices,
757
        const simd8float32 currentValues,
758
        const simd8uint32 currentIndices,
759
        simd8float32& minValues,
760
        simd8uint32& minIndices,
761
        simd8float32& maxValues,
762
0
        simd8uint32& maxIndices) {
763
0
    const __m256 comparison =
764
0
            _mm256_cmp_ps(currentValues.f, candidateValues.f, _CMP_LE_OS);
765
0
    minValues.f = _mm256_min_ps(candidateValues.f, currentValues.f);
766
0
    minIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
767
0
            _mm256_castsi256_ps(candidateIndices.i),
768
0
            _mm256_castsi256_ps(currentIndices.i),
769
0
            comparison));
770
0
    maxValues.f = _mm256_max_ps(candidateValues.f, currentValues.f);
771
0
    maxIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
772
0
            _mm256_castsi256_ps(currentIndices.i),
773
0
            _mm256_castsi256_ps(candidateIndices.i),
774
0
            comparison));
775
0
}
776
777
namespace {
778
779
// get even float32's of a and b, interleaved
780
0
inline simd8float32 geteven(simd8float32 a, simd8float32 b) {
781
0
    return simd8float32(
782
0
            _mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6));
783
0
}
Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: IndexAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFPQFastScan.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_1.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_qbs.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: hamming.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: partitioning.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
Unexecuted instantiation: simdlib_based.cpp:_ZN5faiss12_GLOBAL__N_17getevenENS_12simd8float32ES1_
784
785
// get odd float32's of a and b, interleaved
786
0
inline simd8float32 getodd(simd8float32 a, simd8float32 b) {
787
0
    return simd8float32(
788
0
            _mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6));
789
0
}
Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: IndexAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFPQFastScan.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_1.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_qbs.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: hamming.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: partitioning.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
Unexecuted instantiation: simdlib_based.cpp:_ZN5faiss12_GLOBAL__N_16getoddENS_12simd8float32ES1_
790
791
// 3 cycles
792
// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
793
0
inline simd8float32 getlow128(simd8float32 a, simd8float32 b) {
794
0
    return simd8float32(_mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4));
795
0
}
Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFPQFastScan.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_1.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_qbs.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: hamming.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: partitioning.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
Unexecuted instantiation: simdlib_based.cpp:_ZN5faiss12_GLOBAL__N_19getlow128ENS_12simd8float32ES1_
796
797
0
inline simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
798
    return simd8float32(_mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4));
799
0
}
Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexAdditiveQuantizerFastScan.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: IndexIVFPQFastScan.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_1.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: pq4_fast_scan_search_qbs.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: hamming.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: partitioning.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
Unexecuted instantiation: simdlib_based.cpp:_ZN5faiss12_GLOBAL__N_110gethigh128ENS_12simd8float32ES1_
800
801
} // namespace
802
803
} // namespace faiss