/root/doris/contrib/faiss/faiss/IndexIVF.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexIVF.h> |
11 | | |
12 | | #include <omp.h> |
13 | | #include <cstdint> |
14 | | #include <memory> |
15 | | #include <mutex> |
16 | | |
17 | | #include <algorithm> |
18 | | #include <cinttypes> |
19 | | #include <cstdio> |
20 | | #include <limits> |
21 | | |
22 | | #include <faiss/utils/hamming.h> |
23 | | #include <faiss/utils/utils.h> |
24 | | |
25 | | #include <faiss/IndexFlat.h> |
26 | | #include <faiss/impl/AuxIndexStructures.h> |
27 | | #include <faiss/impl/CodePacker.h> |
28 | | #include <faiss/impl/FaissAssert.h> |
29 | | #include <faiss/impl/IDSelector.h> |
30 | | |
31 | | namespace faiss { |
32 | | |
33 | | using ScopedIds = InvertedLists::ScopedIds; |
34 | | using ScopedCodes = InvertedLists::ScopedCodes; |
35 | | |
36 | | /***************************************** |
37 | | * Level1Quantizer implementation |
38 | | ******************************************/ |
39 | | |
40 | | Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist) |
41 | 0 | : quantizer(quantizer), nlist(nlist) { |
42 | | // here we set a low # iterations because this is typically used |
43 | | // for large clusterings (nb this is not used for the MultiIndex, |
44 | | // for which quantizer_trains_alone = true) |
45 | 0 | cp.niter = 10; |
46 | 0 | } |
47 | | |
48 | 0 | Level1Quantizer::Level1Quantizer() = default; |
49 | | |
50 | 0 | Level1Quantizer::~Level1Quantizer() { |
51 | 0 | if (own_fields) { |
52 | 0 | delete quantizer; |
53 | 0 | } |
54 | 0 | } |
55 | | |
56 | | void Level1Quantizer::train_q1( |
57 | | size_t n, |
58 | | const float* x, |
59 | | bool verbose, |
60 | 0 | MetricType metric_type) { |
61 | 0 | size_t d = quantizer->d; |
62 | 0 | if (quantizer->is_trained && (quantizer->ntotal == nlist)) { |
63 | 0 | if (verbose) |
64 | 0 | printf("IVF quantizer does not need training.\n"); |
65 | 0 | } else if (quantizer_trains_alone == 1) { |
66 | 0 | if (verbose) |
67 | 0 | printf("IVF quantizer trains alone...\n"); |
68 | 0 | quantizer->verbose = verbose; |
69 | 0 | quantizer->train(n, x); |
70 | 0 | FAISS_THROW_IF_NOT_MSG( |
71 | 0 | quantizer->ntotal == nlist, |
72 | 0 | "nlist not consistent with quantizer size"); |
73 | 0 | } else if (quantizer_trains_alone == 0) { |
74 | 0 | if (verbose) |
75 | 0 | printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d); |
76 | |
|
77 | 0 | Clustering clus(d, nlist, cp); |
78 | 0 | quantizer->reset(); |
79 | 0 | if (clustering_index) { |
80 | 0 | clus.train(n, x, *clustering_index); |
81 | 0 | quantizer->add(nlist, clus.centroids.data()); |
82 | 0 | } else { |
83 | 0 | clus.train(n, x, *quantizer); |
84 | 0 | } |
85 | 0 | quantizer->is_trained = true; |
86 | 0 | } else if (quantizer_trains_alone == 2) { |
87 | 0 | if (verbose) { |
88 | 0 | printf("Training L2 quantizer on %zd vectors in %zdD%s\n", |
89 | 0 | n, |
90 | 0 | d, |
91 | 0 | clustering_index ? "(user provided index)" : ""); |
92 | 0 | } |
93 | | // also accept spherical centroids because in that case |
94 | | // L2 and IP are equivalent |
95 | 0 | FAISS_THROW_IF_NOT( |
96 | 0 | metric_type == METRIC_L2 || |
97 | 0 | (metric_type == METRIC_INNER_PRODUCT && cp.spherical)); |
98 | | |
99 | 0 | Clustering clus(d, nlist, cp); |
100 | 0 | if (!clustering_index) { |
101 | 0 | IndexFlatL2 assigner(d); |
102 | 0 | clus.train(n, x, assigner); |
103 | 0 | } else { |
104 | 0 | clus.train(n, x, *clustering_index); |
105 | 0 | } |
106 | 0 | if (verbose) { |
107 | 0 | printf("Adding centroids to quantizer\n"); |
108 | 0 | } |
109 | 0 | if (!quantizer->is_trained) { |
110 | 0 | if (verbose) { |
111 | 0 | printf("But training it first on centroids table...\n"); |
112 | 0 | } |
113 | 0 | quantizer->train(nlist, clus.centroids.data()); |
114 | 0 | } |
115 | 0 | quantizer->add(nlist, clus.centroids.data()); |
116 | 0 | } |
117 | 0 | } |
118 | | |
119 | 0 | size_t Level1Quantizer::coarse_code_size() const { |
120 | 0 | size_t nl = nlist - 1; |
121 | 0 | size_t nbyte = 0; |
122 | 0 | while (nl > 0) { |
123 | 0 | nbyte++; |
124 | 0 | nl >>= 8; |
125 | 0 | } |
126 | 0 | return nbyte; |
127 | 0 | } |
128 | | |
129 | 0 | void Level1Quantizer::encode_listno(idx_t list_no, uint8_t* code) const { |
130 | | // little endian |
131 | 0 | size_t nl = nlist - 1; |
132 | 0 | while (nl > 0) { |
133 | 0 | *code++ = list_no & 0xff; |
134 | 0 | list_no >>= 8; |
135 | 0 | nl >>= 8; |
136 | 0 | } |
137 | 0 | } |
138 | | |
139 | 0 | idx_t Level1Quantizer::decode_listno(const uint8_t* code) const { |
140 | 0 | size_t nl = nlist - 1; |
141 | 0 | int64_t list_no = 0; |
142 | 0 | int nbit = 0; |
143 | 0 | while (nl > 0) { |
144 | 0 | list_no |= int64_t(*code++) << nbit; |
145 | 0 | nbit += 8; |
146 | 0 | nl >>= 8; |
147 | 0 | } |
148 | 0 | FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist); |
149 | 0 | return list_no; |
150 | 0 | } |
151 | | |
152 | | /***************************************** |
153 | | * IndexIVF implementation |
154 | | ******************************************/ |
155 | | |
156 | | IndexIVF::IndexIVF( |
157 | | Index* quantizer, |
158 | | size_t d, |
159 | | size_t nlist, |
160 | | size_t code_size, |
161 | | MetricType metric) |
162 | 0 | : Index(d, metric), |
163 | 0 | IndexIVFInterface(quantizer, nlist), |
164 | 0 | invlists(new ArrayInvertedLists(nlist, code_size)), |
165 | 0 | own_invlists(true), |
166 | 0 | code_size(code_size) { |
167 | 0 | FAISS_THROW_IF_NOT(d == quantizer->d); |
168 | 0 | is_trained = quantizer->is_trained && (quantizer->ntotal == nlist); |
169 | | // Spherical by default if the metric is inner_product |
170 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
171 | 0 | cp.spherical = true; |
172 | 0 | } |
173 | 0 | } |
174 | | |
175 | 0 | IndexIVF::IndexIVF() = default; |
176 | | |
177 | 0 | void IndexIVF::add(idx_t n, const float* x) { |
178 | 0 | add_with_ids(n, x, nullptr); |
179 | 0 | } |
180 | | |
181 | 0 | void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) { |
182 | 0 | std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]); |
183 | 0 | quantizer->assign(n, x, coarse_idx.get()); |
184 | 0 | add_core(n, x, xids, coarse_idx.get()); |
185 | 0 | } |
186 | | |
187 | 0 | void IndexIVF::add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) { |
188 | 0 | size_t coarse_size = coarse_code_size(); |
189 | 0 | DirectMapAdd dm_adder(direct_map, n, xids); |
190 | |
|
191 | 0 | for (idx_t i = 0; i < n; i++) { |
192 | 0 | const uint8_t* code = codes + (code_size + coarse_size) * i; |
193 | 0 | idx_t list_no = decode_listno(code); |
194 | 0 | idx_t id = xids ? xids[i] : ntotal + i; |
195 | 0 | size_t ofs = invlists->add_entry(list_no, id, code + coarse_size); |
196 | 0 | dm_adder.add(i, list_no, ofs); |
197 | 0 | } |
198 | 0 | ntotal += n; |
199 | 0 | } |
200 | | |
201 | | void IndexIVF::add_core( |
202 | | idx_t n, |
203 | | const float* x, |
204 | | const idx_t* xids, |
205 | | const idx_t* coarse_idx, |
206 | 0 | void* inverted_list_context) { |
207 | | // do some blocking to avoid excessive allocs |
208 | 0 | idx_t bs = 65536; |
209 | 0 | if (n > bs) { |
210 | 0 | for (idx_t i0 = 0; i0 < n; i0 += bs) { |
211 | 0 | idx_t i1 = std::min(n, i0 + bs); |
212 | 0 | if (verbose) { |
213 | 0 | printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n", |
214 | 0 | i0, |
215 | 0 | i1); |
216 | 0 | } |
217 | 0 | add_core( |
218 | 0 | i1 - i0, |
219 | 0 | x + i0 * d, |
220 | 0 | xids ? xids + i0 : nullptr, |
221 | 0 | coarse_idx + i0, |
222 | 0 | inverted_list_context); |
223 | 0 | } |
224 | 0 | return; |
225 | 0 | } |
226 | 0 | FAISS_THROW_IF_NOT(coarse_idx); |
227 | 0 | FAISS_THROW_IF_NOT(is_trained); |
228 | 0 | direct_map.check_can_add(xids); |
229 | |
|
230 | 0 | size_t nadd = 0, nminus1 = 0; |
231 | |
|
232 | 0 | for (size_t i = 0; i < n; i++) { |
233 | 0 | if (coarse_idx[i] < 0) |
234 | 0 | nminus1++; |
235 | 0 | } |
236 | |
|
237 | 0 | std::unique_ptr<uint8_t[]> flat_codes(new uint8_t[n * code_size]); |
238 | 0 | encode_vectors(n, x, coarse_idx, flat_codes.get()); |
239 | |
|
240 | 0 | DirectMapAdd dm_adder(direct_map, n, xids); |
241 | |
|
242 | 0 | #pragma omp parallel reduction(+ : nadd) |
243 | 0 | { |
244 | 0 | int nt = omp_get_num_threads(); |
245 | 0 | int rank = omp_get_thread_num(); |
246 | | |
247 | | // each thread takes care of a subset of lists |
248 | 0 | for (size_t i = 0; i < n; i++) { |
249 | 0 | idx_t list_no = coarse_idx[i]; |
250 | 0 | if (list_no >= 0 && list_no % nt == rank) { |
251 | 0 | idx_t id = xids ? xids[i] : ntotal + i; |
252 | 0 | size_t ofs = invlists->add_entry( |
253 | 0 | list_no, |
254 | 0 | id, |
255 | 0 | flat_codes.get() + i * code_size, |
256 | 0 | inverted_list_context); |
257 | |
|
258 | 0 | dm_adder.add(i, list_no, ofs); |
259 | |
|
260 | 0 | nadd++; |
261 | 0 | } else if (rank == 0 && list_no == -1) { |
262 | 0 | dm_adder.add(i, -1, 0); |
263 | 0 | } |
264 | 0 | } |
265 | 0 | } |
266 | |
|
267 | 0 | if (verbose) { |
268 | 0 | printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n", |
269 | 0 | nadd, |
270 | 0 | n, |
271 | 0 | nminus1); |
272 | 0 | } |
273 | |
|
274 | 0 | ntotal += n; |
275 | 0 | } |
276 | | |
277 | 0 | void IndexIVF::make_direct_map(bool b) { |
278 | 0 | if (b) { |
279 | 0 | direct_map.set_type(DirectMap::Array, invlists, ntotal); |
280 | 0 | } else { |
281 | 0 | direct_map.set_type(DirectMap::NoMap, invlists, ntotal); |
282 | 0 | } |
283 | 0 | } |
284 | | |
285 | 0 | void IndexIVF::set_direct_map_type(DirectMap::Type type) { |
286 | 0 | direct_map.set_type(type, invlists, ntotal); |
287 | 0 | } |
288 | | |
289 | | /** It is a sad fact of software that a conceptually simple function like this |
290 | | * becomes very complex when you factor in several ways of parallelizing + |
291 | | * interrupt/error handling + collecting stats + min/max collection. The |
292 | | * codepath that is used 95% of time is the one for parallel_mode = 0 */ |
293 | | void IndexIVF::search( |
294 | | idx_t n, |
295 | | const float* x, |
296 | | idx_t k, |
297 | | float* distances, |
298 | | idx_t* labels, |
299 | 0 | const SearchParameters* params_in) const { |
300 | 0 | FAISS_THROW_IF_NOT(k > 0); |
301 | 0 | const IVFSearchParameters* params = nullptr; |
302 | 0 | if (params_in) { |
303 | 0 | params = dynamic_cast<const IVFSearchParameters*>(params_in); |
304 | 0 | FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type"); |
305 | 0 | } |
306 | 0 | const size_t nprobe = |
307 | 0 | std::min(nlist, params ? params->nprobe : this->nprobe); |
308 | 0 | FAISS_THROW_IF_NOT(nprobe > 0); |
309 | | |
310 | | // search function for a subset of queries |
311 | 0 | auto sub_search_func = [this, k, nprobe, params]( |
312 | 0 | idx_t n, |
313 | 0 | const float* x, |
314 | 0 | float* distances, |
315 | 0 | idx_t* labels, |
316 | 0 | IndexIVFStats* ivf_stats) { |
317 | 0 | std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]); |
318 | 0 | std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]); |
319 | |
|
320 | 0 | double t0 = getmillisecs(); |
321 | 0 | quantizer->search( |
322 | 0 | n, |
323 | 0 | x, |
324 | 0 | nprobe, |
325 | 0 | coarse_dis.get(), |
326 | 0 | idx.get(), |
327 | 0 | params ? params->quantizer_params : nullptr); |
328 | |
|
329 | 0 | double t1 = getmillisecs(); |
330 | 0 | invlists->prefetch_lists(idx.get(), n * nprobe); |
331 | |
|
332 | 0 | search_preassigned( |
333 | 0 | n, |
334 | 0 | x, |
335 | 0 | k, |
336 | 0 | idx.get(), |
337 | 0 | coarse_dis.get(), |
338 | 0 | distances, |
339 | 0 | labels, |
340 | 0 | false, |
341 | 0 | params, |
342 | 0 | ivf_stats); |
343 | 0 | double t2 = getmillisecs(); |
344 | 0 | ivf_stats->quantization_time += t1 - t0; |
345 | 0 | ivf_stats->search_time += t2 - t0; |
346 | 0 | }; |
347 | |
|
348 | 0 | if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) { |
349 | 0 | int nt = std::min(omp_get_max_threads(), int(n)); |
350 | 0 | std::vector<IndexIVFStats> stats(nt); |
351 | 0 | std::mutex exception_mutex; |
352 | 0 | std::string exception_string; |
353 | |
|
354 | 0 | #pragma omp parallel for if (nt > 1) |
355 | 0 | for (idx_t slice = 0; slice < nt; slice++) { |
356 | 0 | IndexIVFStats local_stats; |
357 | 0 | idx_t i0 = n * slice / nt; |
358 | 0 | idx_t i1 = n * (slice + 1) / nt; |
359 | 0 | if (i1 > i0) { |
360 | 0 | try { |
361 | 0 | sub_search_func( |
362 | 0 | i1 - i0, |
363 | 0 | x + i0 * d, |
364 | 0 | distances + i0 * k, |
365 | 0 | labels + i0 * k, |
366 | 0 | &stats[slice]); |
367 | 0 | } catch (const std::exception& e) { |
368 | 0 | std::lock_guard<std::mutex> lock(exception_mutex); |
369 | 0 | exception_string = e.what(); |
370 | 0 | } |
371 | 0 | } |
372 | 0 | } |
373 | |
|
374 | 0 | if (!exception_string.empty()) { |
375 | 0 | FAISS_THROW_MSG(exception_string.c_str()); |
376 | 0 | } |
377 | | |
378 | | // collect stats |
379 | 0 | for (idx_t slice = 0; slice < nt; slice++) { |
380 | 0 | indexIVF_stats.add(stats[slice]); |
381 | 0 | } |
382 | 0 | } else { |
383 | | // handle parallelization at level below (or don't run in parallel at |
384 | | // all) |
385 | 0 | sub_search_func(n, x, distances, labels, &indexIVF_stats); |
386 | 0 | } |
387 | 0 | } |
388 | | |
389 | | void IndexIVF::search_preassigned( |
390 | | idx_t n, |
391 | | const float* x, |
392 | | idx_t k, |
393 | | const idx_t* keys, |
394 | | const float* coarse_dis, |
395 | | float* distances, |
396 | | idx_t* labels, |
397 | | bool store_pairs, |
398 | | const IVFSearchParameters* params, |
399 | 0 | IndexIVFStats* ivf_stats) const { |
400 | 0 | FAISS_THROW_IF_NOT(k > 0); |
401 | | |
402 | 0 | idx_t nprobe = params ? params->nprobe : this->nprobe; |
403 | 0 | nprobe = std::min((idx_t)nlist, nprobe); |
404 | 0 | FAISS_THROW_IF_NOT(nprobe > 0); |
405 | | |
406 | 0 | const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max(); |
407 | 0 | idx_t max_codes = params ? params->max_codes : this->max_codes; |
408 | 0 | IDSelector* sel = params ? params->sel : nullptr; |
409 | 0 | const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel); |
410 | 0 | if (selr) { |
411 | 0 | if (selr->assume_sorted) { |
412 | 0 | sel = nullptr; // use special IDSelectorRange processing |
413 | 0 | } else { |
414 | 0 | selr = nullptr; // use generic processing |
415 | 0 | } |
416 | 0 | } |
417 | |
|
418 | 0 | FAISS_THROW_IF_NOT_MSG( |
419 | 0 | !(sel && store_pairs), |
420 | 0 | "selector and store_pairs cannot be combined"); |
421 | | |
422 | 0 | FAISS_THROW_IF_NOT_MSG( |
423 | 0 | !invlists->use_iterator || (max_codes == 0 && store_pairs == false), |
424 | 0 | "iterable inverted lists don't support max_codes and store_pairs"); |
425 | | |
426 | 0 | size_t nlistv = 0, ndis = 0, nheap = 0; |
427 | |
|
428 | 0 | using HeapForIP = CMin<float, idx_t>; |
429 | 0 | using HeapForL2 = CMax<float, idx_t>; |
430 | |
|
431 | 0 | bool interrupt = false; |
432 | 0 | std::mutex exception_mutex; |
433 | 0 | std::string exception_string; |
434 | |
|
435 | 0 | int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT; |
436 | 0 | bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT); |
437 | |
|
438 | 0 | FAISS_THROW_IF_NOT_MSG( |
439 | 0 | max_codes == 0 || pmode == 0 || pmode == 3, |
440 | 0 | "max_codes supported only for parallel_mode = 0 or 3"); |
441 | | |
442 | 0 | if (max_codes == 0) { |
443 | 0 | max_codes = unlimited_list_size; |
444 | 0 | } |
445 | |
|
446 | 0 | [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 && |
447 | 0 | (pmode == 0 ? false |
448 | 0 | : pmode == 3 ? n > 1 |
449 | 0 | : pmode == 1 ? nprobe > 1 |
450 | 0 | : nprobe * n > 1); |
451 | |
|
452 | 0 | void* inverted_list_context = |
453 | 0 | params ? params->inverted_list_context : nullptr; |
454 | |
|
455 | 0 | #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap) |
456 | 0 | { |
457 | 0 | std::unique_ptr<InvertedListScanner> scanner( |
458 | 0 | get_InvertedListScanner(store_pairs, sel, params)); |
459 | | |
460 | | /***************************************************** |
461 | | * Depending on parallel_mode, there are two possible ways |
462 | | * to organize the search. Here we define local functions |
463 | | * that are in common between the two |
464 | | ******************************************************/ |
465 | | |
466 | | // initialize + reorder a result heap |
467 | |
|
468 | 0 | auto init_result = [&](float* simi, idx_t* idxi) { |
469 | 0 | if (!do_heap_init) |
470 | 0 | return; |
471 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
472 | 0 | heap_heapify<HeapForIP>(k, simi, idxi); |
473 | 0 | } else { |
474 | 0 | heap_heapify<HeapForL2>(k, simi, idxi); |
475 | 0 | } |
476 | 0 | }; |
477 | |
|
478 | 0 | auto add_local_results = [&](const float* local_dis, |
479 | 0 | const idx_t* local_idx, |
480 | 0 | float* simi, |
481 | 0 | idx_t* idxi) { |
482 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
483 | 0 | heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k); |
484 | 0 | } else { |
485 | 0 | heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k); |
486 | 0 | } |
487 | 0 | }; |
488 | |
|
489 | 0 | auto reorder_result = [&](float* simi, idx_t* idxi) { |
490 | 0 | if (!do_heap_init) |
491 | 0 | return; |
492 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
493 | 0 | heap_reorder<HeapForIP>(k, simi, idxi); |
494 | 0 | } else { |
495 | 0 | heap_reorder<HeapForL2>(k, simi, idxi); |
496 | 0 | } |
497 | 0 | }; |
498 | | |
499 | | // single list scan using the current scanner (with query |
500 | | // set porperly) and storing results in simi and idxi |
501 | 0 | auto scan_one_list = [&](idx_t key, |
502 | 0 | float coarse_dis_i, |
503 | 0 | float* simi, |
504 | 0 | idx_t* idxi, |
505 | 0 | idx_t list_size_max) { |
506 | 0 | if (key < 0) { |
507 | | // not enough centroids for multiprobe |
508 | 0 | return (size_t)0; |
509 | 0 | } |
510 | 0 | FAISS_THROW_IF_NOT_FMT( |
511 | 0 | key < (idx_t)nlist, |
512 | 0 | "Invalid key=%" PRId64 " nlist=%zd\n", |
513 | 0 | key, |
514 | 0 | nlist); |
515 | | |
516 | | // don't waste time on empty lists |
517 | 0 | if (invlists->is_empty(key, inverted_list_context)) { |
518 | 0 | return (size_t)0; |
519 | 0 | } |
520 | | |
521 | 0 | scanner->set_list(key, coarse_dis_i); |
522 | |
|
523 | 0 | nlistv++; |
524 | |
|
525 | 0 | try { |
526 | 0 | if (invlists->use_iterator) { |
527 | 0 | size_t list_size = 0; |
528 | |
|
529 | 0 | std::unique_ptr<InvertedListsIterator> it( |
530 | 0 | invlists->get_iterator(key, inverted_list_context)); |
531 | |
|
532 | 0 | nheap += scanner->iterate_codes( |
533 | 0 | it.get(), simi, idxi, k, list_size); |
534 | |
|
535 | 0 | return list_size; |
536 | 0 | } else { |
537 | 0 | size_t list_size = invlists->list_size(key); |
538 | 0 | if (list_size > list_size_max) { |
539 | 0 | list_size = list_size_max; |
540 | 0 | } |
541 | |
|
542 | 0 | InvertedLists::ScopedCodes scodes(invlists, key); |
543 | 0 | const uint8_t* codes = scodes.get(); |
544 | |
|
545 | 0 | std::unique_ptr<InvertedLists::ScopedIds> sids; |
546 | 0 | const idx_t* ids = nullptr; |
547 | |
|
548 | 0 | if (!store_pairs) { |
549 | 0 | sids = std::make_unique<InvertedLists::ScopedIds>( |
550 | 0 | invlists, key); |
551 | 0 | ids = sids->get(); |
552 | 0 | } |
553 | |
|
554 | 0 | if (selr) { // IDSelectorRange |
555 | | // restrict search to a section of the inverted list |
556 | 0 | size_t jmin, jmax; |
557 | 0 | selr->find_sorted_ids_bounds( |
558 | 0 | list_size, ids, &jmin, &jmax); |
559 | 0 | list_size = jmax - jmin; |
560 | 0 | if (list_size == 0) { |
561 | 0 | return (size_t)0; |
562 | 0 | } |
563 | 0 | codes += jmin * code_size; |
564 | 0 | ids += jmin; |
565 | 0 | } |
566 | | |
567 | 0 | nheap += scanner->scan_codes( |
568 | 0 | list_size, codes, ids, simi, idxi, k); |
569 | |
|
570 | 0 | return list_size; |
571 | 0 | } |
572 | 0 | } catch (const std::exception& e) { |
573 | 0 | std::lock_guard<std::mutex> lock(exception_mutex); |
574 | 0 | exception_string = |
575 | 0 | demangle_cpp_symbol(typeid(e).name()) + " " + e.what(); |
576 | 0 | interrupt = true; |
577 | 0 | return size_t(0); |
578 | 0 | } |
579 | 0 | }; |
580 | | |
581 | | /**************************************************** |
582 | | * Actual loops, depending on parallel_mode |
583 | | ****************************************************/ |
584 | |
|
585 | 0 | if (pmode == 0 || pmode == 3) { |
586 | 0 | #pragma omp for |
587 | 0 | for (idx_t i = 0; i < n; i++) { |
588 | 0 | if (interrupt) { |
589 | 0 | continue; |
590 | 0 | } |
591 | | |
592 | | // loop over queries |
593 | 0 | scanner->set_query(x + i * d); |
594 | 0 | float* simi = distances + i * k; |
595 | 0 | idx_t* idxi = labels + i * k; |
596 | |
|
597 | 0 | init_result(simi, idxi); |
598 | |
|
599 | 0 | idx_t nscan = 0; |
600 | | |
601 | | // loop over probes |
602 | 0 | for (size_t ik = 0; ik < nprobe; ik++) { |
603 | 0 | nscan += scan_one_list( |
604 | 0 | keys[i * nprobe + ik], |
605 | 0 | coarse_dis[i * nprobe + ik], |
606 | 0 | simi, |
607 | 0 | idxi, |
608 | 0 | max_codes - nscan); |
609 | 0 | if (nscan >= max_codes) { |
610 | 0 | break; |
611 | 0 | } |
612 | 0 | } |
613 | |
|
614 | 0 | ndis += nscan; |
615 | 0 | reorder_result(simi, idxi); |
616 | |
|
617 | 0 | if (InterruptCallback::is_interrupted()) { |
618 | 0 | interrupt = true; |
619 | 0 | } |
620 | |
|
621 | 0 | } // parallel for |
622 | 0 | } else if (pmode == 1) { |
623 | 0 | std::vector<idx_t> local_idx(k); |
624 | 0 | std::vector<float> local_dis(k); |
625 | |
|
626 | 0 | for (size_t i = 0; i < n; i++) { |
627 | 0 | scanner->set_query(x + i * d); |
628 | 0 | init_result(local_dis.data(), local_idx.data()); |
629 | |
|
630 | 0 | #pragma omp for schedule(dynamic) |
631 | 0 | for (idx_t ik = 0; ik < nprobe; ik++) { |
632 | 0 | ndis += scan_one_list( |
633 | 0 | keys[i * nprobe + ik], |
634 | 0 | coarse_dis[i * nprobe + ik], |
635 | 0 | local_dis.data(), |
636 | 0 | local_idx.data(), |
637 | 0 | unlimited_list_size); |
638 | | |
639 | | // can't do the test on max_codes |
640 | 0 | } |
641 | | // merge thread-local results |
642 | |
|
643 | 0 | float* simi = distances + i * k; |
644 | 0 | idx_t* idxi = labels + i * k; |
645 | 0 | #pragma omp single |
646 | 0 | init_result(simi, idxi); |
647 | |
|
648 | 0 | #pragma omp barrier |
649 | 0 | #pragma omp critical |
650 | 0 | { |
651 | 0 | add_local_results( |
652 | 0 | local_dis.data(), local_idx.data(), simi, idxi); |
653 | 0 | } |
654 | 0 | #pragma omp barrier |
655 | 0 | #pragma omp single |
656 | 0 | reorder_result(simi, idxi); |
657 | 0 | } |
658 | 0 | } else if (pmode == 2) { |
659 | 0 | std::vector<idx_t> local_idx(k); |
660 | 0 | std::vector<float> local_dis(k); |
661 | |
|
662 | 0 | #pragma omp single |
663 | 0 | for (int64_t i = 0; i < n; i++) { |
664 | 0 | init_result(distances + i * k, labels + i * k); |
665 | 0 | } |
666 | |
|
667 | 0 | #pragma omp for schedule(dynamic) |
668 | 0 | for (int64_t ij = 0; ij < n * nprobe; ij++) { |
669 | 0 | size_t i = ij / nprobe; |
670 | |
|
671 | 0 | scanner->set_query(x + i * d); |
672 | 0 | init_result(local_dis.data(), local_idx.data()); |
673 | 0 | ndis += scan_one_list( |
674 | 0 | keys[ij], |
675 | 0 | coarse_dis[ij], |
676 | 0 | local_dis.data(), |
677 | 0 | local_idx.data(), |
678 | 0 | unlimited_list_size); |
679 | 0 | #pragma omp critical |
680 | 0 | { |
681 | 0 | add_local_results( |
682 | 0 | local_dis.data(), |
683 | 0 | local_idx.data(), |
684 | 0 | distances + i * k, |
685 | 0 | labels + i * k); |
686 | 0 | } |
687 | 0 | } |
688 | 0 | #pragma omp single |
689 | 0 | for (int64_t i = 0; i < n; i++) { |
690 | 0 | reorder_result(distances + i * k, labels + i * k); |
691 | 0 | } |
692 | 0 | } else { |
693 | 0 | FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode); |
694 | 0 | } |
695 | 0 | } // parallel section |
696 | |
|
697 | 0 | if (interrupt) { |
698 | 0 | if (!exception_string.empty()) { |
699 | 0 | FAISS_THROW_FMT( |
700 | 0 | "search interrupted with: %s", exception_string.c_str()); |
701 | 0 | } else { |
702 | 0 | FAISS_THROW_MSG("computation interrupted"); |
703 | 0 | } |
704 | 0 | } |
705 | | |
706 | 0 | if (ivf_stats == nullptr) { |
707 | 0 | ivf_stats = &indexIVF_stats; |
708 | 0 | } |
709 | 0 | ivf_stats->nq += n; |
710 | 0 | ivf_stats->nlist += nlistv; |
711 | 0 | ivf_stats->ndis += ndis; |
712 | 0 | ivf_stats->nheap_updates += nheap; |
713 | 0 | } |
714 | | |
715 | | void IndexIVF::range_search( |
716 | | idx_t nx, |
717 | | const float* x, |
718 | | float radius, |
719 | | RangeSearchResult* result, |
720 | 0 | const SearchParameters* params_in) const { |
721 | 0 | const IVFSearchParameters* params = nullptr; |
722 | 0 | const SearchParameters* quantizer_params = nullptr; |
723 | 0 | if (params_in) { |
724 | 0 | params = dynamic_cast<const IVFSearchParameters*>(params_in); |
725 | 0 | FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type"); |
726 | 0 | quantizer_params = params->quantizer_params; |
727 | 0 | } |
728 | 0 | const size_t nprobe = |
729 | 0 | std::min(nlist, params ? params->nprobe : this->nprobe); |
730 | 0 | std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]); |
731 | 0 | std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]); |
732 | |
|
733 | 0 | double t0 = getmillisecs(); |
734 | 0 | quantizer->search( |
735 | 0 | nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params); |
736 | 0 | indexIVF_stats.quantization_time += getmillisecs() - t0; |
737 | |
|
738 | 0 | t0 = getmillisecs(); |
739 | 0 | invlists->prefetch_lists(keys.get(), nx * nprobe); |
740 | |
|
741 | 0 | range_search_preassigned( |
742 | 0 | nx, |
743 | 0 | x, |
744 | 0 | radius, |
745 | 0 | keys.get(), |
746 | 0 | coarse_dis.get(), |
747 | 0 | result, |
748 | 0 | false, |
749 | 0 | params, |
750 | 0 | &indexIVF_stats); |
751 | |
|
752 | 0 | indexIVF_stats.search_time += getmillisecs() - t0; |
753 | 0 | } |
754 | | |
755 | | void IndexIVF::range_search_preassigned( |
756 | | idx_t nx, |
757 | | const float* x, |
758 | | float radius, |
759 | | const idx_t* keys, |
760 | | const float* coarse_dis, |
761 | | RangeSearchResult* result, |
762 | | bool store_pairs, |
763 | | const IVFSearchParameters* params, |
764 | 0 | IndexIVFStats* stats) const { |
765 | 0 | idx_t nprobe = params ? params->nprobe : this->nprobe; |
766 | 0 | nprobe = std::min((idx_t)nlist, nprobe); |
767 | 0 | FAISS_THROW_IF_NOT(nprobe > 0); |
768 | | |
769 | 0 | idx_t max_codes = params ? params->max_codes : this->max_codes; |
770 | 0 | IDSelector* sel = params ? params->sel : nullptr; |
771 | |
|
772 | 0 | FAISS_THROW_IF_NOT_MSG( |
773 | 0 | !invlists->use_iterator || (max_codes == 0 && store_pairs == false), |
774 | 0 | "iterable inverted lists don't support max_codes and store_pairs"); |
775 | | |
776 | 0 | size_t nlistv = 0, ndis = 0; |
777 | |
|
778 | 0 | bool interrupt = false; |
779 | 0 | std::mutex exception_mutex; |
780 | 0 | std::string exception_string; |
781 | |
|
782 | 0 | std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads()); |
783 | |
|
784 | 0 | int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT; |
785 | | // don't start parallel section if single query |
786 | 0 | [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 && |
787 | 0 | (pmode == 3 ? false |
788 | 0 | : pmode == 0 ? nx > 1 |
789 | 0 | : pmode == 1 ? nprobe > 1 |
790 | 0 | : nprobe * nx > 1); |
791 | |
|
792 | 0 | void* inverted_list_context = |
793 | 0 | params ? params->inverted_list_context : nullptr; |
794 | |
|
795 | 0 | #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis) |
796 | 0 | { |
797 | 0 | RangeSearchPartialResult pres(result); |
798 | 0 | std::unique_ptr<InvertedListScanner> scanner( |
799 | 0 | get_InvertedListScanner(store_pairs, sel, params)); |
800 | 0 | FAISS_THROW_IF_NOT(scanner.get()); |
801 | 0 | all_pres[omp_get_thread_num()] = &pres; |
802 | | |
803 | | // prepare the list scanning function |
804 | |
|
805 | 0 | auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) { |
806 | 0 | idx_t key = keys[i * nprobe + ik]; /* select the list */ |
807 | 0 | if (key < 0) |
808 | 0 | return; |
809 | 0 | FAISS_THROW_IF_NOT_FMT( |
810 | 0 | key < (idx_t)nlist, |
811 | 0 | "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n", |
812 | 0 | key, |
813 | 0 | ik, |
814 | 0 | nlist); |
815 | | |
816 | 0 | if (invlists->is_empty(key, inverted_list_context)) { |
817 | 0 | return; |
818 | 0 | } |
819 | | |
820 | 0 | try { |
821 | 0 | size_t list_size = 0; |
822 | 0 | scanner->set_list(key, coarse_dis[i * nprobe + ik]); |
823 | 0 | if (invlists->use_iterator) { |
824 | 0 | std::unique_ptr<InvertedListsIterator> it( |
825 | 0 | invlists->get_iterator(key, inverted_list_context)); |
826 | |
|
827 | 0 | scanner->iterate_codes_range( |
828 | 0 | it.get(), radius, qres, list_size); |
829 | 0 | } else { |
830 | 0 | InvertedLists::ScopedCodes scodes(invlists, key); |
831 | 0 | InvertedLists::ScopedIds ids(invlists, key); |
832 | 0 | list_size = invlists->list_size(key); |
833 | |
|
834 | 0 | scanner->scan_codes_range( |
835 | 0 | list_size, scodes.get(), ids.get(), radius, qres); |
836 | 0 | } |
837 | 0 | nlistv++; |
838 | 0 | ndis += list_size; |
839 | 0 | } catch (const std::exception& e) { |
840 | 0 | std::lock_guard<std::mutex> lock(exception_mutex); |
841 | 0 | exception_string = |
842 | 0 | demangle_cpp_symbol(typeid(e).name()) + " " + e.what(); |
843 | 0 | interrupt = true; |
844 | 0 | } |
845 | 0 | }; |
846 | |
|
847 | 0 | if (parallel_mode == 0) { |
848 | 0 | #pragma omp for |
849 | 0 | for (idx_t i = 0; i < nx; i++) { |
850 | 0 | scanner->set_query(x + i * d); |
851 | |
|
852 | 0 | RangeQueryResult& qres = pres.new_result(i); |
853 | |
|
854 | 0 | for (size_t ik = 0; ik < nprobe; ik++) { |
855 | 0 | scan_list_func(i, ik, qres); |
856 | 0 | } |
857 | 0 | } |
858 | |
|
859 | 0 | } else if (parallel_mode == 1) { |
860 | 0 | for (size_t i = 0; i < nx; i++) { |
861 | 0 | scanner->set_query(x + i * d); |
862 | |
|
863 | 0 | RangeQueryResult& qres = pres.new_result(i); |
864 | |
|
865 | 0 | #pragma omp for schedule(dynamic) |
866 | 0 | for (int64_t ik = 0; ik < nprobe; ik++) { |
867 | 0 | scan_list_func(i, ik, qres); |
868 | 0 | } |
869 | 0 | } |
870 | 0 | } else if (parallel_mode == 2) { |
871 | 0 | RangeQueryResult* qres = nullptr; |
872 | |
|
873 | 0 | #pragma omp for schedule(dynamic) |
874 | 0 | for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) { |
875 | 0 | idx_t i = iik / (idx_t)nprobe; |
876 | 0 | idx_t ik = iik % (idx_t)nprobe; |
877 | 0 | if (qres == nullptr || qres->qno != i) { |
878 | 0 | qres = &pres.new_result(i); |
879 | 0 | scanner->set_query(x + i * d); |
880 | 0 | } |
881 | 0 | scan_list_func(i, ik, *qres); |
882 | 0 | } |
883 | 0 | } else { |
884 | 0 | FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode); |
885 | 0 | } |
886 | 0 | if (parallel_mode == 0) { |
887 | 0 | pres.finalize(); |
888 | 0 | } else { |
889 | 0 | #pragma omp barrier |
890 | 0 | #pragma omp single |
891 | 0 | RangeSearchPartialResult::merge(all_pres, false); |
892 | 0 | #pragma omp barrier |
893 | 0 | } |
894 | 0 | } |
895 | |
|
896 | 0 | if (interrupt) { |
897 | 0 | if (!exception_string.empty()) { |
898 | 0 | FAISS_THROW_FMT( |
899 | 0 | "search interrupted with: %s", exception_string.c_str()); |
900 | 0 | } else { |
901 | 0 | FAISS_THROW_MSG("computation interrupted"); |
902 | 0 | } |
903 | 0 | } |
904 | | |
905 | 0 | if (stats == nullptr) { |
906 | 0 | stats = &indexIVF_stats; |
907 | 0 | } |
908 | 0 | stats->nq += nx; |
909 | 0 | stats->nlist += nlistv; |
910 | 0 | stats->ndis += ndis; |
911 | 0 | } |
912 | | |
913 | | InvertedListScanner* IndexIVF::get_InvertedListScanner( |
914 | | bool /*store_pairs*/, |
915 | | const IDSelector* /* sel */, |
916 | 0 | const IVFSearchParameters* /* params */) const { |
917 | 0 | FAISS_THROW_MSG("get_InvertedListScanner not implemented"); |
918 | 0 | } |
919 | | |
920 | 0 | void IndexIVF::reconstruct(idx_t key, float* recons) const { |
921 | 0 | idx_t lo = direct_map.get(key); |
922 | 0 | reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons); |
923 | 0 | } |
924 | | |
925 | 0 | void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const { |
926 | 0 | FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal)); |
927 | | |
928 | 0 | for (idx_t list_no = 0; list_no < nlist; list_no++) { |
929 | 0 | size_t list_size = invlists->list_size(list_no); |
930 | 0 | ScopedIds idlist(invlists, list_no); |
931 | |
|
932 | 0 | for (idx_t offset = 0; offset < list_size; offset++) { |
933 | 0 | idx_t id = idlist[offset]; |
934 | 0 | if (!(id >= i0 && id < i0 + ni)) { |
935 | 0 | continue; |
936 | 0 | } |
937 | | |
938 | 0 | float* reconstructed = recons + (id - i0) * d; |
939 | 0 | reconstruct_from_offset(list_no, offset, reconstructed); |
940 | 0 | } |
941 | 0 | } |
942 | 0 | } |
943 | | |
944 | 0 | bool IndexIVF::check_ids_sorted() const { |
945 | 0 | size_t nflip = 0; |
946 | |
|
947 | 0 | for (size_t i = 0; i < nlist; i++) { |
948 | 0 | size_t list_size = invlists->list_size(i); |
949 | 0 | InvertedLists::ScopedIds ids(invlists, i); |
950 | 0 | for (size_t j = 0; j + 1 < list_size; j++) { |
951 | 0 | if (ids[j + 1] < ids[j]) { |
952 | 0 | nflip++; |
953 | 0 | } |
954 | 0 | } |
955 | 0 | } |
956 | 0 | return nflip == 0; |
957 | 0 | } |
958 | | |
959 | | /* standalone codec interface */ |
960 | 0 | size_t IndexIVF::sa_code_size() const { |
961 | 0 | size_t coarse_size = coarse_code_size(); |
962 | 0 | return code_size + coarse_size; |
963 | 0 | } |
964 | | |
965 | 0 | void IndexIVF::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { |
966 | 0 | FAISS_THROW_IF_NOT(is_trained); |
967 | 0 | std::unique_ptr<int64_t[]> idx(new int64_t[n]); |
968 | 0 | quantizer->assign(n, x, idx.get()); |
969 | 0 | encode_vectors(n, x, idx.get(), bytes, true); |
970 | 0 | } |
971 | | |
972 | | void IndexIVF::search_and_reconstruct( |
973 | | idx_t n, |
974 | | const float* x, |
975 | | idx_t k, |
976 | | float* distances, |
977 | | idx_t* labels, |
978 | | float* recons, |
979 | 0 | const SearchParameters* params_in) const { |
980 | 0 | const IVFSearchParameters* params = nullptr; |
981 | 0 | if (params_in) { |
982 | 0 | params = dynamic_cast<const IVFSearchParameters*>(params_in); |
983 | 0 | FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type"); |
984 | 0 | } |
985 | 0 | const size_t nprobe = |
986 | 0 | std::min(nlist, params ? params->nprobe : this->nprobe); |
987 | 0 | FAISS_THROW_IF_NOT(nprobe > 0); |
988 | | |
989 | 0 | std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]); |
990 | 0 | std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]); |
991 | |
|
992 | 0 | quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); |
993 | |
|
994 | 0 | invlists->prefetch_lists(idx.get(), n * nprobe); |
995 | | |
996 | | // search_preassigned() with `store_pairs` enabled to obtain the list_no |
997 | | // and offset into `codes` for reconstruction |
998 | 0 | search_preassigned( |
999 | 0 | n, |
1000 | 0 | x, |
1001 | 0 | k, |
1002 | 0 | idx.get(), |
1003 | 0 | coarse_dis.get(), |
1004 | 0 | distances, |
1005 | 0 | labels, |
1006 | 0 | true /* store_pairs */, |
1007 | 0 | params); |
1008 | 0 | #pragma omp parallel for if (n * k > 1000) |
1009 | 0 | for (idx_t ij = 0; ij < n * k; ij++) { |
1010 | 0 | idx_t key = labels[ij]; |
1011 | 0 | float* reconstructed = recons + ij * d; |
1012 | 0 | if (key < 0) { |
1013 | | // Fill with NaNs |
1014 | 0 | memset(reconstructed, -1, sizeof(*reconstructed) * d); |
1015 | 0 | } else { |
1016 | 0 | int list_no = lo_listno(key); |
1017 | 0 | int offset = lo_offset(key); |
1018 | | |
1019 | | // Update label to the actual id |
1020 | 0 | labels[ij] = invlists->get_single_id(list_no, offset); |
1021 | |
|
1022 | 0 | reconstruct_from_offset(list_no, offset, reconstructed); |
1023 | 0 | } |
1024 | 0 | } |
1025 | 0 | } |
1026 | | |
1027 | | void IndexIVF::search_and_return_codes( |
1028 | | idx_t n, |
1029 | | const float* x, |
1030 | | idx_t k, |
1031 | | float* distances, |
1032 | | idx_t* labels, |
1033 | | uint8_t* codes, |
1034 | | bool include_listno, |
1035 | 0 | const SearchParameters* params_in) const { |
1036 | 0 | const IVFSearchParameters* params = nullptr; |
1037 | 0 | if (params_in) { |
1038 | 0 | params = dynamic_cast<const IVFSearchParameters*>(params_in); |
1039 | 0 | FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type"); |
1040 | 0 | } |
1041 | 0 | const size_t nprobe = |
1042 | 0 | std::min(nlist, params ? params->nprobe : this->nprobe); |
1043 | 0 | FAISS_THROW_IF_NOT(nprobe > 0); |
1044 | | |
1045 | 0 | std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]); |
1046 | 0 | std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]); |
1047 | |
|
1048 | 0 | quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); |
1049 | |
|
1050 | 0 | invlists->prefetch_lists(idx.get(), n * nprobe); |
1051 | | |
1052 | | // search_preassigned() with `store_pairs` enabled to obtain the list_no |
1053 | | // and offset into `codes` for reconstruction |
1054 | 0 | search_preassigned( |
1055 | 0 | n, |
1056 | 0 | x, |
1057 | 0 | k, |
1058 | 0 | idx.get(), |
1059 | 0 | coarse_dis.get(), |
1060 | 0 | distances, |
1061 | 0 | labels, |
1062 | 0 | true /* store_pairs */, |
1063 | 0 | params); |
1064 | |
|
1065 | 0 | size_t code_size_1 = code_size; |
1066 | 0 | if (include_listno) { |
1067 | 0 | code_size_1 += coarse_code_size(); |
1068 | 0 | } |
1069 | |
|
1070 | 0 | #pragma omp parallel for if (n * k > 1000) |
1071 | 0 | for (idx_t ij = 0; ij < n * k; ij++) { |
1072 | 0 | idx_t key = labels[ij]; |
1073 | 0 | uint8_t* code1 = codes + ij * code_size_1; |
1074 | |
|
1075 | 0 | if (key < 0) { |
1076 | | // Fill with 0xff |
1077 | 0 | memset(code1, -1, code_size_1); |
1078 | 0 | } else { |
1079 | 0 | int list_no = lo_listno(key); |
1080 | 0 | int offset = lo_offset(key); |
1081 | 0 | const uint8_t* cc = invlists->get_single_code(list_no, offset); |
1082 | |
|
1083 | 0 | labels[ij] = invlists->get_single_id(list_no, offset); |
1084 | |
|
1085 | 0 | if (include_listno) { |
1086 | 0 | encode_listno(list_no, code1); |
1087 | 0 | code1 += code_size_1 - code_size; |
1088 | 0 | } |
1089 | 0 | memcpy(code1, cc, code_size); |
1090 | 0 | } |
1091 | 0 | } |
1092 | 0 | } |
1093 | | |
1094 | | void IndexIVF::reconstruct_from_offset( |
1095 | | int64_t /*list_no*/, |
1096 | | int64_t /*offset*/, |
1097 | 0 | float* /*recons*/) const { |
1098 | 0 | FAISS_THROW_MSG("reconstruct_from_offset not implemented"); |
1099 | 0 | } |
1100 | | |
1101 | 0 | void IndexIVF::reset() { |
1102 | 0 | direct_map.clear(); |
1103 | 0 | invlists->reset(); |
1104 | 0 | ntotal = 0; |
1105 | 0 | } |
1106 | | |
1107 | 0 | size_t IndexIVF::remove_ids(const IDSelector& sel) { |
1108 | 0 | size_t nremove = direct_map.remove_ids(sel, invlists); |
1109 | 0 | ntotal -= nremove; |
1110 | 0 | return nremove; |
1111 | 0 | } |
1112 | | |
1113 | 0 | void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) { |
1114 | 0 | if (direct_map.type == DirectMap::Hashtable) { |
1115 | | // just remove then add |
1116 | 0 | IDSelectorArray sel(n, new_ids); |
1117 | 0 | size_t nremove = remove_ids(sel); |
1118 | 0 | FAISS_THROW_IF_NOT_MSG( |
1119 | 0 | nremove == n, "did not find all entries to remove"); |
1120 | 0 | add_with_ids(n, x, new_ids); |
1121 | 0 | return; |
1122 | 0 | } |
1123 | | |
1124 | 0 | FAISS_THROW_IF_NOT(direct_map.type == DirectMap::Array); |
1125 | | // here it is more tricky because we don't want to introduce holes |
1126 | | // in continuous range of ids |
1127 | | |
1128 | 0 | FAISS_THROW_IF_NOT(is_trained); |
1129 | 0 | std::vector<idx_t> assign(n); |
1130 | 0 | quantizer->assign(n, x, assign.data()); |
1131 | |
|
1132 | 0 | std::vector<uint8_t> flat_codes(n * code_size); |
1133 | 0 | encode_vectors(n, x, assign.data(), flat_codes.data()); |
1134 | |
|
1135 | 0 | direct_map.update_codes( |
1136 | 0 | invlists, n, new_ids, assign.data(), flat_codes.data()); |
1137 | 0 | } |
1138 | | |
1139 | 0 | void IndexIVF::train(idx_t n, const float* x) { |
1140 | 0 | if (verbose) { |
1141 | 0 | printf("Training level-1 quantizer\n"); |
1142 | 0 | } |
1143 | |
|
1144 | 0 | train_q1(n, x, verbose, metric_type); |
1145 | |
|
1146 | 0 | if (verbose) { |
1147 | 0 | printf("Training IVF residual\n"); |
1148 | 0 | } |
1149 | | |
1150 | | // optional subsampling |
1151 | 0 | idx_t max_nt = train_encoder_num_vectors(); |
1152 | 0 | if (max_nt <= 0) { |
1153 | 0 | max_nt = (size_t)1 << 35; |
1154 | 0 | } |
1155 | |
|
1156 | 0 | TransformedVectors tv( |
1157 | 0 | x, fvecs_maybe_subsample(d, (size_t*)&n, max_nt, x, verbose)); |
1158 | |
|
1159 | 0 | if (by_residual) { |
1160 | 0 | std::vector<idx_t> assign(n); |
1161 | 0 | quantizer->assign(n, tv.x, assign.data()); |
1162 | |
|
1163 | 0 | std::vector<float> residuals(n * d); |
1164 | 0 | quantizer->compute_residual_n(n, tv.x, residuals.data(), assign.data()); |
1165 | |
|
1166 | 0 | train_encoder(n, residuals.data(), assign.data()); |
1167 | 0 | } else { |
1168 | 0 | train_encoder(n, tv.x, nullptr); |
1169 | 0 | } |
1170 | |
|
1171 | 0 | is_trained = true; |
1172 | 0 | } |
1173 | | |
1174 | 0 | idx_t IndexIVF::train_encoder_num_vectors() const { |
1175 | 0 | return 0; |
1176 | 0 | } |
1177 | | |
1178 | | void IndexIVF::train_encoder( |
1179 | | idx_t /*n*/, |
1180 | | const float* /*x*/, |
1181 | 0 | const idx_t* assign) { |
1182 | | // does nothing by default |
1183 | 0 | if (verbose) { |
1184 | 0 | printf("IndexIVF: no residual training\n"); |
1185 | 0 | } |
1186 | 0 | } |
1187 | | |
1188 | | bool check_compatible_for_merge_expensive_check = true; |
1189 | | |
1190 | 0 | void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const { |
1191 | | // minimal sanity checks |
1192 | 0 | const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex); |
1193 | 0 | FAISS_THROW_IF_NOT(other); |
1194 | 0 | FAISS_THROW_IF_NOT(other->d == d); |
1195 | 0 | FAISS_THROW_IF_NOT(other->nlist == nlist); |
1196 | 0 | FAISS_THROW_IF_NOT(quantizer->ntotal == other->quantizer->ntotal); |
1197 | 0 | FAISS_THROW_IF_NOT(other->code_size == code_size); |
1198 | 0 | FAISS_THROW_IF_NOT_MSG( |
1199 | 0 | typeid(*this) == typeid(*other), |
1200 | 0 | "can only merge indexes of the same type"); |
1201 | 0 | FAISS_THROW_IF_NOT_MSG( |
1202 | 0 | this->direct_map.no() && other->direct_map.no(), |
1203 | 0 | "merge direct_map not implemented"); |
1204 | | |
1205 | 0 | if (check_compatible_for_merge_expensive_check) { |
1206 | 0 | std::vector<float> v(d), v2(d); |
1207 | 0 | for (size_t i = 0; i < nlist; i++) { |
1208 | 0 | quantizer->reconstruct(i, v.data()); |
1209 | 0 | other->quantizer->reconstruct(i, v2.data()); |
1210 | 0 | FAISS_THROW_IF_NOT_MSG( |
1211 | 0 | v == v2, "coarse quantizers should be the same"); |
1212 | 0 | } |
1213 | 0 | } |
1214 | 0 | } |
1215 | | |
1216 | 0 | void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) { |
1217 | 0 | check_compatible_for_merge(otherIndex); |
1218 | 0 | IndexIVF* other = static_cast<IndexIVF*>(&otherIndex); |
1219 | 0 | invlists->merge_from(other->invlists, add_id); |
1220 | |
|
1221 | 0 | ntotal += other->ntotal; |
1222 | 0 | other->ntotal = 0; |
1223 | 0 | } |
1224 | | |
1225 | 0 | CodePacker* IndexIVF::get_CodePacker() const { |
1226 | 0 | return new CodePackerFlat(code_size); |
1227 | 0 | } |
1228 | | |
1229 | 0 | void IndexIVF::replace_invlists(InvertedLists* il, bool own) { |
1230 | 0 | if (own_invlists) { |
1231 | 0 | delete invlists; |
1232 | 0 | invlists = nullptr; |
1233 | 0 | } |
1234 | | // FAISS_THROW_IF_NOT (ntotal == 0); |
1235 | 0 | if (il) { |
1236 | 0 | FAISS_THROW_IF_NOT(il->nlist == nlist); |
1237 | 0 | FAISS_THROW_IF_NOT( |
1238 | 0 | il->code_size == code_size || |
1239 | 0 | il->code_size == InvertedLists::INVALID_CODE_SIZE); |
1240 | 0 | } |
1241 | 0 | invlists = il; |
1242 | 0 | own_invlists = own; |
1243 | 0 | } |
1244 | | |
1245 | | void IndexIVF::copy_subset_to( |
1246 | | IndexIVF& other, |
1247 | | InvertedLists::subset_type_t subset_type, |
1248 | | idx_t a1, |
1249 | 0 | idx_t a2) const { |
1250 | 0 | other.ntotal += |
1251 | 0 | invlists->copy_subset_to(*other.invlists, subset_type, a1, a2); |
1252 | 0 | } |
1253 | | |
1254 | 0 | IndexIVF::~IndexIVF() { |
1255 | 0 | if (own_invlists) { |
1256 | 0 | delete invlists; |
1257 | 0 | } |
1258 | 0 | } |
1259 | | |
1260 | | /************************************************************************* |
1261 | | * IndexIVFStats |
1262 | | *************************************************************************/ |
1263 | | |
1264 | 8 | void IndexIVFStats::reset() { |
1265 | 8 | memset((void*)this, 0, sizeof(*this)); |
1266 | 8 | } |
1267 | | |
1268 | 0 | void IndexIVFStats::add(const IndexIVFStats& other) { |
1269 | 0 | nq += other.nq; |
1270 | 0 | nlist += other.nlist; |
1271 | 0 | ndis += other.ndis; |
1272 | 0 | nheap_updates += other.nheap_updates; |
1273 | 0 | quantization_time += other.quantization_time; |
1274 | 0 | search_time += other.search_time; |
1275 | 0 | } |
1276 | | |
1277 | | IndexIVFStats indexIVF_stats; |
1278 | | |
1279 | | /************************************************************************* |
1280 | | * InvertedListScanner |
1281 | | *************************************************************************/ |
1282 | | |
1283 | | size_t InvertedListScanner::scan_codes( |
1284 | | size_t list_size, |
1285 | | const uint8_t* codes, |
1286 | | const idx_t* ids, |
1287 | | float* simi, |
1288 | | idx_t* idxi, |
1289 | 0 | size_t k) const { |
1290 | 0 | size_t nup = 0; |
1291 | |
|
1292 | 0 | if (!keep_max) { |
1293 | 0 | for (size_t j = 0; j < list_size; j++) { |
1294 | 0 | if (sel != nullptr) { |
1295 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
1296 | 0 | if (!sel->is_member(id)) { |
1297 | 0 | codes += code_size; |
1298 | 0 | continue; |
1299 | 0 | } |
1300 | 0 | } |
1301 | | |
1302 | 0 | float dis = distance_to_code(codes); |
1303 | 0 | if (dis < simi[0]) { |
1304 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
1305 | 0 | maxheap_replace_top(k, simi, idxi, dis, id); |
1306 | 0 | nup++; |
1307 | 0 | } |
1308 | 0 | codes += code_size; |
1309 | 0 | } |
1310 | 0 | } else { |
1311 | 0 | for (size_t j = 0; j < list_size; j++) { |
1312 | 0 | if (sel != nullptr) { |
1313 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
1314 | 0 | if (!sel->is_member(id)) { |
1315 | 0 | codes += code_size; |
1316 | 0 | continue; |
1317 | 0 | } |
1318 | 0 | } |
1319 | | |
1320 | 0 | float dis = distance_to_code(codes); |
1321 | 0 | if (dis > simi[0]) { |
1322 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
1323 | 0 | minheap_replace_top(k, simi, idxi, dis, id); |
1324 | 0 | nup++; |
1325 | 0 | } |
1326 | 0 | codes += code_size; |
1327 | 0 | } |
1328 | 0 | } |
1329 | 0 | return nup; |
1330 | 0 | } |
1331 | | |
1332 | | size_t InvertedListScanner::iterate_codes( |
1333 | | InvertedListsIterator* it, |
1334 | | float* simi, |
1335 | | idx_t* idxi, |
1336 | | size_t k, |
1337 | 0 | size_t& list_size) const { |
1338 | 0 | size_t nup = 0; |
1339 | 0 | list_size = 0; |
1340 | |
|
1341 | 0 | if (!keep_max) { |
1342 | 0 | for (; it->is_available(); it->next()) { |
1343 | 0 | auto id_and_codes = it->get_id_and_codes(); |
1344 | 0 | float dis = distance_to_code(id_and_codes.second); |
1345 | 0 | if (dis < simi[0]) { |
1346 | 0 | maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first); |
1347 | 0 | nup++; |
1348 | 0 | } |
1349 | 0 | list_size++; |
1350 | 0 | } |
1351 | 0 | } else { |
1352 | 0 | for (; it->is_available(); it->next()) { |
1353 | 0 | auto id_and_codes = it->get_id_and_codes(); |
1354 | 0 | float dis = distance_to_code(id_and_codes.second); |
1355 | 0 | if (dis > simi[0]) { |
1356 | 0 | minheap_replace_top(k, simi, idxi, dis, id_and_codes.first); |
1357 | 0 | nup++; |
1358 | 0 | } |
1359 | 0 | list_size++; |
1360 | 0 | } |
1361 | 0 | } |
1362 | 0 | return nup; |
1363 | 0 | } |
1364 | | |
1365 | | void InvertedListScanner::scan_codes_range( |
1366 | | size_t list_size, |
1367 | | const uint8_t* codes, |
1368 | | const idx_t* ids, |
1369 | | float radius, |
1370 | 0 | RangeQueryResult& res) const { |
1371 | 0 | for (size_t j = 0; j < list_size; j++) { |
1372 | 0 | float dis = distance_to_code(codes); |
1373 | 0 | bool keep = !keep_max |
1374 | 0 | ? dis < radius |
1375 | 0 | : dis > radius; // TODO templatize to remove this test |
1376 | 0 | if (keep) { |
1377 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
1378 | 0 | res.add(dis, id); |
1379 | 0 | } |
1380 | 0 | codes += code_size; |
1381 | 0 | } |
1382 | 0 | } |
1383 | | |
1384 | | void InvertedListScanner::iterate_codes_range( |
1385 | | InvertedListsIterator* it, |
1386 | | float radius, |
1387 | | RangeQueryResult& res, |
1388 | 0 | size_t& list_size) const { |
1389 | 0 | list_size = 0; |
1390 | 0 | for (; it->is_available(); it->next()) { |
1391 | 0 | auto id_and_codes = it->get_id_and_codes(); |
1392 | 0 | float dis = distance_to_code(id_and_codes.second); |
1393 | 0 | bool keep = !keep_max |
1394 | 0 | ? dis < radius |
1395 | 0 | : dis > radius; // TODO templatize to remove this test |
1396 | 0 | if (keep) { |
1397 | 0 | res.add(dis, id_and_codes.first); |
1398 | 0 | } |
1399 | 0 | list_size++; |
1400 | 0 | } |
1401 | 0 | } |
1402 | | |
1403 | | } // namespace faiss |