Coverage Report

Created: 2025-10-16 13:21

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVFFastScan.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexIVFFastScan.h>
9
10
#include <cassert>
11
#include <cinttypes>
12
#include <cstdio>
13
#include <set>
14
15
#include <omp.h>
16
17
#include <memory>
18
19
#include <faiss/IndexIVFPQ.h>
20
#include <faiss/impl/AuxIndexStructures.h>
21
#include <faiss/impl/FaissAssert.h>
22
#include <faiss/impl/LookupTableScaler.h>
23
#include <faiss/impl/pq4_fast_scan.h>
24
#include <faiss/impl/simd_result_handlers.h>
25
#include <faiss/invlists/BlockInvertedLists.h>
26
#include <faiss/utils/hamming.h>
27
#include <faiss/utils/quantize_lut.h>
28
#include <faiss/utils/utils.h>
29
30
namespace faiss {
31
32
using namespace simd_result_handlers;
33
34
inline size_t roundup(size_t a, size_t b) {
35
    return (a + b - 1) / b * b;
36
}
37
38
IndexIVFFastScan::IndexIVFFastScan(
39
        Index* quantizer,
40
        size_t d,
41
        size_t nlist,
42
        size_t code_size,
43
        MetricType metric)
44
0
        : IndexIVF(quantizer, d, nlist, code_size, metric) {
45
    // unlike other indexes, we prefer no residuals for performance reasons.
46
0
    by_residual = false;
47
0
    FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
48
0
}
49
50
0
IndexIVFFastScan::IndexIVFFastScan() {
51
0
    bbs = 0;
52
0
    M2 = 0;
53
0
    is_trained = false;
54
0
    by_residual = false;
55
0
}
56
57
void IndexIVFFastScan::init_fastscan(
58
        Quantizer* fine_quantizer,
59
        size_t M,
60
        size_t nbits_init,
61
        size_t nlist,
62
        MetricType /* metric */,
63
0
        int bbs_2) {
64
0
    FAISS_THROW_IF_NOT(bbs_2 % 32 == 0);
65
0
    FAISS_THROW_IF_NOT(nbits_init == 4);
66
0
    FAISS_THROW_IF_NOT(fine_quantizer->d == d);
67
68
0
    this->fine_quantizer = fine_quantizer;
69
0
    this->M = M;
70
0
    this->nbits = nbits_init;
71
0
    this->bbs = bbs_2;
72
0
    ksub = (1 << nbits_init);
73
0
    M2 = roundup(M, 2);
74
0
    code_size = M2 / 2;
75
0
    FAISS_THROW_IF_NOT(code_size == fine_quantizer->code_size);
76
77
0
    is_trained = false;
78
0
    replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
79
0
}
80
81
0
void IndexIVFFastScan::init_code_packer() {
82
0
    auto bil = dynamic_cast<BlockInvertedLists*>(invlists);
83
0
    FAISS_THROW_IF_NOT(bil);
84
0
    delete bil->packer; // in case there was one before
85
0
    bil->packer = get_CodePacker();
86
0
}
87
88
0
IndexIVFFastScan::~IndexIVFFastScan() = default;
89
90
/*********************************************************
91
 * Code management functions
92
 *********************************************************/
93
94
void IndexIVFFastScan::add_with_ids(
95
        idx_t n,
96
        const float* x,
97
0
        const idx_t* xids) {
98
0
    FAISS_THROW_IF_NOT(is_trained);
99
100
    // do some blocking to avoid excessive allocs
101
0
    constexpr idx_t bs = 65536;
102
0
    if (n > bs) {
103
0
        double t0 = getmillisecs();
104
0
        for (idx_t i0 = 0; i0 < n; i0 += bs) {
105
0
            idx_t i1 = std::min(n, i0 + bs);
106
0
            if (verbose) {
107
0
                double t1 = getmillisecs();
108
0
                double elapsed_time = (t1 - t0) / 1000;
109
0
                double total_time = 0;
110
0
                if (i0 != 0) {
111
0
                    total_time = elapsed_time / i0 * n;
112
0
                }
113
0
                size_t mem = get_mem_usage_kb() / (1 << 10);
114
115
0
                printf("IndexIVFFastScan::add_with_ids %zd/%zd, time %.2f/%.2f, RSS %zdMB\n",
116
0
                       size_t(i1),
117
0
                       size_t(n),
118
0
                       elapsed_time,
119
0
                       total_time,
120
0
                       mem);
121
0
            }
122
0
            add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
123
0
        }
124
0
        return;
125
0
    }
126
0
    InterruptCallback::check();
127
128
0
    direct_map.check_can_add(xids);
129
0
    std::unique_ptr<idx_t[]> idx(new idx_t[n]);
130
0
    quantizer->assign(n, x, idx.get());
131
132
0
    AlignedTable<uint8_t> flat_codes(n * code_size);
133
0
    encode_vectors(n, x, idx.get(), flat_codes.get());
134
135
0
    DirectMapAdd dm_adder(direct_map, n, xids);
136
0
    BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
137
0
    FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
138
139
    // prepare batches
140
0
    std::vector<idx_t> order(n);
141
0
    for (idx_t i = 0; i < n; i++) {
142
0
        order[i] = i;
143
0
    }
144
145
    // TODO should not need stable
146
0
    std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
147
0
        return idx[a] < idx[b];
148
0
    });
149
150
    // TODO parallelize
151
0
    idx_t i0 = 0;
152
0
    while (i0 < n) {
153
0
        idx_t list_no = idx[order[i0]];
154
0
        idx_t i1 = i0 + 1;
155
0
        while (i1 < n && idx[order[i1]] == list_no) {
156
0
            i1++;
157
0
        }
158
159
0
        if (list_no == -1) {
160
0
            i0 = i1;
161
0
            continue;
162
0
        }
163
164
        // make linear array
165
0
        AlignedTable<uint8_t> list_codes((i1 - i0) * code_size);
166
0
        size_t list_size = bil->list_size(list_no);
167
168
0
        bil->resize(list_no, list_size + i1 - i0);
169
170
0
        for (idx_t i = i0; i < i1; i++) {
171
0
            size_t ofs = list_size + i - i0;
172
0
            idx_t id = xids ? xids[order[i]] : ntotal + order[i];
173
0
            dm_adder.add(order[i], list_no, ofs);
174
0
            bil->ids[list_no][ofs] = id;
175
0
            memcpy(list_codes.data() + (i - i0) * code_size,
176
0
                   flat_codes.data() + order[i] * code_size,
177
0
                   code_size);
178
0
        }
179
0
        pq4_pack_codes_range(
180
0
                list_codes.data(),
181
0
                M,
182
0
                list_size,
183
0
                list_size + i1 - i0,
184
0
                bbs,
185
0
                M2,
186
0
                bil->codes[list_no].data());
187
188
0
        i0 = i1;
189
0
    }
190
191
0
    ntotal += n;
192
0
}
193
194
0
CodePacker* IndexIVFFastScan::get_CodePacker() const {
195
0
    return new CodePackerPQ4(M, bbs);
196
0
}
197
198
/*********************************************************
199
 * search
200
 *********************************************************/
201
202
namespace {
203
204
template <class C, typename dis_t>
205
void estimators_from_tables_generic(
206
        const IndexIVFFastScan& index,
207
        const uint8_t* codes,
208
        size_t ncodes,
209
        const dis_t* dis_table,
210
        const int64_t* ids,
211
        float bias,
212
        size_t k,
213
        typename C::T* heap_dis,
214
        int64_t* heap_ids,
215
0
        const NormTableScaler* scaler) {
216
0
    using accu_t = typename C::T;
217
0
    size_t nscale = scaler ? scaler->nscale : 0;
218
0
    for (size_t j = 0; j < ncodes; ++j) {
219
0
        BitstringReader bsr(codes + j * index.code_size, index.code_size);
220
0
        accu_t dis = bias;
221
0
        const dis_t* __restrict dt = dis_table;
222
223
0
        for (size_t m = 0; m < index.M - nscale; m++) {
224
0
            uint64_t c = bsr.read(index.nbits);
225
0
            dis += dt[c];
226
0
            dt += index.ksub;
227
0
        }
228
229
0
        if (scaler) {
230
0
            for (size_t m = 0; m < nscale; m++) {
231
0
                uint64_t c = bsr.read(index.nbits);
232
0
                dis += scaler->scale_one(dt[c]);
233
0
                dt += index.ksub;
234
0
            }
235
0
        }
236
237
0
        if (C::cmp(heap_dis[0], dis)) {
238
0
            heap_pop<C>(k, heap_dis, heap_ids);
239
0
            heap_push<C>(k, heap_dis, heap_ids, dis, ids[j]);
240
0
        }
241
0
    }
242
0
}
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_130estimators_from_tables_genericINS_4CMaxIflEEfEEvRKNS_16IndexIVFFastScanEPKhmPKT0_PKlfmPNT_1TEPlPKNS_15NormTableScalerE
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_130estimators_from_tables_genericINS_4CMinIflEEfEEvRKNS_16IndexIVFFastScanEPKhmPKT0_PKlfmPNT_1TEPlPKNS_15NormTableScalerE
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_130estimators_from_tables_genericINS_4CMaxItlEEhEEvRKNS_16IndexIVFFastScanEPKhmPKT0_PKlfmPNT_1TEPlPKNS_15NormTableScalerE
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_130estimators_from_tables_genericINS_4CMinItlEEhEEvRKNS_16IndexIVFFastScanEPKhmPKT0_PKlfmPNT_1TEPlPKNS_15NormTableScalerE
243
244
using namespace quantize_lut;
245
246
} // anonymous namespace
247
248
/*********************************************************
249
 * Look-Up Table functions
250
 *********************************************************/
251
252
void IndexIVFFastScan::compute_LUT_uint8(
253
        size_t n,
254
        const float* x,
255
        const CoarseQuantized& cq,
256
        AlignedTable<uint8_t>& dis_tables,
257
        AlignedTable<uint16_t>& biases,
258
0
        float* normalizers) const {
259
0
    AlignedTable<float> dis_tables_float;
260
0
    AlignedTable<float> biases_float;
261
262
0
    compute_LUT(n, x, cq, dis_tables_float, biases_float);
263
0
    size_t nprobe = cq.nprobe;
264
0
    bool lut_is_3d = lookup_table_is_3d();
265
0
    size_t dim123 = ksub * M;
266
0
    size_t dim123_2 = ksub * M2;
267
0
    if (lut_is_3d) {
268
0
        dim123 *= nprobe;
269
0
        dim123_2 *= nprobe;
270
0
    }
271
0
    dis_tables.resize(n * dim123_2);
272
0
    if (biases_float.get()) {
273
0
        biases.resize(n * nprobe);
274
0
    }
275
276
    // OMP for MSVC requires i to have signed integral type
277
0
#pragma omp parallel for if (n > 100)
278
0
    for (int64_t i = 0; i < n; i++) {
279
0
        const float* t_in = dis_tables_float.get() + i * dim123;
280
0
        const float* b_in = nullptr;
281
0
        uint8_t* t_out = dis_tables.get() + i * dim123_2;
282
0
        uint16_t* b_out = nullptr;
283
0
        if (biases_float.get()) {
284
0
            b_in = biases_float.get() + i * nprobe;
285
0
            b_out = biases.get() + i * nprobe;
286
0
        }
287
288
0
        quantize_LUT_and_bias(
289
0
                nprobe,
290
0
                M,
291
0
                ksub,
292
0
                lut_is_3d,
293
0
                t_in,
294
0
                b_in,
295
0
                t_out,
296
0
                M2,
297
0
                b_out,
298
0
                normalizers + 2 * i,
299
0
                normalizers + 2 * i + 1);
300
0
    }
301
0
}
302
303
/*********************************************************
304
 * Search functions
305
 *********************************************************/
306
307
void IndexIVFFastScan::search(
308
        idx_t n,
309
        const float* x,
310
        idx_t k,
311
        float* distances,
312
        idx_t* labels,
313
0
        const SearchParameters* params_in) const {
314
0
    const IVFSearchParameters* params = nullptr;
315
0
    if (params_in) {
316
0
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
317
0
        FAISS_THROW_IF_NOT_MSG(
318
0
                params, "IndexIVFFastScan params have incorrect type");
319
0
    }
320
321
0
    search_preassigned(
322
0
            n, x, k, nullptr, nullptr, distances, labels, false, params);
323
0
}
324
325
void IndexIVFFastScan::search_preassigned(
326
        idx_t n,
327
        const float* x,
328
        idx_t k,
329
        const idx_t* assign,
330
        const float* centroid_dis,
331
        float* distances,
332
        idx_t* labels,
333
        bool store_pairs,
334
        const IVFSearchParameters* params,
335
0
        IndexIVFStats* stats) const {
336
0
    size_t nprobe = this->nprobe;
337
0
    if (params) {
338
0
        FAISS_THROW_IF_NOT(params->max_codes == 0);
339
0
        nprobe = params->nprobe;
340
0
    }
341
342
0
    FAISS_THROW_IF_NOT_MSG(
343
0
            !store_pairs, "store_pairs not supported for this index");
344
0
    FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
345
0
    FAISS_THROW_IF_NOT(k > 0);
346
347
0
    const CoarseQuantized cq = {nprobe, centroid_dis, assign};
348
0
    search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params);
349
0
}
350
351
void IndexIVFFastScan::range_search(
352
        idx_t n,
353
        const float* x,
354
        float radius,
355
        RangeSearchResult* result,
356
0
        const SearchParameters* params_in) const {
357
0
    size_t nprobe = this->nprobe;
358
0
    const IVFSearchParameters* params = nullptr;
359
0
    if (params_in) {
360
0
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
361
0
        FAISS_THROW_IF_NOT_MSG(
362
0
                params, "IndexIVFFastScan params have incorrect type");
363
0
        nprobe = params->nprobe;
364
0
    }
365
366
0
    const CoarseQuantized cq = {nprobe, nullptr, nullptr};
367
0
    range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params);
368
0
}
369
370
namespace {
371
372
template <class C>
373
ResultHandlerCompare<C, true>* make_knn_handler_fixC(
374
        int impl,
375
        idx_t n,
376
        idx_t k,
377
        float* distances,
378
        idx_t* labels,
379
0
        const IDSelector* sel) {
380
0
    using HeapHC = HeapHandler<C, true>;
381
0
    using ReservoirHC = ReservoirHandler<C, true>;
382
0
    using SingleResultHC = SingleResultHandler<C, true>;
383
384
0
    if (k == 1) {
385
0
        return new SingleResultHC(n, 0, distances, labels, sel);
386
0
    } else if (impl % 2 == 0) {
387
0
        return new HeapHC(n, 0, k, distances, labels, sel);
388
0
    } else /* if (impl % 2 == 1) */ {
389
0
        return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel);
390
0
    }
391
0
}
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_121make_knn_handler_fixCINS_4CMaxItlEEEEPNS_20simd_result_handlers20ResultHandlerCompareIT_Lb1EEEillPfPlPKNS_10IDSelectorE
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZN5faiss12_GLOBAL__N_121make_knn_handler_fixCINS_4CMinItlEEEEPNS_20simd_result_handlers20ResultHandlerCompareIT_Lb1EEEillPfPlPKNS_10IDSelectorE
392
393
SIMDResultHandlerToFloat* make_knn_handler(
394
        bool is_max,
395
        int impl,
396
        idx_t n,
397
        idx_t k,
398
        float* distances,
399
        idx_t* labels,
400
0
        const IDSelector* sel) {
401
0
    if (is_max) {
402
0
        return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
403
0
                impl, n, k, distances, labels, sel);
404
0
    } else {
405
0
        return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
406
0
                impl, n, k, distances, labels, sel);
407
0
    }
408
0
}
409
410
using CoarseQuantized = IndexIVFFastScan::CoarseQuantized;
411
412
struct CoarseQuantizedWithBuffer : CoarseQuantized {
413
    explicit CoarseQuantizedWithBuffer(const CoarseQuantized& cq)
414
0
            : CoarseQuantized(cq) {}
415
416
0
    bool done() const {
417
0
        return ids != nullptr;
418
0
    }
419
420
    std::vector<idx_t> ids_buffer;
421
    std::vector<float> dis_buffer;
422
423
    void quantize(
424
            const Index* quantizer,
425
            idx_t n,
426
            const float* x,
427
0
            const SearchParameters* quantizer_params) {
428
0
        dis_buffer.resize(nprobe * n);
429
0
        ids_buffer.resize(nprobe * n);
430
0
        quantizer->search(
431
0
                n,
432
0
                x,
433
0
                nprobe,
434
0
                dis_buffer.data(),
435
0
                ids_buffer.data(),
436
0
                quantizer_params);
437
0
        dis = dis_buffer.data();
438
0
        ids = ids_buffer.data();
439
0
    }
440
};
441
442
struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
443
    size_t i0, i1;
444
    CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1)
445
0
            : CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) {
446
0
        if (done()) {
447
0
            dis += nprobe * i0;
448
0
            ids += nprobe * i0;
449
0
        }
450
0
    }
451
452
    void quantize_slice(
453
            const Index* quantizer,
454
            const float* x,
455
0
            const SearchParameters* quantizer_params) {
456
0
        quantize(quantizer, i1 - i0, x + quantizer->d * i0, quantizer_params);
457
0
    }
458
};
459
460
int compute_search_nslice(
461
        const IndexIVFFastScan* index,
462
        size_t n,
463
0
        size_t nprobe) {
464
0
    int nslice;
465
0
    if (n <= omp_get_max_threads()) {
466
0
        nslice = n;
467
0
    } else if (index->lookup_table_is_3d()) {
468
        // make sure we don't make too big LUT tables
469
0
        size_t lut_size_per_query = index->M * index->ksub * nprobe *
470
0
                (sizeof(float) + sizeof(uint8_t));
471
472
0
        size_t max_lut_size = precomputed_table_max_bytes;
473
        // how many queries we can handle within mem budget
474
0
        size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
475
0
        nslice = roundup(
476
0
                std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
477
0
    } else {
478
        // LUTs unlikely to be a limiting factor
479
0
        nslice = omp_get_max_threads();
480
0
    }
481
0
    return nslice;
482
0
}
483
484
} // namespace
485
486
void IndexIVFFastScan::search_dispatch_implem(
487
        idx_t n,
488
        const float* x,
489
        idx_t k,
490
        float* distances,
491
        idx_t* labels,
492
        const CoarseQuantized& cq_in,
493
        const NormTableScaler* scaler,
494
0
        const IVFSearchParameters* params) const {
495
0
    const idx_t nprobe = params ? params->nprobe : this->nprobe;
496
0
    const IDSelector* sel = (params) ? params->sel : nullptr;
497
0
    const SearchParameters* quantizer_params =
498
0
            params ? params->quantizer_params : nullptr;
499
500
0
    bool is_max = !is_similarity_metric(metric_type);
501
0
    using RH = SIMDResultHandlerToFloat;
502
503
0
    if (n == 0) {
504
0
        return;
505
0
    }
506
507
    // actual implementation used
508
0
    int impl = implem;
509
510
0
    if (impl == 0) {
511
0
        if (bbs == 32) {
512
0
            impl = 12;
513
0
        } else {
514
0
            impl = 10;
515
0
        }
516
0
        if (k > 20) { // use reservoir rather than heap
517
0
            impl++;
518
0
        }
519
0
    }
520
521
0
    bool multiple_threads =
522
0
            n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
523
0
    if (impl >= 100) {
524
0
        multiple_threads = false;
525
0
        impl -= 100;
526
0
    }
527
528
0
    CoarseQuantizedWithBuffer cq(cq_in);
529
0
    cq.nprobe = nprobe;
530
531
0
    if (!cq.done() && !multiple_threads) {
532
        // we do the coarse quantization here execpt when search is
533
        // sliced over threads (then it is more efficient to have each thread do
534
        // its own coarse quantization)
535
0
        cq.quantize(quantizer, n, x, quantizer_params);
536
0
        invlists->prefetch_lists(cq.ids, n * cq.nprobe);
537
0
    }
538
539
0
    if (impl == 1) {
540
0
        if (is_max) {
541
0
            search_implem_1<CMax<float, int64_t>>(
542
0
                    n, x, k, distances, labels, cq, scaler, params);
543
0
        } else {
544
0
            search_implem_1<CMin<float, int64_t>>(
545
0
                    n, x, k, distances, labels, cq, scaler, params);
546
0
        }
547
0
    } else if (impl == 2) {
548
0
        if (is_max) {
549
0
            search_implem_2<CMax<uint16_t, int64_t>>(
550
0
                    n, x, k, distances, labels, cq, scaler, params);
551
0
        } else {
552
0
            search_implem_2<CMin<uint16_t, int64_t>>(
553
0
                    n, x, k, distances, labels, cq, scaler, params);
554
0
        }
555
0
    } else if (impl >= 10 && impl <= 15) {
556
0
        size_t ndis = 0, nlist_visited = 0;
557
558
0
        if (!multiple_threads) {
559
            // clang-format off
560
0
            if (impl == 12 || impl == 13) {
561
0
                std::unique_ptr<RH> handler(
562
0
                    make_knn_handler(
563
0
                        is_max, 
564
0
                        impl, 
565
0
                        n, 
566
0
                        k, 
567
0
                        distances, 
568
0
                        labels, sel
569
0
                    )
570
0
                );
571
0
                search_implem_12(
572
0
                        n, x, *handler.get(),
573
0
                        cq, &ndis, &nlist_visited, scaler, params);
574
0
            } else if (impl == 14 || impl == 15) {
575
0
                search_implem_14(
576
0
                        n, x, k, distances, labels,
577
0
                        cq, impl, scaler, params);
578
0
            } else {
579
0
                std::unique_ptr<RH> handler(
580
0
                    make_knn_handler(
581
0
                        is_max, 
582
0
                        impl, 
583
0
                        n, 
584
0
                        k, 
585
0
                        distances, 
586
0
                        labels,
587
0
                        sel
588
0
                    )
589
0
                );
590
0
                search_implem_10(
591
0
                        n, x, *handler.get(), cq,
592
0
                        &ndis, &nlist_visited, scaler, params);
593
0
            }
594
            // clang-format on
595
0
        } else {
596
            // explicitly slice over threads
597
0
            int nslice = compute_search_nslice(this, n, cq.nprobe);
598
0
            if (impl == 14 || impl == 15) {
599
                // this might require slicing if there are too
600
                // many queries (for now we keep this simple)
601
0
                search_implem_14(
602
0
                        n, x, k, distances, labels, cq, impl, scaler, params);
603
0
            } else {
604
0
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
605
0
                for (int slice = 0; slice < nslice; slice++) {
606
0
                    idx_t i0 = n * slice / nslice;
607
0
                    idx_t i1 = n * (slice + 1) / nslice;
608
0
                    float* dis_i = distances + i0 * k;
609
0
                    idx_t* lab_i = labels + i0 * k;
610
0
                    CoarseQuantizedSlice cq_i(cq, i0, i1);
611
0
                    if (!cq_i.done()) {
612
0
                        cq_i.quantize_slice(quantizer, x, quantizer_params);
613
0
                    }
614
0
                    std::unique_ptr<RH> handler(make_knn_handler(
615
0
                            is_max, impl, i1 - i0, k, dis_i, lab_i, sel));
616
                    // clang-format off
617
0
                    if (impl == 12 || impl == 13) {
618
0
                        search_implem_12(
619
0
                                i1 - i0, x + i0 * d, *handler.get(),
620
0
                                cq_i, &ndis, &nlist_visited, scaler, params);
621
0
                    } else {
622
0
                        search_implem_10(
623
0
                                i1 - i0, x + i0 * d, *handler.get(),
624
0
                                cq_i, &ndis, &nlist_visited, scaler, params);
625
0
                    }
626
                    // clang-format on
627
0
                }
628
0
            }
629
0
        }
630
0
        indexIVF_stats.nq += n;
631
0
        indexIVF_stats.ndis += ndis;
632
0
        indexIVF_stats.nlist += nlist_visited;
633
0
    } else {
634
0
        FAISS_THROW_FMT("implem %d does not exist", implem);
635
0
    }
636
0
}
637
638
void IndexIVFFastScan::range_search_dispatch_implem(
639
        idx_t n,
640
        const float* x,
641
        float radius,
642
        RangeSearchResult& rres,
643
        const CoarseQuantized& cq_in,
644
        const NormTableScaler* scaler,
645
0
        const IVFSearchParameters* params) const {
646
    // const idx_t nprobe = params ? params->nprobe : this->nprobe;
647
0
    const IDSelector* sel = (params) ? params->sel : nullptr;
648
0
    const SearchParameters* quantizer_params =
649
0
            params ? params->quantizer_params : nullptr;
650
651
0
    bool is_max = !is_similarity_metric(metric_type);
652
653
0
    if (n == 0) {
654
0
        return;
655
0
    }
656
657
    // actual implementation used
658
0
    int impl = implem;
659
660
0
    if (impl == 0) {
661
0
        if (bbs == 32) {
662
0
            impl = 12;
663
0
        } else {
664
0
            impl = 10;
665
0
        }
666
0
    }
667
668
0
    CoarseQuantizedWithBuffer cq(cq_in);
669
670
0
    bool multiple_threads =
671
0
            n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
672
0
    if (impl >= 100) {
673
0
        multiple_threads = false;
674
0
        impl -= 100;
675
0
    }
676
677
0
    if (!multiple_threads && !cq.done()) {
678
0
        cq.quantize(quantizer, n, x, quantizer_params);
679
0
        invlists->prefetch_lists(cq.ids, n * cq.nprobe);
680
0
    }
681
682
0
    size_t ndis = 0, nlist_visited = 0;
683
684
0
    if (!multiple_threads) { // single thread
685
0
        std::unique_ptr<SIMDResultHandlerToFloat> handler;
686
0
        if (is_max) {
687
0
            handler.reset(new RangeHandler<CMax<uint16_t, int64_t>, true>(
688
0
                    rres, radius, 0, sel));
689
0
        } else {
690
0
            handler.reset(new RangeHandler<CMin<uint16_t, int64_t>, true>(
691
0
                    rres, radius, 0, sel));
692
0
        }
693
0
        if (impl == 12) {
694
0
            search_implem_12(
695
0
                    n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
696
0
        } else if (impl == 10) {
697
0
            search_implem_10(
698
0
                    n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
699
0
        } else {
700
0
            FAISS_THROW_FMT("Range search implem %d not implemented", impl);
701
0
        }
702
0
    } else {
703
        // explicitly slice over threads
704
0
        int nslice = compute_search_nslice(this, n, cq.nprobe);
705
0
#pragma omp parallel
706
0
        {
707
0
            RangeSearchPartialResult pres(&rres);
708
709
0
#pragma omp for reduction(+ : ndis, nlist_visited)
710
0
            for (int slice = 0; slice < nslice; slice++) {
711
0
                idx_t i0 = n * slice / nslice;
712
0
                idx_t i1 = n * (slice + 1) / nslice;
713
0
                CoarseQuantizedSlice cq_i(cq, i0, i1);
714
0
                if (!cq_i.done()) {
715
0
                    cq_i.quantize_slice(quantizer, x, quantizer_params);
716
0
                }
717
0
                std::unique_ptr<SIMDResultHandlerToFloat> handler;
718
0
                if (is_max) {
719
0
                    handler.reset(new PartialRangeHandler<
720
0
                                  CMax<uint16_t, int64_t>,
721
0
                                  true>(pres, radius, 0, i0, i1, sel));
722
0
                } else {
723
0
                    handler.reset(new PartialRangeHandler<
724
0
                                  CMin<uint16_t, int64_t>,
725
0
                                  true>(pres, radius, 0, i0, i1, sel));
726
0
                }
727
728
0
                if (impl == 12 || impl == 13) {
729
0
                    search_implem_12(
730
0
                            i1 - i0,
731
0
                            x + i0 * d,
732
0
                            *handler.get(),
733
0
                            cq_i,
734
0
                            &ndis,
735
0
                            &nlist_visited,
736
0
                            scaler,
737
0
                            params);
738
0
                } else {
739
0
                    search_implem_10(
740
0
                            i1 - i0,
741
0
                            x + i0 * d,
742
0
                            *handler.get(),
743
0
                            cq_i,
744
0
                            &ndis,
745
0
                            &nlist_visited,
746
0
                            scaler,
747
0
                            params);
748
0
                }
749
0
            }
750
0
            pres.finalize();
751
0
        }
752
0
    }
753
754
0
    indexIVF_stats.nq += n;
755
0
    indexIVF_stats.ndis += ndis;
756
0
    indexIVF_stats.nlist += nlist_visited;
757
0
}
758
759
template <class C>
760
void IndexIVFFastScan::search_implem_1(
761
        idx_t n,
762
        const float* x,
763
        idx_t k,
764
        float* distances,
765
        idx_t* labels,
766
        const CoarseQuantized& cq,
767
        const NormTableScaler* scaler,
768
0
        const IVFSearchParameters* params) const {
769
0
    FAISS_THROW_IF_NOT(orig_invlists);
770
771
0
    size_t dim12 = ksub * M;
772
0
    AlignedTable<float> dis_tables;
773
0
    AlignedTable<float> biases;
774
775
0
    compute_LUT(n, x, cq, dis_tables, biases);
776
777
0
    bool single_LUT = !lookup_table_is_3d();
778
779
0
    size_t ndis = 0, nlist_visited = 0;
780
0
    size_t nprobe = cq.nprobe;
781
0
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
782
0
    for (idx_t i = 0; i < n; i++) {
783
0
        int64_t* heap_ids = labels + i * k;
784
0
        float* heap_dis = distances + i * k;
785
0
        heap_heapify<C>(k, heap_dis, heap_ids);
786
0
        float* LUT = nullptr;
787
788
0
        if (single_LUT) {
789
0
            LUT = dis_tables.get() + i * dim12;
790
0
        }
791
0
        for (idx_t j = 0; j < nprobe; j++) {
792
0
            if (!single_LUT) {
793
0
                LUT = dis_tables.get() + (i * nprobe + j) * dim12;
794
0
            }
795
0
            idx_t list_no = cq.ids[i * nprobe + j];
796
0
            if (list_no < 0)
797
0
                continue;
798
0
            size_t ls = orig_invlists->list_size(list_no);
799
0
            if (ls == 0)
800
0
                continue;
801
0
            InvertedLists::ScopedCodes codes(orig_invlists, list_no);
802
0
            InvertedLists::ScopedIds ids(orig_invlists, list_no);
803
804
0
            float bias = biases.get() ? biases[i * nprobe + j] : 0;
805
806
0
            estimators_from_tables_generic<C>(
807
0
                    *this,
808
0
                    codes.get(),
809
0
                    ls,
810
0
                    LUT,
811
0
                    ids.get(),
812
0
                    bias,
813
0
                    k,
814
0
                    heap_dis,
815
0
                    heap_ids,
816
0
                    scaler);
817
0
            nlist_visited++;
818
0
            ndis++;
819
0
        }
820
0
        heap_reorder<C>(k, heap_dis, heap_ids);
821
0
    }
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZNK5faiss16IndexIVFFastScan15search_implem_1INS_4CMaxIflEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE.omp_outlined_debug__
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZNK5faiss16IndexIVFFastScan15search_implem_1INS_4CMinIflEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE.omp_outlined_debug__
822
0
    indexIVF_stats.nq += n;
823
0
    indexIVF_stats.ndis += ndis;
824
0
    indexIVF_stats.nlist += nlist_visited;
825
0
}
Unexecuted instantiation: _ZNK5faiss16IndexIVFFastScan15search_implem_1INS_4CMaxIflEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE
Unexecuted instantiation: _ZNK5faiss16IndexIVFFastScan15search_implem_1INS_4CMinIflEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE
826
827
template <class C>
828
void IndexIVFFastScan::search_implem_2(
829
        idx_t n,
830
        const float* x,
831
        idx_t k,
832
        float* distances,
833
        idx_t* labels,
834
        const CoarseQuantized& cq,
835
        const NormTableScaler* scaler,
836
0
        const IVFSearchParameters* params) const {
837
0
    FAISS_THROW_IF_NOT(orig_invlists);
838
839
0
    size_t dim12 = ksub * M2;
840
0
    AlignedTable<uint8_t> dis_tables;
841
0
    AlignedTable<uint16_t> biases;
842
0
    std::unique_ptr<float[]> normalizers(new float[2 * n]);
843
844
0
    compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
845
846
0
    bool single_LUT = !lookup_table_is_3d();
847
848
0
    size_t ndis = 0, nlist_visited = 0;
849
0
    size_t nprobe = cq.nprobe;
850
851
0
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
852
0
    for (idx_t i = 0; i < n; i++) {
853
0
        std::vector<uint16_t> tmp_dis(k);
854
0
        int64_t* heap_ids = labels + i * k;
855
0
        uint16_t* heap_dis = tmp_dis.data();
856
0
        heap_heapify<C>(k, heap_dis, heap_ids);
857
0
        const uint8_t* LUT = nullptr;
858
859
0
        if (single_LUT) {
860
0
            LUT = dis_tables.get() + i * dim12;
861
0
        }
862
0
        for (idx_t j = 0; j < nprobe; j++) {
863
0
            if (!single_LUT) {
864
0
                LUT = dis_tables.get() + (i * nprobe + j) * dim12;
865
0
            }
866
0
            idx_t list_no = cq.ids[i * nprobe + j];
867
0
            if (list_no < 0)
868
0
                continue;
869
0
            size_t ls = orig_invlists->list_size(list_no);
870
0
            if (ls == 0)
871
0
                continue;
872
0
            InvertedLists::ScopedCodes codes(orig_invlists, list_no);
873
0
            InvertedLists::ScopedIds ids(orig_invlists, list_no);
874
875
0
            uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
876
877
0
            estimators_from_tables_generic<C>(
878
0
                    *this,
879
0
                    codes.get(),
880
0
                    ls,
881
0
                    LUT,
882
0
                    ids.get(),
883
0
                    bias,
884
0
                    k,
885
0
                    heap_dis,
886
0
                    heap_ids,
887
0
                    scaler);
888
889
0
            nlist_visited++;
890
0
            ndis += ls;
891
0
        }
892
0
        heap_reorder<C>(k, heap_dis, heap_ids);
893
        // convert distances to float
894
0
        {
895
0
            float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
896
0
            if (skip & 16) {
897
0
                one_a = 1;
898
0
                b = 0;
899
0
            }
900
0
            float* heap_dis_float = distances + i * k;
901
0
            for (int j = 0; j < k; j++) {
902
0
                heap_dis_float[j] = b + heap_dis[j] * one_a;
903
0
            }
904
0
        }
905
0
    }
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZNK5faiss16IndexIVFFastScan15search_implem_2INS_4CMaxItlEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE.omp_outlined_debug__
Unexecuted instantiation: IndexIVFFastScan.cpp:_ZNK5faiss16IndexIVFFastScan15search_implem_2INS_4CMinItlEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE.omp_outlined_debug__
906
0
    indexIVF_stats.nq += n;
907
0
    indexIVF_stats.ndis += ndis;
908
0
    indexIVF_stats.nlist += nlist_visited;
909
0
}
Unexecuted instantiation: _ZNK5faiss16IndexIVFFastScan15search_implem_2INS_4CMaxItlEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE
Unexecuted instantiation: _ZNK5faiss16IndexIVFFastScan15search_implem_2INS_4CMinItlEEEEvlPKflPfPlRKNS0_15CoarseQuantizedEPKNS_15NormTableScalerEPKNS_19SearchParametersIVFE
910
911
void IndexIVFFastScan::search_implem_10(
912
        idx_t n,
913
        const float* x,
914
        SIMDResultHandlerToFloat& handler,
915
        const CoarseQuantized& cq,
916
        size_t* ndis_out,
917
        size_t* nlist_out,
918
        const NormTableScaler* scaler,
919
0
        const IVFSearchParameters* params) const {
920
0
    size_t dim12 = ksub * M2;
921
0
    AlignedTable<uint8_t> dis_tables;
922
0
    AlignedTable<uint16_t> biases;
923
0
    std::unique_ptr<float[]> normalizers(new float[2 * n]);
924
925
0
    compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
926
927
0
    bool single_LUT = !lookup_table_is_3d();
928
929
0
    size_t ndis = 0;
930
0
    int qmap1[1];
931
932
0
    handler.q_map = qmap1;
933
0
    handler.begin(skip & 16 ? nullptr : normalizers.get());
934
0
    size_t nprobe = cq.nprobe;
935
936
0
    for (idx_t i = 0; i < n; i++) {
937
0
        const uint8_t* LUT = nullptr;
938
0
        qmap1[0] = i;
939
940
0
        if (single_LUT) {
941
0
            LUT = dis_tables.get() + i * dim12;
942
0
        }
943
0
        for (idx_t j = 0; j < nprobe; j++) {
944
0
            size_t ij = i * nprobe + j;
945
0
            if (!single_LUT) {
946
0
                LUT = dis_tables.get() + ij * dim12;
947
0
            }
948
0
            if (biases.get()) {
949
0
                handler.dbias = biases.get() + ij;
950
0
            }
951
952
0
            idx_t list_no = cq.ids[ij];
953
0
            if (list_no < 0) {
954
0
                continue;
955
0
            }
956
0
            size_t ls = invlists->list_size(list_no);
957
0
            if (ls == 0) {
958
0
                continue;
959
0
            }
960
961
0
            InvertedLists::ScopedCodes codes(invlists, list_no);
962
0
            InvertedLists::ScopedIds ids(invlists, list_no);
963
964
0
            handler.ntotal = ls;
965
0
            handler.id_map = ids.get();
966
967
0
            pq4_accumulate_loop(
968
0
                    1,
969
0
                    roundup(ls, bbs),
970
0
                    bbs,
971
0
                    M2,
972
0
                    codes.get(),
973
0
                    LUT,
974
0
                    handler,
975
0
                    scaler);
976
977
0
            ndis++;
978
0
        }
979
0
    }
980
981
0
    handler.end();
982
0
    *ndis_out = ndis;
983
0
    *nlist_out = nlist;
984
0
}
985
986
void IndexIVFFastScan::search_implem_12(
987
        idx_t n,
988
        const float* x,
989
        SIMDResultHandlerToFloat& handler,
990
        const CoarseQuantized& cq,
991
        size_t* ndis_out,
992
        size_t* nlist_out,
993
        const NormTableScaler* scaler,
994
0
        const IVFSearchParameters* params) const {
995
0
    if (n == 0) { // does not work well with reservoir
996
0
        return;
997
0
    }
998
0
    FAISS_THROW_IF_NOT(bbs == 32);
999
1000
0
    size_t dim12 = ksub * M2;
1001
0
    AlignedTable<uint8_t> dis_tables;
1002
0
    AlignedTable<uint16_t> biases;
1003
0
    std::unique_ptr<float[]> normalizers(new float[2 * n]);
1004
1005
0
    compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
1006
1007
0
    handler.begin(skip & 16 ? nullptr : normalizers.get());
1008
1009
0
    struct QC {
1010
0
        int qno;     // sequence number of the query
1011
0
        int list_no; // list to visit
1012
0
        int rank;    // this is the rank'th result of the coarse quantizer
1013
0
    };
1014
0
    bool single_LUT = !lookup_table_is_3d();
1015
0
    size_t nprobe = cq.nprobe;
1016
1017
0
    std::vector<QC> qcs;
1018
0
    {
1019
0
        int ij = 0;
1020
0
        for (int i = 0; i < n; i++) {
1021
0
            for (int j = 0; j < nprobe; j++) {
1022
0
                if (cq.ids[ij] >= 0) {
1023
0
                    qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
1024
0
                }
1025
0
                ij++;
1026
0
            }
1027
0
        }
1028
0
        std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
1029
0
            return a.list_no < b.list_no;
1030
0
        });
1031
0
    }
1032
1033
    // prepare the result handlers
1034
1035
0
    int actual_qbs2 = this->qbs2 ? this->qbs2 : 11;
1036
1037
0
    std::vector<uint16_t> tmp_bias;
1038
0
    if (biases.get()) {
1039
0
        tmp_bias.resize(actual_qbs2);
1040
0
        handler.dbias = tmp_bias.data();
1041
0
    }
1042
1043
0
    size_t ndis = 0;
1044
1045
0
    size_t i0 = 0;
1046
0
    uint64_t t_copy_pack = 0, t_scan = 0;
1047
0
    while (i0 < qcs.size()) {
1048
        // find all queries that access this inverted list
1049
0
        int list_no = qcs[i0].list_no;
1050
0
        size_t i1 = i0 + 1;
1051
1052
0
        while (i1 < qcs.size() && i1 < i0 + actual_qbs2) {
1053
0
            if (qcs[i1].list_no != list_no) {
1054
0
                break;
1055
0
            }
1056
0
            i1++;
1057
0
        }
1058
1059
0
        size_t list_size = invlists->list_size(list_no);
1060
1061
0
        if (list_size == 0) {
1062
0
            i0 = i1;
1063
0
            continue;
1064
0
        }
1065
1066
        // re-organize LUTs and biases into the right order
1067
0
        int nc = i1 - i0;
1068
1069
0
        std::vector<int> q_map(nc), lut_entries(nc);
1070
0
        AlignedTable<uint8_t> LUT(nc * dim12);
1071
0
        memset(LUT.get(), -1, nc * dim12);
1072
0
        int qbs_for_list = pq4_preferred_qbs(nc);
1073
1074
0
        for (size_t i = i0; i < i1; i++) {
1075
0
            const QC& qc = qcs[i];
1076
0
            q_map[i - i0] = qc.qno;
1077
0
            int ij = qc.qno * nprobe + qc.rank;
1078
0
            lut_entries[i - i0] = single_LUT ? qc.qno : ij;
1079
0
            if (biases.get()) {
1080
0
                tmp_bias[i - i0] = biases[ij];
1081
0
            }
1082
0
        }
1083
0
        pq4_pack_LUT_qbs_q_map(
1084
0
                qbs_for_list,
1085
0
                M2,
1086
0
                dis_tables.get(),
1087
0
                lut_entries.data(),
1088
0
                LUT.get());
1089
1090
        // access the inverted list
1091
1092
0
        ndis += (i1 - i0) * list_size;
1093
1094
0
        InvertedLists::ScopedCodes codes(invlists, list_no);
1095
0
        InvertedLists::ScopedIds ids(invlists, list_no);
1096
1097
        // prepare the handler
1098
1099
0
        handler.ntotal = list_size;
1100
0
        handler.q_map = q_map.data();
1101
0
        handler.id_map = ids.get();
1102
1103
0
        pq4_accumulate_loop_qbs(
1104
0
                qbs_for_list,
1105
0
                list_size,
1106
0
                M2,
1107
0
                codes.get(),
1108
0
                LUT.get(),
1109
0
                handler,
1110
0
                scaler);
1111
        // prepare for next loop
1112
0
        i0 = i1;
1113
0
    }
1114
1115
0
    handler.end();
1116
1117
    // these stats are not thread-safe
1118
1119
0
    IVFFastScan_stats.t_copy_pack += t_copy_pack;
1120
0
    IVFFastScan_stats.t_scan += t_scan;
1121
1122
0
    *ndis_out = ndis;
1123
0
    *nlist_out = nlist;
1124
0
}
1125
1126
void IndexIVFFastScan::search_implem_14(
1127
        idx_t n,
1128
        const float* x,
1129
        idx_t k,
1130
        float* distances,
1131
        idx_t* labels,
1132
        const CoarseQuantized& cq,
1133
        int impl,
1134
        const NormTableScaler* scaler,
1135
0
        const IVFSearchParameters* params) const {
1136
0
    if (n == 0) { // does not work well with reservoir
1137
0
        return;
1138
0
    }
1139
0
    FAISS_THROW_IF_NOT(bbs == 32);
1140
1141
0
    const IDSelector* sel = params ? params->sel : nullptr;
1142
1143
0
    size_t dim12 = ksub * M2;
1144
0
    AlignedTable<uint8_t> dis_tables;
1145
0
    AlignedTable<uint16_t> biases;
1146
0
    std::unique_ptr<float[]> normalizers(new float[2 * n]);
1147
1148
0
    compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
1149
1150
0
    struct QC {
1151
0
        int qno;     // sequence number of the query
1152
0
        int list_no; // list to visit
1153
0
        int rank;    // this is the rank'th result of the coarse quantizer
1154
0
    };
1155
0
    bool single_LUT = !lookup_table_is_3d();
1156
0
    size_t nprobe = cq.nprobe;
1157
1158
0
    std::vector<QC> qcs;
1159
0
    {
1160
0
        int ij = 0;
1161
0
        for (int i = 0; i < n; i++) {
1162
0
            for (int j = 0; j < nprobe; j++) {
1163
0
                if (cq.ids[ij] >= 0) {
1164
0
                    qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
1165
0
                }
1166
0
                ij++;
1167
0
            }
1168
0
        }
1169
0
        std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
1170
0
            return a.list_no < b.list_no;
1171
0
        });
1172
0
    }
1173
1174
0
    struct SE {
1175
0
        size_t start; // start in the QC vector
1176
0
        size_t end;   // end in the QC vector
1177
0
        size_t list_size;
1178
0
    };
1179
0
    std::vector<SE> ses;
1180
0
    size_t i0_l = 0;
1181
0
    while (i0_l < qcs.size()) {
1182
        // find all queries that access this inverted list
1183
0
        int list_no = qcs[i0_l].list_no;
1184
0
        size_t i1 = i0_l + 1;
1185
1186
0
        while (i1 < qcs.size() && i1 < i0_l + qbs2) {
1187
0
            if (qcs[i1].list_no != list_no) {
1188
0
                break;
1189
0
            }
1190
0
            i1++;
1191
0
        }
1192
1193
0
        size_t list_size = invlists->list_size(list_no);
1194
1195
0
        if (list_size == 0) {
1196
0
            i0_l = i1;
1197
0
            continue;
1198
0
        }
1199
0
        ses.push_back(SE{i0_l, i1, list_size});
1200
0
        i0_l = i1;
1201
0
    }
1202
1203
    // function to handle the global heap
1204
0
    bool is_max = !is_similarity_metric(metric_type);
1205
0
    using HeapForIP = CMin<float, idx_t>;
1206
0
    using HeapForL2 = CMax<float, idx_t>;
1207
0
    auto init_result = [&](float* simi, idx_t* idxi) {
1208
0
        if (!is_max) {
1209
0
            heap_heapify<HeapForIP>(k, simi, idxi);
1210
0
        } else {
1211
0
            heap_heapify<HeapForL2>(k, simi, idxi);
1212
0
        }
1213
0
    };
1214
1215
0
    auto add_local_results = [&](const float* local_dis,
1216
0
                                 const idx_t* local_idx,
1217
0
                                 float* simi,
1218
0
                                 idx_t* idxi) {
1219
0
        if (!is_max) {
1220
0
            heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
1221
0
        } else {
1222
0
            heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
1223
0
        }
1224
0
    };
1225
1226
0
    auto reorder_result = [&](float* simi, idx_t* idxi) {
1227
0
        if (!is_max) {
1228
0
            heap_reorder<HeapForIP>(k, simi, idxi);
1229
0
        } else {
1230
0
            heap_reorder<HeapForL2>(k, simi, idxi);
1231
0
        }
1232
0
    };
1233
1234
0
    size_t ndis = 0;
1235
0
    size_t nlist_visited = 0;
1236
1237
0
#pragma omp parallel reduction(+ : ndis, nlist_visited)
1238
0
    {
1239
        // storage for each thread
1240
0
        std::vector<idx_t> local_idx(k * n);
1241
0
        std::vector<float> local_dis(k * n);
1242
1243
        // prepare the result handlers
1244
0
        std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
1245
0
                is_max, impl, n, k, local_dis.data(), local_idx.data(), sel));
1246
0
        handler->begin(normalizers.get());
1247
1248
0
        int actual_qbs2 = this->qbs2 ? this->qbs2 : 11;
1249
1250
0
        std::vector<uint16_t> tmp_bias;
1251
0
        if (biases.get()) {
1252
0
            tmp_bias.resize(actual_qbs2);
1253
0
            handler->dbias = tmp_bias.data();
1254
0
        }
1255
1256
0
        std::set<int> q_set;
1257
0
        uint64_t t_copy_pack = 0, t_scan = 0;
1258
0
#pragma omp for schedule(dynamic)
1259
0
        for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
1260
0
            size_t i0 = ses[cluster].start;
1261
0
            size_t i1 = ses[cluster].end;
1262
0
            size_t list_size = ses[cluster].list_size;
1263
0
            nlist_visited++;
1264
0
            int list_no = qcs[i0].list_no;
1265
1266
            // re-organize LUTs and biases into the right order
1267
0
            int nc = i1 - i0;
1268
1269
0
            std::vector<int> q_map(nc), lut_entries(nc);
1270
0
            AlignedTable<uint8_t> LUT(nc * dim12);
1271
0
            memset(LUT.get(), -1, nc * dim12);
1272
0
            int qbs_for_list = pq4_preferred_qbs(nc);
1273
1274
0
            for (size_t i = i0; i < i1; i++) {
1275
0
                const QC& qc = qcs[i];
1276
0
                q_map[i - i0] = qc.qno;
1277
0
                q_set.insert(qc.qno);
1278
0
                int ij = qc.qno * nprobe + qc.rank;
1279
0
                lut_entries[i - i0] = single_LUT ? qc.qno : ij;
1280
0
                if (biases.get()) {
1281
0
                    tmp_bias[i - i0] = biases[ij];
1282
0
                }
1283
0
            }
1284
0
            pq4_pack_LUT_qbs_q_map(
1285
0
                    qbs_for_list,
1286
0
                    M2,
1287
0
                    dis_tables.get(),
1288
0
                    lut_entries.data(),
1289
0
                    LUT.get());
1290
1291
            // access the inverted list
1292
1293
0
            ndis += (i1 - i0) * list_size;
1294
1295
0
            InvertedLists::ScopedCodes codes(invlists, list_no);
1296
0
            InvertedLists::ScopedIds ids(invlists, list_no);
1297
1298
            // prepare the handler
1299
1300
0
            handler->ntotal = list_size;
1301
0
            handler->q_map = q_map.data();
1302
0
            handler->id_map = ids.get();
1303
1304
0
            pq4_accumulate_loop_qbs(
1305
0
                    qbs_for_list,
1306
0
                    list_size,
1307
0
                    M2,
1308
0
                    codes.get(),
1309
0
                    LUT.get(),
1310
0
                    *handler.get(),
1311
0
                    scaler);
1312
0
        }
1313
1314
        // labels is in-place for HeapHC
1315
0
        handler->end();
1316
1317
        // merge per-thread results
1318
0
#pragma omp single
1319
0
        {
1320
            // we init the results as a heap
1321
0
            for (idx_t i = 0; i < n; i++) {
1322
0
                init_result(distances + i * k, labels + i * k);
1323
0
            }
1324
0
        }
1325
0
#pragma omp barrier
1326
0
#pragma omp critical
1327
0
        {
1328
            // write to global heap  #go over only the queries
1329
0
            for (std::set<int>::iterator it = q_set.begin(); it != q_set.end();
1330
0
                 ++it) {
1331
0
                add_local_results(
1332
0
                        local_dis.data() + *it * k,
1333
0
                        local_idx.data() + *it * k,
1334
0
                        distances + *it * k,
1335
0
                        labels + *it * k);
1336
0
            }
1337
1338
0
            IVFFastScan_stats.t_copy_pack += t_copy_pack;
1339
0
            IVFFastScan_stats.t_scan += t_scan;
1340
0
        }
1341
0
#pragma omp barrier
1342
0
#pragma omp single
1343
0
        {
1344
0
            for (idx_t i = 0; i < n; i++) {
1345
0
                reorder_result(distances + i * k, labels + i * k);
1346
0
            }
1347
0
        }
1348
0
    }
1349
1350
0
    indexIVF_stats.nq += n;
1351
0
    indexIVF_stats.ndis += ndis;
1352
0
    indexIVF_stats.nlist += nlist_visited;
1353
0
}
1354
1355
void IndexIVFFastScan::reconstruct_from_offset(
1356
        int64_t list_no,
1357
        int64_t offset,
1358
0
        float* recons) const {
1359
    // unpack codes
1360
0
    size_t coarse_size = coarse_code_size();
1361
0
    std::vector<uint8_t> code(coarse_size + code_size, 0);
1362
0
    encode_listno(list_no, code.data());
1363
0
    InvertedLists::ScopedCodes list_codes(invlists, list_no);
1364
0
    BitstringWriter bsw(code.data() + coarse_size, code_size);
1365
1366
0
    for (size_t m = 0; m < M; m++) {
1367
0
        uint8_t c =
1368
0
                pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
1369
0
        bsw.write(c, nbits);
1370
0
    }
1371
1372
0
    sa_decode(1, code.data(), recons);
1373
0
}
1374
1375
0
void IndexIVFFastScan::reconstruct_orig_invlists() {
1376
0
    FAISS_THROW_IF_NOT(orig_invlists != nullptr);
1377
0
    FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);
1378
1379
0
#pragma omp parallel for if (nlist > 100)
1380
0
    for (idx_t list_no = 0; list_no < nlist; list_no++) {
1381
0
        InvertedLists::ScopedCodes codes(invlists, list_no);
1382
0
        InvertedLists::ScopedIds ids(invlists, list_no);
1383
0
        size_t list_size = invlists->list_size(list_no);
1384
0
        std::vector<uint8_t> code(code_size, 0);
1385
1386
0
        for (size_t offset = 0; offset < list_size; offset++) {
1387
            // unpack codes
1388
0
            BitstringWriter bsw(code.data(), code_size);
1389
0
            for (size_t m = 0; m < M; m++) {
1390
0
                uint8_t c =
1391
0
                        pq4_get_packed_element(codes.get(), bbs, M2, offset, m);
1392
0
                bsw.write(c, nbits);
1393
0
            }
1394
1395
            // get id
1396
0
            idx_t id = ids.get()[offset];
1397
1398
0
            orig_invlists->add_entry(list_no, id, code.data());
1399
0
        }
1400
0
    }
1401
0
}
1402
1403
void IndexIVFFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
1404
0
        const {
1405
0
    size_t coarse_size = coarse_code_size();
1406
1407
0
#pragma omp parallel if (n > 1)
1408
0
    {
1409
0
        std::vector<float> residual(d);
1410
1411
0
#pragma omp for
1412
0
        for (idx_t i = 0; i < n; i++) {
1413
0
            const uint8_t* code = codes + i * (code_size + coarse_size);
1414
0
            int64_t list_no = decode_listno(code);
1415
0
            float* xi = x + i * d;
1416
0
            fine_quantizer->decode(code + coarse_size, xi, 1);
1417
0
            if (by_residual) {
1418
0
                quantizer->reconstruct(list_no, residual.data());
1419
0
                for (size_t j = 0; j < d; j++) {
1420
0
                    xi[j] += residual[j];
1421
0
                }
1422
0
            }
1423
0
        }
1424
0
    }
1425
0
}
1426
1427
IVFFastScanStats IVFFastScan_stats;
1428
1429
} // namespace faiss