Coverage Report

Created: 2025-10-31 14:22

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVFPQFastScan.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexIVFPQFastScan.h>
9
10
#include <cassert>
11
#include <cinttypes>
12
#include <cstdio>
13
14
#include <memory>
15
16
#include <faiss/impl/AuxIndexStructures.h>
17
#include <faiss/impl/FaissAssert.h>
18
#include <faiss/utils/distances.h>
19
#include <faiss/utils/simdlib.h>
20
21
#include <faiss/invlists/BlockInvertedLists.h>
22
23
#include <faiss/impl/pq4_fast_scan.h>
24
#include <faiss/impl/simd_result_handlers.h>
25
26
namespace faiss {
27
28
using namespace simd_result_handlers;
29
30
inline size_t roundup(size_t a, size_t b) {
31
    return (a + b - 1) / b * b;
32
}
33
34
IndexIVFPQFastScan::IndexIVFPQFastScan(
35
        Index* quantizer,
36
        size_t d,
37
        size_t nlist,
38
        size_t M,
39
        size_t nbits,
40
        MetricType metric,
41
        int bbs)
42
0
        : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
43
0
    by_residual = false; // set to false by default because it's faster
44
45
0
    init_fastscan(&pq, M, nbits, nlist, metric, bbs);
46
0
}
47
48
0
IndexIVFPQFastScan::IndexIVFPQFastScan() {
49
0
    by_residual = false;
50
0
    bbs = 0;
51
0
    M2 = 0;
52
0
}
53
54
IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
55
0
        : IndexIVFFastScan(
56
0
                  orig.quantizer,
57
0
                  orig.d,
58
0
                  orig.nlist,
59
0
                  orig.pq.code_size,
60
0
                  orig.metric_type),
61
0
          pq(orig.pq) {
62
0
    FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
63
64
0
    init_fastscan(
65
0
            &pq, orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
66
67
0
    by_residual = orig.by_residual;
68
0
    ntotal = orig.ntotal;
69
0
    is_trained = orig.is_trained;
70
0
    nprobe = orig.nprobe;
71
72
0
    precomputed_table.resize(orig.precomputed_table.size());
73
74
0
    if (precomputed_table.nbytes() > 0) {
75
0
        memcpy(precomputed_table.get(),
76
0
               orig.precomputed_table.data(),
77
0
               precomputed_table.nbytes());
78
0
    }
79
80
0
#pragma omp parallel for if (nlist > 100)
81
0
    for (idx_t i = 0; i < nlist; i++) {
82
0
        size_t nb = orig.invlists->list_size(i);
83
0
        size_t nb2 = roundup(nb, bbs);
84
0
        AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
85
0
        pq4_pack_codes(
86
0
                InvertedLists::ScopedCodes(orig.invlists, i).get(),
87
0
                nb,
88
0
                M,
89
0
                nb2,
90
0
                bbs,
91
0
                M2,
92
0
                tmp.get());
93
0
        invlists->add_entries(
94
0
                i,
95
0
                nb,
96
0
                InvertedLists::ScopedIds(orig.invlists, i).get(),
97
0
                tmp.get());
98
0
    }
99
100
0
    orig_invlists = orig.invlists;
101
0
}
102
103
/*********************************************************
104
 * Training
105
 *********************************************************/
106
107
void IndexIVFPQFastScan::train_encoder(
108
        idx_t n,
109
        const float* x,
110
0
        const idx_t* assign) {
111
0
    pq.verbose = verbose;
112
0
    pq.train(n, x);
113
114
0
    if (by_residual && metric_type == METRIC_L2) {
115
0
        precompute_table();
116
0
    }
117
0
}
118
119
0
idx_t IndexIVFPQFastScan::train_encoder_num_vectors() const {
120
0
    return pq.cp.max_points_per_centroid * pq.ksub;
121
0
}
122
123
0
void IndexIVFPQFastScan::precompute_table() {
124
0
    initialize_IVFPQ_precomputed_table(
125
0
            use_precomputed_table,
126
0
            quantizer,
127
0
            pq,
128
0
            precomputed_table,
129
0
            by_residual,
130
0
            verbose);
131
0
}
132
133
/*********************************************************
134
 * Code management functions
135
 *********************************************************/
136
137
void IndexIVFPQFastScan::encode_vectors(
138
        idx_t n,
139
        const float* x,
140
        const idx_t* list_nos,
141
        uint8_t* codes,
142
0
        bool include_listnos) const {
143
0
    if (by_residual) {
144
0
        AlignedTable<float> residuals(n * d);
145
0
        for (size_t i = 0; i < n; i++) {
146
0
            if (list_nos[i] < 0) {
147
0
                memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
148
0
            } else {
149
0
                quantizer->compute_residual(
150
0
                        x + i * d, residuals.data() + i * d, list_nos[i]);
151
0
            }
152
0
        }
153
0
        pq.compute_codes(residuals.data(), codes, n);
154
0
    } else {
155
0
        pq.compute_codes(x, codes, n);
156
0
    }
157
158
0
    if (include_listnos) {
159
0
        size_t coarse_size = coarse_code_size();
160
0
        for (idx_t i = n - 1; i >= 0; i--) {
161
0
            uint8_t* code = codes + i * (coarse_size + code_size);
162
0
            memmove(code + coarse_size, codes + i * code_size, code_size);
163
0
            encode_listno(list_nos[i], code);
164
0
        }
165
0
    }
166
0
}
167
168
/*********************************************************
169
 * Look-Up Table functions
170
 *********************************************************/
171
172
void fvec_madd_simd(
173
        size_t n,
174
        const float* a,
175
        float bf,
176
        const float* b,
177
0
        float* c) {
178
0
    assert(is_aligned_pointer(a));
179
0
    assert(is_aligned_pointer(b));
180
0
    assert(is_aligned_pointer(c));
181
0
    assert(n % 8 == 0);
182
0
    simd8float32 bf8(bf);
183
0
    n /= 8;
184
0
    for (size_t i = 0; i < n; i++) {
185
0
        simd8float32 ai(a);
186
0
        simd8float32 bi(b);
187
188
0
        simd8float32 ci = fmadd(bf8, bi, ai);
189
0
        ci.store(c);
190
0
        c += 8;
191
0
        a += 8;
192
0
        b += 8;
193
0
    }
194
0
}
195
196
0
bool IndexIVFPQFastScan::lookup_table_is_3d() const {
197
0
    return by_residual && metric_type == METRIC_L2;
198
0
}
199
200
void IndexIVFPQFastScan::compute_LUT(
201
        size_t n,
202
        const float* x,
203
        const CoarseQuantized& cq,
204
        AlignedTable<float>& dis_tables,
205
0
        AlignedTable<float>& biases) const {
206
0
    size_t dim12 = pq.ksub * pq.M;
207
0
    size_t d = pq.d;
208
0
    size_t nprobe = this->nprobe;
209
210
0
    if (by_residual) {
211
0
        if (metric_type == METRIC_L2) {
212
0
            dis_tables.resize(n * nprobe * dim12);
213
214
0
            if (use_precomputed_table == 1) {
215
0
                biases.resize(n * nprobe);
216
0
                memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe);
217
218
0
                AlignedTable<float> ip_table(n * dim12);
219
0
                pq.compute_inner_prod_tables(n, x, ip_table.get());
220
221
0
#pragma omp parallel for if (n * nprobe > 8000)
222
0
                for (idx_t ij = 0; ij < n * nprobe; ij++) {
223
0
                    idx_t i = ij / nprobe;
224
0
                    float* tab = dis_tables.get() + ij * dim12;
225
0
                    idx_t cij = cq.ids[ij];
226
227
0
                    if (cij >= 0) {
228
0
                        fvec_madd_simd(
229
0
                                dim12,
230
0
                                precomputed_table.get() + cij * dim12,
231
0
                                -2,
232
0
                                ip_table.get() + i * dim12,
233
0
                                tab);
234
0
                    } else {
235
                        // fill with NaNs so that they are ignored during
236
                        // LUT quantization
237
0
                        memset(tab, -1, sizeof(float) * dim12);
238
0
                    }
239
0
                }
240
241
0
            } else {
242
0
                std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
243
0
                biases.resize(n * nprobe);
244
0
                memset(biases.get(), 0, sizeof(float) * n * nprobe);
245
246
0
#pragma omp parallel for if (n * nprobe > 8000)
247
0
                for (idx_t ij = 0; ij < n * nprobe; ij++) {
248
0
                    idx_t i = ij / nprobe;
249
0
                    float* xij = &xrel[ij * d];
250
0
                    idx_t cij = cq.ids[ij];
251
252
0
                    if (cij >= 0) {
253
0
                        quantizer->compute_residual(x + i * d, xij, cij);
254
0
                    } else {
255
                        // will fill with NaNs
256
0
                        memset(xij, -1, sizeof(float) * d);
257
0
                    }
258
0
                }
259
260
0
                pq.compute_distance_tables(
261
0
                        n * nprobe, xrel.get(), dis_tables.get());
262
0
            }
263
264
0
        } else if (metric_type == METRIC_INNER_PRODUCT) {
265
0
            dis_tables.resize(n * dim12);
266
0
            pq.compute_inner_prod_tables(n, x, dis_tables.get());
267
            // compute_inner_prod_tables(pq, n, x, dis_tables.get());
268
269
0
            biases.resize(n * nprobe);
270
0
            memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe);
271
0
        } else {
272
0
            FAISS_THROW_FMT("metric %d not supported", metric_type);
273
0
        }
274
275
0
    } else {
276
0
        dis_tables.resize(n * dim12);
277
0
        if (metric_type == METRIC_L2) {
278
0
            pq.compute_distance_tables(n, x, dis_tables.get());
279
0
        } else if (metric_type == METRIC_INNER_PRODUCT) {
280
0
            pq.compute_inner_prod_tables(n, x, dis_tables.get());
281
0
        } else {
282
0
            FAISS_THROW_FMT("metric %d not supported", metric_type);
283
0
        }
284
0
    }
285
0
}
286
287
} // namespace faiss