Coverage Report

Created: 2025-09-18 20:22

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVFRaBitQ.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexIVFRaBitQ.h>
9
10
#include <omp.h>
11
12
#include <cstddef>
13
#include <cstdint>
14
#include <memory>
15
#include <vector>
16
17
#include <faiss/impl/FaissAssert.h>
18
#include <faiss/impl/RaBitQuantizer.h>
19
20
namespace faiss {
21
22
IndexIVFRaBitQ::IndexIVFRaBitQ(
23
        Index* quantizer,
24
        const size_t d,
25
        const size_t nlist,
26
        MetricType metric)
27
0
        : IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) {
28
0
    code_size = rabitq.code_size;
29
0
    invlists->code_size = code_size;
30
0
    is_trained = false;
31
32
0
    by_residual = true;
33
0
}
34
35
0
IndexIVFRaBitQ::IndexIVFRaBitQ() {
36
0
    by_residual = true;
37
0
}
38
39
void IndexIVFRaBitQ::train_encoder(
40
        idx_t n,
41
        const float* x,
42
0
        const idx_t* assign) {
43
0
    rabitq.train(n, x);
44
0
}
45
46
void IndexIVFRaBitQ::encode_vectors(
47
        idx_t n,
48
        const float* x,
49
        const idx_t* list_nos,
50
        uint8_t* codes,
51
0
        bool include_listnos) const {
52
0
    size_t coarse_size = include_listnos ? coarse_code_size() : 0;
53
0
    memset(codes, 0, (code_size + coarse_size) * n);
54
55
0
#pragma omp parallel if (n > 1000)
56
0
    {
57
0
        std::vector<float> centroid(d);
58
59
0
#pragma omp for
60
0
        for (idx_t i = 0; i < n; i++) {
61
0
            int64_t list_no = list_nos[i];
62
0
            if (list_no >= 0) {
63
0
                const float* xi = x + i * d;
64
0
                uint8_t* code = codes + i * (code_size + coarse_size);
65
66
                // both by_residual and !by_residual lead to the same code
67
0
                quantizer->reconstruct(list_no, centroid.data());
68
0
                rabitq.compute_codes_core(
69
0
                        xi, code + coarse_size, 1, centroid.data());
70
71
0
                if (coarse_size) {
72
0
                    encode_listno(list_no, code);
73
0
                }
74
0
            }
75
0
        }
76
0
    }
77
0
}
78
79
void IndexIVFRaBitQ::add_core(
80
        idx_t n,
81
        const float* x,
82
        const idx_t* xids,
83
        const idx_t* precomputed_idx,
84
0
        void* inverted_list_context) {
85
0
    FAISS_THROW_IF_NOT(is_trained);
86
87
0
    DirectMapAdd dm_add(direct_map, n, xids);
88
89
0
#pragma omp parallel
90
0
    {
91
0
        std::vector<uint8_t> one_code(code_size);
92
0
        std::vector<float> centroid(d);
93
94
0
        int nt = omp_get_num_threads();
95
0
        int rank = omp_get_thread_num();
96
97
        // each thread takes care of a subset of lists
98
0
        for (size_t i = 0; i < n; i++) {
99
0
            int64_t list_no = precomputed_idx[i];
100
0
            if (list_no >= 0 && list_no % nt == rank) {
101
0
                int64_t id = xids ? xids[i] : ntotal + i;
102
103
0
                const float* xi = x + i * d;
104
105
                // both by_residual and !by_residual lead to the same code
106
0
                quantizer->reconstruct(list_no, centroid.data());
107
0
                rabitq.compute_codes_core(
108
0
                        xi, one_code.data(), 1, centroid.data());
109
110
0
                size_t ofs = invlists->add_entry(
111
0
                        list_no, id, one_code.data(), inverted_list_context);
112
113
0
                dm_add.add(i, list_no, ofs);
114
115
0
            } else if (rank == 0 && list_no == -1) {
116
0
                dm_add.add(i, -1, 0);
117
0
            }
118
0
        }
119
0
    }
120
121
0
    ntotal += n;
122
0
}
123
124
struct RaBitInvertedListScanner : InvertedListScanner {
125
    const IndexIVFRaBitQ& ivf_rabitq;
126
127
    std::vector<float> reconstructed_centroid;
128
    std::vector<float> query_vector;
129
130
    std::unique_ptr<FlatCodesDistanceComputer> dc;
131
132
    uint8_t qb = 0;
133
134
    RaBitInvertedListScanner(
135
            const IndexIVFRaBitQ& ivf_rabitq_in,
136
            bool store_pairs = false,
137
            const IDSelector* sel = nullptr,
138
            uint8_t qb_in = 0)
139
0
            : InvertedListScanner(store_pairs, sel),
140
0
              ivf_rabitq{ivf_rabitq_in},
141
0
              qb{qb_in} {
142
0
        keep_max = is_similarity_metric(ivf_rabitq.metric_type);
143
0
        code_size = ivf_rabitq.code_size;
144
0
    }
145
146
    /// from now on we handle this query.
147
0
    void set_query(const float* query_vector_in) override {
148
0
        query_vector.assign(query_vector_in, query_vector_in + ivf_rabitq.d);
149
150
0
        internal_try_setup_dc();
151
0
    }
152
153
    /// following codes come from this inverted list
154
0
    void set_list(idx_t list_no, float coarse_dis) override {
155
0
        this->list_no = list_no;
156
157
0
        reconstructed_centroid.resize(ivf_rabitq.d);
158
0
        ivf_rabitq.quantizer->reconstruct(
159
0
                list_no, reconstructed_centroid.data());
160
161
0
        internal_try_setup_dc();
162
0
    }
163
164
    /// compute a single query-to-code distance
165
0
    float distance_to_code(const uint8_t* code) const override {
166
0
        return dc->distance_to_code(code);
167
0
    }
168
169
0
    void internal_try_setup_dc() {
170
0
        if (!query_vector.empty() && !reconstructed_centroid.empty()) {
171
            // both query_vector and centroid are available!
172
            // set up DistanceComputer
173
0
            dc.reset(ivf_rabitq.rabitq.get_distance_computer(
174
0
                    qb, reconstructed_centroid.data()));
175
176
0
            dc->set_query(query_vector.data());
177
0
        }
178
0
    }
179
};
180
181
InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner(
182
        bool store_pairs,
183
        const IDSelector* sel,
184
0
        const IVFSearchParameters* search_params_in) const {
185
0
    uint8_t used_qb = qb;
186
0
    if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
187
0
                search_params_in)) {
188
0
        used_qb = params->qb;
189
0
    }
190
191
0
    return new RaBitInvertedListScanner(*this, store_pairs, sel, used_qb);
192
0
}
193
194
void IndexIVFRaBitQ::reconstruct_from_offset(
195
        int64_t list_no,
196
        int64_t offset,
197
0
        float* recons) const {
198
0
    const uint8_t* code = invlists->get_single_code(list_no, offset);
199
200
0
    std::vector<float> centroid(d);
201
0
    quantizer->reconstruct(list_no, centroid.data());
202
203
0
    rabitq.decode_core(code, recons, 1, centroid.data());
204
0
}
205
206
0
void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
207
0
    size_t coarse_size = coarse_code_size();
208
209
0
#pragma omp parallel
210
0
    {
211
0
        std::vector<float> centroid(d);
212
213
0
#pragma omp for
214
0
        for (idx_t i = 0; i < n; i++) {
215
0
            const uint8_t* code = codes + i * (code_size + coarse_size);
216
0
            int64_t list_no = decode_listno(code);
217
0
            float* xi = x + i * d;
218
219
0
            quantizer->reconstruct(list_no, centroid.data());
220
0
            rabitq.decode_core(code + coarse_size, xi, 1, centroid.data());
221
0
        }
222
0
    }
223
0
}
224
225
struct IVFRaBitDistanceComputer : DistanceComputer {
226
    const float* q = nullptr;
227
    const IndexIVFRaBitQ* parent = nullptr;
228
229
    void set_query(const float* x) override;
230
231
    float operator()(idx_t i) override;
232
233
    float symmetric_dis(idx_t i, idx_t j) override;
234
};
235
236
0
void IVFRaBitDistanceComputer::set_query(const float* x) {
237
0
    q = x;
238
0
}
239
240
0
float IVFRaBitDistanceComputer::operator()(idx_t i) {
241
    // find the appropriate list
242
0
    idx_t lo = parent->direct_map.get(i);
243
0
    uint64_t list_no = lo_listno(lo);
244
0
    uint64_t offset = lo_offset(lo);
245
246
0
    const uint8_t* code = parent->invlists->get_single_code(list_no, offset);
247
248
    // ok, we know the appropriate cluster that we need
249
0
    std::vector<float> centroid(parent->d);
250
0
    parent->quantizer->reconstruct(list_no, centroid.data());
251
252
    // compute the distance
253
0
    float distance = 0;
254
255
0
    std::unique_ptr<FlatCodesDistanceComputer> dc(
256
0
            parent->rabitq.get_distance_computer(parent->qb, centroid.data()));
257
0
    dc->set_query(q);
258
0
    distance = dc->distance_to_code(code);
259
260
    // deallocate
261
0
    parent->invlists->release_codes(list_no, code);
262
263
    // done
264
0
    return distance;
265
0
}
266
267
0
float IVFRaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
268
0
    FAISS_THROW_MSG("Not implemented");
269
0
}
270
271
0
DistanceComputer* IndexIVFRaBitQ::get_distance_computer() const {
272
0
    IVFRaBitDistanceComputer* dc = new IVFRaBitDistanceComputer;
273
0
    dc->parent = this;
274
0
    return dc;
275
0
}
276
277
} // namespace faiss