contrib/faiss/faiss/impl/ProductQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/impl/ProductQuantizer.h> |
11 | | |
12 | | #include <cstddef> |
13 | | #include <cstdio> |
14 | | #include <cstring> |
15 | | #include <memory> |
16 | | |
17 | | #include <algorithm> |
18 | | |
19 | | #include <faiss/IndexFlat.h> |
20 | | #include <faiss/VectorTransform.h> |
21 | | #include <faiss/impl/FaissAssert.h> |
22 | | #include <faiss/utils/distances.h> |
23 | | |
24 | | extern "C" { |
25 | | |
26 | | /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ |
27 | | |
28 | | int sgemm_( |
29 | | const char* transa, |
30 | | const char* transb, |
31 | | FINTEGER* m, |
32 | | FINTEGER* n, |
33 | | FINTEGER* k, |
34 | | const float* alpha, |
35 | | const float* a, |
36 | | FINTEGER* lda, |
37 | | const float* b, |
38 | | FINTEGER* ldb, |
39 | | float* beta, |
40 | | float* c, |
41 | | FINTEGER* ldc); |
42 | | } |
43 | | |
44 | | namespace faiss { |
45 | | |
46 | | /********************************************* |
47 | | * PQ implementation |
48 | | *********************************************/ |
49 | | |
50 | | ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits) |
51 | 5 | : Quantizer(d, 0), M(M), nbits(nbits), assign_index(nullptr) { |
52 | 5 | set_derived_values(); |
53 | 5 | } |
54 | | |
55 | 2 | ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {} |
56 | | |
57 | 7 | void ProductQuantizer::set_derived_values() { |
58 | | // quite a few derived values |
59 | 7 | FAISS_THROW_IF_NOT_MSG( |
60 | 7 | d % M == 0, |
61 | 7 | "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)"); |
62 | 7 | dsub = d / M; |
63 | 7 | code_size = (nbits * M + 7) / 8; |
64 | 7 | FAISS_THROW_IF_MSG(nbits > 24, "nbits larger than 24 is not practical."); |
65 | 7 | ksub = 1 << nbits; |
66 | 7 | centroids.resize(d * ksub); |
67 | 7 | verbose = false; |
68 | 7 | train_type = Train_default; |
69 | 7 | } |
70 | | |
71 | 0 | void ProductQuantizer::set_params(const float* centroids_, int m) { |
72 | 0 | memcpy(get_centroids(m, 0), |
73 | 0 | centroids_, |
74 | 0 | ksub * dsub * sizeof(centroids_[0])); |
75 | 0 | } |
76 | | |
77 | | static void init_hypercube( |
78 | | int d, |
79 | | int nbits, |
80 | | int n, |
81 | | const float* x, |
82 | 0 | float* centroids) { |
83 | 0 | std::vector<float> mean(d); |
84 | 0 | for (int i = 0; i < n; i++) |
85 | 0 | for (int j = 0; j < d; j++) |
86 | 0 | mean[j] += x[i * d + j]; |
87 | |
|
88 | 0 | float maxm = 0; |
89 | 0 | for (int j = 0; j < d; j++) { |
90 | 0 | mean[j] /= n; |
91 | 0 | if (fabs(mean[j]) > maxm) |
92 | 0 | maxm = fabs(mean[j]); |
93 | 0 | } |
94 | |
|
95 | 0 | for (int i = 0; i < (1 << nbits); i++) { |
96 | 0 | float* cent = centroids + i * d; |
97 | 0 | for (int j = 0; j < nbits; j++) |
98 | 0 | cent[j] = mean[j] + (((i >> j) & 1) ? 1 : -1) * maxm; |
99 | 0 | for (int j = nbits; j < d; j++) |
100 | 0 | cent[j] = mean[j]; |
101 | 0 | } |
102 | 0 | } |
103 | | |
104 | | static void init_hypercube_pca( |
105 | | int d, |
106 | | int nbits, |
107 | | int n, |
108 | | const float* x, |
109 | 0 | float* centroids) { |
110 | 0 | PCAMatrix pca(d, nbits); |
111 | 0 | pca.train(n, x); |
112 | |
|
113 | 0 | for (int i = 0; i < (1 << nbits); i++) { |
114 | 0 | float* cent = centroids + i * d; |
115 | 0 | for (int j = 0; j < d; j++) { |
116 | 0 | cent[j] = pca.mean[j]; |
117 | 0 | float f = 1.0; |
118 | 0 | for (int k = 0; k < nbits; k++) |
119 | 0 | cent[j] += f * sqrt(pca.eigenvalues[k]) * |
120 | 0 | (((i >> k) & 1) ? 1 : -1) * pca.PCAMat[j + k * d]; |
121 | 0 | } |
122 | 0 | } |
123 | 0 | } |
124 | | |
125 | 2 | void ProductQuantizer::train(size_t n, const float* x) { |
126 | 2 | if (train_type != Train_shared) { |
127 | 2 | train_type_t final_train_type; |
128 | 2 | final_train_type = train_type; |
129 | 2 | if (train_type == Train_hypercube || |
130 | 2 | train_type == Train_hypercube_pca) { |
131 | 0 | if (dsub < nbits) { |
132 | 0 | final_train_type = Train_default; |
133 | 0 | printf("cannot train hypercube: nbits=%zd > log2(d=%zd)\n", |
134 | 0 | nbits, |
135 | 0 | dsub); |
136 | 0 | } |
137 | 0 | } |
138 | | |
139 | 2 | std::unique_ptr<float[]> xslice(new float[n * dsub]); |
140 | 4 | for (int m = 0; m < M; m++) { |
141 | 202 | for (int j = 0; j < n; j++) |
142 | 200 | memcpy(xslice.get() + j * dsub, |
143 | 200 | x + j * d + m * dsub, |
144 | 200 | dsub * sizeof(float)); |
145 | | |
146 | 2 | Clustering clus(dsub, ksub, cp); |
147 | | |
148 | | // we have some initialization for the centroids |
149 | 2 | if (final_train_type != Train_default) { |
150 | 0 | clus.centroids.resize(dsub * ksub); |
151 | 0 | } |
152 | | |
153 | 2 | switch (final_train_type) { |
154 | 0 | case Train_hypercube: |
155 | 0 | init_hypercube( |
156 | 0 | dsub, |
157 | 0 | nbits, |
158 | 0 | n, |
159 | 0 | xslice.get(), |
160 | 0 | clus.centroids.data()); |
161 | 0 | break; |
162 | 0 | case Train_hypercube_pca: |
163 | 0 | init_hypercube_pca( |
164 | 0 | dsub, |
165 | 0 | nbits, |
166 | 0 | n, |
167 | 0 | xslice.get(), |
168 | 0 | clus.centroids.data()); |
169 | 0 | break; |
170 | 0 | case Train_hot_start: |
171 | 0 | memcpy(clus.centroids.data(), |
172 | 0 | get_centroids(m, 0), |
173 | 0 | dsub * ksub * sizeof(float)); |
174 | 0 | break; |
175 | 2 | default:; |
176 | 2 | } |
177 | | |
178 | 2 | if (verbose) { |
179 | 0 | clus.verbose = true; |
180 | 0 | printf("Training PQ slice %d/%zd\n", m, M); |
181 | 0 | } |
182 | 2 | IndexFlatL2 index(dsub); |
183 | 2 | clus.train(n, xslice.get(), assign_index ? *assign_index : index); |
184 | 2 | set_params(clus.centroids.data(), m); |
185 | 2 | } |
186 | | |
187 | 2 | } else { |
188 | 0 | Clustering clus(dsub, ksub, cp); |
189 | |
|
190 | 0 | if (verbose) { |
191 | 0 | clus.verbose = true; |
192 | 0 | printf("Training all PQ slices at once\n"); |
193 | 0 | } |
194 | |
|
195 | 0 | IndexFlatL2 index(dsub); |
196 | |
|
197 | 0 | clus.train(n * M, x, assign_index ? *assign_index : index); |
198 | 0 | for (int m = 0; m < M; m++) { |
199 | 0 | set_params(clus.centroids.data(), m); |
200 | 0 | } |
201 | 0 | } |
202 | 2 | } |
203 | | |
204 | | template <class PQEncoder> |
205 | 0 | void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) { |
206 | 0 | std::vector<float> distances(pq.ksub); |
207 | | |
208 | | // It seems to be meaningless to allocate std::vector<float> distances. |
209 | | // But it is done in order to cope the ineffectiveness of the way |
210 | | // the compiler generates the code. Basically, doing something like |
211 | | // |
212 | | // size_t min_distance = HUGE_VALF; |
213 | | // size_t idxm = 0; |
214 | | // for (size_t i = 0; i < N; i++) { |
215 | | // const float distance = compute_distance(x, y + i * d, d); |
216 | | // if (distance < min_distance) { |
217 | | // min_distance = distance; |
218 | | // idxm = i; |
219 | | // } |
220 | | // } |
221 | | // |
222 | | // generates significantly more CPU instructions than the baseline |
223 | | // |
224 | | // std::vector<float> distances_cached(N); |
225 | | // for (size_t i = 0; i < N; i++) { |
226 | | // distances_cached[i] = compute_distance(x, y + i * d, d); |
227 | | // } |
228 | | // size_t min_distance = HUGE_VALF; |
229 | | // size_t idxm = 0; |
230 | | // for (size_t i = 0; i < N; i++) { |
231 | | // const float distance = distances_cached[i]; |
232 | | // if (distance < min_distance) { |
233 | | // min_distance = distance; |
234 | | // idxm = i; |
235 | | // } |
236 | | // } |
237 | | // |
238 | | // So, the baseline is faster. This is because of the vectorization. |
239 | | // I suppose that the branch predictor might affect the performance as well. |
240 | | // So, the buffer is allocated, but it might be unused in |
241 | | // manually optimized code. Let's hope that the compiler is smart enough to |
242 | | // get rid of std::vector allocation in such a case. |
243 | |
|
244 | 0 | PQEncoder encoder(code, pq.nbits); |
245 | 0 | for (size_t m = 0; m < pq.M; m++) { |
246 | 0 | const float* xsub = x + m * pq.dsub; |
247 | |
|
248 | 0 | uint64_t idxm = 0; |
249 | 0 | if (pq.transposed_centroids.empty()) { |
250 | | // the regular version |
251 | 0 | idxm = fvec_L2sqr_ny_nearest( |
252 | 0 | distances.data(), |
253 | 0 | xsub, |
254 | 0 | pq.get_centroids(m, 0), |
255 | 0 | pq.dsub, |
256 | 0 | pq.ksub); |
257 | 0 | } else { |
258 | | // transposed centroids are available, use'em |
259 | 0 | idxm = fvec_L2sqr_ny_nearest_y_transposed( |
260 | 0 | distances.data(), |
261 | 0 | xsub, |
262 | 0 | pq.transposed_centroids.data() + m * pq.ksub, |
263 | 0 | pq.centroids_sq_lengths.data() + m * pq.ksub, |
264 | 0 | pq.dsub, |
265 | 0 | pq.M * pq.ksub, |
266 | 0 | pq.ksub); |
267 | 0 | } |
268 | |
|
269 | 0 | encoder.encode(idxm); |
270 | 0 | } |
271 | 0 | } Unexecuted instantiation: _ZN5faiss12compute_codeINS_10PQEncoder8EEEvRKNS_16ProductQuantizerEPKfPh Unexecuted instantiation: _ZN5faiss12compute_codeINS_11PQEncoder16EEEvRKNS_16ProductQuantizerEPKfPh Unexecuted instantiation: _ZN5faiss12compute_codeINS_16PQEncoderGenericEEEvRKNS_16ProductQuantizerEPKfPh |
272 | | |
273 | 0 | void ProductQuantizer::compute_code(const float* x, uint8_t* code) const { |
274 | 0 | switch (nbits) { |
275 | 0 | case 8: |
276 | 0 | faiss::compute_code<PQEncoder8>(*this, x, code); |
277 | 0 | break; |
278 | | |
279 | 0 | case 16: |
280 | 0 | faiss::compute_code<PQEncoder16>(*this, x, code); |
281 | 0 | break; |
282 | | |
283 | 0 | default: |
284 | 0 | faiss::compute_code<PQEncoderGeneric>(*this, x, code); |
285 | 0 | break; |
286 | 0 | } |
287 | 0 | } |
288 | | |
289 | | template <class PQDecoder> |
290 | 0 | void decode(const ProductQuantizer& pq, const uint8_t* code, float* x) { |
291 | 0 | PQDecoder decoder(code, pq.nbits); |
292 | 0 | for (size_t m = 0; m < pq.M; m++) { |
293 | 0 | uint64_t c = decoder.decode(); |
294 | 0 | memcpy(x + m * pq.dsub, |
295 | 0 | pq.get_centroids(m, c), |
296 | 0 | sizeof(float) * pq.dsub); |
297 | 0 | } |
298 | 0 | } Unexecuted instantiation: _ZN5faiss6decodeINS_10PQDecoder8EEEvRKNS_16ProductQuantizerEPKhPf Unexecuted instantiation: _ZN5faiss6decodeINS_11PQDecoder16EEEvRKNS_16ProductQuantizerEPKhPf Unexecuted instantiation: _ZN5faiss6decodeINS_16PQDecoderGenericEEEvRKNS_16ProductQuantizerEPKhPf |
299 | | |
300 | 0 | void ProductQuantizer::decode(const uint8_t* code, float* x) const { |
301 | 0 | switch (nbits) { |
302 | 0 | case 8: |
303 | 0 | faiss::decode<PQDecoder8>(*this, code, x); |
304 | 0 | break; |
305 | | |
306 | 0 | case 16: |
307 | 0 | faiss::decode<PQDecoder16>(*this, code, x); |
308 | 0 | break; |
309 | | |
310 | 0 | default: |
311 | 0 | faiss::decode<PQDecoderGeneric>(*this, code, x); |
312 | 0 | break; |
313 | 0 | } |
314 | 0 | } |
315 | | |
316 | 0 | void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const { |
317 | 0 | #pragma omp parallel for if (n > 100) |
318 | 0 | for (int64_t i = 0; i < n; i++) { |
319 | 0 | this->decode(code + code_size * i, x + d * i); |
320 | 0 | } |
321 | 0 | } |
322 | | |
323 | | void ProductQuantizer::compute_code_from_distance_table( |
324 | | const float* tab, |
325 | 0 | uint8_t* code) const { |
326 | 0 | PQEncoderGeneric encoder(code, nbits); |
327 | 0 | for (size_t m = 0; m < M; m++) { |
328 | 0 | float mindis = 1e20; |
329 | 0 | uint64_t idxm = 0; |
330 | | |
331 | | /* Find best centroid */ |
332 | 0 | for (size_t j = 0; j < ksub; j++) { |
333 | 0 | float dis = *tab++; |
334 | 0 | if (dis < mindis) { |
335 | 0 | mindis = dis; |
336 | 0 | idxm = j; |
337 | 0 | } |
338 | 0 | } |
339 | |
|
340 | 0 | encoder.encode(idxm); |
341 | 0 | } |
342 | 0 | } |
343 | | |
344 | | void ProductQuantizer::compute_codes_with_assign_index( |
345 | | const float* x, |
346 | | uint8_t* codes, |
347 | 0 | size_t n) { |
348 | 0 | FAISS_THROW_IF_NOT(assign_index && assign_index->d == dsub); |
349 | | |
350 | 0 | for (size_t m = 0; m < M; m++) { |
351 | 0 | assign_index->reset(); |
352 | 0 | assign_index->add(ksub, get_centroids(m, 0)); |
353 | 0 | size_t bs = 65536; |
354 | |
|
355 | 0 | std::unique_ptr<float[]> xslice(new float[bs * dsub]); |
356 | 0 | std::unique_ptr<idx_t[]> assign(new idx_t[bs]); |
357 | |
|
358 | 0 | for (size_t i0 = 0; i0 < n; i0 += bs) { |
359 | 0 | size_t i1 = std::min(i0 + bs, n); |
360 | |
|
361 | 0 | for (size_t i = i0; i < i1; i++) { |
362 | 0 | memcpy(xslice.get() + (i - i0) * dsub, |
363 | 0 | x + i * d + m * dsub, |
364 | 0 | dsub * sizeof(float)); |
365 | 0 | } |
366 | |
|
367 | 0 | assign_index->assign(i1 - i0, xslice.get(), assign.get()); |
368 | |
|
369 | 0 | if (nbits == 8) { |
370 | 0 | uint8_t* c = codes + code_size * i0 + m; |
371 | 0 | for (size_t i = i0; i < i1; i++) { |
372 | 0 | *c = assign[i - i0]; |
373 | 0 | c += M; |
374 | 0 | } |
375 | 0 | } else if (nbits == 16) { |
376 | 0 | uint16_t* c = (uint16_t*)(codes + code_size * i0 + m * 2); |
377 | 0 | for (size_t i = i0; i < i1; i++) { |
378 | 0 | *c = assign[i - i0]; |
379 | 0 | c += M; |
380 | 0 | } |
381 | 0 | } else { |
382 | 0 | for (size_t i = i0; i < i1; ++i) { |
383 | 0 | uint8_t* c = codes + code_size * i + ((m * nbits) / 8); |
384 | 0 | uint8_t offset = (m * nbits) % 8; |
385 | 0 | uint64_t ass = assign[i - i0]; |
386 | |
|
387 | 0 | PQEncoderGeneric encoder(c, nbits, offset); |
388 | 0 | encoder.encode(ass); |
389 | 0 | } |
390 | 0 | } |
391 | 0 | } |
392 | 0 | } |
393 | 0 | } |
394 | | |
395 | | // block size used in ProductQuantizer::compute_codes |
396 | | int product_quantizer_compute_codes_bs = 256 * 1024; |
397 | | |
398 | | void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n) |
399 | 0 | const { |
400 | | // process by blocks to avoid using too much RAM |
401 | 0 | size_t bs = product_quantizer_compute_codes_bs; |
402 | 0 | if (n > bs) { |
403 | 0 | for (size_t i0 = 0; i0 < n; i0 += bs) { |
404 | 0 | size_t i1 = std::min(i0 + bs, n); |
405 | 0 | compute_codes(x + d * i0, codes + code_size * i0, i1 - i0); |
406 | 0 | } |
407 | 0 | return; |
408 | 0 | } |
409 | | |
410 | 0 | if (dsub < 16) { // simple direct computation |
411 | |
|
412 | 0 | #pragma omp parallel for |
413 | 0 | for (int64_t i = 0; i < n; i++) |
414 | 0 | compute_code(x + i * d, codes + i * code_size); |
415 | |
|
416 | 0 | } else { // worthwhile to use BLAS |
417 | 0 | std::unique_ptr<float[]> dis_tables(new float[n * ksub * M]); |
418 | 0 | compute_distance_tables(n, x, dis_tables.get()); |
419 | |
|
420 | 0 | #pragma omp parallel for |
421 | 0 | for (int64_t i = 0; i < n; i++) { |
422 | 0 | uint8_t* code = codes + i * code_size; |
423 | 0 | const float* tab = dis_tables.get() + i * ksub * M; |
424 | 0 | compute_code_from_distance_table(tab, code); |
425 | 0 | } |
426 | 0 | } |
427 | 0 | } |
428 | | |
429 | | void ProductQuantizer::compute_distance_table(const float* x, float* dis_table) |
430 | 4 | const { |
431 | 4 | if (transposed_centroids.empty()) { |
432 | | // use regular version |
433 | 20 | for (size_t m = 0; m < M; m++) { |
434 | 16 | fvec_L2sqr_ny( |
435 | 16 | dis_table + m * ksub, |
436 | 16 | x + m * dsub, |
437 | 16 | get_centroids(m, 0), |
438 | 16 | dsub, |
439 | 16 | ksub); |
440 | 16 | } |
441 | 4 | } else { |
442 | | // transposed centroids are available, use'em |
443 | 0 | for (size_t m = 0; m < M; m++) { |
444 | 0 | fvec_L2sqr_ny_transposed( |
445 | 0 | dis_table + m * ksub, |
446 | 0 | x + m * dsub, |
447 | 0 | transposed_centroids.data() + m * ksub, |
448 | 0 | centroids_sq_lengths.data() + m * ksub, |
449 | 0 | dsub, |
450 | 0 | M * ksub, |
451 | 0 | ksub); |
452 | 0 | } |
453 | 0 | } |
454 | 4 | } |
455 | | |
456 | | void ProductQuantizer::compute_inner_prod_table( |
457 | | const float* x, |
458 | 0 | float* dis_table) const { |
459 | 0 | size_t m; |
460 | |
|
461 | 0 | for (m = 0; m < M; m++) { |
462 | 0 | fvec_inner_products_ny( |
463 | 0 | dis_table + m * ksub, |
464 | 0 | x + m * dsub, |
465 | 0 | get_centroids(m, 0), |
466 | 0 | dsub, |
467 | 0 | ksub); |
468 | 0 | } |
469 | 0 | } |
470 | | |
471 | | void ProductQuantizer::compute_distance_tables( |
472 | | size_t nx, |
473 | | const float* x, |
474 | 0 | float* dis_tables) const { |
475 | 0 | #if defined(__AVX2__) || defined(__aarch64__) |
476 | 0 | if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings |
477 | 0 | compute_PQ_dis_tables_dsub2( |
478 | 0 | d, ksub, centroids.data(), nx, x, false, dis_tables); |
479 | 0 | } else |
480 | 0 | #endif |
481 | 0 | if (dsub < 16) { |
482 | |
|
483 | 0 | #pragma omp parallel for if (nx > 1) |
484 | 0 | for (int64_t i = 0; i < nx; i++) { |
485 | 0 | compute_distance_table(x + i * d, dis_tables + i * ksub * M); |
486 | 0 | } |
487 | |
|
488 | 0 | } else { // use BLAS |
489 | |
|
490 | 0 | for (int m = 0; m < M; m++) { |
491 | 0 | pairwise_L2sqr( |
492 | 0 | dsub, |
493 | 0 | nx, |
494 | 0 | x + dsub * m, |
495 | 0 | ksub, |
496 | 0 | centroids.data() + m * dsub * ksub, |
497 | 0 | dis_tables + ksub * m, |
498 | 0 | d, |
499 | 0 | dsub, |
500 | 0 | ksub * M); |
501 | 0 | } |
502 | 0 | } |
503 | 0 | } |
504 | | |
505 | | void ProductQuantizer::compute_inner_prod_tables( |
506 | | size_t nx, |
507 | | const float* x, |
508 | 0 | float* dis_tables) const { |
509 | 0 | #if defined(__AVX2__) || defined(__aarch64__) |
510 | 0 | if (dsub == 2 && nbits < 8) { |
511 | 0 | compute_PQ_dis_tables_dsub2( |
512 | 0 | d, ksub, centroids.data(), nx, x, true, dis_tables); |
513 | 0 | } else |
514 | 0 | #endif |
515 | 0 | if (dsub < 16) { |
516 | |
|
517 | 0 | #pragma omp parallel for if (nx > 1) |
518 | 0 | for (int64_t i = 0; i < nx; i++) { |
519 | 0 | compute_inner_prod_table(x + i * d, dis_tables + i * ksub * M); |
520 | 0 | } |
521 | |
|
522 | 0 | } else { // use BLAS |
523 | | |
524 | | // compute distance tables |
525 | 0 | for (int m = 0; m < M; m++) { |
526 | 0 | FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, dsubi = dsub, |
527 | 0 | di = d; |
528 | 0 | float one = 1.0, zero = 0; |
529 | |
|
530 | 0 | sgemm_("Transposed", |
531 | 0 | "Not transposed", |
532 | 0 | &ksubi, |
533 | 0 | &nxi, |
534 | 0 | &dsubi, |
535 | 0 | &one, |
536 | 0 | ¢roids[m * dsub * ksub], |
537 | 0 | &dsubi, |
538 | 0 | x + dsub * m, |
539 | 0 | &di, |
540 | 0 | &zero, |
541 | 0 | dis_tables + ksub * m, |
542 | 0 | &ldc); |
543 | 0 | } |
544 | 0 | } |
545 | 0 | } |
546 | | |
547 | | /********************************************** |
548 | | * Templatized search functions |
549 | | * The template class C indicates whether to keep the highest or smallest values |
550 | | **********************************************/ |
551 | | |
552 | | namespace { |
553 | | |
554 | | /* compute an estimator using look-up tables for typical values of M */ |
555 | | template <typename CT, class C> |
556 | | void pq_estimators_from_tables_Mmul4( |
557 | | int M, |
558 | | const CT* codes, |
559 | | size_t ncodes, |
560 | | const float* __restrict dis_table, |
561 | | size_t ksub, |
562 | | size_t k, |
563 | | float* heap_dis, |
564 | 0 | int64_t* heap_ids) { |
565 | 0 | for (size_t j = 0; j < ncodes; j++) { |
566 | 0 | float dis = 0; |
567 | 0 | const float* dt = dis_table; |
568 | |
|
569 | 0 | for (size_t m = 0; m < M; m += 4) { |
570 | 0 | float dism = 0; |
571 | 0 | dism = dt[*codes++]; |
572 | 0 | dt += ksub; |
573 | 0 | dism += dt[*codes++]; |
574 | 0 | dt += ksub; |
575 | 0 | dism += dt[*codes++]; |
576 | 0 | dt += ksub; |
577 | 0 | dism += dt[*codes++]; |
578 | 0 | dt += ksub; |
579 | 0 | dis += dism; |
580 | 0 | } |
581 | |
|
582 | 0 | if (C::cmp(heap_dis[0], dis)) { |
583 | 0 | heap_replace_top<C>(k, heap_dis, heap_ids, dis, j); |
584 | 0 | } |
585 | 0 | } |
586 | 0 | } Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_131pq_estimators_from_tables_Mmul4IhNS_4CMaxIflEEEEviPKT_mPKfmmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_131pq_estimators_from_tables_Mmul4ItNS_4CMaxIflEEEEviPKT_mPKfmmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_131pq_estimators_from_tables_Mmul4IhNS_4CMinIflEEEEviPKT_mPKfmmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_131pq_estimators_from_tables_Mmul4ItNS_4CMinIflEEEEviPKT_mPKfmmPfPl |
587 | | |
588 | | template <typename CT, class C> |
589 | | void pq_estimators_from_tables_M4( |
590 | | const CT* codes, |
591 | | size_t ncodes, |
592 | | const float* __restrict dis_table, |
593 | | size_t ksub, |
594 | | size_t k, |
595 | | float* heap_dis, |
596 | 0 | int64_t* heap_ids) { |
597 | 0 | for (size_t j = 0; j < ncodes; j++) { |
598 | 0 | float dis = 0; |
599 | 0 | const float* dt = dis_table; |
600 | 0 | dis = dt[*codes++]; |
601 | 0 | dt += ksub; |
602 | 0 | dis += dt[*codes++]; |
603 | 0 | dt += ksub; |
604 | 0 | dis += dt[*codes++]; |
605 | 0 | dt += ksub; |
606 | 0 | dis += dt[*codes++]; |
607 | |
|
608 | 0 | if (C::cmp(heap_dis[0], dis)) { |
609 | 0 | heap_replace_top<C>(k, heap_dis, heap_ids, dis, j); |
610 | 0 | } |
611 | 0 | } |
612 | 0 | } Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_128pq_estimators_from_tables_M4IhNS_4CMaxIflEEEEvPKT_mPKfmmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_128pq_estimators_from_tables_M4ItNS_4CMaxIflEEEEvPKT_mPKfmmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_128pq_estimators_from_tables_M4IhNS_4CMinIflEEEEvPKT_mPKfmmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_128pq_estimators_from_tables_M4ItNS_4CMinIflEEEEvPKT_mPKfmmPfPl |
613 | | |
614 | | template <typename CT, class C> |
615 | | void pq_estimators_from_tables( |
616 | | const ProductQuantizer& pq, |
617 | | const CT* codes, |
618 | | size_t ncodes, |
619 | | const float* dis_table, |
620 | | size_t k, |
621 | | float* heap_dis, |
622 | 0 | int64_t* heap_ids) { |
623 | 0 | if (pq.M == 4) { |
624 | 0 | pq_estimators_from_tables_M4<CT, C>( |
625 | 0 | codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids); |
626 | 0 | return; |
627 | 0 | } |
628 | | |
629 | 0 | if (pq.M % 4 == 0) { |
630 | 0 | pq_estimators_from_tables_Mmul4<CT, C>( |
631 | 0 | pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids); |
632 | 0 | return; |
633 | 0 | } |
634 | | |
635 | | /* Default is relatively slow */ |
636 | 0 | const size_t M = pq.M; |
637 | 0 | const size_t ksub = pq.ksub; |
638 | 0 | for (size_t j = 0; j < ncodes; j++) { |
639 | 0 | float dis = 0; |
640 | 0 | const float* __restrict dt = dis_table; |
641 | 0 | for (int m = 0; m < M; m++) { |
642 | 0 | dis += dt[*codes++]; |
643 | 0 | dt += ksub; |
644 | 0 | } |
645 | 0 | if (C::cmp(heap_dis[0], dis)) { |
646 | 0 | heap_replace_top<C>(k, heap_dis, heap_ids, dis, j); |
647 | 0 | } |
648 | 0 | } |
649 | 0 | } Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_estimators_from_tablesIhNS_4CMaxIflEEEEvRKNS_16ProductQuantizerEPKT_mPKfmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_estimators_from_tablesItNS_4CMaxIflEEEEvRKNS_16ProductQuantizerEPKT_mPKfmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_estimators_from_tablesIhNS_4CMinIflEEEEvRKNS_16ProductQuantizerEPKT_mPKfmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_estimators_from_tablesItNS_4CMinIflEEEEvRKNS_16ProductQuantizerEPKT_mPKfmPfPl |
650 | | |
651 | | template <class C> |
652 | | void pq_estimators_from_tables_generic( |
653 | | const ProductQuantizer& pq, |
654 | | size_t nbits, |
655 | | const uint8_t* codes, |
656 | | size_t ncodes, |
657 | | const float* dis_table, |
658 | | size_t k, |
659 | | float* heap_dis, |
660 | 0 | int64_t* heap_ids) { |
661 | 0 | const size_t M = pq.M; |
662 | 0 | const size_t ksub = pq.ksub; |
663 | 0 | for (size_t j = 0; j < ncodes; ++j) { |
664 | 0 | PQDecoderGeneric decoder(codes + j * pq.code_size, nbits); |
665 | 0 | float dis = 0; |
666 | 0 | const float* __restrict dt = dis_table; |
667 | 0 | for (size_t m = 0; m < M; m++) { |
668 | 0 | uint64_t c = decoder.decode(); |
669 | 0 | dis += dt[c]; |
670 | 0 | dt += ksub; |
671 | 0 | } |
672 | |
|
673 | 0 | if (C::cmp(heap_dis[0], dis)) { |
674 | 0 | heap_replace_top<C>(k, heap_dis, heap_ids, dis, j); |
675 | 0 | } |
676 | 0 | } |
677 | 0 | } Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_133pq_estimators_from_tables_genericINS_4CMaxIflEEEEvRKNS_16ProductQuantizerEmPKhmPKfmPfPl Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_133pq_estimators_from_tables_genericINS_4CMinIflEEEEvRKNS_16ProductQuantizerEmPKhmPKfmPfPl |
678 | | |
679 | | template <class C> |
680 | | void pq_knn_search_with_tables( |
681 | | const ProductQuantizer& pq, |
682 | | size_t nbits, |
683 | | const float* dis_tables, |
684 | | const uint8_t* codes, |
685 | | const size_t ncodes, |
686 | | HeapArray<C>* res, |
687 | 0 | bool init_finalize_heap) { |
688 | 0 | size_t k = res->k, nx = res->nh; |
689 | 0 | size_t ksub = pq.ksub, M = pq.M; |
690 | |
|
691 | 0 | #pragma omp parallel for if (nx > 1) |
692 | 0 | for (int64_t i = 0; i < nx; i++) { |
693 | | /* query preparation for asymmetric search: compute look-up tables */ |
694 | 0 | const float* dis_table = dis_tables + i * ksub * M; |
695 | | |
696 | | /* Compute distances and keep smallest values */ |
697 | 0 | int64_t* __restrict heap_ids = res->ids + i * k; |
698 | 0 | float* __restrict heap_dis = res->val + i * k; |
699 | |
|
700 | 0 | if (init_finalize_heap) { |
701 | 0 | heap_heapify<C>(k, heap_dis, heap_ids); |
702 | 0 | } |
703 | |
|
704 | 0 | switch (nbits) { |
705 | 0 | case 8: |
706 | 0 | pq_estimators_from_tables<uint8_t, C>( |
707 | 0 | pq, codes, ncodes, dis_table, k, heap_dis, heap_ids); |
708 | 0 | break; |
709 | | |
710 | 0 | case 16: |
711 | 0 | pq_estimators_from_tables<uint16_t, C>( |
712 | 0 | pq, |
713 | 0 | (uint16_t*)codes, |
714 | 0 | ncodes, |
715 | 0 | dis_table, |
716 | 0 | k, |
717 | 0 | heap_dis, |
718 | 0 | heap_ids); |
719 | 0 | break; |
720 | | |
721 | 0 | default: |
722 | 0 | pq_estimators_from_tables_generic<C>( |
723 | 0 | pq, |
724 | 0 | nbits, |
725 | 0 | codes, |
726 | 0 | ncodes, |
727 | 0 | dis_table, |
728 | 0 | k, |
729 | 0 | heap_dis, |
730 | 0 | heap_ids); |
731 | 0 | break; |
732 | 0 | } |
733 | | |
734 | 0 | if (init_finalize_heap) { |
735 | 0 | heap_reorder<C>(k, heap_dis, heap_ids); |
736 | 0 | } |
737 | 0 | } Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_knn_search_with_tablesINS_4CMaxIflEEEEvRKNS_16ProductQuantizerEmPKfPKhmPNS_9HeapArrayIT_EEb.omp_outlined_debug__ Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_knn_search_with_tablesINS_4CMinIflEEEEvRKNS_16ProductQuantizerEmPKfPKhmPNS_9HeapArrayIT_EEb.omp_outlined_debug__ |
738 | 0 | } Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_knn_search_with_tablesINS_4CMaxIflEEEEvRKNS_16ProductQuantizerEmPKfPKhmPNS_9HeapArrayIT_EEb Unexecuted instantiation: ProductQuantizer.cpp:_ZN5faiss12_GLOBAL__N_125pq_knn_search_with_tablesINS_4CMinIflEEEEvRKNS_16ProductQuantizerEmPKfPKhmPNS_9HeapArrayIT_EEb |
739 | | |
740 | | } // anonymous namespace |
741 | | |
742 | | void ProductQuantizer::search( |
743 | | const float* __restrict x, |
744 | | size_t nx, |
745 | | const uint8_t* codes, |
746 | | const size_t ncodes, |
747 | | float_maxheap_array_t* res, |
748 | 0 | bool init_finalize_heap) const { |
749 | 0 | FAISS_THROW_IF_NOT(nx == res->nh); |
750 | 0 | std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]); |
751 | 0 | compute_distance_tables(nx, x, dis_tables.get()); |
752 | |
|
753 | 0 | pq_knn_search_with_tables<CMax<float, int64_t>>( |
754 | 0 | *this, |
755 | 0 | nbits, |
756 | 0 | dis_tables.get(), |
757 | 0 | codes, |
758 | 0 | ncodes, |
759 | 0 | res, |
760 | 0 | init_finalize_heap); |
761 | 0 | } |
762 | | |
763 | | void ProductQuantizer::search_ip( |
764 | | const float* __restrict x, |
765 | | size_t nx, |
766 | | const uint8_t* codes, |
767 | | const size_t ncodes, |
768 | | float_minheap_array_t* res, |
769 | 0 | bool init_finalize_heap) const { |
770 | 0 | FAISS_THROW_IF_NOT(nx == res->nh); |
771 | 0 | std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]); |
772 | 0 | compute_inner_prod_tables(nx, x, dis_tables.get()); |
773 | |
|
774 | 0 | pq_knn_search_with_tables<CMin<float, int64_t>>( |
775 | 0 | *this, |
776 | 0 | nbits, |
777 | 0 | dis_tables.get(), |
778 | 0 | codes, |
779 | 0 | ncodes, |
780 | 0 | res, |
781 | 0 | init_finalize_heap); |
782 | 0 | } |
783 | | |
784 | 1 | void ProductQuantizer::compute_sdc_table() { |
785 | 1 | sdc_table.resize(M * ksub * ksub); |
786 | | |
787 | 1 | if (dsub < 4) { |
788 | 0 | #pragma omp parallel for |
789 | 0 | for (int mk = 0; mk < M * ksub; mk++) { |
790 | | // allow omp to schedule in a more fine-grained way |
791 | | // `collapse` is not supported in OpenMP 2.x |
792 | 0 | int m = mk / ksub; |
793 | 0 | int k = mk % ksub; |
794 | 0 | const float* cents = centroids.data() + m * ksub * dsub; |
795 | 0 | const float* centi = cents + k * dsub; |
796 | 0 | float* dis_tab = sdc_table.data() + m * ksub * ksub; |
797 | 0 | fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub); |
798 | 0 | } |
799 | 1 | } else { |
800 | | // NOTE: it would disable the omp loop in pairwise_L2sqr |
801 | | // but still accelerate especially when M >= 4 |
802 | 1 | #pragma omp parallel for |
803 | 6 | for (int m = 0; m < M; m++) { |
804 | 3 | const float* cents = centroids.data() + m * ksub * dsub; |
805 | 3 | float* dis_tab = sdc_table.data() + m * ksub * ksub; |
806 | 3 | pairwise_L2sqr( |
807 | 3 | dsub, ksub, cents, ksub, cents, dis_tab, dsub, dsub, ksub); |
808 | 3 | } |
809 | 1 | } |
810 | 1 | } |
811 | | |
812 | | void ProductQuantizer::search_sdc( |
813 | | const uint8_t* qcodes, |
814 | | size_t nq, |
815 | | const uint8_t* bcodes, |
816 | | const size_t nb, |
817 | | float_maxheap_array_t* res, |
818 | 0 | bool init_finalize_heap) const { |
819 | 0 | FAISS_THROW_IF_NOT(sdc_table.size() == M * ksub * ksub); |
820 | 0 | FAISS_THROW_IF_NOT(nbits == 8); |
821 | 0 | size_t k = res->k; |
822 | |
|
823 | 0 | #pragma omp parallel for |
824 | 0 | for (int64_t i = 0; i < nq; i++) { |
825 | | /* Compute distances and keep smallest values */ |
826 | 0 | idx_t* heap_ids = res->ids + i * k; |
827 | 0 | float* heap_dis = res->val + i * k; |
828 | 0 | const uint8_t* qcode = qcodes + i * code_size; |
829 | |
|
830 | 0 | if (init_finalize_heap) |
831 | 0 | maxheap_heapify(k, heap_dis, heap_ids); |
832 | |
|
833 | 0 | const uint8_t* bcode = bcodes; |
834 | 0 | for (size_t j = 0; j < nb; j++) { |
835 | 0 | float dis = 0; |
836 | 0 | const float* tab = sdc_table.data(); |
837 | 0 | for (int m = 0; m < M; m++) { |
838 | 0 | dis += tab[bcode[m] + qcode[m] * ksub]; |
839 | 0 | tab += ksub * ksub; |
840 | 0 | } |
841 | 0 | if (dis < heap_dis[0]) { |
842 | 0 | maxheap_replace_top(k, heap_dis, heap_ids, dis, j); |
843 | 0 | } |
844 | 0 | bcode += code_size; |
845 | 0 | } |
846 | |
|
847 | 0 | if (init_finalize_heap) |
848 | 0 | maxheap_reorder(k, heap_dis, heap_ids); |
849 | 0 | } |
850 | 0 | } |
851 | | |
852 | 0 | void ProductQuantizer::sync_transposed_centroids() { |
853 | 0 | transposed_centroids.resize(d * ksub); |
854 | 0 | centroids_sq_lengths.resize(ksub * M); |
855 | |
|
856 | 0 | for (size_t mi = 0; mi < M; mi++) { |
857 | 0 | for (size_t ki = 0; ki < ksub; ki++) { |
858 | 0 | float sqlen = 0; |
859 | |
|
860 | 0 | for (size_t di = 0; di < dsub; di++) { |
861 | 0 | const float q = centroids[(mi * ksub + ki) * dsub + di]; |
862 | |
|
863 | 0 | transposed_centroids[(di * M + mi) * ksub + ki] = q; |
864 | 0 | sqlen += q * q; |
865 | 0 | } |
866 | |
|
867 | 0 | centroids_sq_lengths[mi * ksub + ki] = sqlen; |
868 | 0 | } |
869 | 0 | } |
870 | 0 | } |
871 | | |
872 | 0 | void ProductQuantizer::clear_transposed_centroids() { |
873 | 0 | transposed_centroids.clear(); |
874 | 0 | transposed_centroids.shrink_to_fit(); |
875 | |
|
876 | 0 | centroids_sq_lengths.clear(); |
877 | 0 | centroids_sq_lengths.shrink_to_fit(); |
878 | 0 | } |
879 | | |
880 | | } // namespace faiss |