/root/doris/contrib/faiss/faiss/IndexIVFIndependentQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexIVFIndependentQuantizer.h> |
9 | | #include <faiss/IndexIVFPQ.h> |
10 | | #include <faiss/impl/FaissAssert.h> |
11 | | #include <faiss/utils/utils.h> |
12 | | |
13 | | namespace faiss { |
14 | | |
15 | | IndexIVFIndependentQuantizer::IndexIVFIndependentQuantizer( |
16 | | Index* quantizer, |
17 | | IndexIVF* index_ivf, |
18 | | VectorTransform* vt) |
19 | 0 | : Index(quantizer->d, index_ivf->metric_type), |
20 | 0 | quantizer(quantizer), |
21 | 0 | vt(vt), |
22 | 0 | index_ivf(index_ivf) { |
23 | 0 | if (vt) { |
24 | 0 | FAISS_THROW_IF_NOT_MSG( |
25 | 0 | vt->d_in == d && vt->d_out == index_ivf->d, |
26 | 0 | "invalid vector dimensions"); |
27 | 0 | } else { |
28 | 0 | FAISS_THROW_IF_NOT_MSG(index_ivf->d == d, "invalid vector dimensions"); |
29 | 0 | } |
30 | | |
31 | 0 | if (quantizer->is_trained && quantizer->ntotal != 0) { |
32 | 0 | FAISS_THROW_IF_NOT(quantizer->ntotal == index_ivf->nlist); |
33 | 0 | } |
34 | 0 | if (index_ivf->is_trained && vt) { |
35 | 0 | FAISS_THROW_IF_NOT(vt->is_trained); |
36 | 0 | } |
37 | 0 | ntotal = index_ivf->ntotal; |
38 | 0 | is_trained = |
39 | 0 | (quantizer->is_trained && quantizer->ntotal == index_ivf->nlist && |
40 | 0 | (!vt || vt->is_trained) && index_ivf->is_trained); |
41 | | |
42 | | // disable precomputed tables because they use the distances that are |
43 | | // provided by the coarse quantizer (that are out of sync with the IVFPQ) |
44 | 0 | if (auto index_ivfpq = dynamic_cast<IndexIVFPQ*>(index_ivf)) { |
45 | 0 | index_ivfpq->use_precomputed_table = -1; |
46 | 0 | } |
47 | 0 | } |
48 | | |
49 | 0 | IndexIVFIndependentQuantizer::~IndexIVFIndependentQuantizer() { |
50 | 0 | if (own_fields) { |
51 | 0 | delete quantizer; |
52 | 0 | delete index_ivf; |
53 | 0 | delete vt; |
54 | 0 | } |
55 | 0 | } |
56 | | |
57 | | namespace { |
58 | | |
59 | | struct VTransformedVectors : TransformedVectors { |
60 | | VTransformedVectors(const VectorTransform* vt, idx_t n, const float* x) |
61 | 0 | : TransformedVectors(x, vt ? vt->apply(n, x) : x) {} |
62 | | }; |
63 | | |
64 | | struct SubsampledVectors : TransformedVectors { |
65 | | SubsampledVectors(int d, idx_t* n, idx_t max_n, const float* x) |
66 | 0 | : TransformedVectors( |
67 | 0 | x, |
68 | 0 | fvecs_maybe_subsample(d, (size_t*)n, max_n, x, true)) {} |
69 | | }; |
70 | | |
71 | | } // anonymous namespace |
72 | | |
73 | 0 | void IndexIVFIndependentQuantizer::add(idx_t n, const float* x) { |
74 | 0 | std::vector<float> D(n); |
75 | 0 | std::vector<idx_t> I(n); |
76 | 0 | quantizer->search(n, x, 1, D.data(), I.data()); |
77 | |
|
78 | 0 | VTransformedVectors tv(vt, n, x); |
79 | |
|
80 | 0 | index_ivf->add_core(n, tv.x, nullptr, I.data()); |
81 | 0 | } |
82 | | |
83 | | void IndexIVFIndependentQuantizer::search( |
84 | | idx_t n, |
85 | | const float* x, |
86 | | idx_t k, |
87 | | float* distances, |
88 | | idx_t* labels, |
89 | 0 | const SearchParameters* params) const { |
90 | 0 | FAISS_THROW_IF_NOT_MSG(!params, "search parameters not supported"); |
91 | 0 | int nprobe = index_ivf->nprobe; |
92 | 0 | std::vector<float> D(n * nprobe); |
93 | 0 | std::vector<idx_t> I(n * nprobe); |
94 | 0 | quantizer->search(n, x, nprobe, D.data(), I.data()); |
95 | |
|
96 | 0 | VTransformedVectors tv(vt, n, x); |
97 | |
|
98 | 0 | index_ivf->search_preassigned( |
99 | 0 | n, tv.x, k, I.data(), D.data(), distances, labels, false); |
100 | 0 | } |
101 | | |
102 | 0 | void IndexIVFIndependentQuantizer::reset() { |
103 | 0 | index_ivf->reset(); |
104 | 0 | ntotal = 0; |
105 | 0 | } |
106 | | |
107 | 0 | void IndexIVFIndependentQuantizer::train(idx_t n, const float* x) { |
108 | | // quantizer training |
109 | 0 | size_t nlist = index_ivf->nlist; |
110 | 0 | Level1Quantizer l1(quantizer, nlist); |
111 | 0 | l1.train_q1(n, x, verbose, metric_type); |
112 | | |
113 | | // train the VectorTransform |
114 | 0 | if (vt && !vt->is_trained) { |
115 | 0 | if (verbose) { |
116 | 0 | printf("IndexIVFIndependentQuantizer: train the VectorTransform\n"); |
117 | 0 | } |
118 | 0 | vt->train(n, x); |
119 | 0 | } |
120 | | |
121 | | // get the centroids from the quantizer, transform them and |
122 | | // add them to the index_ivf's quantizer |
123 | 0 | if (verbose) { |
124 | 0 | printf("IndexIVFIndependentQuantizer: extract the main quantizer centroids\n"); |
125 | 0 | } |
126 | 0 | std::vector<float> centroids(nlist * d); |
127 | 0 | quantizer->reconstruct_n(0, nlist, centroids.data()); |
128 | 0 | VTransformedVectors tcent(vt, nlist, centroids.data()); |
129 | |
|
130 | 0 | if (verbose) { |
131 | 0 | printf("IndexIVFIndependentQuantizer: add centroids to the secondary quantizer\n"); |
132 | 0 | } |
133 | 0 | if (!index_ivf->quantizer->is_trained) { |
134 | 0 | index_ivf->quantizer->train(nlist, tcent.x); |
135 | 0 | } |
136 | 0 | index_ivf->quantizer->add(nlist, tcent.x); |
137 | | |
138 | | // train the payload |
139 | | |
140 | | // optional subsampling |
141 | 0 | idx_t max_nt = index_ivf->train_encoder_num_vectors(); |
142 | 0 | if (max_nt <= 0) { |
143 | 0 | max_nt = (size_t)1 << 35; |
144 | 0 | } |
145 | 0 | SubsampledVectors sv(index_ivf->d, &n, max_nt, x); |
146 | | |
147 | | // transform subsampled vectors |
148 | 0 | VTransformedVectors tv(vt, n, sv.x); |
149 | |
|
150 | 0 | if (verbose) { |
151 | 0 | printf("IndexIVFIndependentQuantizer: train encoder\n"); |
152 | 0 | } |
153 | |
|
154 | 0 | if (index_ivf->by_residual) { |
155 | | // assign with quantizer |
156 | 0 | std::vector<idx_t> assign(n); |
157 | 0 | quantizer->assign(n, sv.x, assign.data()); |
158 | | |
159 | | // compute residual with IVF quantizer |
160 | 0 | std::vector<float> residuals(n * index_ivf->d); |
161 | 0 | index_ivf->quantizer->compute_residual_n( |
162 | 0 | n, tv.x, residuals.data(), assign.data()); |
163 | |
|
164 | 0 | index_ivf->train_encoder(n, residuals.data(), assign.data()); |
165 | 0 | } else { |
166 | 0 | index_ivf->train_encoder(n, tv.x, nullptr); |
167 | 0 | } |
168 | 0 | index_ivf->is_trained = true; |
169 | 0 | is_trained = true; |
170 | 0 | } |
171 | | |
172 | | } // namespace faiss |