contrib/faiss/faiss/impl/code_distance/code_distance-avx2.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #ifdef __AVX2__ |
11 | | |
12 | | #include <immintrin.h> |
13 | | |
14 | | #include <type_traits> |
15 | | |
16 | | #include <faiss/impl/ProductQuantizer.h> |
17 | | #include <faiss/impl/code_distance/code_distance-generic.h> |
18 | | |
19 | | // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782 |
20 | | #if defined(__GNUC__) && __GNUC__ < 9 |
21 | 0 | #define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x)) |
22 | | #endif |
23 | | |
24 | | namespace { |
25 | | |
26 | 0 | inline float horizontal_sum(const __m128 v) { |
27 | 0 | const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); |
28 | 0 | const __m128 v1 = _mm_add_ps(v, v0); |
29 | 0 | __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); |
30 | 0 | const __m128 v3 = _mm_add_ps(v1, v2); |
31 | 0 | return _mm_cvtss_f32(v3); |
32 | 0 | } Unexecuted instantiation: IndexIVFPQ.cpp:_ZN12_GLOBAL__N_114horizontal_sumEDv4_f Unexecuted instantiation: IndexPQ.cpp:_ZN12_GLOBAL__N_114horizontal_sumEDv4_f |
33 | | |
34 | | // Computes a horizontal sum over an __m256 register |
35 | 0 | inline float horizontal_sum(const __m256 v) { |
36 | 0 | const __m128 v0 = |
37 | 0 | _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); |
38 | 0 | return horizontal_sum(v0); |
39 | 0 | } Unexecuted instantiation: IndexIVFPQ.cpp:_ZN12_GLOBAL__N_114horizontal_sumEDv8_f Unexecuted instantiation: IndexPQ.cpp:_ZN12_GLOBAL__N_114horizontal_sumEDv8_f |
40 | | |
41 | | // processes a single code for M=4, ksub=256, nbits=8 |
42 | | float inline distance_single_code_avx2_pqdecoder8_m4( |
43 | | // precomputed distances, layout (4, 256) |
44 | | const float* sim_table, |
45 | 0 | const uint8_t* code) { |
46 | 0 | float result = 0; |
47 | |
|
48 | 0 | const float* tab = sim_table; |
49 | 0 | constexpr size_t ksub = 1 << 8; |
50 | |
|
51 | 0 | const __m128i vksub = _mm_set1_epi32(ksub); |
52 | 0 | __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); |
53 | 0 | offsets_0 = _mm_mullo_epi32(offsets_0, vksub); |
54 | | |
55 | | // accumulators of partial sums |
56 | 0 | __m128 partialSum; |
57 | | |
58 | | // load 4 uint8 values |
59 | 0 | const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code)); |
60 | 0 | { |
61 | | // convert uint8 values (low part of __m128i) to int32 |
62 | | // values |
63 | 0 | const __m128i idx1 = _mm_cvtepu8_epi32(mm1); |
64 | | |
65 | | // add offsets |
66 | 0 | const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); |
67 | | |
68 | | // gather 8 values, similar to 8 operations of tab[idx] |
69 | 0 | __m128 collected = |
70 | 0 | _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); |
71 | | |
72 | | // collect partial sums |
73 | 0 | partialSum = collected; |
74 | 0 | } |
75 | | |
76 | | // horizontal sum for partialSum |
77 | 0 | result = horizontal_sum(partialSum); |
78 | 0 | return result; |
79 | 0 | } Unexecuted instantiation: IndexIVFPQ.cpp:_ZN12_GLOBAL__N_139distance_single_code_avx2_pqdecoder8_m4EPKfPKh Unexecuted instantiation: IndexPQ.cpp:_ZN12_GLOBAL__N_139distance_single_code_avx2_pqdecoder8_m4EPKfPKh |
80 | | |
81 | | // processes a single code for M=8, ksub=256, nbits=8 |
82 | | float inline distance_single_code_avx2_pqdecoder8_m8( |
83 | | // precomputed distances, layout (8, 256) |
84 | | const float* sim_table, |
85 | 0 | const uint8_t* code) { |
86 | 0 | float result = 0; |
87 | |
|
88 | 0 | const float* tab = sim_table; |
89 | 0 | constexpr size_t ksub = 1 << 8; |
90 | |
|
91 | 0 | const __m256i vksub = _mm256_set1_epi32(ksub); |
92 | 0 | __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
93 | 0 | offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); |
94 | | |
95 | | // accumulators of partial sums |
96 | 0 | __m256 partialSum; |
97 | | |
98 | | // load 8 uint8 values |
99 | 0 | const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code); |
100 | 0 | { |
101 | | // convert uint8 values (low part of __m128i) to int32 |
102 | | // values |
103 | 0 | const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); |
104 | | |
105 | | // add offsets |
106 | 0 | const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); |
107 | | |
108 | | // gather 8 values, similar to 8 operations of tab[idx] |
109 | 0 | __m256 collected = |
110 | 0 | _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); |
111 | | |
112 | | // collect partial sums |
113 | 0 | partialSum = collected; |
114 | 0 | } |
115 | | |
116 | | // horizontal sum for partialSum |
117 | 0 | result = horizontal_sum(partialSum); |
118 | 0 | return result; |
119 | 0 | } Unexecuted instantiation: IndexIVFPQ.cpp:_ZN12_GLOBAL__N_139distance_single_code_avx2_pqdecoder8_m8EPKfPKh Unexecuted instantiation: IndexPQ.cpp:_ZN12_GLOBAL__N_139distance_single_code_avx2_pqdecoder8_m8EPKfPKh |
120 | | |
121 | | // processes four codes for M=4, ksub=256, nbits=8 |
122 | | inline void distance_four_codes_avx2_pqdecoder8_m4( |
123 | | // precomputed distances, layout (4, 256) |
124 | | const float* sim_table, |
125 | | // codes |
126 | | const uint8_t* __restrict code0, |
127 | | const uint8_t* __restrict code1, |
128 | | const uint8_t* __restrict code2, |
129 | | const uint8_t* __restrict code3, |
130 | | // computed distances |
131 | | float& result0, |
132 | | float& result1, |
133 | | float& result2, |
134 | 0 | float& result3) { |
135 | 0 | constexpr intptr_t N = 4; |
136 | |
|
137 | 0 | const float* tab = sim_table; |
138 | 0 | constexpr size_t ksub = 1 << 8; |
139 | | |
140 | | // process 8 values |
141 | 0 | const __m128i vksub = _mm_set1_epi32(ksub); |
142 | 0 | __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3); |
143 | 0 | offsets_0 = _mm_mullo_epi32(offsets_0, vksub); |
144 | | |
145 | | // accumulators of partial sums |
146 | 0 | __m128 partialSums[N]; |
147 | | |
148 | | // load 4 uint8 values |
149 | 0 | __m128i mm1[N]; |
150 | 0 | mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0)); |
151 | 0 | mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1)); |
152 | 0 | mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2)); |
153 | 0 | mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3)); |
154 | |
|
155 | 0 | for (intptr_t j = 0; j < N; j++) { |
156 | | // convert uint8 values (low part of __m128i) to int32 |
157 | | // values |
158 | 0 | const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]); |
159 | | |
160 | | // add offsets |
161 | 0 | const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0); |
162 | | |
163 | | // gather 4 values, similar to 4 operations of tab[idx] |
164 | 0 | __m128 collected = |
165 | 0 | _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float)); |
166 | | |
167 | | // collect partial sums |
168 | 0 | partialSums[j] = collected; |
169 | 0 | } |
170 | | |
171 | | // horizontal sum for partialSum |
172 | 0 | result0 = horizontal_sum(partialSums[0]); |
173 | 0 | result1 = horizontal_sum(partialSums[1]); |
174 | 0 | result2 = horizontal_sum(partialSums[2]); |
175 | 0 | result3 = horizontal_sum(partialSums[3]); |
176 | 0 | } Unexecuted instantiation: IndexIVFPQ.cpp:_ZN12_GLOBAL__N_138distance_four_codes_avx2_pqdecoder8_m4EPKfPKhS3_S3_S3_RfS4_S4_S4_ Unexecuted instantiation: IndexPQ.cpp:_ZN12_GLOBAL__N_138distance_four_codes_avx2_pqdecoder8_m4EPKfPKhS3_S3_S3_RfS4_S4_S4_ |
177 | | |
178 | | // processes four codes for M=8, ksub=256, nbits=8 |
179 | | inline void distance_four_codes_avx2_pqdecoder8_m8( |
180 | | // precomputed distances, layout (8, 256) |
181 | | const float* sim_table, |
182 | | // codes |
183 | | const uint8_t* __restrict code0, |
184 | | const uint8_t* __restrict code1, |
185 | | const uint8_t* __restrict code2, |
186 | | const uint8_t* __restrict code3, |
187 | | // computed distances |
188 | | float& result0, |
189 | | float& result1, |
190 | | float& result2, |
191 | 0 | float& result3) { |
192 | 0 | constexpr intptr_t N = 4; |
193 | |
|
194 | 0 | const float* tab = sim_table; |
195 | 0 | constexpr size_t ksub = 1 << 8; |
196 | | |
197 | | // process 8 values |
198 | 0 | const __m256i vksub = _mm256_set1_epi32(ksub); |
199 | 0 | __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
200 | 0 | offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); |
201 | | |
202 | | // accumulators of partial sums |
203 | 0 | __m256 partialSums[N]; |
204 | | |
205 | | // load 8 uint8 values |
206 | 0 | __m128i mm1[N]; |
207 | 0 | mm1[0] = _mm_loadu_si64((const __m128i_u*)code0); |
208 | 0 | mm1[1] = _mm_loadu_si64((const __m128i_u*)code1); |
209 | 0 | mm1[2] = _mm_loadu_si64((const __m128i_u*)code2); |
210 | 0 | mm1[3] = _mm_loadu_si64((const __m128i_u*)code3); |
211 | |
|
212 | 0 | for (intptr_t j = 0; j < N; j++) { |
213 | | // convert uint8 values (low part of __m128i) to int32 |
214 | | // values |
215 | 0 | const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); |
216 | | |
217 | | // add offsets |
218 | 0 | const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0); |
219 | | |
220 | | // gather 8 values, similar to 8 operations of tab[idx] |
221 | 0 | __m256 collected = |
222 | 0 | _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float)); |
223 | | |
224 | | // collect partial sums |
225 | 0 | partialSums[j] = collected; |
226 | 0 | } |
227 | | |
228 | | // horizontal sum for partialSum |
229 | 0 | result0 = horizontal_sum(partialSums[0]); |
230 | 0 | result1 = horizontal_sum(partialSums[1]); |
231 | 0 | result2 = horizontal_sum(partialSums[2]); |
232 | 0 | result3 = horizontal_sum(partialSums[3]); |
233 | 0 | } Unexecuted instantiation: IndexIVFPQ.cpp:_ZN12_GLOBAL__N_138distance_four_codes_avx2_pqdecoder8_m8EPKfPKhS3_S3_S3_RfS4_S4_S4_ Unexecuted instantiation: IndexPQ.cpp:_ZN12_GLOBAL__N_138distance_four_codes_avx2_pqdecoder8_m8EPKfPKhS3_S3_S3_RfS4_S4_S4_ |
234 | | |
235 | | } // namespace |
236 | | |
237 | | namespace faiss { |
238 | | |
239 | | template <typename PQDecoderT> |
240 | | typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>:: |
241 | | type inline distance_single_code_avx2( |
242 | | // number of subquantizers |
243 | | const size_t M, |
244 | | // number of bits per quantization index |
245 | | const size_t nbits, |
246 | | // precomputed distances, layout (M, ksub) |
247 | | const float* sim_table, |
248 | 0 | const uint8_t* code) { |
249 | | // default implementation |
250 | 0 | return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code); |
251 | 0 | } Unexecuted instantiation: _ZN5faiss25distance_single_code_avx2INS_11PQDecoder16EEENSt9enable_ifIXntsr3std7is_sameIT_NS_10PQDecoder8EEE5valueEfE4typeEmmPKfPKh Unexecuted instantiation: _ZN5faiss25distance_single_code_avx2INS_16PQDecoderGenericEEENSt9enable_ifIXntsr3std7is_sameIT_NS_10PQDecoder8EEE5valueEfE4typeEmmPKfPKh |
252 | | |
253 | | template <typename PQDecoderT> |
254 | | typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>:: |
255 | | type inline distance_single_code_avx2( |
256 | | // number of subquantizers |
257 | | const size_t M, |
258 | | // number of bits per quantization index |
259 | | const size_t nbits, |
260 | | // precomputed distances, layout (M, ksub) |
261 | | const float* sim_table, |
262 | 0 | const uint8_t* code) { |
263 | 0 | if (M == 4) { |
264 | 0 | return distance_single_code_avx2_pqdecoder8_m4(sim_table, code); |
265 | 0 | } |
266 | 0 | if (M == 8) { |
267 | 0 | return distance_single_code_avx2_pqdecoder8_m8(sim_table, code); |
268 | 0 | } |
269 | | |
270 | 0 | float result = 0; |
271 | 0 | constexpr size_t ksub = 1 << 8; |
272 | |
|
273 | 0 | size_t m = 0; |
274 | 0 | const size_t pqM16 = M / 16; |
275 | |
|
276 | 0 | const float* tab = sim_table; |
277 | |
|
278 | 0 | if (pqM16 > 0) { |
279 | | // process 16 values per loop |
280 | |
|
281 | 0 | const __m256i vksub = _mm256_set1_epi32(ksub); |
282 | 0 | __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
283 | 0 | offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); |
284 | | |
285 | | // accumulators of partial sums |
286 | 0 | __m256 partialSum = _mm256_setzero_ps(); |
287 | | |
288 | | // loop |
289 | 0 | for (m = 0; m < pqM16 * 16; m += 16) { |
290 | | // load 16 uint8 values |
291 | 0 | const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m)); |
292 | 0 | { |
293 | | // convert uint8 values (low part of __m128i) to int32 |
294 | | // values |
295 | 0 | const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); |
296 | | |
297 | | // add offsets |
298 | 0 | const __m256i indices_to_read_from = |
299 | 0 | _mm256_add_epi32(idx1, offsets_0); |
300 | | |
301 | | // gather 8 values, similar to 8 operations of tab[idx] |
302 | 0 | __m256 collected = _mm256_i32gather_ps( |
303 | 0 | tab, indices_to_read_from, sizeof(float)); |
304 | 0 | tab += ksub * 8; |
305 | | |
306 | | // collect partial sums |
307 | 0 | partialSum = _mm256_add_ps(partialSum, collected); |
308 | 0 | } |
309 | | |
310 | | // move high 8 uint8 to low ones |
311 | 0 | const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); |
312 | 0 | { |
313 | | // convert uint8 values (low part of __m128i) to int32 |
314 | | // values |
315 | 0 | const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); |
316 | | |
317 | | // add offsets |
318 | 0 | const __m256i indices_to_read_from = |
319 | 0 | _mm256_add_epi32(idx1, offsets_0); |
320 | | |
321 | | // gather 8 values, similar to 8 operations of tab[idx] |
322 | 0 | __m256 collected = _mm256_i32gather_ps( |
323 | 0 | tab, indices_to_read_from, sizeof(float)); |
324 | 0 | tab += ksub * 8; |
325 | | |
326 | | // collect partial sums |
327 | 0 | partialSum = _mm256_add_ps(partialSum, collected); |
328 | 0 | } |
329 | 0 | } |
330 | | |
331 | | // horizontal sum for partialSum |
332 | 0 | result += horizontal_sum(partialSum); |
333 | 0 | } |
334 | | |
335 | | // |
336 | 0 | if (m < M) { |
337 | | // process leftovers |
338 | 0 | PQDecoder8 decoder(code + m, nbits); |
339 | |
|
340 | 0 | for (; m < M; m++) { |
341 | 0 | result += tab[decoder.decode()]; |
342 | 0 | tab += ksub; |
343 | 0 | } |
344 | 0 | } |
345 | |
|
346 | 0 | return result; |
347 | 0 | } |
348 | | |
349 | | template <typename PQDecoderT> |
350 | | typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>:: |
351 | | type |
352 | | distance_four_codes_avx2( |
353 | | // number of subquantizers |
354 | | const size_t M, |
355 | | // number of bits per quantization index |
356 | | const size_t nbits, |
357 | | // precomputed distances, layout (M, ksub) |
358 | | const float* sim_table, |
359 | | // codes |
360 | | const uint8_t* __restrict code0, |
361 | | const uint8_t* __restrict code1, |
362 | | const uint8_t* __restrict code2, |
363 | | const uint8_t* __restrict code3, |
364 | | // computed distances |
365 | | float& result0, |
366 | | float& result1, |
367 | | float& result2, |
368 | 0 | float& result3) { |
369 | 0 | distance_four_codes_generic<PQDecoderT>( |
370 | 0 | M, |
371 | 0 | nbits, |
372 | 0 | sim_table, |
373 | 0 | code0, |
374 | 0 | code1, |
375 | 0 | code2, |
376 | 0 | code3, |
377 | 0 | result0, |
378 | 0 | result1, |
379 | 0 | result2, |
380 | 0 | result3); |
381 | 0 | } Unexecuted instantiation: _ZN5faiss24distance_four_codes_avx2INS_11PQDecoder16EEENSt9enable_ifIXntsr3std7is_sameIT_NS_10PQDecoder8EEE5valueEvE4typeEmmPKfPKhSA_SA_SA_RfSB_SB_SB_ Unexecuted instantiation: _ZN5faiss24distance_four_codes_avx2INS_16PQDecoderGenericEEENSt9enable_ifIXntsr3std7is_sameIT_NS_10PQDecoder8EEE5valueEvE4typeEmmPKfPKhSA_SA_SA_RfSB_SB_SB_ |
382 | | |
383 | | // Combines 4 operations of distance_single_code() |
384 | | template <typename PQDecoderT> |
385 | | typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type |
386 | | distance_four_codes_avx2( |
387 | | // number of subquantizers |
388 | | const size_t M, |
389 | | // number of bits per quantization index |
390 | | const size_t nbits, |
391 | | // precomputed distances, layout (M, ksub) |
392 | | const float* sim_table, |
393 | | // codes |
394 | | const uint8_t* __restrict code0, |
395 | | const uint8_t* __restrict code1, |
396 | | const uint8_t* __restrict code2, |
397 | | const uint8_t* __restrict code3, |
398 | | // computed distances |
399 | | float& result0, |
400 | | float& result1, |
401 | | float& result2, |
402 | 0 | float& result3) { |
403 | 0 | if (M == 4) { |
404 | 0 | distance_four_codes_avx2_pqdecoder8_m4( |
405 | 0 | sim_table, |
406 | 0 | code0, |
407 | 0 | code1, |
408 | 0 | code2, |
409 | 0 | code3, |
410 | 0 | result0, |
411 | 0 | result1, |
412 | 0 | result2, |
413 | 0 | result3); |
414 | 0 | return; |
415 | 0 | } |
416 | 0 | if (M == 8) { |
417 | 0 | distance_four_codes_avx2_pqdecoder8_m8( |
418 | 0 | sim_table, |
419 | 0 | code0, |
420 | 0 | code1, |
421 | 0 | code2, |
422 | 0 | code3, |
423 | 0 | result0, |
424 | 0 | result1, |
425 | 0 | result2, |
426 | 0 | result3); |
427 | 0 | return; |
428 | 0 | } |
429 | | |
430 | 0 | result0 = 0; |
431 | 0 | result1 = 0; |
432 | 0 | result2 = 0; |
433 | 0 | result3 = 0; |
434 | 0 | constexpr size_t ksub = 1 << 8; |
435 | |
|
436 | 0 | size_t m = 0; |
437 | 0 | const size_t pqM16 = M / 16; |
438 | |
|
439 | 0 | constexpr intptr_t N = 4; |
440 | |
|
441 | 0 | const float* tab = sim_table; |
442 | |
|
443 | 0 | if (pqM16 > 0) { |
444 | | // process 16 values per loop |
445 | 0 | const __m256i vksub = _mm256_set1_epi32(ksub); |
446 | 0 | __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
447 | 0 | offsets_0 = _mm256_mullo_epi32(offsets_0, vksub); |
448 | | |
449 | | // accumulators of partial sums |
450 | 0 | __m256 partialSums[N]; |
451 | 0 | for (intptr_t j = 0; j < N; j++) { |
452 | 0 | partialSums[j] = _mm256_setzero_ps(); |
453 | 0 | } |
454 | | |
455 | | // loop |
456 | 0 | for (m = 0; m < pqM16 * 16; m += 16) { |
457 | | // load 16 uint8 values |
458 | 0 | __m128i mm1[N]; |
459 | 0 | mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); |
460 | 0 | mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); |
461 | 0 | mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); |
462 | 0 | mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); |
463 | | |
464 | | // process first 8 codes |
465 | 0 | for (intptr_t j = 0; j < N; j++) { |
466 | | // convert uint8 values (low part of __m128i) to int32 |
467 | | // values |
468 | 0 | const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]); |
469 | | |
470 | | // add offsets |
471 | 0 | const __m256i indices_to_read_from = |
472 | 0 | _mm256_add_epi32(idx1, offsets_0); |
473 | | |
474 | | // gather 8 values, similar to 8 operations of tab[idx] |
475 | 0 | __m256 collected = _mm256_i32gather_ps( |
476 | 0 | tab, indices_to_read_from, sizeof(float)); |
477 | | |
478 | | // collect partial sums |
479 | 0 | partialSums[j] = _mm256_add_ps(partialSums[j], collected); |
480 | 0 | } |
481 | 0 | tab += ksub * 8; |
482 | | |
483 | | // process next 8 codes |
484 | 0 | for (intptr_t j = 0; j < N; j++) { |
485 | | // move high 8 uint8 to low ones |
486 | 0 | const __m128i mm2 = |
487 | 0 | _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128()); |
488 | | |
489 | | // convert uint8 values (low part of __m128i) to int32 |
490 | | // values |
491 | 0 | const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); |
492 | | |
493 | | // add offsets |
494 | 0 | const __m256i indices_to_read_from = |
495 | 0 | _mm256_add_epi32(idx1, offsets_0); |
496 | | |
497 | | // gather 8 values, similar to 8 operations of tab[idx] |
498 | 0 | __m256 collected = _mm256_i32gather_ps( |
499 | 0 | tab, indices_to_read_from, sizeof(float)); |
500 | | |
501 | | // collect partial sums |
502 | 0 | partialSums[j] = _mm256_add_ps(partialSums[j], collected); |
503 | 0 | } |
504 | |
|
505 | 0 | tab += ksub * 8; |
506 | 0 | } |
507 | | |
508 | | // horizontal sum for partialSum |
509 | 0 | result0 += horizontal_sum(partialSums[0]); |
510 | 0 | result1 += horizontal_sum(partialSums[1]); |
511 | 0 | result2 += horizontal_sum(partialSums[2]); |
512 | 0 | result3 += horizontal_sum(partialSums[3]); |
513 | 0 | } |
514 | | |
515 | | // |
516 | 0 | if (m < M) { |
517 | | // process leftovers |
518 | 0 | PQDecoder8 decoder0(code0 + m, nbits); |
519 | 0 | PQDecoder8 decoder1(code1 + m, nbits); |
520 | 0 | PQDecoder8 decoder2(code2 + m, nbits); |
521 | 0 | PQDecoder8 decoder3(code3 + m, nbits); |
522 | 0 | for (; m < M; m++) { |
523 | 0 | result0 += tab[decoder0.decode()]; |
524 | 0 | result1 += tab[decoder1.decode()]; |
525 | 0 | result2 += tab[decoder2.decode()]; |
526 | 0 | result3 += tab[decoder3.decode()]; |
527 | 0 | tab += ksub; |
528 | 0 | } |
529 | 0 | } |
530 | 0 | } |
531 | | |
532 | | } // namespace faiss |
533 | | |
534 | | #endif |