Coverage Report

Created: 2025-10-24 10:16

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVFIndependentQuantizer.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexIVFIndependentQuantizer.h>
9
#include <faiss/IndexIVFPQ.h>
10
#include <faiss/impl/FaissAssert.h>
11
#include <faiss/utils/utils.h>
12
13
namespace faiss {
14
15
IndexIVFIndependentQuantizer::IndexIVFIndependentQuantizer(
16
        Index* quantizer,
17
        IndexIVF* index_ivf,
18
        VectorTransform* vt)
19
0
        : Index(quantizer->d, index_ivf->metric_type),
20
0
          quantizer(quantizer),
21
0
          vt(vt),
22
0
          index_ivf(index_ivf) {
23
0
    if (vt) {
24
0
        FAISS_THROW_IF_NOT_MSG(
25
0
                vt->d_in == d && vt->d_out == index_ivf->d,
26
0
                "invalid vector dimensions");
27
0
    } else {
28
0
        FAISS_THROW_IF_NOT_MSG(index_ivf->d == d, "invalid vector dimensions");
29
0
    }
30
31
0
    if (quantizer->is_trained && quantizer->ntotal != 0) {
32
0
        FAISS_THROW_IF_NOT(quantizer->ntotal == index_ivf->nlist);
33
0
    }
34
0
    if (index_ivf->is_trained && vt) {
35
0
        FAISS_THROW_IF_NOT(vt->is_trained);
36
0
    }
37
0
    ntotal = index_ivf->ntotal;
38
0
    is_trained =
39
0
            (quantizer->is_trained && quantizer->ntotal == index_ivf->nlist &&
40
0
             (!vt || vt->is_trained) && index_ivf->is_trained);
41
42
    // disable precomputed tables because they use the distances that are
43
    // provided by the coarse quantizer (that are out of sync with the IVFPQ)
44
0
    if (auto index_ivfpq = dynamic_cast<IndexIVFPQ*>(index_ivf)) {
45
0
        index_ivfpq->use_precomputed_table = -1;
46
0
    }
47
0
}
48
49
0
IndexIVFIndependentQuantizer::~IndexIVFIndependentQuantizer() {
50
0
    if (own_fields) {
51
0
        delete quantizer;
52
0
        delete index_ivf;
53
0
        delete vt;
54
0
    }
55
0
}
56
57
namespace {
58
59
struct VTransformedVectors : TransformedVectors {
60
    VTransformedVectors(const VectorTransform* vt, idx_t n, const float* x)
61
0
            : TransformedVectors(x, vt ? vt->apply(n, x) : x) {}
62
};
63
64
struct SubsampledVectors : TransformedVectors {
65
    SubsampledVectors(int d, idx_t* n, idx_t max_n, const float* x)
66
0
            : TransformedVectors(
67
0
                      x,
68
0
                      fvecs_maybe_subsample(d, (size_t*)n, max_n, x, true)) {}
69
};
70
71
} // anonymous namespace
72
73
0
void IndexIVFIndependentQuantizer::add(idx_t n, const float* x) {
74
0
    std::vector<float> D(n);
75
0
    std::vector<idx_t> I(n);
76
0
    quantizer->search(n, x, 1, D.data(), I.data());
77
78
0
    VTransformedVectors tv(vt, n, x);
79
80
0
    index_ivf->add_core(n, tv.x, nullptr, I.data());
81
0
}
82
83
void IndexIVFIndependentQuantizer::search(
84
        idx_t n,
85
        const float* x,
86
        idx_t k,
87
        float* distances,
88
        idx_t* labels,
89
0
        const SearchParameters* params) const {
90
0
    FAISS_THROW_IF_NOT_MSG(!params, "search parameters not supported");
91
0
    int nprobe = index_ivf->nprobe;
92
0
    std::vector<float> D(n * nprobe);
93
0
    std::vector<idx_t> I(n * nprobe);
94
0
    quantizer->search(n, x, nprobe, D.data(), I.data());
95
96
0
    VTransformedVectors tv(vt, n, x);
97
98
0
    index_ivf->search_preassigned(
99
0
            n, tv.x, k, I.data(), D.data(), distances, labels, false);
100
0
}
101
102
0
void IndexIVFIndependentQuantizer::reset() {
103
0
    index_ivf->reset();
104
0
    ntotal = 0;
105
0
}
106
107
0
void IndexIVFIndependentQuantizer::train(idx_t n, const float* x) {
108
    // quantizer training
109
0
    size_t nlist = index_ivf->nlist;
110
0
    Level1Quantizer l1(quantizer, nlist);
111
0
    l1.train_q1(n, x, verbose, metric_type);
112
113
    // train the VectorTransform
114
0
    if (vt && !vt->is_trained) {
115
0
        if (verbose) {
116
0
            printf("IndexIVFIndependentQuantizer: train the VectorTransform\n");
117
0
        }
118
0
        vt->train(n, x);
119
0
    }
120
121
    // get the centroids from the quantizer, transform them and
122
    // add them to the index_ivf's quantizer
123
0
    if (verbose) {
124
0
        printf("IndexIVFIndependentQuantizer: extract the main quantizer centroids\n");
125
0
    }
126
0
    std::vector<float> centroids(nlist * d);
127
0
    quantizer->reconstruct_n(0, nlist, centroids.data());
128
0
    VTransformedVectors tcent(vt, nlist, centroids.data());
129
130
0
    if (verbose) {
131
0
        printf("IndexIVFIndependentQuantizer: add centroids to the secondary quantizer\n");
132
0
    }
133
0
    if (!index_ivf->quantizer->is_trained) {
134
0
        index_ivf->quantizer->train(nlist, tcent.x);
135
0
    }
136
0
    index_ivf->quantizer->add(nlist, tcent.x);
137
138
    // train the payload
139
140
    // optional subsampling
141
0
    idx_t max_nt = index_ivf->train_encoder_num_vectors();
142
0
    if (max_nt <= 0) {
143
0
        max_nt = (size_t)1 << 35;
144
0
    }
145
0
    SubsampledVectors sv(index_ivf->d, &n, max_nt, x);
146
147
    // transform subsampled vectors
148
0
    VTransformedVectors tv(vt, n, sv.x);
149
150
0
    if (verbose) {
151
0
        printf("IndexIVFIndependentQuantizer: train encoder\n");
152
0
    }
153
154
0
    if (index_ivf->by_residual) {
155
        // assign with quantizer
156
0
        std::vector<idx_t> assign(n);
157
0
        quantizer->assign(n, sv.x, assign.data());
158
159
        // compute residual with IVF quantizer
160
0
        std::vector<float> residuals(n * index_ivf->d);
161
0
        index_ivf->quantizer->compute_residual_n(
162
0
                n, tv.x, residuals.data(), assign.data());
163
164
0
        index_ivf->train_encoder(n, residuals.data(), assign.data());
165
0
    } else {
166
0
        index_ivf->train_encoder(n, tv.x, nullptr);
167
0
    }
168
0
    index_ivf->is_trained = true;
169
0
    is_trained = true;
170
0
}
171
172
} // namespace faiss