Coverage Report

Created: 2026-03-25 11:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/lattice_Zn.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/impl/lattice_Zn.h>
11
12
#include <cassert>
13
#include <cmath>
14
#include <cstdlib>
15
#include <cstring>
16
17
#include <algorithm>
18
#include <queue>
19
#include <unordered_set>
20
21
#include <faiss/impl/platform_macros.h>
22
#include <faiss/utils/distances.h>
23
24
namespace faiss {
25
26
/********************************************
27
 * small utility functions
28
 ********************************************/
29
30
namespace {
31
32
0
inline float sqr(float x) {
33
0
    return x * x;
34
0
}
35
36
typedef std::vector<float> point_list_t;
37
38
struct Comb {
39
    std::vector<uint64_t> tab; // Pascal's triangle
40
    int nmax;
41
42
1
    explicit Comb(int nmax) : nmax(nmax) {
43
1
        tab.resize(nmax * nmax, 0);
44
1
        tab[0] = 1;
45
100
        for (int i = 1; i < nmax; i++) {
46
99
            tab[i * nmax] = 1;
47
5.04k
            for (int j = 1; j <= i; j++) {
48
4.95k
                tab[i * nmax + j] =
49
4.95k
                        tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
50
4.95k
            }
51
99
        }
52
1
    }
53
54
0
    uint64_t operator()(int n, int p) const {
55
0
        assert(n < nmax && p < nmax);
56
0
        if (p > n)
57
0
            return 0;
58
0
        return tab[n * nmax + p];
59
0
    }
60
};
61
62
Comb comb(100);
63
64
// compute combinations of n integer values <= v that sum up to total (squared)
65
0
point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
66
0
    if (total < 0) {
67
0
        return point_list_t();
68
0
    } else if (n == 1) {
69
0
        while (sqr(v + add) > total)
70
0
            v--;
71
0
        if (sqr(v + add) == total) {
72
0
            return point_list_t(1, v + add);
73
0
        } else {
74
0
            return point_list_t();
75
0
        }
76
0
    } else {
77
0
        point_list_t res;
78
0
        while (v >= 0) {
79
0
            point_list_t sub_points =
80
0
                    sum_of_sq(total - sqr(v + add), v, n - 1, add);
81
0
            for (size_t i = 0; i < sub_points.size(); i += n - 1) {
82
0
                res.push_back(v + add);
83
0
                for (int j = 0; j < n - 1; j++) {
84
0
                    res.push_back(sub_points[i + j]);
85
0
                }
86
0
            }
87
0
            v--;
88
0
        }
89
0
        return res;
90
0
    }
91
0
}
92
93
0
int decode_comb_1(uint64_t* n, int k1, int r) {
94
0
    while (comb(r, k1) > *n) {
95
0
        r--;
96
0
    }
97
0
    *n -= comb(r, k1);
98
0
    return r;
99
0
}
100
101
// optimized version for < 64 bits
102
uint64_t repeats_encode_64(
103
        const std::vector<Repeat>& repeats,
104
        int dim,
105
0
        const float* c) {
106
0
    uint64_t coded = 0;
107
0
    int nfree = dim;
108
0
    uint64_t code = 0, shift = 1;
109
0
    for (auto r = repeats.begin(); r != repeats.end(); ++r) {
110
0
        int rank = 0, occ = 0;
111
0
        uint64_t code_comb = 0;
112
0
        uint64_t tosee = ~coded;
113
0
        for (;;) {
114
            // directly jump to next available slot.
115
0
            int i = __builtin_ctzll(tosee);
116
0
            tosee &= ~(uint64_t{1} << i);
117
0
            if (c[i] == r->val) {
118
0
                code_comb += comb(rank, occ + 1);
119
0
                occ++;
120
0
                coded |= uint64_t{1} << i;
121
0
                if (occ == r->n)
122
0
                    break;
123
0
            }
124
0
            rank++;
125
0
        }
126
0
        uint64_t max_comb = comb(nfree, r->n);
127
0
        code += shift * code_comb;
128
0
        shift *= max_comb;
129
0
        nfree -= r->n;
130
0
    }
131
0
    return code;
132
0
}
133
134
void repeats_decode_64(
135
        const std::vector<Repeat>& repeats,
136
        int dim,
137
        uint64_t code,
138
0
        float* c) {
139
0
    uint64_t decoded = 0;
140
0
    int nfree = dim;
141
0
    for (auto r = repeats.begin(); r != repeats.end(); ++r) {
142
0
        uint64_t max_comb = comb(nfree, r->n);
143
0
        uint64_t code_comb = code % max_comb;
144
0
        code /= max_comb;
145
146
0
        int occ = 0;
147
0
        int rank = nfree;
148
0
        int next_rank = decode_comb_1(&code_comb, r->n, rank);
149
0
        uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
150
0
        for (;;) {
151
0
            int i = 63 - __builtin_clzll(tosee);
152
0
            tosee &= ~(uint64_t{1} << i);
153
0
            rank--;
154
0
            if (rank == next_rank) {
155
0
                decoded |= uint64_t{1} << i;
156
0
                c[i] = r->val;
157
0
                occ++;
158
0
                if (occ == r->n)
159
0
                    break;
160
0
                next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
161
0
            }
162
0
        }
163
0
        nfree -= r->n;
164
0
    }
165
0
}
166
167
} // anonymous namespace
168
169
0
Repeats::Repeats(int dim, const float* c) : dim(dim) {
170
0
    for (int i = 0; i < dim; i++) {
171
0
        int j = 0;
172
0
        for (;;) {
173
0
            if (j == repeats.size()) {
174
0
                repeats.push_back(Repeat{c[i], 1});
175
0
                break;
176
0
            }
177
0
            if (repeats[j].val == c[i]) {
178
0
                repeats[j].n++;
179
0
                break;
180
0
            }
181
0
            j++;
182
0
        }
183
0
    }
184
0
}
185
186
0
uint64_t Repeats::count() const {
187
0
    uint64_t accu = 1;
188
0
    int remain = dim;
189
0
    for (int i = 0; i < repeats.size(); i++) {
190
0
        accu *= comb(remain, repeats[i].n);
191
0
        remain -= repeats[i].n;
192
0
    }
193
0
    return accu;
194
0
}
195
196
// version with a bool vector that works for > 64 dim
197
0
uint64_t Repeats::encode(const float* c) const {
198
0
    if (dim < 64) {
199
0
        return repeats_encode_64(repeats, dim, c);
200
0
    }
201
0
    std::vector<bool> coded(dim, false);
202
0
    int nfree = dim;
203
0
    uint64_t code = 0, shift = 1;
204
0
    for (auto r = repeats.begin(); r != repeats.end(); ++r) {
205
0
        int rank = 0, occ = 0;
206
0
        uint64_t code_comb = 0;
207
0
        for (int i = 0; i < dim; i++) {
208
0
            if (!coded[i]) {
209
0
                if (c[i] == r->val) {
210
0
                    code_comb += comb(rank, occ + 1);
211
0
                    occ++;
212
0
                    coded[i] = true;
213
0
                    if (occ == r->n)
214
0
                        break;
215
0
                }
216
0
                rank++;
217
0
            }
218
0
        }
219
0
        uint64_t max_comb = comb(nfree, r->n);
220
0
        code += shift * code_comb;
221
0
        shift *= max_comb;
222
0
        nfree -= r->n;
223
0
    }
224
0
    return code;
225
0
}
226
227
0
void Repeats::decode(uint64_t code, float* c) const {
228
0
    if (dim < 64) {
229
0
        repeats_decode_64(repeats, dim, code, c);
230
0
        return;
231
0
    }
232
233
0
    std::vector<bool> decoded(dim, false);
234
0
    int nfree = dim;
235
0
    for (auto r = repeats.begin(); r != repeats.end(); ++r) {
236
0
        uint64_t max_comb = comb(nfree, r->n);
237
0
        uint64_t code_comb = code % max_comb;
238
0
        code /= max_comb;
239
240
0
        int occ = 0;
241
0
        int rank = nfree;
242
0
        int next_rank = decode_comb_1(&code_comb, r->n, rank);
243
0
        for (int i = dim - 1; i >= 0; i--) {
244
0
            if (!decoded[i]) {
245
0
                rank--;
246
0
                if (rank == next_rank) {
247
0
                    decoded[i] = true;
248
0
                    c[i] = r->val;
249
0
                    occ++;
250
0
                    if (occ == r->n)
251
0
                        break;
252
0
                    next_rank =
253
0
                            decode_comb_1(&code_comb, r->n - occ, next_rank);
254
0
                }
255
0
            }
256
0
        }
257
0
        nfree -= r->n;
258
0
    }
259
0
}
260
261
/********************************************
262
 * EnumeratedVectors functions
263
 ********************************************/
264
265
void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
266
0
        const {
267
0
#pragma omp parallel if (n > 1000)
268
0
    {
269
0
#pragma omp for
270
0
        for (int i = 0; i < n; i++) {
271
0
            codes[i] = encode(c + i * dim);
272
0
        }
273
0
    }
274
0
}
275
276
void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
277
0
        const {
278
0
#pragma omp parallel if (n > 1000)
279
0
    {
280
0
#pragma omp for
281
0
        for (int i = 0; i < n; i++) {
282
0
            decode(codes[i], c + i * dim);
283
0
        }
284
0
    }
285
0
}
286
287
void EnumeratedVectors::find_nn(
288
        size_t nc,
289
        const uint64_t* codes,
290
        size_t nq,
291
        const float* xq,
292
        int64_t* labels,
293
0
        float* distances) {
294
0
    for (size_t i = 0; i < nq; i++) {
295
0
        distances[i] = -1e20;
296
0
        labels[i] = -1;
297
0
    }
298
299
0
    std::vector<float> c(dim);
300
0
    for (size_t i = 0; i < nc; i++) {
301
0
        uint64_t code = codes[nc];
302
0
        decode(code, c.data());
303
0
        for (size_t j = 0; j < nq; j++) {
304
0
            const float* x = xq + j * dim;
305
0
            float dis = fvec_inner_product(x, c.data(), dim);
306
0
            if (dis > distances[j]) {
307
0
                distances[j] = dis;
308
0
                labels[j] = i;
309
0
            }
310
0
        }
311
0
    }
312
0
}
313
314
/**********************************************************
315
 * ZnSphereSearch
316
 **********************************************************/
317
318
0
ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
319
0
    voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
320
0
    natom = voc.size() / dim;
321
0
}
322
323
0
float ZnSphereSearch::search(const float* x, float* c) const {
324
0
    std::vector<float> tmp(dimS * 2);
325
0
    std::vector<int> tmp_int(dimS);
326
0
    return search(x, c, tmp.data(), tmp_int.data());
327
0
}
328
329
float ZnSphereSearch::search(
330
        const float* x,
331
        float* c,
332
        float* tmp,   // size 2 *dim
333
        int* tmp_int, // size dim
334
0
        int* ibest_out) const {
335
0
    int dim = dimS;
336
0
    assert(natom > 0);
337
0
    int* o = tmp_int;
338
0
    float* xabs = tmp;
339
0
    float* xperm = tmp + dim;
340
341
    // argsort
342
0
    for (int i = 0; i < dim; i++) {
343
0
        o[i] = i;
344
0
        xabs[i] = fabsf(x[i]);
345
0
    }
346
0
    std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
347
0
    for (int i = 0; i < dim; i++) {
348
0
        xperm[i] = xabs[o[i]];
349
0
    }
350
    // find best
351
0
    int ibest = -1;
352
0
    float dpbest = -100;
353
0
    for (int i = 0; i < natom; i++) {
354
0
        float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
355
0
        if (dp > dpbest) {
356
0
            dpbest = dp;
357
0
            ibest = i;
358
0
        }
359
0
    }
360
    // revert sort
361
0
    const float* cin = voc.data() + ibest * dim;
362
0
    for (int i = 0; i < dim; i++) {
363
0
        c[o[i]] = copysignf(cin[i], x[o[i]]);
364
0
    }
365
0
    if (ibest_out) {
366
0
        *ibest_out = ibest;
367
0
    }
368
0
    return dpbest;
369
0
}
370
371
void ZnSphereSearch::search_multi(
372
        int n,
373
        const float* x,
374
        float* c_out,
375
0
        float* dp_out) {
376
0
#pragma omp parallel if (n > 1000)
377
0
    {
378
0
#pragma omp for
379
0
        for (int i = 0; i < n; i++) {
380
0
            dp_out[i] = search(x + i * dimS, c_out + i * dimS);
381
0
        }
382
0
    }
383
0
}
384
385
/**********************************************************
386
 * ZnSphereCodec
387
 **********************************************************/
388
389
ZnSphereCodec::ZnSphereCodec(int dim, int r2)
390
0
        : ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
391
0
    nv = 0;
392
0
    for (int i = 0; i < natom; i++) {
393
0
        Repeats repeats(dim, &voc[i * dim]);
394
0
        CodeSegment cs(repeats);
395
0
        cs.c0 = nv;
396
0
        Repeat& br = repeats.repeats.back();
397
0
        cs.signbits = br.val == 0 ? dim - br.n : dim;
398
0
        code_segments.push_back(cs);
399
0
        nv += repeats.count() << cs.signbits;
400
0
    }
401
402
0
    uint64_t nvx = nv;
403
0
    code_size = 0;
404
0
    while (nvx > 0) {
405
0
        nvx >>= 8;
406
0
        code_size++;
407
0
    }
408
0
}
409
410
0
uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
411
0
    std::vector<float> tmp(dim * 2);
412
0
    std::vector<int> tmp_int(dim);
413
0
    int ano; // atom number
414
0
    std::vector<float> c(dim);
415
0
    search(x, c.data(), tmp.data(), tmp_int.data(), &ano);
416
0
    uint64_t signs = 0;
417
0
    std::vector<float> cabs(dim);
418
0
    int nnz = 0;
419
0
    for (int i = 0; i < dim; i++) {
420
0
        cabs[i] = fabs(c[i]);
421
0
        if (c[i] != 0) {
422
0
            if (c[i] < 0) {
423
0
                signs |= uint64_t{1} << nnz;
424
0
            }
425
0
            nnz++;
426
0
        }
427
0
    }
428
0
    const CodeSegment& cs = code_segments[ano];
429
0
    assert(nnz == cs.signbits);
430
0
    uint64_t code = cs.c0 + signs;
431
0
    code += cs.encode(cabs.data()) << cs.signbits;
432
0
    return code;
433
0
}
434
435
0
uint64_t ZnSphereCodec::encode(const float* x) const {
436
0
    return search_and_encode(x);
437
0
}
438
439
0
void ZnSphereCodec::decode(uint64_t code, float* c) const {
440
0
    int i0 = 0, i1 = natom;
441
0
    while (i0 + 1 < i1) {
442
0
        int imed = (i0 + i1) / 2;
443
0
        if (code_segments[imed].c0 <= code)
444
0
            i0 = imed;
445
0
        else
446
0
            i1 = imed;
447
0
    }
448
0
    const CodeSegment& cs = code_segments[i0];
449
0
    code -= cs.c0;
450
0
    uint64_t signs = code;
451
0
    code >>= cs.signbits;
452
0
    cs.decode(code, c);
453
454
0
    int nnz = 0;
455
0
    for (int i = 0; i < dim; i++) {
456
0
        if (c[i] != 0) {
457
0
            if (signs & (uint64_t(1) << nnz)) {
458
0
                c[i] = -c[i];
459
0
            }
460
0
            nnz++;
461
0
        }
462
0
    }
463
0
}
464
465
/**************************************************************
466
 * ZnSphereCodecRec
467
 **************************************************************/
468
469
0
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
470
0
    return all_nv[ld * (r2 + 1) + r2a];
471
0
}
472
473
0
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
474
0
    return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
475
0
}
476
477
0
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
478
0
    all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
479
0
}
480
481
ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
482
0
        : EnumeratedVectors(dim), r2(r2) {
483
0
    log2_dim = 0;
484
0
    while (dim > (1 << log2_dim)) {
485
0
        log2_dim++;
486
0
    }
487
0
    assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
488
489
0
    all_nv.resize((log2_dim + 1) * (r2 + 1));
490
0
    all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
491
492
0
    for (int r2a = 0; r2a <= r2; r2a++) {
493
0
        int r = int(sqrt(r2a));
494
0
        if (r * r == r2a) {
495
0
            all_nv[r2a] = r == 0 ? 1 : 2;
496
0
        } else {
497
0
            all_nv[r2a] = 0;
498
0
        }
499
0
    }
500
501
0
    for (int ld = 1; ld <= log2_dim; ld++) {
502
0
        for (int r2sub = 0; r2sub <= r2; r2sub++) {
503
0
            uint64_t nv = 0;
504
0
            for (int r2a = 0; r2a <= r2sub; r2a++) {
505
0
                int r2b = r2sub - r2a;
506
0
                set_nv_cum(ld, r2sub, r2a, nv);
507
0
                nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b);
508
0
            }
509
0
            all_nv[ld * (r2 + 1) + r2sub] = nv;
510
0
        }
511
0
    }
512
0
    nv = get_nv(log2_dim, r2);
513
514
0
    uint64_t nvx = nv;
515
0
    code_size = 0;
516
0
    while (nvx > 0) {
517
0
        nvx >>= 8;
518
0
        code_size++;
519
0
    }
520
521
0
    int cache_level = std::min(3, log2_dim - 1);
522
0
    decode_cache_ld = 0;
523
0
    assert(cache_level <= log2_dim);
524
0
    decode_cache.resize((r2 + 1));
525
526
0
    for (int r2sub = 0; r2sub <= r2; r2sub++) {
527
0
        int ld = cache_level;
528
0
        uint64_t nvi = get_nv(ld, r2sub);
529
0
        std::vector<float>& cache = decode_cache[r2sub];
530
0
        int dimsub = (1 << cache_level);
531
0
        cache.resize(nvi * dimsub);
532
0
        std::vector<float> c(dim);
533
0
        uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
534
0
        for (int i = 0; i < nvi; i++) {
535
0
            decode(i + code0, c.data());
536
0
            memcpy(&cache[i * dimsub],
537
0
                   c.data() + dim - dimsub,
538
0
                   dimsub * sizeof(*c.data()));
539
0
        }
540
0
    }
541
0
    decode_cache_ld = cache_level;
542
0
}
543
544
0
uint64_t ZnSphereCodecRec::encode(const float* c) const {
545
0
    return encode_centroid(c);
546
0
}
547
548
0
uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
549
0
    std::vector<uint64_t> codes(dim);
550
0
    std::vector<int> norm2s(dim);
551
0
    for (int i = 0; i < dim; i++) {
552
0
        if (c[i] == 0) {
553
0
            codes[i] = 0;
554
0
            norm2s[i] = 0;
555
0
        } else {
556
0
            int r2i = int(c[i] * c[i]);
557
0
            norm2s[i] = r2i;
558
0
            codes[i] = c[i] >= 0 ? 0 : 1;
559
0
        }
560
0
    }
561
0
    int dim2 = dim / 2;
562
0
    for (int ld = 1; ld <= log2_dim; ld++) {
563
0
        for (int i = 0; i < dim2; i++) {
564
0
            int r2a = norm2s[2 * i];
565
0
            int r2b = norm2s[2 * i + 1];
566
567
0
            uint64_t code_a = codes[2 * i];
568
0
            uint64_t code_b = codes[2 * i + 1];
569
570
0
            codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
571
0
                    code_a * get_nv(ld - 1, r2b) + code_b;
572
0
            norm2s[i] = r2a + r2b;
573
0
        }
574
0
        dim2 /= 2;
575
0
    }
576
0
    return codes[0];
577
0
}
578
579
0
void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
580
0
    std::vector<uint64_t> codes(dim);
581
0
    std::vector<int> norm2s(dim);
582
0
    codes[0] = code;
583
0
    norm2s[0] = r2;
584
585
0
    int dim2 = 1;
586
0
    for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
587
0
        for (int i = dim2 - 1; i >= 0; i--) {
588
0
            int r2sub = norm2s[i];
589
0
            int i0 = 0, i1 = r2sub + 1;
590
0
            uint64_t codei = codes[i];
591
0
            const uint64_t* cum =
592
0
                    &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
593
0
            while (i1 > i0 + 1) {
594
0
                int imed = (i0 + i1) / 2;
595
0
                if (cum[imed] <= codei)
596
0
                    i0 = imed;
597
0
                else
598
0
                    i1 = imed;
599
0
            }
600
0
            int r2a = i0, r2b = r2sub - i0;
601
0
            codei -= cum[r2a];
602
0
            norm2s[2 * i] = r2a;
603
0
            norm2s[2 * i + 1] = r2b;
604
605
0
            uint64_t code_a = codei / get_nv(ld - 1, r2b);
606
0
            uint64_t code_b = codei % get_nv(ld - 1, r2b);
607
608
0
            codes[2 * i] = code_a;
609
0
            codes[2 * i + 1] = code_b;
610
0
        }
611
0
        dim2 *= 2;
612
0
    }
613
614
0
    if (decode_cache_ld == 0) {
615
0
        for (int i = 0; i < dim; i++) {
616
0
            if (norm2s[i] == 0) {
617
0
                c[i] = 0;
618
0
            } else {
619
0
                float r = sqrt(norm2s[i]);
620
0
                assert(r * r == norm2s[i]);
621
0
                c[i] = codes[i] == 0 ? r : -r;
622
0
            }
623
0
        }
624
0
    } else {
625
0
        int subdim = 1 << decode_cache_ld;
626
0
        assert((dim2 * subdim) == dim);
627
628
0
        for (int i = 0; i < dim2; i++) {
629
0
            const std::vector<float>& cache = decode_cache[norm2s[i]];
630
0
            assert(codes[i] < cache.size());
631
0
            memcpy(c + i * subdim,
632
0
                   &cache[codes[i] * subdim],
633
0
                   sizeof(*c) * subdim);
634
0
        }
635
0
    }
636
0
}
637
638
// if not use_rec, instantiate an arbitrary harmless znc_rec
639
ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
640
0
        : ZnSphereCodec(dim, r2),
641
0
          use_rec((dim & (dim - 1)) == 0),
642
0
          znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
643
644
0
uint64_t ZnSphereCodecAlt::encode(const float* x) const {
645
0
    if (!use_rec) {
646
        // it's ok if the vector is not normalized
647
0
        return ZnSphereCodec::encode(x);
648
0
    } else {
649
        // find nearest centroid
650
0
        std::vector<float> centroid(dim);
651
0
        search(x, centroid.data());
652
0
        return znc_rec.encode(centroid.data());
653
0
    }
654
0
}
655
656
0
void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
657
0
    if (!use_rec) {
658
0
        ZnSphereCodec::decode(code, c);
659
0
    } else {
660
0
        znc_rec.decode(code, c);
661
0
    }
662
0
}
663
664
} // namespace faiss