/root/doris/contrib/faiss/faiss/IndexNSG.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexNSG.h> |
11 | | |
12 | | #include <cinttypes> |
13 | | #include <memory> |
14 | | |
15 | | #include <faiss/IndexFlat.h> |
16 | | #include <faiss/IndexNNDescent.h> |
17 | | #include <faiss/impl/AuxIndexStructures.h> |
18 | | #include <faiss/impl/FaissAssert.h> |
19 | | #include <faiss/utils/distances.h> |
20 | | |
21 | | namespace faiss { |
22 | | |
23 | | using namespace nsg; |
24 | | |
25 | | /************************************************************** |
26 | | * IndexNSG implementation |
27 | | **************************************************************/ |
28 | | |
29 | 0 | IndexNSG::IndexNSG(int d, int R, MetricType metric) : Index(d, metric), nsg(R) { |
30 | 0 | nndescent_L = GK + 50; |
31 | 0 | } |
32 | | |
33 | | IndexNSG::IndexNSG(Index* storage, int R) |
34 | 0 | : Index(storage->d, storage->metric_type), |
35 | 0 | nsg(R), |
36 | 0 | storage(storage), |
37 | 0 | build_type(1) { |
38 | 0 | nndescent_L = GK + 50; |
39 | 0 | } |
40 | | |
41 | 0 | IndexNSG::~IndexNSG() { |
42 | 0 | if (own_fields) { |
43 | 0 | delete storage; |
44 | 0 | } |
45 | 0 | } |
46 | | |
47 | 0 | void IndexNSG::train(idx_t n, const float* x) { |
48 | 0 | FAISS_THROW_IF_NOT_MSG( |
49 | 0 | storage, |
50 | 0 | "Please use IndexNSGFlat (or variants) instead of IndexNSG directly"); |
51 | | // nsg structure does not require training |
52 | 0 | storage->train(n, x); |
53 | 0 | is_trained = true; |
54 | 0 | } |
55 | | |
56 | | void IndexNSG::search( |
57 | | idx_t n, |
58 | | const float* x, |
59 | | idx_t k, |
60 | | float* distances, |
61 | | idx_t* labels, |
62 | 0 | const SearchParameters* params) const { |
63 | 0 | FAISS_THROW_IF_NOT_MSG( |
64 | 0 | !params, "search params not supported for this index"); |
65 | 0 | FAISS_THROW_IF_NOT_MSG( |
66 | 0 | storage, |
67 | 0 | "Please use IndexNSGFlat (or variants) instead of IndexNSG directly"); |
68 | | |
69 | 0 | int L = std::max(nsg.search_L, (int)k); // in case of search L = -1 |
70 | 0 | idx_t check_period = InterruptCallback::get_period_hint(d * L); |
71 | |
|
72 | 0 | for (idx_t i0 = 0; i0 < n; i0 += check_period) { |
73 | 0 | idx_t i1 = std::min(i0 + check_period, n); |
74 | |
|
75 | 0 | #pragma omp parallel |
76 | 0 | { |
77 | 0 | VisitedTable vt(ntotal); |
78 | |
|
79 | 0 | std::unique_ptr<DistanceComputer> dis( |
80 | 0 | storage_distance_computer(storage)); |
81 | |
|
82 | 0 | #pragma omp for |
83 | 0 | for (idx_t i = i0; i < i1; i++) { |
84 | 0 | idx_t* idxi = labels + i * k; |
85 | 0 | float* simi = distances + i * k; |
86 | 0 | dis->set_query(x + i * d); |
87 | |
|
88 | 0 | nsg.search(*dis, k, idxi, simi, vt); |
89 | |
|
90 | 0 | vt.advance(); |
91 | 0 | } |
92 | 0 | } |
93 | 0 | InterruptCallback::check(); |
94 | 0 | } |
95 | |
|
96 | 0 | if (is_similarity_metric(metric_type)) { |
97 | | // we need to revert the negated distances |
98 | 0 | for (size_t i = 0; i < k * n; i++) { |
99 | 0 | distances[i] = -distances[i]; |
100 | 0 | } |
101 | 0 | } |
102 | 0 | } |
103 | | |
104 | 0 | void IndexNSG::build(idx_t n, const float* x, idx_t* knn_graph, int GK_2) { |
105 | 0 | FAISS_THROW_IF_NOT_MSG( |
106 | 0 | storage, |
107 | 0 | "Please use IndexNSGFlat (or variants) instead of IndexNSG directly"); |
108 | 0 | FAISS_THROW_IF_NOT_MSG( |
109 | 0 | !is_built && ntotal == 0, "The IndexNSG is already built"); |
110 | | |
111 | 0 | storage->add(n, x); |
112 | 0 | ntotal = storage->ntotal; |
113 | | |
114 | | // check the knn graph |
115 | 0 | check_knn_graph(knn_graph, n, GK_2); |
116 | |
|
117 | 0 | const nsg::Graph<idx_t> knng(knn_graph, n, GK_2); |
118 | 0 | nsg.build(storage, n, knng, verbose); |
119 | 0 | is_built = true; |
120 | 0 | } |
121 | | |
122 | 0 | void IndexNSG::add(idx_t n, const float* x) { |
123 | 0 | FAISS_THROW_IF_NOT_MSG( |
124 | 0 | storage, |
125 | 0 | "Please use IndexNSGFlat (or variants) " |
126 | 0 | "instead of IndexNSG directly"); |
127 | 0 | FAISS_THROW_IF_NOT(is_trained); |
128 | | |
129 | 0 | FAISS_THROW_IF_NOT_MSG( |
130 | 0 | !is_built && ntotal == 0, |
131 | 0 | "NSG does not support incremental addition"); |
132 | | |
133 | 0 | std::vector<idx_t> knng; |
134 | 0 | if (verbose) { |
135 | 0 | printf("IndexNSG::add %zd vectors\n", size_t(n)); |
136 | 0 | } |
137 | |
|
138 | 0 | if (build_type == 0) { // build with brute force search |
139 | |
|
140 | 0 | if (verbose) { |
141 | 0 | printf(" Build knn graph with brute force search on storage index\n"); |
142 | 0 | } |
143 | |
|
144 | 0 | storage->add(n, x); |
145 | 0 | ntotal = storage->ntotal; |
146 | 0 | FAISS_THROW_IF_NOT(ntotal == n); |
147 | | |
148 | 0 | knng.resize(ntotal * (GK + 1)); |
149 | 0 | storage->assign(ntotal, x, knng.data(), GK + 1); |
150 | | |
151 | | // Remove itself |
152 | | // - For metric distance, we just need to remove the first neighbor |
153 | | // - But for non-metric, e.g. inner product, we need to check |
154 | | // - each neighbor |
155 | 0 | if (storage->metric_type == METRIC_INNER_PRODUCT) { |
156 | 0 | for (idx_t i = 0; i < ntotal; i++) { |
157 | 0 | int count = 0; |
158 | 0 | for (int j = 0; j < GK + 1; j++) { |
159 | 0 | idx_t id = knng[i * (GK + 1) + j]; |
160 | 0 | if (id != i) { |
161 | 0 | knng[i * GK + count] = id; |
162 | 0 | count += 1; |
163 | 0 | } |
164 | 0 | if (count == GK) { |
165 | 0 | break; |
166 | 0 | } |
167 | 0 | } |
168 | 0 | } |
169 | 0 | } else { |
170 | 0 | for (idx_t i = 0; i < ntotal; i++) { |
171 | 0 | memmove(knng.data() + i * GK, |
172 | 0 | knng.data() + i * (GK + 1) + 1, |
173 | 0 | GK * sizeof(idx_t)); |
174 | 0 | } |
175 | 0 | } |
176 | |
|
177 | 0 | } else if (build_type == 1) { // build with NNDescent |
178 | 0 | IndexNNDescent index(storage, GK); |
179 | 0 | index.nndescent.S = nndescent_S; |
180 | 0 | index.nndescent.R = nndescent_R; |
181 | 0 | index.nndescent.L = std::max(nndescent_L, GK + 50); |
182 | 0 | index.nndescent.iter = nndescent_iter; |
183 | 0 | index.verbose = verbose; |
184 | |
|
185 | 0 | if (verbose) { |
186 | 0 | printf(" Build knn graph with NNdescent S=%d R=%d L=%d niter=%d\n", |
187 | 0 | index.nndescent.S, |
188 | 0 | index.nndescent.R, |
189 | 0 | index.nndescent.L, |
190 | 0 | index.nndescent.iter); |
191 | 0 | } |
192 | | |
193 | | // prevent IndexNSG from deleting the storage |
194 | 0 | index.own_fields = false; |
195 | |
|
196 | 0 | index.add(n, x); |
197 | | |
198 | | // storage->add is already implicit called in IndexNSG.add |
199 | 0 | ntotal = storage->ntotal; |
200 | 0 | FAISS_THROW_IF_NOT(ntotal == n); |
201 | | |
202 | 0 | knng.resize(ntotal * GK); |
203 | | |
204 | | // cast from idx_t to int |
205 | 0 | const int* knn_graph = index.nndescent.final_graph.data(); |
206 | 0 | #pragma omp parallel for |
207 | 0 | for (idx_t i = 0; i < ntotal * GK; i++) { |
208 | 0 | knng[i] = knn_graph[i]; |
209 | 0 | } |
210 | 0 | } else { |
211 | 0 | FAISS_THROW_MSG("build_type should be 0 or 1"); |
212 | 0 | } |
213 | | |
214 | 0 | if (verbose) { |
215 | 0 | printf(" Check the knn graph\n"); |
216 | 0 | } |
217 | | |
218 | | // check the knn graph |
219 | 0 | check_knn_graph(knng.data(), n, GK); |
220 | |
|
221 | 0 | if (verbose) { |
222 | 0 | printf(" nsg building\n"); |
223 | 0 | } |
224 | |
|
225 | 0 | const nsg::Graph<idx_t> knn_graph(knng.data(), n, GK); |
226 | 0 | nsg.build(storage, n, knn_graph, verbose); |
227 | 0 | is_built = true; |
228 | 0 | } |
229 | | |
230 | 0 | void IndexNSG::reset() { |
231 | 0 | nsg.reset(); |
232 | 0 | storage->reset(); |
233 | 0 | ntotal = 0; |
234 | 0 | is_built = false; |
235 | 0 | } |
236 | | |
237 | 0 | void IndexNSG::reconstruct(idx_t key, float* recons) const { |
238 | 0 | storage->reconstruct(key, recons); |
239 | 0 | } |
240 | | |
241 | 0 | void IndexNSG::check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const { |
242 | 0 | idx_t total_count = 0; |
243 | |
|
244 | 0 | #pragma omp parallel for reduction(+ : total_count) |
245 | 0 | for (idx_t i = 0; i < n; i++) { |
246 | 0 | int count = 0; |
247 | 0 | for (int j = 0; j < K; j++) { |
248 | 0 | idx_t id = knn_graph[i * K + j]; |
249 | 0 | if (id < 0 || id >= n || id == i) { |
250 | 0 | count += 1; |
251 | 0 | } |
252 | 0 | } |
253 | 0 | total_count += count; |
254 | 0 | } |
255 | |
|
256 | 0 | if (total_count > 0) { |
257 | 0 | fprintf(stderr, |
258 | 0 | "WARNING: the input knn graph " |
259 | 0 | "has %" PRId64 " invalid entries\n", |
260 | 0 | total_count); |
261 | 0 | } |
262 | 0 | FAISS_THROW_IF_NOT_MSG( |
263 | 0 | total_count < n / 10, |
264 | 0 | "There are too much invalid entries in the knn graph. " |
265 | 0 | "It may be an invalid knn graph."); |
266 | 0 | } |
267 | | |
268 | | /************************************************************** |
269 | | * IndexNSGFlat implementation |
270 | | **************************************************************/ |
271 | | |
272 | 0 | IndexNSGFlat::IndexNSGFlat() { |
273 | 0 | is_trained = true; |
274 | 0 | } |
275 | | |
276 | | IndexNSGFlat::IndexNSGFlat(int d, int R, MetricType metric) |
277 | 0 | : IndexNSG(new IndexFlat(d, metric), R) { |
278 | 0 | own_fields = true; |
279 | 0 | is_trained = true; |
280 | 0 | } |
281 | | |
282 | | /************************************************************** |
283 | | * IndexNSGPQ implementation |
284 | | **************************************************************/ |
285 | | |
286 | 0 | IndexNSGPQ::IndexNSGPQ() = default; |
287 | | |
288 | | IndexNSGPQ::IndexNSGPQ(int d, int pq_m, int M, int pq_nbits) |
289 | 0 | : IndexNSG(new IndexPQ(d, pq_m, pq_nbits), M) { |
290 | 0 | own_fields = true; |
291 | 0 | is_trained = false; |
292 | 0 | } |
293 | | |
294 | 0 | void IndexNSGPQ::train(idx_t n, const float* x) { |
295 | 0 | IndexNSG::train(n, x); |
296 | 0 | (dynamic_cast<IndexPQ*>(storage))->pq.compute_sdc_table(); |
297 | 0 | } |
298 | | |
299 | | /************************************************************** |
300 | | * IndexNSGSQ implementation |
301 | | **************************************************************/ |
302 | | |
303 | | IndexNSGSQ::IndexNSGSQ( |
304 | | int d, |
305 | | ScalarQuantizer::QuantizerType qtype, |
306 | | int M, |
307 | | MetricType metric) |
308 | 0 | : IndexNSG(new IndexScalarQuantizer(d, qtype, metric), M) { |
309 | 0 | is_trained = this->storage->is_trained; |
310 | 0 | own_fields = true; |
311 | 0 | } |
312 | | |
313 | 0 | IndexNSGSQ::IndexNSGSQ() = default; |
314 | | |
315 | | } // namespace faiss |