/root/doris/contrib/faiss/faiss/IndexNNDescent.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexNNDescent.h> |
11 | | |
12 | | #include <omp.h> |
13 | | |
14 | | #include <cinttypes> |
15 | | #include <cstdio> |
16 | | #include <cstdlib> |
17 | | |
18 | | #include <queue> |
19 | | #include <unordered_set> |
20 | | |
21 | | #ifdef __SSE__ |
22 | | #endif |
23 | | |
24 | | #include <faiss/IndexFlat.h> |
25 | | #include <faiss/impl/AuxIndexStructures.h> |
26 | | #include <faiss/impl/FaissAssert.h> |
27 | | #include <faiss/utils/Heap.h> |
28 | | #include <faiss/utils/distances.h> |
29 | | #include <faiss/utils/random.h> |
30 | | |
31 | | extern "C" { |
32 | | |
33 | | /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ |
34 | | |
35 | | int sgemm_( |
36 | | const char* transa, |
37 | | const char* transb, |
38 | | FINTEGER* m, |
39 | | FINTEGER* n, |
40 | | FINTEGER* k, |
41 | | const float* alpha, |
42 | | const float* a, |
43 | | FINTEGER* lda, |
44 | | const float* b, |
45 | | FINTEGER* ldb, |
46 | | float* beta, |
47 | | float* c, |
48 | | FINTEGER* ldc); |
49 | | } |
50 | | |
51 | | namespace faiss { |
52 | | |
53 | | using storage_idx_t = NNDescent::storage_idx_t; |
54 | | |
55 | | /************************************************************** |
56 | | * add / search blocks of descriptors |
57 | | **************************************************************/ |
58 | | |
59 | | namespace { |
60 | | |
61 | 0 | DistanceComputer* storage_distance_computer(const Index* storage) { |
62 | 0 | if (is_similarity_metric(storage->metric_type)) { |
63 | 0 | return new NegativeDistanceComputer(storage->get_distance_computer()); |
64 | 0 | } else { |
65 | 0 | return storage->get_distance_computer(); |
66 | 0 | } |
67 | 0 | } |
68 | | |
69 | | } // namespace |
70 | | |
71 | | /************************************************************** |
72 | | * IndexNNDescent implementation |
73 | | **************************************************************/ |
74 | | |
75 | | IndexNNDescent::IndexNNDescent(int d, int K, MetricType metric) |
76 | 0 | : Index(d, metric), |
77 | 0 | nndescent(d, K), |
78 | 0 | own_fields(false), |
79 | 0 | storage(nullptr) {} |
80 | | |
81 | | IndexNNDescent::IndexNNDescent(Index* storage, int K) |
82 | 0 | : Index(storage->d, storage->metric_type), |
83 | 0 | nndescent(storage->d, K), |
84 | 0 | own_fields(false), |
85 | 0 | storage(storage) {} |
86 | | |
87 | 0 | IndexNNDescent::~IndexNNDescent() { |
88 | 0 | if (own_fields) { |
89 | 0 | delete storage; |
90 | 0 | } |
91 | 0 | } |
92 | | |
93 | 0 | void IndexNNDescent::train(idx_t n, const float* x) { |
94 | 0 | FAISS_THROW_IF_NOT_MSG( |
95 | 0 | storage, |
96 | 0 | "Please use IndexNNDescentFlat (or variants) " |
97 | 0 | "instead of IndexNNDescent directly"); |
98 | | // nndescent structure does not require training |
99 | 0 | storage->train(n, x); |
100 | 0 | is_trained = true; |
101 | 0 | } |
102 | | |
103 | | void IndexNNDescent::search( |
104 | | idx_t n, |
105 | | const float* x, |
106 | | idx_t k, |
107 | | float* distances, |
108 | | idx_t* labels, |
109 | 0 | const SearchParameters* params) const { |
110 | 0 | FAISS_THROW_IF_NOT_MSG( |
111 | 0 | !params, "search params not supported for this index"); |
112 | 0 | FAISS_THROW_IF_NOT_MSG( |
113 | 0 | storage, |
114 | 0 | "Please use IndexNNDescentFlat (or variants) " |
115 | 0 | "instead of IndexNNDescent directly"); |
116 | 0 | if (verbose) { |
117 | 0 | printf("Parameters: k=%" PRId64 ", search_L=%d\n", |
118 | 0 | k, |
119 | 0 | nndescent.search_L); |
120 | 0 | } |
121 | |
|
122 | 0 | idx_t check_period = |
123 | 0 | InterruptCallback::get_period_hint(d * nndescent.search_L); |
124 | |
|
125 | 0 | for (idx_t i0 = 0; i0 < n; i0 += check_period) { |
126 | 0 | idx_t i1 = std::min(i0 + check_period, n); |
127 | |
|
128 | 0 | #pragma omp parallel |
129 | 0 | { |
130 | 0 | VisitedTable vt(ntotal); |
131 | |
|
132 | 0 | std::unique_ptr<DistanceComputer> dis( |
133 | 0 | storage_distance_computer(storage)); |
134 | |
|
135 | 0 | #pragma omp for |
136 | 0 | for (idx_t i = i0; i < i1; i++) { |
137 | 0 | idx_t* idxi = labels + i * k; |
138 | 0 | float* simi = distances + i * k; |
139 | 0 | dis->set_query(x + i * d); |
140 | |
|
141 | 0 | nndescent.search(*dis, k, idxi, simi, vt); |
142 | 0 | } |
143 | 0 | } |
144 | 0 | InterruptCallback::check(); |
145 | 0 | } |
146 | |
|
147 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
148 | | // we need to revert the negated distances |
149 | 0 | for (size_t i = 0; i < k * n; i++) { |
150 | 0 | distances[i] = -distances[i]; |
151 | 0 | } |
152 | 0 | } |
153 | 0 | } |
154 | | |
155 | 0 | void IndexNNDescent::add(idx_t n, const float* x) { |
156 | 0 | FAISS_THROW_IF_NOT_MSG( |
157 | 0 | storage, |
158 | 0 | "Please use IndexNNDescentFlat (or variants) " |
159 | 0 | "instead of IndexNNDescent directly"); |
160 | 0 | FAISS_THROW_IF_NOT(is_trained); |
161 | | |
162 | 0 | if (ntotal != 0) { |
163 | 0 | fprintf(stderr, |
164 | 0 | "WARNING NNDescent doest not support dynamic insertions," |
165 | 0 | "multiple insertions would lead to re-building the index"); |
166 | 0 | } |
167 | |
|
168 | 0 | storage->add(n, x); |
169 | 0 | ntotal = storage->ntotal; |
170 | |
|
171 | 0 | std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage)); |
172 | 0 | nndescent.build(*dis, ntotal, verbose); |
173 | 0 | } |
174 | | |
175 | 0 | void IndexNNDescent::reset() { |
176 | 0 | nndescent.reset(); |
177 | 0 | storage->reset(); |
178 | 0 | ntotal = 0; |
179 | 0 | } |
180 | | |
181 | 0 | void IndexNNDescent::reconstruct(idx_t key, float* recons) const { |
182 | 0 | storage->reconstruct(key, recons); |
183 | 0 | } |
184 | | |
185 | | /************************************************************** |
186 | | * IndexNNDescentFlat implementation |
187 | | **************************************************************/ |
188 | | |
189 | 0 | IndexNNDescentFlat::IndexNNDescentFlat() { |
190 | 0 | is_trained = true; |
191 | 0 | } |
192 | | |
193 | | IndexNNDescentFlat::IndexNNDescentFlat(int d, int M, MetricType metric) |
194 | 0 | : IndexNNDescent(new IndexFlat(d, metric), M) { |
195 | 0 | own_fields = true; |
196 | 0 | is_trained = true; |
197 | 0 | } |
198 | | |
199 | | } // namespace faiss |