Coverage Report

Created: 2025-09-18 06:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexLattice.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/IndexLattice.h>
11
#include <faiss/impl/FaissAssert.h>
12
#include <faiss/utils/distances.h>
13
#include <faiss/utils/hamming.h> // for the bitstring routines
14
15
namespace faiss {
16
17
IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
18
0
        : IndexFlatCodes(0, d, METRIC_L2),
19
0
          nsq(nsq),
20
0
          dsq(d / nsq),
21
0
          zn_sphere_codec(dsq, r2),
22
0
          scale_nbit(scale_nbit) {
23
0
    FAISS_THROW_IF_NOT(d % nsq == 0);
24
25
0
    lattice_nbit = 0;
26
0
    while (!(((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) {
27
0
        lattice_nbit++;
28
0
    }
29
30
0
    int total_nbit = (lattice_nbit + scale_nbit) * nsq;
31
32
0
    code_size = (total_nbit + 7) / 8;
33
34
0
    is_trained = false;
35
0
}
36
37
0
void IndexLattice::train(idx_t n, const float* x) {
38
    // compute ranges per sub-block
39
0
    trained.resize(nsq * 2);
40
0
    float* mins = trained.data();
41
0
    float* maxs = trained.data() + nsq;
42
0
    for (int sq = 0; sq < nsq; sq++) {
43
0
        mins[sq] = HUGE_VAL;
44
0
        maxs[sq] = -1;
45
0
    }
46
47
0
    for (idx_t i = 0; i < n; i++) {
48
0
        for (int sq = 0; sq < nsq; sq++) {
49
0
            float norm2 = fvec_norm_L2sqr(x + i * d + sq * dsq, dsq);
50
0
            if (norm2 > maxs[sq])
51
0
                maxs[sq] = norm2;
52
0
            if (norm2 < mins[sq])
53
0
                mins[sq] = norm2;
54
0
        }
55
0
    }
56
57
0
    for (int sq = 0; sq < nsq; sq++) {
58
0
        mins[sq] = sqrtf(mins[sq]);
59
0
        maxs[sq] = sqrtf(maxs[sq]);
60
0
    }
61
62
0
    is_trained = true;
63
0
}
64
65
/* The standalone codec interface */
66
0
size_t IndexLattice::sa_code_size() const {
67
0
    return code_size;
68
0
}
69
70
0
void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
71
0
    const float* mins = trained.data();
72
0
    const float* maxs = mins + nsq;
73
0
    int64_t sc = int64_t(1) << scale_nbit;
74
75
0
#pragma omp parallel for
76
0
    for (idx_t i = 0; i < n; i++) {
77
0
        BitstringWriter wr(codes + i * code_size, code_size);
78
0
        const float* xi = x + i * d;
79
0
        for (int j = 0; j < nsq; j++) {
80
0
            float nj = (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) * sc /
81
0
                    (maxs[j] - mins[j]);
82
0
            if (nj < 0)
83
0
                nj = 0;
84
0
            if (nj >= sc)
85
0
                nj = sc - 1;
86
0
            wr.write((int64_t)nj, scale_nbit);
87
0
            wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
88
0
            xi += dsq;
89
0
        }
90
0
    }
91
0
}
92
93
0
void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
94
0
    const float* mins = trained.data();
95
0
    const float* maxs = mins + nsq;
96
0
    float sc = int64_t(1) << scale_nbit;
97
0
    float r = sqrtf(zn_sphere_codec.r2);
98
99
0
#pragma omp parallel for
100
0
    for (idx_t i = 0; i < n; i++) {
101
0
        BitstringReader rd(codes + i * code_size, code_size);
102
0
        float* xi = x + i * d;
103
0
        for (int j = 0; j < nsq; j++) {
104
0
            float norm =
105
0
                    (rd.read(scale_nbit) + 0.5) * (maxs[j] - mins[j]) / sc +
106
0
                    mins[j];
107
0
            norm /= r;
108
0
            zn_sphere_codec.decode(rd.read(lattice_nbit), xi);
109
0
            for (int l = 0; l < dsq; l++) {
110
0
                xi[l] *= norm;
111
0
            }
112
0
            xi += dsq;
113
0
        }
114
0
    }
115
0
}
116
117
} // namespace faiss