/root/doris/contrib/faiss/faiss/impl/AdditiveQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/impl/AdditiveQuantizer.h> |
11 | | |
12 | | #include <cstddef> |
13 | | #include <cstdio> |
14 | | #include <cstring> |
15 | | #include <memory> |
16 | | #include <random> |
17 | | |
18 | | #include <algorithm> |
19 | | |
20 | | #include <faiss/Clustering.h> |
21 | | #include <faiss/impl/FaissAssert.h> |
22 | | #include <faiss/impl/LocalSearchQuantizer.h> |
23 | | #include <faiss/impl/ResidualQuantizer.h> |
24 | | #include <faiss/utils/Heap.h> |
25 | | #include <faiss/utils/distances.h> |
26 | | #include <faiss/utils/hamming.h> |
27 | | |
28 | | extern "C" { |
29 | | |
30 | | // general matrix multiplication |
31 | | int sgemm_( |
32 | | const char* transa, |
33 | | const char* transb, |
34 | | FINTEGER* m, |
35 | | FINTEGER* n, |
36 | | FINTEGER* k, |
37 | | const float* alpha, |
38 | | const float* a, |
39 | | FINTEGER* lda, |
40 | | const float* b, |
41 | | FINTEGER* ldb, |
42 | | float* beta, |
43 | | float* c, |
44 | | FINTEGER* ldc); |
45 | | } |
46 | | |
47 | | namespace faiss { |
48 | | |
49 | | AdditiveQuantizer::AdditiveQuantizer( |
50 | | size_t d, |
51 | | const std::vector<size_t>& nbits, |
52 | | Search_type_t search_type) |
53 | 0 | : Quantizer(d), |
54 | 0 | M(nbits.size()), |
55 | 0 | nbits(nbits), |
56 | 0 | search_type(search_type) { |
57 | 0 | set_derived_values(); |
58 | 0 | } |
59 | | |
60 | | AdditiveQuantizer::AdditiveQuantizer() |
61 | 0 | : AdditiveQuantizer(0, std::vector<size_t>()) {} |
62 | | |
63 | 0 | void AdditiveQuantizer::set_derived_values() { |
64 | 0 | tot_bits = 0; |
65 | 0 | only_8bit = true; |
66 | 0 | codebook_offsets.resize(M + 1, 0); |
67 | 0 | for (int i = 0; i < M; i++) { |
68 | 0 | int nbit = nbits[i]; |
69 | 0 | size_t k = 1 << nbit; |
70 | 0 | codebook_offsets[i + 1] = codebook_offsets[i] + k; |
71 | 0 | tot_bits += nbit; |
72 | 0 | if (nbit != 0) { |
73 | 0 | only_8bit = false; |
74 | 0 | } |
75 | 0 | } |
76 | 0 | total_codebook_size = codebook_offsets[M]; |
77 | 0 | switch (search_type) { |
78 | 0 | case ST_norm_float: |
79 | 0 | norm_bits = 32; |
80 | 0 | break; |
81 | 0 | case ST_norm_qint8: |
82 | 0 | case ST_norm_cqint8: |
83 | 0 | case ST_norm_lsq2x4: |
84 | 0 | case ST_norm_rq2x4: |
85 | 0 | norm_bits = 8; |
86 | 0 | break; |
87 | 0 | case ST_norm_qint4: |
88 | 0 | case ST_norm_cqint4: |
89 | 0 | norm_bits = 4; |
90 | 0 | break; |
91 | 0 | case ST_decompress: |
92 | 0 | case ST_LUT_nonorm: |
93 | 0 | case ST_norm_from_LUT: |
94 | 0 | default: |
95 | 0 | norm_bits = 0; |
96 | 0 | break; |
97 | 0 | } |
98 | 0 | tot_bits += norm_bits; |
99 | | |
100 | | // convert bits to bytes |
101 | 0 | code_size = (tot_bits + 7) / 8; |
102 | 0 | } |
103 | | |
104 | 0 | void AdditiveQuantizer::train_norm(size_t n, const float* norms) { |
105 | 0 | norm_min = HUGE_VALF; |
106 | 0 | norm_max = -HUGE_VALF; |
107 | 0 | for (idx_t i = 0; i < n; i++) { |
108 | 0 | if (norms[i] < norm_min) { |
109 | 0 | norm_min = norms[i]; |
110 | 0 | } |
111 | 0 | if (norms[i] > norm_max) { |
112 | 0 | norm_max = norms[i]; |
113 | 0 | } |
114 | 0 | } |
115 | |
|
116 | 0 | if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) { |
117 | 0 | size_t k = (1 << 8); |
118 | 0 | if (search_type == ST_norm_cqint4) { |
119 | 0 | k = (1 << 4); |
120 | 0 | } |
121 | 0 | Clustering1D clus(k); |
122 | 0 | clus.train_exact(n, norms); |
123 | 0 | qnorm.add(clus.k, clus.centroids.data()); |
124 | 0 | } else if (search_type == ST_norm_lsq2x4 || search_type == ST_norm_rq2x4) { |
125 | 0 | std::unique_ptr<AdditiveQuantizer> aq; |
126 | 0 | if (search_type == ST_norm_lsq2x4) { |
127 | 0 | aq.reset(new LocalSearchQuantizer(1, 2, 4)); |
128 | 0 | } else { |
129 | 0 | aq.reset(new ResidualQuantizer(1, 2, 4)); |
130 | 0 | } |
131 | |
|
132 | 0 | aq->train(n, norms); |
133 | | // flatten aq codebooks |
134 | 0 | std::vector<float> flat_codebooks(1 << 8); |
135 | 0 | FAISS_THROW_IF_NOT(aq->codebooks.size() == 32); |
136 | | |
137 | | // save norm tables for 4-bit fastscan search |
138 | 0 | norm_tabs = aq->codebooks; |
139 | | |
140 | | // assume big endian |
141 | 0 | const float* c = norm_tabs.data(); |
142 | 0 | for (size_t i = 0; i < 16; i++) { |
143 | 0 | for (size_t j = 0; j < 16; j++) { |
144 | 0 | flat_codebooks[i * 16 + j] = c[j] + c[16 + i]; |
145 | 0 | } |
146 | 0 | } |
147 | |
|
148 | 0 | qnorm.reset(); |
149 | 0 | qnorm.add(1 << 8, flat_codebooks.data()); |
150 | 0 | FAISS_THROW_IF_NOT(qnorm.ntotal == (1 << 8)); |
151 | 0 | } |
152 | 0 | } |
153 | | |
154 | 0 | void AdditiveQuantizer::compute_codebook_tables() { |
155 | 0 | centroid_norms.resize(total_codebook_size); |
156 | 0 | fvec_norms_L2sqr( |
157 | 0 | centroid_norms.data(), codebooks.data(), d, total_codebook_size); |
158 | 0 | size_t cross_table_size = 0; |
159 | 0 | for (int m = 0; m < M; m++) { |
160 | 0 | size_t K = (size_t)1 << nbits[m]; |
161 | 0 | cross_table_size += K * codebook_offsets[m]; |
162 | 0 | } |
163 | 0 | codebook_cross_products.resize(cross_table_size); |
164 | 0 | size_t ofs = 0; |
165 | 0 | for (int m = 1; m < M; m++) { |
166 | 0 | FINTEGER ki = (size_t)1 << nbits[m]; |
167 | 0 | FINTEGER kk = codebook_offsets[m]; |
168 | 0 | FINTEGER di = d; |
169 | 0 | float zero = 0, one = 1; |
170 | 0 | assert(ofs + ki * kk <= cross_table_size); |
171 | 0 | sgemm_("Transposed", |
172 | 0 | "Not transposed", |
173 | 0 | &ki, |
174 | 0 | &kk, |
175 | 0 | &di, |
176 | 0 | &one, |
177 | 0 | codebooks.data() + d * kk, |
178 | 0 | &di, |
179 | 0 | codebooks.data(), |
180 | 0 | &di, |
181 | 0 | &zero, |
182 | 0 | codebook_cross_products.data() + ofs, |
183 | 0 | &ki); |
184 | 0 | ofs += ki * kk; |
185 | 0 | } |
186 | 0 | } |
187 | | |
188 | | namespace { |
189 | | |
190 | | // TODO |
191 | | // https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0 |
192 | | |
193 | 0 | uint8_t encode_qint8(float x, float amin, float amax) { |
194 | 0 | float x1 = (x - amin) / (amax - amin) * 256; |
195 | 0 | int32_t xi = int32_t(floor(x1)); |
196 | |
|
197 | 0 | return xi < 0 ? 0 : xi > 255 ? 255 : xi; |
198 | 0 | } |
199 | | |
200 | 0 | uint8_t encode_qint4(float x, float amin, float amax) { |
201 | 0 | float x1 = (x - amin) / (amax - amin) * 16; |
202 | 0 | int32_t xi = int32_t(floor(x1)); |
203 | |
|
204 | 0 | return xi < 0 ? 0 : xi > 15 ? 15 : xi; |
205 | 0 | } |
206 | | |
207 | 0 | float decode_qint8(uint8_t i, float amin, float amax) { |
208 | 0 | return (i + 0.5) / 256 * (amax - amin) + amin; |
209 | 0 | } |
210 | | |
211 | 0 | float decode_qint4(uint8_t i, float amin, float amax) { |
212 | 0 | return (i + 0.5) / 16 * (amax - amin) + amin; |
213 | 0 | } |
214 | | |
215 | | } // anonymous namespace |
216 | | |
217 | 0 | uint32_t AdditiveQuantizer::encode_qcint(float x) const { |
218 | 0 | idx_t id; |
219 | 0 | qnorm.assign(1, &x, &id, 1); |
220 | 0 | return uint32_t(id); |
221 | 0 | } |
222 | | |
223 | 0 | float AdditiveQuantizer::decode_qcint(uint32_t c) const { |
224 | 0 | return qnorm.get_xb()[c]; |
225 | 0 | } |
226 | | |
227 | 0 | uint64_t AdditiveQuantizer::encode_norm(float norm) const { |
228 | 0 | switch (search_type) { |
229 | 0 | case ST_norm_float: |
230 | 0 | uint32_t inorm; |
231 | 0 | memcpy(&inorm, &norm, 4); |
232 | 0 | return inorm; |
233 | 0 | case ST_norm_qint8: |
234 | 0 | return encode_qint8(norm, norm_min, norm_max); |
235 | 0 | case ST_norm_qint4: |
236 | 0 | return encode_qint4(norm, norm_min, norm_max); |
237 | 0 | case ST_norm_lsq2x4: |
238 | 0 | case ST_norm_rq2x4: |
239 | 0 | case ST_norm_cqint8: |
240 | 0 | return encode_qcint(norm); |
241 | 0 | case ST_norm_cqint4: |
242 | 0 | return encode_qcint(norm); |
243 | 0 | case ST_decompress: |
244 | 0 | case ST_LUT_nonorm: |
245 | 0 | case ST_norm_from_LUT: |
246 | 0 | default: |
247 | 0 | return 0; |
248 | 0 | } |
249 | 0 | } |
250 | | |
251 | | void AdditiveQuantizer::pack_codes( |
252 | | size_t n, |
253 | | const int32_t* codes, |
254 | | uint8_t* packed_codes, |
255 | | int64_t ld_codes, |
256 | | const float* norms, |
257 | 0 | const float* centroids) const { |
258 | 0 | if (ld_codes == -1) { |
259 | 0 | ld_codes = M; |
260 | 0 | } |
261 | 0 | std::vector<float> norm_buf; |
262 | 0 | if (search_type == ST_norm_float || search_type == ST_norm_qint4 || |
263 | 0 | search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 || |
264 | 0 | search_type == ST_norm_cqint4 || search_type == ST_norm_lsq2x4 || |
265 | 0 | search_type == ST_norm_rq2x4) { |
266 | 0 | if (centroids != nullptr || !norms) { |
267 | 0 | norm_buf.resize(n); |
268 | 0 | std::vector<float> x_recons(n * d); |
269 | 0 | decode_unpacked(codes, x_recons.data(), n, ld_codes); |
270 | |
|
271 | 0 | if (centroids != nullptr) { |
272 | | // x = x + c |
273 | 0 | fvec_add(n * d, x_recons.data(), centroids, x_recons.data()); |
274 | 0 | } |
275 | 0 | fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n); |
276 | 0 | norms = norm_buf.data(); |
277 | 0 | } |
278 | 0 | } |
279 | 0 | #pragma omp parallel for if (n > 1000) |
280 | 0 | for (int64_t i = 0; i < n; i++) { |
281 | 0 | const int32_t* codes1 = codes + i * ld_codes; |
282 | 0 | BitstringWriter bsw(packed_codes + i * code_size, code_size); |
283 | 0 | for (int m = 0; m < M; m++) { |
284 | 0 | bsw.write(codes1[m], nbits[m]); |
285 | 0 | } |
286 | 0 | if (norm_bits != 0) { |
287 | 0 | bsw.write(encode_norm(norms[i]), norm_bits); |
288 | 0 | } |
289 | 0 | } |
290 | 0 | } |
291 | | |
292 | 0 | void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const { |
293 | 0 | FAISS_THROW_IF_NOT_MSG( |
294 | 0 | is_trained, "The additive quantizer is not trained yet."); |
295 | | |
296 | | // standard additive quantizer decoding |
297 | 0 | #pragma omp parallel for if (n > 100) |
298 | 0 | for (int64_t i = 0; i < n; i++) { |
299 | 0 | BitstringReader bsr(code + i * code_size, code_size); |
300 | 0 | float* xi = x + i * d; |
301 | 0 | for (int m = 0; m < M; m++) { |
302 | 0 | int idx = bsr.read(nbits[m]); |
303 | 0 | const float* c = codebooks.data() + d * (codebook_offsets[m] + idx); |
304 | 0 | if (m == 0) { |
305 | 0 | memcpy(xi, c, sizeof(*x) * d); |
306 | 0 | } else { |
307 | 0 | fvec_add(d, xi, c, xi); |
308 | 0 | } |
309 | 0 | } |
310 | 0 | } |
311 | 0 | } |
312 | | |
313 | | void AdditiveQuantizer::decode_unpacked( |
314 | | const int32_t* code, |
315 | | float* x, |
316 | | size_t n, |
317 | 0 | int64_t ld_codes) const { |
318 | 0 | FAISS_THROW_IF_NOT_MSG( |
319 | 0 | is_trained, "The additive quantizer is not trained yet."); |
320 | | |
321 | 0 | if (ld_codes == -1) { |
322 | 0 | ld_codes = M; |
323 | 0 | } |
324 | | |
325 | | // standard additive quantizer decoding |
326 | 0 | #pragma omp parallel for if (n > 1000) |
327 | 0 | for (int64_t i = 0; i < n; i++) { |
328 | 0 | const int32_t* codesi = code + i * ld_codes; |
329 | 0 | float* xi = x + i * d; |
330 | 0 | for (int m = 0; m < M; m++) { |
331 | 0 | int idx = codesi[m]; |
332 | 0 | const float* c = codebooks.data() + d * (codebook_offsets[m] + idx); |
333 | 0 | if (m == 0) { |
334 | 0 | memcpy(xi, c, sizeof(*x) * d); |
335 | 0 | } else { |
336 | 0 | fvec_add(d, xi, c, xi); |
337 | 0 | } |
338 | 0 | } |
339 | 0 | } |
340 | 0 | } |
341 | | |
342 | 0 | AdditiveQuantizer::~AdditiveQuantizer() {} |
343 | | |
344 | | /**************************************************************************** |
345 | | * Support for fast distance computations in centroids |
346 | | ****************************************************************************/ |
347 | | |
348 | 0 | void AdditiveQuantizer::compute_centroid_norms(float* norms) const { |
349 | 0 | size_t ntotal = (size_t)1 << tot_bits; |
350 | | // TODO: make tree of partial sums |
351 | 0 | #pragma omp parallel |
352 | 0 | { |
353 | 0 | std::vector<float> tmp(d); |
354 | 0 | #pragma omp for |
355 | 0 | for (int64_t i = 0; i < ntotal; i++) { |
356 | 0 | decode_64bit(i, tmp.data()); |
357 | 0 | norms[i] = fvec_norm_L2sqr(tmp.data(), d); |
358 | 0 | } |
359 | 0 | } |
360 | 0 | } |
361 | | |
362 | 0 | void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const { |
363 | 0 | for (int m = 0; m < M; m++) { |
364 | 0 | idx_t idx = bits & (((size_t)1 << nbits[m]) - 1); |
365 | 0 | bits >>= nbits[m]; |
366 | 0 | const float* c = codebooks.data() + d * (codebook_offsets[m] + idx); |
367 | 0 | if (m == 0) { |
368 | 0 | memcpy(xi, c, sizeof(*xi) * d); |
369 | 0 | } else { |
370 | 0 | fvec_add(d, xi, c, xi); |
371 | 0 | } |
372 | 0 | } |
373 | 0 | } |
374 | | |
375 | | void AdditiveQuantizer::compute_LUT( |
376 | | size_t n, |
377 | | const float* xq, |
378 | | float* LUT, |
379 | | float alpha, |
380 | 0 | long ld_lut) const { |
381 | | // in all cases, it is large matrix multiplication |
382 | |
|
383 | 0 | FINTEGER ncenti = total_codebook_size; |
384 | 0 | FINTEGER di = d; |
385 | 0 | FINTEGER nqi = n; |
386 | 0 | FINTEGER ldc = ld_lut > 0 ? ld_lut : ncenti; |
387 | 0 | float zero = 0; |
388 | |
|
389 | 0 | sgemm_("Transposed", |
390 | 0 | "Not transposed", |
391 | 0 | &ncenti, |
392 | 0 | &nqi, |
393 | 0 | &di, |
394 | 0 | &alpha, |
395 | 0 | codebooks.data(), |
396 | 0 | &di, |
397 | 0 | xq, |
398 | 0 | &di, |
399 | 0 | &zero, |
400 | 0 | LUT, |
401 | 0 | &ldc); |
402 | 0 | } |
403 | | |
404 | | namespace { |
405 | | |
406 | | /* compute inner products of one query with all centroids, given a look-up |
407 | | * table of all inner producst with codebook entries */ |
408 | | void compute_inner_prod_with_LUT( |
409 | | const AdditiveQuantizer& aq, |
410 | | const float* LUT, |
411 | 0 | float* ips) { |
412 | 0 | size_t prev_size = 1; |
413 | 0 | for (int m = 0; m < aq.M; m++) { |
414 | 0 | const float* LUTm = LUT + aq.codebook_offsets[m]; |
415 | 0 | int nb = aq.nbits[m]; |
416 | 0 | size_t nc = (size_t)1 << nb; |
417 | |
|
418 | 0 | if (m == 0) { |
419 | 0 | memcpy(ips, LUT, sizeof(*ips) * nc); |
420 | 0 | } else { |
421 | 0 | for (int64_t i = nc - 1; i >= 0; i--) { |
422 | 0 | float v = LUTm[i]; |
423 | 0 | fvec_add(prev_size, ips, v, ips + i * prev_size); |
424 | 0 | } |
425 | 0 | } |
426 | 0 | prev_size *= nc; |
427 | 0 | } |
428 | 0 | } |
429 | | |
430 | | } // anonymous namespace |
431 | | |
432 | | void AdditiveQuantizer::knn_centroids_inner_product( |
433 | | idx_t n, |
434 | | const float* xq, |
435 | | idx_t k, |
436 | | float* distances, |
437 | 0 | idx_t* labels) const { |
438 | 0 | std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]); |
439 | 0 | compute_LUT(n, xq, LUT.get()); |
440 | 0 | size_t ntotal = (size_t)1 << tot_bits; |
441 | |
|
442 | 0 | #pragma omp parallel if (n > 100) |
443 | 0 | { |
444 | 0 | std::vector<float> dis(ntotal); |
445 | 0 | #pragma omp for |
446 | 0 | for (idx_t i = 0; i < n; i++) { |
447 | 0 | const float* LUTi = LUT.get() + i * total_codebook_size; |
448 | 0 | compute_inner_prod_with_LUT(*this, LUTi, dis.data()); |
449 | 0 | float* distances_i = distances + i * k; |
450 | 0 | idx_t* labels_i = labels + i * k; |
451 | 0 | minheap_heapify(k, distances_i, labels_i); |
452 | 0 | minheap_addn(k, distances_i, labels_i, dis.data(), nullptr, ntotal); |
453 | 0 | minheap_reorder(k, distances_i, labels_i); |
454 | 0 | } |
455 | 0 | } |
456 | 0 | } |
457 | | |
458 | | void AdditiveQuantizer::knn_centroids_L2( |
459 | | idx_t n, |
460 | | const float* xq, |
461 | | idx_t k, |
462 | | float* distances, |
463 | | idx_t* labels, |
464 | 0 | const float* norms) const { |
465 | 0 | std::unique_ptr<float[]> LUT(new float[n * total_codebook_size]); |
466 | 0 | compute_LUT(n, xq, LUT.get()); |
467 | 0 | std::unique_ptr<float[]> q_norms(new float[n]); |
468 | 0 | fvec_norms_L2sqr(q_norms.get(), xq, d, n); |
469 | 0 | size_t ntotal = (size_t)1 << tot_bits; |
470 | |
|
471 | 0 | #pragma omp parallel if (n > 100) |
472 | 0 | { |
473 | 0 | std::vector<float> dis(ntotal); |
474 | 0 | #pragma omp for |
475 | 0 | for (idx_t i = 0; i < n; i++) { |
476 | 0 | const float* LUTi = LUT.get() + i * total_codebook_size; |
477 | 0 | float* distances_i = distances + i * k; |
478 | 0 | idx_t* labels_i = labels + i * k; |
479 | |
|
480 | 0 | compute_inner_prod_with_LUT(*this, LUTi, dis.data()); |
481 | | |
482 | | // update distances using |
483 | | // ||x - y||^2 = ||x||^2 + ||y||^2 - 2 * <x,y> |
484 | |
|
485 | 0 | maxheap_heapify(k, distances_i, labels_i); |
486 | 0 | for (idx_t j = 0; j < ntotal; j++) { |
487 | 0 | float disj = q_norms[i] + norms[j] - 2 * dis[j]; |
488 | 0 | if (disj < distances_i[0]) { |
489 | 0 | heap_replace_top<CMax<float, int64_t>>( |
490 | 0 | k, distances_i, labels_i, disj, j); |
491 | 0 | } |
492 | 0 | } |
493 | 0 | maxheap_reorder(k, distances_i, labels_i); |
494 | 0 | } |
495 | 0 | } |
496 | 0 | } |
497 | | |
498 | | /**************************************************************************** |
499 | | * Support for fast distance computations in codes |
500 | | ****************************************************************************/ |
501 | | |
502 | | namespace { |
503 | | |
504 | | float accumulate_IPs( |
505 | | const AdditiveQuantizer& aq, |
506 | | BitstringReader& bs, |
507 | 0 | const float* LUT) { |
508 | 0 | float accu = 0; |
509 | 0 | for (int m = 0; m < aq.M; m++) { |
510 | 0 | size_t nbit = aq.nbits[m]; |
511 | 0 | int idx = bs.read(nbit); |
512 | 0 | accu += LUT[idx]; |
513 | 0 | LUT += (uint64_t)1 << nbit; |
514 | 0 | } |
515 | 0 | return accu; |
516 | 0 | } |
517 | | |
518 | 0 | float compute_norm_from_LUT(const AdditiveQuantizer& aq, BitstringReader& bs) { |
519 | 0 | float accu = 0; |
520 | 0 | std::vector<int> idx(aq.M); |
521 | 0 | const float* c = aq.codebook_cross_products.data(); |
522 | 0 | for (int m = 0; m < aq.M; m++) { |
523 | 0 | size_t nbit = aq.nbits[m]; |
524 | 0 | int i = bs.read(nbit); |
525 | 0 | size_t K = 1 << nbit; |
526 | 0 | idx[m] = i; |
527 | |
|
528 | 0 | accu += aq.centroid_norms[aq.codebook_offsets[m] + i]; |
529 | |
|
530 | 0 | for (int l = 0; l < m; l++) { |
531 | 0 | int j = idx[l]; |
532 | 0 | accu += 2 * c[j * K + i]; |
533 | 0 | c += (1 << aq.nbits[l]) * K; |
534 | 0 | } |
535 | 0 | } |
536 | | // FAISS_THROW_IF_NOT(c == aq.codebook_cross_products.data() + |
537 | | // aq.codebook_cross_products.size()); |
538 | 0 | return accu; |
539 | 0 | } |
540 | | |
541 | | } // anonymous namespace |
542 | | |
543 | | template <> |
544 | | float AdditiveQuantizer:: |
545 | | compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>( |
546 | | const uint8_t* codes, |
547 | 0 | const float* LUT) const { |
548 | 0 | BitstringReader bs(codes, code_size); |
549 | 0 | return accumulate_IPs(*this, bs, LUT); |
550 | 0 | } |
551 | | |
552 | | template <> |
553 | | float AdditiveQuantizer:: |
554 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>( |
555 | | const uint8_t* codes, |
556 | 0 | const float* LUT) const { |
557 | 0 | BitstringReader bs(codes, code_size); |
558 | 0 | return -accumulate_IPs(*this, bs, LUT); |
559 | 0 | } |
560 | | |
561 | | template <> |
562 | | float AdditiveQuantizer:: |
563 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>( |
564 | | const uint8_t* codes, |
565 | 0 | const float* LUT) const { |
566 | 0 | BitstringReader bs(codes, code_size); |
567 | 0 | float accu = accumulate_IPs(*this, bs, LUT); |
568 | 0 | uint32_t norm_i = bs.read(32); |
569 | 0 | float norm2; |
570 | 0 | memcpy(&norm2, &norm_i, 4); |
571 | 0 | return norm2 - 2 * accu; |
572 | 0 | } |
573 | | |
574 | | template <> |
575 | | float AdditiveQuantizer:: |
576 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>( |
577 | | const uint8_t* codes, |
578 | 0 | const float* LUT) const { |
579 | 0 | BitstringReader bs(codes, code_size); |
580 | 0 | float accu = accumulate_IPs(*this, bs, LUT); |
581 | 0 | uint32_t norm_i = bs.read(8); |
582 | 0 | float norm2 = decode_qcint(norm_i); |
583 | 0 | return norm2 - 2 * accu; |
584 | 0 | } |
585 | | |
586 | | template <> |
587 | | float AdditiveQuantizer:: |
588 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>( |
589 | | const uint8_t* codes, |
590 | 0 | const float* LUT) const { |
591 | 0 | BitstringReader bs(codes, code_size); |
592 | 0 | float accu = accumulate_IPs(*this, bs, LUT); |
593 | 0 | uint32_t norm_i = bs.read(4); |
594 | 0 | float norm2 = decode_qcint(norm_i); |
595 | 0 | return norm2 - 2 * accu; |
596 | 0 | } |
597 | | |
598 | | template <> |
599 | | float AdditiveQuantizer:: |
600 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>( |
601 | | const uint8_t* codes, |
602 | 0 | const float* LUT) const { |
603 | 0 | BitstringReader bs(codes, code_size); |
604 | 0 | float accu = accumulate_IPs(*this, bs, LUT); |
605 | 0 | uint32_t norm_i = bs.read(8); |
606 | 0 | float norm2 = decode_qint8(norm_i, norm_min, norm_max); |
607 | 0 | return norm2 - 2 * accu; |
608 | 0 | } |
609 | | |
610 | | template <> |
611 | | float AdditiveQuantizer:: |
612 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>( |
613 | | const uint8_t* codes, |
614 | 0 | const float* LUT) const { |
615 | 0 | BitstringReader bs(codes, code_size); |
616 | 0 | float accu = accumulate_IPs(*this, bs, LUT); |
617 | 0 | uint32_t norm_i = bs.read(4); |
618 | 0 | float norm2 = decode_qint4(norm_i, norm_min, norm_max); |
619 | 0 | return norm2 - 2 * accu; |
620 | 0 | } |
621 | | |
622 | | template <> |
623 | | float AdditiveQuantizer:: |
624 | | compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_from_LUT>( |
625 | | const uint8_t* codes, |
626 | 0 | const float* LUT) const { |
627 | 0 | FAISS_THROW_IF_NOT(codebook_cross_products.size() > 0); |
628 | 0 | BitstringReader bs(codes, code_size); |
629 | 0 | float accu = accumulate_IPs(*this, bs, LUT); |
630 | 0 | BitstringReader bs2(codes, code_size); |
631 | 0 | float norm2 = compute_norm_from_LUT(*this, bs2); |
632 | 0 | return norm2 - 2 * accu; |
633 | 0 | } |
634 | | |
635 | | } // namespace faiss |