Coverage Report

Created: 2026-03-12 16:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/utils/distances_simd.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/utils/distances.h>
11
12
#include <algorithm>
13
#include <cassert>
14
#include <cmath>
15
#include <cstdio>
16
#include <cstring>
17
18
#include <faiss/impl/FaissAssert.h>
19
#include <faiss/impl/platform_macros.h>
20
#include <faiss/utils/simdlib.h>
21
22
#ifdef __SSE3__
23
#include <immintrin.h>
24
#endif
25
26
#if defined(__AVX512F__)
27
#include <faiss/utils/transpose/transpose-avx512-inl.h>
28
#elif defined(__AVX2__)
29
#include <faiss/utils/transpose/transpose-avx2-inl.h>
30
#endif
31
32
#ifdef __ARM_FEATURE_SVE
33
#include <arm_sve.h>
34
#endif
35
36
#ifdef __aarch64__
37
#include <arm_neon.h>
38
#endif
39
40
namespace faiss {
41
42
#ifdef __AVX__
43
#define USE_AVX
44
#endif
45
46
/*********************************************************
47
 * Optimized distance computations
48
 *********************************************************/
49
50
/* Functions to compute:
51
   - L2 distance between 2 vectors
52
   - inner product between 2 vectors
53
   - L2 norm of a vector
54
55
   The functions should probably not be invoked when a large number of
56
   vectors are be processed in batch (in which case Matrix multiply
57
   is faster), but may be useful for comparing vectors isolated in
58
   memory.
59
60
   Works with any vectors of any dimension, even unaligned (in which
61
   case they are slower).
62
63
*/
64
65
/*********************************************************
66
 * Reference implementations
67
 */
68
69
0
float fvec_L1_ref(const float* x, const float* y, size_t d) {
70
0
    size_t i;
71
0
    float res = 0;
72
0
    for (i = 0; i < d; i++) {
73
0
        const float tmp = x[i] - y[i];
74
0
        res += fabs(tmp);
75
0
    }
76
0
    return res;
77
0
}
78
79
0
float fvec_Linf_ref(const float* x, const float* y, size_t d) {
80
0
    size_t i;
81
0
    float res = 0;
82
0
    for (i = 0; i < d; i++) {
83
0
        res = fmax(res, fabs(x[i] - y[i]));
84
0
    }
85
0
    return res;
86
0
}
87
88
void fvec_L2sqr_ny_ref(
89
        float* dis,
90
        const float* x,
91
        const float* y,
92
        size_t d,
93
16
        size_t ny) {
94
4.11k
    for (size_t i = 0; i < ny; i++) {
95
4.09k
        dis[i] = fvec_L2sqr(x, y, d);
96
4.09k
        y += d;
97
4.09k
    }
98
16
}
99
100
void fvec_L2sqr_ny_y_transposed_ref(
101
        float* dis,
102
        const float* x,
103
        const float* y,
104
        const float* y_sqlen,
105
        size_t d,
106
        size_t d_offset,
107
0
        size_t ny) {
108
0
    float x_sqlen = 0;
109
0
    for (size_t j = 0; j < d; j++) {
110
0
        x_sqlen += x[j] * x[j];
111
0
    }
112
113
0
    for (size_t i = 0; i < ny; i++) {
114
0
        float dp = 0;
115
0
        for (size_t j = 0; j < d; j++) {
116
0
            dp += x[j] * y[i + j * d_offset];
117
0
        }
118
119
0
        dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
120
0
    }
121
0
}
122
123
size_t fvec_L2sqr_ny_nearest_ref(
124
        float* distances_tmp_buffer,
125
        const float* x,
126
        const float* y,
127
        size_t d,
128
0
        size_t ny) {
129
0
    fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
130
131
0
    size_t nearest_idx = 0;
132
0
    float min_dis = HUGE_VALF;
133
134
0
    for (size_t i = 0; i < ny; i++) {
135
0
        if (distances_tmp_buffer[i] < min_dis) {
136
0
            min_dis = distances_tmp_buffer[i];
137
0
            nearest_idx = i;
138
0
        }
139
0
    }
140
141
0
    return nearest_idx;
142
0
}
143
144
size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
145
        float* distances_tmp_buffer,
146
        const float* x,
147
        const float* y,
148
        const float* y_sqlen,
149
        size_t d,
150
        size_t d_offset,
151
0
        size_t ny) {
152
0
    fvec_L2sqr_ny_y_transposed_ref(
153
0
            distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
154
155
0
    size_t nearest_idx = 0;
156
0
    float min_dis = HUGE_VALF;
157
158
0
    for (size_t i = 0; i < ny; i++) {
159
0
        if (distances_tmp_buffer[i] < min_dis) {
160
0
            min_dis = distances_tmp_buffer[i];
161
0
            nearest_idx = i;
162
0
        }
163
0
    }
164
165
0
    return nearest_idx;
166
0
}
167
168
void fvec_inner_products_ny_ref(
169
        float* ip,
170
        const float* x,
171
        const float* y,
172
        size_t d,
173
0
        size_t ny) {
174
    // BLAS slower for the use cases here
175
#if 0
176
    {
177
        FINTEGER di = d;
178
        FINTEGER nyi = ny;
179
        float one = 1.0, zero = 0.0;
180
        FINTEGER onei = 1;
181
        sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
182
    }
183
#endif
184
0
    for (size_t i = 0; i < ny; i++) {
185
0
        ip[i] = fvec_inner_product(x, y, d);
186
0
        y += d;
187
0
    }
188
0
}
189
190
/*********************************************************
191
 * Autovectorized implementations
192
 */
193
194
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
195
1.35M
float fvec_inner_product(const float* x, const float* y, size_t d) {
196
1.35M
    float res = 0.F;
197
1.35M
    FAISS_PRAGMA_IMPRECISE_LOOP
198
99.9M
    for (size_t i = 0; i != d; ++i) {
199
98.6M
        res += x[i] * y[i];
200
98.6M
    }
201
1.35M
    return res;
202
1.35M
}
203
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
204
205
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
206
18.8k
float fvec_norm_L2sqr(const float* x, size_t d) {
207
    // the double in the _ref is suspected to be a typo. Some of the manual
208
    // implementations this replaces used float.
209
18.8k
    float res = 0;
210
18.8k
    FAISS_PRAGMA_IMPRECISE_LOOP
211
4.18M
    for (size_t i = 0; i != d; ++i) {
212
4.16M
        res += x[i] * x[i];
213
4.16M
    }
214
215
18.8k
    return res;
216
18.8k
}
217
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
218
219
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
220
7.58M
float fvec_L2sqr(const float* x, const float* y, size_t d) {
221
7.58M
    size_t i;
222
7.58M
    float res = 0;
223
7.58M
    FAISS_PRAGMA_IMPRECISE_LOOP
224
1.17G
    for (i = 0; i < d; i++) {
225
1.16G
        const float tmp = x[i] - y[i];
226
1.16G
        res += tmp * tmp;
227
1.16G
    }
228
7.58M
    return res;
229
7.58M
}
230
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
231
232
/// Special version of inner product that computes 4 distances
233
/// between x and yi
234
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
235
void fvec_inner_product_batch_4(
236
        const float* __restrict x,
237
        const float* __restrict y0,
238
        const float* __restrict y1,
239
        const float* __restrict y2,
240
        const float* __restrict y3,
241
        const size_t d,
242
        float& dis0,
243
        float& dis1,
244
        float& dis2,
245
272k
        float& dis3) {
246
272k
    float d0 = 0;
247
272k
    float d1 = 0;
248
272k
    float d2 = 0;
249
272k
    float d3 = 0;
250
272k
    FAISS_PRAGMA_IMPRECISE_LOOP
251
19.3M
    for (size_t i = 0; i < d; ++i) {
252
19.0M
        d0 += x[i] * y0[i];
253
19.0M
        d1 += x[i] * y1[i];
254
19.0M
        d2 += x[i] * y2[i];
255
19.0M
        d3 += x[i] * y3[i];
256
19.0M
    }
257
258
272k
    dis0 = d0;
259
272k
    dis1 = d1;
260
272k
    dis2 = d2;
261
272k
    dis3 = d3;
262
272k
}
263
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
264
265
/// Special version of L2sqr that computes 4 distances
266
/// between x and yi, which is performance oriented.
267
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
268
void fvec_L2sqr_batch_4(
269
        const float* x,
270
        const float* y0,
271
        const float* y1,
272
        const float* y2,
273
        const float* y3,
274
        const size_t d,
275
        float& dis0,
276
        float& dis1,
277
        float& dis2,
278
1.32M
        float& dis3) {
279
1.32M
    float d0 = 0;
280
1.32M
    float d1 = 0;
281
1.32M
    float d2 = 0;
282
1.32M
    float d3 = 0;
283
1.32M
    FAISS_PRAGMA_IMPRECISE_LOOP
284
193M
    for (size_t i = 0; i < d; ++i) {
285
191M
        const float q0 = x[i] - y0[i];
286
191M
        const float q1 = x[i] - y1[i];
287
191M
        const float q2 = x[i] - y2[i];
288
191M
        const float q3 = x[i] - y3[i];
289
191M
        d0 += q0 * q0;
290
191M
        d1 += q1 * q1;
291
191M
        d2 += q2 * q2;
292
191M
        d3 += q3 * q3;
293
191M
    }
294
295
1.32M
    dis0 = d0;
296
1.32M
    dis1 = d1;
297
1.32M
    dis2 = d2;
298
1.32M
    dis3 = d3;
299
1.32M
}
300
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
301
302
/*********************************************************
303
 * SSE and AVX implementations
304
 */
305
306
#ifdef __SSE3__
307
308
// reads 0 <= d < 4 floats as __m128
309
4
static inline __m128 masked_read(int d, const float* x) {
310
4
    assert(0 <= d && d < 4);
311
4
    ALIGNED(16) float buf[4] = {0, 0, 0, 0};
312
4
    switch (d) {
313
2
        case 3:
314
2
            buf[2] = x[2];
315
2
            [[fallthrough]];
316
4
        case 2:
317
4
            buf[1] = x[1];
318
4
            [[fallthrough]];
319
4
        case 1:
320
4
            buf[0] = x[0];
321
4
    }
322
4
    return _mm_load_ps(buf);
323
    // cannot use AVX2 _mm_mask_set1_epi32
324
4
}
325
326
namespace {
327
328
/// helper function
329
0
inline float horizontal_sum(const __m128 v) {
330
    // say, v is [x0, x1, x2, x3]
331
332
    // v0 is [x2, x3, ..., ...]
333
0
    const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
334
    // v1 is [x0 + x2, x1 + x3, ..., ...]
335
0
    const __m128 v1 = _mm_add_ps(v, v0);
336
    // v2 is [x1 + x3, ..., .... ,...]
337
0
    __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
338
    // v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
339
0
    const __m128 v3 = _mm_add_ps(v1, v2);
340
    // return v3[0]
341
0
    return _mm_cvtss_f32(v3);
342
0
}
343
344
#ifdef __AVX2__
345
/// helper function for AVX2
346
0
inline float horizontal_sum(const __m256 v) {
347
    // add high and low parts
348
0
    const __m128 v0 =
349
0
            _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
350
    // perform horizontal sum on v0
351
0
    return horizontal_sum(v0);
352
0
}
353
#endif
354
355
#ifdef __AVX512F__
356
/// helper function for AVX512
357
inline float horizontal_sum(const __m512 v) {
358
    // performs better than adding the high and low parts
359
    return _mm512_reduce_add_ps(v);
360
}
361
#endif
362
363
/// Function that does a component-wise operation between x and y
364
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
365
/// functions below
366
struct ElementOpL2 {
367
0
    static float op(float x, float y) {
368
0
        float tmp = x - y;
369
0
        return tmp * tmp;
370
0
    }
371
372
0
    static __m128 op(__m128 x, __m128 y) {
373
0
        __m128 tmp = _mm_sub_ps(x, y);
374
0
        return _mm_mul_ps(tmp, tmp);
375
0
    }
376
377
#ifdef __AVX2__
378
0
    static __m256 op(__m256 x, __m256 y) {
379
0
        __m256 tmp = _mm256_sub_ps(x, y);
380
0
        return _mm256_mul_ps(tmp, tmp);
381
0
    }
382
#endif
383
384
#ifdef __AVX512F__
385
    static __m512 op(__m512 x, __m512 y) {
386
        __m512 tmp = _mm512_sub_ps(x, y);
387
        return _mm512_mul_ps(tmp, tmp);
388
    }
389
#endif
390
};
391
392
/// Function that does a component-wise operation between x and y
393
/// to compute inner products
394
struct ElementOpIP {
395
0
    static float op(float x, float y) {
396
0
        return x * y;
397
0
    }
398
399
0
    static __m128 op(__m128 x, __m128 y) {
400
0
        return _mm_mul_ps(x, y);
401
0
    }
402
403
#ifdef __AVX2__
404
0
    static __m256 op(__m256 x, __m256 y) {
405
0
        return _mm256_mul_ps(x, y);
406
0
    }
407
#endif
408
409
#ifdef __AVX512F__
410
    static __m512 op(__m512 x, __m512 y) {
411
        return _mm512_mul_ps(x, y);
412
    }
413
#endif
414
};
415
416
template <class ElementOp>
417
0
void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
418
0
    float x0s = x[0];
419
0
    __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
420
421
0
    size_t i;
422
0
    for (i = 0; i + 3 < ny; i += 4) {
423
0
        __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
424
0
        y += 4;
425
0
        dis[i] = _mm_cvtss_f32(accu);
426
0
        __m128 tmp = _mm_shuffle_ps(accu, accu, 1);
427
0
        dis[i + 1] = _mm_cvtss_f32(tmp);
428
0
        tmp = _mm_shuffle_ps(accu, accu, 2);
429
0
        dis[i + 2] = _mm_cvtss_f32(tmp);
430
0
        tmp = _mm_shuffle_ps(accu, accu, 3);
431
0
        dis[i + 3] = _mm_cvtss_f32(tmp);
432
0
    }
433
0
    while (i < ny) { // handle non-multiple-of-4 case
434
0
        dis[i++] = ElementOp::op(x0s, *y++);
435
0
    }
436
0
}
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_113fvec_op_ny_D1INS0_11ElementOpL2EEEvPfPKfS5_m
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_113fvec_op_ny_D1INS0_11ElementOpIPEEEvPfPKfS5_m
437
438
template <class ElementOp>
439
void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
440
    __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
441
442
    size_t i;
443
    for (i = 0; i + 1 < ny; i += 2) {
444
        __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
445
        y += 4;
446
        accu = _mm_hadd_ps(accu, accu);
447
        dis[i] = _mm_cvtss_f32(accu);
448
        accu = _mm_shuffle_ps(accu, accu, 3);
449
        dis[i + 1] = _mm_cvtss_f32(accu);
450
    }
451
    if (i < ny) { // handle odd case
452
        dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
453
    }
454
}
455
456
#if defined(__AVX512F__)
457
458
template <>
459
void fvec_op_ny_D2<ElementOpIP>(
460
        float* dis,
461
        const float* x,
462
        const float* y,
463
        size_t ny) {
464
    const size_t ny16 = ny / 16;
465
    size_t i = 0;
466
467
    if (ny16 > 0) {
468
        // process 16 D2-vectors per loop.
469
        _mm_prefetch((const char*)y, _MM_HINT_T0);
470
        _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
471
472
        const __m512 m0 = _mm512_set1_ps(x[0]);
473
        const __m512 m1 = _mm512_set1_ps(x[1]);
474
475
        for (i = 0; i < ny16 * 16; i += 16) {
476
            _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
477
478
            // load 16x2 matrix and transpose it in registers.
479
            // the typical bottleneck is memory access, so
480
            // let's trade instructions for the bandwidth.
481
482
            __m512 v0;
483
            __m512 v1;
484
485
            transpose_16x2(
486
                    _mm512_loadu_ps(y + 0 * 16),
487
                    _mm512_loadu_ps(y + 1 * 16),
488
                    v0,
489
                    v1);
490
491
            // compute distances (dot product)
492
            __m512 distances = _mm512_mul_ps(m0, v0);
493
            distances = _mm512_fmadd_ps(m1, v1, distances);
494
495
            // store
496
            _mm512_storeu_ps(dis + i, distances);
497
498
            y += 32; // move to the next set of 16x2 elements
499
        }
500
    }
501
502
    if (i < ny) {
503
        // process leftovers
504
        float x0 = x[0];
505
        float x1 = x[1];
506
507
        for (; i < ny; i++) {
508
            float distance = x0 * y[0] + x1 * y[1];
509
            y += 2;
510
            dis[i] = distance;
511
        }
512
    }
513
}
514
515
template <>
516
void fvec_op_ny_D2<ElementOpL2>(
517
        float* dis,
518
        const float* x,
519
        const float* y,
520
        size_t ny) {
521
    const size_t ny16 = ny / 16;
522
    size_t i = 0;
523
524
    if (ny16 > 0) {
525
        // process 16 D2-vectors per loop.
526
        _mm_prefetch((const char*)y, _MM_HINT_T0);
527
        _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
528
529
        const __m512 m0 = _mm512_set1_ps(x[0]);
530
        const __m512 m1 = _mm512_set1_ps(x[1]);
531
532
        for (i = 0; i < ny16 * 16; i += 16) {
533
            _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
534
535
            // load 16x2 matrix and transpose it in registers.
536
            // the typical bottleneck is memory access, so
537
            // let's trade instructions for the bandwidth.
538
539
            __m512 v0;
540
            __m512 v1;
541
542
            transpose_16x2(
543
                    _mm512_loadu_ps(y + 0 * 16),
544
                    _mm512_loadu_ps(y + 1 * 16),
545
                    v0,
546
                    v1);
547
548
            // compute differences
549
            const __m512 d0 = _mm512_sub_ps(m0, v0);
550
            const __m512 d1 = _mm512_sub_ps(m1, v1);
551
552
            // compute squares of differences
553
            __m512 distances = _mm512_mul_ps(d0, d0);
554
            distances = _mm512_fmadd_ps(d1, d1, distances);
555
556
            // store
557
            _mm512_storeu_ps(dis + i, distances);
558
559
            y += 32; // move to the next set of 16x2 elements
560
        }
561
    }
562
563
    if (i < ny) {
564
        // process leftovers
565
        float x0 = x[0];
566
        float x1 = x[1];
567
568
        for (; i < ny; i++) {
569
            float sub0 = x0 - y[0];
570
            float sub1 = x1 - y[1];
571
            float distance = sub0 * sub0 + sub1 * sub1;
572
573
            y += 2;
574
            dis[i] = distance;
575
        }
576
    }
577
}
578
579
#elif defined(__AVX2__)
580
581
template <>
582
void fvec_op_ny_D2<ElementOpIP>(
583
        float* dis,
584
        const float* x,
585
        const float* y,
586
0
        size_t ny) {
587
0
    const size_t ny8 = ny / 8;
588
0
    size_t i = 0;
589
590
0
    if (ny8 > 0) {
591
        // process 8 D2-vectors per loop.
592
0
        _mm_prefetch((const char*)y, _MM_HINT_T0);
593
0
        _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
594
595
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
596
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
597
598
0
        for (i = 0; i < ny8 * 8; i += 8) {
599
0
            _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
600
601
            // load 8x2 matrix and transpose it in registers.
602
            // the typical bottleneck is memory access, so
603
            // let's trade instructions for the bandwidth.
604
605
0
            __m256 v0;
606
0
            __m256 v1;
607
608
0
            transpose_8x2(
609
0
                    _mm256_loadu_ps(y + 0 * 8),
610
0
                    _mm256_loadu_ps(y + 1 * 8),
611
0
                    v0,
612
0
                    v1);
613
614
            // compute distances
615
0
            __m256 distances = _mm256_mul_ps(m0, v0);
616
0
            distances = _mm256_fmadd_ps(m1, v1, distances);
617
618
            // store
619
0
            _mm256_storeu_ps(dis + i, distances);
620
621
0
            y += 16;
622
0
        }
623
0
    }
624
625
0
    if (i < ny) {
626
        // process leftovers
627
0
        float x0 = x[0];
628
0
        float x1 = x[1];
629
630
0
        for (; i < ny; i++) {
631
0
            float distance = x0 * y[0] + x1 * y[1];
632
0
            y += 2;
633
0
            dis[i] = distance;
634
0
        }
635
0
    }
636
0
}
637
638
template <>
639
void fvec_op_ny_D2<ElementOpL2>(
640
        float* dis,
641
        const float* x,
642
        const float* y,
643
16
        size_t ny) {
644
16
    const size_t ny8 = ny / 8;
645
16
    size_t i = 0;
646
647
16
    if (ny8 > 0) {
648
        // process 8 D2-vectors per loop.
649
0
        _mm_prefetch((const char*)y, _MM_HINT_T0);
650
0
        _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
651
652
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
653
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
654
655
0
        for (i = 0; i < ny8 * 8; i += 8) {
656
0
            _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
657
658
            // load 8x2 matrix and transpose it in registers.
659
            // the typical bottleneck is memory access, so
660
            // let's trade instructions for the bandwidth.
661
662
0
            __m256 v0;
663
0
            __m256 v1;
664
665
0
            transpose_8x2(
666
0
                    _mm256_loadu_ps(y + 0 * 8),
667
0
                    _mm256_loadu_ps(y + 1 * 8),
668
0
                    v0,
669
0
                    v1);
670
671
            // compute differences
672
0
            const __m256 d0 = _mm256_sub_ps(m0, v0);
673
0
            const __m256 d1 = _mm256_sub_ps(m1, v1);
674
675
            // compute squares of differences
676
0
            __m256 distances = _mm256_mul_ps(d0, d0);
677
0
            distances = _mm256_fmadd_ps(d1, d1, distances);
678
679
            // store
680
0
            _mm256_storeu_ps(dis + i, distances);
681
682
0
            y += 16;
683
0
        }
684
0
    }
685
686
16
    if (i < ny) {
687
        // process leftovers
688
16
        float x0 = x[0];
689
16
        float x1 = x[1];
690
691
80
        for (; i < ny; i++) {
692
64
            float sub0 = x0 - y[0];
693
64
            float sub1 = x1 - y[1];
694
64
            float distance = sub0 * sub0 + sub1 * sub1;
695
696
64
            y += 2;
697
64
            dis[i] = distance;
698
64
        }
699
16
    }
700
16
}
701
702
#endif
703
704
template <class ElementOp>
705
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
706
    __m128 x0 = _mm_loadu_ps(x);
707
708
    for (size_t i = 0; i < ny; i++) {
709
        __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
710
        y += 4;
711
        dis[i] = horizontal_sum(accu);
712
    }
713
}
714
715
#if defined(__AVX512F__)
716
717
template <>
718
void fvec_op_ny_D4<ElementOpIP>(
719
        float* dis,
720
        const float* x,
721
        const float* y,
722
        size_t ny) {
723
    const size_t ny16 = ny / 16;
724
    size_t i = 0;
725
726
    if (ny16 > 0) {
727
        // process 16 D4-vectors per loop.
728
        const __m512 m0 = _mm512_set1_ps(x[0]);
729
        const __m512 m1 = _mm512_set1_ps(x[1]);
730
        const __m512 m2 = _mm512_set1_ps(x[2]);
731
        const __m512 m3 = _mm512_set1_ps(x[3]);
732
733
        for (i = 0; i < ny16 * 16; i += 16) {
734
            // load 16x4 matrix and transpose it in registers.
735
            // the typical bottleneck is memory access, so
736
            // let's trade instructions for the bandwidth.
737
738
            __m512 v0;
739
            __m512 v1;
740
            __m512 v2;
741
            __m512 v3;
742
743
            transpose_16x4(
744
                    _mm512_loadu_ps(y + 0 * 16),
745
                    _mm512_loadu_ps(y + 1 * 16),
746
                    _mm512_loadu_ps(y + 2 * 16),
747
                    _mm512_loadu_ps(y + 3 * 16),
748
                    v0,
749
                    v1,
750
                    v2,
751
                    v3);
752
753
            // compute distances
754
            __m512 distances = _mm512_mul_ps(m0, v0);
755
            distances = _mm512_fmadd_ps(m1, v1, distances);
756
            distances = _mm512_fmadd_ps(m2, v2, distances);
757
            distances = _mm512_fmadd_ps(m3, v3, distances);
758
759
            // store
760
            _mm512_storeu_ps(dis + i, distances);
761
762
            y += 64; // move to the next set of 16x4 elements
763
        }
764
    }
765
766
    if (i < ny) {
767
        // process leftovers
768
        __m128 x0 = _mm_loadu_ps(x);
769
770
        for (; i < ny; i++) {
771
            __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
772
            y += 4;
773
            dis[i] = horizontal_sum(accu);
774
        }
775
    }
776
}
777
778
template <>
779
void fvec_op_ny_D4<ElementOpL2>(
780
        float* dis,
781
        const float* x,
782
        const float* y,
783
        size_t ny) {
784
    const size_t ny16 = ny / 16;
785
    size_t i = 0;
786
787
    if (ny16 > 0) {
788
        // process 16 D4-vectors per loop.
789
        const __m512 m0 = _mm512_set1_ps(x[0]);
790
        const __m512 m1 = _mm512_set1_ps(x[1]);
791
        const __m512 m2 = _mm512_set1_ps(x[2]);
792
        const __m512 m3 = _mm512_set1_ps(x[3]);
793
794
        for (i = 0; i < ny16 * 16; i += 16) {
795
            // load 16x4 matrix and transpose it in registers.
796
            // the typical bottleneck is memory access, so
797
            // let's trade instructions for the bandwidth.
798
799
            __m512 v0;
800
            __m512 v1;
801
            __m512 v2;
802
            __m512 v3;
803
804
            transpose_16x4(
805
                    _mm512_loadu_ps(y + 0 * 16),
806
                    _mm512_loadu_ps(y + 1 * 16),
807
                    _mm512_loadu_ps(y + 2 * 16),
808
                    _mm512_loadu_ps(y + 3 * 16),
809
                    v0,
810
                    v1,
811
                    v2,
812
                    v3);
813
814
            // compute differences
815
            const __m512 d0 = _mm512_sub_ps(m0, v0);
816
            const __m512 d1 = _mm512_sub_ps(m1, v1);
817
            const __m512 d2 = _mm512_sub_ps(m2, v2);
818
            const __m512 d3 = _mm512_sub_ps(m3, v3);
819
820
            // compute squares of differences
821
            __m512 distances = _mm512_mul_ps(d0, d0);
822
            distances = _mm512_fmadd_ps(d1, d1, distances);
823
            distances = _mm512_fmadd_ps(d2, d2, distances);
824
            distances = _mm512_fmadd_ps(d3, d3, distances);
825
826
            // store
827
            _mm512_storeu_ps(dis + i, distances);
828
829
            y += 64; // move to the next set of 16x4 elements
830
        }
831
    }
832
833
    if (i < ny) {
834
        // process leftovers
835
        __m128 x0 = _mm_loadu_ps(x);
836
837
        for (; i < ny; i++) {
838
            __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
839
            y += 4;
840
            dis[i] = horizontal_sum(accu);
841
        }
842
    }
843
}
844
845
#elif defined(__AVX2__)
846
847
template <>
848
void fvec_op_ny_D4<ElementOpIP>(
849
        float* dis,
850
        const float* x,
851
        const float* y,
852
0
        size_t ny) {
853
0
    const size_t ny8 = ny / 8;
854
0
    size_t i = 0;
855
856
0
    if (ny8 > 0) {
857
        // process 8 D4-vectors per loop.
858
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
859
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
860
0
        const __m256 m2 = _mm256_set1_ps(x[2]);
861
0
        const __m256 m3 = _mm256_set1_ps(x[3]);
862
863
0
        for (i = 0; i < ny8 * 8; i += 8) {
864
            // load 8x4 matrix and transpose it in registers.
865
            // the typical bottleneck is memory access, so
866
            // let's trade instructions for the bandwidth.
867
868
0
            __m256 v0;
869
0
            __m256 v1;
870
0
            __m256 v2;
871
0
            __m256 v3;
872
873
0
            transpose_8x4(
874
0
                    _mm256_loadu_ps(y + 0 * 8),
875
0
                    _mm256_loadu_ps(y + 1 * 8),
876
0
                    _mm256_loadu_ps(y + 2 * 8),
877
0
                    _mm256_loadu_ps(y + 3 * 8),
878
0
                    v0,
879
0
                    v1,
880
0
                    v2,
881
0
                    v3);
882
883
            // compute distances
884
0
            __m256 distances = _mm256_mul_ps(m0, v0);
885
0
            distances = _mm256_fmadd_ps(m1, v1, distances);
886
0
            distances = _mm256_fmadd_ps(m2, v2, distances);
887
0
            distances = _mm256_fmadd_ps(m3, v3, distances);
888
889
            // store
890
0
            _mm256_storeu_ps(dis + i, distances);
891
892
0
            y += 32;
893
0
        }
894
0
    }
895
896
0
    if (i < ny) {
897
        // process leftovers
898
0
        __m128 x0 = _mm_loadu_ps(x);
899
900
0
        for (; i < ny; i++) {
901
0
            __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
902
0
            y += 4;
903
0
            dis[i] = horizontal_sum(accu);
904
0
        }
905
0
    }
906
0
}
907
908
template <>
909
void fvec_op_ny_D4<ElementOpL2>(
910
        float* dis,
911
        const float* x,
912
        const float* y,
913
0
        size_t ny) {
914
0
    const size_t ny8 = ny / 8;
915
0
    size_t i = 0;
916
917
0
    if (ny8 > 0) {
918
        // process 8 D4-vectors per loop.
919
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
920
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
921
0
        const __m256 m2 = _mm256_set1_ps(x[2]);
922
0
        const __m256 m3 = _mm256_set1_ps(x[3]);
923
924
0
        for (i = 0; i < ny8 * 8; i += 8) {
925
            // load 8x4 matrix and transpose it in registers.
926
            // the typical bottleneck is memory access, so
927
            // let's trade instructions for the bandwidth.
928
929
0
            __m256 v0;
930
0
            __m256 v1;
931
0
            __m256 v2;
932
0
            __m256 v3;
933
934
0
            transpose_8x4(
935
0
                    _mm256_loadu_ps(y + 0 * 8),
936
0
                    _mm256_loadu_ps(y + 1 * 8),
937
0
                    _mm256_loadu_ps(y + 2 * 8),
938
0
                    _mm256_loadu_ps(y + 3 * 8),
939
0
                    v0,
940
0
                    v1,
941
0
                    v2,
942
0
                    v3);
943
944
            // compute differences
945
0
            const __m256 d0 = _mm256_sub_ps(m0, v0);
946
0
            const __m256 d1 = _mm256_sub_ps(m1, v1);
947
0
            const __m256 d2 = _mm256_sub_ps(m2, v2);
948
0
            const __m256 d3 = _mm256_sub_ps(m3, v3);
949
950
            // compute squares of differences
951
0
            __m256 distances = _mm256_mul_ps(d0, d0);
952
0
            distances = _mm256_fmadd_ps(d1, d1, distances);
953
0
            distances = _mm256_fmadd_ps(d2, d2, distances);
954
0
            distances = _mm256_fmadd_ps(d3, d3, distances);
955
956
            // store
957
0
            _mm256_storeu_ps(dis + i, distances);
958
959
0
            y += 32;
960
0
        }
961
0
    }
962
963
0
    if (i < ny) {
964
        // process leftovers
965
0
        __m128 x0 = _mm_loadu_ps(x);
966
967
0
        for (; i < ny; i++) {
968
0
            __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
969
0
            y += 4;
970
0
            dis[i] = horizontal_sum(accu);
971
0
        }
972
0
    }
973
0
}
974
975
#endif
976
977
template <class ElementOp>
978
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
979
    __m128 x0 = _mm_loadu_ps(x);
980
    __m128 x1 = _mm_loadu_ps(x + 4);
981
982
    for (size_t i = 0; i < ny; i++) {
983
        __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
984
        y += 4;
985
        accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
986
        y += 4;
987
        accu = _mm_hadd_ps(accu, accu);
988
        accu = _mm_hadd_ps(accu, accu);
989
        dis[i] = _mm_cvtss_f32(accu);
990
    }
991
}
992
993
#if defined(__AVX512F__)
994
995
template <>
996
void fvec_op_ny_D8<ElementOpIP>(
997
        float* dis,
998
        const float* x,
999
        const float* y,
1000
        size_t ny) {
1001
    const size_t ny16 = ny / 16;
1002
    size_t i = 0;
1003
1004
    if (ny16 > 0) {
1005
        // process 16 D16-vectors per loop.
1006
        const __m512 m0 = _mm512_set1_ps(x[0]);
1007
        const __m512 m1 = _mm512_set1_ps(x[1]);
1008
        const __m512 m2 = _mm512_set1_ps(x[2]);
1009
        const __m512 m3 = _mm512_set1_ps(x[3]);
1010
        const __m512 m4 = _mm512_set1_ps(x[4]);
1011
        const __m512 m5 = _mm512_set1_ps(x[5]);
1012
        const __m512 m6 = _mm512_set1_ps(x[6]);
1013
        const __m512 m7 = _mm512_set1_ps(x[7]);
1014
1015
        for (i = 0; i < ny16 * 16; i += 16) {
1016
            // load 16x8 matrix and transpose it in registers.
1017
            // the typical bottleneck is memory access, so
1018
            // let's trade instructions for the bandwidth.
1019
1020
            __m512 v0;
1021
            __m512 v1;
1022
            __m512 v2;
1023
            __m512 v3;
1024
            __m512 v4;
1025
            __m512 v5;
1026
            __m512 v6;
1027
            __m512 v7;
1028
1029
            transpose_16x8(
1030
                    _mm512_loadu_ps(y + 0 * 16),
1031
                    _mm512_loadu_ps(y + 1 * 16),
1032
                    _mm512_loadu_ps(y + 2 * 16),
1033
                    _mm512_loadu_ps(y + 3 * 16),
1034
                    _mm512_loadu_ps(y + 4 * 16),
1035
                    _mm512_loadu_ps(y + 5 * 16),
1036
                    _mm512_loadu_ps(y + 6 * 16),
1037
                    _mm512_loadu_ps(y + 7 * 16),
1038
                    v0,
1039
                    v1,
1040
                    v2,
1041
                    v3,
1042
                    v4,
1043
                    v5,
1044
                    v6,
1045
                    v7);
1046
1047
            // compute distances
1048
            __m512 distances = _mm512_mul_ps(m0, v0);
1049
            distances = _mm512_fmadd_ps(m1, v1, distances);
1050
            distances = _mm512_fmadd_ps(m2, v2, distances);
1051
            distances = _mm512_fmadd_ps(m3, v3, distances);
1052
            distances = _mm512_fmadd_ps(m4, v4, distances);
1053
            distances = _mm512_fmadd_ps(m5, v5, distances);
1054
            distances = _mm512_fmadd_ps(m6, v6, distances);
1055
            distances = _mm512_fmadd_ps(m7, v7, distances);
1056
1057
            // store
1058
            _mm512_storeu_ps(dis + i, distances);
1059
1060
            y += 128; // 16 floats * 8 rows
1061
        }
1062
    }
1063
1064
    if (i < ny) {
1065
        // process leftovers
1066
        __m256 x0 = _mm256_loadu_ps(x);
1067
1068
        for (; i < ny; i++) {
1069
            __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
1070
            y += 8;
1071
            dis[i] = horizontal_sum(accu);
1072
        }
1073
    }
1074
}
1075
1076
template <>
1077
void fvec_op_ny_D8<ElementOpL2>(
1078
        float* dis,
1079
        const float* x,
1080
        const float* y,
1081
        size_t ny) {
1082
    const size_t ny16 = ny / 16;
1083
    size_t i = 0;
1084
1085
    if (ny16 > 0) {
1086
        // process 16 D16-vectors per loop.
1087
        const __m512 m0 = _mm512_set1_ps(x[0]);
1088
        const __m512 m1 = _mm512_set1_ps(x[1]);
1089
        const __m512 m2 = _mm512_set1_ps(x[2]);
1090
        const __m512 m3 = _mm512_set1_ps(x[3]);
1091
        const __m512 m4 = _mm512_set1_ps(x[4]);
1092
        const __m512 m5 = _mm512_set1_ps(x[5]);
1093
        const __m512 m6 = _mm512_set1_ps(x[6]);
1094
        const __m512 m7 = _mm512_set1_ps(x[7]);
1095
1096
        for (i = 0; i < ny16 * 16; i += 16) {
1097
            // load 16x8 matrix and transpose it in registers.
1098
            // the typical bottleneck is memory access, so
1099
            // let's trade instructions for the bandwidth.
1100
1101
            __m512 v0;
1102
            __m512 v1;
1103
            __m512 v2;
1104
            __m512 v3;
1105
            __m512 v4;
1106
            __m512 v5;
1107
            __m512 v6;
1108
            __m512 v7;
1109
1110
            transpose_16x8(
1111
                    _mm512_loadu_ps(y + 0 * 16),
1112
                    _mm512_loadu_ps(y + 1 * 16),
1113
                    _mm512_loadu_ps(y + 2 * 16),
1114
                    _mm512_loadu_ps(y + 3 * 16),
1115
                    _mm512_loadu_ps(y + 4 * 16),
1116
                    _mm512_loadu_ps(y + 5 * 16),
1117
                    _mm512_loadu_ps(y + 6 * 16),
1118
                    _mm512_loadu_ps(y + 7 * 16),
1119
                    v0,
1120
                    v1,
1121
                    v2,
1122
                    v3,
1123
                    v4,
1124
                    v5,
1125
                    v6,
1126
                    v7);
1127
1128
            // compute differences
1129
            const __m512 d0 = _mm512_sub_ps(m0, v0);
1130
            const __m512 d1 = _mm512_sub_ps(m1, v1);
1131
            const __m512 d2 = _mm512_sub_ps(m2, v2);
1132
            const __m512 d3 = _mm512_sub_ps(m3, v3);
1133
            const __m512 d4 = _mm512_sub_ps(m4, v4);
1134
            const __m512 d5 = _mm512_sub_ps(m5, v5);
1135
            const __m512 d6 = _mm512_sub_ps(m6, v6);
1136
            const __m512 d7 = _mm512_sub_ps(m7, v7);
1137
1138
            // compute squares of differences
1139
            __m512 distances = _mm512_mul_ps(d0, d0);
1140
            distances = _mm512_fmadd_ps(d1, d1, distances);
1141
            distances = _mm512_fmadd_ps(d2, d2, distances);
1142
            distances = _mm512_fmadd_ps(d3, d3, distances);
1143
            distances = _mm512_fmadd_ps(d4, d4, distances);
1144
            distances = _mm512_fmadd_ps(d5, d5, distances);
1145
            distances = _mm512_fmadd_ps(d6, d6, distances);
1146
            distances = _mm512_fmadd_ps(d7, d7, distances);
1147
1148
            // store
1149
            _mm512_storeu_ps(dis + i, distances);
1150
1151
            y += 128; // 16 floats * 8 rows
1152
        }
1153
    }
1154
1155
    if (i < ny) {
1156
        // process leftovers
1157
        __m256 x0 = _mm256_loadu_ps(x);
1158
1159
        for (; i < ny; i++) {
1160
            __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1161
            y += 8;
1162
            dis[i] = horizontal_sum(accu);
1163
        }
1164
    }
1165
}
1166
1167
#elif defined(__AVX2__)
1168
1169
template <>
1170
void fvec_op_ny_D8<ElementOpIP>(
1171
        float* dis,
1172
        const float* x,
1173
        const float* y,
1174
0
        size_t ny) {
1175
0
    const size_t ny8 = ny / 8;
1176
0
    size_t i = 0;
1177
1178
0
    if (ny8 > 0) {
1179
        // process 8 D8-vectors per loop.
1180
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
1181
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
1182
0
        const __m256 m2 = _mm256_set1_ps(x[2]);
1183
0
        const __m256 m3 = _mm256_set1_ps(x[3]);
1184
0
        const __m256 m4 = _mm256_set1_ps(x[4]);
1185
0
        const __m256 m5 = _mm256_set1_ps(x[5]);
1186
0
        const __m256 m6 = _mm256_set1_ps(x[6]);
1187
0
        const __m256 m7 = _mm256_set1_ps(x[7]);
1188
1189
0
        for (i = 0; i < ny8 * 8; i += 8) {
1190
            // load 8x8 matrix and transpose it in registers.
1191
            // the typical bottleneck is memory access, so
1192
            // let's trade instructions for the bandwidth.
1193
1194
0
            __m256 v0;
1195
0
            __m256 v1;
1196
0
            __m256 v2;
1197
0
            __m256 v3;
1198
0
            __m256 v4;
1199
0
            __m256 v5;
1200
0
            __m256 v6;
1201
0
            __m256 v7;
1202
1203
0
            transpose_8x8(
1204
0
                    _mm256_loadu_ps(y + 0 * 8),
1205
0
                    _mm256_loadu_ps(y + 1 * 8),
1206
0
                    _mm256_loadu_ps(y + 2 * 8),
1207
0
                    _mm256_loadu_ps(y + 3 * 8),
1208
0
                    _mm256_loadu_ps(y + 4 * 8),
1209
0
                    _mm256_loadu_ps(y + 5 * 8),
1210
0
                    _mm256_loadu_ps(y + 6 * 8),
1211
0
                    _mm256_loadu_ps(y + 7 * 8),
1212
0
                    v0,
1213
0
                    v1,
1214
0
                    v2,
1215
0
                    v3,
1216
0
                    v4,
1217
0
                    v5,
1218
0
                    v6,
1219
0
                    v7);
1220
1221
            // compute distances
1222
0
            __m256 distances = _mm256_mul_ps(m0, v0);
1223
0
            distances = _mm256_fmadd_ps(m1, v1, distances);
1224
0
            distances = _mm256_fmadd_ps(m2, v2, distances);
1225
0
            distances = _mm256_fmadd_ps(m3, v3, distances);
1226
0
            distances = _mm256_fmadd_ps(m4, v4, distances);
1227
0
            distances = _mm256_fmadd_ps(m5, v5, distances);
1228
0
            distances = _mm256_fmadd_ps(m6, v6, distances);
1229
0
            distances = _mm256_fmadd_ps(m7, v7, distances);
1230
1231
            // store
1232
0
            _mm256_storeu_ps(dis + i, distances);
1233
1234
0
            y += 64;
1235
0
        }
1236
0
    }
1237
1238
0
    if (i < ny) {
1239
        // process leftovers
1240
0
        __m256 x0 = _mm256_loadu_ps(x);
1241
1242
0
        for (; i < ny; i++) {
1243
0
            __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
1244
0
            y += 8;
1245
0
            dis[i] = horizontal_sum(accu);
1246
0
        }
1247
0
    }
1248
0
}
1249
1250
template <>
1251
void fvec_op_ny_D8<ElementOpL2>(
1252
        float* dis,
1253
        const float* x,
1254
        const float* y,
1255
0
        size_t ny) {
1256
0
    const size_t ny8 = ny / 8;
1257
0
    size_t i = 0;
1258
1259
0
    if (ny8 > 0) {
1260
        // process 8 D8-vectors per loop.
1261
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
1262
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
1263
0
        const __m256 m2 = _mm256_set1_ps(x[2]);
1264
0
        const __m256 m3 = _mm256_set1_ps(x[3]);
1265
0
        const __m256 m4 = _mm256_set1_ps(x[4]);
1266
0
        const __m256 m5 = _mm256_set1_ps(x[5]);
1267
0
        const __m256 m6 = _mm256_set1_ps(x[6]);
1268
0
        const __m256 m7 = _mm256_set1_ps(x[7]);
1269
1270
0
        for (i = 0; i < ny8 * 8; i += 8) {
1271
            // load 8x8 matrix and transpose it in registers.
1272
            // the typical bottleneck is memory access, so
1273
            // let's trade instructions for the bandwidth.
1274
1275
0
            __m256 v0;
1276
0
            __m256 v1;
1277
0
            __m256 v2;
1278
0
            __m256 v3;
1279
0
            __m256 v4;
1280
0
            __m256 v5;
1281
0
            __m256 v6;
1282
0
            __m256 v7;
1283
1284
0
            transpose_8x8(
1285
0
                    _mm256_loadu_ps(y + 0 * 8),
1286
0
                    _mm256_loadu_ps(y + 1 * 8),
1287
0
                    _mm256_loadu_ps(y + 2 * 8),
1288
0
                    _mm256_loadu_ps(y + 3 * 8),
1289
0
                    _mm256_loadu_ps(y + 4 * 8),
1290
0
                    _mm256_loadu_ps(y + 5 * 8),
1291
0
                    _mm256_loadu_ps(y + 6 * 8),
1292
0
                    _mm256_loadu_ps(y + 7 * 8),
1293
0
                    v0,
1294
0
                    v1,
1295
0
                    v2,
1296
0
                    v3,
1297
0
                    v4,
1298
0
                    v5,
1299
0
                    v6,
1300
0
                    v7);
1301
1302
            // compute differences
1303
0
            const __m256 d0 = _mm256_sub_ps(m0, v0);
1304
0
            const __m256 d1 = _mm256_sub_ps(m1, v1);
1305
0
            const __m256 d2 = _mm256_sub_ps(m2, v2);
1306
0
            const __m256 d3 = _mm256_sub_ps(m3, v3);
1307
0
            const __m256 d4 = _mm256_sub_ps(m4, v4);
1308
0
            const __m256 d5 = _mm256_sub_ps(m5, v5);
1309
0
            const __m256 d6 = _mm256_sub_ps(m6, v6);
1310
0
            const __m256 d7 = _mm256_sub_ps(m7, v7);
1311
1312
            // compute squares of differences
1313
0
            __m256 distances = _mm256_mul_ps(d0, d0);
1314
0
            distances = _mm256_fmadd_ps(d1, d1, distances);
1315
0
            distances = _mm256_fmadd_ps(d2, d2, distances);
1316
0
            distances = _mm256_fmadd_ps(d3, d3, distances);
1317
0
            distances = _mm256_fmadd_ps(d4, d4, distances);
1318
0
            distances = _mm256_fmadd_ps(d5, d5, distances);
1319
0
            distances = _mm256_fmadd_ps(d6, d6, distances);
1320
0
            distances = _mm256_fmadd_ps(d7, d7, distances);
1321
1322
            // store
1323
0
            _mm256_storeu_ps(dis + i, distances);
1324
1325
0
            y += 64;
1326
0
        }
1327
0
    }
1328
1329
0
    if (i < ny) {
1330
        // process leftovers
1331
0
        __m256 x0 = _mm256_loadu_ps(x);
1332
1333
0
        for (; i < ny; i++) {
1334
0
            __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1335
0
            y += 8;
1336
0
            dis[i] = horizontal_sum(accu);
1337
0
        }
1338
0
    }
1339
0
}
1340
1341
#endif
1342
1343
template <class ElementOp>
1344
0
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
1345
0
    __m128 x0 = _mm_loadu_ps(x);
1346
0
    __m128 x1 = _mm_loadu_ps(x + 4);
1347
0
    __m128 x2 = _mm_loadu_ps(x + 8);
1348
1349
0
    for (size_t i = 0; i < ny; i++) {
1350
0
        __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
1351
0
        y += 4;
1352
0
        accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
1353
0
        y += 4;
1354
0
        accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
1355
0
        y += 4;
1356
0
        dis[i] = horizontal_sum(accu);
1357
0
    }
1358
0
}
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_114fvec_op_ny_D12INS0_11ElementOpL2EEEvPfPKfS5_m
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_114fvec_op_ny_D12INS0_11ElementOpIPEEEvPfPKfS5_m
1359
1360
} // anonymous namespace
1361
1362
void fvec_L2sqr_ny(
1363
        float* dis,
1364
        const float* x,
1365
        const float* y,
1366
        size_t d,
1367
32
        size_t ny) {
1368
    // optimized for a few special cases
1369
1370
32
#define DISPATCH(dval)                                  \
1371
32
    case dval:                                          \
1372
16
        fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
1373
16
        return;
1374
1375
32
    switch (d) {
1376
0
        DISPATCH(1)
1377
16
        DISPATCH(2)
1378
0
        DISPATCH(4)
1379
0
        DISPATCH(8)
1380
0
        DISPATCH(12)
1381
16
        default:
1382
16
            fvec_L2sqr_ny_ref(dis, x, y, d, ny);
1383
16
            return;
1384
32
    }
1385
32
#undef DISPATCH
1386
32
}
1387
1388
void fvec_inner_products_ny(
1389
        float* dis,
1390
        const float* x,
1391
        const float* y,
1392
        size_t d,
1393
0
        size_t ny) {
1394
0
#define DISPATCH(dval)                                  \
1395
0
    case dval:                                          \
1396
0
        fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
1397
0
        return;
1398
1399
0
    switch (d) {
1400
0
        DISPATCH(1)
1401
0
        DISPATCH(2)
1402
0
        DISPATCH(4)
1403
0
        DISPATCH(8)
1404
0
        DISPATCH(12)
1405
0
        default:
1406
0
            fvec_inner_products_ny_ref(dis, x, y, d, ny);
1407
0
            return;
1408
0
    }
1409
0
#undef DISPATCH
1410
0
}
1411
1412
#if defined(__AVX512F__)
1413
1414
template <size_t DIM>
1415
void fvec_L2sqr_ny_y_transposed_D(
1416
        float* distances,
1417
        const float* x,
1418
        const float* y,
1419
        const float* y_sqlen,
1420
        const size_t d_offset,
1421
        size_t ny) {
1422
    // current index being processed
1423
    size_t i = 0;
1424
1425
    // squared length of x
1426
    float x_sqlen = 0;
1427
    for (size_t j = 0; j < DIM; j++) {
1428
        x_sqlen += x[j] * x[j];
1429
    }
1430
1431
    // process 16 vectors per loop
1432
    const size_t ny16 = ny / 16;
1433
1434
    if (ny16 > 0) {
1435
        // m[i] = (2 * x[i], ... 2 * x[i])
1436
        __m512 m[DIM];
1437
        for (size_t j = 0; j < DIM; j++) {
1438
            m[j] = _mm512_set1_ps(x[j]);
1439
            m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
1440
        }
1441
1442
        __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
1443
1444
        for (; i < ny16 * 16; i += 16) {
1445
            // Load vectors for 16 dimensions
1446
            __m512 v[DIM];
1447
            for (size_t j = 0; j < DIM; j++) {
1448
                v[j] = _mm512_loadu_ps(y + j * d_offset);
1449
            }
1450
1451
            // Compute dot products
1452
            __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
1453
            for (size_t j = 1; j < DIM; j++) {
1454
                dp = _mm512_fnmadd_ps(m[j], v[j], dp);
1455
            }
1456
1457
            // Compute y^2 - (2 * x, y) + x^2
1458
            __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
1459
1460
            _mm512_storeu_ps(distances + i, distances_v);
1461
1462
            // Scroll y and y_sqlen forward
1463
            y += 16;
1464
            y_sqlen += 16;
1465
        }
1466
    }
1467
1468
    if (i < ny) {
1469
        // Process leftovers
1470
        for (; i < ny; i++) {
1471
            float dp = 0;
1472
            for (size_t j = 0; j < DIM; j++) {
1473
                dp += x[j] * y[j * d_offset];
1474
            }
1475
1476
            // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
1477
            // lowest distance.
1478
            const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1479
            distances[i] = distance;
1480
1481
            y += 1;
1482
            y_sqlen += 1;
1483
        }
1484
    }
1485
}
1486
1487
#elif defined(__AVX2__)
1488
1489
template <size_t DIM>
1490
void fvec_L2sqr_ny_y_transposed_D(
1491
        float* distances,
1492
        const float* x,
1493
        const float* y,
1494
        const float* y_sqlen,
1495
        const size_t d_offset,
1496
0
        size_t ny) {
1497
    // current index being processed
1498
0
    size_t i = 0;
1499
1500
    // squared length of x
1501
0
    float x_sqlen = 0;
1502
0
    for (size_t j = 0; j < DIM; j++) {
1503
0
        x_sqlen += x[j] * x[j];
1504
0
    }
1505
1506
    // process 8 vectors per loop.
1507
0
    const size_t ny8 = ny / 8;
1508
1509
0
    if (ny8 > 0) {
1510
        // m[i] = (2 * x[i], ... 2 * x[i])
1511
0
        __m256 m[DIM];
1512
0
        for (size_t j = 0; j < DIM; j++) {
1513
0
            m[j] = _mm256_set1_ps(x[j]);
1514
0
            m[j] = _mm256_add_ps(m[j], m[j]);
1515
0
        }
1516
1517
0
        __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
1518
1519
0
        for (; i < ny8 * 8; i += 8) {
1520
            // collect dim 0 for 8 D4-vectors.
1521
0
            const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
1522
1523
            // compute dot products
1524
            // this is x^2 - 2x[0]*y[0]
1525
0
            __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
1526
1527
0
            for (size_t j = 1; j < DIM; j++) {
1528
                // collect dim j for 8 D4-vectors.
1529
0
                const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
1530
0
                dp = _mm256_fnmadd_ps(m[j], vj, dp);
1531
0
            }
1532
1533
            // we've got x^2 - (2x, y) at this point
1534
1535
            // y^2 - (2x, y) + x^2
1536
0
            __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
1537
1538
0
            _mm256_storeu_ps(distances + i, distances_v);
1539
1540
            // scroll y and y_sqlen forward.
1541
0
            y += 8;
1542
0
            y_sqlen += 8;
1543
0
        }
1544
0
    }
1545
1546
0
    if (i < ny) {
1547
        // process leftovers
1548
0
        for (; i < ny; i++) {
1549
0
            float dp = 0;
1550
0
            for (size_t j = 0; j < DIM; j++) {
1551
0
                dp += x[j] * y[j * d_offset];
1552
0
            }
1553
1554
            // compute y^2 - 2 * (x, y), which is sufficient for looking for the
1555
            //   lowest distance.
1556
0
            const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1557
0
            distances[i] = distance;
1558
1559
0
            y += 1;
1560
0
            y_sqlen += 1;
1561
0
        }
1562
0
    }
1563
0
}
Unexecuted instantiation: _ZN5faiss28fvec_L2sqr_ny_y_transposed_DILm1EEEvPfPKfS3_S3_mm
Unexecuted instantiation: _ZN5faiss28fvec_L2sqr_ny_y_transposed_DILm2EEEvPfPKfS3_S3_mm
Unexecuted instantiation: _ZN5faiss28fvec_L2sqr_ny_y_transposed_DILm4EEEvPfPKfS3_S3_mm
Unexecuted instantiation: _ZN5faiss28fvec_L2sqr_ny_y_transposed_DILm8EEEvPfPKfS3_S3_mm
1564
1565
#endif
1566
1567
void fvec_L2sqr_ny_transposed(
1568
        float* dis,
1569
        const float* x,
1570
        const float* y,
1571
        const float* y_sqlen,
1572
        size_t d,
1573
        size_t d_offset,
1574
0
        size_t ny) {
1575
    // optimized for a few special cases
1576
1577
0
#ifdef __AVX2__
1578
0
#define DISPATCH(dval)                             \
1579
0
    case dval:                                     \
1580
0
        return fvec_L2sqr_ny_y_transposed_D<dval>( \
1581
0
                dis, x, y, y_sqlen, d_offset, ny);
1582
1583
0
    switch (d) {
1584
0
        DISPATCH(1)
1585
0
        DISPATCH(2)
1586
0
        DISPATCH(4)
1587
0
        DISPATCH(8)
1588
0
        default:
1589
0
            return fvec_L2sqr_ny_y_transposed_ref(
1590
0
                    dis, x, y, y_sqlen, d, d_offset, ny);
1591
0
    }
1592
0
#undef DISPATCH
1593
#else
1594
    // non-AVX2 case
1595
    return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1596
#endif
1597
0
}
1598
1599
#if defined(__AVX512F__)
1600
1601
size_t fvec_L2sqr_ny_nearest_D2(
1602
        float* distances_tmp_buffer,
1603
        const float* x,
1604
        const float* y,
1605
        size_t ny) {
1606
    // this implementation does not use distances_tmp_buffer.
1607
1608
    size_t i = 0;
1609
    float current_min_distance = HUGE_VALF;
1610
    size_t current_min_index = 0;
1611
1612
    const size_t ny16 = ny / 16;
1613
    if (ny16 > 0) {
1614
        _mm_prefetch((const char*)y, _MM_HINT_T0);
1615
        _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
1616
1617
        __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1618
        __m512i min_indices = _mm512_set1_epi32(0);
1619
1620
        __m512i current_indices = _mm512_setr_epi32(
1621
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1622
        const __m512i indices_increment = _mm512_set1_epi32(16);
1623
1624
        const __m512 m0 = _mm512_set1_ps(x[0]);
1625
        const __m512 m1 = _mm512_set1_ps(x[1]);
1626
1627
        for (; i < ny16 * 16; i += 16) {
1628
            _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
1629
1630
            __m512 v0;
1631
            __m512 v1;
1632
1633
            transpose_16x2(
1634
                    _mm512_loadu_ps(y + 0 * 16),
1635
                    _mm512_loadu_ps(y + 1 * 16),
1636
                    v0,
1637
                    v1);
1638
1639
            const __m512 d0 = _mm512_sub_ps(m0, v0);
1640
            const __m512 d1 = _mm512_sub_ps(m1, v1);
1641
1642
            __m512 distances = _mm512_mul_ps(d0, d0);
1643
            distances = _mm512_fmadd_ps(d1, d1, distances);
1644
1645
            __mmask16 comparison =
1646
                    _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1647
1648
            min_distances = _mm512_min_ps(distances, min_distances);
1649
            min_indices = _mm512_mask_blend_epi32(
1650
                    comparison, min_indices, current_indices);
1651
1652
            current_indices =
1653
                    _mm512_add_epi32(current_indices, indices_increment);
1654
1655
            y += 32;
1656
        }
1657
1658
        alignas(64) float min_distances_scalar[16];
1659
        alignas(64) uint32_t min_indices_scalar[16];
1660
        _mm512_store_ps(min_distances_scalar, min_distances);
1661
        _mm512_store_epi32(min_indices_scalar, min_indices);
1662
1663
        for (size_t j = 0; j < 16; j++) {
1664
            if (current_min_distance > min_distances_scalar[j]) {
1665
                current_min_distance = min_distances_scalar[j];
1666
                current_min_index = min_indices_scalar[j];
1667
            }
1668
        }
1669
    }
1670
1671
    if (i < ny) {
1672
        float x0 = x[0];
1673
        float x1 = x[1];
1674
1675
        for (; i < ny; i++) {
1676
            float sub0 = x0 - y[0];
1677
            float sub1 = x1 - y[1];
1678
            float distance = sub0 * sub0 + sub1 * sub1;
1679
1680
            y += 2;
1681
1682
            if (current_min_distance > distance) {
1683
                current_min_distance = distance;
1684
                current_min_index = i;
1685
            }
1686
        }
1687
    }
1688
1689
    return current_min_index;
1690
}
1691
1692
size_t fvec_L2sqr_ny_nearest_D4(
1693
        float* distances_tmp_buffer,
1694
        const float* x,
1695
        const float* y,
1696
        size_t ny) {
1697
    // this implementation does not use distances_tmp_buffer.
1698
1699
    size_t i = 0;
1700
    float current_min_distance = HUGE_VALF;
1701
    size_t current_min_index = 0;
1702
1703
    const size_t ny16 = ny / 16;
1704
1705
    if (ny16 > 0) {
1706
        __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1707
        __m512i min_indices = _mm512_set1_epi32(0);
1708
1709
        __m512i current_indices = _mm512_setr_epi32(
1710
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1711
        const __m512i indices_increment = _mm512_set1_epi32(16);
1712
1713
        const __m512 m0 = _mm512_set1_ps(x[0]);
1714
        const __m512 m1 = _mm512_set1_ps(x[1]);
1715
        const __m512 m2 = _mm512_set1_ps(x[2]);
1716
        const __m512 m3 = _mm512_set1_ps(x[3]);
1717
1718
        for (; i < ny16 * 16; i += 16) {
1719
            __m512 v0;
1720
            __m512 v1;
1721
            __m512 v2;
1722
            __m512 v3;
1723
1724
            transpose_16x4(
1725
                    _mm512_loadu_ps(y + 0 * 16),
1726
                    _mm512_loadu_ps(y + 1 * 16),
1727
                    _mm512_loadu_ps(y + 2 * 16),
1728
                    _mm512_loadu_ps(y + 3 * 16),
1729
                    v0,
1730
                    v1,
1731
                    v2,
1732
                    v3);
1733
1734
            const __m512 d0 = _mm512_sub_ps(m0, v0);
1735
            const __m512 d1 = _mm512_sub_ps(m1, v1);
1736
            const __m512 d2 = _mm512_sub_ps(m2, v2);
1737
            const __m512 d3 = _mm512_sub_ps(m3, v3);
1738
1739
            __m512 distances = _mm512_mul_ps(d0, d0);
1740
            distances = _mm512_fmadd_ps(d1, d1, distances);
1741
            distances = _mm512_fmadd_ps(d2, d2, distances);
1742
            distances = _mm512_fmadd_ps(d3, d3, distances);
1743
1744
            __mmask16 comparison =
1745
                    _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1746
1747
            min_distances = _mm512_min_ps(distances, min_distances);
1748
            min_indices = _mm512_mask_blend_epi32(
1749
                    comparison, min_indices, current_indices);
1750
1751
            current_indices =
1752
                    _mm512_add_epi32(current_indices, indices_increment);
1753
1754
            y += 64;
1755
        }
1756
1757
        alignas(64) float min_distances_scalar[16];
1758
        alignas(64) uint32_t min_indices_scalar[16];
1759
        _mm512_store_ps(min_distances_scalar, min_distances);
1760
        _mm512_store_epi32(min_indices_scalar, min_indices);
1761
1762
        for (size_t j = 0; j < 16; j++) {
1763
            if (current_min_distance > min_distances_scalar[j]) {
1764
                current_min_distance = min_distances_scalar[j];
1765
                current_min_index = min_indices_scalar[j];
1766
            }
1767
        }
1768
    }
1769
1770
    if (i < ny) {
1771
        __m128 x0 = _mm_loadu_ps(x);
1772
1773
        for (; i < ny; i++) {
1774
            __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
1775
            y += 4;
1776
            const float distance = horizontal_sum(accu);
1777
1778
            if (current_min_distance > distance) {
1779
                current_min_distance = distance;
1780
                current_min_index = i;
1781
            }
1782
        }
1783
    }
1784
1785
    return current_min_index;
1786
}
1787
1788
size_t fvec_L2sqr_ny_nearest_D8(
1789
        float* distances_tmp_buffer,
1790
        const float* x,
1791
        const float* y,
1792
        size_t ny) {
1793
    // this implementation does not use distances_tmp_buffer.
1794
1795
    size_t i = 0;
1796
    float current_min_distance = HUGE_VALF;
1797
    size_t current_min_index = 0;
1798
1799
    const size_t ny16 = ny / 16;
1800
    if (ny16 > 0) {
1801
        __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1802
        __m512i min_indices = _mm512_set1_epi32(0);
1803
1804
        __m512i current_indices = _mm512_setr_epi32(
1805
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1806
        const __m512i indices_increment = _mm512_set1_epi32(16);
1807
1808
        const __m512 m0 = _mm512_set1_ps(x[0]);
1809
        const __m512 m1 = _mm512_set1_ps(x[1]);
1810
        const __m512 m2 = _mm512_set1_ps(x[2]);
1811
        const __m512 m3 = _mm512_set1_ps(x[3]);
1812
1813
        const __m512 m4 = _mm512_set1_ps(x[4]);
1814
        const __m512 m5 = _mm512_set1_ps(x[5]);
1815
        const __m512 m6 = _mm512_set1_ps(x[6]);
1816
        const __m512 m7 = _mm512_set1_ps(x[7]);
1817
1818
        for (; i < ny16 * 16; i += 16) {
1819
            __m512 v0;
1820
            __m512 v1;
1821
            __m512 v2;
1822
            __m512 v3;
1823
            __m512 v4;
1824
            __m512 v5;
1825
            __m512 v6;
1826
            __m512 v7;
1827
1828
            transpose_16x8(
1829
                    _mm512_loadu_ps(y + 0 * 16),
1830
                    _mm512_loadu_ps(y + 1 * 16),
1831
                    _mm512_loadu_ps(y + 2 * 16),
1832
                    _mm512_loadu_ps(y + 3 * 16),
1833
                    _mm512_loadu_ps(y + 4 * 16),
1834
                    _mm512_loadu_ps(y + 5 * 16),
1835
                    _mm512_loadu_ps(y + 6 * 16),
1836
                    _mm512_loadu_ps(y + 7 * 16),
1837
                    v0,
1838
                    v1,
1839
                    v2,
1840
                    v3,
1841
                    v4,
1842
                    v5,
1843
                    v6,
1844
                    v7);
1845
1846
            const __m512 d0 = _mm512_sub_ps(m0, v0);
1847
            const __m512 d1 = _mm512_sub_ps(m1, v1);
1848
            const __m512 d2 = _mm512_sub_ps(m2, v2);
1849
            const __m512 d3 = _mm512_sub_ps(m3, v3);
1850
            const __m512 d4 = _mm512_sub_ps(m4, v4);
1851
            const __m512 d5 = _mm512_sub_ps(m5, v5);
1852
            const __m512 d6 = _mm512_sub_ps(m6, v6);
1853
            const __m512 d7 = _mm512_sub_ps(m7, v7);
1854
1855
            __m512 distances = _mm512_mul_ps(d0, d0);
1856
            distances = _mm512_fmadd_ps(d1, d1, distances);
1857
            distances = _mm512_fmadd_ps(d2, d2, distances);
1858
            distances = _mm512_fmadd_ps(d3, d3, distances);
1859
            distances = _mm512_fmadd_ps(d4, d4, distances);
1860
            distances = _mm512_fmadd_ps(d5, d5, distances);
1861
            distances = _mm512_fmadd_ps(d6, d6, distances);
1862
            distances = _mm512_fmadd_ps(d7, d7, distances);
1863
1864
            __mmask16 comparison =
1865
                    _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1866
1867
            min_distances = _mm512_min_ps(distances, min_distances);
1868
            min_indices = _mm512_mask_blend_epi32(
1869
                    comparison, min_indices, current_indices);
1870
1871
            current_indices =
1872
                    _mm512_add_epi32(current_indices, indices_increment);
1873
1874
            y += 128;
1875
        }
1876
1877
        alignas(64) float min_distances_scalar[16];
1878
        alignas(64) uint32_t min_indices_scalar[16];
1879
        _mm512_store_ps(min_distances_scalar, min_distances);
1880
        _mm512_store_epi32(min_indices_scalar, min_indices);
1881
1882
        for (size_t j = 0; j < 16; j++) {
1883
            if (current_min_distance > min_distances_scalar[j]) {
1884
                current_min_distance = min_distances_scalar[j];
1885
                current_min_index = min_indices_scalar[j];
1886
            }
1887
        }
1888
    }
1889
1890
    if (i < ny) {
1891
        __m256 x0 = _mm256_loadu_ps(x);
1892
1893
        for (; i < ny; i++) {
1894
            __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1895
            y += 8;
1896
            const float distance = horizontal_sum(accu);
1897
1898
            if (current_min_distance > distance) {
1899
                current_min_distance = distance;
1900
                current_min_index = i;
1901
            }
1902
        }
1903
    }
1904
1905
    return current_min_index;
1906
}
1907
1908
#elif defined(__AVX2__)
1909
1910
size_t fvec_L2sqr_ny_nearest_D2(
1911
        float* distances_tmp_buffer,
1912
        const float* x,
1913
        const float* y,
1914
8
        size_t ny) {
1915
    // this implementation does not use distances_tmp_buffer.
1916
1917
    // current index being processed
1918
8
    size_t i = 0;
1919
1920
    // min distance and the index of the closest vector so far
1921
8
    float current_min_distance = HUGE_VALF;
1922
8
    size_t current_min_index = 0;
1923
1924
    // process 8 D2-vectors per loop.
1925
8
    const size_t ny8 = ny / 8;
1926
8
    if (ny8 > 0) {
1927
0
        _mm_prefetch((const char*)y, _MM_HINT_T0);
1928
0
        _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
1929
1930
        // track min distance and the closest vector independently
1931
        // for each of 8 AVX2 components.
1932
0
        __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
1933
0
        __m256i min_indices = _mm256_set1_epi32(0);
1934
1935
0
        __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1936
0
        const __m256i indices_increment = _mm256_set1_epi32(8);
1937
1938
        // 1 value per register
1939
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
1940
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
1941
1942
0
        for (; i < ny8 * 8; i += 8) {
1943
0
            _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
1944
1945
0
            __m256 v0;
1946
0
            __m256 v1;
1947
1948
0
            transpose_8x2(
1949
0
                    _mm256_loadu_ps(y + 0 * 8),
1950
0
                    _mm256_loadu_ps(y + 1 * 8),
1951
0
                    v0,
1952
0
                    v1);
1953
1954
            // compute differences
1955
0
            const __m256 d0 = _mm256_sub_ps(m0, v0);
1956
0
            const __m256 d1 = _mm256_sub_ps(m1, v1);
1957
1958
            // compute squares of differences
1959
0
            __m256 distances = _mm256_mul_ps(d0, d0);
1960
0
            distances = _mm256_fmadd_ps(d1, d1, distances);
1961
1962
            // compare the new distances to the min distances
1963
            // for each of 8 AVX2 components.
1964
0
            __m256 comparison =
1965
0
                    _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
1966
1967
            // update min distances and indices with closest vectors if needed.
1968
0
            min_distances = _mm256_min_ps(distances, min_distances);
1969
0
            min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1970
0
                    _mm256_castsi256_ps(current_indices),
1971
0
                    _mm256_castsi256_ps(min_indices),
1972
0
                    comparison));
1973
1974
            // update current indices values. Basically, +8 to each of the
1975
            // 8 AVX2 components.
1976
0
            current_indices =
1977
0
                    _mm256_add_epi32(current_indices, indices_increment);
1978
1979
            // scroll y forward (8 vectors 2 DIM each).
1980
0
            y += 16;
1981
0
        }
1982
1983
        // dump values and find the minimum distance / minimum index
1984
0
        float min_distances_scalar[8];
1985
0
        uint32_t min_indices_scalar[8];
1986
0
        _mm256_storeu_ps(min_distances_scalar, min_distances);
1987
0
        _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
1988
1989
0
        for (size_t j = 0; j < 8; j++) {
1990
0
            if (current_min_distance > min_distances_scalar[j]) {
1991
0
                current_min_distance = min_distances_scalar[j];
1992
0
                current_min_index = min_indices_scalar[j];
1993
0
            }
1994
0
        }
1995
0
    }
1996
1997
8
    if (i < ny) {
1998
        // process leftovers.
1999
        // the following code is not optimal, but it is rarely invoked.
2000
8
        float x0 = x[0];
2001
8
        float x1 = x[1];
2002
2003
40
        for (; i < ny; i++) {
2004
32
            float sub0 = x0 - y[0];
2005
32
            float sub1 = x1 - y[1];
2006
32
            float distance = sub0 * sub0 + sub1 * sub1;
2007
2008
32
            y += 2;
2009
2010
32
            if (current_min_distance > distance) {
2011
20
                current_min_distance = distance;
2012
20
                current_min_index = i;
2013
20
            }
2014
32
        }
2015
8
    }
2016
2017
8
    return current_min_index;
2018
8
}
2019
2020
size_t fvec_L2sqr_ny_nearest_D4(
2021
        float* distances_tmp_buffer,
2022
        const float* x,
2023
        const float* y,
2024
0
        size_t ny) {
2025
    // this implementation does not use distances_tmp_buffer.
2026
2027
    // current index being processed
2028
0
    size_t i = 0;
2029
2030
    // min distance and the index of the closest vector so far
2031
0
    float current_min_distance = HUGE_VALF;
2032
0
    size_t current_min_index = 0;
2033
2034
    // process 8 D4-vectors per loop.
2035
0
    const size_t ny8 = ny / 8;
2036
2037
0
    if (ny8 > 0) {
2038
        // track min distance and the closest vector independently
2039
        // for each of 8 AVX2 components.
2040
0
        __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
2041
0
        __m256i min_indices = _mm256_set1_epi32(0);
2042
2043
0
        __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
2044
0
        const __m256i indices_increment = _mm256_set1_epi32(8);
2045
2046
        // 1 value per register
2047
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
2048
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
2049
0
        const __m256 m2 = _mm256_set1_ps(x[2]);
2050
0
        const __m256 m3 = _mm256_set1_ps(x[3]);
2051
2052
0
        for (; i < ny8 * 8; i += 8) {
2053
0
            __m256 v0;
2054
0
            __m256 v1;
2055
0
            __m256 v2;
2056
0
            __m256 v3;
2057
2058
0
            transpose_8x4(
2059
0
                    _mm256_loadu_ps(y + 0 * 8),
2060
0
                    _mm256_loadu_ps(y + 1 * 8),
2061
0
                    _mm256_loadu_ps(y + 2 * 8),
2062
0
                    _mm256_loadu_ps(y + 3 * 8),
2063
0
                    v0,
2064
0
                    v1,
2065
0
                    v2,
2066
0
                    v3);
2067
2068
            // compute differences
2069
0
            const __m256 d0 = _mm256_sub_ps(m0, v0);
2070
0
            const __m256 d1 = _mm256_sub_ps(m1, v1);
2071
0
            const __m256 d2 = _mm256_sub_ps(m2, v2);
2072
0
            const __m256 d3 = _mm256_sub_ps(m3, v3);
2073
2074
            // compute squares of differences
2075
0
            __m256 distances = _mm256_mul_ps(d0, d0);
2076
0
            distances = _mm256_fmadd_ps(d1, d1, distances);
2077
0
            distances = _mm256_fmadd_ps(d2, d2, distances);
2078
0
            distances = _mm256_fmadd_ps(d3, d3, distances);
2079
2080
            // compare the new distances to the min distances
2081
            // for each of 8 AVX2 components.
2082
0
            __m256 comparison =
2083
0
                    _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
2084
2085
            // update min distances and indices with closest vectors if needed.
2086
0
            min_distances = _mm256_min_ps(distances, min_distances);
2087
0
            min_indices = _mm256_castps_si256(_mm256_blendv_ps(
2088
0
                    _mm256_castsi256_ps(current_indices),
2089
0
                    _mm256_castsi256_ps(min_indices),
2090
0
                    comparison));
2091
2092
            // update current indices values. Basically, +8 to each of the
2093
            // 8 AVX2 components.
2094
0
            current_indices =
2095
0
                    _mm256_add_epi32(current_indices, indices_increment);
2096
2097
            // scroll y forward (8 vectors 4 DIM each).
2098
0
            y += 32;
2099
0
        }
2100
2101
        // dump values and find the minimum distance / minimum index
2102
0
        float min_distances_scalar[8];
2103
0
        uint32_t min_indices_scalar[8];
2104
0
        _mm256_storeu_ps(min_distances_scalar, min_distances);
2105
0
        _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
2106
2107
0
        for (size_t j = 0; j < 8; j++) {
2108
0
            if (current_min_distance > min_distances_scalar[j]) {
2109
0
                current_min_distance = min_distances_scalar[j];
2110
0
                current_min_index = min_indices_scalar[j];
2111
0
            }
2112
0
        }
2113
0
    }
2114
2115
0
    if (i < ny) {
2116
        // process leftovers
2117
0
        __m128 x0 = _mm_loadu_ps(x);
2118
2119
0
        for (; i < ny; i++) {
2120
0
            __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
2121
0
            y += 4;
2122
0
            const float distance = horizontal_sum(accu);
2123
2124
0
            if (current_min_distance > distance) {
2125
0
                current_min_distance = distance;
2126
0
                current_min_index = i;
2127
0
            }
2128
0
        }
2129
0
    }
2130
2131
0
    return current_min_index;
2132
0
}
2133
2134
size_t fvec_L2sqr_ny_nearest_D8(
2135
        float* distances_tmp_buffer,
2136
        const float* x,
2137
        const float* y,
2138
0
        size_t ny) {
2139
    // this implementation does not use distances_tmp_buffer.
2140
2141
    // current index being processed
2142
0
    size_t i = 0;
2143
2144
    // min distance and the index of the closest vector so far
2145
0
    float current_min_distance = HUGE_VALF;
2146
0
    size_t current_min_index = 0;
2147
2148
    // process 8 D8-vectors per loop.
2149
0
    const size_t ny8 = ny / 8;
2150
0
    if (ny8 > 0) {
2151
        // track min distance and the closest vector independently
2152
        // for each of 8 AVX2 components.
2153
0
        __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
2154
0
        __m256i min_indices = _mm256_set1_epi32(0);
2155
2156
0
        __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
2157
0
        const __m256i indices_increment = _mm256_set1_epi32(8);
2158
2159
        // 1 value per register
2160
0
        const __m256 m0 = _mm256_set1_ps(x[0]);
2161
0
        const __m256 m1 = _mm256_set1_ps(x[1]);
2162
0
        const __m256 m2 = _mm256_set1_ps(x[2]);
2163
0
        const __m256 m3 = _mm256_set1_ps(x[3]);
2164
2165
0
        const __m256 m4 = _mm256_set1_ps(x[4]);
2166
0
        const __m256 m5 = _mm256_set1_ps(x[5]);
2167
0
        const __m256 m6 = _mm256_set1_ps(x[6]);
2168
0
        const __m256 m7 = _mm256_set1_ps(x[7]);
2169
2170
0
        for (; i < ny8 * 8; i += 8) {
2171
0
            __m256 v0;
2172
0
            __m256 v1;
2173
0
            __m256 v2;
2174
0
            __m256 v3;
2175
0
            __m256 v4;
2176
0
            __m256 v5;
2177
0
            __m256 v6;
2178
0
            __m256 v7;
2179
2180
0
            transpose_8x8(
2181
0
                    _mm256_loadu_ps(y + 0 * 8),
2182
0
                    _mm256_loadu_ps(y + 1 * 8),
2183
0
                    _mm256_loadu_ps(y + 2 * 8),
2184
0
                    _mm256_loadu_ps(y + 3 * 8),
2185
0
                    _mm256_loadu_ps(y + 4 * 8),
2186
0
                    _mm256_loadu_ps(y + 5 * 8),
2187
0
                    _mm256_loadu_ps(y + 6 * 8),
2188
0
                    _mm256_loadu_ps(y + 7 * 8),
2189
0
                    v0,
2190
0
                    v1,
2191
0
                    v2,
2192
0
                    v3,
2193
0
                    v4,
2194
0
                    v5,
2195
0
                    v6,
2196
0
                    v7);
2197
2198
            // compute differences
2199
0
            const __m256 d0 = _mm256_sub_ps(m0, v0);
2200
0
            const __m256 d1 = _mm256_sub_ps(m1, v1);
2201
0
            const __m256 d2 = _mm256_sub_ps(m2, v2);
2202
0
            const __m256 d3 = _mm256_sub_ps(m3, v3);
2203
0
            const __m256 d4 = _mm256_sub_ps(m4, v4);
2204
0
            const __m256 d5 = _mm256_sub_ps(m5, v5);
2205
0
            const __m256 d6 = _mm256_sub_ps(m6, v6);
2206
0
            const __m256 d7 = _mm256_sub_ps(m7, v7);
2207
2208
            // compute squares of differences
2209
0
            __m256 distances = _mm256_mul_ps(d0, d0);
2210
0
            distances = _mm256_fmadd_ps(d1, d1, distances);
2211
0
            distances = _mm256_fmadd_ps(d2, d2, distances);
2212
0
            distances = _mm256_fmadd_ps(d3, d3, distances);
2213
0
            distances = _mm256_fmadd_ps(d4, d4, distances);
2214
0
            distances = _mm256_fmadd_ps(d5, d5, distances);
2215
0
            distances = _mm256_fmadd_ps(d6, d6, distances);
2216
0
            distances = _mm256_fmadd_ps(d7, d7, distances);
2217
2218
            // compare the new distances to the min distances
2219
            // for each of 8 AVX2 components.
2220
0
            __m256 comparison =
2221
0
                    _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
2222
2223
            // update min distances and indices with closest vectors if needed.
2224
0
            min_distances = _mm256_min_ps(distances, min_distances);
2225
0
            min_indices = _mm256_castps_si256(_mm256_blendv_ps(
2226
0
                    _mm256_castsi256_ps(current_indices),
2227
0
                    _mm256_castsi256_ps(min_indices),
2228
0
                    comparison));
2229
2230
            // update current indices values. Basically, +8 to each of the
2231
            // 8 AVX2 components.
2232
0
            current_indices =
2233
0
                    _mm256_add_epi32(current_indices, indices_increment);
2234
2235
            // scroll y forward (8 vectors 8 DIM each).
2236
0
            y += 64;
2237
0
        }
2238
2239
        // dump values and find the minimum distance / minimum index
2240
0
        float min_distances_scalar[8];
2241
0
        uint32_t min_indices_scalar[8];
2242
0
        _mm256_storeu_ps(min_distances_scalar, min_distances);
2243
0
        _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
2244
2245
0
        for (size_t j = 0; j < 8; j++) {
2246
0
            if (current_min_distance > min_distances_scalar[j]) {
2247
0
                current_min_distance = min_distances_scalar[j];
2248
0
                current_min_index = min_indices_scalar[j];
2249
0
            }
2250
0
        }
2251
0
    }
2252
2253
0
    if (i < ny) {
2254
        // process leftovers
2255
0
        __m256 x0 = _mm256_loadu_ps(x);
2256
2257
0
        for (; i < ny; i++) {
2258
0
            __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
2259
0
            y += 8;
2260
0
            const float distance = horizontal_sum(accu);
2261
2262
0
            if (current_min_distance > distance) {
2263
0
                current_min_distance = distance;
2264
0
                current_min_index = i;
2265
0
            }
2266
0
        }
2267
0
    }
2268
2269
0
    return current_min_index;
2270
0
}
2271
2272
#else
2273
size_t fvec_L2sqr_ny_nearest_D2(
2274
        float* distances_tmp_buffer,
2275
        const float* x,
2276
        const float* y,
2277
        size_t ny) {
2278
    return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny);
2279
}
2280
2281
size_t fvec_L2sqr_ny_nearest_D4(
2282
        float* distances_tmp_buffer,
2283
        const float* x,
2284
        const float* y,
2285
        size_t ny) {
2286
    return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
2287
}
2288
2289
size_t fvec_L2sqr_ny_nearest_D8(
2290
        float* distances_tmp_buffer,
2291
        const float* x,
2292
        const float* y,
2293
        size_t ny) {
2294
    return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny);
2295
}
2296
#endif
2297
2298
size_t fvec_L2sqr_ny_nearest(
2299
        float* distances_tmp_buffer,
2300
        const float* x,
2301
        const float* y,
2302
        size_t d,
2303
8
        size_t ny) {
2304
    // optimized for a few special cases
2305
8
#define DISPATCH(dval) \
2306
8
    case dval:         \
2307
8
        return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
2308
2309
8
    switch (d) {
2310
8
        DISPATCH(2)
2311
0
        DISPATCH(4)
2312
0
        DISPATCH(8)
2313
0
        default:
2314
0
            return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
2315
8
    }
2316
8
#undef DISPATCH
2317
8
}
2318
2319
#if defined(__AVX512F__)
2320
2321
template <size_t DIM>
2322
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
2323
        float* distances_tmp_buffer,
2324
        const float* x,
2325
        const float* y,
2326
        const float* y_sqlen,
2327
        const size_t d_offset,
2328
        size_t ny) {
2329
    // This implementation does not use distances_tmp_buffer.
2330
2331
    // Current index being processed
2332
    size_t i = 0;
2333
2334
    // Min distance and the index of the closest vector so far
2335
    float current_min_distance = HUGE_VALF;
2336
    size_t current_min_index = 0;
2337
2338
    // Process 16 vectors per loop
2339
    const size_t ny16 = ny / 16;
2340
2341
    if (ny16 > 0) {
2342
        // Track min distance and the closest vector independently
2343
        // for each of 16 AVX-512 components.
2344
        __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
2345
        __m512i min_indices = _mm512_set1_epi32(0);
2346
2347
        __m512i current_indices = _mm512_setr_epi32(
2348
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
2349
        const __m512i indices_increment = _mm512_set1_epi32(16);
2350
2351
        // m[i] = (2 * x[i], ... 2 * x[i])
2352
        __m512 m[DIM];
2353
        for (size_t j = 0; j < DIM; j++) {
2354
            m[j] = _mm512_set1_ps(x[j]);
2355
            m[j] = _mm512_add_ps(m[j], m[j]);
2356
        }
2357
2358
        for (; i < ny16 * 16; i += 16) {
2359
            // Compute dot products
2360
            const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
2361
            __m512 dp = _mm512_mul_ps(m[0], v0);
2362
            for (size_t j = 1; j < DIM; j++) {
2363
                const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
2364
                dp = _mm512_fmadd_ps(m[j], vj, dp);
2365
            }
2366
2367
            // Compute y^2 - (2 * x, y), which is sufficient for looking for the
2368
            // lowest distance.
2369
            // x^2 is the constant that can be avoided.
2370
            const __m512 distances =
2371
                    _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
2372
2373
            // Compare the new distances to the min distances
2374
            __mmask16 comparison =
2375
                    _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
2376
2377
            // Update min distances and indices with closest vectors if needed
2378
            min_distances =
2379
                    _mm512_mask_blend_ps(comparison, distances, min_distances);
2380
            min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
2381
                    comparison,
2382
                    _mm512_castsi512_ps(current_indices),
2383
                    _mm512_castsi512_ps(min_indices)));
2384
2385
            // Update current indices values. Basically, +16 to each of the 16
2386
            // AVX-512 components.
2387
            current_indices =
2388
                    _mm512_add_epi32(current_indices, indices_increment);
2389
2390
            // Scroll y and y_sqlen forward.
2391
            y += 16;
2392
            y_sqlen += 16;
2393
        }
2394
2395
        // Dump values and find the minimum distance / minimum index
2396
        float min_distances_scalar[16];
2397
        uint32_t min_indices_scalar[16];
2398
        _mm512_storeu_ps(min_distances_scalar, min_distances);
2399
        _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
2400
2401
        for (size_t j = 0; j < 16; j++) {
2402
            if (current_min_distance > min_distances_scalar[j]) {
2403
                current_min_distance = min_distances_scalar[j];
2404
                current_min_index = min_indices_scalar[j];
2405
            }
2406
        }
2407
    }
2408
2409
    if (i < ny) {
2410
        // Process leftovers
2411
        for (; i < ny; i++) {
2412
            float dp = 0;
2413
            for (size_t j = 0; j < DIM; j++) {
2414
                dp += x[j] * y[j * d_offset];
2415
            }
2416
2417
            // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
2418
            // lowest distance.
2419
            const float distance = y_sqlen[0] - 2 * dp;
2420
2421
            if (current_min_distance > distance) {
2422
                current_min_distance = distance;
2423
                current_min_index = i;
2424
            }
2425
2426
            y += 1;
2427
            y_sqlen += 1;
2428
        }
2429
    }
2430
2431
    return current_min_index;
2432
}
2433
2434
#elif defined(__AVX2__)
2435
2436
template <size_t DIM>
2437
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
2438
        float* distances_tmp_buffer,
2439
        const float* x,
2440
        const float* y,
2441
        const float* y_sqlen,
2442
        const size_t d_offset,
2443
0
        size_t ny) {
2444
    // this implementation does not use distances_tmp_buffer.
2445
2446
    // current index being processed
2447
0
    size_t i = 0;
2448
2449
    // min distance and the index of the closest vector so far
2450
0
    float current_min_distance = HUGE_VALF;
2451
0
    size_t current_min_index = 0;
2452
2453
    // process 8 vectors per loop.
2454
0
    const size_t ny8 = ny / 8;
2455
2456
0
    if (ny8 > 0) {
2457
        // track min distance and the closest vector independently
2458
        // for each of 8 AVX2 components.
2459
0
        __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
2460
0
        __m256i min_indices = _mm256_set1_epi32(0);
2461
2462
0
        __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
2463
0
        const __m256i indices_increment = _mm256_set1_epi32(8);
2464
2465
        // m[i] = (2 * x[i], ... 2 * x[i])
2466
0
        __m256 m[DIM];
2467
0
        for (size_t j = 0; j < DIM; j++) {
2468
0
            m[j] = _mm256_set1_ps(x[j]);
2469
0
            m[j] = _mm256_add_ps(m[j], m[j]);
2470
0
        }
2471
2472
0
        for (; i < ny8 * 8; i += 8) {
2473
            // collect dim 0 for 8 D4-vectors.
2474
0
            const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
2475
            // compute dot products
2476
0
            __m256 dp = _mm256_mul_ps(m[0], v0);
2477
2478
0
            for (size_t j = 1; j < DIM; j++) {
2479
                // collect dim j for 8 D4-vectors.
2480
0
                const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
2481
0
                dp = _mm256_fmadd_ps(m[j], vj, dp);
2482
0
            }
2483
2484
            // compute y^2 - (2 * x, y), which is sufficient for looking for the
2485
            //   lowest distance.
2486
            // x^2 is the constant that can be avoided.
2487
0
            const __m256 distances =
2488
0
                    _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
2489
2490
            // compare the new distances to the min distances
2491
            // for each of 8 AVX2 components.
2492
0
            const __m256 comparison =
2493
0
                    _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
2494
2495
            // update min distances and indices with closest vectors if needed.
2496
0
            min_distances =
2497
0
                    _mm256_blendv_ps(distances, min_distances, comparison);
2498
0
            min_indices = _mm256_castps_si256(_mm256_blendv_ps(
2499
0
                    _mm256_castsi256_ps(current_indices),
2500
0
                    _mm256_castsi256_ps(min_indices),
2501
0
                    comparison));
2502
2503
            // update current indices values. Basically, +8 to each of the
2504
            // 8 AVX2 components.
2505
0
            current_indices =
2506
0
                    _mm256_add_epi32(current_indices, indices_increment);
2507
2508
            // scroll y and y_sqlen forward.
2509
0
            y += 8;
2510
0
            y_sqlen += 8;
2511
0
        }
2512
2513
        // dump values and find the minimum distance / minimum index
2514
0
        float min_distances_scalar[8];
2515
0
        uint32_t min_indices_scalar[8];
2516
0
        _mm256_storeu_ps(min_distances_scalar, min_distances);
2517
0
        _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
2518
2519
0
        for (size_t j = 0; j < 8; j++) {
2520
0
            if (current_min_distance > min_distances_scalar[j]) {
2521
0
                current_min_distance = min_distances_scalar[j];
2522
0
                current_min_index = min_indices_scalar[j];
2523
0
            }
2524
0
        }
2525
0
    }
2526
2527
0
    if (i < ny) {
2528
        // process leftovers
2529
0
        for (; i < ny; i++) {
2530
0
            float dp = 0;
2531
0
            for (size_t j = 0; j < DIM; j++) {
2532
0
                dp += x[j] * y[j * d_offset];
2533
0
            }
2534
2535
            // compute y^2 - 2 * (x, y), which is sufficient for looking for the
2536
            //   lowest distance.
2537
0
            const float distance = y_sqlen[0] - 2 * dp;
2538
2539
0
            if (current_min_distance > distance) {
2540
0
                current_min_distance = distance;
2541
0
                current_min_index = i;
2542
0
            }
2543
2544
0
            y += 1;
2545
0
            y_sqlen += 1;
2546
0
        }
2547
0
    }
2548
2549
0
    return current_min_index;
2550
0
}
Unexecuted instantiation: _ZN5faiss36fvec_L2sqr_ny_nearest_y_transposed_DILm1EEEmPfPKfS3_S3_mm
Unexecuted instantiation: _ZN5faiss36fvec_L2sqr_ny_nearest_y_transposed_DILm2EEEmPfPKfS3_S3_mm
Unexecuted instantiation: _ZN5faiss36fvec_L2sqr_ny_nearest_y_transposed_DILm4EEEmPfPKfS3_S3_mm
Unexecuted instantiation: _ZN5faiss36fvec_L2sqr_ny_nearest_y_transposed_DILm8EEEmPfPKfS3_S3_mm
2551
2552
#endif
2553
2554
size_t fvec_L2sqr_ny_nearest_y_transposed(
2555
        float* distances_tmp_buffer,
2556
        const float* x,
2557
        const float* y,
2558
        const float* y_sqlen,
2559
        size_t d,
2560
        size_t d_offset,
2561
0
        size_t ny) {
2562
    // optimized for a few special cases
2563
0
#ifdef __AVX2__
2564
0
#define DISPATCH(dval)                                     \
2565
0
    case dval:                                             \
2566
0
        return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
2567
0
                distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
2568
2569
0
    switch (d) {
2570
0
        DISPATCH(1)
2571
0
        DISPATCH(2)
2572
0
        DISPATCH(4)
2573
0
        DISPATCH(8)
2574
0
        default:
2575
0
            return fvec_L2sqr_ny_nearest_y_transposed_ref(
2576
0
                    distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
2577
0
    }
2578
0
#undef DISPATCH
2579
#else
2580
    // non-AVX2 case
2581
    return fvec_L2sqr_ny_nearest_y_transposed_ref(
2582
            distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
2583
#endif
2584
0
}
2585
2586
#endif
2587
2588
#ifdef USE_AVX
2589
2590
3
float fvec_L1(const float* x, const float* y, size_t d) {
2591
3
    __m256 msum1 = _mm256_setzero_ps();
2592
    // signmask used for absolute value
2593
3
    __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
2594
2595
3
    while (d >= 8) {
2596
0
        __m256 mx = _mm256_loadu_ps(x);
2597
0
        x += 8;
2598
0
        __m256 my = _mm256_loadu_ps(y);
2599
0
        y += 8;
2600
        // subtract
2601
0
        const __m256 a_m_b = _mm256_sub_ps(mx, my);
2602
        // find sum of absolute value of distances (manhattan distance)
2603
0
        msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
2604
0
        d -= 8;
2605
0
    }
2606
2607
3
    __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
2608
3
    msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
2609
3
    __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
2610
2611
3
    if (d >= 4) {
2612
0
        __m128 mx = _mm_loadu_ps(x);
2613
0
        x += 4;
2614
0
        __m128 my = _mm_loadu_ps(y);
2615
0
        y += 4;
2616
0
        const __m128 a_m_b = _mm_sub_ps(mx, my);
2617
0
        msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
2618
0
        d -= 4;
2619
0
    }
2620
2621
3
    if (d > 0) {
2622
2
        __m128 mx = masked_read(d, x);
2623
2
        __m128 my = masked_read(d, y);
2624
2
        __m128 a_m_b = _mm_sub_ps(mx, my);
2625
2
        msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
2626
2
    }
2627
2628
3
    msum2 = _mm_hadd_ps(msum2, msum2);
2629
3
    msum2 = _mm_hadd_ps(msum2, msum2);
2630
3
    return _mm_cvtss_f32(msum2);
2631
3
}
2632
2633
0
float fvec_Linf(const float* x, const float* y, size_t d) {
2634
0
    __m256 msum1 = _mm256_setzero_ps();
2635
    // signmask used for absolute value
2636
0
    __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
2637
2638
0
    while (d >= 8) {
2639
0
        __m256 mx = _mm256_loadu_ps(x);
2640
0
        x += 8;
2641
0
        __m256 my = _mm256_loadu_ps(y);
2642
0
        y += 8;
2643
        // subtract
2644
0
        const __m256 a_m_b = _mm256_sub_ps(mx, my);
2645
        // find max of absolute value of distances (chebyshev distance)
2646
0
        msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
2647
0
        d -= 8;
2648
0
    }
2649
2650
0
    __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
2651
0
    msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
2652
0
    __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
2653
2654
0
    if (d >= 4) {
2655
0
        __m128 mx = _mm_loadu_ps(x);
2656
0
        x += 4;
2657
0
        __m128 my = _mm_loadu_ps(y);
2658
0
        y += 4;
2659
0
        const __m128 a_m_b = _mm_sub_ps(mx, my);
2660
0
        msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
2661
0
        d -= 4;
2662
0
    }
2663
2664
0
    if (d > 0) {
2665
0
        __m128 mx = masked_read(d, x);
2666
0
        __m128 my = masked_read(d, y);
2667
0
        __m128 a_m_b = _mm_sub_ps(mx, my);
2668
0
        msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
2669
0
    }
2670
2671
0
    msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
2672
0
    msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
2673
0
    return _mm_cvtss_f32(msum2);
2674
0
}
2675
2676
#elif defined(__SSE3__) // But not AVX
2677
2678
float fvec_L1(const float* x, const float* y, size_t d) {
2679
    return fvec_L1_ref(x, y, d);
2680
}
2681
2682
float fvec_Linf(const float* x, const float* y, size_t d) {
2683
    return fvec_Linf_ref(x, y, d);
2684
}
2685
2686
#elif defined(__ARM_FEATURE_SVE)
2687
2688
struct ElementOpIP {
2689
    static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) {
2690
        return svmul_f32_x(pg, x, y);
2691
    }
2692
    static svfloat32_t merge(
2693
            svbool_t pg,
2694
            svfloat32_t z,
2695
            svfloat32_t x,
2696
            svfloat32_t y) {
2697
        return svmla_f32_x(pg, z, x, y);
2698
    }
2699
};
2700
2701
template <typename ElementOp>
2702
void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) {
2703
    const size_t lanes = svcntw();
2704
    const size_t lanes2 = lanes * 2;
2705
    const size_t lanes3 = lanes * 3;
2706
    const size_t lanes4 = lanes * 4;
2707
    const svbool_t pg = svptrue_b32();
2708
    const svfloat32_t x0 = svdup_n_f32(x[0]);
2709
    size_t i = 0;
2710
    for (; i + lanes4 < ny; i += lanes4) {
2711
        svfloat32_t y0 = svld1_f32(pg, y);
2712
        svfloat32_t y1 = svld1_f32(pg, y + lanes);
2713
        svfloat32_t y2 = svld1_f32(pg, y + lanes2);
2714
        svfloat32_t y3 = svld1_f32(pg, y + lanes3);
2715
        y0 = ElementOp::op(pg, x0, y0);
2716
        y1 = ElementOp::op(pg, x0, y1);
2717
        y2 = ElementOp::op(pg, x0, y2);
2718
        y3 = ElementOp::op(pg, x0, y3);
2719
        svst1_f32(pg, dis, y0);
2720
        svst1_f32(pg, dis + lanes, y1);
2721
        svst1_f32(pg, dis + lanes2, y2);
2722
        svst1_f32(pg, dis + lanes3, y3);
2723
        y += lanes4;
2724
        dis += lanes4;
2725
    }
2726
    const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2727
    const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
2728
    const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny);
2729
    const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny);
2730
    svfloat32_t y0 = svld1_f32(pg0, y);
2731
    svfloat32_t y1 = svld1_f32(pg1, y + lanes);
2732
    svfloat32_t y2 = svld1_f32(pg2, y + lanes2);
2733
    svfloat32_t y3 = svld1_f32(pg3, y + lanes3);
2734
    y0 = ElementOp::op(pg0, x0, y0);
2735
    y1 = ElementOp::op(pg1, x0, y1);
2736
    y2 = ElementOp::op(pg2, x0, y2);
2737
    y3 = ElementOp::op(pg3, x0, y3);
2738
    svst1_f32(pg0, dis, y0);
2739
    svst1_f32(pg1, dis + lanes, y1);
2740
    svst1_f32(pg2, dis + lanes2, y2);
2741
    svst1_f32(pg3, dis + lanes3, y3);
2742
}
2743
2744
template <typename ElementOp>
2745
void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) {
2746
    const size_t lanes = svcntw();
2747
    const size_t lanes2 = lanes * 2;
2748
    const size_t lanes4 = lanes * 4;
2749
    const svbool_t pg = svptrue_b32();
2750
    const svfloat32_t x0 = svdup_n_f32(x[0]);
2751
    const svfloat32_t x1 = svdup_n_f32(x[1]);
2752
    size_t i = 0;
2753
    for (; i + lanes2 < ny; i += lanes2) {
2754
        const svfloat32x2_t y0 = svld2_f32(pg, y);
2755
        const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2);
2756
        svfloat32_t y00 = svget2_f32(y0, 0);
2757
        const svfloat32_t y01 = svget2_f32(y0, 1);
2758
        svfloat32_t y10 = svget2_f32(y1, 0);
2759
        const svfloat32_t y11 = svget2_f32(y1, 1);
2760
        y00 = ElementOp::op(pg, x0, y00);
2761
        y10 = ElementOp::op(pg, x0, y10);
2762
        y00 = ElementOp::merge(pg, y00, x1, y01);
2763
        y10 = ElementOp::merge(pg, y10, x1, y11);
2764
        svst1_f32(pg, dis, y00);
2765
        svst1_f32(pg, dis + lanes, y10);
2766
        y += lanes4;
2767
        dis += lanes2;
2768
    }
2769
    const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2770
    const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
2771
    const svfloat32x2_t y0 = svld2_f32(pg0, y);
2772
    const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2);
2773
    svfloat32_t y00 = svget2_f32(y0, 0);
2774
    const svfloat32_t y01 = svget2_f32(y0, 1);
2775
    svfloat32_t y10 = svget2_f32(y1, 0);
2776
    const svfloat32_t y11 = svget2_f32(y1, 1);
2777
    y00 = ElementOp::op(pg0, x0, y00);
2778
    y10 = ElementOp::op(pg1, x0, y10);
2779
    y00 = ElementOp::merge(pg0, y00, x1, y01);
2780
    y10 = ElementOp::merge(pg1, y10, x1, y11);
2781
    svst1_f32(pg0, dis, y00);
2782
    svst1_f32(pg1, dis + lanes, y10);
2783
}
2784
2785
template <typename ElementOp>
2786
void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) {
2787
    const size_t lanes = svcntw();
2788
    const size_t lanes4 = lanes * 4;
2789
    const svbool_t pg = svptrue_b32();
2790
    const svfloat32_t x0 = svdup_n_f32(x[0]);
2791
    const svfloat32_t x1 = svdup_n_f32(x[1]);
2792
    const svfloat32_t x2 = svdup_n_f32(x[2]);
2793
    const svfloat32_t x3 = svdup_n_f32(x[3]);
2794
    size_t i = 0;
2795
    for (; i + lanes < ny; i += lanes) {
2796
        const svfloat32x4_t y0 = svld4_f32(pg, y);
2797
        svfloat32_t y00 = svget4_f32(y0, 0);
2798
        const svfloat32_t y01 = svget4_f32(y0, 1);
2799
        svfloat32_t y02 = svget4_f32(y0, 2);
2800
        const svfloat32_t y03 = svget4_f32(y0, 3);
2801
        y00 = ElementOp::op(pg, x0, y00);
2802
        y02 = ElementOp::op(pg, x2, y02);
2803
        y00 = ElementOp::merge(pg, y00, x1, y01);
2804
        y02 = ElementOp::merge(pg, y02, x3, y03);
2805
        y00 = svadd_f32_x(pg, y00, y02);
2806
        svst1_f32(pg, dis, y00);
2807
        y += lanes4;
2808
        dis += lanes;
2809
    }
2810
    const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2811
    const svfloat32x4_t y0 = svld4_f32(pg0, y);
2812
    svfloat32_t y00 = svget4_f32(y0, 0);
2813
    const svfloat32_t y01 = svget4_f32(y0, 1);
2814
    svfloat32_t y02 = svget4_f32(y0, 2);
2815
    const svfloat32_t y03 = svget4_f32(y0, 3);
2816
    y00 = ElementOp::op(pg0, x0, y00);
2817
    y02 = ElementOp::op(pg0, x2, y02);
2818
    y00 = ElementOp::merge(pg0, y00, x1, y01);
2819
    y02 = ElementOp::merge(pg0, y02, x3, y03);
2820
    y00 = svadd_f32_x(pg0, y00, y02);
2821
    svst1_f32(pg0, dis, y00);
2822
}
2823
2824
template <typename ElementOp>
2825
void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) {
2826
    const size_t lanes = svcntw();
2827
    const size_t lanes4 = lanes * 4;
2828
    const size_t lanes8 = lanes * 8;
2829
    const svbool_t pg = svptrue_b32();
2830
    const svfloat32_t x0 = svdup_n_f32(x[0]);
2831
    const svfloat32_t x1 = svdup_n_f32(x[1]);
2832
    const svfloat32_t x2 = svdup_n_f32(x[2]);
2833
    const svfloat32_t x3 = svdup_n_f32(x[3]);
2834
    const svfloat32_t x4 = svdup_n_f32(x[4]);
2835
    const svfloat32_t x5 = svdup_n_f32(x[5]);
2836
    const svfloat32_t x6 = svdup_n_f32(x[6]);
2837
    const svfloat32_t x7 = svdup_n_f32(x[7]);
2838
    size_t i = 0;
2839
    for (; i + lanes < ny; i += lanes) {
2840
        const svfloat32x4_t ya = svld4_f32(pg, y);
2841
        const svfloat32x4_t yb = svld4_f32(pg, y + lanes4);
2842
        const svfloat32_t ya0 = svget4_f32(ya, 0);
2843
        const svfloat32_t ya1 = svget4_f32(ya, 1);
2844
        const svfloat32_t ya2 = svget4_f32(ya, 2);
2845
        const svfloat32_t ya3 = svget4_f32(ya, 3);
2846
        const svfloat32_t yb0 = svget4_f32(yb, 0);
2847
        const svfloat32_t yb1 = svget4_f32(yb, 1);
2848
        const svfloat32_t yb2 = svget4_f32(yb, 2);
2849
        const svfloat32_t yb3 = svget4_f32(yb, 3);
2850
        svfloat32_t y0 = svuzp1(ya0, yb0);
2851
        const svfloat32_t y1 = svuzp1(ya1, yb1);
2852
        svfloat32_t y2 = svuzp1(ya2, yb2);
2853
        const svfloat32_t y3 = svuzp1(ya3, yb3);
2854
        svfloat32_t y4 = svuzp2(ya0, yb0);
2855
        const svfloat32_t y5 = svuzp2(ya1, yb1);
2856
        svfloat32_t y6 = svuzp2(ya2, yb2);
2857
        const svfloat32_t y7 = svuzp2(ya3, yb3);
2858
        y0 = ElementOp::op(pg, x0, y0);
2859
        y2 = ElementOp::op(pg, x2, y2);
2860
        y4 = ElementOp::op(pg, x4, y4);
2861
        y6 = ElementOp::op(pg, x6, y6);
2862
        y0 = ElementOp::merge(pg, y0, x1, y1);
2863
        y2 = ElementOp::merge(pg, y2, x3, y3);
2864
        y4 = ElementOp::merge(pg, y4, x5, y5);
2865
        y6 = ElementOp::merge(pg, y6, x7, y7);
2866
        y0 = svadd_f32_x(pg, y0, y2);
2867
        y4 = svadd_f32_x(pg, y4, y6);
2868
        y0 = svadd_f32_x(pg, y0, y4);
2869
        svst1_f32(pg, dis, y0);
2870
        y += lanes8;
2871
        dis += lanes;
2872
    }
2873
    const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2874
    const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2);
2875
    const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2);
2876
    const svfloat32x4_t ya = svld4_f32(pga, y);
2877
    const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4);
2878
    const svfloat32_t ya0 = svget4_f32(ya, 0);
2879
    const svfloat32_t ya1 = svget4_f32(ya, 1);
2880
    const svfloat32_t ya2 = svget4_f32(ya, 2);
2881
    const svfloat32_t ya3 = svget4_f32(ya, 3);
2882
    const svfloat32_t yb0 = svget4_f32(yb, 0);
2883
    const svfloat32_t yb1 = svget4_f32(yb, 1);
2884
    const svfloat32_t yb2 = svget4_f32(yb, 2);
2885
    const svfloat32_t yb3 = svget4_f32(yb, 3);
2886
    svfloat32_t y0 = svuzp1(ya0, yb0);
2887
    const svfloat32_t y1 = svuzp1(ya1, yb1);
2888
    svfloat32_t y2 = svuzp1(ya2, yb2);
2889
    const svfloat32_t y3 = svuzp1(ya3, yb3);
2890
    svfloat32_t y4 = svuzp2(ya0, yb0);
2891
    const svfloat32_t y5 = svuzp2(ya1, yb1);
2892
    svfloat32_t y6 = svuzp2(ya2, yb2);
2893
    const svfloat32_t y7 = svuzp2(ya3, yb3);
2894
    y0 = ElementOp::op(pg0, x0, y0);
2895
    y2 = ElementOp::op(pg0, x2, y2);
2896
    y4 = ElementOp::op(pg0, x4, y4);
2897
    y6 = ElementOp::op(pg0, x6, y6);
2898
    y0 = ElementOp::merge(pg0, y0, x1, y1);
2899
    y2 = ElementOp::merge(pg0, y2, x3, y3);
2900
    y4 = ElementOp::merge(pg0, y4, x5, y5);
2901
    y6 = ElementOp::merge(pg0, y6, x7, y7);
2902
    y0 = svadd_f32_x(pg0, y0, y2);
2903
    y4 = svadd_f32_x(pg0, y4, y6);
2904
    y0 = svadd_f32_x(pg0, y0, y4);
2905
    svst1_f32(pg0, dis, y0);
2906
    y += lanes8;
2907
    dis += lanes;
2908
}
2909
2910
template <typename ElementOp>
2911
void fvec_op_ny_sve_lanes1(
2912
        float* dis,
2913
        const float* x,
2914
        const float* y,
2915
        size_t ny) {
2916
    const size_t lanes = svcntw();
2917
    const size_t lanes2 = lanes * 2;
2918
    const size_t lanes3 = lanes * 3;
2919
    const size_t lanes4 = lanes * 4;
2920
    const svbool_t pg = svptrue_b32();
2921
    const svfloat32_t x0 = svld1_f32(pg, x);
2922
    size_t i = 0;
2923
    for (; i + 3 < ny; i += 4) {
2924
        svfloat32_t y0 = svld1_f32(pg, y);
2925
        svfloat32_t y1 = svld1_f32(pg, y + lanes);
2926
        svfloat32_t y2 = svld1_f32(pg, y + lanes2);
2927
        svfloat32_t y3 = svld1_f32(pg, y + lanes3);
2928
        y += lanes4;
2929
        y0 = ElementOp::op(pg, x0, y0);
2930
        y1 = ElementOp::op(pg, x0, y1);
2931
        y2 = ElementOp::op(pg, x0, y2);
2932
        y3 = ElementOp::op(pg, x0, y3);
2933
        dis[i] = svaddv_f32(pg, y0);
2934
        dis[i + 1] = svaddv_f32(pg, y1);
2935
        dis[i + 2] = svaddv_f32(pg, y2);
2936
        dis[i + 3] = svaddv_f32(pg, y3);
2937
    }
2938
    for (; i < ny; ++i) {
2939
        svfloat32_t y0 = svld1_f32(pg, y);
2940
        y += lanes;
2941
        y0 = ElementOp::op(pg, x0, y0);
2942
        dis[i] = svaddv_f32(pg, y0);
2943
    }
2944
}
2945
2946
template <typename ElementOp>
2947
void fvec_op_ny_sve_lanes2(
2948
        float* dis,
2949
        const float* x,
2950
        const float* y,
2951
        size_t ny) {
2952
    const size_t lanes = svcntw();
2953
    const size_t lanes2 = lanes * 2;
2954
    const size_t lanes3 = lanes * 3;
2955
    const size_t lanes4 = lanes * 4;
2956
    const svbool_t pg = svptrue_b32();
2957
    const svfloat32_t x0 = svld1_f32(pg, x);
2958
    const svfloat32_t x1 = svld1_f32(pg, x + lanes);
2959
    size_t i = 0;
2960
    for (; i + 1 < ny; i += 2) {
2961
        svfloat32_t y00 = svld1_f32(pg, y);
2962
        const svfloat32_t y01 = svld1_f32(pg, y + lanes);
2963
        svfloat32_t y10 = svld1_f32(pg, y + lanes2);
2964
        const svfloat32_t y11 = svld1_f32(pg, y + lanes3);
2965
        y += lanes4;
2966
        y00 = ElementOp::op(pg, x0, y00);
2967
        y10 = ElementOp::op(pg, x0, y10);
2968
        y00 = ElementOp::merge(pg, y00, x1, y01);
2969
        y10 = ElementOp::merge(pg, y10, x1, y11);
2970
        dis[i] = svaddv_f32(pg, y00);
2971
        dis[i + 1] = svaddv_f32(pg, y10);
2972
    }
2973
    if (i < ny) {
2974
        svfloat32_t y0 = svld1_f32(pg, y);
2975
        const svfloat32_t y1 = svld1_f32(pg, y + lanes);
2976
        y0 = ElementOp::op(pg, x0, y0);
2977
        y0 = ElementOp::merge(pg, y0, x1, y1);
2978
        dis[i] = svaddv_f32(pg, y0);
2979
    }
2980
}
2981
2982
template <typename ElementOp>
2983
void fvec_op_ny_sve_lanes3(
2984
        float* dis,
2985
        const float* x,
2986
        const float* y,
2987
        size_t ny) {
2988
    const size_t lanes = svcntw();
2989
    const size_t lanes2 = lanes * 2;
2990
    const size_t lanes3 = lanes * 3;
2991
    const svbool_t pg = svptrue_b32();
2992
    const svfloat32_t x0 = svld1_f32(pg, x);
2993
    const svfloat32_t x1 = svld1_f32(pg, x + lanes);
2994
    const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
2995
    for (size_t i = 0; i < ny; ++i) {
2996
        svfloat32_t y0 = svld1_f32(pg, y);
2997
        const svfloat32_t y1 = svld1_f32(pg, y + lanes);
2998
        svfloat32_t y2 = svld1_f32(pg, y + lanes2);
2999
        y += lanes3;
3000
        y0 = ElementOp::op(pg, x0, y0);
3001
        y0 = ElementOp::merge(pg, y0, x1, y1);
3002
        y0 = ElementOp::merge(pg, y0, x2, y2);
3003
        dis[i] = svaddv_f32(pg, y0);
3004
    }
3005
}
3006
3007
template <typename ElementOp>
3008
void fvec_op_ny_sve_lanes4(
3009
        float* dis,
3010
        const float* x,
3011
        const float* y,
3012
        size_t ny) {
3013
    const size_t lanes = svcntw();
3014
    const size_t lanes2 = lanes * 2;
3015
    const size_t lanes3 = lanes * 3;
3016
    const size_t lanes4 = lanes * 4;
3017
    const svbool_t pg = svptrue_b32();
3018
    const svfloat32_t x0 = svld1_f32(pg, x);
3019
    const svfloat32_t x1 = svld1_f32(pg, x + lanes);
3020
    const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
3021
    const svfloat32_t x3 = svld1_f32(pg, x + lanes3);
3022
    for (size_t i = 0; i < ny; ++i) {
3023
        svfloat32_t y0 = svld1_f32(pg, y);
3024
        const svfloat32_t y1 = svld1_f32(pg, y + lanes);
3025
        svfloat32_t y2 = svld1_f32(pg, y + lanes2);
3026
        const svfloat32_t y3 = svld1_f32(pg, y + lanes3);
3027
        y += lanes4;
3028
        y0 = ElementOp::op(pg, x0, y0);
3029
        y2 = ElementOp::op(pg, x2, y2);
3030
        y0 = ElementOp::merge(pg, y0, x1, y1);
3031
        y2 = ElementOp::merge(pg, y2, x3, y3);
3032
        y0 = svadd_f32_x(pg, y0, y2);
3033
        dis[i] = svaddv_f32(pg, y0);
3034
    }
3035
}
3036
3037
void fvec_L2sqr_ny(
3038
        float* dis,
3039
        const float* x,
3040
        const float* y,
3041
        size_t d,
3042
        size_t ny) {
3043
    fvec_L2sqr_ny_ref(dis, x, y, d, ny);
3044
}
3045
3046
void fvec_L2sqr_ny_transposed(
3047
        float* dis,
3048
        const float* x,
3049
        const float* y,
3050
        const float* y_sqlen,
3051
        size_t d,
3052
        size_t d_offset,
3053
        size_t ny) {
3054
    return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
3055
}
3056
3057
size_t fvec_L2sqr_ny_nearest(
3058
        float* distances_tmp_buffer,
3059
        const float* x,
3060
        const float* y,
3061
        size_t d,
3062
        size_t ny) {
3063
    return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
3064
}
3065
3066
size_t fvec_L2sqr_ny_nearest_y_transposed(
3067
        float* distances_tmp_buffer,
3068
        const float* x,
3069
        const float* y,
3070
        const float* y_sqlen,
3071
        size_t d,
3072
        size_t d_offset,
3073
        size_t ny) {
3074
    return fvec_L2sqr_ny_nearest_y_transposed_ref(
3075
            distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
3076
}
3077
3078
float fvec_L1(const float* x, const float* y, size_t d) {
3079
    return fvec_L1_ref(x, y, d);
3080
}
3081
3082
float fvec_Linf(const float* x, const float* y, size_t d) {
3083
    return fvec_Linf_ref(x, y, d);
3084
}
3085
3086
void fvec_inner_products_ny(
3087
        float* dis,
3088
        const float* x,
3089
        const float* y,
3090
        size_t d,
3091
        size_t ny) {
3092
    const size_t lanes = svcntw();
3093
    switch (d) {
3094
        case 1:
3095
            fvec_op_ny_sve_d1<ElementOpIP>(dis, x, y, ny);
3096
            break;
3097
        case 2:
3098
            fvec_op_ny_sve_d2<ElementOpIP>(dis, x, y, ny);
3099
            break;
3100
        case 4:
3101
            fvec_op_ny_sve_d4<ElementOpIP>(dis, x, y, ny);
3102
            break;
3103
        case 8:
3104
            fvec_op_ny_sve_d8<ElementOpIP>(dis, x, y, ny);
3105
            break;
3106
        default:
3107
            if (d == lanes)
3108
                fvec_op_ny_sve_lanes1<ElementOpIP>(dis, x, y, ny);
3109
            else if (d == lanes * 2)
3110
                fvec_op_ny_sve_lanes2<ElementOpIP>(dis, x, y, ny);
3111
            else if (d == lanes * 3)
3112
                fvec_op_ny_sve_lanes3<ElementOpIP>(dis, x, y, ny);
3113
            else if (d == lanes * 4)
3114
                fvec_op_ny_sve_lanes4<ElementOpIP>(dis, x, y, ny);
3115
            else
3116
                fvec_inner_products_ny_ref(dis, x, y, d, ny);
3117
            break;
3118
    }
3119
}
3120
3121
#elif defined(__aarch64__)
3122
3123
// not optimized for ARM
3124
void fvec_L2sqr_ny(
3125
        float* dis,
3126
        const float* x,
3127
        const float* y,
3128
        size_t d,
3129
        size_t ny) {
3130
    fvec_L2sqr_ny_ref(dis, x, y, d, ny);
3131
}
3132
3133
void fvec_L2sqr_ny_transposed(
3134
        float* dis,
3135
        const float* x,
3136
        const float* y,
3137
        const float* y_sqlen,
3138
        size_t d,
3139
        size_t d_offset,
3140
        size_t ny) {
3141
    return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
3142
}
3143
3144
size_t fvec_L2sqr_ny_nearest(
3145
        float* distances_tmp_buffer,
3146
        const float* x,
3147
        const float* y,
3148
        size_t d,
3149
        size_t ny) {
3150
    return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
3151
}
3152
3153
size_t fvec_L2sqr_ny_nearest_y_transposed(
3154
        float* distances_tmp_buffer,
3155
        const float* x,
3156
        const float* y,
3157
        const float* y_sqlen,
3158
        size_t d,
3159
        size_t d_offset,
3160
        size_t ny) {
3161
    return fvec_L2sqr_ny_nearest_y_transposed_ref(
3162
            distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
3163
}
3164
3165
float fvec_L1(const float* x, const float* y, size_t d) {
3166
    return fvec_L1_ref(x, y, d);
3167
}
3168
3169
float fvec_Linf(const float* x, const float* y, size_t d) {
3170
    return fvec_Linf_ref(x, y, d);
3171
}
3172
3173
void fvec_inner_products_ny(
3174
        float* dis,
3175
        const float* x,
3176
        const float* y,
3177
        size_t d,
3178
        size_t ny) {
3179
    fvec_inner_products_ny_ref(dis, x, y, d, ny);
3180
}
3181
3182
#else
3183
// scalar implementation
3184
3185
float fvec_L1(const float* x, const float* y, size_t d) {
3186
    return fvec_L1_ref(x, y, d);
3187
}
3188
3189
float fvec_Linf(const float* x, const float* y, size_t d) {
3190
    return fvec_Linf_ref(x, y, d);
3191
}
3192
3193
void fvec_L2sqr_ny(
3194
        float* dis,
3195
        const float* x,
3196
        const float* y,
3197
        size_t d,
3198
        size_t ny) {
3199
    fvec_L2sqr_ny_ref(dis, x, y, d, ny);
3200
}
3201
3202
void fvec_L2sqr_ny_transposed(
3203
        float* dis,
3204
        const float* x,
3205
        const float* y,
3206
        const float* y_sqlen,
3207
        size_t d,
3208
        size_t d_offset,
3209
        size_t ny) {
3210
    return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
3211
}
3212
3213
size_t fvec_L2sqr_ny_nearest(
3214
        float* distances_tmp_buffer,
3215
        const float* x,
3216
        const float* y,
3217
        size_t d,
3218
        size_t ny) {
3219
    return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
3220
}
3221
3222
size_t fvec_L2sqr_ny_nearest_y_transposed(
3223
        float* distances_tmp_buffer,
3224
        const float* x,
3225
        const float* y,
3226
        const float* y_sqlen,
3227
        size_t d,
3228
        size_t d_offset,
3229
        size_t ny) {
3230
    return fvec_L2sqr_ny_nearest_y_transposed_ref(
3231
            distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
3232
}
3233
3234
void fvec_inner_products_ny(
3235
        float* dis,
3236
        const float* x,
3237
        const float* y,
3238
        size_t d,
3239
        size_t ny) {
3240
    fvec_inner_products_ny_ref(dis, x, y, d, ny);
3241
}
3242
3243
#endif
3244
3245
/***************************************************************************
3246
 * heavily optimized table computations
3247
 ***************************************************************************/
3248
3249
[[maybe_unused]] static inline void fvec_madd_ref(
3250
        size_t n,
3251
        const float* a,
3252
        float bf,
3253
        const float* b,
3254
0
        float* c) {
3255
0
    for (size_t i = 0; i < n; i++)
3256
0
        c[i] = a[i] + bf * b[i];
3257
0
}
3258
3259
#if defined(__AVX512F__)
3260
3261
static inline void fvec_madd_avx512(
3262
        const size_t n,
3263
        const float* __restrict a,
3264
        const float bf,
3265
        const float* __restrict b,
3266
        float* __restrict c) {
3267
    const size_t n16 = n / 16;
3268
    const size_t n_for_masking = n % 16;
3269
3270
    const __m512 bfmm = _mm512_set1_ps(bf);
3271
3272
    size_t idx = 0;
3273
    for (idx = 0; idx < n16 * 16; idx += 16) {
3274
        const __m512 ax = _mm512_loadu_ps(a + idx);
3275
        const __m512 bx = _mm512_loadu_ps(b + idx);
3276
        const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
3277
        _mm512_storeu_ps(c + idx, abmul);
3278
    }
3279
3280
    if (n_for_masking > 0) {
3281
        const __mmask16 mask = (1 << n_for_masking) - 1;
3282
3283
        const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
3284
        const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
3285
        const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
3286
        _mm512_mask_storeu_ps(c + idx, mask, abmul);
3287
    }
3288
}
3289
3290
#elif defined(__AVX2__)
3291
3292
static inline void fvec_madd_avx2(
3293
        const size_t n,
3294
        const float* __restrict a,
3295
        const float bf,
3296
        const float* __restrict b,
3297
0
        float* __restrict c) {
3298
    //
3299
0
    const size_t n8 = n / 8;
3300
0
    const size_t n_for_masking = n % 8;
3301
3302
0
    const __m256 bfmm = _mm256_set1_ps(bf);
3303
3304
0
    size_t idx = 0;
3305
0
    for (idx = 0; idx < n8 * 8; idx += 8) {
3306
0
        const __m256 ax = _mm256_loadu_ps(a + idx);
3307
0
        const __m256 bx = _mm256_loadu_ps(b + idx);
3308
0
        const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
3309
0
        _mm256_storeu_ps(c + idx, abmul);
3310
0
    }
3311
3312
0
    if (n_for_masking > 0) {
3313
0
        __m256i mask;
3314
0
        switch (n_for_masking) {
3315
0
            case 1:
3316
0
                mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
3317
0
                break;
3318
0
            case 2:
3319
0
                mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
3320
0
                break;
3321
0
            case 3:
3322
0
                mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
3323
0
                break;
3324
0
            case 4:
3325
0
                mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
3326
0
                break;
3327
0
            case 5:
3328
0
                mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
3329
0
                break;
3330
0
            case 6:
3331
0
                mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
3332
0
                break;
3333
0
            case 7:
3334
0
                mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
3335
0
                break;
3336
0
        }
3337
3338
0
        const __m256 ax = _mm256_maskload_ps(a + idx, mask);
3339
0
        const __m256 bx = _mm256_maskload_ps(b + idx, mask);
3340
0
        const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
3341
0
        _mm256_maskstore_ps(c + idx, mask, abmul);
3342
0
    }
3343
0
}
3344
3345
#endif
3346
3347
#ifdef __SSE3__
3348
3349
[[maybe_unused]] static inline void fvec_madd_sse(
3350
        size_t n,
3351
        const float* a,
3352
        float bf,
3353
        const float* b,
3354
0
        float* c) {
3355
0
    n >>= 2;
3356
0
    __m128 bf4 = _mm_set_ps1(bf);
3357
0
    __m128* a4 = (__m128*)a;
3358
0
    __m128* b4 = (__m128*)b;
3359
0
    __m128* c4 = (__m128*)c;
3360
0
3361
0
    while (n--) {
3362
0
        *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
3363
0
        b4++;
3364
0
        a4++;
3365
0
        c4++;
3366
0
    }
3367
0
}
3368
3369
0
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
3370
#ifdef __AVX512F__
3371
    fvec_madd_avx512(n, a, bf, b, c);
3372
#elif __AVX2__
3373
    fvec_madd_avx2(n, a, bf, b, c);
3374
#else
3375
    if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
3376
        fvec_madd_sse(n, a, bf, b, c);
3377
    else
3378
        fvec_madd_ref(n, a, bf, b, c);
3379
#endif
3380
0
}
3381
3382
#elif defined(__ARM_FEATURE_SVE)
3383
3384
void fvec_madd(
3385
        const size_t n,
3386
        const float* __restrict a,
3387
        const float bf,
3388
        const float* __restrict b,
3389
        float* __restrict c) {
3390
    const size_t lanes = static_cast<size_t>(svcntw());
3391
    const size_t lanes2 = lanes * 2;
3392
    const size_t lanes3 = lanes * 3;
3393
    const size_t lanes4 = lanes * 4;
3394
    size_t i = 0;
3395
    for (; i + lanes4 < n; i += lanes4) {
3396
        const auto mask = svptrue_b32();
3397
        const auto ai0 = svld1_f32(mask, a + i);
3398
        const auto ai1 = svld1_f32(mask, a + i + lanes);
3399
        const auto ai2 = svld1_f32(mask, a + i + lanes2);
3400
        const auto ai3 = svld1_f32(mask, a + i + lanes3);
3401
        const auto bi0 = svld1_f32(mask, b + i);
3402
        const auto bi1 = svld1_f32(mask, b + i + lanes);
3403
        const auto bi2 = svld1_f32(mask, b + i + lanes2);
3404
        const auto bi3 = svld1_f32(mask, b + i + lanes3);
3405
        const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf);
3406
        const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf);
3407
        const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf);
3408
        const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf);
3409
        svst1_f32(mask, c + i, ci0);
3410
        svst1_f32(mask, c + i + lanes, ci1);
3411
        svst1_f32(mask, c + i + lanes2, ci2);
3412
        svst1_f32(mask, c + i + lanes3, ci3);
3413
    }
3414
    const auto mask0 = svwhilelt_b32_u64(i, n);
3415
    const auto mask1 = svwhilelt_b32_u64(i + lanes, n);
3416
    const auto mask2 = svwhilelt_b32_u64(i + lanes2, n);
3417
    const auto mask3 = svwhilelt_b32_u64(i + lanes3, n);
3418
    const auto ai0 = svld1_f32(mask0, a + i);
3419
    const auto ai1 = svld1_f32(mask1, a + i + lanes);
3420
    const auto ai2 = svld1_f32(mask2, a + i + lanes2);
3421
    const auto ai3 = svld1_f32(mask3, a + i + lanes3);
3422
    const auto bi0 = svld1_f32(mask0, b + i);
3423
    const auto bi1 = svld1_f32(mask1, b + i + lanes);
3424
    const auto bi2 = svld1_f32(mask2, b + i + lanes2);
3425
    const auto bi3 = svld1_f32(mask3, b + i + lanes3);
3426
    const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf);
3427
    const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf);
3428
    const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf);
3429
    const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf);
3430
    svst1_f32(mask0, c + i, ci0);
3431
    svst1_f32(mask1, c + i + lanes, ci1);
3432
    svst1_f32(mask2, c + i + lanes2, ci2);
3433
    svst1_f32(mask3, c + i + lanes3, ci3);
3434
}
3435
3436
#elif defined(__aarch64__)
3437
3438
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
3439
    const size_t n_simd = n - (n & 3);
3440
    const float32x4_t bfv = vdupq_n_f32(bf);
3441
    size_t i;
3442
    for (i = 0; i < n_simd; i += 4) {
3443
        const float32x4_t ai = vld1q_f32(a + i);
3444
        const float32x4_t bi = vld1q_f32(b + i);
3445
        const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
3446
        vst1q_f32(c + i, ci);
3447
    }
3448
    for (; i < n; ++i)
3449
        c[i] = a[i] + bf * b[i];
3450
}
3451
3452
#else
3453
3454
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
3455
    fvec_madd_ref(n, a, bf, b, c);
3456
}
3457
3458
#endif
3459
3460
static inline int fvec_madd_and_argmin_ref(
3461
        size_t n,
3462
        const float* a,
3463
        float bf,
3464
        const float* b,
3465
0
        float* c) {
3466
0
    float vmin = 1e20;
3467
0
    int imin = -1;
3468
3469
0
    for (size_t i = 0; i < n; i++) {
3470
0
        c[i] = a[i] + bf * b[i];
3471
0
        if (c[i] < vmin) {
3472
0
            vmin = c[i];
3473
0
            imin = i;
3474
0
        }
3475
0
    }
3476
0
    return imin;
3477
0
}
3478
3479
#ifdef __SSE3__
3480
3481
static inline int fvec_madd_and_argmin_sse(
3482
        size_t n,
3483
        const float* a,
3484
        float bf,
3485
        const float* b,
3486
0
        float* c) {
3487
0
    n >>= 2;
3488
0
    __m128 bf4 = _mm_set_ps1(bf);
3489
0
    __m128 vmin4 = _mm_set_ps1(1e20);
3490
0
    __m128i imin4 = _mm_set1_epi32(-1);
3491
0
    __m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
3492
0
    __m128i inc4 = _mm_set1_epi32(4);
3493
0
    __m128* a4 = (__m128*)a;
3494
0
    __m128* b4 = (__m128*)b;
3495
0
    __m128* c4 = (__m128*)c;
3496
3497
0
    while (n--) {
3498
0
        __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
3499
0
        *c4 = vc4;
3500
0
        __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
3501
        // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
3502
3503
0
        imin4 = _mm_or_si128(
3504
0
                _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
3505
0
        vmin4 = _mm_min_ps(vmin4, vc4);
3506
0
        b4++;
3507
0
        a4++;
3508
0
        c4++;
3509
0
        idx4 = _mm_add_epi32(idx4, inc4);
3510
0
    }
3511
3512
    // 4 values -> 2
3513
0
    {
3514
0
        idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
3515
0
        __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
3516
0
        __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
3517
0
        imin4 = _mm_or_si128(
3518
0
                _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
3519
0
        vmin4 = _mm_min_ps(vmin4, vc4);
3520
0
    }
3521
    // 2 values -> 1
3522
0
    {
3523
0
        idx4 = _mm_shuffle_epi32(imin4, 1);
3524
0
        __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
3525
0
        __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
3526
0
        imin4 = _mm_or_si128(
3527
0
                _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
3528
        // vmin4 = _mm_min_ps (vmin4, vc4);
3529
0
    }
3530
0
    return _mm_cvtsi128_si32(imin4);
3531
0
}
3532
3533
int fvec_madd_and_argmin(
3534
        size_t n,
3535
        const float* a,
3536
        float bf,
3537
        const float* b,
3538
0
        float* c) {
3539
0
    if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
3540
0
        return fvec_madd_and_argmin_sse(n, a, bf, b, c);
3541
0
    else
3542
0
        return fvec_madd_and_argmin_ref(n, a, bf, b, c);
3543
0
}
3544
3545
#elif defined(__aarch64__)
3546
3547
int fvec_madd_and_argmin(
3548
        size_t n,
3549
        const float* a,
3550
        float bf,
3551
        const float* b,
3552
        float* c) {
3553
    float32x4_t vminv = vdupq_n_f32(1e20);
3554
    uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
3555
    size_t i;
3556
    {
3557
        const size_t n_simd = n - (n & 3);
3558
        const uint32_t iota[] = {0, 1, 2, 3};
3559
        uint32x4_t iv = vld1q_u32(iota);
3560
        const uint32x4_t incv = vdupq_n_u32(4);
3561
        const float32x4_t bfv = vdupq_n_f32(bf);
3562
        for (i = 0; i < n_simd; i += 4) {
3563
            const float32x4_t ai = vld1q_f32(a + i);
3564
            const float32x4_t bi = vld1q_f32(b + i);
3565
            const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
3566
            vst1q_f32(c + i, ci);
3567
            const uint32x4_t less_than = vcltq_f32(ci, vminv);
3568
            vminv = vminq_f32(ci, vminv);
3569
            iminv = vorrq_u32(
3570
                    vandq_u32(less_than, iv),
3571
                    vandq_u32(vmvnq_u32(less_than), iminv));
3572
            iv = vaddq_u32(iv, incv);
3573
        }
3574
    }
3575
    float vmin = vminvq_f32(vminv);
3576
    uint32_t imin;
3577
    {
3578
        const float32x4_t vminy = vdupq_n_f32(vmin);
3579
        const uint32x4_t equals = vceqq_f32(vminv, vminy);
3580
        imin = vminvq_u32(vorrq_u32(
3581
                vandq_u32(equals, iminv),
3582
                vandq_u32(
3583
                        vmvnq_u32(equals),
3584
                        vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
3585
    }
3586
    for (; i < n; ++i) {
3587
        c[i] = a[i] + bf * b[i];
3588
        if (c[i] < vmin) {
3589
            vmin = c[i];
3590
            imin = static_cast<uint32_t>(i);
3591
        }
3592
    }
3593
    return static_cast<int>(imin);
3594
}
3595
3596
#else
3597
3598
int fvec_madd_and_argmin(
3599
        size_t n,
3600
        const float* a,
3601
        float bf,
3602
        const float* b,
3603
        float* c) {
3604
    return fvec_madd_and_argmin_ref(n, a, bf, b, c);
3605
}
3606
3607
#endif
3608
3609
/***************************************************************************
3610
 * PQ tables computations
3611
 ***************************************************************************/
3612
3613
namespace {
3614
3615
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
3616
template <bool is_inner_product>
3617
void pq2_8cents_table(
3618
        const simd8float32 centroids[8],
3619
        const simd8float32 x,
3620
        float* out,
3621
        size_t ldo,
3622
0
        size_t nout = 4) {
3623
0
    simd8float32 ips[4];
3624
3625
0
    for (int i = 0; i < 4; i++) {
3626
0
        simd8float32 p1, p2;
3627
0
        if (is_inner_product) {
3628
0
            p1 = x * centroids[2 * i];
3629
0
            p2 = x * centroids[2 * i + 1];
3630
0
        } else {
3631
0
            p1 = (x - centroids[2 * i]);
3632
0
            p1 = p1 * p1;
3633
0
            p2 = (x - centroids[2 * i + 1]);
3634
0
            p2 = p2 * p2;
3635
0
        }
3636
0
        ips[i] = hadd(p1, p2);
3637
0
    }
3638
3639
0
    simd8float32 ip02a = geteven(ips[0], ips[1]);
3640
0
    simd8float32 ip02b = geteven(ips[2], ips[3]);
3641
0
    simd8float32 ip0 = getlow128(ip02a, ip02b);
3642
0
    simd8float32 ip2 = gethigh128(ip02a, ip02b);
3643
3644
0
    simd8float32 ip13a = getodd(ips[0], ips[1]);
3645
0
    simd8float32 ip13b = getodd(ips[2], ips[3]);
3646
0
    simd8float32 ip1 = getlow128(ip13a, ip13b);
3647
0
    simd8float32 ip3 = gethigh128(ip13a, ip13b);
3648
3649
0
    switch (nout) {
3650
0
        case 4:
3651
0
            ip3.storeu(out + 3 * ldo);
3652
0
            [[fallthrough]];
3653
0
        case 3:
3654
0
            ip2.storeu(out + 2 * ldo);
3655
0
            [[fallthrough]];
3656
0
        case 2:
3657
0
            ip1.storeu(out + 1 * ldo);
3658
0
            [[fallthrough]];
3659
0
        case 1:
3660
0
            ip0.storeu(out);
3661
0
    }
3662
0
}
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_116pq2_8cents_tableILb1EEEvPKNS_12simd8float32ES2_Pfmm
Unexecuted instantiation: distances_simd.cpp:_ZN5faiss12_GLOBAL__N_116pq2_8cents_tableILb0EEEvPKNS_12simd8float32ES2_Pfmm
3663
3664
0
simd8float32 load_simd8float32_partial(const float* x, int n) {
3665
0
    ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
3666
0
    float* wp = tmp;
3667
0
    for (int i = 0; i < n; i++) {
3668
0
        *wp++ = *x++;
3669
0
    }
3670
0
    return simd8float32(tmp);
3671
0
}
3672
3673
} // anonymous namespace
3674
3675
void compute_PQ_dis_tables_dsub2(
3676
        size_t d,
3677
        size_t ksub,
3678
        const float* all_centroids,
3679
        size_t nx,
3680
        const float* x,
3681
        bool is_inner_product,
3682
0
        float* dis_tables) {
3683
0
    size_t M = d / 2;
3684
0
    FAISS_THROW_IF_NOT(ksub % 8 == 0);
3685
3686
0
    for (size_t m0 = 0; m0 < M; m0 += 4) {
3687
0
        int m1 = std::min(M, m0 + 4);
3688
0
        for (int k0 = 0; k0 < ksub; k0 += 8) {
3689
0
            simd8float32 centroids[8];
3690
0
            for (int k = 0; k < 8; k++) {
3691
0
                ALIGNED(32) float centroid[8];
3692
0
                size_t wp = 0;
3693
0
                size_t rp = (m0 * ksub + k + k0) * 2;
3694
0
                for (int m = m0; m < m1; m++) {
3695
0
                    centroid[wp++] = all_centroids[rp];
3696
0
                    centroid[wp++] = all_centroids[rp + 1];
3697
0
                    rp += 2 * ksub;
3698
0
                }
3699
0
                centroids[k] = simd8float32(centroid);
3700
0
            }
3701
0
            for (size_t i = 0; i < nx; i++) {
3702
0
                simd8float32 xi;
3703
0
                if (m1 == m0 + 4) {
3704
0
                    xi.loadu(x + i * d + m0 * 2);
3705
0
                } else {
3706
0
                    xi = load_simd8float32_partial(
3707
0
                            x + i * d + m0 * 2, 2 * (m1 - m0));
3708
0
                }
3709
3710
0
                if (is_inner_product) {
3711
0
                    pq2_8cents_table<true>(
3712
0
                            centroids,
3713
0
                            xi,
3714
0
                            dis_tables + (i * M + m0) * ksub + k0,
3715
0
                            ksub,
3716
0
                            m1 - m0);
3717
0
                } else {
3718
0
                    pq2_8cents_table<false>(
3719
0
                            centroids,
3720
0
                            xi,
3721
0
                            dis_tables + (i * M + m0) * ksub + k0,
3722
0
                            ksub,
3723
0
                            m1 - m0);
3724
0
                }
3725
0
            }
3726
0
        }
3727
0
    }
3728
0
}
3729
3730
/*********************************************************
3731
 * Vector to vector functions
3732
 *********************************************************/
3733
3734
0
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
3735
0
    size_t i;
3736
0
    for (i = 0; i + 7 < d; i += 8) {
3737
0
        simd8float32 ci, ai, bi;
3738
0
        ai.loadu(a + i);
3739
0
        bi.loadu(b + i);
3740
0
        ci = ai - bi;
3741
0
        ci.storeu(c + i);
3742
0
    }
3743
    // finish non-multiple of 8 remainder
3744
0
    for (; i < d; i++) {
3745
0
        c[i] = a[i] - b[i];
3746
0
    }
3747
0
}
3748
3749
0
void fvec_add(size_t d, const float* a, const float* b, float* c) {
3750
0
    size_t i;
3751
0
    for (i = 0; i + 7 < d; i += 8) {
3752
0
        simd8float32 ci, ai, bi;
3753
0
        ai.loadu(a + i);
3754
0
        bi.loadu(b + i);
3755
0
        ci = ai + bi;
3756
0
        ci.storeu(c + i);
3757
0
    }
3758
    // finish non-multiple of 8 remainder
3759
0
    for (; i < d; i++) {
3760
0
        c[i] = a[i] + b[i];
3761
0
    }
3762
0
}
3763
3764
0
void fvec_add(size_t d, const float* a, float b, float* c) {
3765
0
    size_t i;
3766
0
    simd8float32 bv(b);
3767
0
    for (i = 0; i + 7 < d; i += 8) {
3768
0
        simd8float32 ci, ai;
3769
0
        ai.loadu(a + i);
3770
0
        ci = ai + bv;
3771
0
        ci.storeu(c + i);
3772
0
    }
3773
    // finish non-multiple of 8 remainder
3774
0
    for (; i < d; i++) {
3775
0
        c[i] = a[i] + b;
3776
0
    }
3777
0
}
3778
3779
} // namespace faiss