Coverage Report

Created: 2025-10-24 07:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexFlat.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/IndexFlat.h>
11
12
#include <faiss/impl/AuxIndexStructures.h>
13
#include <faiss/impl/FaissAssert.h>
14
#include <faiss/utils/Heap.h>
15
#include <faiss/utils/distances.h>
16
#include <faiss/utils/extra_distances.h>
17
#include <faiss/utils/prefetch.h>
18
#include <faiss/utils/sorting.h>
19
#include <cstring>
20
21
namespace faiss {
22
23
IndexFlat::IndexFlat(idx_t d, MetricType metric)
24
138
        : IndexFlatCodes(sizeof(float) * d, d, metric) {}
25
26
void IndexFlat::search(
27
        idx_t n,
28
        const float* x,
29
        idx_t k,
30
        float* distances,
31
        idx_t* labels,
32
0
        const SearchParameters* params) const {
33
0
    IDSelector* sel = params ? params->sel : nullptr;
34
0
    FAISS_THROW_IF_NOT(k > 0);
35
36
    // we see the distances and labels as heaps
37
0
    if (metric_type == METRIC_INNER_PRODUCT) {
38
0
        float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
39
0
        knn_inner_product(x, get_xb(), d, n, ntotal, &res, sel);
40
0
    } else if (metric_type == METRIC_L2) {
41
0
        float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
42
0
        knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
43
0
    } else {
44
0
        FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
45
0
        knn_extra_metrics(
46
0
                x,
47
0
                get_xb(),
48
0
                d,
49
0
                n,
50
0
                ntotal,
51
0
                metric_type,
52
0
                metric_arg,
53
0
                k,
54
0
                distances,
55
0
                labels);
56
0
    }
57
0
}
58
59
void IndexFlat::range_search(
60
        idx_t n,
61
        const float* x,
62
        float radius,
63
        RangeSearchResult* result,
64
3
        const SearchParameters* params) const {
65
3
    IDSelector* sel = params ? params->sel : nullptr;
66
67
3
    switch (metric_type) {
68
3
        case METRIC_INNER_PRODUCT:
69
3
            range_search_inner_product(
70
3
                    x, get_xb(), d, n, ntotal, radius, result, sel);
71
3
            break;
72
0
        case METRIC_L2:
73
0
            range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result, sel);
74
0
            break;
75
0
        default:
76
0
            FAISS_THROW_MSG("metric type not supported");
77
3
    }
78
3
}
79
80
void IndexFlat::compute_distance_subset(
81
        idx_t n,
82
        const float* x,
83
        idx_t k,
84
        float* distances,
85
0
        const idx_t* labels) const {
86
0
    switch (metric_type) {
87
0
        case METRIC_INNER_PRODUCT:
88
0
            fvec_inner_products_by_idx(distances, x, get_xb(), labels, d, n, k);
89
0
            break;
90
0
        case METRIC_L2:
91
0
            fvec_L2sqr_by_idx(distances, x, get_xb(), labels, d, n, k);
92
0
            break;
93
0
        default:
94
0
            FAISS_THROW_MSG("metric type not supported");
95
0
    }
96
0
}
97
98
namespace {
99
100
struct FlatL2Dis : FlatCodesDistanceComputer {
101
    size_t d;
102
    idx_t nb;
103
    const float* q;
104
    const float* b;
105
    size_t ndis;
106
107
985k
    float distance_to_code(const uint8_t* code) final {
108
985k
        ndis++;
109
985k
        return fvec_L2sqr(q, (float*)code, d);
110
985k
    }
111
112
6.17M
    float symmetric_dis(idx_t i, idx_t j) override {
113
6.17M
        return fvec_L2sqr(b + j * d, b + i * d, d);
114
6.17M
    }
115
116
    explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
117
19.0k
            : FlatCodesDistanceComputer(
118
19.0k
                      storage.codes.data(),
119
19.0k
                      storage.code_size),
120
19.0k
              d(storage.d),
121
19.0k
              nb(storage.ntotal),
122
19.0k
              q(q),
123
19.0k
              b(storage.get_xb()),
124
19.0k
              ndis(0) {}
125
126
20.8k
    void set_query(const float* x) override {
127
20.8k
        q = x;
128
20.8k
    }
129
130
    // compute four distances
131
    void distances_batch_4(
132
            const idx_t idx0,
133
            const idx_t idx1,
134
            const idx_t idx2,
135
            const idx_t idx3,
136
            float& dis0,
137
            float& dis1,
138
            float& dis2,
139
1.30M
            float& dis3) final override {
140
1.30M
        ndis += 4;
141
142
        // compute first, assign next
143
1.30M
        const float* __restrict y0 =
144
1.30M
                reinterpret_cast<const float*>(codes + idx0 * code_size);
145
1.30M
        const float* __restrict y1 =
146
1.30M
                reinterpret_cast<const float*>(codes + idx1 * code_size);
147
1.30M
        const float* __restrict y2 =
148
1.30M
                reinterpret_cast<const float*>(codes + idx2 * code_size);
149
1.30M
        const float* __restrict y3 =
150
1.30M
                reinterpret_cast<const float*>(codes + idx3 * code_size);
151
152
1.30M
        float dp0 = 0;
153
1.30M
        float dp1 = 0;
154
1.30M
        float dp2 = 0;
155
1.30M
        float dp3 = 0;
156
1.30M
        fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
157
1.30M
        dis0 = dp0;
158
1.30M
        dis1 = dp1;
159
1.30M
        dis2 = dp2;
160
1.30M
        dis3 = dp3;
161
1.30M
    }
162
};
163
164
struct FlatIPDis : FlatCodesDistanceComputer {
165
    size_t d;
166
    idx_t nb;
167
    const float* q;
168
    const float* b;
169
    size_t ndis;
170
171
1.09M
    float symmetric_dis(idx_t i, idx_t j) final override {
172
1.09M
        return fvec_inner_product(b + j * d, b + i * d, d);
173
1.09M
    }
174
175
196k
    float distance_to_code(const uint8_t* code) final override {
176
196k
        ndis++;
177
196k
        return fvec_inner_product(q, (const float*)code, d);
178
196k
    }
179
180
    explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
181
1.17k
            : FlatCodesDistanceComputer(
182
1.17k
                      storage.codes.data(),
183
1.17k
                      storage.code_size),
184
1.17k
              d(storage.d),
185
1.17k
              nb(storage.ntotal),
186
1.17k
              q(q),
187
1.17k
              b(storage.get_xb()),
188
1.17k
              ndis(0) {}
189
190
6.48k
    void set_query(const float* x) override {
191
6.48k
        q = x;
192
6.48k
    }
193
194
    // compute four distances
195
    void distances_batch_4(
196
            const idx_t idx0,
197
            const idx_t idx1,
198
            const idx_t idx2,
199
            const idx_t idx3,
200
            float& dis0,
201
            float& dis1,
202
            float& dis2,
203
271k
            float& dis3) final override {
204
271k
        ndis += 4;
205
206
        // compute first, assign next
207
271k
        const float* __restrict y0 =
208
271k
                reinterpret_cast<const float*>(codes + idx0 * code_size);
209
271k
        const float* __restrict y1 =
210
271k
                reinterpret_cast<const float*>(codes + idx1 * code_size);
211
271k
        const float* __restrict y2 =
212
271k
                reinterpret_cast<const float*>(codes + idx2 * code_size);
213
271k
        const float* __restrict y3 =
214
271k
                reinterpret_cast<const float*>(codes + idx3 * code_size);
215
216
271k
        float dp0 = 0;
217
271k
        float dp1 = 0;
218
271k
        float dp2 = 0;
219
271k
        float dp3 = 0;
220
271k
        fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
221
271k
        dis0 = dp0;
222
271k
        dis1 = dp1;
223
271k
        dis2 = dp2;
224
271k
        dis3 = dp3;
225
271k
    }
226
};
227
228
} // namespace
229
230
20.2k
FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
231
20.2k
    if (metric_type == METRIC_L2) {
232
19.0k
        return new FlatL2Dis(*this);
233
19.0k
    } else if (metric_type == METRIC_INNER_PRODUCT) {
234
1.17k
        return new FlatIPDis(*this);
235
1.17k
    } else {
236
0
        return get_extra_distance_computer(
237
0
                d, metric_type, metric_arg, ntotal, get_xb());
238
0
    }
239
20.2k
}
240
241
0
void IndexFlat::reconstruct(idx_t key, float* recons) const {
242
0
    memcpy(recons, &(codes[key * code_size]), code_size);
243
0
}
244
245
18.8k
void IndexFlat::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
246
18.8k
    if (n > 0) {
247
18.8k
        memcpy(bytes, x, sizeof(float) * d * n);
248
18.8k
    }
249
18.8k
}
250
251
0
void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
252
0
    if (n > 0) {
253
0
        memcpy(x, bytes, sizeof(float) * d * n);
254
0
    }
255
0
}
256
257
/***************************************************
258
 * IndexFlatL2
259
 ***************************************************/
260
261
namespace {
262
struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
263
    size_t d;
264
    idx_t nb;
265
    const float* q;
266
    const float* b;
267
    size_t ndis;
268
269
    const float* l2norms;
270
    float query_l2norm;
271
272
0
    float distance_to_code(const uint8_t* code) final override {
273
0
        ndis++;
274
0
        return fvec_L2sqr(q, (float*)code, d);
275
0
    }
276
277
0
    float operator()(const idx_t i) final override {
278
0
        const float* __restrict y =
279
0
                reinterpret_cast<const float*>(codes + i * code_size);
280
281
0
        prefetch_L2(l2norms + i);
282
0
        const float dp0 = fvec_inner_product(q, y, d);
283
0
        return query_l2norm + l2norms[i] - 2 * dp0;
284
0
    }
285
286
0
    float symmetric_dis(idx_t i, idx_t j) final override {
287
0
        const float* __restrict yi =
288
0
                reinterpret_cast<const float*>(codes + i * code_size);
289
0
        const float* __restrict yj =
290
0
                reinterpret_cast<const float*>(codes + j * code_size);
291
292
0
        prefetch_L2(l2norms + i);
293
0
        prefetch_L2(l2norms + j);
294
0
        const float dp0 = fvec_inner_product(yi, yj, d);
295
0
        return l2norms[i] + l2norms[j] - 2 * dp0;
296
0
    }
297
298
    explicit FlatL2WithNormsDis(
299
            const IndexFlatL2& storage,
300
            const float* q = nullptr)
301
0
            : FlatCodesDistanceComputer(
302
0
                      storage.codes.data(),
303
0
                      storage.code_size),
304
0
              d(storage.d),
305
0
              nb(storage.ntotal),
306
0
              q(q),
307
0
              b(storage.get_xb()),
308
0
              ndis(0),
309
0
              l2norms(storage.cached_l2norms.data()),
310
0
              query_l2norm(0) {}
311
312
0
    void set_query(const float* x) override {
313
0
        q = x;
314
0
        query_l2norm = fvec_norm_L2sqr(q, d);
315
0
    }
316
317
    // compute four distances
318
    void distances_batch_4(
319
            const idx_t idx0,
320
            const idx_t idx1,
321
            const idx_t idx2,
322
            const idx_t idx3,
323
            float& dis0,
324
            float& dis1,
325
            float& dis2,
326
0
            float& dis3) final override {
327
0
        ndis += 4;
328
329
        // compute first, assign next
330
0
        const float* __restrict y0 =
331
0
                reinterpret_cast<const float*>(codes + idx0 * code_size);
332
0
        const float* __restrict y1 =
333
0
                reinterpret_cast<const float*>(codes + idx1 * code_size);
334
0
        const float* __restrict y2 =
335
0
                reinterpret_cast<const float*>(codes + idx2 * code_size);
336
0
        const float* __restrict y3 =
337
0
                reinterpret_cast<const float*>(codes + idx3 * code_size);
338
339
0
        prefetch_L2(l2norms + idx0);
340
0
        prefetch_L2(l2norms + idx1);
341
0
        prefetch_L2(l2norms + idx2);
342
0
        prefetch_L2(l2norms + idx3);
343
344
0
        float dp0 = 0;
345
0
        float dp1 = 0;
346
0
        float dp2 = 0;
347
0
        float dp3 = 0;
348
0
        fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
349
0
        dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
350
0
        dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
351
0
        dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
352
0
        dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
353
0
    }
354
};
355
356
} // namespace
357
358
0
void IndexFlatL2::sync_l2norms() {
359
0
    cached_l2norms.resize(ntotal);
360
0
    fvec_norms_L2sqr(
361
0
            cached_l2norms.data(),
362
0
            reinterpret_cast<const float*>(codes.data()),
363
0
            d,
364
0
            ntotal);
365
0
}
366
367
0
void IndexFlatL2::clear_l2norms() {
368
0
    cached_l2norms.clear();
369
0
    cached_l2norms.shrink_to_fit();
370
0
}
371
372
19.0k
FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
373
19.0k
    if (metric_type == METRIC_L2) {
374
19.0k
        if (!cached_l2norms.empty()) {
375
0
            return new FlatL2WithNormsDis(*this);
376
0
        }
377
19.0k
    }
378
379
19.0k
    return IndexFlat::get_FlatCodesDistanceComputer();
380
19.0k
}
381
382
/***************************************************
383
 * IndexFlat1D
384
 ***************************************************/
385
386
IndexFlat1D::IndexFlat1D(bool continuous_update)
387
0
        : IndexFlatL2(1), continuous_update(continuous_update) {}
388
389
/// if not continuous_update, call this between the last add and
390
/// the first search
391
0
void IndexFlat1D::update_permutation() {
392
0
    perm.resize(ntotal);
393
0
    if (ntotal < 1000000) {
394
0
        fvec_argsort(ntotal, get_xb(), (size_t*)perm.data());
395
0
    } else {
396
0
        fvec_argsort_parallel(ntotal, get_xb(), (size_t*)perm.data());
397
0
    }
398
0
}
399
400
0
void IndexFlat1D::add(idx_t n, const float* x) {
401
0
    IndexFlatL2::add(n, x);
402
0
    if (continuous_update)
403
0
        update_permutation();
404
0
}
405
406
0
void IndexFlat1D::reset() {
407
0
    IndexFlatL2::reset();
408
0
    perm.clear();
409
0
}
410
411
void IndexFlat1D::search(
412
        idx_t n,
413
        const float* x,
414
        idx_t k,
415
        float* distances,
416
        idx_t* labels,
417
0
        const SearchParameters* params) const {
418
0
    FAISS_THROW_IF_NOT_MSG(
419
0
            !params, "search params not supported for this index");
420
0
    FAISS_THROW_IF_NOT(k > 0);
421
0
    FAISS_THROW_IF_NOT_MSG(
422
0
            perm.size() == ntotal, "Call update_permutation before search");
423
0
    const float* xb = get_xb();
424
425
0
#pragma omp parallel for if (n > 10000)
426
0
    for (idx_t i = 0; i < n; i++) {
427
0
        float q = x[i]; // query
428
0
        float* D = distances + i * k;
429
0
        idx_t* I = labels + i * k;
430
431
        // binary search
432
0
        idx_t i0 = 0, i1 = ntotal;
433
0
        idx_t wp = 0;
434
435
0
        if (ntotal == 0) {
436
0
            for (idx_t j = 0; j < k; j++) {
437
0
                I[j] = -1;
438
0
                D[j] = HUGE_VAL;
439
0
            }
440
0
            goto done;
441
0
        }
442
443
0
        if (xb[perm[i0]] > q) {
444
0
            i1 = 0;
445
0
            goto finish_right;
446
0
        }
447
448
0
        if (xb[perm[i1 - 1]] <= q) {
449
0
            i0 = i1 - 1;
450
0
            goto finish_left;
451
0
        }
452
453
0
        while (i0 + 1 < i1) {
454
0
            idx_t imed = (i0 + i1) / 2;
455
0
            if (xb[perm[imed]] <= q)
456
0
                i0 = imed;
457
0
            else
458
0
                i1 = imed;
459
0
        }
460
461
        // query is between xb[perm[i0]] and xb[perm[i1]]
462
        // expand to nearest neighs
463
464
0
        while (wp < k) {
465
0
            float xleft = xb[perm[i0]];
466
0
            float xright = xb[perm[i1]];
467
468
0
            if (q - xleft < xright - q) {
469
0
                D[wp] = q - xleft;
470
0
                I[wp] = perm[i0];
471
0
                i0--;
472
0
                wp++;
473
0
                if (i0 < 0) {
474
0
                    goto finish_right;
475
0
                }
476
0
            } else {
477
0
                D[wp] = xright - q;
478
0
                I[wp] = perm[i1];
479
0
                i1++;
480
0
                wp++;
481
0
                if (i1 >= ntotal) {
482
0
                    goto finish_left;
483
0
                }
484
0
            }
485
0
        }
486
0
        goto done;
487
488
0
    finish_right:
489
        // grow to the right from i1
490
0
        while (wp < k) {
491
0
            if (i1 < ntotal) {
492
0
                D[wp] = xb[perm[i1]] - q;
493
0
                I[wp] = perm[i1];
494
0
                i1++;
495
0
            } else {
496
0
                D[wp] = std::numeric_limits<float>::infinity();
497
0
                I[wp] = -1;
498
0
            }
499
0
            wp++;
500
0
        }
501
0
        goto done;
502
503
0
    finish_left:
504
        // grow to the left from i0
505
0
        while (wp < k) {
506
0
            if (i0 >= 0) {
507
0
                D[wp] = q - xb[perm[i0]];
508
0
                I[wp] = perm[i0];
509
0
                i0--;
510
0
            } else {
511
0
                D[wp] = std::numeric_limits<float>::infinity();
512
0
                I[wp] = -1;
513
0
            }
514
0
            wp++;
515
0
        }
516
0
    done:;
517
0
    }
518
0
}
519
520
} // namespace faiss