Coverage Report

Created: 2025-10-14 07:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexHNSW.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexHNSW.h>
9
10
#include <omp.h>
11
#include <cinttypes>
12
#include <cstdio>
13
#include <cstdlib>
14
#include <cstring>
15
16
#include <limits>
17
#include <memory>
18
#include <queue>
19
#include <random>
20
21
#include <cstdint>
22
23
#include <faiss/Index2Layer.h>
24
#include <faiss/IndexFlat.h>
25
#include <faiss/IndexIVFPQ.h>
26
#include <faiss/impl/AuxIndexStructures.h>
27
#include <faiss/impl/FaissAssert.h>
28
#include <faiss/impl/ResultHandler.h>
29
#include <faiss/utils/random.h>
30
#include <faiss/utils/sorting.h>
31
32
namespace faiss {
33
34
using MinimaxHeap = HNSW::MinimaxHeap;
35
using storage_idx_t = HNSW::storage_idx_t;
36
using NodeDistFarther = HNSW::NodeDistFarther;
37
38
HNSWStats hnsw_stats;
39
40
/**************************************************************
41
 * add / search blocks of descriptors
42
 **************************************************************/
43
44
namespace {
45
46
20.1k
DistanceComputer* storage_distance_computer(const Index* storage) {
47
20.1k
    if (is_similarity_metric(storage->metric_type)) {
48
1.47k
        return new NegativeDistanceComputer(storage->get_distance_computer());
49
18.6k
    } else {
50
18.6k
        return storage->get_distance_computer();
51
18.6k
    }
52
20.1k
}
53
54
void hnsw_add_vertices(
55
        IndexHNSW& index_hnsw,
56
        size_t n0,
57
        size_t n,
58
        const float* x,
59
        bool verbose,
60
18.8k
        bool preset_levels = false) {
61
18.8k
    size_t d = index_hnsw.d;
62
18.8k
    HNSW& hnsw = index_hnsw.hnsw;
63
18.8k
    size_t ntotal = n0 + n;
64
18.8k
    double t0 = getmillisecs();
65
18.8k
    if (verbose) {
66
0
        printf("hnsw_add_vertices: adding %zd elements on top of %zd "
67
0
               "(preset_levels=%d)\n",
68
0
               n,
69
0
               n0,
70
0
               int(preset_levels));
71
0
    }
72
73
18.8k
    if (n == 0) {
74
0
        return;
75
0
    }
76
77
18.8k
    int max_level = hnsw.prepare_level_tab(n, preset_levels);
78
79
18.8k
    if (verbose) {
80
0
        printf("  max_level = %d\n", max_level);
81
0
    }
82
83
18.8k
    std::vector<omp_lock_t> locks(ntotal);
84
7.50M
    for (int i = 0; i < ntotal; i++)
85
7.48M
        omp_init_lock(&locks[i]);
86
87
    // add vectors from highest to lowest level
88
18.8k
    std::vector<int> hist;
89
18.8k
    std::vector<int> order(n);
90
91
18.8k
    { // make buckets with vectors of the same level
92
93
        // build histogram
94
45.9k
        for (int i = 0; i < n; i++) {
95
27.1k
            storage_idx_t pt_id = i + n0;
96
27.1k
            int pt_level = hnsw.levels[pt_id] - 1;
97
47.0k
            while (pt_level >= hist.size())
98
19.8k
                hist.push_back(0);
99
27.1k
            hist[pt_level]++;
100
27.1k
        }
101
102
        // accumulate
103
18.8k
        std::vector<int> offsets(hist.size() + 1, 0);
104
19.8k
        for (int i = 0; i < hist.size() - 1; i++) {
105
1.08k
            offsets[i + 1] = offsets[i] + hist[i];
106
1.08k
        }
107
108
        // bucket sort
109
45.9k
        for (int i = 0; i < n; i++) {
110
27.1k
            storage_idx_t pt_id = i + n0;
111
27.1k
            int pt_level = hnsw.levels[pt_id] - 1;
112
27.1k
            order[offsets[pt_level]++] = pt_id;
113
27.1k
        }
114
18.8k
    }
115
116
18.8k
    idx_t check_period = InterruptCallback::get_period_hint(
117
18.8k
            max_level * index_hnsw.d * hnsw.efConstruction);
118
119
18.8k
    { // perform add
120
18.8k
        RandomGenerator rng2(789);
121
122
18.8k
        int i1 = n;
123
124
18.8k
        for (int pt_level = hist.size() - 1;
125
38.6k
             pt_level >= int(!index_hnsw.init_level0);
126
19.8k
             pt_level--) {
127
19.8k
            int i0 = i1 - hist[pt_level];
128
129
19.8k
            if (verbose) {
130
0
                printf("Adding %d elements at level %d\n", i1 - i0, pt_level);
131
0
            }
132
133
            // random permutation to get rid of dataset order bias
134
47.0k
            for (int j = i0; j < i1; j++)
135
27.1k
                std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
136
137
19.8k
            bool interrupt = false;
138
139
19.8k
#pragma omp parallel if (i1 > i0 + 100)
140
84.9k
            {
141
84.9k
                VisitedTable vt(ntotal);
142
143
84.9k
                std::unique_ptr<DistanceComputer> dis(
144
84.9k
                        storage_distance_computer(index_hnsw.storage));
145
84.9k
                int prev_display =
146
84.9k
                        verbose && omp_get_thread_num() == 0 ? 0 : -1;
147
84.9k
                size_t counter = 0;
148
149
                // here we should do schedule(dynamic) but this segfaults for
150
                // some versions of LLVM. The performance impact should not be
151
                // too large when (i1 - i0) / num_threads >> 1
152
84.9k
#pragma omp for schedule(static)
153
84.9k
                for (int i = i0; i < i1; i++) {
154
84.9k
                    storage_idx_t pt_id = order[i];
155
84.9k
                    dis->set_query(x + (pt_id - n0) * d);
156
157
                    // cannot break
158
84.9k
                    if (interrupt) {
159
84.9k
                        continue;
160
84.9k
                    }
161
162
84.9k
                    hnsw.add_with_locks(
163
84.9k
                            *dis,
164
84.9k
                            pt_level,
165
84.9k
                            pt_id,
166
84.9k
                            locks,
167
84.9k
                            vt,
168
84.9k
                            index_hnsw.keep_max_size_level0 && (pt_level == 0));
169
170
84.9k
                    if (prev_display >= 0 && i - i0 > prev_display + 10000) {
171
84.9k
                        prev_display = i - i0;
172
84.9k
                        printf("  %d / %d\r", i - i0, i1 - i0);
173
84.9k
                        fflush(stdout);
174
84.9k
                    }
175
84.9k
                    if (counter % check_period == 0) {
176
84.9k
                        if (InterruptCallback::is_interrupted()) {
177
84.9k
                            interrupt = true;
178
84.9k
                        }
179
84.9k
                    }
180
84.9k
                    counter++;
181
84.9k
                }
182
84.9k
            }
183
19.8k
            if (interrupt) {
184
0
                FAISS_THROW_MSG("computation interrupted");
185
0
            }
186
19.8k
            i1 = i0;
187
19.8k
        }
188
18.8k
        if (index_hnsw.init_level0) {
189
18.8k
            FAISS_ASSERT(i1 == 0);
190
18.8k
        } else {
191
0
            FAISS_ASSERT((i1 - hist[0]) == 0);
192
0
        }
193
18.8k
    }
194
18.8k
    if (verbose) {
195
0
        printf("Done in %.3f ms\n", getmillisecs() - t0);
196
0
    }
197
198
7.50M
    for (int i = 0; i < ntotal; i++) {
199
7.48M
        omp_destroy_lock(&locks[i]);
200
7.48M
    }
201
18.8k
}
202
203
} // namespace
204
205
/**************************************************************
206
 * IndexHNSW implementation
207
 **************************************************************/
208
209
IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
210
59
        : Index(d, metric), hnsw(M) {}
211
212
IndexHNSW::IndexHNSW(Index* storage, int M)
213
135
        : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {
214
135
    metric_arg = storage->metric_arg;
215
135
}
216
217
193
IndexHNSW::~IndexHNSW() {
218
193
    if (own_fields) {
219
192
        delete storage;
220
192
    }
221
193
}
222
223
32
void IndexHNSW::train(idx_t n, const float* x) {
224
32
    FAISS_THROW_IF_NOT_MSG(
225
32
            storage,
226
32
            "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
227
    // hnsw structure does not require training
228
32
    storage->train(n, x);
229
32
    is_trained = true;
230
32
}
231
232
namespace {
233
234
template <class BlockResultHandler>
235
void hnsw_search(
236
        const IndexHNSW* index,
237
        idx_t n,
238
        const float* x,
239
        BlockResultHandler& bres,
240
152
        const SearchParameters* params) {
241
152
    FAISS_THROW_IF_NOT_MSG(
242
152
            index->storage,
243
152
            "No storage index, please use IndexHNSWFlat (or variants) "
244
152
            "instead of IndexHNSW directly");
245
152
    const HNSW& hnsw = index->hnsw;
246
247
152
    int efSearch = hnsw.efSearch;
248
152
    if (params) {
249
142
        if (const SearchParametersHNSW* hnsw_params =
250
142
                    dynamic_cast<const SearchParametersHNSW*>(params)) {
251
142
            efSearch = hnsw_params->efSearch;
252
142
        }
253
142
    }
254
152
    size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
255
256
152
    idx_t check_period = InterruptCallback::get_period_hint(
257
152
            hnsw.max_level * index->d * efSearch);
258
259
304
    for (idx_t i0 = 0; i0 < n; i0 += check_period) {
260
152
        idx_t i1 = std::min(i0 + check_period, n);
261
262
152
#pragma omp parallel if (i1 - i0 > 1)
263
455
        {
264
455
            VisitedTable vt(index->ntotal);
265
455
            typename BlockResultHandler::SingleResultHandler res(bres);
266
267
455
            std::unique_ptr<DistanceComputer> dis(
268
455
                    storage_distance_computer(index->storage));
269
270
455
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
271
455
            for (idx_t i = i0; i < i1; i++) {
272
455
                res.begin(i);
273
455
                dis->set_query(x + i * index->d);
274
275
455
                HNSWStats stats = hnsw.search(*dis, res, vt, params);
276
455
                n1 += stats.n1;
277
455
                n2 += stats.n2;
278
455
                ndis += stats.ndis;
279
455
                nhops += stats.nhops;
280
455
                res.end();
281
455
            }
282
455
        }
IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE.omp_outlined_debug__
Line
Count
Source
263
173
        {
264
173
            VisitedTable vt(index->ntotal);
265
173
            typename BlockResultHandler::SingleResultHandler res(bres);
266
267
173
            std::unique_ptr<DistanceComputer> dis(
268
173
                    storage_distance_computer(index->storage));
269
270
173
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
271
173
            for (idx_t i = i0; i < i1; i++) {
272
173
                res.begin(i);
273
173
                dis->set_query(x + i * index->d);
274
275
173
                HNSWStats stats = hnsw.search(*dis, res, vt, params);
276
173
                n1 += stats.n1;
277
173
                n2 += stats.n2;
278
173
                ndis += stats.ndis;
279
173
                nhops += stats.nhops;
280
173
                res.end();
281
173
            }
282
173
        }
IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE.omp_outlined_debug__
Line
Count
Source
263
282
        {
264
282
            VisitedTable vt(index->ntotal);
265
282
            typename BlockResultHandler::SingleResultHandler res(bres);
266
267
282
            std::unique_ptr<DistanceComputer> dis(
268
282
                    storage_distance_computer(index->storage));
269
270
282
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
271
282
            for (idx_t i = i0; i < i1; i++) {
272
282
                res.begin(i);
273
282
                dis->set_query(x + i * index->d);
274
275
282
                HNSWStats stats = hnsw.search(*dis, res, vt, params);
276
282
                n1 += stats.n1;
277
282
                n2 += stats.n2;
278
282
                ndis += stats.ndis;
279
282
                nhops += stats.nhops;
280
282
                res.end();
281
282
            }
282
282
        }
283
152
        InterruptCallback::check();
284
152
    }
285
286
152
    hnsw_stats.combine({n1, n2, ndis, nhops});
287
152
}
IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE
Line
Count
Source
240
58
        const SearchParameters* params) {
241
58
    FAISS_THROW_IF_NOT_MSG(
242
58
            index->storage,
243
58
            "No storage index, please use IndexHNSWFlat (or variants) "
244
58
            "instead of IndexHNSW directly");
245
58
    const HNSW& hnsw = index->hnsw;
246
247
58
    int efSearch = hnsw.efSearch;
248
58
    if (params) {
249
53
        if (const SearchParametersHNSW* hnsw_params =
250
53
                    dynamic_cast<const SearchParametersHNSW*>(params)) {
251
53
            efSearch = hnsw_params->efSearch;
252
53
        }
253
53
    }
254
58
    size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
255
256
58
    idx_t check_period = InterruptCallback::get_period_hint(
257
58
            hnsw.max_level * index->d * efSearch);
258
259
116
    for (idx_t i0 = 0; i0 < n; i0 += check_period) {
260
58
        idx_t i1 = std::min(i0 + check_period, n);
261
262
58
#pragma omp parallel if (i1 - i0 > 1)
263
58
        {
264
58
            VisitedTable vt(index->ntotal);
265
58
            typename BlockResultHandler::SingleResultHandler res(bres);
266
267
58
            std::unique_ptr<DistanceComputer> dis(
268
58
                    storage_distance_computer(index->storage));
269
270
58
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
271
58
            for (idx_t i = i0; i < i1; i++) {
272
58
                res.begin(i);
273
58
                dis->set_query(x + i * index->d);
274
275
58
                HNSWStats stats = hnsw.search(*dis, res, vt, params);
276
58
                n1 += stats.n1;
277
58
                n2 += stats.n2;
278
58
                ndis += stats.ndis;
279
58
                nhops += stats.nhops;
280
58
                res.end();
281
58
            }
282
58
        }
283
58
        InterruptCallback::check();
284
58
    }
285
286
58
    hnsw_stats.combine({n1, n2, ndis, nhops});
287
58
}
IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE
Line
Count
Source
240
94
        const SearchParameters* params) {
241
94
    FAISS_THROW_IF_NOT_MSG(
242
94
            index->storage,
243
94
            "No storage index, please use IndexHNSWFlat (or variants) "
244
94
            "instead of IndexHNSW directly");
245
94
    const HNSW& hnsw = index->hnsw;
246
247
94
    int efSearch = hnsw.efSearch;
248
94
    if (params) {
249
89
        if (const SearchParametersHNSW* hnsw_params =
250
89
                    dynamic_cast<const SearchParametersHNSW*>(params)) {
251
89
            efSearch = hnsw_params->efSearch;
252
89
        }
253
89
    }
254
94
    size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
255
256
94
    idx_t check_period = InterruptCallback::get_period_hint(
257
94
            hnsw.max_level * index->d * efSearch);
258
259
188
    for (idx_t i0 = 0; i0 < n; i0 += check_period) {
260
94
        idx_t i1 = std::min(i0 + check_period, n);
261
262
94
#pragma omp parallel if (i1 - i0 > 1)
263
94
        {
264
94
            VisitedTable vt(index->ntotal);
265
94
            typename BlockResultHandler::SingleResultHandler res(bres);
266
267
94
            std::unique_ptr<DistanceComputer> dis(
268
94
                    storage_distance_computer(index->storage));
269
270
94
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
271
94
            for (idx_t i = i0; i < i1; i++) {
272
94
                res.begin(i);
273
94
                dis->set_query(x + i * index->d);
274
275
94
                HNSWStats stats = hnsw.search(*dis, res, vt, params);
276
94
                n1 += stats.n1;
277
94
                n2 += stats.n2;
278
94
                ndis += stats.ndis;
279
94
                nhops += stats.nhops;
280
94
                res.end();
281
94
            }
282
94
        }
283
94
        InterruptCallback::check();
284
94
    }
285
286
94
    hnsw_stats.combine({n1, n2, ndis, nhops});
287
94
}
288
289
} // anonymous namespace
290
291
void IndexHNSW::search(
292
        idx_t n,
293
        const float* x,
294
        idx_t k,
295
        float* distances,
296
        idx_t* labels,
297
57
        const SearchParameters* params) const {
298
57
    FAISS_THROW_IF_NOT(k > 0);
299
300
57
    using RH = HeapBlockResultHandler<HNSW::C>;
301
57
    RH bres(n, distances, labels, k);
302
303
57
    hnsw_search(this, n, x, bres, params);
304
305
57
    if (is_similarity_metric(this->metric_type)) {
306
        // we need to revert the negated distances
307
196
        for (size_t i = 0; i < k * n; i++) {
308
177
            distances[i] = -distances[i];
309
177
        }
310
19
    }
311
57
}
312
313
void IndexHNSW::range_search(
314
        idx_t n,
315
        const float* x,
316
        float radius,
317
        RangeSearchResult* result,
318
94
        const SearchParameters* params) const {
319
94
    using RH = RangeSearchBlockResultHandler<HNSW::C>;
320
94
    RH bres(result, is_similarity_metric(metric_type) ? -radius : radius);
321
322
94
    hnsw_search(this, n, x, bres, params);
323
324
94
    if (is_similarity_metric(this->metric_type)) {
325
        // we need to revert the negated distances
326
1.51k
        for (size_t i = 0; i < result->lims[result->nq]; i++) {
327
1.48k
            result->distances[i] = -result->distances[i];
328
1.48k
        }
329
23
    }
330
94
}
331
332
18.8k
void IndexHNSW::add(idx_t n, const float* x) {
333
18.8k
    FAISS_THROW_IF_NOT_MSG(
334
18.8k
            storage,
335
18.8k
            "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
336
18.8k
    FAISS_THROW_IF_NOT(is_trained);
337
18.8k
    int n0 = ntotal;
338
18.8k
    storage->add(n, x);
339
18.8k
    ntotal = storage->ntotal;
340
341
18.8k
    hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal);
342
18.8k
}
343
344
0
void IndexHNSW::reset() {
345
0
    hnsw.reset();
346
0
    storage->reset();
347
0
    ntotal = 0;
348
0
}
349
350
0
void IndexHNSW::reconstruct(idx_t key, float* recons) const {
351
0
    storage->reconstruct(key, recons);
352
0
}
353
354
/**************************************************************
355
 * This section of functions were used during the development of HNSW support.
356
 * They may be useful in the future but are dormant for now, and thus are not
357
 * unit tested at the moment.
358
 * shrink_level_0_neighbors
359
 * search_level_0
360
 * init_level_0_from_knngraph
361
 * init_level_0_from_entry_points
362
 * reorder_links
363
 * link_singletons
364
 **************************************************************/
365
0
void IndexHNSW::shrink_level_0_neighbors(int new_size) {
366
0
#pragma omp parallel
367
0
    {
368
0
        std::unique_ptr<DistanceComputer> dis(
369
0
                storage_distance_computer(storage));
370
371
0
#pragma omp for
372
0
        for (idx_t i = 0; i < ntotal; i++) {
373
0
            size_t begin, end;
374
0
            hnsw.neighbor_range(i, 0, &begin, &end);
375
376
0
            std::priority_queue<NodeDistFarther> initial_list;
377
378
0
            for (size_t j = begin; j < end; j++) {
379
0
                int v1 = hnsw.neighbors[j];
380
0
                if (v1 < 0)
381
0
                    break;
382
0
                initial_list.emplace(dis->symmetric_dis(i, v1), v1);
383
384
                // initial_list.emplace(qdis(v1), v1);
385
0
            }
386
387
0
            std::vector<NodeDistFarther> shrunk_list;
388
0
            HNSW::shrink_neighbor_list(
389
0
                    *dis, initial_list, shrunk_list, new_size);
390
391
0
            for (size_t j = begin; j < end; j++) {
392
0
                if (j - begin < shrunk_list.size())
393
0
                    hnsw.neighbors[j] = shrunk_list[j - begin].id;
394
0
                else
395
0
                    hnsw.neighbors[j] = -1;
396
0
            }
397
0
        }
398
0
    }
399
0
}
400
401
void IndexHNSW::search_level_0(
402
        idx_t n,
403
        const float* x,
404
        idx_t k,
405
        const storage_idx_t* nearest,
406
        const float* nearest_d,
407
        float* distances,
408
        idx_t* labels,
409
        int nprobe,
410
        int search_type,
411
0
        const SearchParameters* params) const {
412
0
    FAISS_THROW_IF_NOT(k > 0);
413
0
    FAISS_THROW_IF_NOT(nprobe > 0);
414
415
0
    storage_idx_t ntotal = hnsw.levels.size();
416
417
0
    using RH = HeapBlockResultHandler<HNSW::C>;
418
0
    RH bres(n, distances, labels, k);
419
420
0
#pragma omp parallel
421
0
    {
422
0
        std::unique_ptr<DistanceComputer> qdis(
423
0
                storage_distance_computer(storage));
424
0
        HNSWStats search_stats;
425
0
        VisitedTable vt(ntotal);
426
0
        RH::SingleResultHandler res(bres);
427
428
0
#pragma omp for
429
0
        for (idx_t i = 0; i < n; i++) {
430
0
            res.begin(i);
431
0
            qdis->set_query(x + i * d);
432
433
0
            hnsw.search_level_0(
434
0
                    *qdis.get(),
435
0
                    res,
436
0
                    nprobe,
437
0
                    nearest + i * nprobe,
438
0
                    nearest_d + i * nprobe,
439
0
                    search_type,
440
0
                    search_stats,
441
0
                    vt,
442
0
                    params);
443
0
            res.end();
444
0
            vt.advance();
445
0
        }
446
0
#pragma omp critical
447
0
        { hnsw_stats.combine(search_stats); }
448
0
    }
449
0
    if (is_similarity_metric(this->metric_type)) {
450
// we need to revert the negated distances
451
0
#pragma omp parallel for
452
0
        for (int64_t i = 0; i < k * n; i++) {
453
0
            distances[i] = -distances[i];
454
0
        }
455
0
    }
456
0
}
457
458
void IndexHNSW::init_level_0_from_knngraph(
459
        int k,
460
        const float* D,
461
0
        const idx_t* I) {
462
0
    int dest_size = hnsw.nb_neighbors(0);
463
464
0
#pragma omp parallel for
465
0
    for (idx_t i = 0; i < ntotal; i++) {
466
0
        DistanceComputer* qdis = storage_distance_computer(storage);
467
0
        std::vector<float> vec(d);
468
0
        storage->reconstruct(i, vec.data());
469
0
        qdis->set_query(vec.data());
470
471
0
        std::priority_queue<NodeDistFarther> initial_list;
472
473
0
        for (size_t j = 0; j < k; j++) {
474
0
            int v1 = I[i * k + j];
475
0
            if (v1 == i)
476
0
                continue;
477
0
            if (v1 < 0)
478
0
                break;
479
0
            initial_list.emplace(D[i * k + j], v1);
480
0
        }
481
482
0
        std::vector<NodeDistFarther> shrunk_list;
483
0
        HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size);
484
485
0
        size_t begin, end;
486
0
        hnsw.neighbor_range(i, 0, &begin, &end);
487
488
0
        for (size_t j = begin; j < end; j++) {
489
0
            if (j - begin < shrunk_list.size())
490
0
                hnsw.neighbors[j] = shrunk_list[j - begin].id;
491
0
            else
492
0
                hnsw.neighbors[j] = -1;
493
0
        }
494
0
    }
495
0
}
496
497
void IndexHNSW::init_level_0_from_entry_points(
498
        int n,
499
        const storage_idx_t* points,
500
0
        const storage_idx_t* nearests) {
501
0
    std::vector<omp_lock_t> locks(ntotal);
502
0
    for (int i = 0; i < ntotal; i++)
503
0
        omp_init_lock(&locks[i]);
504
505
0
#pragma omp parallel
506
0
    {
507
0
        VisitedTable vt(ntotal);
508
509
0
        std::unique_ptr<DistanceComputer> dis(
510
0
                storage_distance_computer(storage));
511
0
        std::vector<float> vec(storage->d);
512
513
0
#pragma omp for schedule(dynamic)
514
0
        for (int i = 0; i < n; i++) {
515
0
            storage_idx_t pt_id = points[i];
516
0
            storage_idx_t nearest = nearests[i];
517
0
            storage->reconstruct(pt_id, vec.data());
518
0
            dis->set_query(vec.data());
519
520
0
            hnsw.add_links_starting_from(
521
0
                    *dis, pt_id, nearest, (*dis)(nearest), 0, locks.data(), vt);
522
523
0
            if (verbose && i % 10000 == 0) {
524
0
                printf("  %d / %d\r", i, n);
525
0
                fflush(stdout);
526
0
            }
527
0
        }
528
0
    }
529
0
    if (verbose) {
530
0
        printf("\n");
531
0
    }
532
533
0
    for (int i = 0; i < ntotal; i++)
534
0
        omp_destroy_lock(&locks[i]);
535
0
}
536
537
0
void IndexHNSW::reorder_links() {
538
0
    int M = hnsw.nb_neighbors(0);
539
540
0
#pragma omp parallel
541
0
    {
542
0
        std::vector<float> distances(M);
543
0
        std::vector<size_t> order(M);
544
0
        std::vector<storage_idx_t> tmp(M);
545
0
        std::unique_ptr<DistanceComputer> dis(
546
0
                storage_distance_computer(storage));
547
548
0
#pragma omp for
549
0
        for (storage_idx_t i = 0; i < ntotal; i++) {
550
0
            size_t begin, end;
551
0
            hnsw.neighbor_range(i, 0, &begin, &end);
552
553
0
            for (size_t j = begin; j < end; j++) {
554
0
                storage_idx_t nj = hnsw.neighbors[j];
555
0
                if (nj < 0) {
556
0
                    end = j;
557
0
                    break;
558
0
                }
559
0
                distances[j - begin] = dis->symmetric_dis(i, nj);
560
0
                tmp[j - begin] = nj;
561
0
            }
562
563
0
            fvec_argsort(end - begin, distances.data(), order.data());
564
0
            for (size_t j = begin; j < end; j++) {
565
0
                hnsw.neighbors[j] = tmp[order[j - begin]];
566
0
            }
567
0
        }
568
0
    }
569
0
}
570
571
0
void IndexHNSW::link_singletons() {
572
0
    printf("search for singletons\n");
573
574
0
    std::vector<bool> seen(ntotal);
575
576
0
    for (size_t i = 0; i < ntotal; i++) {
577
0
        size_t begin, end;
578
0
        hnsw.neighbor_range(i, 0, &begin, &end);
579
0
        for (size_t j = begin; j < end; j++) {
580
0
            storage_idx_t ni = hnsw.neighbors[j];
581
0
            if (ni >= 0)
582
0
                seen[ni] = true;
583
0
        }
584
0
    }
585
586
0
    int n_sing = 0, n_sing_l1 = 0;
587
0
    std::vector<storage_idx_t> singletons;
588
0
    for (storage_idx_t i = 0; i < ntotal; i++) {
589
0
        if (!seen[i]) {
590
0
            singletons.push_back(i);
591
0
            n_sing++;
592
0
            if (hnsw.levels[i] > 1)
593
0
                n_sing_l1++;
594
0
        }
595
0
    }
596
597
0
    printf("  Found %d / %" PRId64 " singletons (%d appear in a level above)\n",
598
0
           n_sing,
599
0
           ntotal,
600
0
           n_sing_l1);
601
602
0
    std::vector<float> recons(singletons.size() * d);
603
0
    for (int i = 0; i < singletons.size(); i++) {
604
0
        FAISS_ASSERT(!"not implemented");
605
0
    }
606
0
}
607
608
0
void IndexHNSW::permute_entries(const idx_t* perm) {
609
0
    auto flat_storage = dynamic_cast<IndexFlatCodes*>(storage);
610
0
    FAISS_THROW_IF_NOT_MSG(
611
0
            flat_storage, "don't know how to permute this index");
612
0
    flat_storage->permute_entries(perm);
613
0
    hnsw.permute_entries(perm);
614
0
}
615
616
0
DistanceComputer* IndexHNSW::get_distance_computer() const {
617
0
    return storage->get_distance_computer();
618
0
}
619
620
/**************************************************************
621
 * IndexHNSWFlat implementation
622
 **************************************************************/
623
624
59
IndexHNSWFlat::IndexHNSWFlat() {
625
59
    is_trained = true;
626
59
}
627
628
IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
629
133
        : IndexHNSW(
630
133
                  (metric == METRIC_L2) ? new IndexFlatL2(d)
631
133
                                        : new IndexFlat(d, metric),
632
133
                  M) {
633
133
    own_fields = true;
634
133
    is_trained = true;
635
133
}
636
637
/**************************************************************
638
 * IndexHNSWPQ implementation
639
 **************************************************************/
640
641
0
IndexHNSWPQ::IndexHNSWPQ() = default;
642
643
IndexHNSWPQ::IndexHNSWPQ(
644
        int d,
645
        int pq_m,
646
        int M,
647
        int pq_nbits,
648
        MetricType metric)
649
0
        : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
650
0
    own_fields = true;
651
0
    is_trained = false;
652
0
}
653
654
0
void IndexHNSWPQ::train(idx_t n, const float* x) {
655
0
    IndexHNSW::train(n, x);
656
0
    (dynamic_cast<IndexPQ*>(storage))->pq.compute_sdc_table();
657
0
}
658
659
/**************************************************************
660
 * IndexHNSWSQ implementation
661
 **************************************************************/
662
663
IndexHNSWSQ::IndexHNSWSQ(
664
        int d,
665
        ScalarQuantizer::QuantizerType qtype,
666
        int M,
667
        MetricType metric)
668
2
        : IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) {
669
2
    is_trained = this->storage->is_trained;
670
2
    own_fields = true;
671
2
}
672
673
0
IndexHNSWSQ::IndexHNSWSQ() = default;
674
675
/**************************************************************
676
 * IndexHNSW2Level implementation
677
 **************************************************************/
678
679
IndexHNSW2Level::IndexHNSW2Level(
680
        Index* quantizer,
681
        size_t nlist,
682
        int m_pq,
683
        int M)
684
0
        : IndexHNSW(new Index2Layer(quantizer, nlist, m_pq), M) {
685
0
    own_fields = true;
686
0
    is_trained = false;
687
0
}
688
689
0
IndexHNSW2Level::IndexHNSW2Level() = default;
690
691
namespace {
692
693
// same as search_from_candidates but uses v
694
// visno -> is in result list
695
// visno + 1 -> in result list + in candidates
696
int search_from_candidates_2(
697
        const HNSW& hnsw,
698
        DistanceComputer& qdis,
699
        int k,
700
        idx_t* I,
701
        float* D,
702
        MinimaxHeap& candidates,
703
        VisitedTable& vt,
704
        HNSWStats& stats,
705
        int level,
706
0
        int nres_in = 0) {
707
0
    int nres = nres_in;
708
0
    for (int i = 0; i < candidates.size(); i++) {
709
0
        idx_t v1 = candidates.ids[i];
710
0
        FAISS_ASSERT(v1 >= 0);
711
0
        vt.visited[v1] = vt.visno + 1;
712
0
    }
713
714
0
    int nstep = 0;
715
716
0
    while (candidates.size() > 0) {
717
0
        float d0 = 0;
718
0
        int v0 = candidates.pop_min(&d0);
719
720
0
        size_t begin, end;
721
0
        hnsw.neighbor_range(v0, level, &begin, &end);
722
723
0
        for (size_t j = begin; j < end; j++) {
724
0
            int v1 = hnsw.neighbors[j];
725
0
            if (v1 < 0)
726
0
                break;
727
0
            if (vt.visited[v1] == vt.visno + 1) {
728
                // nothing to do
729
0
            } else {
730
0
                float d = qdis(v1);
731
0
                candidates.push(v1, d);
732
733
                // never seen before --> add to heap
734
0
                if (vt.visited[v1] < vt.visno) {
735
0
                    if (nres < k) {
736
0
                        faiss::maxheap_push(++nres, D, I, d, v1);
737
0
                    } else if (d < D[0]) {
738
0
                        faiss::maxheap_replace_top(nres, D, I, d, v1);
739
0
                    }
740
0
                }
741
0
                vt.visited[v1] = vt.visno + 1;
742
0
            }
743
0
        }
744
745
0
        nstep++;
746
0
        if (nstep > hnsw.efSearch) {
747
0
            break;
748
0
        }
749
0
    }
750
751
0
    stats.n1++;
752
0
    if (candidates.size() == 0)
753
0
        stats.n2++;
754
755
0
    return nres;
756
0
}
757
758
} // namespace
759
760
void IndexHNSW2Level::search(
761
        idx_t n,
762
        const float* x,
763
        idx_t k,
764
        float* distances,
765
        idx_t* labels,
766
0
        const SearchParameters* params) const {
767
0
    FAISS_THROW_IF_NOT(k > 0);
768
0
    FAISS_THROW_IF_NOT_MSG(
769
0
            !params, "search params not supported for this index");
770
771
0
    if (dynamic_cast<const Index2Layer*>(storage)) {
772
0
        IndexHNSW::search(n, x, k, distances, labels);
773
774
0
    } else { // "mixed" search
775
0
        size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
776
777
0
        const IndexIVFPQ* index_ivfpq =
778
0
                dynamic_cast<const IndexIVFPQ*>(storage);
779
780
0
        int nprobe = index_ivfpq->nprobe;
781
782
0
        std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
783
0
        std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
784
785
0
        index_ivfpq->quantizer->search(
786
0
                n, x, nprobe, coarse_dis.get(), coarse_assign.get());
787
788
0
        index_ivfpq->search_preassigned(
789
0
                n,
790
0
                x,
791
0
                k,
792
0
                coarse_assign.get(),
793
0
                coarse_dis.get(),
794
0
                distances,
795
0
                labels,
796
0
                false);
797
798
0
#pragma omp parallel
799
0
        {
800
0
            VisitedTable vt(ntotal);
801
0
            std::unique_ptr<DistanceComputer> dis(
802
0
                    storage_distance_computer(storage));
803
804
0
            constexpr int candidates_size = 1;
805
0
            MinimaxHeap candidates(candidates_size);
806
807
0
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
808
0
            for (idx_t i = 0; i < n; i++) {
809
0
                idx_t* idxi = labels + i * k;
810
0
                float* simi = distances + i * k;
811
0
                dis->set_query(x + i * d);
812
813
                // mark all inverted list elements as visited
814
815
0
                for (int j = 0; j < nprobe; j++) {
816
0
                    idx_t key = coarse_assign[j + i * nprobe];
817
0
                    if (key < 0)
818
0
                        break;
819
0
                    size_t list_length = index_ivfpq->get_list_size(key);
820
0
                    const idx_t* ids = index_ivfpq->invlists->get_ids(key);
821
822
0
                    for (int jj = 0; jj < list_length; jj++) {
823
0
                        vt.set(ids[jj]);
824
0
                    }
825
0
                }
826
827
0
                candidates.clear();
828
829
0
                for (int j = 0; j < k; j++) {
830
0
                    if (idxi[j] < 0)
831
0
                        break;
832
0
                    candidates.push(idxi[j], simi[j]);
833
0
                }
834
835
                // reorder from sorted to heap
836
0
                maxheap_heapify(k, simi, idxi, simi, idxi, k);
837
838
0
                HNSWStats search_stats;
839
0
                search_from_candidates_2(
840
0
                        hnsw,
841
0
                        *dis,
842
0
                        k,
843
0
                        idxi,
844
0
                        simi,
845
0
                        candidates,
846
0
                        vt,
847
0
                        search_stats,
848
0
                        0,
849
0
                        k);
850
0
                n1 += search_stats.n1;
851
0
                n2 += search_stats.n2;
852
0
                ndis += search_stats.ndis;
853
0
                nhops += search_stats.nhops;
854
855
0
                vt.advance();
856
0
                vt.advance();
857
858
0
                maxheap_reorder(k, simi, idxi);
859
0
            }
860
0
        }
861
862
0
        hnsw_stats.combine({n1, n2, ndis, nhops});
863
0
    }
864
0
}
865
866
0
void IndexHNSW2Level::flip_to_ivf() {
867
0
    Index2Layer* storage2l = dynamic_cast<Index2Layer*>(storage);
868
869
0
    FAISS_THROW_IF_NOT(storage2l);
870
871
0
    IndexIVFPQ* index_ivfpq = new IndexIVFPQ(
872
0
            storage2l->q1.quantizer,
873
0
            d,
874
0
            storage2l->q1.nlist,
875
0
            storage2l->pq.M,
876
0
            8);
877
0
    index_ivfpq->pq = storage2l->pq;
878
0
    index_ivfpq->is_trained = storage2l->is_trained;
879
0
    index_ivfpq->precompute_table();
880
0
    index_ivfpq->own_fields = storage2l->q1.own_fields;
881
0
    storage2l->transfer_to_IVFPQ(*index_ivfpq);
882
0
    index_ivfpq->make_direct_map(true);
883
884
0
    storage = index_ivfpq;
885
0
    delete storage2l;
886
0
}
887
888
/**************************************************************
889
 * IndexHNSWCagra implementation
890
 **************************************************************/
891
892
0
IndexHNSWCagra::IndexHNSWCagra() {
893
0
    is_trained = true;
894
0
}
895
896
IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
897
0
        : IndexHNSW(
898
0
                  (metric == METRIC_L2)
899
0
                          ? static_cast<IndexFlat*>(new IndexFlatL2(d))
900
0
                          : static_cast<IndexFlat*>(new IndexFlatIP(d)),
901
0
                  M) {
902
0
    FAISS_THROW_IF_NOT_MSG(
903
0
            ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
904
0
            "unsupported metric type for IndexHNSWCagra");
905
0
    own_fields = true;
906
0
    is_trained = true;
907
0
    init_level0 = true;
908
0
    keep_max_size_level0 = true;
909
0
}
910
911
0
void IndexHNSWCagra::add(idx_t n, const float* x) {
912
0
    FAISS_THROW_IF_NOT_MSG(
913
0
            !base_level_only,
914
0
            "Cannot add vectors when base_level_only is set to True");
915
916
0
    IndexHNSW::add(n, x);
917
0
}
918
919
void IndexHNSWCagra::search(
920
        idx_t n,
921
        const float* x,
922
        idx_t k,
923
        float* distances,
924
        idx_t* labels,
925
0
        const SearchParameters* params) const {
926
0
    if (!base_level_only) {
927
0
        IndexHNSW::search(n, x, k, distances, labels, params);
928
0
    } else {
929
0
        std::vector<storage_idx_t> nearest(n);
930
0
        std::vector<float> nearest_d(n);
931
932
0
#pragma omp for
933
0
        for (idx_t i = 0; i < n; i++) {
934
0
            std::unique_ptr<DistanceComputer> dis(
935
0
                    storage_distance_computer(this->storage));
936
0
            dis->set_query(x + i * d);
937
0
            nearest[i] = -1;
938
0
            nearest_d[i] = std::numeric_limits<float>::max();
939
940
0
            std::random_device rd;
941
0
            std::mt19937 gen(rd());
942
0
            std::uniform_int_distribution<idx_t> distrib(0, this->ntotal - 1);
943
944
0
            for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) {
945
0
                auto idx = distrib(gen);
946
0
                auto distance = (*dis)(idx);
947
0
                if (distance < nearest_d[i]) {
948
0
                    nearest[i] = idx;
949
0
                    nearest_d[i] = distance;
950
0
                }
951
0
            }
952
0
            FAISS_THROW_IF_NOT_MSG(
953
0
                    nearest[i] >= 0, "Could not find a valid entrypoint.");
954
0
        }
955
956
0
        search_level_0(
957
0
                n,
958
0
                x,
959
0
                k,
960
0
                nearest.data(),
961
0
                nearest_d.data(),
962
0
                distances,
963
0
                labels,
964
0
                1, // n_probes
965
0
                1, // search_type
966
0
                params);
967
0
    }
968
0
}
969
970
} // namespace faiss