contrib/faiss/faiss/utils/approx_topk/avx2-inl.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <immintrin.h> |
11 | | |
12 | | #include <limits> |
13 | | |
14 | | #include <faiss/impl/FaissAssert.h> |
15 | | #include <faiss/utils/Heap.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | template <typename C, uint32_t NBUCKETS, uint32_t N> |
20 | | struct HeapWithBuckets { |
21 | | // this case was not implemented yet. |
22 | | }; |
23 | | |
24 | | template <uint32_t NBUCKETS, uint32_t N> |
25 | | struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> { |
26 | | static constexpr uint32_t NBUCKETS_8 = NBUCKETS / 8; |
27 | | static_assert( |
28 | | (NBUCKETS) > 0 && ((NBUCKETS % 8) == 0), |
29 | | "Number of buckets needs to be 8, 16, 24, ..."); |
30 | | |
31 | | static void addn( |
32 | | // number of elements |
33 | | const uint32_t n, |
34 | | // distances. It is assumed to have n elements. |
35 | | const float* const __restrict distances, |
36 | | // number of best elements to keep |
37 | | const uint32_t k, |
38 | | // output distances |
39 | | float* const __restrict bh_val, |
40 | | // output indices, each being within [0, n) range |
41 | 0 | int32_t* const __restrict bh_ids) { |
42 | | // forward a call to bs_addn with 1 beam |
43 | 0 | bs_addn(1, n, distances, k, bh_val, bh_ids); |
44 | 0 | } |
45 | | |
46 | | static void bs_addn( |
47 | | // beam_size parameter of Beam Search algorithm |
48 | | const uint32_t beam_size, |
49 | | // number of elements per beam |
50 | | const uint32_t n_per_beam, |
51 | | // distances. It is assumed to have (n_per_beam * beam_size) |
52 | | // elements. |
53 | | const float* const __restrict distances, |
54 | | // number of best elements to keep |
55 | | const uint32_t k, |
56 | | // output distances |
57 | | float* const __restrict bh_val, |
58 | | // output indices, each being within [0, n_per_beam * beam_size) |
59 | | // range |
60 | 0 | int32_t* const __restrict bh_ids) { |
61 | | // // Basically, the function runs beam_size iterations. |
62 | | // // Every iteration NBUCKETS * N elements are added to a regular heap. |
63 | | // // So, maximum number of added elements is beam_size * NBUCKETS * N. |
64 | | // // This number is expected to be less or equal than k. |
65 | | // FAISS_THROW_IF_NOT_FMT( |
66 | | // beam_size * NBUCKETS * N >= k, |
67 | | // "Cannot pick %d elements, only %d. " |
68 | | // "Check the function and template arguments values.", |
69 | | // k, |
70 | | // beam_size * NBUCKETS * N); |
71 | |
|
72 | 0 | using C = CMax<float, int>; |
73 | | |
74 | | // main loop |
75 | 0 | for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) { |
76 | 0 | __m256 min_distances_i[NBUCKETS_8][N]; |
77 | 0 | __m256i min_indices_i[NBUCKETS_8][N]; |
78 | |
|
79 | 0 | for (uint32_t j = 0; j < NBUCKETS_8; j++) { |
80 | 0 | for (uint32_t p = 0; p < N; p++) { |
81 | 0 | min_distances_i[j][p] = |
82 | 0 | _mm256_set1_ps(std::numeric_limits<float>::max()); |
83 | 0 | min_indices_i[j][p] = |
84 | 0 | _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
85 | 0 | } |
86 | 0 | } |
87 | |
|
88 | 0 | __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
89 | 0 | __m256i indices_delta = _mm256_set1_epi32(NBUCKETS); |
90 | |
|
91 | 0 | const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS; |
92 | | |
93 | | // put the data into buckets |
94 | 0 | for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) { |
95 | 0 | for (uint32_t j = 0; j < NBUCKETS_8; j++) { |
96 | 0 | const __m256 distances_reg = _mm256_loadu_ps( |
97 | 0 | distances + j * 8 + ip + n_per_beam * beam_index); |
98 | | |
99 | | // loop. Compiler should get rid of unneeded ops |
100 | 0 | __m256 distance_candidate = distances_reg; |
101 | 0 | __m256i indices_candidate = current_indices; |
102 | |
|
103 | 0 | for (uint32_t p = 0; p < N; p++) { |
104 | 0 | const __m256 comparison = _mm256_cmp_ps( |
105 | 0 | min_distances_i[j][p], |
106 | 0 | distance_candidate, |
107 | 0 | _CMP_LE_OS); |
108 | | |
109 | | // // blend seems to be slower that min |
110 | | // const __m256 min_distances_new = _mm256_blendv_ps( |
111 | | // distance_candidate, |
112 | | // min_distances_i[j][p], |
113 | | // comparison); |
114 | 0 | const __m256 min_distances_new = _mm256_min_ps( |
115 | 0 | distance_candidate, min_distances_i[j][p]); |
116 | 0 | const __m256i min_indices_new = |
117 | 0 | _mm256_castps_si256(_mm256_blendv_ps( |
118 | 0 | _mm256_castsi256_ps(indices_candidate), |
119 | 0 | _mm256_castsi256_ps( |
120 | 0 | min_indices_i[j][p]), |
121 | 0 | comparison)); |
122 | | |
123 | | // // blend seems to be slower that min |
124 | | // const __m256 max_distances_new = _mm256_blendv_ps( |
125 | | // min_distances_i[j][p], |
126 | | // distance_candidate, |
127 | | // comparison); |
128 | 0 | const __m256 max_distances_new = _mm256_max_ps( |
129 | 0 | min_distances_i[j][p], distances_reg); |
130 | 0 | const __m256i max_indices_new = |
131 | 0 | _mm256_castps_si256(_mm256_blendv_ps( |
132 | 0 | _mm256_castsi256_ps( |
133 | 0 | min_indices_i[j][p]), |
134 | 0 | _mm256_castsi256_ps(indices_candidate), |
135 | 0 | comparison)); |
136 | |
|
137 | 0 | distance_candidate = max_distances_new; |
138 | 0 | indices_candidate = max_indices_new; |
139 | |
|
140 | 0 | min_distances_i[j][p] = min_distances_new; |
141 | 0 | min_indices_i[j][p] = min_indices_new; |
142 | 0 | } |
143 | 0 | } |
144 | |
|
145 | 0 | current_indices = |
146 | 0 | _mm256_add_epi32(current_indices, indices_delta); |
147 | 0 | } |
148 | | |
149 | | // fix the indices |
150 | 0 | for (uint32_t j = 0; j < NBUCKETS_8; j++) { |
151 | 0 | const __m256i offset = |
152 | 0 | _mm256_set1_epi32(n_per_beam * beam_index + j * 8); |
153 | 0 | for (uint32_t p = 0; p < N; p++) { |
154 | 0 | min_indices_i[j][p] = |
155 | 0 | _mm256_add_epi32(min_indices_i[j][p], offset); |
156 | 0 | } |
157 | 0 | } |
158 | | |
159 | | // merge every bucket into the regular heap |
160 | 0 | for (uint32_t p = 0; p < N; p++) { |
161 | 0 | for (uint32_t j = 0; j < NBUCKETS_8; j++) { |
162 | 0 | int32_t min_indices_scalar[8]; |
163 | 0 | float min_distances_scalar[8]; |
164 | |
|
165 | 0 | _mm256_storeu_si256( |
166 | 0 | (__m256i*)min_indices_scalar, min_indices_i[j][p]); |
167 | 0 | _mm256_storeu_ps( |
168 | 0 | min_distances_scalar, min_distances_i[j][p]); |
169 | | |
170 | | // this exact way is needed to maintain the order as if the |
171 | | // input elements were pushed to the heap sequentially |
172 | 0 | for (size_t j8 = 0; j8 < 8; j8++) { |
173 | 0 | const auto value = min_distances_scalar[j8]; |
174 | 0 | const auto index = min_indices_scalar[j8]; |
175 | 0 | if (C::cmp2(bh_val[0], value, bh_ids[0], index)) { |
176 | 0 | heap_replace_top<C>( |
177 | 0 | k, bh_val, bh_ids, value, index); |
178 | 0 | } |
179 | 0 | } |
180 | 0 | } |
181 | 0 | } |
182 | | |
183 | | // process leftovers |
184 | 0 | for (uint32_t ip = nb; ip < n_per_beam; ip++) { |
185 | 0 | const int32_t index = ip + n_per_beam * beam_index; |
186 | 0 | const float value = distances[index]; |
187 | |
|
188 | 0 | if (C::cmp(bh_val[0], value)) { |
189 | 0 | heap_replace_top<C>(k, bh_val, bh_ids, value, index); |
190 | 0 | } |
191 | 0 | } |
192 | 0 | } |
193 | 0 | } Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj16ELj1EE7bs_addnEjjPKfjPfPi Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj8ELj3EE7bs_addnEjjPKfjPfPi Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj8ELj2EE7bs_addnEjjPKfjPfPi Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj16ELj2EE7bs_addnEjjPKfjPfPi Unexecuted instantiation: _ZN5faiss15HeapWithBucketsINS_4CMaxIfiEELj32ELj2EE7bs_addnEjjPKfjPfPi |
194 | | }; |
195 | | |
196 | | } // namespace faiss |