contrib/faiss/faiss/utils/hamming_distance/avx2-inl.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #ifndef HAMMING_AVX2_INL_H |
9 | | #define HAMMING_AVX2_INL_H |
10 | | |
11 | | // AVX2 version |
12 | | |
13 | | #include <cassert> |
14 | | #include <cstddef> |
15 | | #include <cstdint> |
16 | | |
17 | | #include <faiss/impl/platform_macros.h> |
18 | | |
19 | | #include <immintrin.h> |
20 | | |
21 | | namespace faiss { |
22 | | |
23 | | /* Elementary Hamming distance computation: unoptimized */ |
24 | | template <size_t nbits, typename T> |
25 | | inline T hamming(const uint8_t* bs1, const uint8_t* bs2) { |
26 | | const size_t nbytes = nbits / 8; |
27 | | size_t i; |
28 | | T h = 0; |
29 | | for (i = 0; i < nbytes; i++) { |
30 | | h += (T)hamdis_tab_ham_bytes[bs1[i] ^ bs2[i]]; |
31 | | } |
32 | | return h; |
33 | | } |
34 | | |
35 | | /* Hamming distances for multiples of 64 bits */ |
36 | | template <size_t nbits> |
37 | 0 | inline hamdis_t hamming(const uint64_t* bs1, const uint64_t* bs2) { |
38 | 0 | const size_t nwords = nbits / 64; |
39 | 0 | size_t i; |
40 | 0 | hamdis_t h = 0; |
41 | 0 | for (i = 0; i < nwords; i++) { |
42 | 0 | h += popcount64(bs1[i] ^ bs2[i]); |
43 | 0 | } |
44 | 0 | return h; |
45 | 0 | } |
46 | | |
47 | | /* specialized (optimized) functions */ |
48 | | template <> |
49 | 0 | inline hamdis_t hamming<64>(const uint64_t* pa, const uint64_t* pb) { |
50 | 0 | return popcount64(pa[0] ^ pb[0]); |
51 | 0 | } |
52 | | |
53 | | template <> |
54 | 0 | inline hamdis_t hamming<128>(const uint64_t* pa, const uint64_t* pb) { |
55 | 0 | return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]); |
56 | 0 | } |
57 | | |
58 | | template <> |
59 | 0 | inline hamdis_t hamming<256>(const uint64_t* pa, const uint64_t* pb) { |
60 | 0 | return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]) + |
61 | 0 | popcount64(pa[2] ^ pb[2]) + popcount64(pa[3] ^ pb[3]); |
62 | 0 | } |
63 | | |
64 | | /* Hamming distances for multiple of 64 bits */ |
65 | | inline hamdis_t hamming( |
66 | | const uint64_t* bs1, |
67 | | const uint64_t* bs2, |
68 | 0 | size_t nwords) { |
69 | 0 | hamdis_t h = 0; |
70 | 0 | for (size_t i = 0; i < nwords; i++) { |
71 | 0 | h += popcount64(bs1[i] ^ bs2[i]); |
72 | 0 | } |
73 | 0 | return h; |
74 | 0 | } |
75 | | |
76 | | /****************************************************************** |
77 | | * The HammingComputer series of classes compares a single code of |
78 | | * size 4 to 32 to incoming codes. They are intended for use as a |
79 | | * template class where it would be inefficient to switch on the code |
80 | | * size in the inner loop. Hopefully the compiler will inline the |
81 | | * hamming() functions and put the a0, a1, ... in registers. |
82 | | ******************************************************************/ |
83 | | |
84 | | struct HammingComputer4 { |
85 | | uint32_t a0; |
86 | | |
87 | 0 | HammingComputer4() {} |
88 | | |
89 | 0 | HammingComputer4(const uint8_t* a, int code_size) { |
90 | 0 | set(a, code_size); |
91 | 0 | } |
92 | | |
93 | 0 | void set(const uint8_t* a, int code_size) { |
94 | 0 | assert(code_size == 4); |
95 | 0 | a0 = *(uint32_t*)a; |
96 | 0 | } |
97 | | |
98 | 0 | inline int hamming(const uint8_t* b) const { |
99 | 0 | return popcount64(*(uint32_t*)b ^ a0); |
100 | 0 | } |
101 | | |
102 | 0 | inline static constexpr int get_code_size() { |
103 | 0 | return 4; |
104 | 0 | } |
105 | | }; |
106 | | |
107 | | struct HammingComputer8 { |
108 | | uint64_t a0; |
109 | | |
110 | 0 | HammingComputer8() {} |
111 | | |
112 | 0 | HammingComputer8(const uint8_t* a, int code_size) { |
113 | 0 | set(a, code_size); |
114 | 0 | } |
115 | | |
116 | 0 | void set(const uint8_t* a, int code_size) { |
117 | 0 | assert(code_size == 8); |
118 | 0 | a0 = *(uint64_t*)a; |
119 | 0 | } |
120 | | |
121 | 0 | inline int hamming(const uint8_t* b) const { |
122 | 0 | return popcount64(*(uint64_t*)b ^ a0); |
123 | 0 | } |
124 | | |
125 | 0 | inline static constexpr int get_code_size() { |
126 | 0 | return 8; |
127 | 0 | } |
128 | | }; |
129 | | |
130 | | struct HammingComputer16 { |
131 | | uint64_t a0, a1; |
132 | | |
133 | 0 | HammingComputer16() {} |
134 | | |
135 | 0 | HammingComputer16(const uint8_t* a8, int code_size) { |
136 | 0 | set(a8, code_size); |
137 | 0 | } |
138 | | |
139 | 0 | void set(const uint8_t* a8, int code_size) { |
140 | 0 | assert(code_size == 16); |
141 | 0 | const uint64_t* a = (uint64_t*)a8; |
142 | 0 | a0 = a[0]; |
143 | 0 | a1 = a[1]; |
144 | 0 | } |
145 | | |
146 | 0 | inline int hamming(const uint8_t* b8) const { |
147 | 0 | const uint64_t* b = (uint64_t*)b8; |
148 | 0 | return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1); |
149 | 0 | } |
150 | | |
151 | 0 | inline static constexpr int get_code_size() { |
152 | 0 | return 16; |
153 | 0 | } |
154 | | }; |
155 | | |
156 | | // when applied to an array, 1/2 of the 64-bit accesses are unaligned. |
157 | | // This incurs a penalty of ~10% wrt. fully aligned accesses. |
158 | | struct HammingComputer20 { |
159 | | uint64_t a0, a1; |
160 | | uint32_t a2; |
161 | | |
162 | 0 | HammingComputer20() {} |
163 | | |
164 | 0 | HammingComputer20(const uint8_t* a8, int code_size) { |
165 | 0 | set(a8, code_size); |
166 | 0 | } |
167 | | |
168 | 0 | void set(const uint8_t* a8, int code_size) { |
169 | 0 | assert(code_size == 20); |
170 | 0 | const uint64_t* a = (uint64_t*)a8; |
171 | 0 | a0 = a[0]; |
172 | 0 | a1 = a[1]; |
173 | 0 | a2 = a[2]; |
174 | 0 | } |
175 | | |
176 | 0 | inline int hamming(const uint8_t* b8) const { |
177 | 0 | const uint64_t* b = (uint64_t*)b8; |
178 | 0 | return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1) + |
179 | 0 | popcount64(*(uint32_t*)(b + 2) ^ a2); |
180 | 0 | } |
181 | | |
182 | 0 | inline static constexpr int get_code_size() { |
183 | 0 | return 20; |
184 | 0 | } |
185 | | }; |
186 | | |
187 | | struct HammingComputer32 { |
188 | | uint64_t a0, a1, a2, a3; |
189 | | |
190 | 0 | HammingComputer32() {} |
191 | | |
192 | 0 | HammingComputer32(const uint8_t* a8, int code_size) { |
193 | 0 | set(a8, code_size); |
194 | 0 | } |
195 | | |
196 | 0 | void set(const uint8_t* a8, int code_size) { |
197 | 0 | assert(code_size == 32); |
198 | 0 | const uint64_t* a = (uint64_t*)a8; |
199 | 0 | a0 = a[0]; |
200 | 0 | a1 = a[1]; |
201 | 0 | a2 = a[2]; |
202 | 0 | a3 = a[3]; |
203 | 0 | } |
204 | | |
205 | 0 | inline int hamming(const uint8_t* b8) const { |
206 | 0 | const uint64_t* b = (uint64_t*)b8; |
207 | 0 | return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1) + |
208 | 0 | popcount64(b[2] ^ a2) + popcount64(b[3] ^ a3); |
209 | 0 | } |
210 | | |
211 | 0 | inline static constexpr int get_code_size() { |
212 | 0 | return 32; |
213 | 0 | } |
214 | | }; |
215 | | |
216 | | struct HammingComputer64 { |
217 | | uint64_t a0, a1, a2, a3, a4, a5, a6, a7; |
218 | | |
219 | 0 | HammingComputer64() {} |
220 | | |
221 | 0 | HammingComputer64(const uint8_t* a8, int code_size) { |
222 | 0 | set(a8, code_size); |
223 | 0 | } |
224 | | |
225 | 0 | void set(const uint8_t* a8, int code_size) { |
226 | 0 | assert(code_size == 64); |
227 | 0 | const uint64_t* a = (uint64_t*)a8; |
228 | 0 | a0 = a[0]; |
229 | 0 | a1 = a[1]; |
230 | 0 | a2 = a[2]; |
231 | 0 | a3 = a[3]; |
232 | 0 | a4 = a[4]; |
233 | 0 | a5 = a[5]; |
234 | 0 | a6 = a[6]; |
235 | 0 | a7 = a[7]; |
236 | 0 | } |
237 | | |
238 | 0 | inline int hamming(const uint8_t* b8) const { |
239 | 0 | const uint64_t* b = (uint64_t*)b8; |
240 | 0 | return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1) + |
241 | 0 | popcount64(b[2] ^ a2) + popcount64(b[3] ^ a3) + |
242 | 0 | popcount64(b[4] ^ a4) + popcount64(b[5] ^ a5) + |
243 | 0 | popcount64(b[6] ^ a6) + popcount64(b[7] ^ a7); |
244 | 0 | } |
245 | | |
246 | 0 | inline static constexpr int get_code_size() { |
247 | 0 | return 64; |
248 | 0 | } |
249 | | }; |
250 | | |
251 | | struct HammingComputerDefault { |
252 | | const uint8_t* a8; |
253 | | int quotient8; |
254 | | int remainder8; |
255 | | |
256 | 0 | HammingComputerDefault() {} |
257 | | |
258 | 0 | HammingComputerDefault(const uint8_t* a8, int code_size) { |
259 | 0 | set(a8, code_size); |
260 | 0 | } |
261 | | |
262 | 0 | void set(const uint8_t* a8_2, int code_size) { |
263 | 0 | this->a8 = a8_2; |
264 | 0 | quotient8 = code_size / 8; |
265 | 0 | remainder8 = code_size % 8; |
266 | 0 | } |
267 | | |
268 | 0 | int hamming(const uint8_t* b8) const { |
269 | 0 | int accu = 0; |
270 | |
|
271 | 0 | const uint64_t* a64 = reinterpret_cast<const uint64_t*>(a8); |
272 | 0 | const uint64_t* b64 = reinterpret_cast<const uint64_t*>(b8); |
273 | 0 | int i = 0, len = quotient8; |
274 | 0 | switch (len & 7) { |
275 | 0 | default: |
276 | 0 | while (len > 7) { |
277 | 0 | len -= 8; |
278 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
279 | 0 | i++; |
280 | 0 | [[fallthrough]]; |
281 | 0 | case 7: |
282 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
283 | 0 | i++; |
284 | 0 | [[fallthrough]]; |
285 | 0 | case 6: |
286 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
287 | 0 | i++; |
288 | 0 | [[fallthrough]]; |
289 | 0 | case 5: |
290 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
291 | 0 | i++; |
292 | 0 | [[fallthrough]]; |
293 | 0 | case 4: |
294 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
295 | 0 | i++; |
296 | 0 | [[fallthrough]]; |
297 | 0 | case 3: |
298 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
299 | 0 | i++; |
300 | 0 | [[fallthrough]]; |
301 | 0 | case 2: |
302 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
303 | 0 | i++; |
304 | 0 | [[fallthrough]]; |
305 | 0 | case 1: |
306 | 0 | accu += popcount64(a64[i] ^ b64[i]); |
307 | 0 | i++; |
308 | 0 | } |
309 | 0 | } |
310 | 0 | if (remainder8) { |
311 | 0 | const uint8_t* a = a8 + 8 * quotient8; |
312 | 0 | const uint8_t* b = b8 + 8 * quotient8; |
313 | 0 | switch (remainder8) { |
314 | 0 | case 7: |
315 | 0 | accu += hamdis_tab_ham_bytes[a[6] ^ b[6]]; |
316 | 0 | [[fallthrough]]; |
317 | 0 | case 6: |
318 | 0 | accu += hamdis_tab_ham_bytes[a[5] ^ b[5]]; |
319 | 0 | [[fallthrough]]; |
320 | 0 | case 5: |
321 | 0 | accu += hamdis_tab_ham_bytes[a[4] ^ b[4]]; |
322 | 0 | [[fallthrough]]; |
323 | 0 | case 4: |
324 | 0 | accu += hamdis_tab_ham_bytes[a[3] ^ b[3]]; |
325 | 0 | [[fallthrough]]; |
326 | 0 | case 3: |
327 | 0 | accu += hamdis_tab_ham_bytes[a[2] ^ b[2]]; |
328 | 0 | [[fallthrough]]; |
329 | 0 | case 2: |
330 | 0 | accu += hamdis_tab_ham_bytes[a[1] ^ b[1]]; |
331 | 0 | [[fallthrough]]; |
332 | 0 | case 1: |
333 | 0 | accu += hamdis_tab_ham_bytes[a[0] ^ b[0]]; |
334 | 0 | [[fallthrough]]; |
335 | 0 | default: |
336 | 0 | break; |
337 | 0 | } |
338 | 0 | } |
339 | | |
340 | 0 | return accu; |
341 | 0 | } |
342 | | |
343 | 0 | inline int get_code_size() const { |
344 | 0 | return quotient8 * 8 + remainder8; |
345 | 0 | } |
346 | | }; |
347 | | |
348 | | /*************************************************************************** |
349 | | * generalized Hamming = number of bytes that are different between |
350 | | * two codes. |
351 | | ***************************************************************************/ |
352 | | |
353 | 0 | inline int generalized_hamming_64(uint64_t a) { |
354 | 0 | a |= a >> 1; |
355 | 0 | a |= a >> 2; |
356 | 0 | a |= a >> 4; |
357 | 0 | a &= 0x0101010101010101UL; |
358 | 0 | return popcount64(a); |
359 | 0 | } |
360 | | |
361 | | struct GenHammingComputer8 { |
362 | | uint64_t a0; |
363 | | |
364 | 0 | GenHammingComputer8(const uint8_t* a, int code_size) { |
365 | 0 | assert(code_size == 8); |
366 | 0 | a0 = *(uint64_t*)a; |
367 | 0 | } |
368 | | |
369 | 0 | inline int hamming(const uint8_t* b) const { |
370 | 0 | return generalized_hamming_64(*(uint64_t*)b ^ a0); |
371 | 0 | } |
372 | | |
373 | 0 | inline static constexpr int get_code_size() { |
374 | 0 | return 8; |
375 | 0 | } |
376 | | }; |
377 | | |
378 | | // I'm not sure whether this version is faster of slower, tbh |
379 | | // todo: test on different CPUs |
380 | | struct GenHammingComputer16 { |
381 | | __m128i a; |
382 | | |
383 | 0 | GenHammingComputer16(const uint8_t* a8, int code_size) { |
384 | 0 | assert(code_size == 16); |
385 | 0 | a = _mm_loadu_si128((const __m128i_u*)a8); |
386 | 0 | } |
387 | | |
388 | 0 | inline int hamming(const uint8_t* b8) const { |
389 | 0 | const __m128i b = _mm_loadu_si128((const __m128i_u*)b8); |
390 | 0 | const __m128i cmp = _mm_cmpeq_epi8(a, b); |
391 | 0 | const auto movemask = _mm_movemask_epi8(cmp); |
392 | 0 | return 16 - popcount32(movemask); |
393 | 0 | } |
394 | | |
395 | 0 | inline static constexpr int get_code_size() { |
396 | 0 | return 16; |
397 | 0 | } |
398 | | }; |
399 | | |
400 | | struct GenHammingComputer32 { |
401 | | __m256i a; |
402 | | |
403 | 0 | GenHammingComputer32(const uint8_t* a8, int code_size) { |
404 | 0 | assert(code_size == 32); |
405 | 0 | a = _mm256_loadu_si256((const __m256i_u*)a8); |
406 | 0 | } |
407 | | |
408 | 0 | inline int hamming(const uint8_t* b8) const { |
409 | 0 | const __m256i b = _mm256_loadu_si256((const __m256i_u*)b8); |
410 | 0 | const __m256i cmp = _mm256_cmpeq_epi8(a, b); |
411 | 0 | const uint32_t movemask = _mm256_movemask_epi8(cmp); |
412 | 0 | return 32 - popcount32(movemask); |
413 | 0 | } |
414 | | |
415 | 0 | inline static constexpr int get_code_size() { |
416 | 0 | return 32; |
417 | 0 | } |
418 | | }; |
419 | | |
420 | | // A specialized version might be needed for the very long |
421 | | // GenHamming code_size. In such a case, one may accumulate |
422 | | // counts using _mm256_sub_epi8 and then compute a horizontal |
423 | | // sum (using _mm256_sad_epu8, maybe, in blocks of no larger |
424 | | // than 256 * 32 bytes). |
425 | | |
426 | | struct GenHammingComputerM8 { |
427 | | const uint64_t* a; |
428 | | int n; |
429 | | |
430 | 0 | GenHammingComputerM8(const uint8_t* a8, int code_size) { |
431 | 0 | assert(code_size % 8 == 0); |
432 | 0 | a = (uint64_t*)a8; |
433 | 0 | n = code_size / 8; |
434 | 0 | } |
435 | | |
436 | 0 | int hamming(const uint8_t* b8) const { |
437 | 0 | const uint64_t* b = (uint64_t*)b8; |
438 | 0 | int accu = 0; |
439 | |
|
440 | 0 | int i = 0; |
441 | 0 | int n4 = (n / 4) * 4; |
442 | 0 | for (; i < n4; i += 4) { |
443 | 0 | const __m256i av = _mm256_loadu_si256((const __m256i_u*)(a + i)); |
444 | 0 | const __m256i bv = _mm256_loadu_si256((const __m256i_u*)(b + i)); |
445 | 0 | const __m256i cmp = _mm256_cmpeq_epi8(av, bv); |
446 | 0 | const uint32_t movemask = _mm256_movemask_epi8(cmp); |
447 | 0 | accu += 32 - popcount32(movemask); |
448 | 0 | } |
449 | |
|
450 | 0 | for (; i < n; i++) |
451 | 0 | accu += generalized_hamming_64(a[i] ^ b[i]); |
452 | 0 | return accu; |
453 | 0 | } |
454 | | |
455 | 0 | inline int get_code_size() const { |
456 | 0 | return n * 8; |
457 | 0 | } |
458 | | }; |
459 | | |
460 | | } // namespace faiss |
461 | | |
462 | | #endif |