/root/doris/contrib/faiss/faiss/impl/NSG.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/NSG.h> |
9 | | |
10 | | #include <algorithm> |
11 | | #include <memory> |
12 | | #include <mutex> |
13 | | #include <stack> |
14 | | |
15 | | #include <faiss/impl/DistanceComputer.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | namespace { |
20 | | |
21 | | using LockGuard = std::lock_guard<std::mutex>; |
22 | | |
23 | | // It needs to be smaller than 0 |
24 | | constexpr int EMPTY_ID = -1; |
25 | | |
26 | | } // anonymous namespace |
27 | | |
28 | | namespace nsg { |
29 | | |
30 | 0 | DistanceComputer* storage_distance_computer(const Index* storage) { |
31 | 0 | if (is_similarity_metric(storage->metric_type)) { |
32 | 0 | return new NegativeDistanceComputer(storage->get_distance_computer()); |
33 | 0 | } else { |
34 | 0 | return storage->get_distance_computer(); |
35 | 0 | } |
36 | 0 | } |
37 | | |
38 | | struct Neighbor { |
39 | | int32_t id; |
40 | | float distance; |
41 | | bool flag; |
42 | | |
43 | | Neighbor() = default; |
44 | | Neighbor(int id, float distance, bool f) |
45 | 0 | : id(id), distance(distance), flag(f) {} |
46 | | |
47 | 0 | inline bool operator<(const Neighbor& other) const { |
48 | 0 | return distance < other.distance; |
49 | 0 | } |
50 | | }; |
51 | | |
52 | | struct Node { |
53 | | int32_t id; |
54 | | float distance; |
55 | | |
56 | | Node() = default; |
57 | 0 | Node(int id, float distance) : id(id), distance(distance) {} |
58 | | |
59 | 0 | inline bool operator<(const Node& other) const { |
60 | 0 | return distance < other.distance; |
61 | 0 | } |
62 | | |
63 | | // to keep the compiler happy |
64 | 0 | inline bool operator<(int other) const { |
65 | 0 | return id < other; |
66 | 0 | } |
67 | | }; |
68 | | |
69 | 0 | inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) { |
70 | | // find the location to insert |
71 | 0 | int left = 0, right = K - 1; |
72 | 0 | if (addr[left].distance > nn.distance) { |
73 | 0 | memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor)); |
74 | 0 | addr[left] = nn; |
75 | 0 | return left; |
76 | 0 | } |
77 | 0 | if (addr[right].distance < nn.distance) { |
78 | 0 | addr[K] = nn; |
79 | 0 | return K; |
80 | 0 | } |
81 | 0 | while (left < right - 1) { |
82 | 0 | int mid = (left + right) / 2; |
83 | 0 | if (addr[mid].distance > nn.distance) { |
84 | 0 | right = mid; |
85 | 0 | } else { |
86 | 0 | left = mid; |
87 | 0 | } |
88 | 0 | } |
89 | | // check equal ID |
90 | |
|
91 | 0 | while (left > 0) { |
92 | 0 | if (addr[left].distance < nn.distance) { |
93 | 0 | break; |
94 | 0 | } |
95 | 0 | if (addr[left].id == nn.id) { |
96 | 0 | return K + 1; |
97 | 0 | } |
98 | 0 | left--; |
99 | 0 | } |
100 | 0 | if (addr[left].id == nn.id || addr[right].id == nn.id) { |
101 | 0 | return K + 1; |
102 | 0 | } |
103 | 0 | memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor)); |
104 | 0 | addr[right] = nn; |
105 | 0 | return right; |
106 | 0 | } |
107 | | |
108 | | } // namespace nsg |
109 | | |
110 | | using namespace nsg; |
111 | | |
112 | 0 | NSG::NSG(int R) : R(R), rng(0x0903) { |
113 | 0 | L = R + 32; |
114 | 0 | C = R + 100; |
115 | 0 | srand(0x1998); |
116 | 0 | } |
117 | | |
118 | | void NSG::search( |
119 | | DistanceComputer& dis, |
120 | | int k, |
121 | | idx_t* I, |
122 | | float* D, |
123 | 0 | VisitedTable& vt) const { |
124 | 0 | FAISS_THROW_IF_NOT(is_built); |
125 | 0 | FAISS_THROW_IF_NOT(final_graph); |
126 | | |
127 | 0 | int pool_size = std::max(search_L, k); |
128 | 0 | std::vector<Neighbor> retset; |
129 | 0 | std::vector<Node> tmp; |
130 | 0 | search_on_graph<false>( |
131 | 0 | *final_graph, dis, vt, enterpoint, pool_size, retset, tmp); |
132 | |
|
133 | 0 | for (size_t i = 0; i < k; i++) { |
134 | 0 | I[i] = retset[i].id; |
135 | 0 | D[i] = retset[i].distance; |
136 | 0 | } |
137 | 0 | } |
138 | | |
139 | | void NSG::build( |
140 | | Index* storage, |
141 | | idx_t n, |
142 | | const nsg::Graph<idx_t>& knn_graph, |
143 | 0 | bool verbose) { |
144 | 0 | FAISS_THROW_IF_NOT(!is_built && ntotal == 0); |
145 | | |
146 | 0 | if (verbose) { |
147 | 0 | printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C); |
148 | 0 | } |
149 | |
|
150 | 0 | ntotal = n; |
151 | 0 | init_graph(storage, knn_graph); |
152 | |
|
153 | 0 | std::vector<int> degrees(n, 0); |
154 | 0 | { |
155 | 0 | nsg::Graph<Node> tmp_graph(n, R); |
156 | |
|
157 | 0 | link(storage, knn_graph, tmp_graph, verbose); |
158 | |
|
159 | 0 | final_graph = std::make_shared<nsg::Graph<int>>(n, R); |
160 | 0 | std::fill_n(final_graph->data, n * R, EMPTY_ID); |
161 | |
|
162 | 0 | #pragma omp parallel for |
163 | 0 | for (int i = 0; i < n; i++) { |
164 | 0 | int cnt = 0; |
165 | 0 | for (int j = 0; j < R; j++) { |
166 | 0 | int id = tmp_graph.at(i, j).id; |
167 | 0 | if (id != EMPTY_ID) { |
168 | 0 | final_graph->at(i, cnt) = id; |
169 | 0 | cnt += 1; |
170 | 0 | } |
171 | 0 | degrees[i] = cnt; |
172 | 0 | } |
173 | 0 | } |
174 | 0 | } |
175 | |
|
176 | 0 | int num_attached = tree_grow(storage, degrees); |
177 | 0 | check_graph(); |
178 | 0 | is_built = true; |
179 | |
|
180 | 0 | if (verbose) { |
181 | 0 | int max = 0, min = 1e6; |
182 | 0 | double avg = 0; |
183 | |
|
184 | 0 | for (int i = 0; i < n; i++) { |
185 | 0 | int size = 0; |
186 | 0 | while (size < R && final_graph->at(i, size) != EMPTY_ID) { |
187 | 0 | size += 1; |
188 | 0 | } |
189 | 0 | max = std::max(size, max); |
190 | 0 | min = std::min(size, min); |
191 | 0 | avg += size; |
192 | 0 | } |
193 | |
|
194 | 0 | avg = avg / n; |
195 | 0 | printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n", |
196 | 0 | max, |
197 | 0 | min, |
198 | 0 | avg); |
199 | 0 | printf("Attached nodes: %d\n", num_attached); |
200 | 0 | } |
201 | 0 | } |
202 | | |
203 | 0 | void NSG::reset() { |
204 | 0 | final_graph.reset(); |
205 | 0 | ntotal = 0; |
206 | 0 | is_built = false; |
207 | 0 | } |
208 | | |
209 | 0 | void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) { |
210 | 0 | int d = storage->d; |
211 | 0 | int n = storage->ntotal; |
212 | |
|
213 | 0 | std::unique_ptr<float[]> center(new float[d]); |
214 | 0 | std::unique_ptr<float[]> tmp(new float[d]); |
215 | 0 | std::fill_n(center.get(), d, 0.0f); |
216 | |
|
217 | 0 | for (int i = 0; i < n; i++) { |
218 | 0 | storage->reconstruct(i, tmp.get()); |
219 | 0 | for (int j = 0; j < d; j++) { |
220 | 0 | center[j] += tmp[j]; |
221 | 0 | } |
222 | 0 | } |
223 | |
|
224 | 0 | for (int i = 0; i < d; i++) { |
225 | 0 | center[i] /= n; |
226 | 0 | } |
227 | |
|
228 | 0 | std::vector<Neighbor> retset; |
229 | 0 | std::vector<Node> tmpset; |
230 | | |
231 | | // random initialize navigating point |
232 | 0 | int ep = rng.rand_int(n); |
233 | 0 | std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage)); |
234 | |
|
235 | 0 | dis->set_query(center.get()); |
236 | 0 | VisitedTable vt(ntotal); |
237 | | |
238 | | // Do not collect the visited nodes |
239 | 0 | search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset); |
240 | | |
241 | | // set enterpoint |
242 | 0 | enterpoint = retset[0].id; |
243 | 0 | } |
244 | | |
245 | | template <bool collect_fullset, class index_t> |
246 | | void NSG::search_on_graph( |
247 | | const nsg::Graph<index_t>& graph, |
248 | | DistanceComputer& dis, |
249 | | VisitedTable& vt, |
250 | | int ep, |
251 | | int pool_size, |
252 | | std::vector<Neighbor>& retset, |
253 | 0 | std::vector<Node>& fullset) const { |
254 | 0 | RandomGenerator gen(0x1234); |
255 | 0 | retset.resize(pool_size + 1); |
256 | 0 | std::vector<int> init_ids(pool_size); |
257 | |
|
258 | 0 | int num_ids = 0; |
259 | 0 | std::vector<index_t> neighbors(graph.K); |
260 | 0 | size_t nneigh = graph.get_neighbors(ep, neighbors.data()); |
261 | 0 | for (int i = 0; i < init_ids.size() && i < nneigh; i++) { |
262 | 0 | int id = (int)neighbors[i]; |
263 | 0 | if (id >= ntotal) { |
264 | 0 | continue; |
265 | 0 | } |
266 | | |
267 | 0 | init_ids[i] = id; |
268 | 0 | vt.set(id); |
269 | 0 | num_ids += 1; |
270 | 0 | } |
271 | |
|
272 | 0 | while (num_ids < pool_size) { |
273 | 0 | int id = gen.rand_int(ntotal); |
274 | 0 | if (vt.get(id)) { |
275 | 0 | continue; |
276 | 0 | } |
277 | | |
278 | 0 | init_ids[num_ids] = id; |
279 | 0 | num_ids++; |
280 | 0 | vt.set(id); |
281 | 0 | } |
282 | |
|
283 | 0 | for (int i = 0; i < init_ids.size(); i++) { |
284 | 0 | int id = init_ids[i]; |
285 | |
|
286 | 0 | float dist = dis(id); |
287 | 0 | retset[i] = Neighbor(id, dist, true); |
288 | |
|
289 | 0 | if (collect_fullset) { |
290 | 0 | fullset.emplace_back(retset[i].id, retset[i].distance); |
291 | 0 | } |
292 | 0 | } |
293 | |
|
294 | 0 | std::sort(retset.begin(), retset.begin() + pool_size); |
295 | |
|
296 | 0 | int k = 0; |
297 | 0 | while (k < pool_size) { |
298 | 0 | int updated_pos = pool_size; |
299 | |
|
300 | 0 | if (retset[k].flag) { |
301 | 0 | retset[k].flag = false; |
302 | 0 | int n = retset[k].id; |
303 | |
|
304 | 0 | size_t nneigh_for_n = graph.get_neighbors(n, neighbors.data()); |
305 | 0 | for (int m = 0; m < nneigh_for_n; m++) { |
306 | 0 | int id = neighbors[m]; |
307 | 0 | if (id > ntotal || vt.get(id)) { |
308 | 0 | continue; |
309 | 0 | } |
310 | 0 | vt.set(id); |
311 | |
|
312 | 0 | float dist = dis(id); |
313 | 0 | Neighbor nn(id, dist, true); |
314 | 0 | if (collect_fullset) { |
315 | 0 | fullset.emplace_back(id, dist); |
316 | 0 | } |
317 | |
|
318 | 0 | if (dist >= retset[pool_size - 1].distance) { |
319 | 0 | continue; |
320 | 0 | } |
321 | | |
322 | 0 | int r = insert_into_pool(retset.data(), pool_size, nn); |
323 | |
|
324 | 0 | updated_pos = std::min(updated_pos, r); |
325 | 0 | } |
326 | 0 | } |
327 | |
|
328 | 0 | k = (updated_pos <= k) ? updated_pos : (k + 1); |
329 | 0 | } |
330 | 0 | } Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb0EiEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb0ElEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb1ElEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE Unexecuted instantiation: _ZNK5faiss3NSG15search_on_graphILb1EiEEvRKNS_3nsg5GraphIT0_EERNS_16DistanceComputerERNS_12VisitedTableEiiRSt6vectorINS2_8NeighborESaISD_EERSC_INS2_4NodeESaISH_EE |
331 | | |
332 | | void NSG::link( |
333 | | Index* storage, |
334 | | const nsg::Graph<idx_t>& knn_graph, |
335 | | nsg::Graph<Node>& graph, |
336 | 0 | bool /* verbose */) { |
337 | 0 | #pragma omp parallel |
338 | 0 | { |
339 | 0 | std::unique_ptr<float[]> vec(new float[storage->d]); |
340 | |
|
341 | 0 | std::vector<Node> pool; |
342 | 0 | std::vector<Neighbor> tmp; |
343 | |
|
344 | 0 | VisitedTable vt(ntotal); |
345 | 0 | std::unique_ptr<DistanceComputer> dis( |
346 | 0 | storage_distance_computer(storage)); |
347 | |
|
348 | 0 | #pragma omp for schedule(dynamic, 100) |
349 | 0 | for (int i = 0; i < ntotal; i++) { |
350 | 0 | storage->reconstruct(i, vec.get()); |
351 | 0 | dis->set_query(vec.get()); |
352 | | |
353 | | // Collect the visited nodes into pool |
354 | 0 | search_on_graph<true>( |
355 | 0 | knn_graph, *dis, vt, enterpoint, L, tmp, pool); |
356 | |
|
357 | 0 | sync_prune(i, pool, *dis, vt, knn_graph, graph); |
358 | |
|
359 | 0 | pool.clear(); |
360 | 0 | tmp.clear(); |
361 | 0 | vt.advance(); |
362 | 0 | } |
363 | 0 | } // omp parallel |
364 | |
|
365 | 0 | std::vector<std::mutex> locks(ntotal); |
366 | 0 | #pragma omp parallel |
367 | 0 | { |
368 | 0 | std::unique_ptr<DistanceComputer> dis( |
369 | 0 | storage_distance_computer(storage)); |
370 | |
|
371 | 0 | #pragma omp for schedule(dynamic, 100) |
372 | 0 | for (int i = 0; i < ntotal; ++i) { |
373 | 0 | add_reverse_links(i, locks, *dis, graph); |
374 | 0 | } |
375 | 0 | } // omp parallel |
376 | 0 | } |
377 | | |
378 | | void NSG::sync_prune( |
379 | | int q, |
380 | | std::vector<Node>& pool, |
381 | | DistanceComputer& dis, |
382 | | VisitedTable& vt, |
383 | | const nsg::Graph<idx_t>& knn_graph, |
384 | 0 | nsg::Graph<Node>& graph) { |
385 | 0 | for (int i = 0; i < knn_graph.K; i++) { |
386 | 0 | int id = knn_graph.at(q, i); |
387 | 0 | if (id < 0 || id >= ntotal || vt.get(id)) { |
388 | 0 | continue; |
389 | 0 | } |
390 | | |
391 | 0 | float dist = dis.symmetric_dis(q, id); |
392 | 0 | pool.emplace_back(id, dist); |
393 | 0 | } |
394 | |
|
395 | 0 | std::sort(pool.begin(), pool.end()); |
396 | |
|
397 | 0 | std::vector<Node> result; |
398 | |
|
399 | 0 | int start = 0; |
400 | 0 | if (pool[start].id == q) { |
401 | 0 | start++; |
402 | 0 | } |
403 | 0 | result.push_back(pool[start]); |
404 | |
|
405 | 0 | while (result.size() < R && (++start) < pool.size() && start < C) { |
406 | 0 | auto& p = pool[start]; |
407 | 0 | bool occlude = false; |
408 | 0 | for (int t = 0; t < result.size(); t++) { |
409 | 0 | if (p.id == result[t].id) { |
410 | 0 | occlude = true; |
411 | 0 | break; |
412 | 0 | } |
413 | 0 | float djk = dis.symmetric_dis(result[t].id, p.id); |
414 | 0 | if (djk < p.distance /* dik */) { |
415 | 0 | occlude = true; |
416 | 0 | break; |
417 | 0 | } |
418 | 0 | } |
419 | 0 | if (!occlude) { |
420 | 0 | result.push_back(p); |
421 | 0 | } |
422 | 0 | } |
423 | |
|
424 | 0 | for (size_t i = 0; i < R; i++) { |
425 | 0 | if (i < result.size()) { |
426 | 0 | graph.at(q, i).id = result[i].id; |
427 | 0 | graph.at(q, i).distance = result[i].distance; |
428 | 0 | } else { |
429 | 0 | graph.at(q, i).id = EMPTY_ID; |
430 | 0 | } |
431 | 0 | } |
432 | 0 | } |
433 | | |
434 | | void NSG::add_reverse_links( |
435 | | int q, |
436 | | std::vector<std::mutex>& locks, |
437 | | DistanceComputer& dis, |
438 | 0 | nsg::Graph<Node>& graph) { |
439 | 0 | for (size_t i = 0; i < R; i++) { |
440 | 0 | if (graph.at(q, i).id == EMPTY_ID) { |
441 | 0 | break; |
442 | 0 | } |
443 | | |
444 | 0 | Node sn(q, graph.at(q, i).distance); |
445 | 0 | int des = graph.at(q, i).id; |
446 | |
|
447 | 0 | std::vector<Node> tmp_pool; |
448 | 0 | int dup = 0; |
449 | 0 | { |
450 | 0 | LockGuard guard(locks[des]); |
451 | 0 | for (int j = 0; j < R; j++) { |
452 | 0 | if (graph.at(des, j).id == EMPTY_ID) { |
453 | 0 | break; |
454 | 0 | } |
455 | 0 | if (q == graph.at(des, j).id) { |
456 | 0 | dup = 1; |
457 | 0 | break; |
458 | 0 | } |
459 | 0 | tmp_pool.push_back(graph.at(des, j)); |
460 | 0 | } |
461 | 0 | } |
462 | |
|
463 | 0 | if (dup) { |
464 | 0 | continue; |
465 | 0 | } |
466 | | |
467 | 0 | tmp_pool.push_back(sn); |
468 | 0 | if (tmp_pool.size() > R) { |
469 | 0 | std::vector<Node> result; |
470 | 0 | int start = 0; |
471 | 0 | std::sort(tmp_pool.begin(), tmp_pool.end()); |
472 | 0 | result.push_back(tmp_pool[start]); |
473 | |
|
474 | 0 | while (result.size() < R && (++start) < tmp_pool.size()) { |
475 | 0 | auto& p = tmp_pool[start]; |
476 | 0 | bool occlude = false; |
477 | |
|
478 | 0 | for (int t = 0; t < result.size(); t++) { |
479 | 0 | if (p.id == result[t].id) { |
480 | 0 | occlude = true; |
481 | 0 | break; |
482 | 0 | } |
483 | 0 | float djk = dis.symmetric_dis(result[t].id, p.id); |
484 | 0 | if (djk < p.distance /* dik */) { |
485 | 0 | occlude = true; |
486 | 0 | break; |
487 | 0 | } |
488 | 0 | } |
489 | |
|
490 | 0 | if (!occlude) { |
491 | 0 | result.push_back(p); |
492 | 0 | } |
493 | 0 | } |
494 | |
|
495 | 0 | { |
496 | 0 | LockGuard guard(locks[des]); |
497 | 0 | for (int t = 0; t < result.size(); t++) { |
498 | 0 | graph.at(des, t) = result[t]; |
499 | 0 | } |
500 | 0 | } |
501 | |
|
502 | 0 | } else { |
503 | 0 | LockGuard guard(locks[des]); |
504 | 0 | for (int t = 0; t < R; t++) { |
505 | 0 | if (graph.at(des, t).id == EMPTY_ID) { |
506 | 0 | graph.at(des, t) = sn; |
507 | 0 | break; |
508 | 0 | } |
509 | 0 | } |
510 | 0 | } |
511 | 0 | } |
512 | 0 | } |
513 | | |
514 | 0 | int NSG::tree_grow(Index* storage, std::vector<int>& degrees) { |
515 | 0 | int root = enterpoint; |
516 | 0 | VisitedTable vt(ntotal); |
517 | 0 | VisitedTable vt2(ntotal); |
518 | |
|
519 | 0 | int num_attached = 0; |
520 | 0 | int cnt = 0; |
521 | 0 | while (true) { |
522 | 0 | cnt = dfs(vt, root, cnt); |
523 | 0 | if (cnt >= ntotal) { |
524 | 0 | break; |
525 | 0 | } |
526 | | |
527 | 0 | root = attach_unlinked(storage, vt, vt2, degrees); |
528 | 0 | vt2.advance(); |
529 | 0 | num_attached += 1; |
530 | 0 | } |
531 | |
|
532 | 0 | return num_attached; |
533 | 0 | } |
534 | | |
535 | 0 | int NSG::dfs(VisitedTable& vt, int root, int cnt) const { |
536 | 0 | int node = root; |
537 | 0 | std::stack<int> stack; |
538 | 0 | stack.push(root); |
539 | |
|
540 | 0 | if (!vt.get(root)) { |
541 | 0 | cnt++; |
542 | 0 | } |
543 | 0 | vt.set(root); |
544 | |
|
545 | 0 | while (!stack.empty()) { |
546 | 0 | int next = EMPTY_ID; |
547 | 0 | for (int i = 0; i < R; i++) { |
548 | 0 | int id = final_graph->at(node, i); |
549 | 0 | if (id != EMPTY_ID && !vt.get(id)) { |
550 | 0 | next = id; |
551 | 0 | break; |
552 | 0 | } |
553 | 0 | } |
554 | |
|
555 | 0 | if (next == EMPTY_ID) { |
556 | 0 | stack.pop(); |
557 | 0 | if (stack.empty()) { |
558 | 0 | break; |
559 | 0 | } |
560 | 0 | node = stack.top(); |
561 | 0 | continue; |
562 | 0 | } |
563 | 0 | node = next; |
564 | 0 | vt.set(node); |
565 | 0 | stack.push(node); |
566 | 0 | cnt++; |
567 | 0 | } |
568 | |
|
569 | 0 | return cnt; |
570 | 0 | } |
571 | | |
572 | | int NSG::attach_unlinked( |
573 | | Index* storage, |
574 | | VisitedTable& vt, |
575 | | VisitedTable& vt2, |
576 | 0 | std::vector<int>& degrees) { |
577 | | /* NOTE: This implementation is slightly different from the original paper. |
578 | | * |
579 | | * Instead of connecting the unlinked node to the nearest point in the |
580 | | * spanning tree which will increase the maximum degree of the graph and |
581 | | * also make the graph hard to maintain, this implementation links the |
582 | | * unlinked node to the nearest node of which the degree is smaller than R. |
583 | | * It will keep the degree of all nodes to be no more than `R`. |
584 | | */ |
585 | | |
586 | | // find one unlinked node |
587 | 0 | int id = EMPTY_ID; |
588 | 0 | for (int i = 0; i < ntotal; i++) { |
589 | 0 | if (!vt.get(i)) { |
590 | 0 | id = i; |
591 | 0 | break; |
592 | 0 | } |
593 | 0 | } |
594 | |
|
595 | 0 | if (id == EMPTY_ID) { |
596 | 0 | return EMPTY_ID; // No Unlinked Node |
597 | 0 | } |
598 | | |
599 | 0 | std::vector<Neighbor> tmp; |
600 | 0 | std::vector<Node> pool; |
601 | |
|
602 | 0 | std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage)); |
603 | 0 | std::unique_ptr<float[]> vec(new float[storage->d]); |
604 | |
|
605 | 0 | storage->reconstruct(id, vec.get()); |
606 | 0 | dis->set_query(vec.get()); |
607 | | |
608 | | // Collect the visited nodes into pool |
609 | 0 | search_on_graph<true>( |
610 | 0 | *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool); |
611 | |
|
612 | 0 | std::sort(pool.begin(), pool.end()); |
613 | |
|
614 | 0 | int node; |
615 | 0 | bool found = false; |
616 | 0 | for (int i = 0; i < pool.size(); i++) { |
617 | 0 | node = pool[i].id; |
618 | 0 | if (degrees[node] < R && node != id) { |
619 | 0 | found = true; |
620 | 0 | break; |
621 | 0 | } |
622 | 0 | } |
623 | | |
624 | | // randomly choice annother node |
625 | 0 | if (!found) { |
626 | 0 | do { |
627 | 0 | node = rng.rand_int(ntotal); |
628 | 0 | if (vt.get(node) && degrees[node] < R && node != id) { |
629 | 0 | found = true; |
630 | 0 | } |
631 | 0 | } while (!found); |
632 | 0 | } |
633 | |
|
634 | 0 | int pos = degrees[node]; |
635 | 0 | final_graph->at(node, pos) = id; // replace |
636 | 0 | degrees[node] += 1; |
637 | |
|
638 | 0 | return node; |
639 | 0 | } |
640 | | |
641 | 0 | void NSG::check_graph() const { |
642 | 0 | #pragma omp parallel for |
643 | 0 | for (int i = 0; i < ntotal; i++) { |
644 | 0 | for (int j = 0; j < R; j++) { |
645 | 0 | int id = final_graph->at(i, j); |
646 | 0 | FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID)); |
647 | 0 | } |
648 | 0 | } |
649 | 0 | } |
650 | | |
651 | | } // namespace faiss |