Coverage Report

Created: 2025-09-30 11:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexNNDescent.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/IndexNNDescent.h>
11
12
#include <omp.h>
13
14
#include <cinttypes>
15
#include <cstdio>
16
#include <cstdlib>
17
18
#include <queue>
19
#include <unordered_set>
20
21
#ifdef __SSE__
22
#endif
23
24
#include <faiss/IndexFlat.h>
25
#include <faiss/impl/AuxIndexStructures.h>
26
#include <faiss/impl/FaissAssert.h>
27
#include <faiss/utils/Heap.h>
28
#include <faiss/utils/distances.h>
29
#include <faiss/utils/random.h>
30
31
extern "C" {
32
33
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
34
35
int sgemm_(
36
        const char* transa,
37
        const char* transb,
38
        FINTEGER* m,
39
        FINTEGER* n,
40
        FINTEGER* k,
41
        const float* alpha,
42
        const float* a,
43
        FINTEGER* lda,
44
        const float* b,
45
        FINTEGER* ldb,
46
        float* beta,
47
        float* c,
48
        FINTEGER* ldc);
49
}
50
51
namespace faiss {
52
53
using storage_idx_t = NNDescent::storage_idx_t;
54
55
/**************************************************************
56
 * add / search blocks of descriptors
57
 **************************************************************/
58
59
namespace {
60
61
0
DistanceComputer* storage_distance_computer(const Index* storage) {
62
0
    if (is_similarity_metric(storage->metric_type)) {
63
0
        return new NegativeDistanceComputer(storage->get_distance_computer());
64
0
    } else {
65
0
        return storage->get_distance_computer();
66
0
    }
67
0
}
68
69
} // namespace
70
71
/**************************************************************
72
 * IndexNNDescent implementation
73
 **************************************************************/
74
75
IndexNNDescent::IndexNNDescent(int d, int K, MetricType metric)
76
0
        : Index(d, metric),
77
0
          nndescent(d, K),
78
0
          own_fields(false),
79
0
          storage(nullptr) {}
80
81
IndexNNDescent::IndexNNDescent(Index* storage, int K)
82
0
        : Index(storage->d, storage->metric_type),
83
0
          nndescent(storage->d, K),
84
0
          own_fields(false),
85
0
          storage(storage) {}
86
87
0
IndexNNDescent::~IndexNNDescent() {
88
0
    if (own_fields) {
89
0
        delete storage;
90
0
    }
91
0
}
92
93
0
void IndexNNDescent::train(idx_t n, const float* x) {
94
0
    FAISS_THROW_IF_NOT_MSG(
95
0
            storage,
96
0
            "Please use IndexNNDescentFlat (or variants) "
97
0
            "instead of IndexNNDescent directly");
98
    // nndescent structure does not require training
99
0
    storage->train(n, x);
100
0
    is_trained = true;
101
0
}
102
103
void IndexNNDescent::search(
104
        idx_t n,
105
        const float* x,
106
        idx_t k,
107
        float* distances,
108
        idx_t* labels,
109
0
        const SearchParameters* params) const {
110
0
    FAISS_THROW_IF_NOT_MSG(
111
0
            !params, "search params not supported for this index");
112
0
    FAISS_THROW_IF_NOT_MSG(
113
0
            storage,
114
0
            "Please use IndexNNDescentFlat (or variants) "
115
0
            "instead of IndexNNDescent directly");
116
0
    if (verbose) {
117
0
        printf("Parameters: k=%" PRId64 ", search_L=%d\n",
118
0
               k,
119
0
               nndescent.search_L);
120
0
    }
121
122
0
    idx_t check_period =
123
0
            InterruptCallback::get_period_hint(d * nndescent.search_L);
124
125
0
    for (idx_t i0 = 0; i0 < n; i0 += check_period) {
126
0
        idx_t i1 = std::min(i0 + check_period, n);
127
128
0
#pragma omp parallel
129
0
        {
130
0
            VisitedTable vt(ntotal);
131
132
0
            std::unique_ptr<DistanceComputer> dis(
133
0
                    storage_distance_computer(storage));
134
135
0
#pragma omp for
136
0
            for (idx_t i = i0; i < i1; i++) {
137
0
                idx_t* idxi = labels + i * k;
138
0
                float* simi = distances + i * k;
139
0
                dis->set_query(x + i * d);
140
141
0
                nndescent.search(*dis, k, idxi, simi, vt);
142
0
            }
143
0
        }
144
0
        InterruptCallback::check();
145
0
    }
146
147
0
    if (metric_type == METRIC_INNER_PRODUCT) {
148
        // we need to revert the negated distances
149
0
        for (size_t i = 0; i < k * n; i++) {
150
0
            distances[i] = -distances[i];
151
0
        }
152
0
    }
153
0
}
154
155
0
void IndexNNDescent::add(idx_t n, const float* x) {
156
0
    FAISS_THROW_IF_NOT_MSG(
157
0
            storage,
158
0
            "Please use IndexNNDescentFlat (or variants) "
159
0
            "instead of IndexNNDescent directly");
160
0
    FAISS_THROW_IF_NOT(is_trained);
161
162
0
    if (ntotal != 0) {
163
0
        fprintf(stderr,
164
0
                "WARNING NNDescent doest not support dynamic insertions,"
165
0
                "multiple insertions would lead to re-building the index");
166
0
    }
167
168
0
    storage->add(n, x);
169
0
    ntotal = storage->ntotal;
170
171
0
    std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
172
0
    nndescent.build(*dis, ntotal, verbose);
173
0
}
174
175
0
void IndexNNDescent::reset() {
176
0
    nndescent.reset();
177
0
    storage->reset();
178
0
    ntotal = 0;
179
0
}
180
181
0
void IndexNNDescent::reconstruct(idx_t key, float* recons) const {
182
0
    storage->reconstruct(key, recons);
183
0
}
184
185
/**************************************************************
186
 * IndexNNDescentFlat implementation
187
 **************************************************************/
188
189
0
IndexNNDescentFlat::IndexNNDescentFlat() {
190
0
    is_trained = true;
191
0
}
192
193
IndexNNDescentFlat::IndexNNDescentFlat(int d, int M, MetricType metric)
194
0
        : IndexNNDescent(new IndexFlat(d, metric), M) {
195
0
    own_fields = true;
196
0
    is_trained = true;
197
0
}
198
199
} // namespace faiss