Coverage Report

Created: 2025-11-14 17:37

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/NSG.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/NSG.h>
9
10
#include <algorithm>
11
#include <memory>
12
#include <mutex>
13
#include <stack>
14
15
#include <faiss/impl/DistanceComputer.h>
16
17
namespace faiss {
18
19
namespace {
20
21
using LockGuard = std::lock_guard<std::mutex>;
22
23
// It needs to be smaller than 0
24
constexpr int EMPTY_ID = -1;
25
26
} // anonymous namespace
27
28
namespace nsg {
29
30
0
DistanceComputer* storage_distance_computer(const Index* storage) {
31
0
    if (is_similarity_metric(storage->metric_type)) {
32
0
        return new NegativeDistanceComputer(storage->get_distance_computer());
33
0
    } else {
34
0
        return storage->get_distance_computer();
35
0
    }
36
0
}
37
38
struct Neighbor {
39
    int32_t id;
40
    float distance;
41
    bool flag;
42
43
    Neighbor() = default;
44
    Neighbor(int id, float distance, bool f)
45
0
            : id(id), distance(distance), flag(f) {}
46
47
0
    inline bool operator<(const Neighbor& other) const {
48
0
        return distance < other.distance;
49
0
    }
50
};
51
52
struct Node {
53
    int32_t id;
54
    float distance;
55
56
    Node() = default;
57
0
    Node(int id, float distance) : id(id), distance(distance) {}
58
59
0
    inline bool operator<(const Node& other) const {
60
0
        return distance < other.distance;
61
0
    }
62
63
    // to keep the compiler happy
64
0
    inline bool operator<(int other) const {
65
0
        return id < other;
66
0
    }
67
};
68
69
0
inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
70
    // find the location to insert
71
0
    int left = 0, right = K - 1;
72
0
    if (addr[left].distance > nn.distance) {
73
0
        memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
74
0
        addr[left] = nn;
75
0
        return left;
76
0
    }
77
0
    if (addr[right].distance < nn.distance) {
78
0
        addr[K] = nn;
79
0
        return K;
80
0
    }
81
0
    while (left < right - 1) {
82
0
        int mid = (left + right) / 2;
83
0
        if (addr[mid].distance > nn.distance) {
84
0
            right = mid;
85
0
        } else {
86
0
            left = mid;
87
0
        }
88
0
    }
89
    // check equal ID
90
91
0
    while (left > 0) {
92
0
        if (addr[left].distance < nn.distance) {
93
0
            break;
94
0
        }
95
0
        if (addr[left].id == nn.id) {
96
0
            return K + 1;
97
0
        }
98
0
        left--;
99
0
    }
100
0
    if (addr[left].id == nn.id || addr[right].id == nn.id) {
101
0
        return K + 1;
102
0
    }
103
0
    memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
104
0
    addr[right] = nn;
105
0
    return right;
106
0
}
107
108
} // namespace nsg
109
110
using namespace nsg;
111
112
0
NSG::NSG(int R) : R(R), rng(0x0903) {
113
0
    L = R + 32;
114
0
    C = R + 100;
115
0
    srand(0x1998);
116
0
}
117
118
void NSG::search(
119
        DistanceComputer& dis,
120
        int k,
121
        idx_t* I,
122
        float* D,
123
0
        VisitedTable& vt) const {
124
0
    FAISS_THROW_IF_NOT(is_built);
125
0
    FAISS_THROW_IF_NOT(final_graph);
126
127
0
    int pool_size = std::max(search_L, k);
128
0
    std::vector<Neighbor> retset;
129
0
    std::vector<Node> tmp;
130
0
    search_on_graph<false>(
131
0
            *final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
132
133
0
    for (size_t i = 0; i < k; i++) {
134
0
        I[i] = retset[i].id;
135
0
        D[i] = retset[i].distance;
136
0
    }
137
0
}
138
139
void NSG::build(
140
        Index* storage,
141
        idx_t n,
142
        const nsg::Graph<idx_t>& knn_graph,
143
0
        bool verbose) {
144
0
    FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
145
146
0
    if (verbose) {
147
0
        printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
148
0
    }
149
150
0
    ntotal = n;
151
0
    init_graph(storage, knn_graph);
152
153
0
    std::vector<int> degrees(n, 0);
154
0
    {
155
0
        nsg::Graph<Node> tmp_graph(n, R);
156
157
0
        link(storage, knn_graph, tmp_graph, verbose);
158
159
0
        final_graph = std::make_shared<nsg::Graph<int>>(n, R);
160
0
        std::fill_n(final_graph->data, n * R, EMPTY_ID);
161
162
0
#pragma omp parallel for
163
0
        for (int i = 0; i < n; i++) {
164
0
            int cnt = 0;
165
0
            for (int j = 0; j < R; j++) {
166
0
                int id = tmp_graph.at(i, j).id;
167
0
                if (id != EMPTY_ID) {
168
0
                    final_graph->at(i, cnt) = id;
169
0
                    cnt += 1;
170
0
                }
171
0
                degrees[i] = cnt;
172
0
            }
173
0
        }
174
0
    }
175
176
0
    int num_attached = tree_grow(storage, degrees);
177
0
    check_graph();
178
0
    is_built = true;
179
180
0
    if (verbose) {
181
0
        int max = 0, min = 1e6;
182
0
        double avg = 0;
183
184
0
        for (int i = 0; i < n; i++) {
185
0
            int size = 0;
186
0
            while (size < R && final_graph->at(i, size) != EMPTY_ID) {
187
0
                size += 1;
188
0
            }
189
0
            max = std::max(size, max);
190
0
            min = std::min(size, min);
191
0
            avg += size;
192
0
        }
193
194
0
        avg = avg / n;
195
0
        printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
196
0
               max,
197
0
               min,
198
0
               avg);
199
0
        printf("Attached nodes: %d\n", num_attached);
200
0
    }
201
0
}
202
203
0
void NSG::reset() {
204
0
    final_graph.reset();
205
0
    ntotal = 0;
206
0
    is_built = false;
207
0
}
208
209
0
void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
210
0
    int d = storage->d;
211
0
    int n = storage->ntotal;
212
213
0
    std::unique_ptr<float[]> center(new float[d]);
214
0
    std::unique_ptr<float[]> tmp(new float[d]);
215
0
    std::fill_n(center.get(), d, 0.0f);
216
217
0
    for (int i = 0; i < n; i++) {
218
0
        storage->reconstruct(i, tmp.get());
219
0
        for (int j = 0; j < d; j++) {
220
0
            center[j] += tmp[j];
221
0
        }
222
0
    }
223
224
0
    for (int i = 0; i < d; i++) {
225
0
        center[i] /= n;
226
0
    }
227
228
0
    std::vector<Neighbor> retset;
229
0
    std::vector<Node> tmpset;
230
231
    // random initialize navigating point
232
0
    int ep = rng.rand_int(n);
233
0
    std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
234
235
0
    dis->set_query(center.get());
236
0
    VisitedTable vt(ntotal);
237
238
    // Do not collect the visited nodes
239
0
    search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
240
241
    // set enterpoint
242
0
    enterpoint = retset[0].id;
243
0
}
244
245
template <bool collect_fullset, class index_t>
246
void NSG::search_on_graph(
247
        const nsg::Graph<index_t>& graph,
248
        DistanceComputer& dis,
249
        VisitedTable& vt,
250
        int ep,
251
        int pool_size,
252
        std::vector<Neighbor>& retset,
253
0
        std::vector<Node>& fullset) const {
254
0
    RandomGenerator gen(0x1234);
255
0
    retset.resize(pool_size + 1);
256
0
    std::vector<int> init_ids(pool_size);
257
258
0
    int num_ids = 0;
259
0
    std::vector<index_t> neighbors(graph.K);
260
0
    size_t nneigh = graph.get_neighbors(ep, neighbors.data());
261
0
    for (int i = 0; i < init_ids.size() && i < nneigh; i++) {
262
0
        int id = (int)neighbors[i];
263
0
        if (id >= ntotal) {
264
0
            continue;
265
0
        }
266
267
0
        init_ids[i] = id;
268
0
        vt.set(id);
269
0
        num_ids += 1;
270
0
    }
271
272
0
    while (num_ids < pool_size) {
273
0
        int id = gen.rand_int(ntotal);
274
0
        if (vt.get(id)) {
275
0
            continue;
276
0
        }
277
278
0
        init_ids[num_ids] = id;
279
0
        num_ids++;
280
0
        vt.set(id);
281
0
    }
282
283
0
    for (int i = 0; i < init_ids.size(); i++) {
284
0
        int id = init_ids[i];
285
286
0
        float dist = dis(id);
287
0
        retset[i] = Neighbor(id, dist, true);
288
289
0
        if (collect_fullset) {
290
0
            fullset.emplace_back(retset[i].id, retset[i].distance);
291
0
        }
292
0
    }
293
294
0
    std::sort(retset.begin(), retset.begin() + pool_size);
295
296
0
    int k = 0;
297
0
    while (k < pool_size) {
298
0
        int updated_pos = pool_size;
299
300
0
        if (retset[k].flag) {
301
0
            retset[k].flag = false;
302
0
            int n = retset[k].id;
303
304
0
            size_t nneigh_for_n = graph.get_neighbors(n, neighbors.data());
305
0
            for (int m = 0; m < nneigh_for_n; m++) {
306
0
                int id = neighbors[m];
307
0
                if (id > ntotal || vt.get(id)) {
308
0
                    continue;
309
0
                }
310
0
                vt.set(id);
311
312
0
                float dist = dis(id);
313
0
                Neighbor nn(id, dist, true);
314
0
                if (collect_fullset) {
315
0
                    fullset.emplace_back(id, dist);
316
0
                }
317
318
0
                if (dist >= retset[pool_size - 1].distance) {
319
0
                    continue;
320
0
                }
321
322
0
                int r = insert_into_pool(retset.data(), pool_size, nn);
323
324
0
                updated_pos = std::min(updated_pos, r);
325
0
            }
326
0
        }
327
328
0
        k = (updated_pos <= k) ? updated_pos : (k + 1);
329
0
    }
330
0
}
Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb0EiEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE
Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb0ElEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE
Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb1ElEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE
Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb1EiEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE
331
332
void NSG::link(
333
        Index* storage,
334
        const nsg::Graph<idx_t>& knn_graph,
335
        nsg::Graph<Node>& graph,
336
0
        bool /* verbose */) {
337
0
#pragma omp parallel
338
0
    {
339
0
        std::unique_ptr<float[]> vec(new float[storage->d]);
340
341
0
        std::vector<Node> pool;
342
0
        std::vector<Neighbor> tmp;
343
344
0
        VisitedTable vt(ntotal);
345
0
        std::unique_ptr<DistanceComputer> dis(
346
0
                storage_distance_computer(storage));
347
348
0
#pragma omp for schedule(dynamic, 100)
349
0
        for (int i = 0; i < ntotal; i++) {
350
0
            storage->reconstruct(i, vec.get());
351
0
            dis->set_query(vec.get());
352
353
            // Collect the visited nodes into pool
354
0
            search_on_graph<true>(
355
0
                    knn_graph, *dis, vt, enterpoint, L, tmp, pool);
356
357
0
            sync_prune(i, pool, *dis, vt, knn_graph, graph);
358
359
0
            pool.clear();
360
0
            tmp.clear();
361
0
            vt.advance();
362
0
        }
363
0
    } // omp parallel
364
365
0
    std::vector<std::mutex> locks(ntotal);
366
0
#pragma omp parallel
367
0
    {
368
0
        std::unique_ptr<DistanceComputer> dis(
369
0
                storage_distance_computer(storage));
370
371
0
#pragma omp for schedule(dynamic, 100)
372
0
        for (int i = 0; i < ntotal; ++i) {
373
0
            add_reverse_links(i, locks, *dis, graph);
374
0
        }
375
0
    } // omp parallel
376
0
}
377
378
void NSG::sync_prune(
379
        int q,
380
        std::vector<Node>& pool,
381
        DistanceComputer& dis,
382
        VisitedTable& vt,
383
        const nsg::Graph<idx_t>& knn_graph,
384
0
        nsg::Graph<Node>& graph) {
385
0
    for (int i = 0; i < knn_graph.K; i++) {
386
0
        int id = knn_graph.at(q, i);
387
0
        if (id < 0 || id >= ntotal || vt.get(id)) {
388
0
            continue;
389
0
        }
390
391
0
        float dist = dis.symmetric_dis(q, id);
392
0
        pool.emplace_back(id, dist);
393
0
    }
394
395
0
    std::sort(pool.begin(), pool.end());
396
397
0
    std::vector<Node> result;
398
399
0
    int start = 0;
400
0
    if (pool[start].id == q) {
401
0
        start++;
402
0
    }
403
0
    result.push_back(pool[start]);
404
405
0
    while (result.size() < R && (++start) < pool.size() && start < C) {
406
0
        auto& p = pool[start];
407
0
        bool occlude = false;
408
0
        for (int t = 0; t < result.size(); t++) {
409
0
            if (p.id == result[t].id) {
410
0
                occlude = true;
411
0
                break;
412
0
            }
413
0
            float djk = dis.symmetric_dis(result[t].id, p.id);
414
0
            if (djk < p.distance /* dik */) {
415
0
                occlude = true;
416
0
                break;
417
0
            }
418
0
        }
419
0
        if (!occlude) {
420
0
            result.push_back(p);
421
0
        }
422
0
    }
423
424
0
    for (size_t i = 0; i < R; i++) {
425
0
        if (i < result.size()) {
426
0
            graph.at(q, i).id = result[i].id;
427
0
            graph.at(q, i).distance = result[i].distance;
428
0
        } else {
429
0
            graph.at(q, i).id = EMPTY_ID;
430
0
        }
431
0
    }
432
0
}
433
434
void NSG::add_reverse_links(
435
        int q,
436
        std::vector<std::mutex>& locks,
437
        DistanceComputer& dis,
438
0
        nsg::Graph<Node>& graph) {
439
0
    for (size_t i = 0; i < R; i++) {
440
0
        if (graph.at(q, i).id == EMPTY_ID) {
441
0
            break;
442
0
        }
443
444
0
        Node sn(q, graph.at(q, i).distance);
445
0
        int des = graph.at(q, i).id;
446
447
0
        std::vector<Node> tmp_pool;
448
0
        int dup = 0;
449
0
        {
450
0
            LockGuard guard(locks[des]);
451
0
            for (int j = 0; j < R; j++) {
452
0
                if (graph.at(des, j).id == EMPTY_ID) {
453
0
                    break;
454
0
                }
455
0
                if (q == graph.at(des, j).id) {
456
0
                    dup = 1;
457
0
                    break;
458
0
                }
459
0
                tmp_pool.push_back(graph.at(des, j));
460
0
            }
461
0
        }
462
463
0
        if (dup) {
464
0
            continue;
465
0
        }
466
467
0
        tmp_pool.push_back(sn);
468
0
        if (tmp_pool.size() > R) {
469
0
            std::vector<Node> result;
470
0
            int start = 0;
471
0
            std::sort(tmp_pool.begin(), tmp_pool.end());
472
0
            result.push_back(tmp_pool[start]);
473
474
0
            while (result.size() < R && (++start) < tmp_pool.size()) {
475
0
                auto& p = tmp_pool[start];
476
0
                bool occlude = false;
477
478
0
                for (int t = 0; t < result.size(); t++) {
479
0
                    if (p.id == result[t].id) {
480
0
                        occlude = true;
481
0
                        break;
482
0
                    }
483
0
                    float djk = dis.symmetric_dis(result[t].id, p.id);
484
0
                    if (djk < p.distance /* dik */) {
485
0
                        occlude = true;
486
0
                        break;
487
0
                    }
488
0
                }
489
490
0
                if (!occlude) {
491
0
                    result.push_back(p);
492
0
                }
493
0
            }
494
495
0
            {
496
0
                LockGuard guard(locks[des]);
497
0
                for (int t = 0; t < result.size(); t++) {
498
0
                    graph.at(des, t) = result[t];
499
0
                }
500
0
            }
501
502
0
        } else {
503
0
            LockGuard guard(locks[des]);
504
0
            for (int t = 0; t < R; t++) {
505
0
                if (graph.at(des, t).id == EMPTY_ID) {
506
0
                    graph.at(des, t) = sn;
507
0
                    break;
508
0
                }
509
0
            }
510
0
        }
511
0
    }
512
0
}
513
514
0
int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
515
0
    int root = enterpoint;
516
0
    VisitedTable vt(ntotal);
517
0
    VisitedTable vt2(ntotal);
518
519
0
    int num_attached = 0;
520
0
    int cnt = 0;
521
0
    while (true) {
522
0
        cnt = dfs(vt, root, cnt);
523
0
        if (cnt >= ntotal) {
524
0
            break;
525
0
        }
526
527
0
        root = attach_unlinked(storage, vt, vt2, degrees);
528
0
        vt2.advance();
529
0
        num_attached += 1;
530
0
    }
531
532
0
    return num_attached;
533
0
}
534
535
0
int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
536
0
    int node = root;
537
0
    std::stack<int> stack;
538
0
    stack.push(root);
539
540
0
    if (!vt.get(root)) {
541
0
        cnt++;
542
0
    }
543
0
    vt.set(root);
544
545
0
    while (!stack.empty()) {
546
0
        int next = EMPTY_ID;
547
0
        for (int i = 0; i < R; i++) {
548
0
            int id = final_graph->at(node, i);
549
0
            if (id != EMPTY_ID && !vt.get(id)) {
550
0
                next = id;
551
0
                break;
552
0
            }
553
0
        }
554
555
0
        if (next == EMPTY_ID) {
556
0
            stack.pop();
557
0
            if (stack.empty()) {
558
0
                break;
559
0
            }
560
0
            node = stack.top();
561
0
            continue;
562
0
        }
563
0
        node = next;
564
0
        vt.set(node);
565
0
        stack.push(node);
566
0
        cnt++;
567
0
    }
568
569
0
    return cnt;
570
0
}
571
572
int NSG::attach_unlinked(
573
        Index* storage,
574
        VisitedTable& vt,
575
        VisitedTable& vt2,
576
0
        std::vector<int>& degrees) {
577
    /* NOTE: This implementation is slightly different from the original paper.
578
     *
579
     * Instead of connecting the unlinked node to the nearest point in the
580
     * spanning tree which will increase the maximum degree of the graph and
581
     * also make the graph hard to maintain, this implementation links the
582
     * unlinked node to the nearest node of which the degree is smaller than R.
583
     * It will keep the degree of all nodes to be no more than `R`.
584
     */
585
586
    // find one unlinked node
587
0
    int id = EMPTY_ID;
588
0
    for (int i = 0; i < ntotal; i++) {
589
0
        if (!vt.get(i)) {
590
0
            id = i;
591
0
            break;
592
0
        }
593
0
    }
594
595
0
    if (id == EMPTY_ID) {
596
0
        return EMPTY_ID; // No Unlinked Node
597
0
    }
598
599
0
    std::vector<Neighbor> tmp;
600
0
    std::vector<Node> pool;
601
602
0
    std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
603
0
    std::unique_ptr<float[]> vec(new float[storage->d]);
604
605
0
    storage->reconstruct(id, vec.get());
606
0
    dis->set_query(vec.get());
607
608
    // Collect the visited nodes into pool
609
0
    search_on_graph<true>(
610
0
            *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
611
612
0
    std::sort(pool.begin(), pool.end());
613
614
0
    int node;
615
0
    bool found = false;
616
0
    for (int i = 0; i < pool.size(); i++) {
617
0
        node = pool[i].id;
618
0
        if (degrees[node] < R && node != id) {
619
0
            found = true;
620
0
            break;
621
0
        }
622
0
    }
623
624
    // randomly choice annother node
625
0
    if (!found) {
626
0
        do {
627
0
            node = rng.rand_int(ntotal);
628
0
            if (vt.get(node) && degrees[node] < R && node != id) {
629
0
                found = true;
630
0
            }
631
0
        } while (!found);
632
0
    }
633
634
0
    int pos = degrees[node];
635
0
    final_graph->at(node, pos) = id; // replace
636
0
    degrees[node] += 1;
637
638
0
    return node;
639
0
}
640
641
0
void NSG::check_graph() const {
642
0
#pragma omp parallel for
643
0
    for (int i = 0; i < ntotal; i++) {
644
0
        for (int j = 0; j < R; j++) {
645
0
            int id = final_graph->at(i, j);
646
0
            FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
647
0
        }
648
0
    }
649
0
}
650
651
} // namespace faiss