Coverage Report

Created: 2026-03-16 04:30

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/utils/approx_topk/avx2-inl.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#pragma once
9
10
#include <immintrin.h>
11
12
#include <limits>
13
14
#include <faiss/impl/FaissAssert.h>
15
#include <faiss/utils/Heap.h>
16
17
namespace faiss {
18
19
template <typename C, uint32_t NBUCKETS, uint32_t N>
20
struct HeapWithBuckets {
21
    // this case was not implemented yet.
22
};
23
24
template <uint32_t NBUCKETS, uint32_t N>
25
struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> {
26
    static constexpr uint32_t NBUCKETS_8 = NBUCKETS / 8;
27
    static_assert(
28
            (NBUCKETS) > 0 && ((NBUCKETS % 8) == 0),
29
            "Number of buckets needs to be 8, 16, 24, ...");
30
31
    static void addn(
32
            // number of elements
33
            const uint32_t n,
34
            // distances. It is assumed to have n elements.
35
            const float* const __restrict distances,
36
            // number of best elements to keep
37
            const uint32_t k,
38
            // output distances
39
            float* const __restrict bh_val,
40
            // output indices, each being within [0, n) range
41
0
            int32_t* const __restrict bh_ids) {
42
        // forward a call to bs_addn with 1 beam
43
0
        bs_addn(1, n, distances, k, bh_val, bh_ids);
44
0
    }
45
46
    static void bs_addn(
47
            // beam_size parameter of Beam Search algorithm
48
            const uint32_t beam_size,
49
            // number of elements per beam
50
            const uint32_t n_per_beam,
51
            // distances. It is assumed to have (n_per_beam * beam_size)
52
            // elements.
53
            const float* const __restrict distances,
54
            // number of best elements to keep
55
            const uint32_t k,
56
            // output distances
57
            float* const __restrict bh_val,
58
            // output indices, each being within [0, n_per_beam * beam_size)
59
            // range
60
0
            int32_t* const __restrict bh_ids) {
61
        // // Basically, the function runs beam_size iterations.
62
        // // Every iteration NBUCKETS * N elements are added to a regular heap.
63
        // // So, maximum number of added elements is beam_size * NBUCKETS * N.
64
        // // This number is expected to be less or equal than k.
65
        // FAISS_THROW_IF_NOT_FMT(
66
        //         beam_size * NBUCKETS * N >= k,
67
        //         "Cannot pick %d elements, only %d. "
68
        //         "Check the function and template arguments values.",
69
        //         k,
70
        //         beam_size * NBUCKETS * N);
71
72
0
        using C = CMax<float, int>;
73
74
        // main loop
75
0
        for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
76
0
            __m256 min_distances_i[NBUCKETS_8][N];
77
0
            __m256i min_indices_i[NBUCKETS_8][N];
78
79
0
            for (uint32_t j = 0; j < NBUCKETS_8; j++) {
80
0
                for (uint32_t p = 0; p < N; p++) {
81
0
                    min_distances_i[j][p] =
82
0
                            _mm256_set1_ps(std::numeric_limits<float>::max());
83
0
                    min_indices_i[j][p] =
84
0
                            _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
85
0
                }
86
0
            }
87
88
0
            __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
89
0
            __m256i indices_delta = _mm256_set1_epi32(NBUCKETS);
90
91
0
            const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
92
93
            // put the data into buckets
94
0
            for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
95
0
                for (uint32_t j = 0; j < NBUCKETS_8; j++) {
96
0
                    const __m256 distances_reg = _mm256_loadu_ps(
97
0
                            distances + j * 8 + ip + n_per_beam * beam_index);
98
99
                    // loop. Compiler should get rid of unneeded ops
100
0
                    __m256 distance_candidate = distances_reg;
101
0
                    __m256i indices_candidate = current_indices;
102
103
0
                    for (uint32_t p = 0; p < N; p++) {
104
0
                        const __m256 comparison = _mm256_cmp_ps(
105
0
                                min_distances_i[j][p],
106
0
                                distance_candidate,
107
0
                                _CMP_LE_OS);
108
109
                        // // blend seems to be slower that min
110
                        // const __m256 min_distances_new = _mm256_blendv_ps(
111
                        //         distance_candidate,
112
                        //         min_distances_i[j][p],
113
                        //         comparison);
114
0
                        const __m256 min_distances_new = _mm256_min_ps(
115
0
                                distance_candidate, min_distances_i[j][p]);
116
0
                        const __m256i min_indices_new =
117
0
                                _mm256_castps_si256(_mm256_blendv_ps(
118
0
                                        _mm256_castsi256_ps(indices_candidate),
119
0
                                        _mm256_castsi256_ps(
120
0
                                                min_indices_i[j][p]),
121
0
                                        comparison));
122
123
                        // // blend seems to be slower that min
124
                        // const __m256 max_distances_new = _mm256_blendv_ps(
125
                        //         min_distances_i[j][p],
126
                        //         distance_candidate,
127
                        //         comparison);
128
0
                        const __m256 max_distances_new = _mm256_max_ps(
129
0
                                min_distances_i[j][p], distances_reg);
130
0
                        const __m256i max_indices_new =
131
0
                                _mm256_castps_si256(_mm256_blendv_ps(
132
0
                                        _mm256_castsi256_ps(
133
0
                                                min_indices_i[j][p]),
134
0
                                        _mm256_castsi256_ps(indices_candidate),
135
0
                                        comparison));
136
137
0
                        distance_candidate = max_distances_new;
138
0
                        indices_candidate = max_indices_new;
139
140
0
                        min_distances_i[j][p] = min_distances_new;
141
0
                        min_indices_i[j][p] = min_indices_new;
142
0
                    }
143
0
                }
144
145
0
                current_indices =
146
0
                        _mm256_add_epi32(current_indices, indices_delta);
147
0
            }
148
149
            // fix the indices
150
0
            for (uint32_t j = 0; j < NBUCKETS_8; j++) {
151
0
                const __m256i offset =
152
0
                        _mm256_set1_epi32(n_per_beam * beam_index + j * 8);
153
0
                for (uint32_t p = 0; p < N; p++) {
154
0
                    min_indices_i[j][p] =
155
0
                            _mm256_add_epi32(min_indices_i[j][p], offset);
156
0
                }
157
0
            }
158
159
            // merge every bucket into the regular heap
160
0
            for (uint32_t p = 0; p < N; p++) {
161
0
                for (uint32_t j = 0; j < NBUCKETS_8; j++) {
162
0
                    int32_t min_indices_scalar[8];
163
0
                    float min_distances_scalar[8];
164
165
0
                    _mm256_storeu_si256(
166
0
                            (__m256i*)min_indices_scalar, min_indices_i[j][p]);
167
0
                    _mm256_storeu_ps(
168
0
                            min_distances_scalar, min_distances_i[j][p]);
169
170
                    // this exact way is needed to maintain the order as if the
171
                    // input elements were pushed to the heap sequentially
172
0
                    for (size_t j8 = 0; j8 < 8; j8++) {
173
0
                        const auto value = min_distances_scalar[j8];
174
0
                        const auto index = min_indices_scalar[j8];
175
0
                        if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
176
0
                            heap_replace_top<C>(
177
0
                                    k, bh_val, bh_ids, value, index);
178
0
                        }
179
0
                    }
180
0
                }
181
0
            }
182
183
            // process leftovers
184
0
            for (uint32_t ip = nb; ip < n_per_beam; ip++) {
185
0
                const int32_t index = ip + n_per_beam * beam_index;
186
0
                const float value = distances[index];
187
188
0
                if (C::cmp(bh_val[0], value)) {
189
0
                    heap_replace_top<C>(k, bh_val, bh_ids, value, index);
190
0
                }
191
0
            }
192
0
        }
193
0
    }
Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj16ELj1EE7bs_addnEjjPKfjPfPi
Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj8ELj3EE7bs_addnEjjPKfjPfPi
Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj8ELj2EE7bs_addnEjjPKfjPfPi
Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj16ELj2EE7bs_addnEjjPKfjPfPi
Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj32ELj2EE7bs_addnEjjPKfjPfPi
194
};
195
196
} // namespace faiss