Coverage Report

Created: 2025-11-28 19:49

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/HNSW.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/HNSW.h>
9
10
#include <cstddef>
11
12
#include <faiss/impl/AuxIndexStructures.h>
13
#include <faiss/impl/DistanceComputer.h>
14
#include <faiss/impl/IDSelector.h>
15
#include <faiss/impl/ResultHandler.h>
16
#include <faiss/utils/prefetch.h>
17
18
#include <faiss/impl/platform_macros.h>
19
20
#ifdef __AVX2__
21
#include <immintrin.h>
22
23
#include <limits>
24
#include <type_traits>
25
#endif
26
27
namespace faiss {
28
29
/**************************************************************
30
 * HNSW structure implementation
31
 **************************************************************/
32
33
28.2k
int HNSW::nb_neighbors(int layer_no) const {
34
28.2k
    FAISS_THROW_IF_NOT(layer_no + 1 < cum_nneighbor_per_level.size());
35
28.2k
    return cum_nneighbor_per_level[layer_no + 1] -
36
28.2k
            cum_nneighbor_per_level[layer_no];
37
28.2k
}
38
39
0
void HNSW::set_nb_neighbors(int level_no, int n) {
40
0
    FAISS_THROW_IF_NOT(levels.size() == 0);
41
0
    int cur_n = nb_neighbors(level_no);
42
0
    for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
43
0
        cum_nneighbor_per_level[i] += n - cur_n;
44
0
    }
45
0
}
46
47
4.74M
int HNSW::cum_nb_neighbors(int layer_no) const {
48
4.74M
    return cum_nneighbor_per_level[layer_no];
49
4.74M
}
50
51
void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end)
52
2.36M
        const {
53
2.36M
    size_t o = offsets[no];
54
2.36M
    *begin = o + cum_nb_neighbors(layer_no);
55
2.36M
    *end = o + cum_nb_neighbors(layer_no + 1);
56
2.36M
}
57
58
124
HNSW::HNSW(int M) : rng(12345) {
59
124
    set_default_probas(M, 1.0 / log(M));
60
124
    offsets.push_back(0);
61
124
}
62
63
27.0k
int HNSW::random_level() {
64
27.0k
    double f = rng.rand_float();
65
    // could be a bit faster with bissection
66
28.4k
    for (int level = 0; level < assign_probas.size(); level++) {
67
28.4k
        if (f < assign_probas[level]) {
68
27.0k
            return level;
69
27.0k
        }
70
1.39k
        f -= assign_probas[level];
71
1.39k
    }
72
    // happens with exponentially low probability
73
0
    return assign_probas.size() - 1;
74
27.0k
}
75
76
124
void HNSW::set_default_probas(int M, float levelMult) {
77
124
    int nn = 0;
78
124
    cum_nneighbor_per_level.push_back(0);
79
993
    for (int level = 0;; level++) {
80
993
        float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
81
993
        if (proba < 1e-9)
82
124
            break;
83
869
        assign_probas.push_back(proba);
84
869
        nn += level == 0 ? M * 2 : M;
85
869
        cum_nneighbor_per_level.push_back(nn);
86
869
    }
87
124
}
88
89
0
void HNSW::clear_neighbor_tables(int level) {
90
0
    for (int i = 0; i < levels.size(); i++) {
91
0
        size_t begin, end;
92
0
        neighbor_range(i, level, &begin, &end);
93
0
        for (size_t j = begin; j < end; j++) {
94
0
            neighbors[j] = -1;
95
0
        }
96
0
    }
97
0
}
98
99
0
void HNSW::reset() {
100
0
    max_level = -1;
101
0
    entry_point = -1;
102
0
    offsets.clear();
103
0
    offsets.push_back(0);
104
0
    levels.clear();
105
0
    neighbors.clear();
106
0
}
107
108
0
void HNSW::print_neighbor_stats(int level) const {
109
0
    FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
110
0
    printf("stats on level %d, max %d neighbors per vertex:\n",
111
0
           level,
112
0
           nb_neighbors(level));
113
0
    size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
114
0
#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
115
0
        reduction(+ : tot_reciprocal) reduction(+ : n_node)
116
0
    for (int i = 0; i < levels.size(); i++) {
117
0
        if (levels[i] > level) {
118
0
            n_node++;
119
0
            size_t begin, end;
120
0
            neighbor_range(i, level, &begin, &end);
121
0
            std::unordered_set<int> neighset;
122
0
            for (size_t j = begin; j < end; j++) {
123
0
                if (neighbors[j] < 0)
124
0
                    break;
125
0
                neighset.insert(neighbors[j]);
126
0
            }
127
0
            int n_neigh = neighset.size();
128
0
            int n_common = 0;
129
0
            int n_reciprocal = 0;
130
0
            for (size_t j = begin; j < end; j++) {
131
0
                storage_idx_t i2 = neighbors[j];
132
0
                if (i2 < 0)
133
0
                    break;
134
0
                FAISS_ASSERT(i2 != i);
135
0
                size_t begin2, end2;
136
0
                neighbor_range(i2, level, &begin2, &end2);
137
0
                for (size_t j2 = begin2; j2 < end2; j2++) {
138
0
                    storage_idx_t i3 = neighbors[j2];
139
0
                    if (i3 < 0)
140
0
                        break;
141
0
                    if (i3 == i) {
142
0
                        n_reciprocal++;
143
0
                        continue;
144
0
                    }
145
0
                    if (neighset.count(i3)) {
146
0
                        neighset.erase(i3);
147
0
                        n_common++;
148
0
                    }
149
0
                }
150
0
            }
151
0
            tot_neigh += n_neigh;
152
0
            tot_common += n_common;
153
0
            tot_reciprocal += n_reciprocal;
154
0
        }
155
0
    }
156
0
    float normalizer = n_node;
157
0
    printf("   nb of nodes at that level %zd\n", n_node);
158
0
    printf("   neighbors per node: %.2f (%zd)\n",
159
0
           tot_neigh / normalizer,
160
0
           tot_neigh);
161
0
    printf("   nb of reciprocal neighbors: %.2f\n",
162
0
           tot_reciprocal / normalizer);
163
0
    printf("   nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
164
0
           tot_common / normalizer,
165
0
           tot_common);
166
0
}
167
168
0
void HNSW::fill_with_random_links(size_t n) {
169
0
    int max_level_2 = prepare_level_tab(n);
170
0
    RandomGenerator rng2(456);
171
172
0
    for (int level = max_level_2 - 1; level >= 0; --level) {
173
0
        std::vector<int> elts;
174
0
        for (int i = 0; i < n; i++) {
175
0
            if (levels[i] > level) {
176
0
                elts.push_back(i);
177
0
            }
178
0
        }
179
0
        printf("linking %zd elements in level %d\n", elts.size(), level);
180
181
0
        if (elts.size() == 1)
182
0
            continue;
183
184
0
        for (int ii = 0; ii < elts.size(); ii++) {
185
0
            int i = elts[ii];
186
0
            size_t begin, end;
187
0
            neighbor_range(i, 0, &begin, &end);
188
0
            for (size_t j = begin; j < end; j++) {
189
0
                int other = 0;
190
0
                do {
191
0
                    other = elts[rng2.rand_int(elts.size())];
192
0
                } while (other == i);
193
194
0
                neighbors[j] = other;
195
0
            }
196
0
        }
197
0
    }
198
0
}
199
200
18.7k
int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
201
18.7k
    size_t n0 = offsets.size() - 1;
202
203
18.7k
    if (preset_levels) {
204
0
        FAISS_ASSERT(n0 + n == levels.size());
205
18.7k
    } else {
206
18.7k
        FAISS_ASSERT(n0 == levels.size());
207
45.7k
        for (int i = 0; i < n; i++) {
208
27.0k
            int pt_level = random_level();
209
27.0k
            levels.push_back(pt_level + 1);
210
27.0k
        }
211
18.7k
    }
212
213
18.7k
    int max_level_2 = 0;
214
45.7k
    for (int i = 0; i < n; i++) {
215
27.0k
        int pt_level = levels[i + n0] - 1;
216
27.0k
        if (pt_level > max_level_2)
217
1.21k
            max_level_2 = pt_level;
218
27.0k
        offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
219
27.0k
    }
220
18.7k
    neighbors.resize(offsets.back(), -1);
221
222
18.7k
    return max_level_2;
223
18.7k
}
224
225
/** Enumerate vertices from nearest to farthest from query, keep a
226
 * neighbor only if there is no previous neighbor that is closer to
227
 * that vertex than the query.
228
 */
229
void HNSW::shrink_neighbor_list(
230
        DistanceComputer& qdis,
231
        std::priority_queue<NodeDistFarther>& input,
232
        std::vector<NodeDistFarther>& output,
233
        int max_size,
234
38.0k
        bool keep_max_size_level0) {
235
    // This prevents number of neighbors at
236
    // level 0 from being shrunk to less than 2 * M.
237
    // This is essential in making sure
238
    // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
239
38.0k
    std::vector<NodeDistFarther> outsiders;
240
241
1.33M
    while (input.size() > 0) {
242
1.31M
        NodeDistFarther v1 = input.top();
243
1.31M
        input.pop();
244
1.31M
        float dist_v1_q = v1.d;
245
246
1.31M
        bool good = true;
247
7.26M
        for (NodeDistFarther v2 : output) {
248
7.26M
            float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
249
250
7.26M
            if (dist_v1_v2 < dist_v1_q) {
251
748k
                good = false;
252
748k
                break;
253
748k
            }
254
7.26M
        }
255
256
1.31M
        if (good) {
257
564k
            output.push_back(v1);
258
564k
            if (output.size() >= max_size) {
259
11.8k
                return;
260
11.8k
            }
261
748k
        } else if (keep_max_size_level0) {
262
0
            outsiders.push_back(v1);
263
0
        }
264
1.31M
    }
265
26.2k
    size_t idx = 0;
266
26.2k
    while (keep_max_size_level0 && (output.size() < max_size) &&
267
26.2k
           (idx < outsiders.size())) {
268
0
        output.push_back(outsiders[idx++]);
269
0
    }
270
26.2k
}
271
272
namespace {
273
274
using storage_idx_t = HNSW::storage_idx_t;
275
using NodeDistCloser = HNSW::NodeDistCloser;
276
using NodeDistFarther = HNSW::NodeDistFarther;
277
278
/**************************************************************
279
 * Addition subroutines
280
 **************************************************************/
281
282
/// remove neighbors from the list to make it smaller than max_size
283
void shrink_neighbor_list(
284
        DistanceComputer& qdis,
285
        std::priority_queue<NodeDistCloser>& resultSet1,
286
        int max_size,
287
51.6k
        bool keep_max_size_level0 = false) {
288
51.6k
    if (resultSet1.size() < max_size) {
289
13.5k
        return;
290
13.5k
    }
291
38.0k
    std::priority_queue<NodeDistFarther> resultSet;
292
38.0k
    std::vector<NodeDistFarther> returnlist;
293
294
1.36M
    while (resultSet1.size() > 0) {
295
1.32M
        resultSet.emplace(resultSet1.top().d, resultSet1.top().id);
296
1.32M
        resultSet1.pop();
297
1.32M
    }
298
299
38.0k
    HNSW::shrink_neighbor_list(
300
38.0k
            qdis, resultSet, returnlist, max_size, keep_max_size_level0);
301
302
565k
    for (NodeDistFarther curen2 : returnlist) {
303
565k
        resultSet1.emplace(curen2.d, curen2.id);
304
565k
    }
305
38.0k
}
306
307
/// add a link between two elements, possibly shrinking the list
308
/// of links to make room for it.
309
void add_link(
310
        HNSW& hnsw,
311
        DistanceComputer& qdis,
312
        storage_idx_t src,
313
        storage_idx_t dest,
314
        int level,
315
1.25M
        bool keep_max_size_level0 = false) {
316
1.25M
    size_t begin, end;
317
1.25M
    hnsw.neighbor_range(src, level, &begin, &end);
318
1.25M
    if (hnsw.neighbors[end - 1] == -1) {
319
        // there is enough room, find a slot to add it
320
1.22M
        size_t i = end;
321
63.0M
        while (i > begin) {
322
63.0M
            if (hnsw.neighbors[i - 1] != -1)
323
1.19M
                break;
324
61.8M
            i--;
325
61.8M
        }
326
1.22M
        hnsw.neighbors[i] = dest;
327
1.22M
        return;
328
1.22M
    }
329
330
    // otherwise we let them fight out which to keep
331
332
    // copy to resultSet...
333
30.3k
    std::priority_queue<NodeDistCloser> resultSet;
334
30.3k
    resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
335
760k
    for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
336
730k
        storage_idx_t neigh = hnsw.neighbors[i];
337
730k
        resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
338
730k
    }
339
340
30.3k
    shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
341
342
    // ...and back
343
30.3k
    size_t i = begin;
344
433k
    while (resultSet.size()) {
345
403k
        hnsw.neighbors[i++] = resultSet.top().id;
346
403k
        resultSet.pop();
347
403k
    }
348
    // they may have shrunk more than just by 1 element
349
358k
    while (i < end) {
350
328k
        hnsw.neighbors[i++] = -1;
351
328k
    }
352
30.3k
}
353
354
} // namespace
355
356
/// search neighbors on a single level, starting from an entry point
357
void search_neighbors_to_add(
358
        HNSW& hnsw,
359
        DistanceComputer& qdis,
360
        std::priority_queue<NodeDistCloser>& results,
361
        int entry_point,
362
        float d_entry_point,
363
        int level,
364
        VisitedTable& vt,
365
28.2k
        bool reference_version) {
366
    // top is nearest candidate
367
28.2k
    std::priority_queue<NodeDistFarther> candidates;
368
369
28.2k
    NodeDistFarther ev(d_entry_point, entry_point);
370
28.2k
    candidates.push(ev);
371
28.2k
    results.emplace(d_entry_point, entry_point);
372
28.2k
    vt.set(entry_point);
373
374
1.08M
    while (!candidates.empty()) {
375
        // get nearest
376
1.07M
        const NodeDistFarther& currEv = candidates.top();
377
378
1.07M
        if (currEv.d > results.top().d) {
379
24.2k
            break;
380
24.2k
        }
381
1.05M
        int currNode = currEv.id;
382
1.05M
        candidates.pop();
383
384
        // loop over neighbors
385
1.05M
        size_t begin, end;
386
1.05M
        hnsw.neighbor_range(currNode, level, &begin, &end);
387
388
        // The reference version is not used, but kept here because:
389
        // 1. It is easier to switch back if the optimized version has a problem
390
        // 2. It serves as a starting point for new optimizations
391
        // 3. It helps understand the code
392
        // 4. It ensures the reference version is still compilable if the
393
        // optimized version changes
394
        // The reference and the optimized versions' results are compared in
395
        // test_hnsw.cpp
396
1.05M
        if (reference_version) {
397
            // a reference version
398
0
            for (size_t i = begin; i < end; i++) {
399
0
                storage_idx_t nodeId = hnsw.neighbors[i];
400
0
                if (nodeId < 0)
401
0
                    break;
402
0
                if (vt.get(nodeId))
403
0
                    continue;
404
0
                vt.set(nodeId);
405
406
0
                float dis = qdis(nodeId);
407
0
                NodeDistFarther evE1(dis, nodeId);
408
409
0
                if (results.size() < hnsw.efConstruction ||
410
0
                    results.top().d > dis) {
411
0
                    results.emplace(dis, nodeId);
412
0
                    candidates.emplace(dis, nodeId);
413
0
                    if (results.size() > hnsw.efConstruction) {
414
0
                        results.pop();
415
0
                    }
416
0
                }
417
0
            }
418
1.05M
        } else {
419
            // a faster version
420
421
            // the following version processes 4 neighbors at a time
422
1.05M
            auto update_with_candidate = [&](const storage_idx_t idx,
423
7.34M
                                             const float dis) {
424
7.34M
                if (results.size() < hnsw.efConstruction ||
425
7.34M
                    results.top().d > dis) {
426
2.47M
                    results.emplace(dis, idx);
427
2.47M
                    candidates.emplace(dis, idx);
428
2.47M
                    if (results.size() > hnsw.efConstruction) {
429
1.46M
                        results.pop();
430
1.46M
                    }
431
2.47M
                }
432
7.34M
            };
433
434
1.05M
            int n_buffered = 0;
435
1.05M
            storage_idx_t buffered_ids[4];
436
437
33.2M
            for (size_t j = begin; j < end; j++) {
438
33.1M
                storage_idx_t nodeId = hnsw.neighbors[j];
439
33.1M
                if (nodeId < 0)
440
963k
                    break;
441
32.1M
                if (vt.get(nodeId)) {
442
25.2M
                    continue;
443
25.2M
                }
444
6.94M
                vt.set(nodeId);
445
446
6.94M
                buffered_ids[n_buffered] = nodeId;
447
6.94M
                n_buffered += 1;
448
449
6.94M
                if (n_buffered == 4) {
450
1.57M
                    float dis[4];
451
1.57M
                    qdis.distances_batch_4(
452
1.57M
                            buffered_ids[0],
453
1.57M
                            buffered_ids[1],
454
1.57M
                            buffered_ids[2],
455
1.57M
                            buffered_ids[3],
456
1.57M
                            dis[0],
457
1.57M
                            dis[1],
458
1.57M
                            dis[2],
459
1.57M
                            dis[3]);
460
461
7.86M
                    for (size_t id4 = 0; id4 < 4; id4++) {
462
6.28M
                        update_with_candidate(buffered_ids[id4], dis[id4]);
463
6.28M
                    }
464
465
1.57M
                    n_buffered = 0;
466
1.57M
                }
467
6.94M
            }
468
469
            // process leftovers
470
2.12M
            for (size_t icnt = 0; icnt < n_buffered; icnt++) {
471
1.07M
                float dis = qdis(buffered_ids[icnt]);
472
1.07M
                update_with_candidate(buffered_ids[icnt], dis);
473
1.07M
            }
474
1.05M
        }
475
1.05M
    }
476
477
28.2k
    vt.advance();
478
28.2k
}
479
480
/// Finds neighbors and builds links with them, starting from an entry
481
/// point. The own neighbor list is assumed to be locked.
482
void HNSW::add_links_starting_from(
483
        DistanceComputer& ptdis,
484
        storage_idx_t pt_id,
485
        storage_idx_t nearest,
486
        float d_nearest,
487
        int level,
488
        omp_lock_t* locks,
489
        VisitedTable& vt,
490
28.2k
        bool keep_max_size_level0) {
491
28.2k
    std::priority_queue<NodeDistCloser> link_targets;
492
493
28.2k
    search_neighbors_to_add(
494
28.2k
            *this, ptdis, link_targets, nearest, d_nearest, level, vt);
495
496
    // but we can afford only this many neighbors
497
28.2k
    int M = nb_neighbors(level);
498
499
28.2k
    ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
500
501
28.2k
    std::vector<storage_idx_t> neighbors_to_add;
502
28.2k
    neighbors_to_add.reserve(link_targets.size());
503
658k
    while (!link_targets.empty()) {
504
630k
        storage_idx_t other_id = link_targets.top().id;
505
630k
        add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
506
630k
        neighbors_to_add.push_back(other_id);
507
630k
        link_targets.pop();
508
630k
    }
509
510
28.2k
    omp_unset_lock(&locks[pt_id]);
511
630k
    for (storage_idx_t other_id : neighbors_to_add) {
512
630k
        omp_set_lock(&locks[other_id]);
513
630k
        add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
514
630k
        omp_unset_lock(&locks[other_id]);
515
630k
    }
516
28.2k
    omp_set_lock(&locks[pt_id]);
517
28.2k
}
518
519
/**************************************************************
520
 * Building, parallel
521
 **************************************************************/
522
523
void HNSW::add_with_locks(
524
        DistanceComputer& ptdis,
525
        int pt_level,
526
        int pt_id,
527
        std::vector<omp_lock_t>& locks,
528
        VisitedTable& vt,
529
27.0k
        bool keep_max_size_level0) {
530
    //  greedy search on upper levels
531
532
27.0k
    storage_idx_t nearest;
533
27.0k
#pragma omp critical
534
27.0k
    {
535
27.0k
        nearest = entry_point;
536
537
27.0k
        if (nearest == -1) {
538
82
            max_level = pt_level;
539
82
            entry_point = pt_id;
540
82
        }
541
27.0k
    }
542
543
27.0k
    if (nearest < 0) {
544
82
        return;
545
82
    }
546
547
26.9k
    omp_set_lock(&locks[pt_id]);
548
549
26.9k
    int level = max_level; // level at which we start adding neighbors
550
26.9k
    float d_nearest = ptdis(nearest);
551
552
59.6k
    for (; level > pt_level; level--) {
553
32.7k
        greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
554
32.7k
    }
555
556
55.1k
    for (; level >= 0; level--) {
557
28.2k
        add_links_starting_from(
558
28.2k
                ptdis,
559
28.2k
                pt_id,
560
28.2k
                nearest,
561
28.2k
                d_nearest,
562
28.2k
                level,
563
28.2k
                locks.data(),
564
28.2k
                vt,
565
28.2k
                keep_max_size_level0);
566
28.2k
    }
567
568
26.9k
    omp_unset_lock(&locks[pt_id]);
569
570
26.9k
    if (pt_level > max_level) {
571
56
        max_level = pt_level;
572
56
        entry_point = pt_id;
573
56
    }
574
26.9k
}
575
576
/**************************************************************
577
 * Searching
578
 **************************************************************/
579
580
using MinimaxHeap = HNSW::MinimaxHeap;
581
using Node = HNSW::Node;
582
using C = HNSW::C;
583
/** Do a BFS on the candidates list */
584
int search_from_candidates(
585
        const HNSW& hnsw,
586
        DistanceComputer& qdis,
587
        ResultHandler<C>& res,
588
        MinimaxHeap& candidates,
589
        VisitedTable& vt,
590
        HNSWStats& stats,
591
        int level,
592
        int nres_in,
593
97
        const SearchParameters* params) {
594
97
    int nres = nres_in;
595
97
    int ndis = 0;
596
597
    // can be overridden by search params
598
97
    bool do_dis_check = hnsw.check_relative_distance;
599
97
    int efSearch = hnsw.efSearch;
600
97
    const IDSelector* sel = nullptr;
601
97
    if (params) {
602
87
        if (const SearchParametersHNSW* hnsw_params =
603
87
                    dynamic_cast<const SearchParametersHNSW*>(params)) {
604
87
            do_dis_check = hnsw_params->check_relative_distance;
605
87
            efSearch = hnsw_params->efSearch;
606
87
        }
607
87
        sel = params->sel;
608
87
    }
609
610
97
    C::T threshold = res.threshold;
611
194
    for (int i = 0; i < candidates.size(); i++) {
612
97
        idx_t v1 = candidates.ids[i];
613
97
        float d = candidates.dis[i];
614
97
        FAISS_ASSERT(v1 >= 0);
615
97
        if (!sel || sel->is_member(v1)) {
616
94
            if (d < threshold) {
617
81
                if (res.add_result(d, v1)) {
618
32
                    threshold = res.threshold;
619
32
                }
620
81
            }
621
94
        }
622
97
        vt.set(v1);
623
97
    }
624
625
97
    int nstep = 0;
626
627
2.77k
    while (candidates.size() > 0) {
628
2.68k
        float d0 = 0;
629
2.68k
        int v0 = candidates.pop_min(&d0);
630
631
2.68k
        if (do_dis_check) {
632
            // tricky stopping condition: there are more that ef
633
            // distances that are processed already that are smaller
634
            // than d0
635
636
2.68k
            int n_dis_below = candidates.count_below(d0);
637
2.68k
            if (n_dis_below >= efSearch) {
638
13
                break;
639
13
            }
640
2.68k
        }
641
642
2.67k
        size_t begin, end;
643
2.67k
        hnsw.neighbor_range(v0, level, &begin, &end);
644
645
        // a faster version: reference version in unit test test_hnsw.cpp
646
        // the following version processes 4 neighbors at a time
647
2.67k
        size_t jmax = begin;
648
121k
        for (size_t j = begin; j < end; j++) {
649
121k
            int v1 = hnsw.neighbors[j];
650
121k
            if (v1 < 0)
651
2.54k
                break;
652
653
118k
            prefetch_L2(vt.visited.data() + v1);
654
118k
            jmax += 1;
655
118k
        }
656
657
2.67k
        int counter = 0;
658
2.67k
        size_t saved_j[4];
659
660
2.67k
        threshold = res.threshold;
661
662
17.4k
        auto add_to_heap = [&](const size_t idx, const float dis) {
663
17.4k
            if (!sel || sel->is_member(idx)) {
664
17.1k
                if (dis < threshold) {
665
8.73k
                    if (res.add_result(dis, idx)) {
666
4.78k
                        threshold = res.threshold;
667
4.78k
                        nres += 1;
668
4.78k
                    }
669
8.73k
                }
670
17.1k
            }
671
17.4k
            candidates.push(idx, dis);
672
17.4k
        };
673
674
121k
        for (size_t j = begin; j < jmax; j++) {
675
118k
            int v1 = hnsw.neighbors[j];
676
677
118k
            bool vget = vt.get(v1);
678
118k
            vt.set(v1);
679
118k
            saved_j[counter] = v1;
680
118k
            counter += vget ? 0 : 1;
681
682
118k
            if (counter == 4) {
683
3.81k
                float dis[4];
684
3.81k
                qdis.distances_batch_4(
685
3.81k
                        saved_j[0],
686
3.81k
                        saved_j[1],
687
3.81k
                        saved_j[2],
688
3.81k
                        saved_j[3],
689
3.81k
                        dis[0],
690
3.81k
                        dis[1],
691
3.81k
                        dis[2],
692
3.81k
                        dis[3]);
693
694
19.0k
                for (size_t id4 = 0; id4 < 4; id4++) {
695
15.2k
                    add_to_heap(saved_j[id4], dis[id4]);
696
15.2k
                }
697
698
3.81k
                ndis += 4;
699
700
3.81k
                counter = 0;
701
3.81k
            }
702
118k
        }
703
704
4.84k
        for (size_t icnt = 0; icnt < counter; icnt++) {
705
2.16k
            float dis = qdis(saved_j[icnt]);
706
2.16k
            add_to_heap(saved_j[icnt], dis);
707
708
2.16k
            ndis += 1;
709
2.16k
        }
710
711
2.67k
        nstep++;
712
2.67k
        if (!do_dis_check && nstep > efSearch) {
713
0
            break;
714
0
        }
715
2.67k
    }
716
717
97
    if (level == 0) {
718
97
        stats.n1++;
719
97
        if (candidates.size() == 0) {
720
84
            stats.n2++;
721
84
        }
722
97
        stats.ndis += ndis;
723
97
        stats.nhops += nstep;
724
97
    }
725
726
97
    return nres;
727
97
}
728
729
std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
730
        const HNSW& hnsw,
731
        const Node& node,
732
        DistanceComputer& qdis,
733
        int ef,
734
        VisitedTable* vt,
735
0
        HNSWStats& stats) {
736
0
    int ndis = 0;
737
0
    std::priority_queue<Node> top_candidates;
738
0
    std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
739
740
0
    top_candidates.push(node);
741
0
    candidates.push(node);
742
743
0
    vt->set(node.second);
744
745
0
    while (!candidates.empty()) {
746
0
        float d0;
747
0
        storage_idx_t v0;
748
0
        std::tie(d0, v0) = candidates.top();
749
750
0
        if (d0 > top_candidates.top().first) {
751
0
            break;
752
0
        }
753
754
0
        candidates.pop();
755
756
0
        size_t begin, end;
757
0
        hnsw.neighbor_range(v0, 0, &begin, &end);
758
759
        // a faster version: reference version in unit test test_hnsw.cpp
760
        // the following version processes 4 neighbors at a time
761
0
        size_t jmax = begin;
762
0
        for (size_t j = begin; j < end; j++) {
763
0
            int v1 = hnsw.neighbors[j];
764
0
            if (v1 < 0)
765
0
                break;
766
767
0
            prefetch_L2(vt->visited.data() + v1);
768
0
            jmax += 1;
769
0
        }
770
771
0
        int counter = 0;
772
0
        size_t saved_j[4];
773
774
0
        auto add_to_heap = [&](const size_t idx, const float dis) {
775
0
            if (top_candidates.top().first > dis ||
776
0
                top_candidates.size() < ef) {
777
0
                candidates.emplace(dis, idx);
778
0
                top_candidates.emplace(dis, idx);
779
780
0
                if (top_candidates.size() > ef) {
781
0
                    top_candidates.pop();
782
0
                }
783
0
            }
784
0
        };
785
786
0
        for (size_t j = begin; j < jmax; j++) {
787
0
            int v1 = hnsw.neighbors[j];
788
789
0
            bool vget = vt->get(v1);
790
0
            vt->set(v1);
791
0
            saved_j[counter] = v1;
792
0
            counter += vget ? 0 : 1;
793
794
0
            if (counter == 4) {
795
0
                float dis[4];
796
0
                qdis.distances_batch_4(
797
0
                        saved_j[0],
798
0
                        saved_j[1],
799
0
                        saved_j[2],
800
0
                        saved_j[3],
801
0
                        dis[0],
802
0
                        dis[1],
803
0
                        dis[2],
804
0
                        dis[3]);
805
806
0
                for (size_t id4 = 0; id4 < 4; id4++) {
807
0
                    add_to_heap(saved_j[id4], dis[id4]);
808
0
                }
809
810
0
                ndis += 4;
811
812
0
                counter = 0;
813
0
            }
814
0
        }
815
816
0
        for (size_t icnt = 0; icnt < counter; icnt++) {
817
0
            float dis = qdis(saved_j[icnt]);
818
0
            add_to_heap(saved_j[icnt], dis);
819
820
0
            ndis += 1;
821
0
        }
822
823
0
        stats.nhops += 1;
824
0
    }
825
826
0
    ++stats.n1;
827
0
    if (candidates.size() == 0) {
828
0
        ++stats.n2;
829
0
    }
830
0
    stats.ndis += ndis;
831
832
0
    return top_candidates;
833
0
}
834
835
/// greedily update a nearest vector at a given level
836
HNSWStats greedy_update_nearest(
837
        const HNSW& hnsw,
838
        DistanceComputer& qdis,
839
        int level,
840
        storage_idx_t& nearest,
841
32.8k
        float& d_nearest) {
842
32.8k
    HNSWStats stats;
843
844
59.6k
    for (;;) {
845
59.6k
        storage_idx_t prev_nearest = nearest;
846
847
59.6k
        size_t begin, end;
848
59.6k
        hnsw.neighbor_range(nearest, level, &begin, &end);
849
850
59.6k
        size_t ndis = 0;
851
852
        // a faster version: reference version in unit test test_hnsw.cpp
853
        // the following version processes 4 neighbors at a time
854
59.6k
        auto update_with_candidate = [&](const storage_idx_t idx,
855
364k
                                         const float dis) {
856
364k
            if (dis < d_nearest) {
857
45.5k
                nearest = idx;
858
45.5k
                d_nearest = dis;
859
45.5k
            }
860
364k
        };
861
862
59.6k
        int n_buffered = 0;
863
59.6k
        storage_idx_t buffered_ids[4];
864
865
424k
        for (size_t j = begin; j < end; j++) {
866
419k
            storage_idx_t v = hnsw.neighbors[j];
867
419k
            if (v < 0)
868
55.2k
                break;
869
364k
            ndis += 1;
870
871
364k
            buffered_ids[n_buffered] = v;
872
364k
            n_buffered += 1;
873
874
364k
            if (n_buffered == 4) {
875
70.7k
                float dis[4];
876
70.7k
                qdis.distances_batch_4(
877
70.7k
                        buffered_ids[0],
878
70.7k
                        buffered_ids[1],
879
70.7k
                        buffered_ids[2],
880
70.7k
                        buffered_ids[3],
881
70.7k
                        dis[0],
882
70.7k
                        dis[1],
883
70.7k
                        dis[2],
884
70.7k
                        dis[3]);
885
886
353k
                for (size_t id4 = 0; id4 < 4; id4++) {
887
283k
                    update_with_candidate(buffered_ids[id4], dis[id4]);
888
283k
                }
889
890
70.7k
                n_buffered = 0;
891
70.7k
            }
892
364k
        }
893
894
        // process leftovers
895
141k
        for (size_t icnt = 0; icnt < n_buffered; icnt++) {
896
81.6k
            float dis = qdis(buffered_ids[icnt]);
897
81.6k
            update_with_candidate(buffered_ids[icnt], dis);
898
81.6k
        }
899
900
        // update stats
901
59.6k
        stats.ndis += ndis;
902
59.6k
        stats.nhops += 1;
903
904
59.6k
        if (nearest == prev_nearest) {
905
32.8k
            return stats;
906
32.8k
        }
907
59.6k
    }
908
32.8k
}
909
910
namespace {
911
using MinimaxHeap = HNSW::MinimaxHeap;
912
using Node = HNSW::Node;
913
using C = HNSW::C;
914
915
// just used as a lower bound for the minmaxheap, but it is set for heap search
916
97
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
917
97
    using RH = HeapBlockResultHandler<C>;
918
97
    if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
919
32
        return hres->k;
920
32
    }
921
65
    return 1;
922
97
}
923
924
} // namespace
925
926
HNSWStats HNSW::search(
927
        DistanceComputer& qdis,
928
        ResultHandler<C>& res,
929
        VisitedTable& vt,
930
110
        const SearchParameters* params) const {
931
110
    HNSWStats stats;
932
110
    if (entry_point == -1) {
933
13
        return stats;
934
13
    }
935
97
    int k = extract_k_from_ResultHandler(res);
936
937
97
    bool bounded_queue = this->search_bounded_queue;
938
97
    int efSearch = this->efSearch;
939
97
    if (params) {
940
87
        if (const SearchParametersHNSW* hnsw_params =
941
87
                    dynamic_cast<const SearchParametersHNSW*>(params)) {
942
87
            bounded_queue = hnsw_params->bounded_queue;
943
87
            efSearch = hnsw_params->efSearch;
944
87
        }
945
87
    }
946
947
    //  greedy search on upper levels
948
97
    storage_idx_t nearest = entry_point;
949
97
    float d_nearest = qdis(nearest);
950
951
213
    for (int level = max_level; level >= 1; level--) {
952
116
        HNSWStats local_stats =
953
116
                greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
954
116
        stats.combine(local_stats);
955
116
    }
956
957
97
    int ef = std::max(efSearch, k);
958
97
    if (bounded_queue) { // this is the most common branch
959
97
        MinimaxHeap candidates(ef);
960
961
97
        candidates.push(nearest, d_nearest);
962
963
97
        search_from_candidates(
964
97
                *this, qdis, res, candidates, vt, stats, 0, 0, params);
965
97
    } else {
966
0
        std::priority_queue<Node> top_candidates =
967
0
                search_from_candidate_unbounded(
968
0
                        *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
969
970
0
        while (top_candidates.size() > k) {
971
0
            top_candidates.pop();
972
0
        }
973
974
0
        while (!top_candidates.empty()) {
975
0
            float d;
976
0
            storage_idx_t label;
977
0
            std::tie(d, label) = top_candidates.top();
978
0
            res.add_result(d, label);
979
0
            top_candidates.pop();
980
0
        }
981
0
    }
982
983
97
    vt.advance();
984
985
97
    return stats;
986
110
}
987
988
void HNSW::search_level_0(
989
        DistanceComputer& qdis,
990
        ResultHandler<C>& res,
991
        idx_t nprobe,
992
        const storage_idx_t* nearest_i,
993
        const float* nearest_d,
994
        int search_type,
995
        HNSWStats& search_stats,
996
        VisitedTable& vt,
997
0
        const SearchParameters* params) const {
998
0
    const HNSW& hnsw = *this;
999
1000
0
    auto efSearch = hnsw.efSearch;
1001
0
    if (params) {
1002
0
        if (const SearchParametersHNSW* hnsw_params =
1003
0
                    dynamic_cast<const SearchParametersHNSW*>(params)) {
1004
0
            efSearch = hnsw_params->efSearch;
1005
0
        }
1006
0
    }
1007
1008
0
    int k = extract_k_from_ResultHandler(res);
1009
1010
0
    if (search_type == 1) {
1011
0
        int nres = 0;
1012
1013
0
        for (int j = 0; j < nprobe; j++) {
1014
0
            storage_idx_t cj = nearest_i[j];
1015
1016
0
            if (cj < 0)
1017
0
                break;
1018
1019
0
            if (vt.get(cj))
1020
0
                continue;
1021
1022
0
            int candidates_size = std::max(efSearch, k);
1023
0
            MinimaxHeap candidates(candidates_size);
1024
1025
0
            candidates.push(cj, nearest_d[j]);
1026
1027
0
            nres = search_from_candidates(
1028
0
                    hnsw,
1029
0
                    qdis,
1030
0
                    res,
1031
0
                    candidates,
1032
0
                    vt,
1033
0
                    search_stats,
1034
0
                    0,
1035
0
                    nres,
1036
0
                    params);
1037
0
            nres = std::min(nres, candidates_size);
1038
0
        }
1039
0
    } else if (search_type == 2) {
1040
0
        int candidates_size = std::max(efSearch, int(k));
1041
0
        candidates_size = std::max(candidates_size, int(nprobe));
1042
1043
0
        MinimaxHeap candidates(candidates_size);
1044
0
        for (int j = 0; j < nprobe; j++) {
1045
0
            storage_idx_t cj = nearest_i[j];
1046
1047
0
            if (cj < 0)
1048
0
                break;
1049
0
            candidates.push(cj, nearest_d[j]);
1050
0
        }
1051
1052
0
        search_from_candidates(
1053
0
                hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
1054
0
    }
1055
0
}
1056
1057
0
void HNSW::permute_entries(const idx_t* map) {
1058
    // remap levels
1059
0
    storage_idx_t ntotal = levels.size();
1060
0
    std::vector<storage_idx_t> imap(ntotal); // inverse mapping
1061
    // map: new index -> old index
1062
    // imap: old index -> new index
1063
0
    for (int i = 0; i < ntotal; i++) {
1064
0
        assert(map[i] >= 0 && map[i] < ntotal);
1065
0
        imap[map[i]] = i;
1066
0
    }
1067
0
    if (entry_point != -1) {
1068
0
        entry_point = imap[entry_point];
1069
0
    }
1070
0
    std::vector<int> new_levels(ntotal);
1071
0
    std::vector<size_t> new_offsets(ntotal + 1);
1072
0
    std::vector<storage_idx_t> new_neighbors(neighbors.size());
1073
0
    size_t no = 0;
1074
0
    for (int i = 0; i < ntotal; i++) {
1075
0
        storage_idx_t o = map[i]; // corresponding "old" index
1076
0
        new_levels[i] = levels[o];
1077
0
        for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
1078
0
            storage_idx_t neigh = neighbors[j];
1079
0
            new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
1080
0
        }
1081
0
        new_offsets[i + 1] = no;
1082
0
    }
1083
0
    assert(new_offsets[ntotal] == offsets[ntotal]);
1084
    // swap everyone
1085
0
    std::swap(levels, new_levels);
1086
0
    std::swap(offsets, new_offsets);
1087
0
    neighbors = std::move(new_neighbors);
1088
0
}
1089
1090
/**************************************************************
1091
 * MinimaxHeap
1092
 **************************************************************/
1093
1094
17.5k
void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
1095
17.5k
    if (k == n) {
1096
11.5k
        if (v >= dis[0])
1097
9.16k
            return;
1098
2.41k
        if (ids[0] != -1) {
1099
2.36k
            --nvalid;
1100
2.36k
        }
1101
2.41k
        faiss::heap_pop<HC>(k--, dis.data(), ids.data());
1102
2.41k
    }
1103
8.37k
    faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
1104
8.37k
    ++nvalid;
1105
8.37k
}
1106
1107
0
float HNSW::MinimaxHeap::max() const {
1108
0
    return dis[0];
1109
0
}
1110
1111
3.06k
int HNSW::MinimaxHeap::size() const {
1112
3.06k
    return nvalid;
1113
3.06k
}
1114
1115
0
void HNSW::MinimaxHeap::clear() {
1116
0
    nvalid = k = 0;
1117
0
}
1118
1119
#ifdef __AVX512F__
1120
1121
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1122
    assert(k > 0);
1123
    static_assert(
1124
            std::is_same<storage_idx_t, int32_t>::value,
1125
            "This code expects storage_idx_t to be int32_t");
1126
1127
    int32_t min_idx = -1;
1128
    float min_dis = std::numeric_limits<float>::infinity();
1129
1130
    __m512i min_indices = _mm512_set1_epi32(-1);
1131
    __m512 min_distances =
1132
            _mm512_set1_ps(std::numeric_limits<float>::infinity());
1133
    __m512i current_indices = _mm512_setr_epi32(
1134
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1135
    __m512i offset = _mm512_set1_epi32(16);
1136
1137
    // The following loop tracks the rightmost index with the min distance.
1138
    // -1 index values are ignored.
1139
    const int k16 = (k / 16) * 16;
1140
    for (size_t iii = 0; iii < k16; iii += 16) {
1141
        __m512i indices =
1142
                _mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1143
        __m512 distances = _mm512_loadu_ps(dis.data() + iii);
1144
1145
        // This mask filters out -1 values among indices.
1146
        __mmask16 m1mask =
1147
                _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1148
1149
        __mmask16 dmask =
1150
                _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1151
        __mmask16 finalmask = m1mask | dmask;
1152
1153
        const __m512i min_indices_new = _mm512_mask_blend_epi32(
1154
                finalmask, current_indices, min_indices);
1155
        const __m512 min_distances_new =
1156
                _mm512_mask_blend_ps(finalmask, distances, min_distances);
1157
1158
        min_indices = min_indices_new;
1159
        min_distances = min_distances_new;
1160
1161
        current_indices = _mm512_add_epi32(current_indices, offset);
1162
    }
1163
1164
    // leftovers
1165
    if (k16 != k) {
1166
        const __mmask16 kmask = (1 << (k - k16)) - 1;
1167
1168
        __m512i indices = _mm512_mask_loadu_epi32(
1169
                _mm512_set1_epi32(-1), kmask, ids.data() + k16);
1170
        __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1171
1172
        // This mask filters out -1 values among indices.
1173
        __mmask16 m1mask =
1174
                _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1175
1176
        __mmask16 dmask =
1177
                _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1178
        __mmask16 finalmask = m1mask | dmask;
1179
1180
        const __m512i min_indices_new = _mm512_mask_blend_epi32(
1181
                finalmask, current_indices, min_indices);
1182
        const __m512 min_distances_new =
1183
                _mm512_mask_blend_ps(finalmask, distances, min_distances);
1184
1185
        min_indices = min_indices_new;
1186
        min_distances = min_distances_new;
1187
    }
1188
1189
    // grab min distance
1190
    min_dis = _mm512_reduce_min_ps(min_distances);
1191
    // blend
1192
    __mmask16 mindmask =
1193
            _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1194
    // pick the max one
1195
    min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1196
1197
    if (min_idx == -1) {
1198
        return -1;
1199
    }
1200
1201
    if (vmin_out) {
1202
        *vmin_out = min_dis;
1203
    }
1204
    int ret = ids[min_idx];
1205
    ids[min_idx] = -1;
1206
    --nvalid;
1207
    return ret;
1208
}
1209
1210
#elif __AVX2__
1211
1212
2.68k
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1213
2.68k
    assert(k > 0);
1214
2.68k
    static_assert(
1215
2.68k
            std::is_same<storage_idx_t, int32_t>::value,
1216
2.68k
            "This code expects storage_idx_t to be int32_t");
1217
1218
2.68k
    int32_t min_idx = -1;
1219
2.68k
    float min_dis = std::numeric_limits<float>::infinity();
1220
1221
2.68k
    size_t iii = 0;
1222
1223
2.68k
    __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1224
2.68k
    __m256 min_distances =
1225
2.68k
            _mm256_set1_ps(std::numeric_limits<float>::infinity());
1226
2.68k
    __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1227
2.68k
    __m256i offset = _mm256_set1_epi32(8);
1228
1229
    // The baseline version is available in non-AVX2 branch.
1230
1231
    // The following loop tracks the rightmost index with the min distance.
1232
    // -1 index values are ignored.
1233
2.68k
    const int k8 = (k / 8) * 8;
1234
52.3k
    for (; iii < k8; iii += 8) {
1235
49.6k
        __m256i indices =
1236
49.6k
                _mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1237
49.6k
        __m256 distances = _mm256_loadu_ps(dis.data() + iii);
1238
1239
        // This mask filters out -1 values among indices.
1240
49.6k
        __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1241
1242
49.6k
        __m256i dmask = _mm256_castps_si256(
1243
49.6k
                _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1244
49.6k
        __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1245
1246
49.6k
        const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1247
49.6k
                _mm256_castsi256_ps(current_indices),
1248
49.6k
                _mm256_castsi256_ps(min_indices),
1249
49.6k
                finalmask));
1250
1251
49.6k
        const __m256 min_distances_new =
1252
49.6k
                _mm256_blendv_ps(distances, min_distances, finalmask);
1253
1254
49.6k
        min_indices = min_indices_new;
1255
49.6k
        min_distances = min_distances_new;
1256
1257
49.6k
        current_indices = _mm256_add_epi32(current_indices, offset);
1258
49.6k
    }
1259
1260
    // Vectorizing is doable, but is not practical
1261
2.68k
    int32_t vidx8[8];
1262
2.68k
    float vdis8[8];
1263
2.68k
    _mm256_storeu_ps(vdis8, min_distances);
1264
2.68k
    _mm256_storeu_si256((__m256i*)vidx8, min_indices);
1265
1266
24.2k
    for (size_t j = 0; j < 8; j++) {
1267
21.5k
        if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1268
7.78k
            min_idx = vidx8[j];
1269
7.78k
            min_dis = vdis8[j];
1270
7.78k
        }
1271
21.5k
    }
1272
1273
    // process last values. Vectorizing is doable, but is not practical
1274
7.96k
    for (; iii < k; iii++) {
1275
5.27k
        if (ids[iii] != -1 && dis[iii] <= min_dis) {
1276
254
            min_dis = dis[iii];
1277
254
            min_idx = iii;
1278
254
        }
1279
5.27k
    }
1280
1281
2.68k
    if (min_idx == -1) {
1282
0
        return -1;
1283
0
    }
1284
1285
2.68k
    if (vmin_out) {
1286
2.68k
        *vmin_out = min_dis;
1287
2.68k
    }
1288
2.68k
    int ret = ids[min_idx];
1289
2.68k
    ids[min_idx] = -1;
1290
2.68k
    --nvalid;
1291
2.68k
    return ret;
1292
2.68k
}
1293
1294
#else
1295
1296
// baseline non-vectorized version
1297
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1298
    assert(k > 0);
1299
    // returns min. This is an O(n) operation
1300
    int i = k - 1;
1301
    while (i >= 0) {
1302
        if (ids[i] != -1) {
1303
            break;
1304
        }
1305
        i--;
1306
    }
1307
    if (i == -1) {
1308
        return -1;
1309
    }
1310
    int imin = i;
1311
    float vmin = dis[i];
1312
    i--;
1313
    while (i >= 0) {
1314
        if (ids[i] != -1 && dis[i] < vmin) {
1315
            vmin = dis[i];
1316
            imin = i;
1317
        }
1318
        i--;
1319
    }
1320
    if (vmin_out) {
1321
        *vmin_out = vmin;
1322
    }
1323
    int ret = ids[imin];
1324
    ids[imin] = -1;
1325
    --nvalid;
1326
1327
    return ret;
1328
}
1329
#endif
1330
1331
2.68k
int HNSW::MinimaxHeap::count_below(float thresh) {
1332
2.68k
    int n_below = 0;
1333
405k
    for (int i = 0; i < k; i++) {
1334
402k
        if (dis[i] < thresh) {
1335
79.2k
            n_below++;
1336
79.2k
        }
1337
402k
    }
1338
1339
2.68k
    return n_below;
1340
2.68k
}
1341
1342
} // namespace faiss