Coverage Report

Created: 2025-09-14 10:44

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexLSH.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexLSH.h>
9
10
#include <cstdio>
11
#include <cstring>
12
13
#include <algorithm>
14
#include <memory>
15
16
#include <faiss/impl/FaissAssert.h>
17
#include <faiss/utils/hamming.h>
18
19
namespace faiss {
20
21
/***************************************************************
22
 * IndexLSH
23
 ***************************************************************/
24
25
IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
26
0
        : IndexFlatCodes((nbits + 7) / 8, d),
27
0
          nbits(nbits),
28
0
          rotate_data(rotate_data),
29
0
          train_thresholds(train_thresholds),
30
0
          rrot(d, nbits) {
31
0
    is_trained = !train_thresholds;
32
33
0
    if (rotate_data) {
34
0
        rrot.init(5);
35
0
    } else {
36
0
        FAISS_THROW_IF_NOT(d >= nbits);
37
0
    }
38
0
}
39
40
0
IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
41
42
0
const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
43
0
    float* xt = nullptr;
44
0
    if (rotate_data) {
45
        // also applies bias if exists
46
0
        xt = rrot.apply(n, x);
47
0
    } else if (d != nbits) {
48
0
        assert(nbits < d);
49
0
        xt = new float[nbits * n];
50
0
        float* xp = xt;
51
0
        for (idx_t i = 0; i < n; i++) {
52
0
            const float* xl = x + i * d;
53
0
            for (int j = 0; j < nbits; j++)
54
0
                *xp++ = xl[j];
55
0
        }
56
0
    }
57
58
0
    if (train_thresholds) {
59
0
        if (xt == nullptr) {
60
0
            xt = new float[nbits * n];
61
0
            memcpy(xt, x, sizeof(*x) * n * nbits);
62
0
        }
63
64
0
        float* xp = xt;
65
0
        for (idx_t i = 0; i < n; i++)
66
0
            for (int j = 0; j < nbits; j++)
67
0
                *xp++ -= thresholds[j];
68
0
    }
69
70
0
    return xt ? xt : x;
71
0
}
72
73
0
void IndexLSH::train(idx_t n, const float* x) {
74
0
    if (train_thresholds) {
75
0
        thresholds.resize(nbits);
76
0
        train_thresholds = false;
77
0
        const float* xt = apply_preprocess(n, x);
78
0
        std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
79
0
        train_thresholds = true;
80
81
0
        std::unique_ptr<float[]> transposed_x(new float[n * nbits]);
82
83
0
        for (idx_t i = 0; i < n; i++)
84
0
            for (idx_t j = 0; j < nbits; j++)
85
0
                transposed_x[j * n + i] = xt[i * nbits + j];
86
87
0
        for (idx_t i = 0; i < nbits; i++) {
88
0
            float* xi = transposed_x.get() + i * n;
89
            // std::nth_element
90
0
            std::sort(xi, xi + n);
91
0
            if (n % 2 == 1)
92
0
                thresholds[i] = xi[n / 2];
93
0
            else
94
0
                thresholds[i] = (xi[n / 2 - 1] + xi[n / 2]) / 2;
95
0
        }
96
0
    }
97
0
    is_trained = true;
98
0
}
99
100
void IndexLSH::search(
101
        idx_t n,
102
        const float* x,
103
        idx_t k,
104
        float* distances,
105
        idx_t* labels,
106
0
        const SearchParameters* params) const {
107
0
    FAISS_THROW_IF_NOT_MSG(
108
0
            !params, "search params not supported for this index");
109
0
    FAISS_THROW_IF_NOT(k > 0);
110
0
    FAISS_THROW_IF_NOT(is_trained);
111
0
    const float* xt = apply_preprocess(n, x);
112
0
    std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
113
114
0
    std::unique_ptr<uint8_t[]> qcodes(new uint8_t[n * code_size]);
115
116
0
    fvecs2bitvecs(xt, qcodes.get(), nbits, n);
117
118
0
    std::unique_ptr<int[]> idistances(new int[n * k]);
119
120
0
    int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances.get()};
121
122
0
    hammings_knn_hc(&res, qcodes.get(), codes.data(), ntotal, code_size, true);
123
124
    // convert distances to floats
125
0
    for (int i = 0; i < k * n; i++)
126
0
        distances[i] = idistances[i];
127
0
}
128
129
0
void IndexLSH::transfer_thresholds(LinearTransform* vt) {
130
0
    if (!train_thresholds)
131
0
        return;
132
0
    FAISS_THROW_IF_NOT(nbits == vt->d_out);
133
0
    if (!vt->have_bias) {
134
0
        vt->b.resize(nbits, 0);
135
0
        vt->have_bias = true;
136
0
    }
137
0
    for (int i = 0; i < nbits; i++)
138
0
        vt->b[i] -= thresholds[i];
139
0
    train_thresholds = false;
140
0
    thresholds.clear();
141
0
}
142
143
0
void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
144
0
    FAISS_THROW_IF_NOT(is_trained);
145
0
    const float* xt = apply_preprocess(n, x);
146
0
    std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
147
0
    fvecs2bitvecs(xt, bytes, nbits, n);
148
0
}
149
150
0
void IndexLSH::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
151
0
    float* xt = x;
152
0
    std::unique_ptr<float[]> del;
153
0
    if (rotate_data || nbits != d) {
154
0
        xt = new float[n * nbits];
155
0
        del.reset(xt);
156
0
    }
157
0
    bitvecs2fvecs(bytes, xt, nbits, n);
158
159
0
    if (train_thresholds) {
160
0
        float* xp = xt;
161
0
        for (idx_t i = 0; i < n; i++) {
162
0
            for (int j = 0; j < nbits; j++) {
163
0
                *xp++ += thresholds[j];
164
0
            }
165
0
        }
166
0
    }
167
168
0
    if (rotate_data) {
169
0
        rrot.reverse_transform(n, xt, x);
170
0
    } else if (nbits != d) {
171
0
        for (idx_t i = 0; i < n; i++) {
172
0
            memcpy(x + i * d, xt + i * nbits, nbits * sizeof(xt[0]));
173
0
        }
174
0
    }
175
0
}
176
177
} // namespace faiss