Coverage Report

Created: 2025-11-01 13:43

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVF.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/IndexIVF.h>
11
12
#include <omp.h>
13
#include <cstdint>
14
#include <memory>
15
#include <mutex>
16
17
#include <algorithm>
18
#include <cinttypes>
19
#include <cstdio>
20
#include <limits>
21
22
#include <faiss/utils/hamming.h>
23
#include <faiss/utils/utils.h>
24
25
#include <faiss/IndexFlat.h>
26
#include <faiss/impl/AuxIndexStructures.h>
27
#include <faiss/impl/CodePacker.h>
28
#include <faiss/impl/FaissAssert.h>
29
#include <faiss/impl/IDSelector.h>
30
31
namespace faiss {
32
33
using ScopedIds = InvertedLists::ScopedIds;
34
using ScopedCodes = InvertedLists::ScopedCodes;
35
36
/*****************************************
37
 * Level1Quantizer implementation
38
 ******************************************/
39
40
Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
41
0
        : quantizer(quantizer), nlist(nlist) {
42
    // here we set a low # iterations because this is typically used
43
    // for large clusterings (nb this is not used for the MultiIndex,
44
    // for which quantizer_trains_alone = true)
45
0
    cp.niter = 10;
46
0
}
47
48
0
Level1Quantizer::Level1Quantizer() = default;
49
50
0
Level1Quantizer::~Level1Quantizer() {
51
0
    if (own_fields) {
52
0
        delete quantizer;
53
0
    }
54
0
}
55
56
void Level1Quantizer::train_q1(
57
        size_t n,
58
        const float* x,
59
        bool verbose,
60
0
        MetricType metric_type) {
61
0
    size_t d = quantizer->d;
62
0
    if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
63
0
        if (verbose)
64
0
            printf("IVF quantizer does not need training.\n");
65
0
    } else if (quantizer_trains_alone == 1) {
66
0
        if (verbose)
67
0
            printf("IVF quantizer trains alone...\n");
68
0
        quantizer->verbose = verbose;
69
0
        quantizer->train(n, x);
70
0
        FAISS_THROW_IF_NOT_MSG(
71
0
                quantizer->ntotal == nlist,
72
0
                "nlist not consistent with quantizer size");
73
0
    } else if (quantizer_trains_alone == 0) {
74
0
        if (verbose)
75
0
            printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
76
77
0
        Clustering clus(d, nlist, cp);
78
0
        quantizer->reset();
79
0
        if (clustering_index) {
80
0
            clus.train(n, x, *clustering_index);
81
0
            quantizer->add(nlist, clus.centroids.data());
82
0
        } else {
83
0
            clus.train(n, x, *quantizer);
84
0
        }
85
0
        quantizer->is_trained = true;
86
0
    } else if (quantizer_trains_alone == 2) {
87
0
        if (verbose) {
88
0
            printf("Training L2 quantizer on %zd vectors in %zdD%s\n",
89
0
                   n,
90
0
                   d,
91
0
                   clustering_index ? "(user provided index)" : "");
92
0
        }
93
        // also accept spherical centroids because in that case
94
        // L2 and IP are equivalent
95
0
        FAISS_THROW_IF_NOT(
96
0
                metric_type == METRIC_L2 ||
97
0
                (metric_type == METRIC_INNER_PRODUCT && cp.spherical));
98
99
0
        Clustering clus(d, nlist, cp);
100
0
        if (!clustering_index) {
101
0
            IndexFlatL2 assigner(d);
102
0
            clus.train(n, x, assigner);
103
0
        } else {
104
0
            clus.train(n, x, *clustering_index);
105
0
        }
106
0
        if (verbose) {
107
0
            printf("Adding centroids to quantizer\n");
108
0
        }
109
0
        if (!quantizer->is_trained) {
110
0
            if (verbose) {
111
0
                printf("But training it first on centroids table...\n");
112
0
            }
113
0
            quantizer->train(nlist, clus.centroids.data());
114
0
        }
115
0
        quantizer->add(nlist, clus.centroids.data());
116
0
    }
117
0
}
118
119
0
size_t Level1Quantizer::coarse_code_size() const {
120
0
    size_t nl = nlist - 1;
121
0
    size_t nbyte = 0;
122
0
    while (nl > 0) {
123
0
        nbyte++;
124
0
        nl >>= 8;
125
0
    }
126
0
    return nbyte;
127
0
}
128
129
0
void Level1Quantizer::encode_listno(idx_t list_no, uint8_t* code) const {
130
    // little endian
131
0
    size_t nl = nlist - 1;
132
0
    while (nl > 0) {
133
0
        *code++ = list_no & 0xff;
134
0
        list_no >>= 8;
135
0
        nl >>= 8;
136
0
    }
137
0
}
138
139
0
idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
140
0
    size_t nl = nlist - 1;
141
0
    int64_t list_no = 0;
142
0
    int nbit = 0;
143
0
    while (nl > 0) {
144
0
        list_no |= int64_t(*code++) << nbit;
145
0
        nbit += 8;
146
0
        nl >>= 8;
147
0
    }
148
0
    FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
149
0
    return list_no;
150
0
}
151
152
/*****************************************
153
 * IndexIVF implementation
154
 ******************************************/
155
156
IndexIVF::IndexIVF(
157
        Index* quantizer,
158
        size_t d,
159
        size_t nlist,
160
        size_t code_size,
161
        MetricType metric)
162
0
        : Index(d, metric),
163
0
          IndexIVFInterface(quantizer, nlist),
164
0
          invlists(new ArrayInvertedLists(nlist, code_size)),
165
0
          own_invlists(true),
166
0
          code_size(code_size) {
167
0
    FAISS_THROW_IF_NOT(d == quantizer->d);
168
0
    is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
169
    // Spherical by default if the metric is inner_product
170
0
    if (metric_type == METRIC_INNER_PRODUCT) {
171
0
        cp.spherical = true;
172
0
    }
173
0
}
174
175
0
IndexIVF::IndexIVF() = default;
176
177
0
void IndexIVF::add(idx_t n, const float* x) {
178
0
    add_with_ids(n, x, nullptr);
179
0
}
180
181
0
void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
182
0
    std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
183
0
    quantizer->assign(n, x, coarse_idx.get());
184
0
    add_core(n, x, xids, coarse_idx.get());
185
0
}
186
187
0
void IndexIVF::add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) {
188
0
    size_t coarse_size = coarse_code_size();
189
0
    DirectMapAdd dm_adder(direct_map, n, xids);
190
191
0
    for (idx_t i = 0; i < n; i++) {
192
0
        const uint8_t* code = codes + (code_size + coarse_size) * i;
193
0
        idx_t list_no = decode_listno(code);
194
0
        idx_t id = xids ? xids[i] : ntotal + i;
195
0
        size_t ofs = invlists->add_entry(list_no, id, code + coarse_size);
196
0
        dm_adder.add(i, list_no, ofs);
197
0
    }
198
0
    ntotal += n;
199
0
}
200
201
void IndexIVF::add_core(
202
        idx_t n,
203
        const float* x,
204
        const idx_t* xids,
205
        const idx_t* coarse_idx,
206
0
        void* inverted_list_context) {
207
    // do some blocking to avoid excessive allocs
208
0
    idx_t bs = 65536;
209
0
    if (n > bs) {
210
0
        for (idx_t i0 = 0; i0 < n; i0 += bs) {
211
0
            idx_t i1 = std::min(n, i0 + bs);
212
0
            if (verbose) {
213
0
                printf("   IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n",
214
0
                       i0,
215
0
                       i1);
216
0
            }
217
0
            add_core(
218
0
                    i1 - i0,
219
0
                    x + i0 * d,
220
0
                    xids ? xids + i0 : nullptr,
221
0
                    coarse_idx + i0,
222
0
                    inverted_list_context);
223
0
        }
224
0
        return;
225
0
    }
226
0
    FAISS_THROW_IF_NOT(coarse_idx);
227
0
    FAISS_THROW_IF_NOT(is_trained);
228
0
    direct_map.check_can_add(xids);
229
230
0
    size_t nadd = 0, nminus1 = 0;
231
232
0
    for (size_t i = 0; i < n; i++) {
233
0
        if (coarse_idx[i] < 0)
234
0
            nminus1++;
235
0
    }
236
237
0
    std::unique_ptr<uint8_t[]> flat_codes(new uint8_t[n * code_size]);
238
0
    encode_vectors(n, x, coarse_idx, flat_codes.get());
239
240
0
    DirectMapAdd dm_adder(direct_map, n, xids);
241
242
0
#pragma omp parallel reduction(+ : nadd)
243
0
    {
244
0
        int nt = omp_get_num_threads();
245
0
        int rank = omp_get_thread_num();
246
247
        // each thread takes care of a subset of lists
248
0
        for (size_t i = 0; i < n; i++) {
249
0
            idx_t list_no = coarse_idx[i];
250
0
            if (list_no >= 0 && list_no % nt == rank) {
251
0
                idx_t id = xids ? xids[i] : ntotal + i;
252
0
                size_t ofs = invlists->add_entry(
253
0
                        list_no,
254
0
                        id,
255
0
                        flat_codes.get() + i * code_size,
256
0
                        inverted_list_context);
257
258
0
                dm_adder.add(i, list_no, ofs);
259
260
0
                nadd++;
261
0
            } else if (rank == 0 && list_no == -1) {
262
0
                dm_adder.add(i, -1, 0);
263
0
            }
264
0
        }
265
0
    }
266
267
0
    if (verbose) {
268
0
        printf("    added %zd / %" PRId64 " vectors (%zd -1s)\n",
269
0
               nadd,
270
0
               n,
271
0
               nminus1);
272
0
    }
273
274
0
    ntotal += n;
275
0
}
276
277
0
void IndexIVF::make_direct_map(bool b) {
278
0
    if (b) {
279
0
        direct_map.set_type(DirectMap::Array, invlists, ntotal);
280
0
    } else {
281
0
        direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
282
0
    }
283
0
}
284
285
0
void IndexIVF::set_direct_map_type(DirectMap::Type type) {
286
0
    direct_map.set_type(type, invlists, ntotal);
287
0
}
288
289
/** It is a sad fact of software that a conceptually simple function like this
290
 * becomes very complex when you factor in several ways of parallelizing +
291
 * interrupt/error handling + collecting stats + min/max collection. The
292
 * codepath that is used 95% of time is the one for parallel_mode = 0 */
293
void IndexIVF::search(
294
        idx_t n,
295
        const float* x,
296
        idx_t k,
297
        float* distances,
298
        idx_t* labels,
299
0
        const SearchParameters* params_in) const {
300
0
    FAISS_THROW_IF_NOT(k > 0);
301
0
    const IVFSearchParameters* params = nullptr;
302
0
    if (params_in) {
303
0
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
304
0
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
305
0
    }
306
0
    const size_t nprobe =
307
0
            std::min(nlist, params ? params->nprobe : this->nprobe);
308
0
    FAISS_THROW_IF_NOT(nprobe > 0);
309
310
    // search function for a subset of queries
311
0
    auto sub_search_func = [this, k, nprobe, params](
312
0
                                   idx_t n,
313
0
                                   const float* x,
314
0
                                   float* distances,
315
0
                                   idx_t* labels,
316
0
                                   IndexIVFStats* ivf_stats) {
317
0
        std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
318
0
        std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
319
320
0
        double t0 = getmillisecs();
321
0
        quantizer->search(
322
0
                n,
323
0
                x,
324
0
                nprobe,
325
0
                coarse_dis.get(),
326
0
                idx.get(),
327
0
                params ? params->quantizer_params : nullptr);
328
329
0
        double t1 = getmillisecs();
330
0
        invlists->prefetch_lists(idx.get(), n * nprobe);
331
332
0
        search_preassigned(
333
0
                n,
334
0
                x,
335
0
                k,
336
0
                idx.get(),
337
0
                coarse_dis.get(),
338
0
                distances,
339
0
                labels,
340
0
                false,
341
0
                params,
342
0
                ivf_stats);
343
0
        double t2 = getmillisecs();
344
0
        ivf_stats->quantization_time += t1 - t0;
345
0
        ivf_stats->search_time += t2 - t0;
346
0
    };
347
348
0
    if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
349
0
        int nt = std::min(omp_get_max_threads(), int(n));
350
0
        std::vector<IndexIVFStats> stats(nt);
351
0
        std::mutex exception_mutex;
352
0
        std::string exception_string;
353
354
0
#pragma omp parallel for if (nt > 1)
355
0
        for (idx_t slice = 0; slice < nt; slice++) {
356
0
            IndexIVFStats local_stats;
357
0
            idx_t i0 = n * slice / nt;
358
0
            idx_t i1 = n * (slice + 1) / nt;
359
0
            if (i1 > i0) {
360
0
                try {
361
0
                    sub_search_func(
362
0
                            i1 - i0,
363
0
                            x + i0 * d,
364
0
                            distances + i0 * k,
365
0
                            labels + i0 * k,
366
0
                            &stats[slice]);
367
0
                } catch (const std::exception& e) {
368
0
                    std::lock_guard<std::mutex> lock(exception_mutex);
369
0
                    exception_string = e.what();
370
0
                }
371
0
            }
372
0
        }
373
374
0
        if (!exception_string.empty()) {
375
0
            FAISS_THROW_MSG(exception_string.c_str());
376
0
        }
377
378
        // collect stats
379
0
        for (idx_t slice = 0; slice < nt; slice++) {
380
0
            indexIVF_stats.add(stats[slice]);
381
0
        }
382
0
    } else {
383
        // handle parallelization at level below (or don't run in parallel at
384
        // all)
385
0
        sub_search_func(n, x, distances, labels, &indexIVF_stats);
386
0
    }
387
0
}
388
389
void IndexIVF::search_preassigned(
390
        idx_t n,
391
        const float* x,
392
        idx_t k,
393
        const idx_t* keys,
394
        const float* coarse_dis,
395
        float* distances,
396
        idx_t* labels,
397
        bool store_pairs,
398
        const IVFSearchParameters* params,
399
0
        IndexIVFStats* ivf_stats) const {
400
0
    FAISS_THROW_IF_NOT(k > 0);
401
402
0
    idx_t nprobe = params ? params->nprobe : this->nprobe;
403
0
    nprobe = std::min((idx_t)nlist, nprobe);
404
0
    FAISS_THROW_IF_NOT(nprobe > 0);
405
406
0
    const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
407
0
    idx_t max_codes = params ? params->max_codes : this->max_codes;
408
0
    IDSelector* sel = params ? params->sel : nullptr;
409
0
    const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
410
0
    if (selr) {
411
0
        if (selr->assume_sorted) {
412
0
            sel = nullptr; // use special IDSelectorRange processing
413
0
        } else {
414
0
            selr = nullptr; // use generic processing
415
0
        }
416
0
    }
417
418
0
    FAISS_THROW_IF_NOT_MSG(
419
0
            !(sel && store_pairs),
420
0
            "selector and store_pairs cannot be combined");
421
422
0
    FAISS_THROW_IF_NOT_MSG(
423
0
            !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
424
0
            "iterable inverted lists don't support max_codes and store_pairs");
425
426
0
    size_t nlistv = 0, ndis = 0, nheap = 0;
427
428
0
    using HeapForIP = CMin<float, idx_t>;
429
0
    using HeapForL2 = CMax<float, idx_t>;
430
431
0
    bool interrupt = false;
432
0
    std::mutex exception_mutex;
433
0
    std::string exception_string;
434
435
0
    int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
436
0
    bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
437
438
0
    FAISS_THROW_IF_NOT_MSG(
439
0
            max_codes == 0 || pmode == 0 || pmode == 3,
440
0
            "max_codes supported only for parallel_mode = 0 or 3");
441
442
0
    if (max_codes == 0) {
443
0
        max_codes = unlimited_list_size;
444
0
    }
445
446
0
    [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
447
0
            (pmode == 0           ? false
448
0
                     : pmode == 3 ? n > 1
449
0
                     : pmode == 1 ? nprobe > 1
450
0
                                  : nprobe * n > 1);
451
452
0
    void* inverted_list_context =
453
0
            params ? params->inverted_list_context : nullptr;
454
455
0
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
456
0
    {
457
0
        std::unique_ptr<InvertedListScanner> scanner(
458
0
                get_InvertedListScanner(store_pairs, sel, params));
459
460
        /*****************************************************
461
         * Depending on parallel_mode, there are two possible ways
462
         * to organize the search. Here we define local functions
463
         * that are in common between the two
464
         ******************************************************/
465
466
        // initialize + reorder a result heap
467
468
0
        auto init_result = [&](float* simi, idx_t* idxi) {
469
0
            if (!do_heap_init)
470
0
                return;
471
0
            if (metric_type == METRIC_INNER_PRODUCT) {
472
0
                heap_heapify<HeapForIP>(k, simi, idxi);
473
0
            } else {
474
0
                heap_heapify<HeapForL2>(k, simi, idxi);
475
0
            }
476
0
        };
477
478
0
        auto add_local_results = [&](const float* local_dis,
479
0
                                     const idx_t* local_idx,
480
0
                                     float* simi,
481
0
                                     idx_t* idxi) {
482
0
            if (metric_type == METRIC_INNER_PRODUCT) {
483
0
                heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
484
0
            } else {
485
0
                heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
486
0
            }
487
0
        };
488
489
0
        auto reorder_result = [&](float* simi, idx_t* idxi) {
490
0
            if (!do_heap_init)
491
0
                return;
492
0
            if (metric_type == METRIC_INNER_PRODUCT) {
493
0
                heap_reorder<HeapForIP>(k, simi, idxi);
494
0
            } else {
495
0
                heap_reorder<HeapForL2>(k, simi, idxi);
496
0
            }
497
0
        };
498
499
        // single list scan using the current scanner (with query
500
        // set porperly) and storing results in simi and idxi
501
0
        auto scan_one_list = [&](idx_t key,
502
0
                                 float coarse_dis_i,
503
0
                                 float* simi,
504
0
                                 idx_t* idxi,
505
0
                                 idx_t list_size_max) {
506
0
            if (key < 0) {
507
                // not enough centroids for multiprobe
508
0
                return (size_t)0;
509
0
            }
510
0
            FAISS_THROW_IF_NOT_FMT(
511
0
                    key < (idx_t)nlist,
512
0
                    "Invalid key=%" PRId64 " nlist=%zd\n",
513
0
                    key,
514
0
                    nlist);
515
516
            // don't waste time on empty lists
517
0
            if (invlists->is_empty(key, inverted_list_context)) {
518
0
                return (size_t)0;
519
0
            }
520
521
0
            scanner->set_list(key, coarse_dis_i);
522
523
0
            nlistv++;
524
525
0
            try {
526
0
                if (invlists->use_iterator) {
527
0
                    size_t list_size = 0;
528
529
0
                    std::unique_ptr<InvertedListsIterator> it(
530
0
                            invlists->get_iterator(key, inverted_list_context));
531
532
0
                    nheap += scanner->iterate_codes(
533
0
                            it.get(), simi, idxi, k, list_size);
534
535
0
                    return list_size;
536
0
                } else {
537
0
                    size_t list_size = invlists->list_size(key);
538
0
                    if (list_size > list_size_max) {
539
0
                        list_size = list_size_max;
540
0
                    }
541
542
0
                    InvertedLists::ScopedCodes scodes(invlists, key);
543
0
                    const uint8_t* codes = scodes.get();
544
545
0
                    std::unique_ptr<InvertedLists::ScopedIds> sids;
546
0
                    const idx_t* ids = nullptr;
547
548
0
                    if (!store_pairs) {
549
0
                        sids = std::make_unique<InvertedLists::ScopedIds>(
550
0
                                invlists, key);
551
0
                        ids = sids->get();
552
0
                    }
553
554
0
                    if (selr) { // IDSelectorRange
555
                        // restrict search to a section of the inverted list
556
0
                        size_t jmin, jmax;
557
0
                        selr->find_sorted_ids_bounds(
558
0
                                list_size, ids, &jmin, &jmax);
559
0
                        list_size = jmax - jmin;
560
0
                        if (list_size == 0) {
561
0
                            return (size_t)0;
562
0
                        }
563
0
                        codes += jmin * code_size;
564
0
                        ids += jmin;
565
0
                    }
566
567
0
                    nheap += scanner->scan_codes(
568
0
                            list_size, codes, ids, simi, idxi, k);
569
570
0
                    return list_size;
571
0
                }
572
0
            } catch (const std::exception& e) {
573
0
                std::lock_guard<std::mutex> lock(exception_mutex);
574
0
                exception_string =
575
0
                        demangle_cpp_symbol(typeid(e).name()) + "  " + e.what();
576
0
                interrupt = true;
577
0
                return size_t(0);
578
0
            }
579
0
        };
580
581
        /****************************************************
582
         * Actual loops, depending on parallel_mode
583
         ****************************************************/
584
585
0
        if (pmode == 0 || pmode == 3) {
586
0
#pragma omp for
587
0
            for (idx_t i = 0; i < n; i++) {
588
0
                if (interrupt) {
589
0
                    continue;
590
0
                }
591
592
                // loop over queries
593
0
                scanner->set_query(x + i * d);
594
0
                float* simi = distances + i * k;
595
0
                idx_t* idxi = labels + i * k;
596
597
0
                init_result(simi, idxi);
598
599
0
                idx_t nscan = 0;
600
601
                // loop over probes
602
0
                for (size_t ik = 0; ik < nprobe; ik++) {
603
0
                    nscan += scan_one_list(
604
0
                            keys[i * nprobe + ik],
605
0
                            coarse_dis[i * nprobe + ik],
606
0
                            simi,
607
0
                            idxi,
608
0
                            max_codes - nscan);
609
0
                    if (nscan >= max_codes) {
610
0
                        break;
611
0
                    }
612
0
                }
613
614
0
                ndis += nscan;
615
0
                reorder_result(simi, idxi);
616
617
0
                if (InterruptCallback::is_interrupted()) {
618
0
                    interrupt = true;
619
0
                }
620
621
0
            } // parallel for
622
0
        } else if (pmode == 1) {
623
0
            std::vector<idx_t> local_idx(k);
624
0
            std::vector<float> local_dis(k);
625
626
0
            for (size_t i = 0; i < n; i++) {
627
0
                scanner->set_query(x + i * d);
628
0
                init_result(local_dis.data(), local_idx.data());
629
630
0
#pragma omp for schedule(dynamic)
631
0
                for (idx_t ik = 0; ik < nprobe; ik++) {
632
0
                    ndis += scan_one_list(
633
0
                            keys[i * nprobe + ik],
634
0
                            coarse_dis[i * nprobe + ik],
635
0
                            local_dis.data(),
636
0
                            local_idx.data(),
637
0
                            unlimited_list_size);
638
639
                    // can't do the test on max_codes
640
0
                }
641
                // merge thread-local results
642
643
0
                float* simi = distances + i * k;
644
0
                idx_t* idxi = labels + i * k;
645
0
#pragma omp single
646
0
                init_result(simi, idxi);
647
648
0
#pragma omp barrier
649
0
#pragma omp critical
650
0
                {
651
0
                    add_local_results(
652
0
                            local_dis.data(), local_idx.data(), simi, idxi);
653
0
                }
654
0
#pragma omp barrier
655
0
#pragma omp single
656
0
                reorder_result(simi, idxi);
657
0
            }
658
0
        } else if (pmode == 2) {
659
0
            std::vector<idx_t> local_idx(k);
660
0
            std::vector<float> local_dis(k);
661
662
0
#pragma omp single
663
0
            for (int64_t i = 0; i < n; i++) {
664
0
                init_result(distances + i * k, labels + i * k);
665
0
            }
666
667
0
#pragma omp for schedule(dynamic)
668
0
            for (int64_t ij = 0; ij < n * nprobe; ij++) {
669
0
                size_t i = ij / nprobe;
670
671
0
                scanner->set_query(x + i * d);
672
0
                init_result(local_dis.data(), local_idx.data());
673
0
                ndis += scan_one_list(
674
0
                        keys[ij],
675
0
                        coarse_dis[ij],
676
0
                        local_dis.data(),
677
0
                        local_idx.data(),
678
0
                        unlimited_list_size);
679
0
#pragma omp critical
680
0
                {
681
0
                    add_local_results(
682
0
                            local_dis.data(),
683
0
                            local_idx.data(),
684
0
                            distances + i * k,
685
0
                            labels + i * k);
686
0
                }
687
0
            }
688
0
#pragma omp single
689
0
            for (int64_t i = 0; i < n; i++) {
690
0
                reorder_result(distances + i * k, labels + i * k);
691
0
            }
692
0
        } else {
693
0
            FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
694
0
        }
695
0
    } // parallel section
696
697
0
    if (interrupt) {
698
0
        if (!exception_string.empty()) {
699
0
            FAISS_THROW_FMT(
700
0
                    "search interrupted with: %s", exception_string.c_str());
701
0
        } else {
702
0
            FAISS_THROW_MSG("computation interrupted");
703
0
        }
704
0
    }
705
706
0
    if (ivf_stats == nullptr) {
707
0
        ivf_stats = &indexIVF_stats;
708
0
    }
709
0
    ivf_stats->nq += n;
710
0
    ivf_stats->nlist += nlistv;
711
0
    ivf_stats->ndis += ndis;
712
0
    ivf_stats->nheap_updates += nheap;
713
0
}
714
715
void IndexIVF::range_search(
716
        idx_t nx,
717
        const float* x,
718
        float radius,
719
        RangeSearchResult* result,
720
0
        const SearchParameters* params_in) const {
721
0
    const IVFSearchParameters* params = nullptr;
722
0
    const SearchParameters* quantizer_params = nullptr;
723
0
    if (params_in) {
724
0
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
725
0
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
726
0
        quantizer_params = params->quantizer_params;
727
0
    }
728
0
    const size_t nprobe =
729
0
            std::min(nlist, params ? params->nprobe : this->nprobe);
730
0
    std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
731
0
    std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
732
733
0
    double t0 = getmillisecs();
734
0
    quantizer->search(
735
0
            nx, x, nprobe, coarse_dis.get(), keys.get(), quantizer_params);
736
0
    indexIVF_stats.quantization_time += getmillisecs() - t0;
737
738
0
    t0 = getmillisecs();
739
0
    invlists->prefetch_lists(keys.get(), nx * nprobe);
740
741
0
    range_search_preassigned(
742
0
            nx,
743
0
            x,
744
0
            radius,
745
0
            keys.get(),
746
0
            coarse_dis.get(),
747
0
            result,
748
0
            false,
749
0
            params,
750
0
            &indexIVF_stats);
751
752
0
    indexIVF_stats.search_time += getmillisecs() - t0;
753
0
}
754
755
void IndexIVF::range_search_preassigned(
756
        idx_t nx,
757
        const float* x,
758
        float radius,
759
        const idx_t* keys,
760
        const float* coarse_dis,
761
        RangeSearchResult* result,
762
        bool store_pairs,
763
        const IVFSearchParameters* params,
764
0
        IndexIVFStats* stats) const {
765
0
    idx_t nprobe = params ? params->nprobe : this->nprobe;
766
0
    nprobe = std::min((idx_t)nlist, nprobe);
767
0
    FAISS_THROW_IF_NOT(nprobe > 0);
768
769
0
    idx_t max_codes = params ? params->max_codes : this->max_codes;
770
0
    IDSelector* sel = params ? params->sel : nullptr;
771
772
0
    FAISS_THROW_IF_NOT_MSG(
773
0
            !invlists->use_iterator || (max_codes == 0 && store_pairs == false),
774
0
            "iterable inverted lists don't support max_codes and store_pairs");
775
776
0
    size_t nlistv = 0, ndis = 0;
777
778
0
    bool interrupt = false;
779
0
    std::mutex exception_mutex;
780
0
    std::string exception_string;
781
782
0
    std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
783
784
0
    int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
785
    // don't start parallel section if single query
786
0
    [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
787
0
            (pmode == 3           ? false
788
0
                     : pmode == 0 ? nx > 1
789
0
                     : pmode == 1 ? nprobe > 1
790
0
                                  : nprobe * nx > 1);
791
792
0
    void* inverted_list_context =
793
0
            params ? params->inverted_list_context : nullptr;
794
795
0
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
796
0
    {
797
0
        RangeSearchPartialResult pres(result);
798
0
        std::unique_ptr<InvertedListScanner> scanner(
799
0
                get_InvertedListScanner(store_pairs, sel, params));
800
0
        FAISS_THROW_IF_NOT(scanner.get());
801
0
        all_pres[omp_get_thread_num()] = &pres;
802
803
        // prepare the list scanning function
804
805
0
        auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
806
0
            idx_t key = keys[i * nprobe + ik]; /* select the list  */
807
0
            if (key < 0)
808
0
                return;
809
0
            FAISS_THROW_IF_NOT_FMT(
810
0
                    key < (idx_t)nlist,
811
0
                    "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
812
0
                    key,
813
0
                    ik,
814
0
                    nlist);
815
816
0
            if (invlists->is_empty(key, inverted_list_context)) {
817
0
                return;
818
0
            }
819
820
0
            try {
821
0
                size_t list_size = 0;
822
0
                scanner->set_list(key, coarse_dis[i * nprobe + ik]);
823
0
                if (invlists->use_iterator) {
824
0
                    std::unique_ptr<InvertedListsIterator> it(
825
0
                            invlists->get_iterator(key, inverted_list_context));
826
827
0
                    scanner->iterate_codes_range(
828
0
                            it.get(), radius, qres, list_size);
829
0
                } else {
830
0
                    InvertedLists::ScopedCodes scodes(invlists, key);
831
0
                    InvertedLists::ScopedIds ids(invlists, key);
832
0
                    list_size = invlists->list_size(key);
833
834
0
                    scanner->scan_codes_range(
835
0
                            list_size, scodes.get(), ids.get(), radius, qres);
836
0
                }
837
0
                nlistv++;
838
0
                ndis += list_size;
839
0
            } catch (const std::exception& e) {
840
0
                std::lock_guard<std::mutex> lock(exception_mutex);
841
0
                exception_string =
842
0
                        demangle_cpp_symbol(typeid(e).name()) + "  " + e.what();
843
0
                interrupt = true;
844
0
            }
845
0
        };
846
847
0
        if (parallel_mode == 0) {
848
0
#pragma omp for
849
0
            for (idx_t i = 0; i < nx; i++) {
850
0
                scanner->set_query(x + i * d);
851
852
0
                RangeQueryResult& qres = pres.new_result(i);
853
854
0
                for (size_t ik = 0; ik < nprobe; ik++) {
855
0
                    scan_list_func(i, ik, qres);
856
0
                }
857
0
            }
858
859
0
        } else if (parallel_mode == 1) {
860
0
            for (size_t i = 0; i < nx; i++) {
861
0
                scanner->set_query(x + i * d);
862
863
0
                RangeQueryResult& qres = pres.new_result(i);
864
865
0
#pragma omp for schedule(dynamic)
866
0
                for (int64_t ik = 0; ik < nprobe; ik++) {
867
0
                    scan_list_func(i, ik, qres);
868
0
                }
869
0
            }
870
0
        } else if (parallel_mode == 2) {
871
0
            RangeQueryResult* qres = nullptr;
872
873
0
#pragma omp for schedule(dynamic)
874
0
            for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
875
0
                idx_t i = iik / (idx_t)nprobe;
876
0
                idx_t ik = iik % (idx_t)nprobe;
877
0
                if (qres == nullptr || qres->qno != i) {
878
0
                    qres = &pres.new_result(i);
879
0
                    scanner->set_query(x + i * d);
880
0
                }
881
0
                scan_list_func(i, ik, *qres);
882
0
            }
883
0
        } else {
884
0
            FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
885
0
        }
886
0
        if (parallel_mode == 0) {
887
0
            pres.finalize();
888
0
        } else {
889
0
#pragma omp barrier
890
0
#pragma omp single
891
0
            RangeSearchPartialResult::merge(all_pres, false);
892
0
#pragma omp barrier
893
0
        }
894
0
    }
895
896
0
    if (interrupt) {
897
0
        if (!exception_string.empty()) {
898
0
            FAISS_THROW_FMT(
899
0
                    "search interrupted with: %s", exception_string.c_str());
900
0
        } else {
901
0
            FAISS_THROW_MSG("computation interrupted");
902
0
        }
903
0
    }
904
905
0
    if (stats == nullptr) {
906
0
        stats = &indexIVF_stats;
907
0
    }
908
0
    stats->nq += nx;
909
0
    stats->nlist += nlistv;
910
0
    stats->ndis += ndis;
911
0
}
912
913
InvertedListScanner* IndexIVF::get_InvertedListScanner(
914
        bool /*store_pairs*/,
915
        const IDSelector* /* sel */,
916
0
        const IVFSearchParameters* /* params */) const {
917
0
    FAISS_THROW_MSG("get_InvertedListScanner not implemented");
918
0
}
919
920
0
void IndexIVF::reconstruct(idx_t key, float* recons) const {
921
0
    idx_t lo = direct_map.get(key);
922
0
    reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
923
0
}
924
925
0
void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
926
0
    FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
927
928
0
    for (idx_t list_no = 0; list_no < nlist; list_no++) {
929
0
        size_t list_size = invlists->list_size(list_no);
930
0
        ScopedIds idlist(invlists, list_no);
931
932
0
        for (idx_t offset = 0; offset < list_size; offset++) {
933
0
            idx_t id = idlist[offset];
934
0
            if (!(id >= i0 && id < i0 + ni)) {
935
0
                continue;
936
0
            }
937
938
0
            float* reconstructed = recons + (id - i0) * d;
939
0
            reconstruct_from_offset(list_no, offset, reconstructed);
940
0
        }
941
0
    }
942
0
}
943
944
0
bool IndexIVF::check_ids_sorted() const {
945
0
    size_t nflip = 0;
946
947
0
    for (size_t i = 0; i < nlist; i++) {
948
0
        size_t list_size = invlists->list_size(i);
949
0
        InvertedLists::ScopedIds ids(invlists, i);
950
0
        for (size_t j = 0; j + 1 < list_size; j++) {
951
0
            if (ids[j + 1] < ids[j]) {
952
0
                nflip++;
953
0
            }
954
0
        }
955
0
    }
956
0
    return nflip == 0;
957
0
}
958
959
/* standalone codec interface */
960
0
size_t IndexIVF::sa_code_size() const {
961
0
    size_t coarse_size = coarse_code_size();
962
0
    return code_size + coarse_size;
963
0
}
964
965
0
void IndexIVF::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
966
0
    FAISS_THROW_IF_NOT(is_trained);
967
0
    std::unique_ptr<int64_t[]> idx(new int64_t[n]);
968
0
    quantizer->assign(n, x, idx.get());
969
0
    encode_vectors(n, x, idx.get(), bytes, true);
970
0
}
971
972
void IndexIVF::search_and_reconstruct(
973
        idx_t n,
974
        const float* x,
975
        idx_t k,
976
        float* distances,
977
        idx_t* labels,
978
        float* recons,
979
0
        const SearchParameters* params_in) const {
980
0
    const IVFSearchParameters* params = nullptr;
981
0
    if (params_in) {
982
0
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
983
0
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
984
0
    }
985
0
    const size_t nprobe =
986
0
            std::min(nlist, params ? params->nprobe : this->nprobe);
987
0
    FAISS_THROW_IF_NOT(nprobe > 0);
988
989
0
    std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
990
0
    std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
991
992
0
    quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
993
994
0
    invlists->prefetch_lists(idx.get(), n * nprobe);
995
996
    // search_preassigned() with `store_pairs` enabled to obtain the list_no
997
    // and offset into `codes` for reconstruction
998
0
    search_preassigned(
999
0
            n,
1000
0
            x,
1001
0
            k,
1002
0
            idx.get(),
1003
0
            coarse_dis.get(),
1004
0
            distances,
1005
0
            labels,
1006
0
            true /* store_pairs */,
1007
0
            params);
1008
0
#pragma omp parallel for if (n * k > 1000)
1009
0
    for (idx_t ij = 0; ij < n * k; ij++) {
1010
0
        idx_t key = labels[ij];
1011
0
        float* reconstructed = recons + ij * d;
1012
0
        if (key < 0) {
1013
            // Fill with NaNs
1014
0
            memset(reconstructed, -1, sizeof(*reconstructed) * d);
1015
0
        } else {
1016
0
            int list_no = lo_listno(key);
1017
0
            int offset = lo_offset(key);
1018
1019
            // Update label to the actual id
1020
0
            labels[ij] = invlists->get_single_id(list_no, offset);
1021
1022
0
            reconstruct_from_offset(list_no, offset, reconstructed);
1023
0
        }
1024
0
    }
1025
0
}
1026
1027
void IndexIVF::search_and_return_codes(
1028
        idx_t n,
1029
        const float* x,
1030
        idx_t k,
1031
        float* distances,
1032
        idx_t* labels,
1033
        uint8_t* codes,
1034
        bool include_listno,
1035
0
        const SearchParameters* params_in) const {
1036
0
    const IVFSearchParameters* params = nullptr;
1037
0
    if (params_in) {
1038
0
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
1039
0
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
1040
0
    }
1041
0
    const size_t nprobe =
1042
0
            std::min(nlist, params ? params->nprobe : this->nprobe);
1043
0
    FAISS_THROW_IF_NOT(nprobe > 0);
1044
1045
0
    std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
1046
0
    std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1047
1048
0
    quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
1049
1050
0
    invlists->prefetch_lists(idx.get(), n * nprobe);
1051
1052
    // search_preassigned() with `store_pairs` enabled to obtain the list_no
1053
    // and offset into `codes` for reconstruction
1054
0
    search_preassigned(
1055
0
            n,
1056
0
            x,
1057
0
            k,
1058
0
            idx.get(),
1059
0
            coarse_dis.get(),
1060
0
            distances,
1061
0
            labels,
1062
0
            true /* store_pairs */,
1063
0
            params);
1064
1065
0
    size_t code_size_1 = code_size;
1066
0
    if (include_listno) {
1067
0
        code_size_1 += coarse_code_size();
1068
0
    }
1069
1070
0
#pragma omp parallel for if (n * k > 1000)
1071
0
    for (idx_t ij = 0; ij < n * k; ij++) {
1072
0
        idx_t key = labels[ij];
1073
0
        uint8_t* code1 = codes + ij * code_size_1;
1074
1075
0
        if (key < 0) {
1076
            // Fill with 0xff
1077
0
            memset(code1, -1, code_size_1);
1078
0
        } else {
1079
0
            int list_no = lo_listno(key);
1080
0
            int offset = lo_offset(key);
1081
0
            const uint8_t* cc = invlists->get_single_code(list_no, offset);
1082
1083
0
            labels[ij] = invlists->get_single_id(list_no, offset);
1084
1085
0
            if (include_listno) {
1086
0
                encode_listno(list_no, code1);
1087
0
                code1 += code_size_1 - code_size;
1088
0
            }
1089
0
            memcpy(code1, cc, code_size);
1090
0
        }
1091
0
    }
1092
0
}
1093
1094
void IndexIVF::reconstruct_from_offset(
1095
        int64_t /*list_no*/,
1096
        int64_t /*offset*/,
1097
0
        float* /*recons*/) const {
1098
0
    FAISS_THROW_MSG("reconstruct_from_offset not implemented");
1099
0
}
1100
1101
0
void IndexIVF::reset() {
1102
0
    direct_map.clear();
1103
0
    invlists->reset();
1104
0
    ntotal = 0;
1105
0
}
1106
1107
0
size_t IndexIVF::remove_ids(const IDSelector& sel) {
1108
0
    size_t nremove = direct_map.remove_ids(sel, invlists);
1109
0
    ntotal -= nremove;
1110
0
    return nremove;
1111
0
}
1112
1113
0
void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
1114
0
    if (direct_map.type == DirectMap::Hashtable) {
1115
        // just remove then add
1116
0
        IDSelectorArray sel(n, new_ids);
1117
0
        size_t nremove = remove_ids(sel);
1118
0
        FAISS_THROW_IF_NOT_MSG(
1119
0
                nremove == n, "did not find all entries to remove");
1120
0
        add_with_ids(n, x, new_ids);
1121
0
        return;
1122
0
    }
1123
1124
0
    FAISS_THROW_IF_NOT(direct_map.type == DirectMap::Array);
1125
    // here it is more tricky because we don't want to introduce holes
1126
    // in continuous range of ids
1127
1128
0
    FAISS_THROW_IF_NOT(is_trained);
1129
0
    std::vector<idx_t> assign(n);
1130
0
    quantizer->assign(n, x, assign.data());
1131
1132
0
    std::vector<uint8_t> flat_codes(n * code_size);
1133
0
    encode_vectors(n, x, assign.data(), flat_codes.data());
1134
1135
0
    direct_map.update_codes(
1136
0
            invlists, n, new_ids, assign.data(), flat_codes.data());
1137
0
}
1138
1139
0
void IndexIVF::train(idx_t n, const float* x) {
1140
0
    if (verbose) {
1141
0
        printf("Training level-1 quantizer\n");
1142
0
    }
1143
1144
0
    train_q1(n, x, verbose, metric_type);
1145
1146
0
    if (verbose) {
1147
0
        printf("Training IVF residual\n");
1148
0
    }
1149
1150
    // optional subsampling
1151
0
    idx_t max_nt = train_encoder_num_vectors();
1152
0
    if (max_nt <= 0) {
1153
0
        max_nt = (size_t)1 << 35;
1154
0
    }
1155
1156
0
    TransformedVectors tv(
1157
0
            x, fvecs_maybe_subsample(d, (size_t*)&n, max_nt, x, verbose));
1158
1159
0
    if (by_residual) {
1160
0
        std::vector<idx_t> assign(n);
1161
0
        quantizer->assign(n, tv.x, assign.data());
1162
1163
0
        std::vector<float> residuals(n * d);
1164
0
        quantizer->compute_residual_n(n, tv.x, residuals.data(), assign.data());
1165
1166
0
        train_encoder(n, residuals.data(), assign.data());
1167
0
    } else {
1168
0
        train_encoder(n, tv.x, nullptr);
1169
0
    }
1170
1171
0
    is_trained = true;
1172
0
}
1173
1174
0
idx_t IndexIVF::train_encoder_num_vectors() const {
1175
0
    return 0;
1176
0
}
1177
1178
void IndexIVF::train_encoder(
1179
        idx_t /*n*/,
1180
        const float* /*x*/,
1181
0
        const idx_t* assign) {
1182
    // does nothing by default
1183
0
    if (verbose) {
1184
0
        printf("IndexIVF: no residual training\n");
1185
0
    }
1186
0
}
1187
1188
bool check_compatible_for_merge_expensive_check = true;
1189
1190
0
void IndexIVF::check_compatible_for_merge(const Index& otherIndex) const {
1191
    // minimal sanity checks
1192
0
    const IndexIVF* other = dynamic_cast<const IndexIVF*>(&otherIndex);
1193
0
    FAISS_THROW_IF_NOT(other);
1194
0
    FAISS_THROW_IF_NOT(other->d == d);
1195
0
    FAISS_THROW_IF_NOT(other->nlist == nlist);
1196
0
    FAISS_THROW_IF_NOT(quantizer->ntotal == other->quantizer->ntotal);
1197
0
    FAISS_THROW_IF_NOT(other->code_size == code_size);
1198
0
    FAISS_THROW_IF_NOT_MSG(
1199
0
            typeid(*this) == typeid(*other),
1200
0
            "can only merge indexes of the same type");
1201
0
    FAISS_THROW_IF_NOT_MSG(
1202
0
            this->direct_map.no() && other->direct_map.no(),
1203
0
            "merge direct_map not implemented");
1204
1205
0
    if (check_compatible_for_merge_expensive_check) {
1206
0
        std::vector<float> v(d), v2(d);
1207
0
        for (size_t i = 0; i < nlist; i++) {
1208
0
            quantizer->reconstruct(i, v.data());
1209
0
            other->quantizer->reconstruct(i, v2.data());
1210
0
            FAISS_THROW_IF_NOT_MSG(
1211
0
                    v == v2, "coarse quantizers should be the same");
1212
0
        }
1213
0
    }
1214
0
}
1215
1216
0
void IndexIVF::merge_from(Index& otherIndex, idx_t add_id) {
1217
0
    check_compatible_for_merge(otherIndex);
1218
0
    IndexIVF* other = static_cast<IndexIVF*>(&otherIndex);
1219
0
    invlists->merge_from(other->invlists, add_id);
1220
1221
0
    ntotal += other->ntotal;
1222
0
    other->ntotal = 0;
1223
0
}
1224
1225
0
CodePacker* IndexIVF::get_CodePacker() const {
1226
0
    return new CodePackerFlat(code_size);
1227
0
}
1228
1229
0
void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
1230
0
    if (own_invlists) {
1231
0
        delete invlists;
1232
0
        invlists = nullptr;
1233
0
    }
1234
    // FAISS_THROW_IF_NOT (ntotal == 0);
1235
0
    if (il) {
1236
0
        FAISS_THROW_IF_NOT(il->nlist == nlist);
1237
0
        FAISS_THROW_IF_NOT(
1238
0
                il->code_size == code_size ||
1239
0
                il->code_size == InvertedLists::INVALID_CODE_SIZE);
1240
0
    }
1241
0
    invlists = il;
1242
0
    own_invlists = own;
1243
0
}
1244
1245
void IndexIVF::copy_subset_to(
1246
        IndexIVF& other,
1247
        InvertedLists::subset_type_t subset_type,
1248
        idx_t a1,
1249
0
        idx_t a2) const {
1250
0
    other.ntotal +=
1251
0
            invlists->copy_subset_to(*other.invlists, subset_type, a1, a2);
1252
0
}
1253
1254
0
IndexIVF::~IndexIVF() {
1255
0
    if (own_invlists) {
1256
0
        delete invlists;
1257
0
    }
1258
0
}
1259
1260
/*************************************************************************
1261
 * IndexIVFStats
1262
 *************************************************************************/
1263
1264
8
void IndexIVFStats::reset() {
1265
8
    memset((void*)this, 0, sizeof(*this));
1266
8
}
1267
1268
0
void IndexIVFStats::add(const IndexIVFStats& other) {
1269
0
    nq += other.nq;
1270
0
    nlist += other.nlist;
1271
0
    ndis += other.ndis;
1272
0
    nheap_updates += other.nheap_updates;
1273
0
    quantization_time += other.quantization_time;
1274
0
    search_time += other.search_time;
1275
0
}
1276
1277
IndexIVFStats indexIVF_stats;
1278
1279
/*************************************************************************
1280
 * InvertedListScanner
1281
 *************************************************************************/
1282
1283
size_t InvertedListScanner::scan_codes(
1284
        size_t list_size,
1285
        const uint8_t* codes,
1286
        const idx_t* ids,
1287
        float* simi,
1288
        idx_t* idxi,
1289
0
        size_t k) const {
1290
0
    size_t nup = 0;
1291
1292
0
    if (!keep_max) {
1293
0
        for (size_t j = 0; j < list_size; j++) {
1294
0
            if (sel != nullptr) {
1295
0
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1296
0
                if (!sel->is_member(id)) {
1297
0
                    codes += code_size;
1298
0
                    continue;
1299
0
                }
1300
0
            }
1301
1302
0
            float dis = distance_to_code(codes);
1303
0
            if (dis < simi[0]) {
1304
0
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1305
0
                maxheap_replace_top(k, simi, idxi, dis, id);
1306
0
                nup++;
1307
0
            }
1308
0
            codes += code_size;
1309
0
        }
1310
0
    } else {
1311
0
        for (size_t j = 0; j < list_size; j++) {
1312
0
            if (sel != nullptr) {
1313
0
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1314
0
                if (!sel->is_member(id)) {
1315
0
                    codes += code_size;
1316
0
                    continue;
1317
0
                }
1318
0
            }
1319
1320
0
            float dis = distance_to_code(codes);
1321
0
            if (dis > simi[0]) {
1322
0
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1323
0
                minheap_replace_top(k, simi, idxi, dis, id);
1324
0
                nup++;
1325
0
            }
1326
0
            codes += code_size;
1327
0
        }
1328
0
    }
1329
0
    return nup;
1330
0
}
1331
1332
size_t InvertedListScanner::iterate_codes(
1333
        InvertedListsIterator* it,
1334
        float* simi,
1335
        idx_t* idxi,
1336
        size_t k,
1337
0
        size_t& list_size) const {
1338
0
    size_t nup = 0;
1339
0
    list_size = 0;
1340
1341
0
    if (!keep_max) {
1342
0
        for (; it->is_available(); it->next()) {
1343
0
            auto id_and_codes = it->get_id_and_codes();
1344
0
            float dis = distance_to_code(id_and_codes.second);
1345
0
            if (dis < simi[0]) {
1346
0
                maxheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
1347
0
                nup++;
1348
0
            }
1349
0
            list_size++;
1350
0
        }
1351
0
    } else {
1352
0
        for (; it->is_available(); it->next()) {
1353
0
            auto id_and_codes = it->get_id_and_codes();
1354
0
            float dis = distance_to_code(id_and_codes.second);
1355
0
            if (dis > simi[0]) {
1356
0
                minheap_replace_top(k, simi, idxi, dis, id_and_codes.first);
1357
0
                nup++;
1358
0
            }
1359
0
            list_size++;
1360
0
        }
1361
0
    }
1362
0
    return nup;
1363
0
}
1364
1365
void InvertedListScanner::scan_codes_range(
1366
        size_t list_size,
1367
        const uint8_t* codes,
1368
        const idx_t* ids,
1369
        float radius,
1370
0
        RangeQueryResult& res) const {
1371
0
    for (size_t j = 0; j < list_size; j++) {
1372
0
        float dis = distance_to_code(codes);
1373
0
        bool keep = !keep_max
1374
0
                ? dis < radius
1375
0
                : dis > radius; // TODO templatize to remove this test
1376
0
        if (keep) {
1377
0
            int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
1378
0
            res.add(dis, id);
1379
0
        }
1380
0
        codes += code_size;
1381
0
    }
1382
0
}
1383
1384
void InvertedListScanner::iterate_codes_range(
1385
        InvertedListsIterator* it,
1386
        float radius,
1387
        RangeQueryResult& res,
1388
0
        size_t& list_size) const {
1389
0
    list_size = 0;
1390
0
    for (; it->is_available(); it->next()) {
1391
0
        auto id_and_codes = it->get_id_and_codes();
1392
0
        float dis = distance_to_code(id_and_codes.second);
1393
0
        bool keep = !keep_max
1394
0
                ? dis < radius
1395
0
                : dis > radius; // TODO templatize to remove this test
1396
0
        if (keep) {
1397
0
            res.add(dis, id_and_codes.first);
1398
0
        }
1399
0
        list_size++;
1400
0
    }
1401
0
}
1402
1403
} // namespace faiss