contrib/faiss/faiss/impl/lattice_Zn.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/impl/lattice_Zn.h> |
11 | | |
12 | | #include <cassert> |
13 | | #include <cmath> |
14 | | #include <cstdlib> |
15 | | #include <cstring> |
16 | | |
17 | | #include <algorithm> |
18 | | #include <queue> |
19 | | #include <unordered_set> |
20 | | |
21 | | #include <faiss/impl/platform_macros.h> |
22 | | #include <faiss/utils/distances.h> |
23 | | |
24 | | namespace faiss { |
25 | | |
26 | | /******************************************** |
27 | | * small utility functions |
28 | | ********************************************/ |
29 | | |
30 | | namespace { |
31 | | |
32 | 0 | inline float sqr(float x) { |
33 | 0 | return x * x; |
34 | 0 | } |
35 | | |
36 | | typedef std::vector<float> point_list_t; |
37 | | |
38 | | struct Comb { |
39 | | std::vector<uint64_t> tab; // Pascal's triangle |
40 | | int nmax; |
41 | | |
42 | 7 | explicit Comb(int nmax) : nmax(nmax) { |
43 | 7 | tab.resize(nmax * nmax, 0); |
44 | 7 | tab[0] = 1; |
45 | 700 | for (int i = 1; i < nmax; i++) { |
46 | 693 | tab[i * nmax] = 1; |
47 | 35.3k | for (int j = 1; j <= i; j++) { |
48 | 34.6k | tab[i * nmax + j] = |
49 | 34.6k | tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)]; |
50 | 34.6k | } |
51 | 693 | } |
52 | 7 | } |
53 | | |
54 | 0 | uint64_t operator()(int n, int p) const { |
55 | 0 | assert(n < nmax && p < nmax); |
56 | 0 | if (p > n) |
57 | 0 | return 0; |
58 | 0 | return tab[n * nmax + p]; |
59 | 0 | } |
60 | | }; |
61 | | |
62 | | Comb comb(100); |
63 | | |
64 | | // compute combinations of n integer values <= v that sum up to total (squared) |
65 | 0 | point_list_t sum_of_sq(float total, int v, int n, float add = 0) { |
66 | 0 | if (total < 0) { |
67 | 0 | return point_list_t(); |
68 | 0 | } else if (n == 1) { |
69 | 0 | while (sqr(v + add) > total) |
70 | 0 | v--; |
71 | 0 | if (sqr(v + add) == total) { |
72 | 0 | return point_list_t(1, v + add); |
73 | 0 | } else { |
74 | 0 | return point_list_t(); |
75 | 0 | } |
76 | 0 | } else { |
77 | 0 | point_list_t res; |
78 | 0 | while (v >= 0) { |
79 | 0 | point_list_t sub_points = |
80 | 0 | sum_of_sq(total - sqr(v + add), v, n - 1, add); |
81 | 0 | for (size_t i = 0; i < sub_points.size(); i += n - 1) { |
82 | 0 | res.push_back(v + add); |
83 | 0 | for (int j = 0; j < n - 1; j++) { |
84 | 0 | res.push_back(sub_points[i + j]); |
85 | 0 | } |
86 | 0 | } |
87 | 0 | v--; |
88 | 0 | } |
89 | 0 | return res; |
90 | 0 | } |
91 | 0 | } |
92 | | |
93 | 0 | int decode_comb_1(uint64_t* n, int k1, int r) { |
94 | 0 | while (comb(r, k1) > *n) { |
95 | 0 | r--; |
96 | 0 | } |
97 | 0 | *n -= comb(r, k1); |
98 | 0 | return r; |
99 | 0 | } |
100 | | |
101 | | // optimized version for < 64 bits |
102 | | uint64_t repeats_encode_64( |
103 | | const std::vector<Repeat>& repeats, |
104 | | int dim, |
105 | 0 | const float* c) { |
106 | 0 | uint64_t coded = 0; |
107 | 0 | int nfree = dim; |
108 | 0 | uint64_t code = 0, shift = 1; |
109 | 0 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { |
110 | 0 | int rank = 0, occ = 0; |
111 | 0 | uint64_t code_comb = 0; |
112 | 0 | uint64_t tosee = ~coded; |
113 | 0 | for (;;) { |
114 | | // directly jump to next available slot. |
115 | 0 | int i = __builtin_ctzll(tosee); |
116 | 0 | tosee &= ~(uint64_t{1} << i); |
117 | 0 | if (c[i] == r->val) { |
118 | 0 | code_comb += comb(rank, occ + 1); |
119 | 0 | occ++; |
120 | 0 | coded |= uint64_t{1} << i; |
121 | 0 | if (occ == r->n) |
122 | 0 | break; |
123 | 0 | } |
124 | 0 | rank++; |
125 | 0 | } |
126 | 0 | uint64_t max_comb = comb(nfree, r->n); |
127 | 0 | code += shift * code_comb; |
128 | 0 | shift *= max_comb; |
129 | 0 | nfree -= r->n; |
130 | 0 | } |
131 | 0 | return code; |
132 | 0 | } |
133 | | |
134 | | void repeats_decode_64( |
135 | | const std::vector<Repeat>& repeats, |
136 | | int dim, |
137 | | uint64_t code, |
138 | 0 | float* c) { |
139 | 0 | uint64_t decoded = 0; |
140 | 0 | int nfree = dim; |
141 | 0 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { |
142 | 0 | uint64_t max_comb = comb(nfree, r->n); |
143 | 0 | uint64_t code_comb = code % max_comb; |
144 | 0 | code /= max_comb; |
145 | |
|
146 | 0 | int occ = 0; |
147 | 0 | int rank = nfree; |
148 | 0 | int next_rank = decode_comb_1(&code_comb, r->n, rank); |
149 | 0 | uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded; |
150 | 0 | for (;;) { |
151 | 0 | int i = 63 - __builtin_clzll(tosee); |
152 | 0 | tosee &= ~(uint64_t{1} << i); |
153 | 0 | rank--; |
154 | 0 | if (rank == next_rank) { |
155 | 0 | decoded |= uint64_t{1} << i; |
156 | 0 | c[i] = r->val; |
157 | 0 | occ++; |
158 | 0 | if (occ == r->n) |
159 | 0 | break; |
160 | 0 | next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank); |
161 | 0 | } |
162 | 0 | } |
163 | 0 | nfree -= r->n; |
164 | 0 | } |
165 | 0 | } |
166 | | |
167 | | } // anonymous namespace |
168 | | |
169 | 0 | Repeats::Repeats(int dim, const float* c) : dim(dim) { |
170 | 0 | for (int i = 0; i < dim; i++) { |
171 | 0 | int j = 0; |
172 | 0 | for (;;) { |
173 | 0 | if (j == repeats.size()) { |
174 | 0 | repeats.push_back(Repeat{c[i], 1}); |
175 | 0 | break; |
176 | 0 | } |
177 | 0 | if (repeats[j].val == c[i]) { |
178 | 0 | repeats[j].n++; |
179 | 0 | break; |
180 | 0 | } |
181 | 0 | j++; |
182 | 0 | } |
183 | 0 | } |
184 | 0 | } |
185 | | |
186 | 0 | uint64_t Repeats::count() const { |
187 | 0 | uint64_t accu = 1; |
188 | 0 | int remain = dim; |
189 | 0 | for (int i = 0; i < repeats.size(); i++) { |
190 | 0 | accu *= comb(remain, repeats[i].n); |
191 | 0 | remain -= repeats[i].n; |
192 | 0 | } |
193 | 0 | return accu; |
194 | 0 | } |
195 | | |
196 | | // version with a bool vector that works for > 64 dim |
197 | 0 | uint64_t Repeats::encode(const float* c) const { |
198 | 0 | if (dim < 64) { |
199 | 0 | return repeats_encode_64(repeats, dim, c); |
200 | 0 | } |
201 | 0 | std::vector<bool> coded(dim, false); |
202 | 0 | int nfree = dim; |
203 | 0 | uint64_t code = 0, shift = 1; |
204 | 0 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { |
205 | 0 | int rank = 0, occ = 0; |
206 | 0 | uint64_t code_comb = 0; |
207 | 0 | for (int i = 0; i < dim; i++) { |
208 | 0 | if (!coded[i]) { |
209 | 0 | if (c[i] == r->val) { |
210 | 0 | code_comb += comb(rank, occ + 1); |
211 | 0 | occ++; |
212 | 0 | coded[i] = true; |
213 | 0 | if (occ == r->n) |
214 | 0 | break; |
215 | 0 | } |
216 | 0 | rank++; |
217 | 0 | } |
218 | 0 | } |
219 | 0 | uint64_t max_comb = comb(nfree, r->n); |
220 | 0 | code += shift * code_comb; |
221 | 0 | shift *= max_comb; |
222 | 0 | nfree -= r->n; |
223 | 0 | } |
224 | 0 | return code; |
225 | 0 | } |
226 | | |
227 | 0 | void Repeats::decode(uint64_t code, float* c) const { |
228 | 0 | if (dim < 64) { |
229 | 0 | repeats_decode_64(repeats, dim, code, c); |
230 | 0 | return; |
231 | 0 | } |
232 | | |
233 | 0 | std::vector<bool> decoded(dim, false); |
234 | 0 | int nfree = dim; |
235 | 0 | for (auto r = repeats.begin(); r != repeats.end(); ++r) { |
236 | 0 | uint64_t max_comb = comb(nfree, r->n); |
237 | 0 | uint64_t code_comb = code % max_comb; |
238 | 0 | code /= max_comb; |
239 | |
|
240 | 0 | int occ = 0; |
241 | 0 | int rank = nfree; |
242 | 0 | int next_rank = decode_comb_1(&code_comb, r->n, rank); |
243 | 0 | for (int i = dim - 1; i >= 0; i--) { |
244 | 0 | if (!decoded[i]) { |
245 | 0 | rank--; |
246 | 0 | if (rank == next_rank) { |
247 | 0 | decoded[i] = true; |
248 | 0 | c[i] = r->val; |
249 | 0 | occ++; |
250 | 0 | if (occ == r->n) |
251 | 0 | break; |
252 | 0 | next_rank = |
253 | 0 | decode_comb_1(&code_comb, r->n - occ, next_rank); |
254 | 0 | } |
255 | 0 | } |
256 | 0 | } |
257 | 0 | nfree -= r->n; |
258 | 0 | } |
259 | 0 | } |
260 | | |
261 | | /******************************************** |
262 | | * EnumeratedVectors functions |
263 | | ********************************************/ |
264 | | |
265 | | void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes) |
266 | 0 | const { |
267 | 0 | #pragma omp parallel if (n > 1000) |
268 | 0 | { |
269 | 0 | #pragma omp for |
270 | 0 | for (int i = 0; i < n; i++) { |
271 | 0 | codes[i] = encode(c + i * dim); |
272 | 0 | } |
273 | 0 | } |
274 | 0 | } |
275 | | |
276 | | void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c) |
277 | 0 | const { |
278 | 0 | #pragma omp parallel if (n > 1000) |
279 | 0 | { |
280 | 0 | #pragma omp for |
281 | 0 | for (int i = 0; i < n; i++) { |
282 | 0 | decode(codes[i], c + i * dim); |
283 | 0 | } |
284 | 0 | } |
285 | 0 | } |
286 | | |
287 | | void EnumeratedVectors::find_nn( |
288 | | size_t nc, |
289 | | const uint64_t* codes, |
290 | | size_t nq, |
291 | | const float* xq, |
292 | | int64_t* labels, |
293 | 0 | float* distances) { |
294 | 0 | for (size_t i = 0; i < nq; i++) { |
295 | 0 | distances[i] = -1e20; |
296 | 0 | labels[i] = -1; |
297 | 0 | } |
298 | |
|
299 | 0 | std::vector<float> c(dim); |
300 | 0 | for (size_t i = 0; i < nc; i++) { |
301 | 0 | uint64_t code = codes[nc]; |
302 | 0 | decode(code, c.data()); |
303 | 0 | for (size_t j = 0; j < nq; j++) { |
304 | 0 | const float* x = xq + j * dim; |
305 | 0 | float dis = fvec_inner_product(x, c.data(), dim); |
306 | 0 | if (dis > distances[j]) { |
307 | 0 | distances[j] = dis; |
308 | 0 | labels[j] = i; |
309 | 0 | } |
310 | 0 | } |
311 | 0 | } |
312 | 0 | } |
313 | | |
314 | | /********************************************************** |
315 | | * ZnSphereSearch |
316 | | **********************************************************/ |
317 | | |
318 | 0 | ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) { |
319 | 0 | voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim); |
320 | 0 | natom = voc.size() / dim; |
321 | 0 | } |
322 | | |
323 | 0 | float ZnSphereSearch::search(const float* x, float* c) const { |
324 | 0 | std::vector<float> tmp(dimS * 2); |
325 | 0 | std::vector<int> tmp_int(dimS); |
326 | 0 | return search(x, c, tmp.data(), tmp_int.data()); |
327 | 0 | } |
328 | | |
329 | | float ZnSphereSearch::search( |
330 | | const float* x, |
331 | | float* c, |
332 | | float* tmp, // size 2 *dim |
333 | | int* tmp_int, // size dim |
334 | 0 | int* ibest_out) const { |
335 | 0 | int dim = dimS; |
336 | 0 | assert(natom > 0); |
337 | 0 | int* o = tmp_int; |
338 | 0 | float* xabs = tmp; |
339 | 0 | float* xperm = tmp + dim; |
340 | | |
341 | | // argsort |
342 | 0 | for (int i = 0; i < dim; i++) { |
343 | 0 | o[i] = i; |
344 | 0 | xabs[i] = fabsf(x[i]); |
345 | 0 | } |
346 | 0 | std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; }); |
347 | 0 | for (int i = 0; i < dim; i++) { |
348 | 0 | xperm[i] = xabs[o[i]]; |
349 | 0 | } |
350 | | // find best |
351 | 0 | int ibest = -1; |
352 | 0 | float dpbest = -100; |
353 | 0 | for (int i = 0; i < natom; i++) { |
354 | 0 | float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim); |
355 | 0 | if (dp > dpbest) { |
356 | 0 | dpbest = dp; |
357 | 0 | ibest = i; |
358 | 0 | } |
359 | 0 | } |
360 | | // revert sort |
361 | 0 | const float* cin = voc.data() + ibest * dim; |
362 | 0 | for (int i = 0; i < dim; i++) { |
363 | 0 | c[o[i]] = copysignf(cin[i], x[o[i]]); |
364 | 0 | } |
365 | 0 | if (ibest_out) { |
366 | 0 | *ibest_out = ibest; |
367 | 0 | } |
368 | 0 | return dpbest; |
369 | 0 | } |
370 | | |
371 | | void ZnSphereSearch::search_multi( |
372 | | int n, |
373 | | const float* x, |
374 | | float* c_out, |
375 | 0 | float* dp_out) { |
376 | 0 | #pragma omp parallel if (n > 1000) |
377 | 0 | { |
378 | 0 | #pragma omp for |
379 | 0 | for (int i = 0; i < n; i++) { |
380 | 0 | dp_out[i] = search(x + i * dimS, c_out + i * dimS); |
381 | 0 | } |
382 | 0 | } |
383 | 0 | } |
384 | | |
385 | | /********************************************************** |
386 | | * ZnSphereCodec |
387 | | **********************************************************/ |
388 | | |
389 | | ZnSphereCodec::ZnSphereCodec(int dim, int r2) |
390 | 0 | : ZnSphereSearch(dim, r2), EnumeratedVectors(dim) { |
391 | 0 | nv = 0; |
392 | 0 | for (int i = 0; i < natom; i++) { |
393 | 0 | Repeats repeats(dim, &voc[i * dim]); |
394 | 0 | CodeSegment cs(repeats); |
395 | 0 | cs.c0 = nv; |
396 | 0 | Repeat& br = repeats.repeats.back(); |
397 | 0 | cs.signbits = br.val == 0 ? dim - br.n : dim; |
398 | 0 | code_segments.push_back(cs); |
399 | 0 | nv += repeats.count() << cs.signbits; |
400 | 0 | } |
401 | |
|
402 | 0 | uint64_t nvx = nv; |
403 | 0 | code_size = 0; |
404 | 0 | while (nvx > 0) { |
405 | 0 | nvx >>= 8; |
406 | 0 | code_size++; |
407 | 0 | } |
408 | 0 | } |
409 | | |
410 | 0 | uint64_t ZnSphereCodec::search_and_encode(const float* x) const { |
411 | 0 | std::vector<float> tmp(dim * 2); |
412 | 0 | std::vector<int> tmp_int(dim); |
413 | 0 | int ano; // atom number |
414 | 0 | std::vector<float> c(dim); |
415 | 0 | search(x, c.data(), tmp.data(), tmp_int.data(), &ano); |
416 | 0 | uint64_t signs = 0; |
417 | 0 | std::vector<float> cabs(dim); |
418 | 0 | int nnz = 0; |
419 | 0 | for (int i = 0; i < dim; i++) { |
420 | 0 | cabs[i] = fabs(c[i]); |
421 | 0 | if (c[i] != 0) { |
422 | 0 | if (c[i] < 0) { |
423 | 0 | signs |= uint64_t{1} << nnz; |
424 | 0 | } |
425 | 0 | nnz++; |
426 | 0 | } |
427 | 0 | } |
428 | 0 | const CodeSegment& cs = code_segments[ano]; |
429 | 0 | assert(nnz == cs.signbits); |
430 | 0 | uint64_t code = cs.c0 + signs; |
431 | 0 | code += cs.encode(cabs.data()) << cs.signbits; |
432 | 0 | return code; |
433 | 0 | } |
434 | | |
435 | 0 | uint64_t ZnSphereCodec::encode(const float* x) const { |
436 | 0 | return search_and_encode(x); |
437 | 0 | } |
438 | | |
439 | 0 | void ZnSphereCodec::decode(uint64_t code, float* c) const { |
440 | 0 | int i0 = 0, i1 = natom; |
441 | 0 | while (i0 + 1 < i1) { |
442 | 0 | int imed = (i0 + i1) / 2; |
443 | 0 | if (code_segments[imed].c0 <= code) |
444 | 0 | i0 = imed; |
445 | 0 | else |
446 | 0 | i1 = imed; |
447 | 0 | } |
448 | 0 | const CodeSegment& cs = code_segments[i0]; |
449 | 0 | code -= cs.c0; |
450 | 0 | uint64_t signs = code; |
451 | 0 | code >>= cs.signbits; |
452 | 0 | cs.decode(code, c); |
453 | |
|
454 | 0 | int nnz = 0; |
455 | 0 | for (int i = 0; i < dim; i++) { |
456 | 0 | if (c[i] != 0) { |
457 | 0 | if (signs & (uint64_t(1) << nnz)) { |
458 | 0 | c[i] = -c[i]; |
459 | 0 | } |
460 | 0 | nnz++; |
461 | 0 | } |
462 | 0 | } |
463 | 0 | } |
464 | | |
465 | | /************************************************************** |
466 | | * ZnSphereCodecRec |
467 | | **************************************************************/ |
468 | | |
469 | 0 | uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const { |
470 | 0 | return all_nv[ld * (r2 + 1) + r2a]; |
471 | 0 | } |
472 | | |
473 | 0 | uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const { |
474 | 0 | return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a]; |
475 | 0 | } |
476 | | |
477 | 0 | void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) { |
478 | 0 | all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum; |
479 | 0 | } |
480 | | |
481 | | ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2) |
482 | 0 | : EnumeratedVectors(dim), r2(r2) { |
483 | 0 | log2_dim = 0; |
484 | 0 | while (dim > (1 << log2_dim)) { |
485 | 0 | log2_dim++; |
486 | 0 | } |
487 | 0 | assert(dim == (1 << log2_dim) || !"dimension must be a power of 2"); |
488 | | |
489 | 0 | all_nv.resize((log2_dim + 1) * (r2 + 1)); |
490 | 0 | all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1)); |
491 | |
|
492 | 0 | for (int r2a = 0; r2a <= r2; r2a++) { |
493 | 0 | int r = int(sqrt(r2a)); |
494 | 0 | if (r * r == r2a) { |
495 | 0 | all_nv[r2a] = r == 0 ? 1 : 2; |
496 | 0 | } else { |
497 | 0 | all_nv[r2a] = 0; |
498 | 0 | } |
499 | 0 | } |
500 | |
|
501 | 0 | for (int ld = 1; ld <= log2_dim; ld++) { |
502 | 0 | for (int r2sub = 0; r2sub <= r2; r2sub++) { |
503 | 0 | uint64_t nv = 0; |
504 | 0 | for (int r2a = 0; r2a <= r2sub; r2a++) { |
505 | 0 | int r2b = r2sub - r2a; |
506 | 0 | set_nv_cum(ld, r2sub, r2a, nv); |
507 | 0 | nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b); |
508 | 0 | } |
509 | 0 | all_nv[ld * (r2 + 1) + r2sub] = nv; |
510 | 0 | } |
511 | 0 | } |
512 | 0 | nv = get_nv(log2_dim, r2); |
513 | |
|
514 | 0 | uint64_t nvx = nv; |
515 | 0 | code_size = 0; |
516 | 0 | while (nvx > 0) { |
517 | 0 | nvx >>= 8; |
518 | 0 | code_size++; |
519 | 0 | } |
520 | |
|
521 | 0 | int cache_level = std::min(3, log2_dim - 1); |
522 | 0 | decode_cache_ld = 0; |
523 | 0 | assert(cache_level <= log2_dim); |
524 | 0 | decode_cache.resize((r2 + 1)); |
525 | |
|
526 | 0 | for (int r2sub = 0; r2sub <= r2; r2sub++) { |
527 | 0 | int ld = cache_level; |
528 | 0 | uint64_t nvi = get_nv(ld, r2sub); |
529 | 0 | std::vector<float>& cache = decode_cache[r2sub]; |
530 | 0 | int dimsub = (1 << cache_level); |
531 | 0 | cache.resize(nvi * dimsub); |
532 | 0 | std::vector<float> c(dim); |
533 | 0 | uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub); |
534 | 0 | for (int i = 0; i < nvi; i++) { |
535 | 0 | decode(i + code0, c.data()); |
536 | 0 | memcpy(&cache[i * dimsub], |
537 | 0 | c.data() + dim - dimsub, |
538 | 0 | dimsub * sizeof(*c.data())); |
539 | 0 | } |
540 | 0 | } |
541 | 0 | decode_cache_ld = cache_level; |
542 | 0 | } |
543 | | |
544 | 0 | uint64_t ZnSphereCodecRec::encode(const float* c) const { |
545 | 0 | return encode_centroid(c); |
546 | 0 | } |
547 | | |
548 | 0 | uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const { |
549 | 0 | std::vector<uint64_t> codes(dim); |
550 | 0 | std::vector<int> norm2s(dim); |
551 | 0 | for (int i = 0; i < dim; i++) { |
552 | 0 | if (c[i] == 0) { |
553 | 0 | codes[i] = 0; |
554 | 0 | norm2s[i] = 0; |
555 | 0 | } else { |
556 | 0 | int r2i = int(c[i] * c[i]); |
557 | 0 | norm2s[i] = r2i; |
558 | 0 | codes[i] = c[i] >= 0 ? 0 : 1; |
559 | 0 | } |
560 | 0 | } |
561 | 0 | int dim2 = dim / 2; |
562 | 0 | for (int ld = 1; ld <= log2_dim; ld++) { |
563 | 0 | for (int i = 0; i < dim2; i++) { |
564 | 0 | int r2a = norm2s[2 * i]; |
565 | 0 | int r2b = norm2s[2 * i + 1]; |
566 | |
|
567 | 0 | uint64_t code_a = codes[2 * i]; |
568 | 0 | uint64_t code_b = codes[2 * i + 1]; |
569 | |
|
570 | 0 | codes[i] = get_nv_cum(ld, r2a + r2b, r2a) + |
571 | 0 | code_a * get_nv(ld - 1, r2b) + code_b; |
572 | 0 | norm2s[i] = r2a + r2b; |
573 | 0 | } |
574 | 0 | dim2 /= 2; |
575 | 0 | } |
576 | 0 | return codes[0]; |
577 | 0 | } |
578 | | |
579 | 0 | void ZnSphereCodecRec::decode(uint64_t code, float* c) const { |
580 | 0 | std::vector<uint64_t> codes(dim); |
581 | 0 | std::vector<int> norm2s(dim); |
582 | 0 | codes[0] = code; |
583 | 0 | norm2s[0] = r2; |
584 | |
|
585 | 0 | int dim2 = 1; |
586 | 0 | for (int ld = log2_dim; ld > decode_cache_ld; ld--) { |
587 | 0 | for (int i = dim2 - 1; i >= 0; i--) { |
588 | 0 | int r2sub = norm2s[i]; |
589 | 0 | int i0 = 0, i1 = r2sub + 1; |
590 | 0 | uint64_t codei = codes[i]; |
591 | 0 | const uint64_t* cum = |
592 | 0 | &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)]; |
593 | 0 | while (i1 > i0 + 1) { |
594 | 0 | int imed = (i0 + i1) / 2; |
595 | 0 | if (cum[imed] <= codei) |
596 | 0 | i0 = imed; |
597 | 0 | else |
598 | 0 | i1 = imed; |
599 | 0 | } |
600 | 0 | int r2a = i0, r2b = r2sub - i0; |
601 | 0 | codei -= cum[r2a]; |
602 | 0 | norm2s[2 * i] = r2a; |
603 | 0 | norm2s[2 * i + 1] = r2b; |
604 | |
|
605 | 0 | uint64_t code_a = codei / get_nv(ld - 1, r2b); |
606 | 0 | uint64_t code_b = codei % get_nv(ld - 1, r2b); |
607 | |
|
608 | 0 | codes[2 * i] = code_a; |
609 | 0 | codes[2 * i + 1] = code_b; |
610 | 0 | } |
611 | 0 | dim2 *= 2; |
612 | 0 | } |
613 | |
|
614 | 0 | if (decode_cache_ld == 0) { |
615 | 0 | for (int i = 0; i < dim; i++) { |
616 | 0 | if (norm2s[i] == 0) { |
617 | 0 | c[i] = 0; |
618 | 0 | } else { |
619 | 0 | float r = sqrt(norm2s[i]); |
620 | 0 | assert(r * r == norm2s[i]); |
621 | 0 | c[i] = codes[i] == 0 ? r : -r; |
622 | 0 | } |
623 | 0 | } |
624 | 0 | } else { |
625 | 0 | int subdim = 1 << decode_cache_ld; |
626 | 0 | assert((dim2 * subdim) == dim); |
627 | | |
628 | 0 | for (int i = 0; i < dim2; i++) { |
629 | 0 | const std::vector<float>& cache = decode_cache[norm2s[i]]; |
630 | 0 | assert(codes[i] < cache.size()); |
631 | 0 | memcpy(c + i * subdim, |
632 | 0 | &cache[codes[i] * subdim], |
633 | 0 | sizeof(*c) * subdim); |
634 | 0 | } |
635 | 0 | } |
636 | 0 | } |
637 | | |
638 | | // if not use_rec, instantiate an arbitrary harmless znc_rec |
639 | | ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2) |
640 | 0 | : ZnSphereCodec(dim, r2), |
641 | 0 | use_rec((dim & (dim - 1)) == 0), |
642 | 0 | znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {} |
643 | | |
644 | 0 | uint64_t ZnSphereCodecAlt::encode(const float* x) const { |
645 | 0 | if (!use_rec) { |
646 | | // it's ok if the vector is not normalized |
647 | 0 | return ZnSphereCodec::encode(x); |
648 | 0 | } else { |
649 | | // find nearest centroid |
650 | 0 | std::vector<float> centroid(dim); |
651 | 0 | search(x, centroid.data()); |
652 | 0 | return znc_rec.encode(centroid.data()); |
653 | 0 | } |
654 | 0 | } |
655 | | |
656 | 0 | void ZnSphereCodecAlt::decode(uint64_t code, float* c) const { |
657 | 0 | if (!use_rec) { |
658 | 0 | ZnSphereCodec::decode(code, c); |
659 | 0 | } else { |
660 | 0 | znc_rec.decode(code, c); |
661 | 0 | } |
662 | 0 | } |
663 | | |
664 | | } // namespace faiss |