contrib/faiss/faiss/impl/residual_quantizer_encode_steps.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/residual_quantizer_encode_steps.h> |
9 | | |
10 | | #include <faiss/impl/AuxIndexStructures.h> |
11 | | #include <faiss/impl/FaissAssert.h> |
12 | | #include <faiss/impl/ResidualQuantizer.h> |
13 | | #include <faiss/utils/Heap.h> |
14 | | #include <faiss/utils/distances.h> |
15 | | #include <faiss/utils/simdlib.h> |
16 | | #include <faiss/utils/utils.h> |
17 | | |
18 | | #include <faiss/utils/approx_topk/approx_topk.h> |
19 | | |
20 | | extern "C" { |
21 | | |
22 | | // general matrix multiplication |
23 | | int sgemm_( |
24 | | const char* transa, |
25 | | const char* transb, |
26 | | FINTEGER* m, |
27 | | FINTEGER* n, |
28 | | FINTEGER* k, |
29 | | const float* alpha, |
30 | | const float* a, |
31 | | FINTEGER* lda, |
32 | | const float* b, |
33 | | FINTEGER* ldb, |
34 | | float* beta, |
35 | | float* c, |
36 | | FINTEGER* ldc); |
37 | | } |
38 | | |
39 | | namespace faiss { |
40 | | |
41 | | /******************************************************************** |
42 | | * Basic routines |
43 | | ********************************************************************/ |
44 | | |
45 | | namespace { |
46 | | |
47 | | template <size_t M, size_t NK> |
48 | | void accum_and_store_tab( |
49 | | const size_t m_offset, |
50 | | const float* const __restrict codebook_cross_norms, |
51 | | const uint64_t* const __restrict codebook_offsets, |
52 | | const int32_t* const __restrict codes_i, |
53 | | const size_t b, |
54 | | const size_t ldc, |
55 | | const size_t K, |
56 | 0 | float* const __restrict output) { |
57 | | // load pointers into registers |
58 | 0 | const float* cbs[M]; |
59 | 0 | for (size_t ij = 0; ij < M; ij++) { |
60 | 0 | const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]); |
61 | 0 | cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc]; |
62 | 0 | } |
63 | | |
64 | | // do accumulation in registers using SIMD. |
65 | | // It is possible that compiler may be smart enough so that |
66 | | // this manual SIMD unrolling might be unneeded. |
67 | 0 | #if defined(__AVX2__) || defined(__aarch64__) |
68 | 0 | const size_t K8 = (K / (8 * NK)) * (8 * NK); |
69 | | |
70 | | // process in chunks of size (8 * NK) floats |
71 | 0 | for (size_t kk = 0; kk < K8; kk += 8 * NK) { |
72 | 0 | simd8float32 regs[NK]; |
73 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
74 | 0 | regs[ik].loadu(cbs[0] + kk + ik * 8); |
75 | 0 | } |
76 | |
|
77 | 0 | for (size_t ij = 1; ij < M; ij++) { |
78 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
79 | 0 | regs[ik] += simd8float32(cbs[ij] + kk + ik * 8); |
80 | 0 | } |
81 | 0 | } |
82 | | |
83 | | // write the result |
84 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
85 | 0 | regs[ik].storeu(output + kk + ik * 8); |
86 | 0 | } |
87 | 0 | } |
88 | | #else |
89 | | const size_t K8 = 0; |
90 | | #endif |
91 | | |
92 | | // process leftovers |
93 | 0 | for (size_t kk = K8; kk < K; kk++) { |
94 | 0 | float reg = cbs[0][kk]; |
95 | 0 | for (size_t ij = 1; ij < M; ij++) { |
96 | 0 | reg += cbs[ij][kk]; |
97 | 0 | } |
98 | 0 | output[kk] = reg; |
99 | 0 | } |
100 | 0 | } |
101 | | |
102 | | template <size_t M, size_t NK> |
103 | | void accum_and_add_tab( |
104 | | const size_t m_offset, |
105 | | const float* const __restrict codebook_cross_norms, |
106 | | const uint64_t* const __restrict codebook_offsets, |
107 | | const int32_t* const __restrict codes_i, |
108 | | const size_t b, |
109 | | const size_t ldc, |
110 | | const size_t K, |
111 | 0 | float* const __restrict output) { |
112 | | // load pointers into registers |
113 | 0 | const float* cbs[M]; |
114 | 0 | for (size_t ij = 0; ij < M; ij++) { |
115 | 0 | const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]); |
116 | 0 | cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc]; |
117 | 0 | } |
118 | | |
119 | | // do accumulation in registers using SIMD. |
120 | | // It is possible that compiler may be smart enough so that |
121 | | // this manual SIMD unrolling might be unneeded. |
122 | 0 | #if defined(__AVX2__) || defined(__aarch64__) |
123 | 0 | const size_t K8 = (K / (8 * NK)) * (8 * NK); |
124 | | |
125 | | // process in chunks of size (8 * NK) floats |
126 | 0 | for (size_t kk = 0; kk < K8; kk += 8 * NK) { |
127 | 0 | simd8float32 regs[NK]; |
128 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
129 | 0 | regs[ik].loadu(cbs[0] + kk + ik * 8); |
130 | 0 | } |
131 | |
|
132 | 0 | for (size_t ij = 1; ij < M; ij++) { |
133 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
134 | 0 | regs[ik] += simd8float32(cbs[ij] + kk + ik * 8); |
135 | 0 | } |
136 | 0 | } |
137 | | |
138 | | // write the result |
139 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
140 | 0 | simd8float32 existing(output + kk + ik * 8); |
141 | 0 | existing += regs[ik]; |
142 | 0 | existing.storeu(output + kk + ik * 8); |
143 | 0 | } |
144 | 0 | } |
145 | | #else |
146 | | const size_t K8 = 0; |
147 | | #endif |
148 | | |
149 | | // process leftovers |
150 | 0 | for (size_t kk = K8; kk < K; kk++) { |
151 | 0 | float reg = cbs[0][kk]; |
152 | 0 | for (size_t ij = 1; ij < M; ij++) { |
153 | 0 | reg += cbs[ij][kk]; |
154 | 0 | } |
155 | 0 | output[kk] += reg; |
156 | 0 | } |
157 | 0 | } Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm1ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm2ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm3ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm4ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm5ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm6ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm7ELm4EEEvmPKfPKmPKimmmPf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm8ELm4EEEvmPKfPKmPKimmmPf |
158 | | |
159 | | template <size_t M, size_t NK> |
160 | | void accum_and_finalize_tab( |
161 | | const float* const __restrict codebook_cross_norms, |
162 | | const uint64_t* const __restrict codebook_offsets, |
163 | | const int32_t* const __restrict codes_i, |
164 | | const size_t b, |
165 | | const size_t ldc, |
166 | | const size_t K, |
167 | | const float* const __restrict distances_i, |
168 | | const float* const __restrict cd_common, |
169 | 0 | float* const __restrict output) { |
170 | | // load pointers into registers |
171 | 0 | const float* cbs[M]; |
172 | 0 | for (size_t ij = 0; ij < M; ij++) { |
173 | 0 | const size_t code = static_cast<size_t>(codes_i[b * M + ij]); |
174 | 0 | cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc]; |
175 | 0 | } |
176 | | |
177 | | // do accumulation in registers using SIMD. |
178 | | // It is possible that compiler may be smart enough so that |
179 | | // this manual SIMD unrolling might be unneeded. |
180 | 0 | #if defined(__AVX2__) || defined(__aarch64__) |
181 | 0 | const size_t K8 = (K / (8 * NK)) * (8 * NK); |
182 | | |
183 | | // process in chunks of size (8 * NK) floats |
184 | 0 | for (size_t kk = 0; kk < K8; kk += 8 * NK) { |
185 | 0 | simd8float32 regs[NK]; |
186 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
187 | 0 | regs[ik].loadu(cbs[0] + kk + ik * 8); |
188 | 0 | } |
189 | |
|
190 | 0 | for (size_t ij = 1; ij < M; ij++) { |
191 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
192 | 0 | regs[ik] += simd8float32(cbs[ij] + kk + ik * 8); |
193 | 0 | } |
194 | 0 | } |
195 | |
|
196 | 0 | simd8float32 two(2.0f); |
197 | 0 | for (size_t ik = 0; ik < NK; ik++) { |
198 | | // cent_distances[b * K + k] = distances_i[b] + cd_common[k] |
199 | | // + 2 * dp[k]; |
200 | |
|
201 | 0 | simd8float32 common_v(cd_common + kk + ik * 8); |
202 | 0 | common_v = fmadd(two, regs[ik], common_v); |
203 | |
|
204 | 0 | common_v += simd8float32(distances_i[b]); |
205 | 0 | common_v.storeu(output + b * K + kk + ik * 8); |
206 | 0 | } |
207 | 0 | } |
208 | | #else |
209 | | const size_t K8 = 0; |
210 | | #endif |
211 | | |
212 | | // process leftovers |
213 | 0 | for (size_t kk = K8; kk < K; kk++) { |
214 | 0 | float reg = cbs[0][kk]; |
215 | 0 | for (size_t ij = 1; ij < M; ij++) { |
216 | 0 | reg += cbs[ij][kk]; |
217 | 0 | } |
218 | |
|
219 | 0 | output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg; |
220 | 0 | } |
221 | 0 | } Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm1ELm4EEEvPKfPKmPKimmmS3_S3_Pf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm2ELm4EEEvPKfPKmPKimmmS3_S3_Pf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm3ELm4EEEvPKfPKmPKimmmS3_S3_Pf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm4ELm4EEEvPKfPKmPKimmmS3_S3_Pf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm5ELm4EEEvPKfPKmPKimmmS3_S3_Pf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm6ELm4EEEvPKfPKmPKimmmS3_S3_Pf Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm7ELm4EEEvPKfPKmPKimmmS3_S3_Pf |
222 | | |
223 | | } // anonymous namespace |
224 | | |
225 | | /******************************************************************** |
226 | | * Single encoding step |
227 | | ********************************************************************/ |
228 | | |
229 | | void beam_search_encode_step( |
230 | | size_t d, |
231 | | size_t K, |
232 | | const float* cent, /// size (K, d) |
233 | | size_t n, |
234 | | size_t beam_size, |
235 | | const float* residuals, /// size (n, beam_size, d) |
236 | | size_t m, |
237 | | const int32_t* codes, /// size (n, beam_size, m) |
238 | | size_t new_beam_size, |
239 | | int32_t* new_codes, /// size (n, new_beam_size, m + 1) |
240 | | float* new_residuals, /// size (n, new_beam_size, d) |
241 | | float* new_distances, /// size (n, new_beam_size) |
242 | | Index* assign_index, |
243 | 0 | ApproxTopK_mode_t approx_topk_mode) { |
244 | | // we have to fill in the whole output matrix |
245 | 0 | FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K); |
246 | | |
247 | 0 | std::vector<float> cent_distances; |
248 | 0 | std::vector<idx_t> cent_ids; |
249 | |
|
250 | 0 | if (assign_index) { |
251 | | // search beam_size distances per query |
252 | 0 | FAISS_THROW_IF_NOT(assign_index->d == d); |
253 | 0 | cent_distances.resize(n * beam_size * new_beam_size); |
254 | 0 | cent_ids.resize(n * beam_size * new_beam_size); |
255 | 0 | if (assign_index->ntotal != 0) { |
256 | | // then we assume the codebooks are already added to the index |
257 | 0 | FAISS_THROW_IF_NOT(assign_index->ntotal == K); |
258 | 0 | } else { |
259 | 0 | assign_index->add(K, cent); |
260 | 0 | } |
261 | | |
262 | | // printf("beam_search_encode_step -- mem usage %zd\n", |
263 | | // get_mem_usage_kb()); |
264 | 0 | assign_index->search( |
265 | 0 | n * beam_size, |
266 | 0 | residuals, |
267 | 0 | new_beam_size, |
268 | 0 | cent_distances.data(), |
269 | 0 | cent_ids.data()); |
270 | 0 | } else { |
271 | | // do one big distance computation |
272 | 0 | cent_distances.resize(n * beam_size * K); |
273 | 0 | pairwise_L2sqr( |
274 | 0 | d, n * beam_size, residuals, K, cent, cent_distances.data()); |
275 | 0 | } |
276 | 0 | InterruptCallback::check(); |
277 | |
|
278 | 0 | #pragma omp parallel for if (n > 100) |
279 | 0 | for (int64_t i = 0; i < n; i++) { |
280 | 0 | const int32_t* codes_i = codes + i * m * beam_size; |
281 | 0 | int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size; |
282 | 0 | const float* residuals_i = residuals + i * d * beam_size; |
283 | 0 | float* new_residuals_i = new_residuals + i * d * new_beam_size; |
284 | |
|
285 | 0 | float* new_distances_i = new_distances + i * new_beam_size; |
286 | 0 | using C = CMax<float, int>; |
287 | |
|
288 | 0 | if (assign_index) { |
289 | 0 | const float* cent_distances_i = |
290 | 0 | cent_distances.data() + i * beam_size * new_beam_size; |
291 | 0 | const idx_t* cent_ids_i = |
292 | 0 | cent_ids.data() + i * beam_size * new_beam_size; |
293 | | |
294 | | // here we could be a tad more efficient by merging sorted arrays |
295 | 0 | for (int i_2 = 0; i_2 < new_beam_size; i_2++) { |
296 | 0 | new_distances_i[i_2] = C::neutral(); |
297 | 0 | } |
298 | 0 | std::vector<int> perm(new_beam_size, -1); |
299 | 0 | heap_addn<C>( |
300 | 0 | new_beam_size, |
301 | 0 | new_distances_i, |
302 | 0 | perm.data(), |
303 | 0 | cent_distances_i, |
304 | 0 | nullptr, |
305 | 0 | beam_size * new_beam_size); |
306 | 0 | heap_reorder<C>(new_beam_size, new_distances_i, perm.data()); |
307 | |
|
308 | 0 | for (int j = 0; j < new_beam_size; j++) { |
309 | 0 | int js = perm[j] / new_beam_size; |
310 | 0 | int ls = cent_ids_i[perm[j]]; |
311 | 0 | if (m > 0) { |
312 | 0 | memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); |
313 | 0 | } |
314 | 0 | new_codes_i[m] = ls; |
315 | 0 | new_codes_i += m + 1; |
316 | 0 | fvec_sub( |
317 | 0 | d, |
318 | 0 | residuals_i + js * d, |
319 | 0 | cent + ls * d, |
320 | 0 | new_residuals_i); |
321 | 0 | new_residuals_i += d; |
322 | 0 | } |
323 | |
|
324 | 0 | } else { |
325 | 0 | const float* cent_distances_i = |
326 | 0 | cent_distances.data() + i * beam_size * K; |
327 | | // then we have to select the best results |
328 | 0 | for (int i_2 = 0; i_2 < new_beam_size; i_2++) { |
329 | 0 | new_distances_i[i_2] = C::neutral(); |
330 | 0 | } |
331 | 0 | std::vector<int> perm(new_beam_size, -1); |
332 | |
|
333 | 0 | #define HANDLE_APPROX(NB, BD) \ |
334 | 0 | case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \ |
335 | 0 | HeapWithBuckets<C, NB, BD>::bs_addn( \ |
336 | 0 | beam_size, \ |
337 | 0 | K, \ |
338 | 0 | cent_distances_i, \ |
339 | 0 | new_beam_size, \ |
340 | 0 | new_distances_i, \ |
341 | 0 | perm.data()); \ |
342 | 0 | break; |
343 | |
|
344 | 0 | switch (approx_topk_mode) { |
345 | 0 | HANDLE_APPROX(8, 3) |
346 | 0 | HANDLE_APPROX(8, 2) |
347 | 0 | HANDLE_APPROX(16, 2) |
348 | 0 | HANDLE_APPROX(32, 2) |
349 | 0 | default: |
350 | 0 | heap_addn<C>( |
351 | 0 | new_beam_size, |
352 | 0 | new_distances_i, |
353 | 0 | perm.data(), |
354 | 0 | cent_distances_i, |
355 | 0 | nullptr, |
356 | 0 | beam_size * K); |
357 | 0 | } |
358 | 0 | heap_reorder<C>(new_beam_size, new_distances_i, perm.data()); |
359 | |
|
360 | 0 | #undef HANDLE_APPROX |
361 | |
|
362 | 0 | for (int j = 0; j < new_beam_size; j++) { |
363 | 0 | int js = perm[j] / K; |
364 | 0 | int ls = perm[j] % K; |
365 | 0 | if (m > 0) { |
366 | 0 | memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); |
367 | 0 | } |
368 | 0 | new_codes_i[m] = ls; |
369 | 0 | new_codes_i += m + 1; |
370 | 0 | fvec_sub( |
371 | 0 | d, |
372 | 0 | residuals_i + js * d, |
373 | 0 | cent + ls * d, |
374 | 0 | new_residuals_i); |
375 | 0 | new_residuals_i += d; |
376 | 0 | } |
377 | 0 | } |
378 | 0 | } |
379 | 0 | } |
380 | | |
381 | | // exposed in the faiss namespace |
382 | | void beam_search_encode_step_tab( |
383 | | size_t K, |
384 | | size_t n, |
385 | | size_t beam_size, // input sizes |
386 | | const float* codebook_cross_norms, // size K * ldc |
387 | | size_t ldc, |
388 | | const uint64_t* codebook_offsets, // m |
389 | | const float* query_cp, // size n * ldqc |
390 | | size_t ldqc, // >= K |
391 | | const float* cent_norms_i, // size K |
392 | | size_t m, |
393 | | const int32_t* codes, // n * beam_size * m |
394 | | const float* distances, // n * beam_size |
395 | | size_t new_beam_size, |
396 | | int32_t* new_codes, // n * new_beam_size * (m + 1) |
397 | | float* new_distances, // n * new_beam_size |
398 | | ApproxTopK_mode_t approx_topk_mode) // |
399 | 0 | { |
400 | 0 | FAISS_THROW_IF_NOT(ldc >= K); |
401 | | |
402 | 0 | #pragma omp parallel for if (n > 100) schedule(dynamic) |
403 | 0 | for (int64_t i = 0; i < n; i++) { |
404 | 0 | std::vector<float> cent_distances(beam_size * K); |
405 | 0 | std::vector<float> cd_common(K); |
406 | |
|
407 | 0 | const int32_t* codes_i = codes + i * m * beam_size; |
408 | 0 | const float* query_cp_i = query_cp + i * ldqc; |
409 | 0 | const float* distances_i = distances + i * beam_size; |
410 | |
|
411 | 0 | for (size_t k = 0; k < K; k++) { |
412 | 0 | cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k]; |
413 | 0 | } |
414 | |
|
415 | 0 | bool use_baseline_implementation = false; |
416 | | |
417 | | // This is the baseline implementation. Its primary flaw |
418 | | // that it writes way too many info to the temporary buffer |
419 | | // called dp. |
420 | | // |
421 | | // This baseline code is kept intentionally because it is easy to |
422 | | // understand what an optimized version optimizes exactly. |
423 | | // |
424 | 0 | if (use_baseline_implementation) { |
425 | 0 | for (size_t b = 0; b < beam_size; b++) { |
426 | 0 | std::vector<float> dp(K); |
427 | |
|
428 | 0 | for (size_t m1 = 0; m1 < m; m1++) { |
429 | 0 | size_t c = codes_i[b * m + m1]; |
430 | 0 | const float* cb = |
431 | 0 | &codebook_cross_norms |
432 | 0 | [(codebook_offsets[m1] + c) * ldc]; |
433 | 0 | fvec_add(K, cb, dp.data(), dp.data()); |
434 | 0 | } |
435 | |
|
436 | 0 | for (size_t k = 0; k < K; k++) { |
437 | 0 | cent_distances[b * K + k] = |
438 | 0 | distances_i[b] + cd_common[k] + 2 * dp[k]; |
439 | 0 | } |
440 | 0 | } |
441 | |
|
442 | 0 | } else { |
443 | | // An optimized implementation that avoids using a temporary buffer |
444 | | // and does the accumulation in registers. |
445 | | |
446 | | // Compute a sum of NK AQ codes. |
447 | 0 | #define ACCUM_AND_FINALIZE_TAB(NK) \ |
448 | 0 | case NK: \ |
449 | 0 | for (size_t b = 0; b < beam_size; b++) { \ |
450 | 0 | accum_and_finalize_tab<NK, 4>( \ |
451 | 0 | codebook_cross_norms, \ |
452 | 0 | codebook_offsets, \ |
453 | 0 | codes_i, \ |
454 | 0 | b, \ |
455 | 0 | ldc, \ |
456 | 0 | K, \ |
457 | 0 | distances_i, \ |
458 | 0 | cd_common.data(), \ |
459 | 0 | cent_distances.data()); \ |
460 | 0 | } \ |
461 | 0 | break; |
462 | | |
463 | | // this version contains many switch-case scenarios, but |
464 | | // they won't affect branch predictor. |
465 | 0 | switch (m) { |
466 | 0 | case 0: |
467 | | // trivial case |
468 | 0 | for (size_t b = 0; b < beam_size; b++) { |
469 | 0 | for (size_t k = 0; k < K; k++) { |
470 | 0 | cent_distances[b * K + k] = |
471 | 0 | distances_i[b] + cd_common[k]; |
472 | 0 | } |
473 | 0 | } |
474 | 0 | break; |
475 | | |
476 | 0 | ACCUM_AND_FINALIZE_TAB(1) |
477 | 0 | ACCUM_AND_FINALIZE_TAB(2) |
478 | 0 | ACCUM_AND_FINALIZE_TAB(3) |
479 | 0 | ACCUM_AND_FINALIZE_TAB(4) |
480 | 0 | ACCUM_AND_FINALIZE_TAB(5) |
481 | 0 | ACCUM_AND_FINALIZE_TAB(6) |
482 | 0 | ACCUM_AND_FINALIZE_TAB(7) |
483 | | |
484 | 0 | default: { |
485 | | // m >= 8 case. |
486 | | |
487 | | // A temporary buffer has to be used due to the lack of |
488 | | // registers. But we'll try to accumulate up to 8 AQ codes |
489 | | // in registers and issue a single write operation to the |
490 | | // buffer, while the baseline does no accumulation. So, the |
491 | | // number of write operations to the temporary buffer is |
492 | | // reduced 8x. |
493 | | |
494 | | // allocate a temporary buffer |
495 | 0 | std::vector<float> dp(K); |
496 | |
|
497 | 0 | for (size_t b = 0; b < beam_size; b++) { |
498 | | // Initialize it. Compute a sum of first 8 AQ codes |
499 | | // because m >= 8 . |
500 | 0 | accum_and_store_tab<8, 4>( |
501 | 0 | m, |
502 | 0 | codebook_cross_norms, |
503 | 0 | codebook_offsets, |
504 | 0 | codes_i, |
505 | 0 | b, |
506 | 0 | ldc, |
507 | 0 | K, |
508 | 0 | dp.data()); |
509 | |
|
510 | 0 | #define ACCUM_AND_ADD_TAB(NK) \ |
511 | 0 | case NK: \ |
512 | 0 | accum_and_add_tab<NK, 4>( \ |
513 | 0 | m, \ |
514 | 0 | codebook_cross_norms, \ |
515 | 0 | codebook_offsets + im, \ |
516 | 0 | codes_i + im, \ |
517 | 0 | b, \ |
518 | 0 | ldc, \ |
519 | 0 | K, \ |
520 | 0 | dp.data()); \ |
521 | 0 | break; |
522 | | |
523 | | // accumulate up to 8 additional AQ codes into |
524 | | // a temporary buffer |
525 | 0 | for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) { |
526 | 0 | size_t m_left = m - im; |
527 | 0 | if (m_left > 8) { |
528 | 0 | m_left = 8; |
529 | 0 | } |
530 | |
|
531 | 0 | switch (m_left) { |
532 | 0 | ACCUM_AND_ADD_TAB(1) |
533 | 0 | ACCUM_AND_ADD_TAB(2) |
534 | 0 | ACCUM_AND_ADD_TAB(3) |
535 | 0 | ACCUM_AND_ADD_TAB(4) |
536 | 0 | ACCUM_AND_ADD_TAB(5) |
537 | 0 | ACCUM_AND_ADD_TAB(6) |
538 | 0 | ACCUM_AND_ADD_TAB(7) |
539 | 0 | ACCUM_AND_ADD_TAB(8) |
540 | 0 | } |
541 | 0 | } |
542 | | |
543 | | // done. finalize the result |
544 | 0 | for (size_t k = 0; k < K; k++) { |
545 | 0 | cent_distances[b * K + k] = |
546 | 0 | distances_i[b] + cd_common[k] + 2 * dp[k]; |
547 | 0 | } |
548 | 0 | } |
549 | 0 | } |
550 | 0 | } |
551 | | |
552 | | // the optimized implementation ends here |
553 | 0 | } |
554 | 0 | using C = CMax<float, int>; |
555 | 0 | int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size; |
556 | 0 | float* new_distances_i = new_distances + i * new_beam_size; |
557 | |
|
558 | 0 | const float* cent_distances_i = cent_distances.data(); |
559 | | |
560 | | // then we have to select the best results |
561 | 0 | for (int i_2 = 0; i_2 < new_beam_size; i_2++) { |
562 | 0 | new_distances_i[i_2] = C::neutral(); |
563 | 0 | } |
564 | 0 | std::vector<int> perm(new_beam_size, -1); |
565 | |
|
566 | 0 | #define HANDLE_APPROX(NB, BD) \ |
567 | 0 | case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \ |
568 | 0 | HeapWithBuckets<C, NB, BD>::bs_addn( \ |
569 | 0 | beam_size, \ |
570 | 0 | K, \ |
571 | 0 | cent_distances_i, \ |
572 | 0 | new_beam_size, \ |
573 | 0 | new_distances_i, \ |
574 | 0 | perm.data()); \ |
575 | 0 | break; |
576 | |
|
577 | 0 | switch (approx_topk_mode) { |
578 | 0 | HANDLE_APPROX(8, 3) |
579 | 0 | HANDLE_APPROX(8, 2) |
580 | 0 | HANDLE_APPROX(16, 2) |
581 | 0 | HANDLE_APPROX(32, 2) |
582 | 0 | default: |
583 | 0 | heap_addn<C>( |
584 | 0 | new_beam_size, |
585 | 0 | new_distances_i, |
586 | 0 | perm.data(), |
587 | 0 | cent_distances_i, |
588 | 0 | nullptr, |
589 | 0 | beam_size * K); |
590 | 0 | break; |
591 | 0 | } |
592 | | |
593 | 0 | heap_reorder<C>(new_beam_size, new_distances_i, perm.data()); |
594 | |
|
595 | 0 | #undef HANDLE_APPROX |
596 | |
|
597 | 0 | for (int j = 0; j < new_beam_size; j++) { |
598 | 0 | int js = perm[j] / K; |
599 | 0 | int ls = perm[j] % K; |
600 | 0 | if (m > 0) { |
601 | 0 | memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); |
602 | 0 | } |
603 | 0 | new_codes_i[m] = ls; |
604 | 0 | new_codes_i += m + 1; |
605 | 0 | } |
606 | 0 | } |
607 | 0 | } |
608 | | |
609 | | /******************************************************************** |
610 | | * Multiple encoding steps |
611 | | ********************************************************************/ |
612 | | |
613 | | namespace rq_encode_steps { |
614 | | |
615 | | void refine_beam_mp( |
616 | | const ResidualQuantizer& rq, |
617 | | size_t n, |
618 | | size_t beam_size, |
619 | | const float* x, |
620 | | int out_beam_size, |
621 | | int32_t* out_codes, |
622 | | float* out_residuals, |
623 | | float* out_distances, |
624 | 0 | RefineBeamMemoryPool& pool) { |
625 | 0 | int cur_beam_size = beam_size; |
626 | |
|
627 | 0 | double t0 = getmillisecs(); |
628 | | |
629 | | // find the max_beam_size |
630 | 0 | int max_beam_size = 0; |
631 | 0 | { |
632 | 0 | int tmp_beam_size = cur_beam_size; |
633 | 0 | for (int m = 0; m < rq.M; m++) { |
634 | 0 | int K = 1 << rq.nbits[m]; |
635 | 0 | int new_beam_size = std::min(tmp_beam_size * K, out_beam_size); |
636 | 0 | tmp_beam_size = new_beam_size; |
637 | |
|
638 | 0 | if (max_beam_size < new_beam_size) { |
639 | 0 | max_beam_size = new_beam_size; |
640 | 0 | } |
641 | 0 | } |
642 | 0 | } |
643 | | |
644 | | // preallocate buffers |
645 | 0 | pool.new_codes.resize(n * max_beam_size * (rq.M + 1)); |
646 | 0 | pool.new_residuals.resize(n * max_beam_size * rq.d); |
647 | |
|
648 | 0 | pool.codes.resize(n * max_beam_size * (rq.M + 1)); |
649 | 0 | pool.distances.resize(n * max_beam_size); |
650 | 0 | pool.residuals.resize(n * rq.d * max_beam_size); |
651 | |
|
652 | 0 | for (size_t i = 0; i < n * rq.d * beam_size; i++) { |
653 | 0 | pool.residuals[i] = x[i]; |
654 | 0 | } |
655 | | |
656 | | // set up pointers to buffers |
657 | 0 | int32_t* __restrict codes_ptr = pool.codes.data(); |
658 | 0 | float* __restrict residuals_ptr = pool.residuals.data(); |
659 | |
|
660 | 0 | int32_t* __restrict new_codes_ptr = pool.new_codes.data(); |
661 | 0 | float* __restrict new_residuals_ptr = pool.new_residuals.data(); |
662 | | |
663 | | // index |
664 | 0 | std::unique_ptr<Index> assign_index; |
665 | 0 | if (rq.assign_index_factory) { |
666 | 0 | assign_index.reset((*rq.assign_index_factory)(rq.d)); |
667 | 0 | } |
668 | | |
669 | | // main loop |
670 | 0 | size_t codes_size = 0; |
671 | 0 | size_t distances_size = 0; |
672 | 0 | size_t residuals_size = 0; |
673 | |
|
674 | 0 | for (int m = 0; m < rq.M; m++) { |
675 | 0 | int K = 1 << rq.nbits[m]; |
676 | |
|
677 | 0 | const float* __restrict codebooks_m = |
678 | 0 | rq.codebooks.data() + rq.codebook_offsets[m] * rq.d; |
679 | |
|
680 | 0 | const int new_beam_size = std::min(cur_beam_size * K, out_beam_size); |
681 | |
|
682 | 0 | codes_size = n * new_beam_size * (m + 1); |
683 | 0 | residuals_size = n * new_beam_size * rq.d; |
684 | 0 | distances_size = n * new_beam_size; |
685 | |
|
686 | 0 | beam_search_encode_step( |
687 | 0 | rq.d, |
688 | 0 | K, |
689 | 0 | codebooks_m, |
690 | 0 | n, |
691 | 0 | cur_beam_size, |
692 | 0 | residuals_ptr, |
693 | 0 | m, |
694 | 0 | codes_ptr, |
695 | 0 | new_beam_size, |
696 | 0 | new_codes_ptr, |
697 | 0 | new_residuals_ptr, |
698 | 0 | pool.distances.data(), |
699 | 0 | assign_index.get(), |
700 | 0 | rq.approx_topk_mode); |
701 | |
|
702 | 0 | if (assign_index != nullptr) { |
703 | 0 | assign_index->reset(); |
704 | 0 | } |
705 | |
|
706 | 0 | std::swap(codes_ptr, new_codes_ptr); |
707 | 0 | std::swap(residuals_ptr, new_residuals_ptr); |
708 | |
|
709 | 0 | cur_beam_size = new_beam_size; |
710 | |
|
711 | 0 | if (rq.verbose) { |
712 | 0 | float sum_distances = 0; |
713 | 0 | for (int j = 0; j < distances_size; j++) { |
714 | 0 | sum_distances += pool.distances[j]; |
715 | 0 | } |
716 | |
|
717 | 0 | printf("[%.3f s] encode stage %d, %d bits, " |
718 | 0 | "total error %g, beam_size %d\n", |
719 | 0 | (getmillisecs() - t0) / 1000, |
720 | 0 | m, |
721 | 0 | int(rq.nbits[m]), |
722 | 0 | sum_distances, |
723 | 0 | cur_beam_size); |
724 | 0 | } |
725 | 0 | } |
726 | |
|
727 | 0 | if (out_codes) { |
728 | 0 | memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr)); |
729 | 0 | } |
730 | 0 | if (out_residuals) { |
731 | 0 | memcpy(out_residuals, |
732 | 0 | residuals_ptr, |
733 | 0 | residuals_size * sizeof(*residuals_ptr)); |
734 | 0 | } |
735 | 0 | if (out_distances) { |
736 | 0 | memcpy(out_distances, |
737 | 0 | pool.distances.data(), |
738 | 0 | distances_size * sizeof(pool.distances[0])); |
739 | 0 | } |
740 | 0 | } |
741 | | |
742 | | void refine_beam_LUT_mp( |
743 | | const ResidualQuantizer& rq, |
744 | | size_t n, |
745 | | const float* query_norms, // size n |
746 | | const float* query_cp, // |
747 | | int out_beam_size, |
748 | | int32_t* out_codes, |
749 | | float* out_distances, |
750 | 0 | RefineBeamLUTMemoryPool& pool) { |
751 | 0 | int beam_size = 1; |
752 | |
|
753 | 0 | double t0 = getmillisecs(); |
754 | | |
755 | | // find the max_beam_size |
756 | 0 | int max_beam_size = 0; |
757 | 0 | { |
758 | 0 | int tmp_beam_size = beam_size; |
759 | 0 | for (int m = 0; m < rq.M; m++) { |
760 | 0 | int K = 1 << rq.nbits[m]; |
761 | 0 | int new_beam_size = std::min(tmp_beam_size * K, out_beam_size); |
762 | 0 | tmp_beam_size = new_beam_size; |
763 | |
|
764 | 0 | if (max_beam_size < new_beam_size) { |
765 | 0 | max_beam_size = new_beam_size; |
766 | 0 | } |
767 | 0 | } |
768 | 0 | } |
769 | | |
770 | | // preallocate buffers |
771 | 0 | pool.new_codes.resize(n * max_beam_size * (rq.M + 1)); |
772 | 0 | pool.new_distances.resize(n * max_beam_size); |
773 | |
|
774 | 0 | pool.codes.resize(n * max_beam_size * (rq.M + 1)); |
775 | 0 | pool.distances.resize(n * max_beam_size); |
776 | |
|
777 | 0 | for (size_t i = 0; i < n; i++) { |
778 | 0 | pool.distances[i] = query_norms[i]; |
779 | 0 | } |
780 | | |
781 | | // set up pointers to buffers |
782 | 0 | int32_t* __restrict new_codes_ptr = pool.new_codes.data(); |
783 | 0 | float* __restrict new_distances_ptr = pool.new_distances.data(); |
784 | |
|
785 | 0 | int32_t* __restrict codes_ptr = pool.codes.data(); |
786 | 0 | float* __restrict distances_ptr = pool.distances.data(); |
787 | | |
788 | | // main loop |
789 | 0 | size_t codes_size = 0; |
790 | 0 | size_t distances_size = 0; |
791 | 0 | size_t cross_ofs = 0; |
792 | 0 | for (int m = 0; m < rq.M; m++) { |
793 | 0 | int K = 1 << rq.nbits[m]; |
794 | | |
795 | | // it is guaranteed that (new_beam_size <= max_beam_size) |
796 | 0 | int new_beam_size = std::min(beam_size * K, out_beam_size); |
797 | |
|
798 | 0 | codes_size = n * new_beam_size * (m + 1); |
799 | 0 | distances_size = n * new_beam_size; |
800 | 0 | FAISS_THROW_IF_NOT( |
801 | 0 | cross_ofs + rq.codebook_offsets[m] * K <= |
802 | 0 | rq.codebook_cross_products.size()); |
803 | 0 | beam_search_encode_step_tab( |
804 | 0 | K, |
805 | 0 | n, |
806 | 0 | beam_size, |
807 | 0 | rq.codebook_cross_products.data() + cross_ofs, |
808 | 0 | K, |
809 | 0 | rq.codebook_offsets.data(), |
810 | 0 | query_cp + rq.codebook_offsets[m], |
811 | 0 | rq.total_codebook_size, |
812 | 0 | rq.centroid_norms.data() + rq.codebook_offsets[m], |
813 | 0 | m, |
814 | 0 | codes_ptr, |
815 | 0 | distances_ptr, |
816 | 0 | new_beam_size, |
817 | 0 | new_codes_ptr, |
818 | 0 | new_distances_ptr, |
819 | 0 | rq.approx_topk_mode); |
820 | 0 | cross_ofs += rq.codebook_offsets[m] * K; |
821 | 0 | std::swap(codes_ptr, new_codes_ptr); |
822 | 0 | std::swap(distances_ptr, new_distances_ptr); |
823 | |
|
824 | 0 | beam_size = new_beam_size; |
825 | |
|
826 | 0 | if (rq.verbose) { |
827 | 0 | float sum_distances = 0; |
828 | 0 | for (int j = 0; j < distances_size; j++) { |
829 | 0 | sum_distances += distances_ptr[j]; |
830 | 0 | } |
831 | 0 | printf("[%.3f s] encode stage %d, %d bits, " |
832 | 0 | "total error %g, beam_size %d\n", |
833 | 0 | (getmillisecs() - t0) / 1000, |
834 | 0 | m, |
835 | 0 | int(rq.nbits[m]), |
836 | 0 | sum_distances, |
837 | 0 | beam_size); |
838 | 0 | } |
839 | 0 | } |
840 | 0 | if (out_codes) { |
841 | 0 | memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr)); |
842 | 0 | } |
843 | 0 | if (out_distances) { |
844 | 0 | memcpy(out_distances, |
845 | 0 | distances_ptr, |
846 | 0 | distances_size * sizeof(*distances_ptr)); |
847 | 0 | } |
848 | 0 | } |
849 | | |
850 | | // this is for use_beam_LUT == 0 |
851 | | void compute_codes_add_centroids_mp_lut0( |
852 | | const ResidualQuantizer& rq, |
853 | | const float* x, |
854 | | uint8_t* codes_out, |
855 | | size_t n, |
856 | | const float* centroids, |
857 | 0 | ComputeCodesAddCentroidsLUT0MemoryPool& pool) { |
858 | 0 | pool.codes.resize(rq.max_beam_size * rq.M * n); |
859 | 0 | pool.distances.resize(rq.max_beam_size * n); |
860 | |
|
861 | 0 | pool.residuals.resize(rq.max_beam_size * n * rq.d); |
862 | |
|
863 | 0 | refine_beam_mp( |
864 | 0 | rq, |
865 | 0 | n, |
866 | 0 | 1, |
867 | 0 | x, |
868 | 0 | rq.max_beam_size, |
869 | 0 | pool.codes.data(), |
870 | 0 | pool.residuals.data(), |
871 | 0 | pool.distances.data(), |
872 | 0 | pool.refine_beam_pool); |
873 | |
|
874 | 0 | if (rq.search_type == ResidualQuantizer::ST_norm_float || |
875 | 0 | rq.search_type == ResidualQuantizer::ST_norm_qint8 || |
876 | 0 | rq.search_type == ResidualQuantizer::ST_norm_qint4) { |
877 | 0 | pool.norms.resize(n); |
878 | | // recover the norms of reconstruction as |
879 | | // || original_vector - residual ||^2 |
880 | 0 | for (size_t i = 0; i < n; i++) { |
881 | 0 | pool.norms[i] = fvec_L2sqr( |
882 | 0 | x + i * rq.d, |
883 | 0 | pool.residuals.data() + i * rq.max_beam_size * rq.d, |
884 | 0 | rq.d); |
885 | 0 | } |
886 | 0 | } |
887 | | |
888 | | // pack only the first code of the beam |
889 | | // (hence the ld_codes=M * max_beam_size) |
890 | 0 | rq.pack_codes( |
891 | 0 | n, |
892 | 0 | pool.codes.data(), |
893 | 0 | codes_out, |
894 | 0 | rq.M * rq.max_beam_size, |
895 | 0 | (pool.norms.size() > 0) ? pool.norms.data() : nullptr, |
896 | 0 | centroids); |
897 | 0 | } |
898 | | |
899 | | // use_beam_LUT == 1 |
900 | | void compute_codes_add_centroids_mp_lut1( |
901 | | const ResidualQuantizer& rq, |
902 | | const float* x, |
903 | | uint8_t* codes_out, |
904 | | size_t n, |
905 | | const float* centroids, |
906 | 0 | ComputeCodesAddCentroidsLUT1MemoryPool& pool) { |
907 | | // |
908 | 0 | pool.codes.resize(rq.max_beam_size * rq.M * n); |
909 | 0 | pool.distances.resize(rq.max_beam_size * n); |
910 | |
|
911 | 0 | FAISS_THROW_IF_NOT_MSG( |
912 | 0 | rq.M == 1 || rq.codebook_cross_products.size() > 0, |
913 | 0 | "call compute_codebook_tables first"); |
914 | | |
915 | 0 | pool.query_norms.resize(n); |
916 | 0 | fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n); |
917 | |
|
918 | 0 | pool.query_cp.resize(n * rq.total_codebook_size); |
919 | 0 | { |
920 | 0 | FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n; |
921 | 0 | float zero = 0, one = 1; |
922 | 0 | sgemm_("Transposed", |
923 | 0 | "Not transposed", |
924 | 0 | &ti, |
925 | 0 | &ni, |
926 | 0 | &di, |
927 | 0 | &one, |
928 | 0 | rq.codebooks.data(), |
929 | 0 | &di, |
930 | 0 | x, |
931 | 0 | &di, |
932 | 0 | &zero, |
933 | 0 | pool.query_cp.data(), |
934 | 0 | &ti); |
935 | 0 | } |
936 | |
|
937 | 0 | refine_beam_LUT_mp( |
938 | 0 | rq, |
939 | 0 | n, |
940 | 0 | pool.query_norms.data(), |
941 | 0 | pool.query_cp.data(), |
942 | 0 | rq.max_beam_size, |
943 | 0 | pool.codes.data(), |
944 | 0 | pool.distances.data(), |
945 | 0 | pool.refine_beam_lut_pool); |
946 | | |
947 | | // pack only the first code of the beam |
948 | | // (hence the ld_codes=M * max_beam_size) |
949 | 0 | rq.pack_codes( |
950 | 0 | n, |
951 | 0 | pool.codes.data(), |
952 | 0 | codes_out, |
953 | 0 | rq.M * rq.max_beam_size, |
954 | 0 | nullptr, |
955 | 0 | centroids); |
956 | 0 | } |
957 | | |
958 | | } // namespace rq_encode_steps |
959 | | |
960 | | } // namespace faiss |