/root/doris/contrib/faiss/faiss/IndexIVFRaBitQ.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexIVFRaBitQ.h> |
9 | | |
10 | | #include <omp.h> |
11 | | |
12 | | #include <cstddef> |
13 | | #include <cstdint> |
14 | | #include <memory> |
15 | | #include <vector> |
16 | | |
17 | | #include <faiss/impl/FaissAssert.h> |
18 | | #include <faiss/impl/RaBitQuantizer.h> |
19 | | |
20 | | namespace faiss { |
21 | | |
22 | | IndexIVFRaBitQ::IndexIVFRaBitQ( |
23 | | Index* quantizer, |
24 | | const size_t d, |
25 | | const size_t nlist, |
26 | | MetricType metric) |
27 | 0 | : IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) { |
28 | 0 | code_size = rabitq.code_size; |
29 | 0 | invlists->code_size = code_size; |
30 | 0 | is_trained = false; |
31 | |
|
32 | 0 | by_residual = true; |
33 | 0 | } |
34 | | |
35 | 0 | IndexIVFRaBitQ::IndexIVFRaBitQ() { |
36 | 0 | by_residual = true; |
37 | 0 | } |
38 | | |
39 | | void IndexIVFRaBitQ::train_encoder( |
40 | | idx_t n, |
41 | | const float* x, |
42 | 0 | const idx_t* assign) { |
43 | 0 | rabitq.train(n, x); |
44 | 0 | } |
45 | | |
46 | | void IndexIVFRaBitQ::encode_vectors( |
47 | | idx_t n, |
48 | | const float* x, |
49 | | const idx_t* list_nos, |
50 | | uint8_t* codes, |
51 | 0 | bool include_listnos) const { |
52 | 0 | size_t coarse_size = include_listnos ? coarse_code_size() : 0; |
53 | 0 | memset(codes, 0, (code_size + coarse_size) * n); |
54 | |
|
55 | 0 | #pragma omp parallel if (n > 1000) |
56 | 0 | { |
57 | 0 | std::vector<float> centroid(d); |
58 | |
|
59 | 0 | #pragma omp for |
60 | 0 | for (idx_t i = 0; i < n; i++) { |
61 | 0 | int64_t list_no = list_nos[i]; |
62 | 0 | if (list_no >= 0) { |
63 | 0 | const float* xi = x + i * d; |
64 | 0 | uint8_t* code = codes + i * (code_size + coarse_size); |
65 | | |
66 | | // both by_residual and !by_residual lead to the same code |
67 | 0 | quantizer->reconstruct(list_no, centroid.data()); |
68 | 0 | rabitq.compute_codes_core( |
69 | 0 | xi, code + coarse_size, 1, centroid.data()); |
70 | |
|
71 | 0 | if (coarse_size) { |
72 | 0 | encode_listno(list_no, code); |
73 | 0 | } |
74 | 0 | } |
75 | 0 | } |
76 | 0 | } |
77 | 0 | } |
78 | | |
79 | | void IndexIVFRaBitQ::add_core( |
80 | | idx_t n, |
81 | | const float* x, |
82 | | const idx_t* xids, |
83 | | const idx_t* precomputed_idx, |
84 | 0 | void* inverted_list_context) { |
85 | 0 | FAISS_THROW_IF_NOT(is_trained); |
86 | | |
87 | 0 | DirectMapAdd dm_add(direct_map, n, xids); |
88 | |
|
89 | 0 | #pragma omp parallel |
90 | 0 | { |
91 | 0 | std::vector<uint8_t> one_code(code_size); |
92 | 0 | std::vector<float> centroid(d); |
93 | |
|
94 | 0 | int nt = omp_get_num_threads(); |
95 | 0 | int rank = omp_get_thread_num(); |
96 | | |
97 | | // each thread takes care of a subset of lists |
98 | 0 | for (size_t i = 0; i < n; i++) { |
99 | 0 | int64_t list_no = precomputed_idx[i]; |
100 | 0 | if (list_no >= 0 && list_no % nt == rank) { |
101 | 0 | int64_t id = xids ? xids[i] : ntotal + i; |
102 | |
|
103 | 0 | const float* xi = x + i * d; |
104 | | |
105 | | // both by_residual and !by_residual lead to the same code |
106 | 0 | quantizer->reconstruct(list_no, centroid.data()); |
107 | 0 | rabitq.compute_codes_core( |
108 | 0 | xi, one_code.data(), 1, centroid.data()); |
109 | |
|
110 | 0 | size_t ofs = invlists->add_entry( |
111 | 0 | list_no, id, one_code.data(), inverted_list_context); |
112 | |
|
113 | 0 | dm_add.add(i, list_no, ofs); |
114 | |
|
115 | 0 | } else if (rank == 0 && list_no == -1) { |
116 | 0 | dm_add.add(i, -1, 0); |
117 | 0 | } |
118 | 0 | } |
119 | 0 | } |
120 | |
|
121 | 0 | ntotal += n; |
122 | 0 | } |
123 | | |
124 | | struct RaBitInvertedListScanner : InvertedListScanner { |
125 | | const IndexIVFRaBitQ& ivf_rabitq; |
126 | | |
127 | | std::vector<float> reconstructed_centroid; |
128 | | std::vector<float> query_vector; |
129 | | |
130 | | std::unique_ptr<FlatCodesDistanceComputer> dc; |
131 | | |
132 | | uint8_t qb = 0; |
133 | | |
134 | | RaBitInvertedListScanner( |
135 | | const IndexIVFRaBitQ& ivf_rabitq_in, |
136 | | bool store_pairs = false, |
137 | | const IDSelector* sel = nullptr, |
138 | | uint8_t qb_in = 0) |
139 | 0 | : InvertedListScanner(store_pairs, sel), |
140 | 0 | ivf_rabitq{ivf_rabitq_in}, |
141 | 0 | qb{qb_in} { |
142 | 0 | keep_max = is_similarity_metric(ivf_rabitq.metric_type); |
143 | 0 | code_size = ivf_rabitq.code_size; |
144 | 0 | } |
145 | | |
146 | | /// from now on we handle this query. |
147 | 0 | void set_query(const float* query_vector_in) override { |
148 | 0 | query_vector.assign(query_vector_in, query_vector_in + ivf_rabitq.d); |
149 | |
|
150 | 0 | internal_try_setup_dc(); |
151 | 0 | } |
152 | | |
153 | | /// following codes come from this inverted list |
154 | 0 | void set_list(idx_t list_no, float coarse_dis) override { |
155 | 0 | this->list_no = list_no; |
156 | |
|
157 | 0 | reconstructed_centroid.resize(ivf_rabitq.d); |
158 | 0 | ivf_rabitq.quantizer->reconstruct( |
159 | 0 | list_no, reconstructed_centroid.data()); |
160 | |
|
161 | 0 | internal_try_setup_dc(); |
162 | 0 | } |
163 | | |
164 | | /// compute a single query-to-code distance |
165 | 0 | float distance_to_code(const uint8_t* code) const override { |
166 | 0 | return dc->distance_to_code(code); |
167 | 0 | } |
168 | | |
169 | 0 | void internal_try_setup_dc() { |
170 | 0 | if (!query_vector.empty() && !reconstructed_centroid.empty()) { |
171 | | // both query_vector and centroid are available! |
172 | | // set up DistanceComputer |
173 | 0 | dc.reset(ivf_rabitq.rabitq.get_distance_computer( |
174 | 0 | qb, reconstructed_centroid.data())); |
175 | |
|
176 | 0 | dc->set_query(query_vector.data()); |
177 | 0 | } |
178 | 0 | } |
179 | | }; |
180 | | |
181 | | InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner( |
182 | | bool store_pairs, |
183 | | const IDSelector* sel, |
184 | 0 | const IVFSearchParameters* search_params_in) const { |
185 | 0 | uint8_t used_qb = qb; |
186 | 0 | if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>( |
187 | 0 | search_params_in)) { |
188 | 0 | used_qb = params->qb; |
189 | 0 | } |
190 | |
|
191 | 0 | return new RaBitInvertedListScanner(*this, store_pairs, sel, used_qb); |
192 | 0 | } |
193 | | |
194 | | void IndexIVFRaBitQ::reconstruct_from_offset( |
195 | | int64_t list_no, |
196 | | int64_t offset, |
197 | 0 | float* recons) const { |
198 | 0 | const uint8_t* code = invlists->get_single_code(list_no, offset); |
199 | |
|
200 | 0 | std::vector<float> centroid(d); |
201 | 0 | quantizer->reconstruct(list_no, centroid.data()); |
202 | |
|
203 | 0 | rabitq.decode_core(code, recons, 1, centroid.data()); |
204 | 0 | } |
205 | | |
206 | 0 | void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const { |
207 | 0 | size_t coarse_size = coarse_code_size(); |
208 | |
|
209 | 0 | #pragma omp parallel |
210 | 0 | { |
211 | 0 | std::vector<float> centroid(d); |
212 | |
|
213 | 0 | #pragma omp for |
214 | 0 | for (idx_t i = 0; i < n; i++) { |
215 | 0 | const uint8_t* code = codes + i * (code_size + coarse_size); |
216 | 0 | int64_t list_no = decode_listno(code); |
217 | 0 | float* xi = x + i * d; |
218 | |
|
219 | 0 | quantizer->reconstruct(list_no, centroid.data()); |
220 | 0 | rabitq.decode_core(code + coarse_size, xi, 1, centroid.data()); |
221 | 0 | } |
222 | 0 | } |
223 | 0 | } |
224 | | |
225 | | struct IVFRaBitDistanceComputer : DistanceComputer { |
226 | | const float* q = nullptr; |
227 | | const IndexIVFRaBitQ* parent = nullptr; |
228 | | |
229 | | void set_query(const float* x) override; |
230 | | |
231 | | float operator()(idx_t i) override; |
232 | | |
233 | | float symmetric_dis(idx_t i, idx_t j) override; |
234 | | }; |
235 | | |
236 | 0 | void IVFRaBitDistanceComputer::set_query(const float* x) { |
237 | 0 | q = x; |
238 | 0 | } |
239 | | |
240 | 0 | float IVFRaBitDistanceComputer::operator()(idx_t i) { |
241 | | // find the appropriate list |
242 | 0 | idx_t lo = parent->direct_map.get(i); |
243 | 0 | uint64_t list_no = lo_listno(lo); |
244 | 0 | uint64_t offset = lo_offset(lo); |
245 | |
|
246 | 0 | const uint8_t* code = parent->invlists->get_single_code(list_no, offset); |
247 | | |
248 | | // ok, we know the appropriate cluster that we need |
249 | 0 | std::vector<float> centroid(parent->d); |
250 | 0 | parent->quantizer->reconstruct(list_no, centroid.data()); |
251 | | |
252 | | // compute the distance |
253 | 0 | float distance = 0; |
254 | |
|
255 | 0 | std::unique_ptr<FlatCodesDistanceComputer> dc( |
256 | 0 | parent->rabitq.get_distance_computer(parent->qb, centroid.data())); |
257 | 0 | dc->set_query(q); |
258 | 0 | distance = dc->distance_to_code(code); |
259 | | |
260 | | // deallocate |
261 | 0 | parent->invlists->release_codes(list_no, code); |
262 | | |
263 | | // done |
264 | 0 | return distance; |
265 | 0 | } |
266 | | |
267 | 0 | float IVFRaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) { |
268 | 0 | FAISS_THROW_MSG("Not implemented"); |
269 | 0 | } |
270 | | |
271 | 0 | DistanceComputer* IndexIVFRaBitQ::get_distance_computer() const { |
272 | 0 | IVFRaBitDistanceComputer* dc = new IVFRaBitDistanceComputer; |
273 | 0 | dc->parent = this; |
274 | 0 | return dc; |
275 | 0 | } |
276 | | |
277 | | } // namespace faiss |