Coverage Report

Created: 2025-11-01 13:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexIVFAdditiveQuantizerFastScan.h>
9
10
#include <cinttypes>
11
#include <cstdio>
12
13
#include <memory>
14
15
#include <faiss/impl/AuxIndexStructures.h>
16
#include <faiss/impl/FaissAssert.h>
17
#include <faiss/impl/LookupTableScaler.h>
18
#include <faiss/impl/pq4_fast_scan.h>
19
#include <faiss/invlists/BlockInvertedLists.h>
20
#include <faiss/utils/distances.h>
21
#include <faiss/utils/hamming.h>
22
#include <faiss/utils/quantize_lut.h>
23
#include <faiss/utils/utils.h>
24
25
namespace faiss {
26
27
inline size_t roundup(size_t a, size_t b) {
28
    return (a + b - 1) / b * b;
29
}
30
31
IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan(
32
        Index* quantizer,
33
        AdditiveQuantizer* aq,
34
        size_t d,
35
        size_t nlist,
36
        MetricType metric,
37
        int bbs)
38
0
        : IndexIVFFastScan(quantizer, d, nlist, 0, metric) {
39
0
    if (aq != nullptr) {
40
0
        init(aq, nlist, metric, bbs);
41
0
    }
42
0
}
43
44
void IndexIVFAdditiveQuantizerFastScan::init(
45
        AdditiveQuantizer* aq,
46
        size_t nlist,
47
        MetricType metric,
48
0
        int bbs) {
49
0
    FAISS_THROW_IF_NOT(aq != nullptr);
50
0
    FAISS_THROW_IF_NOT(!aq->nbits.empty());
51
0
    FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
52
0
    if (metric == METRIC_INNER_PRODUCT) {
53
0
        FAISS_THROW_IF_NOT_MSG(
54
0
                aq->search_type == AdditiveQuantizer::ST_LUT_nonorm,
55
0
                "Search type must be ST_LUT_nonorm for IP metric");
56
0
    } else {
57
0
        FAISS_THROW_IF_NOT_MSG(
58
0
                aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
59
0
                        aq->search_type == AdditiveQuantizer::ST_norm_rq2x4,
60
0
                "Search type must be lsq2x4 or rq2x4 for L2 metric");
61
0
    }
62
63
0
    this->aq = aq;
64
0
    if (metric_type == METRIC_L2) {
65
0
        M = aq->M + 2; // 2x4 bits AQ
66
0
    } else {
67
0
        M = aq->M;
68
0
    }
69
0
    init_fastscan(aq, M, 4, nlist, metric, bbs);
70
71
0
    max_train_points = 1024 * ksub * M;
72
0
    by_residual = true;
73
0
}
74
75
IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan(
76
        const IndexIVFAdditiveQuantizer& orig,
77
        int bbs)
78
0
        : IndexIVFFastScan(
79
0
                  orig.quantizer,
80
0
                  orig.d,
81
0
                  orig.nlist,
82
0
                  0,
83
0
                  orig.metric_type),
84
0
          aq(orig.aq) {
85
0
    FAISS_THROW_IF_NOT(
86
0
            metric_type == METRIC_INNER_PRODUCT || !orig.by_residual);
87
88
0
    init(aq, nlist, metric_type, bbs);
89
90
0
    is_trained = orig.is_trained;
91
0
    ntotal = orig.ntotal;
92
0
    nprobe = orig.nprobe;
93
94
0
    for (size_t i = 0; i < nlist; i++) {
95
0
        size_t nb = orig.invlists->list_size(i);
96
0
        size_t nb2 = roundup(nb, bbs);
97
0
        AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
98
0
        pq4_pack_codes(
99
0
                InvertedLists::ScopedCodes(orig.invlists, i).get(),
100
0
                nb,
101
0
                M,
102
0
                nb2,
103
0
                bbs,
104
0
                M2,
105
0
                tmp.get());
106
0
        invlists->add_entries(
107
0
                i,
108
0
                nb,
109
0
                InvertedLists::ScopedIds(orig.invlists, i).get(),
110
0
                tmp.get());
111
0
    }
112
113
0
    orig_invlists = orig.invlists;
114
0
}
115
116
0
IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan() {
117
0
    bbs = 0;
118
0
    M2 = 0;
119
0
    aq = nullptr;
120
121
0
    is_trained = false;
122
0
}
123
124
0
IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() =
125
        default;
126
127
/*********************************************************
128
 * Training
129
 *********************************************************/
130
131
0
idx_t IndexIVFAdditiveQuantizerFastScan::train_encoder_num_vectors() const {
132
0
    return max_train_points;
133
0
}
134
135
void IndexIVFAdditiveQuantizerFastScan::train_encoder(
136
        idx_t n,
137
        const float* x,
138
0
        const idx_t* assign) {
139
0
    if (aq->is_trained) {
140
0
        return;
141
0
    }
142
143
0
    if (verbose) {
144
0
        printf("training additive quantizer on %d vectors\n", int(n));
145
0
    }
146
147
0
    if (verbose) {
148
0
        printf("training %zdx%zd additive quantizer on "
149
0
               "%" PRId64 " vectors in %dD\n",
150
0
               aq->M,
151
0
               ksub,
152
0
               n,
153
0
               d);
154
0
    }
155
0
    aq->verbose = verbose;
156
0
    aq->train(n, x);
157
158
    // train norm quantizer
159
0
    if (by_residual && metric_type == METRIC_L2) {
160
0
        std::vector<float> decoded_x(n * d);
161
0
        std::vector<uint8_t> x_codes(n * aq->code_size);
162
0
        aq->compute_codes(x, x_codes.data(), n);
163
0
        aq->decode(x_codes.data(), decoded_x.data(), n);
164
165
        // add coarse centroids
166
0
        std::vector<float> centroid(d);
167
0
        for (idx_t i = 0; i < n; i++) {
168
0
            auto xi = decoded_x.data() + i * d;
169
0
            quantizer->reconstruct(assign[i], centroid.data());
170
0
            fvec_add(d, centroid.data(), xi, xi);
171
0
        }
172
173
0
        std::vector<float> norms(n, 0);
174
0
        fvec_norms_L2sqr(norms.data(), decoded_x.data(), d, n);
175
176
        // re-train norm tables
177
0
        aq->train_norm(n, norms.data());
178
0
    }
179
180
0
    if (metric_type == METRIC_L2) {
181
0
        estimate_norm_scale(n, x);
182
0
    }
183
0
}
184
185
void IndexIVFAdditiveQuantizerFastScan::estimate_norm_scale(
186
        idx_t n,
187
0
        const float* x_in) {
188
0
    FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
189
190
0
    constexpr int seed = 0x980903;
191
0
    constexpr size_t max_points_estimated = 65536;
192
0
    size_t ns = n;
193
0
    const float* x = fvecs_maybe_subsample(
194
0
            d, &ns, max_points_estimated, x_in, verbose, seed);
195
0
    n = ns;
196
0
    std::unique_ptr<float[]> del_x;
197
0
    if (x != x_in) {
198
0
        del_x.reset((float*)x);
199
0
    }
200
201
0
    std::vector<idx_t> coarse_ids(n);
202
0
    std::vector<float> coarse_dis(n);
203
0
    quantizer->search(n, x, 1, coarse_dis.data(), coarse_ids.data());
204
205
0
    AlignedTable<float> dis_tables;
206
0
    AlignedTable<float> biases;
207
208
0
    size_t index_nprobe = nprobe;
209
0
    nprobe = 1;
210
0
    CoarseQuantized cq{index_nprobe, coarse_dis.data(), coarse_ids.data()};
211
0
    compute_LUT(n, x, cq, dis_tables, biases);
212
0
    nprobe = index_nprobe;
213
214
0
    float scale = 0;
215
216
0
#pragma omp parallel for reduction(+ : scale)
217
0
    for (idx_t i = 0; i < n; i++) {
218
0
        const float* lut = dis_tables.get() + i * M * ksub;
219
0
        scale += quantize_lut::aq_estimate_norm_scale(M, ksub, 2, lut);
220
0
    }
221
0
    scale /= n;
222
0
    norm_scale = (int)std::roundf(std::max(scale, 1.0f));
223
224
0
    if (verbose) {
225
0
        printf("estimated norm scale: %lf\n", scale);
226
0
        printf("rounded norm scale: %d\n", norm_scale);
227
0
    }
228
0
}
229
230
/*********************************************************
231
 * Code management functions
232
 *********************************************************/
233
234
void IndexIVFAdditiveQuantizerFastScan::encode_vectors(
235
        idx_t n,
236
        const float* x,
237
        const idx_t* list_nos,
238
        uint8_t* codes,
239
0
        bool include_listnos) const {
240
0
    idx_t bs = 65536;
241
0
    if (n > bs) {
242
0
        for (idx_t i0 = 0; i0 < n; i0 += bs) {
243
0
            idx_t i1 = std::min(n, i0 + bs);
244
0
            encode_vectors(
245
0
                    i1 - i0,
246
0
                    x + i0 * d,
247
0
                    list_nos + i0,
248
0
                    codes + i0 * code_size,
249
0
                    include_listnos);
250
0
        }
251
0
        return;
252
0
    }
253
254
0
    if (by_residual) {
255
0
        std::vector<float> residuals(n * d);
256
0
        std::vector<float> centroids(n * d);
257
258
0
#pragma omp parallel for if (n > 1000)
259
0
        for (idx_t i = 0; i < n; i++) {
260
0
            if (list_nos[i] < 0) {
261
0
                memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
262
0
            } else {
263
0
                quantizer->compute_residual(
264
0
                        x + i * d, residuals.data() + i * d, list_nos[i]);
265
0
            }
266
0
        }
267
268
0
#pragma omp parallel for if (n > 1000)
269
0
        for (idx_t i = 0; i < n; i++) {
270
0
            auto c = centroids.data() + i * d;
271
0
            quantizer->reconstruct(list_nos[i], c);
272
0
        }
273
274
0
        aq->compute_codes_add_centroids(
275
0
                residuals.data(), codes, n, centroids.data());
276
277
0
    } else {
278
0
        aq->compute_codes(x, codes, n);
279
0
    }
280
281
0
    if (include_listnos) {
282
0
        size_t coarse_size = coarse_code_size();
283
0
        for (idx_t i = n - 1; i >= 0; i--) {
284
0
            uint8_t* code = codes + i * (coarse_size + code_size);
285
0
            memmove(code + coarse_size, codes + i * code_size, code_size);
286
0
            encode_listno(list_nos[i], code);
287
0
        }
288
0
    }
289
0
}
290
291
/*********************************************************
292
 * Search functions
293
 *********************************************************/
294
295
void IndexIVFAdditiveQuantizerFastScan::search(
296
        idx_t n,
297
        const float* x,
298
        idx_t k,
299
        float* distances,
300
        idx_t* labels,
301
0
        const SearchParameters* params) const {
302
0
    FAISS_THROW_IF_NOT_MSG(
303
0
            !params, "search params not supported for this index");
304
305
0
    FAISS_THROW_IF_NOT(k > 0);
306
0
    bool rescale = (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2);
307
0
    if (!rescale) {
308
0
        IndexIVFFastScan::search(n, x, k, distances, labels);
309
0
        return;
310
0
    }
311
312
0
    NormTableScaler scaler(norm_scale);
313
0
    IndexIVFFastScan::CoarseQuantized cq{nprobe};
314
0
    search_dispatch_implem(n, x, k, distances, labels, cq, &scaler);
315
0
}
316
317
/*********************************************************
318
 * Look-Up Table functions
319
 *********************************************************/
320
321
/********************************************************
322
323
Let q denote the query vector,
324
    x denote the quantized database vector,
325
    c denote the corresponding IVF centroid,
326
    r denote the residual (x - c).
327
328
The L2 distance between q and x is:
329
330
    d(q, x) = (q - x)^2
331
            = (q - c - r)^2
332
            = q^2 - 2<q, c> - 2<q, r> + x^2
333
334
where q^2 is a constant for all x, <q,c> is only relevant to c,
335
and x^2 is the quantized database vector norm.
336
337
Different from IVFAdditiveQuantizer, we encode the quantized vector norm x^2
338
instead of r^2. So that we only need to compute one LUT for each query vector:
339
340
    LUT[m][k] = -2 * <q, codebooks[m][k]>
341
342
`-2<q,c>` could be precomputed in `compute_LUT` and store in `biases`.
343
if `by_residual=False`, `<q,c>` is simply 0.
344
345
346
347
About norm look-up tables:
348
349
To take advantage of the fast SIMD table lookups, we encode the norm by a 2x4
350
bits 1D additive quantizer (simply treat the scalar norm as a 1D vector).
351
352
Let `cm` denote the codebooks of the trained 2x4 bits 1D additive quantizer,
353
size (2, 16); `bm` denote the encoding code of the norm, a 8-bit integer; `cb`
354
denote the codebooks of the additive quantizer to encode the database vector,
355
size (M, 16).
356
357
The decoded norm is:
358
359
    decoded_norm = cm[0][bm & 15] + cm[1][bm >> 4]
360
361
The decoding is actually doing a table look-up.
362
363
We combine the norm LUTs and the IP LUTs together:
364
365
    LUT is a 2D table, size (M + 2, 16)
366
    if m < M :
367
        LUT[m][k] = -2 * <q, cb[m][k]>
368
    else:
369
        LUT[m][k] = cm[m - M][k]
370
371
********************************************************/
372
373
0
bool IndexIVFAdditiveQuantizerFastScan::lookup_table_is_3d() const {
374
0
    return false;
375
0
}
376
377
void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
378
        size_t n,
379
        const float* x,
380
        const CoarseQuantized& cq,
381
        AlignedTable<float>& dis_tables,
382
0
        AlignedTable<float>& biases) const {
383
0
    const size_t dim12 = ksub * M;
384
0
    const size_t ip_dim12 = aq->M * ksub;
385
0
    const size_t nprobe = cq.nprobe;
386
387
0
    dis_tables.resize(n * dim12);
388
389
0
    float coef = 1.0f;
390
0
    if (metric_type == METRIC_L2) {
391
0
        coef = -2.0f;
392
0
    }
393
394
0
    if (by_residual) {
395
        // bias = coef * <q, c>
396
        // NOTE: q^2 is not added to `biases`
397
0
        biases.resize(n * nprobe);
398
0
#pragma omp parallel
399
0
        {
400
0
            std::vector<float> centroid(d);
401
0
            float* c = centroid.data();
402
403
0
#pragma omp for
404
0
            for (idx_t ij = 0; ij < n * nprobe; ij++) {
405
0
                int i = ij / nprobe;
406
0
                quantizer->reconstruct(cq.ids[ij], c);
407
0
                biases[ij] = coef * fvec_inner_product(c, x + i * d, d);
408
0
            }
409
0
        }
410
0
    }
411
412
0
    if (metric_type == METRIC_L2) {
413
0
        const size_t norm_dim12 = 2 * ksub;
414
415
        // inner product look-up tables
416
0
        aq->compute_LUT(n, x, dis_tables.data(), -2.0f, dim12);
417
418
        // copy and rescale norm look-up tables
419
0
        auto norm_tabs = aq->norm_tabs;
420
0
        if (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2) {
421
0
            for (size_t i = 0; i < norm_tabs.size(); i++) {
422
0
                norm_tabs[i] /= norm_scale;
423
0
            }
424
0
        }
425
0
        const float* norm_lut = norm_tabs.data();
426
0
        FAISS_THROW_IF_NOT(norm_tabs.size() == norm_dim12);
427
428
        // combine them
429
0
#pragma omp parallel for if (n > 100)
430
0
        for (idx_t i = 0; i < n; i++) {
431
0
            float* tab = dis_tables.data() + i * dim12 + ip_dim12;
432
0
            memcpy(tab, norm_lut, norm_dim12 * sizeof(*tab));
433
0
        }
434
435
0
    } else if (metric_type == METRIC_INNER_PRODUCT) {
436
0
        aq->compute_LUT(n, x, dis_tables.get());
437
0
    } else {
438
0
        FAISS_THROW_FMT("metric %d not supported", metric_type);
439
0
    }
440
0
}
441
442
/********** IndexIVFLocalSearchQuantizerFastScan ************/
443
IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan(
444
        Index* quantizer,
445
        size_t d,
446
        size_t nlist,
447
        size_t M,
448
        size_t nbits,
449
        MetricType metric,
450
        Search_type_t search_type,
451
        int bbs)
452
0
        : IndexIVFAdditiveQuantizerFastScan(
453
0
                  quantizer,
454
0
                  nullptr,
455
0
                  d,
456
0
                  nlist,
457
0
                  metric,
458
0
                  bbs),
459
0
          lsq(d, M, nbits, search_type) {
460
0
    FAISS_THROW_IF_NOT(nbits == 4);
461
0
    init(&lsq, nlist, metric, bbs);
462
0
}
463
464
0
IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan() {
465
0
    aq = &lsq;
466
0
}
467
468
/********** IndexIVFResidualQuantizerFastScan ************/
469
IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan(
470
        Index* quantizer,
471
        size_t d,
472
        size_t nlist,
473
        size_t M,
474
        size_t nbits,
475
        MetricType metric,
476
        Search_type_t search_type,
477
        int bbs)
478
0
        : IndexIVFAdditiveQuantizerFastScan(
479
0
                  quantizer,
480
0
                  nullptr,
481
0
                  d,
482
0
                  nlist,
483
0
                  metric,
484
0
                  bbs),
485
0
          rq(d, M, nbits, search_type) {
486
0
    FAISS_THROW_IF_NOT(nbits == 4);
487
0
    init(&rq, nlist, metric, bbs);
488
0
}
489
490
0
IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan() {
491
0
    aq = &rq;
492
0
}
493
494
/********** IndexIVFProductLocalSearchQuantizerFastScan ************/
495
IndexIVFProductLocalSearchQuantizerFastScan::
496
        IndexIVFProductLocalSearchQuantizerFastScan(
497
                Index* quantizer,
498
                size_t d,
499
                size_t nlist,
500
                size_t nsplits,
501
                size_t Msub,
502
                size_t nbits,
503
                MetricType metric,
504
                Search_type_t search_type,
505
                int bbs)
506
0
        : IndexIVFAdditiveQuantizerFastScan(
507
0
                  quantizer,
508
0
                  nullptr,
509
0
                  d,
510
0
                  nlist,
511
0
                  metric,
512
0
                  bbs),
513
0
          plsq(d, nsplits, Msub, nbits, search_type) {
514
0
    FAISS_THROW_IF_NOT(nbits == 4);
515
0
    init(&plsq, nlist, metric, bbs);
516
0
}
517
518
IndexIVFProductLocalSearchQuantizerFastScan::
519
0
        IndexIVFProductLocalSearchQuantizerFastScan() {
520
0
    aq = &plsq;
521
0
}
522
523
/********** IndexIVFProductResidualQuantizerFastScan ************/
524
IndexIVFProductResidualQuantizerFastScan::
525
        IndexIVFProductResidualQuantizerFastScan(
526
                Index* quantizer,
527
                size_t d,
528
                size_t nlist,
529
                size_t nsplits,
530
                size_t Msub,
531
                size_t nbits,
532
                MetricType metric,
533
                Search_type_t search_type,
534
                int bbs)
535
0
        : IndexIVFAdditiveQuantizerFastScan(
536
0
                  quantizer,
537
0
                  nullptr,
538
0
                  d,
539
0
                  nlist,
540
0
                  metric,
541
0
                  bbs),
542
0
          prq(d, nsplits, Msub, nbits, search_type) {
543
0
    FAISS_THROW_IF_NOT(nbits == 4);
544
0
    init(&prq, nlist, metric, bbs);
545
0
}
546
547
IndexIVFProductResidualQuantizerFastScan::
548
0
        IndexIVFProductResidualQuantizerFastScan() {
549
0
    aq = &prq;
550
0
}
551
552
} // namespace faiss