/root/doris/contrib/faiss/faiss/impl/NNDescent.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/impl/NNDescent.h> |
11 | | |
12 | | #include <mutex> |
13 | | #include <string> |
14 | | |
15 | | #include <faiss/impl/AuxIndexStructures.h> |
16 | | #include <faiss/impl/DistanceComputer.h> |
17 | | |
18 | | namespace faiss { |
19 | | |
20 | | using LockGuard = std::lock_guard<std::mutex>; |
21 | | |
22 | | namespace nndescent { |
23 | | |
24 | | void gen_random(std::mt19937& rng, int* addr, const int size, const int N); |
25 | | |
26 | 0 | Nhood::Nhood(int l, int s, std::mt19937& rng, int N) { |
27 | 0 | M = s; |
28 | 0 | nn_new.resize(s * 2); |
29 | 0 | gen_random(rng, nn_new.data(), (int)nn_new.size(), N); |
30 | 0 | } |
31 | | |
32 | | /// Copy operator |
33 | 0 | Nhood& Nhood::operator=(const Nhood& other) { |
34 | 0 | M = other.M; |
35 | 0 | std::copy( |
36 | 0 | other.nn_new.begin(), |
37 | 0 | other.nn_new.end(), |
38 | 0 | std::back_inserter(nn_new)); |
39 | 0 | nn_new.reserve(other.nn_new.capacity()); |
40 | 0 | pool.reserve(other.pool.capacity()); |
41 | 0 | return *this; |
42 | 0 | } |
43 | | |
44 | | /// Copy constructor |
45 | 0 | Nhood::Nhood(const Nhood& other) { |
46 | 0 | M = other.M; |
47 | 0 | std::copy( |
48 | 0 | other.nn_new.begin(), |
49 | 0 | other.nn_new.end(), |
50 | 0 | std::back_inserter(nn_new)); |
51 | 0 | nn_new.reserve(other.nn_new.capacity()); |
52 | 0 | pool.reserve(other.pool.capacity()); |
53 | 0 | } |
54 | | |
55 | | /// Insert a point into the candidate pool |
56 | 0 | void Nhood::insert(int id, float dist) { |
57 | 0 | LockGuard guard(lock); |
58 | 0 | if (dist > pool.front().distance) |
59 | 0 | return; |
60 | 0 | for (int i = 0; i < pool.size(); i++) { |
61 | 0 | if (id == pool[i].id) |
62 | 0 | return; |
63 | 0 | } |
64 | 0 | if (pool.size() < pool.capacity()) { |
65 | 0 | pool.push_back(Neighbor(id, dist, true)); |
66 | 0 | std::push_heap(pool.begin(), pool.end()); |
67 | 0 | } else { |
68 | 0 | std::pop_heap(pool.begin(), pool.end()); |
69 | 0 | pool[pool.size() - 1] = Neighbor(id, dist, true); |
70 | 0 | std::push_heap(pool.begin(), pool.end()); |
71 | 0 | } |
72 | 0 | } |
73 | | |
74 | | /// In local join, two objects are compared only if at least |
75 | | /// one of them is new. |
76 | | template <typename C> |
77 | 0 | void Nhood::join(C callback) const { |
78 | 0 | for (int const i : nn_new) { |
79 | 0 | for (int const j : nn_new) { |
80 | 0 | if (i < j) { |
81 | 0 | callback(i, j); |
82 | 0 | } |
83 | 0 | } |
84 | 0 | for (int j : nn_old) { |
85 | 0 | callback(i, j); |
86 | 0 | } |
87 | 0 | } |
88 | 0 | } |
89 | | |
90 | 0 | void gen_random(std::mt19937& rng, int* addr, const int size, const int N) { |
91 | 0 | for (int i = 0; i < size; ++i) { |
92 | 0 | addr[i] = rng() % (N - size); |
93 | 0 | } |
94 | 0 | std::sort(addr, addr + size); |
95 | 0 | for (int i = 1; i < size; ++i) { |
96 | 0 | if (addr[i] <= addr[i - 1]) { |
97 | 0 | addr[i] = addr[i - 1] + 1; |
98 | 0 | } |
99 | 0 | } |
100 | 0 | int off = rng() % N; |
101 | 0 | for (int i = 0; i < size; ++i) { |
102 | 0 | addr[i] = (addr[i] + off) % N; |
103 | 0 | } |
104 | 0 | } |
105 | | |
106 | | // Insert a new point into the candidate pool in ascending order |
107 | 0 | int insert_into_pool(Neighbor* addr, int size, Neighbor nn) { |
108 | | // find the location to insert |
109 | 0 | int left = 0, right = size - 1; |
110 | 0 | if (addr[left].distance > nn.distance) { |
111 | 0 | memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor)); |
112 | 0 | addr[left] = nn; |
113 | 0 | return left; |
114 | 0 | } |
115 | 0 | if (addr[right].distance < nn.distance) { |
116 | 0 | addr[size] = nn; |
117 | 0 | return size; |
118 | 0 | } |
119 | 0 | while (left < right - 1) { |
120 | 0 | int mid = (left + right) / 2; |
121 | 0 | if (addr[mid].distance > nn.distance) |
122 | 0 | right = mid; |
123 | 0 | else |
124 | 0 | left = mid; |
125 | 0 | } |
126 | | // check equal ID |
127 | |
|
128 | 0 | while (left > 0) { |
129 | 0 | if (addr[left].distance < nn.distance) |
130 | 0 | break; |
131 | 0 | if (addr[left].id == nn.id) |
132 | 0 | return size + 1; |
133 | 0 | left--; |
134 | 0 | } |
135 | 0 | if (addr[left].id == nn.id || addr[right].id == nn.id) |
136 | 0 | return size + 1; |
137 | 0 | memmove((char*)&addr[right + 1], |
138 | 0 | &addr[right], |
139 | 0 | (size - right) * sizeof(Neighbor)); |
140 | 0 | addr[right] = nn; |
141 | 0 | return right; |
142 | 0 | } |
143 | | |
144 | | } // namespace nndescent |
145 | | |
146 | | using namespace nndescent; |
147 | | |
148 | | constexpr int NUM_EVAL_POINTS = 100; |
149 | | |
150 | 0 | NNDescent::NNDescent(const int d, const int K) : K(K), d(d) { |
151 | 0 | L = K + 50; |
152 | 0 | } |
153 | | |
154 | 0 | NNDescent::~NNDescent() {} |
155 | | |
156 | 0 | void NNDescent::join(DistanceComputer& qdis) { |
157 | 0 | idx_t check_period = InterruptCallback::get_period_hint(d * search_L); |
158 | 0 | for (idx_t i0 = 0; i0 < (idx_t)ntotal; i0 += check_period) { |
159 | 0 | idx_t i1 = std::min(i0 + check_period, (idx_t)ntotal); |
160 | 0 | #pragma omp parallel for default(shared) schedule(dynamic, 100) |
161 | 0 | for (idx_t n = i0; n < i1; n++) { |
162 | 0 | graph[n].join([&](int i, int j) { |
163 | 0 | if (i != j) { |
164 | 0 | float dist = qdis.symmetric_dis(i, j); |
165 | 0 | graph[i].insert(j, dist); |
166 | 0 | graph[j].insert(i, dist); |
167 | 0 | } |
168 | 0 | }); |
169 | 0 | } |
170 | 0 | InterruptCallback::check(); |
171 | 0 | } |
172 | 0 | } |
173 | | |
174 | | /// Sample neighbors for each node to peform local join later |
175 | | /// Store them in nn_new and nn_old |
176 | 0 | void NNDescent::update() { |
177 | | // Step 1. |
178 | | // Clear all nn_new and nn_old |
179 | 0 | #pragma omp parallel for |
180 | 0 | for (int i = 0; i < ntotal; i++) { |
181 | 0 | std::vector<int>().swap(graph[i].nn_new); |
182 | 0 | std::vector<int>().swap(graph[i].nn_old); |
183 | 0 | } |
184 | | |
185 | | // Step 2. |
186 | | // Compute the number of neighbors which is new i.e. flag is true |
187 | | // in the candidate pool. This must not exceed the sample number S. |
188 | | // That means We only select S new neighbors. |
189 | 0 | #pragma omp parallel for |
190 | 0 | for (int n = 0; n < ntotal; ++n) { |
191 | 0 | auto& nn = graph[n]; |
192 | 0 | std::sort(nn.pool.begin(), nn.pool.end()); |
193 | |
|
194 | 0 | if (nn.pool.size() > L) |
195 | 0 | nn.pool.resize(L); |
196 | 0 | nn.pool.reserve(L); // keep the pool size be L |
197 | |
|
198 | 0 | int maxl = std::min(nn.M + S, (int)nn.pool.size()); |
199 | 0 | int c = 0; |
200 | 0 | int l = 0; |
201 | |
|
202 | 0 | while ((l < maxl) && (c < S)) { |
203 | 0 | if (nn.pool[l].flag) { |
204 | 0 | ++c; |
205 | 0 | } |
206 | 0 | ++l; |
207 | 0 | } |
208 | 0 | nn.M = l; |
209 | 0 | } |
210 | | |
211 | | // Step 3. |
212 | | // Find reverse links for each node |
213 | | // Randomly choose R reverse links. |
214 | 0 | #pragma omp parallel |
215 | 0 | { |
216 | 0 | std::mt19937 rng(random_seed * 5081 + omp_get_thread_num()); |
217 | 0 | #pragma omp for |
218 | 0 | for (int n = 0; n < ntotal; ++n) { |
219 | 0 | auto& node = graph[n]; |
220 | 0 | auto& nn_new = node.nn_new; |
221 | 0 | auto& nn_old = node.nn_old; |
222 | |
|
223 | 0 | for (int l = 0; l < node.M; ++l) { |
224 | 0 | auto& nn = node.pool[l]; |
225 | 0 | auto& other = graph[nn.id]; // the other side of the edge |
226 | |
|
227 | 0 | if (nn.flag) { // the node is inserted newly |
228 | | // push the neighbor into nn_new |
229 | 0 | nn_new.push_back(nn.id); |
230 | | // push itself into other.rnn_new if it is not in |
231 | | // the candidate pool of the other side |
232 | 0 | if (nn.distance > other.pool.back().distance) { |
233 | 0 | LockGuard guard(other.lock); |
234 | 0 | if (other.rnn_new.size() < R) { |
235 | 0 | other.rnn_new.push_back(n); |
236 | 0 | } else { |
237 | 0 | int pos = rng() % R; |
238 | 0 | other.rnn_new[pos] = n; |
239 | 0 | } |
240 | 0 | } |
241 | 0 | nn.flag = false; |
242 | |
|
243 | 0 | } else { // the node is old |
244 | | // push the neighbor into nn_old |
245 | 0 | nn_old.push_back(nn.id); |
246 | | // push itself into other.rnn_old if it is not in |
247 | | // the candidate pool of the other side |
248 | 0 | if (nn.distance > other.pool.back().distance) { |
249 | 0 | LockGuard guard(other.lock); |
250 | 0 | if (other.rnn_old.size() < R) { |
251 | 0 | other.rnn_old.push_back(n); |
252 | 0 | } else { |
253 | 0 | int pos = rng() % R; |
254 | 0 | other.rnn_old[pos] = n; |
255 | 0 | } |
256 | 0 | } |
257 | 0 | } |
258 | 0 | } |
259 | | // make heap to join later (in join() function) |
260 | 0 | std::make_heap(node.pool.begin(), node.pool.end()); |
261 | 0 | } |
262 | 0 | } |
263 | | |
264 | | // Step 4. |
265 | | // Combine the forward and the reverse links |
266 | | // R = 0 means no reverse links are used. |
267 | 0 | #pragma omp parallel for |
268 | 0 | for (int i = 0; i < ntotal; ++i) { |
269 | 0 | auto& nn_new = graph[i].nn_new; |
270 | 0 | auto& nn_old = graph[i].nn_old; |
271 | 0 | auto& rnn_new = graph[i].rnn_new; |
272 | 0 | auto& rnn_old = graph[i].rnn_old; |
273 | |
|
274 | 0 | nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end()); |
275 | 0 | nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end()); |
276 | 0 | if (nn_old.size() > R * 2) { |
277 | 0 | nn_old.resize(R * 2); |
278 | 0 | nn_old.reserve(R * 2); |
279 | 0 | } |
280 | |
|
281 | 0 | std::vector<int>().swap(graph[i].rnn_new); |
282 | 0 | std::vector<int>().swap(graph[i].rnn_old); |
283 | 0 | } |
284 | 0 | } |
285 | | |
286 | 0 | void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) { |
287 | 0 | int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal); |
288 | 0 | std::vector<int> eval_points(num_eval_points); |
289 | 0 | std::vector<std::vector<int>> acc_eval_set(num_eval_points); |
290 | 0 | std::mt19937 rng(random_seed * 6577 + omp_get_thread_num()); |
291 | 0 | gen_random(rng, eval_points.data(), eval_points.size(), ntotal); |
292 | 0 | generate_eval_set(qdis, eval_points, acc_eval_set, ntotal); |
293 | 0 | for (int it = 0; it < iter; it++) { |
294 | 0 | join(qdis); |
295 | 0 | update(); |
296 | |
|
297 | 0 | if (verbose) { |
298 | 0 | float recall = eval_recall(eval_points, acc_eval_set); |
299 | 0 | printf("Iter: %d, recall@%d: %lf\n", it, K, recall); |
300 | 0 | } |
301 | 0 | } |
302 | 0 | } |
303 | | |
304 | | /// Sample a small number of points to evaluate the quality of KNNG built |
305 | | void NNDescent::generate_eval_set( |
306 | | DistanceComputer& qdis, |
307 | | std::vector<int>& c, |
308 | | std::vector<std::vector<int>>& v, |
309 | 0 | int N) { |
310 | 0 | #pragma omp parallel for |
311 | 0 | for (int i = 0; i < c.size(); i++) { |
312 | 0 | std::vector<Neighbor> tmp; |
313 | 0 | for (int j = 0; j < N; j++) { |
314 | 0 | if (c[i] == j) { |
315 | 0 | continue; // skip itself |
316 | 0 | } |
317 | 0 | float dist = qdis.symmetric_dis(c[i], j); |
318 | 0 | tmp.push_back(Neighbor(j, dist, true)); |
319 | 0 | } |
320 | |
|
321 | 0 | std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end()); |
322 | 0 | for (int j = 0; j < K; j++) { |
323 | 0 | v[i].push_back(tmp[j].id); |
324 | 0 | } |
325 | 0 | } |
326 | 0 | } |
327 | | |
328 | | /// Evaluate the quality of KNNG built |
329 | | float NNDescent::eval_recall( |
330 | | std::vector<int>& eval_points, |
331 | 0 | std::vector<std::vector<int>>& acc_eval_set) { |
332 | 0 | float mean_acc = 0.0f; |
333 | 0 | for (size_t i = 0; i < eval_points.size(); i++) { |
334 | 0 | float acc = 0; |
335 | 0 | std::vector<Neighbor>& g = graph[eval_points[i]].pool; |
336 | 0 | std::vector<int>& v = acc_eval_set[i]; |
337 | 0 | for (size_t j = 0; j < g.size(); j++) { |
338 | 0 | for (size_t k = 0; k < v.size(); k++) { |
339 | 0 | if (g[j].id == v[k]) { |
340 | 0 | acc++; |
341 | 0 | break; |
342 | 0 | } |
343 | 0 | } |
344 | 0 | } |
345 | 0 | mean_acc += acc / v.size(); |
346 | 0 | } |
347 | 0 | return mean_acc / eval_points.size(); |
348 | 0 | } |
349 | | |
350 | | /// Initialize the KNN graph randomly |
351 | 0 | void NNDescent::init_graph(DistanceComputer& qdis) { |
352 | 0 | graph.reserve(ntotal); |
353 | 0 | { |
354 | 0 | std::mt19937 rng(random_seed * 6007); |
355 | 0 | for (int i = 0; i < ntotal; i++) { |
356 | 0 | graph.push_back(Nhood(L, S, rng, (int)ntotal)); |
357 | 0 | } |
358 | 0 | } |
359 | 0 | #pragma omp parallel |
360 | 0 | { |
361 | 0 | std::mt19937 rng(random_seed * 7741 + omp_get_thread_num()); |
362 | 0 | #pragma omp for |
363 | 0 | for (int i = 0; i < ntotal; i++) { |
364 | 0 | std::vector<int> tmp(S); |
365 | |
|
366 | 0 | gen_random(rng, tmp.data(), S, ntotal); |
367 | |
|
368 | 0 | for (int j = 0; j < S; j++) { |
369 | 0 | int id = tmp[j]; |
370 | 0 | if (id == i) { |
371 | 0 | continue; |
372 | 0 | } |
373 | 0 | float dist = qdis.symmetric_dis(i, id); |
374 | |
|
375 | 0 | graph[i].pool.push_back(Neighbor(id, dist, true)); |
376 | 0 | } |
377 | 0 | std::make_heap(graph[i].pool.begin(), graph[i].pool.end()); |
378 | 0 | graph[i].pool.reserve(L); |
379 | 0 | } |
380 | 0 | } |
381 | 0 | } |
382 | | |
383 | 0 | void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) { |
384 | 0 | FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build"); |
385 | 0 | FAISS_THROW_IF_NOT_FMT( |
386 | 0 | n > NUM_EVAL_POINTS, |
387 | 0 | "NNDescent.build cannot build a graph smaller than %d", |
388 | 0 | int(NUM_EVAL_POINTS)); |
389 | | |
390 | 0 | if (verbose) { |
391 | 0 | printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n", |
392 | 0 | K, |
393 | 0 | S, |
394 | 0 | R, |
395 | 0 | L, |
396 | 0 | iter); |
397 | 0 | } |
398 | |
|
399 | 0 | ntotal = n; |
400 | 0 | init_graph(qdis); |
401 | 0 | nndescent(qdis, verbose); |
402 | |
|
403 | 0 | final_graph.resize(uint64_t(ntotal) * K); |
404 | | |
405 | | // Store the neighbor link structure into final_graph |
406 | | // Clear the old graph |
407 | 0 | for (int i = 0; i < ntotal; i++) { |
408 | 0 | std::sort(graph[i].pool.begin(), graph[i].pool.end()); |
409 | 0 | for (int j = 0; j < K; j++) { |
410 | 0 | FAISS_ASSERT(graph[i].pool[j].id < ntotal); |
411 | 0 | final_graph[i * K + j] = graph[i].pool[j].id; |
412 | 0 | } |
413 | 0 | } |
414 | 0 | std::vector<Nhood>().swap(graph); |
415 | 0 | has_built = true; |
416 | |
|
417 | 0 | if (verbose) { |
418 | 0 | printf("Added %d points into the index\n", ntotal); |
419 | 0 | } |
420 | 0 | } |
421 | | |
422 | | void NNDescent::search( |
423 | | DistanceComputer& qdis, |
424 | | const int topk, |
425 | | idx_t* indices, |
426 | | float* dists, |
427 | 0 | VisitedTable& vt) const { |
428 | 0 | FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet."); |
429 | 0 | int L_2 = std::max(search_L, topk); |
430 | | |
431 | | // candidate pool, the K best items is the result. |
432 | 0 | std::vector<Neighbor> retset(L_2 + 1); |
433 | | |
434 | | // Randomly choose L_2 points to initialize the candidate pool |
435 | 0 | std::vector<int> init_ids(L_2); |
436 | 0 | std::mt19937 rng(random_seed); |
437 | |
|
438 | 0 | gen_random(rng, init_ids.data(), L_2, ntotal); |
439 | 0 | for (int i = 0; i < L_2; i++) { |
440 | 0 | int id = init_ids[i]; |
441 | 0 | float dist = qdis(id); |
442 | 0 | retset[i] = Neighbor(id, dist, true); |
443 | 0 | } |
444 | | |
445 | | // Maintain the candidate pool in ascending order |
446 | 0 | std::sort(retset.begin(), retset.begin() + L_2); |
447 | |
|
448 | 0 | int k = 0; |
449 | | |
450 | | // Stop until the smallest position updated is >= L_2 |
451 | 0 | while (k < L_2) { |
452 | 0 | int nk = L_2; |
453 | |
|
454 | 0 | if (retset[k].flag) { |
455 | 0 | retset[k].flag = false; |
456 | 0 | int n = retset[k].id; |
457 | |
|
458 | 0 | for (int m = 0; m < K; ++m) { |
459 | 0 | int id = final_graph[n * K + m]; |
460 | 0 | if (vt.get(id)) { |
461 | 0 | continue; |
462 | 0 | } |
463 | | |
464 | 0 | vt.set(id); |
465 | 0 | float dist = qdis(id); |
466 | 0 | if (dist >= retset[L_2 - 1].distance) { |
467 | 0 | continue; |
468 | 0 | } |
469 | | |
470 | 0 | Neighbor nn(id, dist, true); |
471 | 0 | int r = insert_into_pool(retset.data(), L_2, nn); |
472 | |
|
473 | 0 | if (r < nk) |
474 | 0 | nk = r; |
475 | 0 | } |
476 | 0 | } |
477 | 0 | if (nk <= k) { |
478 | 0 | k = nk; |
479 | 0 | } else { |
480 | 0 | ++k; |
481 | 0 | } |
482 | 0 | } |
483 | 0 | for (size_t i = 0; i < topk; i++) { |
484 | 0 | indices[i] = retset[i].id; |
485 | 0 | dists[i] = retset[i].distance; |
486 | 0 | } |
487 | |
|
488 | 0 | vt.advance(); |
489 | 0 | } |
490 | | |
491 | 0 | void NNDescent::reset() { |
492 | 0 | has_built = false; |
493 | 0 | ntotal = 0; |
494 | 0 | final_graph.resize(0); |
495 | 0 | } |
496 | | |
497 | | } // namespace faiss |