/root/doris/contrib/faiss/faiss/impl/HNSW.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/HNSW.h> |
9 | | |
10 | | #include <cstddef> |
11 | | |
12 | | #include <faiss/impl/AuxIndexStructures.h> |
13 | | #include <faiss/impl/DistanceComputer.h> |
14 | | #include <faiss/impl/IDSelector.h> |
15 | | #include <faiss/impl/ResultHandler.h> |
16 | | #include <faiss/utils/prefetch.h> |
17 | | |
18 | | #include <faiss/impl/platform_macros.h> |
19 | | |
20 | | #ifdef __AVX2__ |
21 | | #include <immintrin.h> |
22 | | |
23 | | #include <limits> |
24 | | #include <type_traits> |
25 | | #endif |
26 | | |
27 | | namespace faiss { |
28 | | |
29 | | /************************************************************** |
30 | | * HNSW structure implementation |
31 | | **************************************************************/ |
32 | | |
33 | 28.2k | int HNSW::nb_neighbors(int layer_no) const { |
34 | 28.2k | FAISS_THROW_IF_NOT(layer_no + 1 < cum_nneighbor_per_level.size()); |
35 | 28.2k | return cum_nneighbor_per_level[layer_no + 1] - |
36 | 28.2k | cum_nneighbor_per_level[layer_no]; |
37 | 28.2k | } |
38 | | |
39 | 0 | void HNSW::set_nb_neighbors(int level_no, int n) { |
40 | 0 | FAISS_THROW_IF_NOT(levels.size() == 0); |
41 | 0 | int cur_n = nb_neighbors(level_no); |
42 | 0 | for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) { |
43 | 0 | cum_nneighbor_per_level[i] += n - cur_n; |
44 | 0 | } |
45 | 0 | } |
46 | | |
47 | 4.74M | int HNSW::cum_nb_neighbors(int layer_no) const { |
48 | 4.74M | return cum_nneighbor_per_level[layer_no]; |
49 | 4.74M | } |
50 | | |
51 | | void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end) |
52 | 2.36M | const { |
53 | 2.36M | size_t o = offsets[no]; |
54 | 2.36M | *begin = o + cum_nb_neighbors(layer_no); |
55 | 2.36M | *end = o + cum_nb_neighbors(layer_no + 1); |
56 | 2.36M | } |
57 | | |
58 | 124 | HNSW::HNSW(int M) : rng(12345) { |
59 | 124 | set_default_probas(M, 1.0 / log(M)); |
60 | 124 | offsets.push_back(0); |
61 | 124 | } |
62 | | |
63 | 27.0k | int HNSW::random_level() { |
64 | 27.0k | double f = rng.rand_float(); |
65 | | // could be a bit faster with bissection |
66 | 28.4k | for (int level = 0; level < assign_probas.size(); level++) { |
67 | 28.4k | if (f < assign_probas[level]) { |
68 | 27.0k | return level; |
69 | 27.0k | } |
70 | 1.39k | f -= assign_probas[level]; |
71 | 1.39k | } |
72 | | // happens with exponentially low probability |
73 | 0 | return assign_probas.size() - 1; |
74 | 27.0k | } |
75 | | |
76 | 124 | void HNSW::set_default_probas(int M, float levelMult) { |
77 | 124 | int nn = 0; |
78 | 124 | cum_nneighbor_per_level.push_back(0); |
79 | 993 | for (int level = 0;; level++) { |
80 | 993 | float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult)); |
81 | 993 | if (proba < 1e-9) |
82 | 124 | break; |
83 | 869 | assign_probas.push_back(proba); |
84 | 869 | nn += level == 0 ? M * 2 : M; |
85 | 869 | cum_nneighbor_per_level.push_back(nn); |
86 | 869 | } |
87 | 124 | } |
88 | | |
89 | 0 | void HNSW::clear_neighbor_tables(int level) { |
90 | 0 | for (int i = 0; i < levels.size(); i++) { |
91 | 0 | size_t begin, end; |
92 | 0 | neighbor_range(i, level, &begin, &end); |
93 | 0 | for (size_t j = begin; j < end; j++) { |
94 | 0 | neighbors[j] = -1; |
95 | 0 | } |
96 | 0 | } |
97 | 0 | } |
98 | | |
99 | 0 | void HNSW::reset() { |
100 | 0 | max_level = -1; |
101 | 0 | entry_point = -1; |
102 | 0 | offsets.clear(); |
103 | 0 | offsets.push_back(0); |
104 | 0 | levels.clear(); |
105 | 0 | neighbors.clear(); |
106 | 0 | } |
107 | | |
108 | 0 | void HNSW::print_neighbor_stats(int level) const { |
109 | 0 | FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size()); |
110 | 0 | printf("stats on level %d, max %d neighbors per vertex:\n", |
111 | 0 | level, |
112 | 0 | nb_neighbors(level)); |
113 | 0 | size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; |
114 | 0 | #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ |
115 | 0 | reduction(+ : tot_reciprocal) reduction(+ : n_node) |
116 | 0 | for (int i = 0; i < levels.size(); i++) { |
117 | 0 | if (levels[i] > level) { |
118 | 0 | n_node++; |
119 | 0 | size_t begin, end; |
120 | 0 | neighbor_range(i, level, &begin, &end); |
121 | 0 | std::unordered_set<int> neighset; |
122 | 0 | for (size_t j = begin; j < end; j++) { |
123 | 0 | if (neighbors[j] < 0) |
124 | 0 | break; |
125 | 0 | neighset.insert(neighbors[j]); |
126 | 0 | } |
127 | 0 | int n_neigh = neighset.size(); |
128 | 0 | int n_common = 0; |
129 | 0 | int n_reciprocal = 0; |
130 | 0 | for (size_t j = begin; j < end; j++) { |
131 | 0 | storage_idx_t i2 = neighbors[j]; |
132 | 0 | if (i2 < 0) |
133 | 0 | break; |
134 | 0 | FAISS_ASSERT(i2 != i); |
135 | 0 | size_t begin2, end2; |
136 | 0 | neighbor_range(i2, level, &begin2, &end2); |
137 | 0 | for (size_t j2 = begin2; j2 < end2; j2++) { |
138 | 0 | storage_idx_t i3 = neighbors[j2]; |
139 | 0 | if (i3 < 0) |
140 | 0 | break; |
141 | 0 | if (i3 == i) { |
142 | 0 | n_reciprocal++; |
143 | 0 | continue; |
144 | 0 | } |
145 | 0 | if (neighset.count(i3)) { |
146 | 0 | neighset.erase(i3); |
147 | 0 | n_common++; |
148 | 0 | } |
149 | 0 | } |
150 | 0 | } |
151 | 0 | tot_neigh += n_neigh; |
152 | 0 | tot_common += n_common; |
153 | 0 | tot_reciprocal += n_reciprocal; |
154 | 0 | } |
155 | 0 | } |
156 | 0 | float normalizer = n_node; |
157 | 0 | printf(" nb of nodes at that level %zd\n", n_node); |
158 | 0 | printf(" neighbors per node: %.2f (%zd)\n", |
159 | 0 | tot_neigh / normalizer, |
160 | 0 | tot_neigh); |
161 | 0 | printf(" nb of reciprocal neighbors: %.2f\n", |
162 | 0 | tot_reciprocal / normalizer); |
163 | 0 | printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n", |
164 | 0 | tot_common / normalizer, |
165 | 0 | tot_common); |
166 | 0 | } |
167 | | |
168 | 0 | void HNSW::fill_with_random_links(size_t n) { |
169 | 0 | int max_level_2 = prepare_level_tab(n); |
170 | 0 | RandomGenerator rng2(456); |
171 | |
|
172 | 0 | for (int level = max_level_2 - 1; level >= 0; --level) { |
173 | 0 | std::vector<int> elts; |
174 | 0 | for (int i = 0; i < n; i++) { |
175 | 0 | if (levels[i] > level) { |
176 | 0 | elts.push_back(i); |
177 | 0 | } |
178 | 0 | } |
179 | 0 | printf("linking %zd elements in level %d\n", elts.size(), level); |
180 | |
|
181 | 0 | if (elts.size() == 1) |
182 | 0 | continue; |
183 | | |
184 | 0 | for (int ii = 0; ii < elts.size(); ii++) { |
185 | 0 | int i = elts[ii]; |
186 | 0 | size_t begin, end; |
187 | 0 | neighbor_range(i, 0, &begin, &end); |
188 | 0 | for (size_t j = begin; j < end; j++) { |
189 | 0 | int other = 0; |
190 | 0 | do { |
191 | 0 | other = elts[rng2.rand_int(elts.size())]; |
192 | 0 | } while (other == i); |
193 | |
|
194 | 0 | neighbors[j] = other; |
195 | 0 | } |
196 | 0 | } |
197 | 0 | } |
198 | 0 | } |
199 | | |
200 | 18.7k | int HNSW::prepare_level_tab(size_t n, bool preset_levels) { |
201 | 18.7k | size_t n0 = offsets.size() - 1; |
202 | | |
203 | 18.7k | if (preset_levels) { |
204 | 0 | FAISS_ASSERT(n0 + n == levels.size()); |
205 | 18.7k | } else { |
206 | 18.7k | FAISS_ASSERT(n0 == levels.size()); |
207 | 45.7k | for (int i = 0; i < n; i++) { |
208 | 27.0k | int pt_level = random_level(); |
209 | 27.0k | levels.push_back(pt_level + 1); |
210 | 27.0k | } |
211 | 18.7k | } |
212 | | |
213 | 18.7k | int max_level_2 = 0; |
214 | 45.7k | for (int i = 0; i < n; i++) { |
215 | 27.0k | int pt_level = levels[i + n0] - 1; |
216 | 27.0k | if (pt_level > max_level_2) |
217 | 1.21k | max_level_2 = pt_level; |
218 | 27.0k | offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1)); |
219 | 27.0k | } |
220 | 18.7k | neighbors.resize(offsets.back(), -1); |
221 | | |
222 | 18.7k | return max_level_2; |
223 | 18.7k | } |
224 | | |
225 | | /** Enumerate vertices from nearest to farthest from query, keep a |
226 | | * neighbor only if there is no previous neighbor that is closer to |
227 | | * that vertex than the query. |
228 | | */ |
229 | | void HNSW::shrink_neighbor_list( |
230 | | DistanceComputer& qdis, |
231 | | std::priority_queue<NodeDistFarther>& input, |
232 | | std::vector<NodeDistFarther>& output, |
233 | | int max_size, |
234 | 38.0k | bool keep_max_size_level0) { |
235 | | // This prevents number of neighbors at |
236 | | // level 0 from being shrunk to less than 2 * M. |
237 | | // This is essential in making sure |
238 | | // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional |
239 | 38.0k | std::vector<NodeDistFarther> outsiders; |
240 | | |
241 | 1.33M | while (input.size() > 0) { |
242 | 1.31M | NodeDistFarther v1 = input.top(); |
243 | 1.31M | input.pop(); |
244 | 1.31M | float dist_v1_q = v1.d; |
245 | | |
246 | 1.31M | bool good = true; |
247 | 7.26M | for (NodeDistFarther v2 : output) { |
248 | 7.26M | float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id); |
249 | | |
250 | 7.26M | if (dist_v1_v2 < dist_v1_q) { |
251 | 748k | good = false; |
252 | 748k | break; |
253 | 748k | } |
254 | 7.26M | } |
255 | | |
256 | 1.31M | if (good) { |
257 | 564k | output.push_back(v1); |
258 | 564k | if (output.size() >= max_size) { |
259 | 11.8k | return; |
260 | 11.8k | } |
261 | 748k | } else if (keep_max_size_level0) { |
262 | 0 | outsiders.push_back(v1); |
263 | 0 | } |
264 | 1.31M | } |
265 | 26.2k | size_t idx = 0; |
266 | 26.2k | while (keep_max_size_level0 && (output.size() < max_size) && |
267 | 26.2k | (idx < outsiders.size())) { |
268 | 0 | output.push_back(outsiders[idx++]); |
269 | 0 | } |
270 | 26.2k | } |
271 | | |
272 | | namespace { |
273 | | |
274 | | using storage_idx_t = HNSW::storage_idx_t; |
275 | | using NodeDistCloser = HNSW::NodeDistCloser; |
276 | | using NodeDistFarther = HNSW::NodeDistFarther; |
277 | | |
278 | | /************************************************************** |
279 | | * Addition subroutines |
280 | | **************************************************************/ |
281 | | |
282 | | /// remove neighbors from the list to make it smaller than max_size |
283 | | void shrink_neighbor_list( |
284 | | DistanceComputer& qdis, |
285 | | std::priority_queue<NodeDistCloser>& resultSet1, |
286 | | int max_size, |
287 | 51.6k | bool keep_max_size_level0 = false) { |
288 | 51.6k | if (resultSet1.size() < max_size) { |
289 | 13.5k | return; |
290 | 13.5k | } |
291 | 38.0k | std::priority_queue<NodeDistFarther> resultSet; |
292 | 38.0k | std::vector<NodeDistFarther> returnlist; |
293 | | |
294 | 1.36M | while (resultSet1.size() > 0) { |
295 | 1.32M | resultSet.emplace(resultSet1.top().d, resultSet1.top().id); |
296 | 1.32M | resultSet1.pop(); |
297 | 1.32M | } |
298 | | |
299 | 38.0k | HNSW::shrink_neighbor_list( |
300 | 38.0k | qdis, resultSet, returnlist, max_size, keep_max_size_level0); |
301 | | |
302 | 565k | for (NodeDistFarther curen2 : returnlist) { |
303 | 565k | resultSet1.emplace(curen2.d, curen2.id); |
304 | 565k | } |
305 | 38.0k | } |
306 | | |
307 | | /// add a link between two elements, possibly shrinking the list |
308 | | /// of links to make room for it. |
309 | | void add_link( |
310 | | HNSW& hnsw, |
311 | | DistanceComputer& qdis, |
312 | | storage_idx_t src, |
313 | | storage_idx_t dest, |
314 | | int level, |
315 | 1.25M | bool keep_max_size_level0 = false) { |
316 | 1.25M | size_t begin, end; |
317 | 1.25M | hnsw.neighbor_range(src, level, &begin, &end); |
318 | 1.25M | if (hnsw.neighbors[end - 1] == -1) { |
319 | | // there is enough room, find a slot to add it |
320 | 1.22M | size_t i = end; |
321 | 63.0M | while (i > begin) { |
322 | 63.0M | if (hnsw.neighbors[i - 1] != -1) |
323 | 1.19M | break; |
324 | 61.8M | i--; |
325 | 61.8M | } |
326 | 1.22M | hnsw.neighbors[i] = dest; |
327 | 1.22M | return; |
328 | 1.22M | } |
329 | | |
330 | | // otherwise we let them fight out which to keep |
331 | | |
332 | | // copy to resultSet... |
333 | 30.3k | std::priority_queue<NodeDistCloser> resultSet; |
334 | 30.3k | resultSet.emplace(qdis.symmetric_dis(src, dest), dest); |
335 | 760k | for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG |
336 | 730k | storage_idx_t neigh = hnsw.neighbors[i]; |
337 | 730k | resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh); |
338 | 730k | } |
339 | | |
340 | 30.3k | shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0); |
341 | | |
342 | | // ...and back |
343 | 30.3k | size_t i = begin; |
344 | 433k | while (resultSet.size()) { |
345 | 403k | hnsw.neighbors[i++] = resultSet.top().id; |
346 | 403k | resultSet.pop(); |
347 | 403k | } |
348 | | // they may have shrunk more than just by 1 element |
349 | 358k | while (i < end) { |
350 | 328k | hnsw.neighbors[i++] = -1; |
351 | 328k | } |
352 | 30.3k | } |
353 | | |
354 | | } // namespace |
355 | | |
356 | | /// search neighbors on a single level, starting from an entry point |
357 | | void search_neighbors_to_add( |
358 | | HNSW& hnsw, |
359 | | DistanceComputer& qdis, |
360 | | std::priority_queue<NodeDistCloser>& results, |
361 | | int entry_point, |
362 | | float d_entry_point, |
363 | | int level, |
364 | | VisitedTable& vt, |
365 | 28.2k | bool reference_version) { |
366 | | // top is nearest candidate |
367 | 28.2k | std::priority_queue<NodeDistFarther> candidates; |
368 | | |
369 | 28.2k | NodeDistFarther ev(d_entry_point, entry_point); |
370 | 28.2k | candidates.push(ev); |
371 | 28.2k | results.emplace(d_entry_point, entry_point); |
372 | 28.2k | vt.set(entry_point); |
373 | | |
374 | 1.08M | while (!candidates.empty()) { |
375 | | // get nearest |
376 | 1.07M | const NodeDistFarther& currEv = candidates.top(); |
377 | | |
378 | 1.07M | if (currEv.d > results.top().d) { |
379 | 24.2k | break; |
380 | 24.2k | } |
381 | 1.05M | int currNode = currEv.id; |
382 | 1.05M | candidates.pop(); |
383 | | |
384 | | // loop over neighbors |
385 | 1.05M | size_t begin, end; |
386 | 1.05M | hnsw.neighbor_range(currNode, level, &begin, &end); |
387 | | |
388 | | // The reference version is not used, but kept here because: |
389 | | // 1. It is easier to switch back if the optimized version has a problem |
390 | | // 2. It serves as a starting point for new optimizations |
391 | | // 3. It helps understand the code |
392 | | // 4. It ensures the reference version is still compilable if the |
393 | | // optimized version changes |
394 | | // The reference and the optimized versions' results are compared in |
395 | | // test_hnsw.cpp |
396 | 1.05M | if (reference_version) { |
397 | | // a reference version |
398 | 0 | for (size_t i = begin; i < end; i++) { |
399 | 0 | storage_idx_t nodeId = hnsw.neighbors[i]; |
400 | 0 | if (nodeId < 0) |
401 | 0 | break; |
402 | 0 | if (vt.get(nodeId)) |
403 | 0 | continue; |
404 | 0 | vt.set(nodeId); |
405 | |
|
406 | 0 | float dis = qdis(nodeId); |
407 | 0 | NodeDistFarther evE1(dis, nodeId); |
408 | |
|
409 | 0 | if (results.size() < hnsw.efConstruction || |
410 | 0 | results.top().d > dis) { |
411 | 0 | results.emplace(dis, nodeId); |
412 | 0 | candidates.emplace(dis, nodeId); |
413 | 0 | if (results.size() > hnsw.efConstruction) { |
414 | 0 | results.pop(); |
415 | 0 | } |
416 | 0 | } |
417 | 0 | } |
418 | 1.05M | } else { |
419 | | // a faster version |
420 | | |
421 | | // the following version processes 4 neighbors at a time |
422 | 1.05M | auto update_with_candidate = [&](const storage_idx_t idx, |
423 | 7.34M | const float dis) { |
424 | 7.34M | if (results.size() < hnsw.efConstruction || |
425 | 7.34M | results.top().d > dis) { |
426 | 2.47M | results.emplace(dis, idx); |
427 | 2.47M | candidates.emplace(dis, idx); |
428 | 2.47M | if (results.size() > hnsw.efConstruction) { |
429 | 1.46M | results.pop(); |
430 | 1.46M | } |
431 | 2.47M | } |
432 | 7.34M | }; |
433 | | |
434 | 1.05M | int n_buffered = 0; |
435 | 1.05M | storage_idx_t buffered_ids[4]; |
436 | | |
437 | 33.2M | for (size_t j = begin; j < end; j++) { |
438 | 33.1M | storage_idx_t nodeId = hnsw.neighbors[j]; |
439 | 33.1M | if (nodeId < 0) |
440 | 963k | break; |
441 | 32.1M | if (vt.get(nodeId)) { |
442 | 25.2M | continue; |
443 | 25.2M | } |
444 | 6.94M | vt.set(nodeId); |
445 | | |
446 | 6.94M | buffered_ids[n_buffered] = nodeId; |
447 | 6.94M | n_buffered += 1; |
448 | | |
449 | 6.94M | if (n_buffered == 4) { |
450 | 1.57M | float dis[4]; |
451 | 1.57M | qdis.distances_batch_4( |
452 | 1.57M | buffered_ids[0], |
453 | 1.57M | buffered_ids[1], |
454 | 1.57M | buffered_ids[2], |
455 | 1.57M | buffered_ids[3], |
456 | 1.57M | dis[0], |
457 | 1.57M | dis[1], |
458 | 1.57M | dis[2], |
459 | 1.57M | dis[3]); |
460 | | |
461 | 7.86M | for (size_t id4 = 0; id4 < 4; id4++) { |
462 | 6.28M | update_with_candidate(buffered_ids[id4], dis[id4]); |
463 | 6.28M | } |
464 | | |
465 | 1.57M | n_buffered = 0; |
466 | 1.57M | } |
467 | 6.94M | } |
468 | | |
469 | | // process leftovers |
470 | 2.12M | for (size_t icnt = 0; icnt < n_buffered; icnt++) { |
471 | 1.07M | float dis = qdis(buffered_ids[icnt]); |
472 | 1.07M | update_with_candidate(buffered_ids[icnt], dis); |
473 | 1.07M | } |
474 | 1.05M | } |
475 | 1.05M | } |
476 | | |
477 | 28.2k | vt.advance(); |
478 | 28.2k | } |
479 | | |
480 | | /// Finds neighbors and builds links with them, starting from an entry |
481 | | /// point. The own neighbor list is assumed to be locked. |
482 | | void HNSW::add_links_starting_from( |
483 | | DistanceComputer& ptdis, |
484 | | storage_idx_t pt_id, |
485 | | storage_idx_t nearest, |
486 | | float d_nearest, |
487 | | int level, |
488 | | omp_lock_t* locks, |
489 | | VisitedTable& vt, |
490 | 28.2k | bool keep_max_size_level0) { |
491 | 28.2k | std::priority_queue<NodeDistCloser> link_targets; |
492 | | |
493 | 28.2k | search_neighbors_to_add( |
494 | 28.2k | *this, ptdis, link_targets, nearest, d_nearest, level, vt); |
495 | | |
496 | | // but we can afford only this many neighbors |
497 | 28.2k | int M = nb_neighbors(level); |
498 | | |
499 | 28.2k | ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0); |
500 | | |
501 | 28.2k | std::vector<storage_idx_t> neighbors_to_add; |
502 | 28.2k | neighbors_to_add.reserve(link_targets.size()); |
503 | 658k | while (!link_targets.empty()) { |
504 | 630k | storage_idx_t other_id = link_targets.top().id; |
505 | 630k | add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0); |
506 | 630k | neighbors_to_add.push_back(other_id); |
507 | 630k | link_targets.pop(); |
508 | 630k | } |
509 | | |
510 | 28.2k | omp_unset_lock(&locks[pt_id]); |
511 | 630k | for (storage_idx_t other_id : neighbors_to_add) { |
512 | 630k | omp_set_lock(&locks[other_id]); |
513 | 630k | add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0); |
514 | 630k | omp_unset_lock(&locks[other_id]); |
515 | 630k | } |
516 | 28.2k | omp_set_lock(&locks[pt_id]); |
517 | 28.2k | } |
518 | | |
519 | | /************************************************************** |
520 | | * Building, parallel |
521 | | **************************************************************/ |
522 | | |
523 | | void HNSW::add_with_locks( |
524 | | DistanceComputer& ptdis, |
525 | | int pt_level, |
526 | | int pt_id, |
527 | | std::vector<omp_lock_t>& locks, |
528 | | VisitedTable& vt, |
529 | 27.0k | bool keep_max_size_level0) { |
530 | | // greedy search on upper levels |
531 | | |
532 | 27.0k | storage_idx_t nearest; |
533 | 27.0k | #pragma omp critical |
534 | 27.0k | { |
535 | 27.0k | nearest = entry_point; |
536 | | |
537 | 27.0k | if (nearest == -1) { |
538 | 82 | max_level = pt_level; |
539 | 82 | entry_point = pt_id; |
540 | 82 | } |
541 | 27.0k | } |
542 | | |
543 | 27.0k | if (nearest < 0) { |
544 | 82 | return; |
545 | 82 | } |
546 | | |
547 | 26.9k | omp_set_lock(&locks[pt_id]); |
548 | | |
549 | 26.9k | int level = max_level; // level at which we start adding neighbors |
550 | 26.9k | float d_nearest = ptdis(nearest); |
551 | | |
552 | 59.6k | for (; level > pt_level; level--) { |
553 | 32.7k | greedy_update_nearest(*this, ptdis, level, nearest, d_nearest); |
554 | 32.7k | } |
555 | | |
556 | 55.1k | for (; level >= 0; level--) { |
557 | 28.2k | add_links_starting_from( |
558 | 28.2k | ptdis, |
559 | 28.2k | pt_id, |
560 | 28.2k | nearest, |
561 | 28.2k | d_nearest, |
562 | 28.2k | level, |
563 | 28.2k | locks.data(), |
564 | 28.2k | vt, |
565 | 28.2k | keep_max_size_level0); |
566 | 28.2k | } |
567 | | |
568 | 26.9k | omp_unset_lock(&locks[pt_id]); |
569 | | |
570 | 26.9k | if (pt_level > max_level) { |
571 | 56 | max_level = pt_level; |
572 | 56 | entry_point = pt_id; |
573 | 56 | } |
574 | 26.9k | } |
575 | | |
576 | | /************************************************************** |
577 | | * Searching |
578 | | **************************************************************/ |
579 | | |
580 | | using MinimaxHeap = HNSW::MinimaxHeap; |
581 | | using Node = HNSW::Node; |
582 | | using C = HNSW::C; |
583 | | /** Do a BFS on the candidates list */ |
584 | | int search_from_candidates( |
585 | | const HNSW& hnsw, |
586 | | DistanceComputer& qdis, |
587 | | ResultHandler<C>& res, |
588 | | MinimaxHeap& candidates, |
589 | | VisitedTable& vt, |
590 | | HNSWStats& stats, |
591 | | int level, |
592 | | int nres_in, |
593 | 97 | const SearchParameters* params) { |
594 | 97 | int nres = nres_in; |
595 | 97 | int ndis = 0; |
596 | | |
597 | | // can be overridden by search params |
598 | 97 | bool do_dis_check = hnsw.check_relative_distance; |
599 | 97 | int efSearch = hnsw.efSearch; |
600 | 97 | const IDSelector* sel = nullptr; |
601 | 97 | if (params) { |
602 | 87 | if (const SearchParametersHNSW* hnsw_params = |
603 | 87 | dynamic_cast<const SearchParametersHNSW*>(params)) { |
604 | 87 | do_dis_check = hnsw_params->check_relative_distance; |
605 | 87 | efSearch = hnsw_params->efSearch; |
606 | 87 | } |
607 | 87 | sel = params->sel; |
608 | 87 | } |
609 | | |
610 | 97 | C::T threshold = res.threshold; |
611 | 194 | for (int i = 0; i < candidates.size(); i++) { |
612 | 97 | idx_t v1 = candidates.ids[i]; |
613 | 97 | float d = candidates.dis[i]; |
614 | 97 | FAISS_ASSERT(v1 >= 0); |
615 | 97 | if (!sel || sel->is_member(v1)) { |
616 | 94 | if (d < threshold) { |
617 | 81 | if (res.add_result(d, v1)) { |
618 | 32 | threshold = res.threshold; |
619 | 32 | } |
620 | 81 | } |
621 | 94 | } |
622 | 97 | vt.set(v1); |
623 | 97 | } |
624 | | |
625 | 97 | int nstep = 0; |
626 | | |
627 | 2.77k | while (candidates.size() > 0) { |
628 | 2.68k | float d0 = 0; |
629 | 2.68k | int v0 = candidates.pop_min(&d0); |
630 | | |
631 | 2.68k | if (do_dis_check) { |
632 | | // tricky stopping condition: there are more that ef |
633 | | // distances that are processed already that are smaller |
634 | | // than d0 |
635 | | |
636 | 2.68k | int n_dis_below = candidates.count_below(d0); |
637 | 2.68k | if (n_dis_below >= efSearch) { |
638 | 13 | break; |
639 | 13 | } |
640 | 2.68k | } |
641 | | |
642 | 2.67k | size_t begin, end; |
643 | 2.67k | hnsw.neighbor_range(v0, level, &begin, &end); |
644 | | |
645 | | // a faster version: reference version in unit test test_hnsw.cpp |
646 | | // the following version processes 4 neighbors at a time |
647 | 2.67k | size_t jmax = begin; |
648 | 121k | for (size_t j = begin; j < end; j++) { |
649 | 121k | int v1 = hnsw.neighbors[j]; |
650 | 121k | if (v1 < 0) |
651 | 2.54k | break; |
652 | | |
653 | 118k | prefetch_L2(vt.visited.data() + v1); |
654 | 118k | jmax += 1; |
655 | 118k | } |
656 | | |
657 | 2.67k | int counter = 0; |
658 | 2.67k | size_t saved_j[4]; |
659 | | |
660 | 2.67k | threshold = res.threshold; |
661 | | |
662 | 17.4k | auto add_to_heap = [&](const size_t idx, const float dis) { |
663 | 17.4k | if (!sel || sel->is_member(idx)) { |
664 | 17.1k | if (dis < threshold) { |
665 | 8.73k | if (res.add_result(dis, idx)) { |
666 | 4.78k | threshold = res.threshold; |
667 | 4.78k | nres += 1; |
668 | 4.78k | } |
669 | 8.73k | } |
670 | 17.1k | } |
671 | 17.4k | candidates.push(idx, dis); |
672 | 17.4k | }; |
673 | | |
674 | 121k | for (size_t j = begin; j < jmax; j++) { |
675 | 118k | int v1 = hnsw.neighbors[j]; |
676 | | |
677 | 118k | bool vget = vt.get(v1); |
678 | 118k | vt.set(v1); |
679 | 118k | saved_j[counter] = v1; |
680 | 118k | counter += vget ? 0 : 1; |
681 | | |
682 | 118k | if (counter == 4) { |
683 | 3.81k | float dis[4]; |
684 | 3.81k | qdis.distances_batch_4( |
685 | 3.81k | saved_j[0], |
686 | 3.81k | saved_j[1], |
687 | 3.81k | saved_j[2], |
688 | 3.81k | saved_j[3], |
689 | 3.81k | dis[0], |
690 | 3.81k | dis[1], |
691 | 3.81k | dis[2], |
692 | 3.81k | dis[3]); |
693 | | |
694 | 19.0k | for (size_t id4 = 0; id4 < 4; id4++) { |
695 | 15.2k | add_to_heap(saved_j[id4], dis[id4]); |
696 | 15.2k | } |
697 | | |
698 | 3.81k | ndis += 4; |
699 | | |
700 | 3.81k | counter = 0; |
701 | 3.81k | } |
702 | 118k | } |
703 | | |
704 | 4.84k | for (size_t icnt = 0; icnt < counter; icnt++) { |
705 | 2.16k | float dis = qdis(saved_j[icnt]); |
706 | 2.16k | add_to_heap(saved_j[icnt], dis); |
707 | | |
708 | 2.16k | ndis += 1; |
709 | 2.16k | } |
710 | | |
711 | 2.67k | nstep++; |
712 | 2.67k | if (!do_dis_check && nstep > efSearch) { |
713 | 0 | break; |
714 | 0 | } |
715 | 2.67k | } |
716 | | |
717 | 97 | if (level == 0) { |
718 | 97 | stats.n1++; |
719 | 97 | if (candidates.size() == 0) { |
720 | 84 | stats.n2++; |
721 | 84 | } |
722 | 97 | stats.ndis += ndis; |
723 | 97 | stats.nhops += nstep; |
724 | 97 | } |
725 | | |
726 | 97 | return nres; |
727 | 97 | } |
728 | | |
729 | | std::priority_queue<HNSW::Node> search_from_candidate_unbounded( |
730 | | const HNSW& hnsw, |
731 | | const Node& node, |
732 | | DistanceComputer& qdis, |
733 | | int ef, |
734 | | VisitedTable* vt, |
735 | 0 | HNSWStats& stats) { |
736 | 0 | int ndis = 0; |
737 | 0 | std::priority_queue<Node> top_candidates; |
738 | 0 | std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates; |
739 | |
|
740 | 0 | top_candidates.push(node); |
741 | 0 | candidates.push(node); |
742 | |
|
743 | 0 | vt->set(node.second); |
744 | |
|
745 | 0 | while (!candidates.empty()) { |
746 | 0 | float d0; |
747 | 0 | storage_idx_t v0; |
748 | 0 | std::tie(d0, v0) = candidates.top(); |
749 | |
|
750 | 0 | if (d0 > top_candidates.top().first) { |
751 | 0 | break; |
752 | 0 | } |
753 | | |
754 | 0 | candidates.pop(); |
755 | |
|
756 | 0 | size_t begin, end; |
757 | 0 | hnsw.neighbor_range(v0, 0, &begin, &end); |
758 | | |
759 | | // a faster version: reference version in unit test test_hnsw.cpp |
760 | | // the following version processes 4 neighbors at a time |
761 | 0 | size_t jmax = begin; |
762 | 0 | for (size_t j = begin; j < end; j++) { |
763 | 0 | int v1 = hnsw.neighbors[j]; |
764 | 0 | if (v1 < 0) |
765 | 0 | break; |
766 | | |
767 | 0 | prefetch_L2(vt->visited.data() + v1); |
768 | 0 | jmax += 1; |
769 | 0 | } |
770 | |
|
771 | 0 | int counter = 0; |
772 | 0 | size_t saved_j[4]; |
773 | |
|
774 | 0 | auto add_to_heap = [&](const size_t idx, const float dis) { |
775 | 0 | if (top_candidates.top().first > dis || |
776 | 0 | top_candidates.size() < ef) { |
777 | 0 | candidates.emplace(dis, idx); |
778 | 0 | top_candidates.emplace(dis, idx); |
779 | |
|
780 | 0 | if (top_candidates.size() > ef) { |
781 | 0 | top_candidates.pop(); |
782 | 0 | } |
783 | 0 | } |
784 | 0 | }; |
785 | |
|
786 | 0 | for (size_t j = begin; j < jmax; j++) { |
787 | 0 | int v1 = hnsw.neighbors[j]; |
788 | |
|
789 | 0 | bool vget = vt->get(v1); |
790 | 0 | vt->set(v1); |
791 | 0 | saved_j[counter] = v1; |
792 | 0 | counter += vget ? 0 : 1; |
793 | |
|
794 | 0 | if (counter == 4) { |
795 | 0 | float dis[4]; |
796 | 0 | qdis.distances_batch_4( |
797 | 0 | saved_j[0], |
798 | 0 | saved_j[1], |
799 | 0 | saved_j[2], |
800 | 0 | saved_j[3], |
801 | 0 | dis[0], |
802 | 0 | dis[1], |
803 | 0 | dis[2], |
804 | 0 | dis[3]); |
805 | |
|
806 | 0 | for (size_t id4 = 0; id4 < 4; id4++) { |
807 | 0 | add_to_heap(saved_j[id4], dis[id4]); |
808 | 0 | } |
809 | |
|
810 | 0 | ndis += 4; |
811 | |
|
812 | 0 | counter = 0; |
813 | 0 | } |
814 | 0 | } |
815 | |
|
816 | 0 | for (size_t icnt = 0; icnt < counter; icnt++) { |
817 | 0 | float dis = qdis(saved_j[icnt]); |
818 | 0 | add_to_heap(saved_j[icnt], dis); |
819 | |
|
820 | 0 | ndis += 1; |
821 | 0 | } |
822 | |
|
823 | 0 | stats.nhops += 1; |
824 | 0 | } |
825 | |
|
826 | 0 | ++stats.n1; |
827 | 0 | if (candidates.size() == 0) { |
828 | 0 | ++stats.n2; |
829 | 0 | } |
830 | 0 | stats.ndis += ndis; |
831 | |
|
832 | 0 | return top_candidates; |
833 | 0 | } |
834 | | |
835 | | /// greedily update a nearest vector at a given level |
836 | | HNSWStats greedy_update_nearest( |
837 | | const HNSW& hnsw, |
838 | | DistanceComputer& qdis, |
839 | | int level, |
840 | | storage_idx_t& nearest, |
841 | 32.8k | float& d_nearest) { |
842 | 32.8k | HNSWStats stats; |
843 | | |
844 | 59.6k | for (;;) { |
845 | 59.6k | storage_idx_t prev_nearest = nearest; |
846 | | |
847 | 59.6k | size_t begin, end; |
848 | 59.6k | hnsw.neighbor_range(nearest, level, &begin, &end); |
849 | | |
850 | 59.6k | size_t ndis = 0; |
851 | | |
852 | | // a faster version: reference version in unit test test_hnsw.cpp |
853 | | // the following version processes 4 neighbors at a time |
854 | 59.6k | auto update_with_candidate = [&](const storage_idx_t idx, |
855 | 364k | const float dis) { |
856 | 364k | if (dis < d_nearest) { |
857 | 45.5k | nearest = idx; |
858 | 45.5k | d_nearest = dis; |
859 | 45.5k | } |
860 | 364k | }; |
861 | | |
862 | 59.6k | int n_buffered = 0; |
863 | 59.6k | storage_idx_t buffered_ids[4]; |
864 | | |
865 | 424k | for (size_t j = begin; j < end; j++) { |
866 | 419k | storage_idx_t v = hnsw.neighbors[j]; |
867 | 419k | if (v < 0) |
868 | 55.2k | break; |
869 | 364k | ndis += 1; |
870 | | |
871 | 364k | buffered_ids[n_buffered] = v; |
872 | 364k | n_buffered += 1; |
873 | | |
874 | 364k | if (n_buffered == 4) { |
875 | 70.7k | float dis[4]; |
876 | 70.7k | qdis.distances_batch_4( |
877 | 70.7k | buffered_ids[0], |
878 | 70.7k | buffered_ids[1], |
879 | 70.7k | buffered_ids[2], |
880 | 70.7k | buffered_ids[3], |
881 | 70.7k | dis[0], |
882 | 70.7k | dis[1], |
883 | 70.7k | dis[2], |
884 | 70.7k | dis[3]); |
885 | | |
886 | 353k | for (size_t id4 = 0; id4 < 4; id4++) { |
887 | 283k | update_with_candidate(buffered_ids[id4], dis[id4]); |
888 | 283k | } |
889 | | |
890 | 70.7k | n_buffered = 0; |
891 | 70.7k | } |
892 | 364k | } |
893 | | |
894 | | // process leftovers |
895 | 141k | for (size_t icnt = 0; icnt < n_buffered; icnt++) { |
896 | 81.6k | float dis = qdis(buffered_ids[icnt]); |
897 | 81.6k | update_with_candidate(buffered_ids[icnt], dis); |
898 | 81.6k | } |
899 | | |
900 | | // update stats |
901 | 59.6k | stats.ndis += ndis; |
902 | 59.6k | stats.nhops += 1; |
903 | | |
904 | 59.6k | if (nearest == prev_nearest) { |
905 | 32.8k | return stats; |
906 | 32.8k | } |
907 | 59.6k | } |
908 | 32.8k | } |
909 | | |
910 | | namespace { |
911 | | using MinimaxHeap = HNSW::MinimaxHeap; |
912 | | using Node = HNSW::Node; |
913 | | using C = HNSW::C; |
914 | | |
915 | | // just used as a lower bound for the minmaxheap, but it is set for heap search |
916 | 97 | int extract_k_from_ResultHandler(ResultHandler<C>& res) { |
917 | 97 | using RH = HeapBlockResultHandler<C>; |
918 | 97 | if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) { |
919 | 32 | return hres->k; |
920 | 32 | } |
921 | 65 | return 1; |
922 | 97 | } |
923 | | |
924 | | } // namespace |
925 | | |
926 | | HNSWStats HNSW::search( |
927 | | DistanceComputer& qdis, |
928 | | ResultHandler<C>& res, |
929 | | VisitedTable& vt, |
930 | 110 | const SearchParameters* params) const { |
931 | 110 | HNSWStats stats; |
932 | 110 | if (entry_point == -1) { |
933 | 13 | return stats; |
934 | 13 | } |
935 | 97 | int k = extract_k_from_ResultHandler(res); |
936 | | |
937 | 97 | bool bounded_queue = this->search_bounded_queue; |
938 | 97 | int efSearch = this->efSearch; |
939 | 97 | if (params) { |
940 | 87 | if (const SearchParametersHNSW* hnsw_params = |
941 | 87 | dynamic_cast<const SearchParametersHNSW*>(params)) { |
942 | 87 | bounded_queue = hnsw_params->bounded_queue; |
943 | 87 | efSearch = hnsw_params->efSearch; |
944 | 87 | } |
945 | 87 | } |
946 | | |
947 | | // greedy search on upper levels |
948 | 97 | storage_idx_t nearest = entry_point; |
949 | 97 | float d_nearest = qdis(nearest); |
950 | | |
951 | 213 | for (int level = max_level; level >= 1; level--) { |
952 | 116 | HNSWStats local_stats = |
953 | 116 | greedy_update_nearest(*this, qdis, level, nearest, d_nearest); |
954 | 116 | stats.combine(local_stats); |
955 | 116 | } |
956 | | |
957 | 97 | int ef = std::max(efSearch, k); |
958 | 97 | if (bounded_queue) { // this is the most common branch |
959 | 97 | MinimaxHeap candidates(ef); |
960 | | |
961 | 97 | candidates.push(nearest, d_nearest); |
962 | | |
963 | 97 | search_from_candidates( |
964 | 97 | *this, qdis, res, candidates, vt, stats, 0, 0, params); |
965 | 97 | } else { |
966 | 0 | std::priority_queue<Node> top_candidates = |
967 | 0 | search_from_candidate_unbounded( |
968 | 0 | *this, Node(d_nearest, nearest), qdis, ef, &vt, stats); |
969 | |
|
970 | 0 | while (top_candidates.size() > k) { |
971 | 0 | top_candidates.pop(); |
972 | 0 | } |
973 | |
|
974 | 0 | while (!top_candidates.empty()) { |
975 | 0 | float d; |
976 | 0 | storage_idx_t label; |
977 | 0 | std::tie(d, label) = top_candidates.top(); |
978 | 0 | res.add_result(d, label); |
979 | 0 | top_candidates.pop(); |
980 | 0 | } |
981 | 0 | } |
982 | | |
983 | 97 | vt.advance(); |
984 | | |
985 | 97 | return stats; |
986 | 110 | } |
987 | | |
988 | | void HNSW::search_level_0( |
989 | | DistanceComputer& qdis, |
990 | | ResultHandler<C>& res, |
991 | | idx_t nprobe, |
992 | | const storage_idx_t* nearest_i, |
993 | | const float* nearest_d, |
994 | | int search_type, |
995 | | HNSWStats& search_stats, |
996 | | VisitedTable& vt, |
997 | 0 | const SearchParameters* params) const { |
998 | 0 | const HNSW& hnsw = *this; |
999 | |
|
1000 | 0 | auto efSearch = hnsw.efSearch; |
1001 | 0 | if (params) { |
1002 | 0 | if (const SearchParametersHNSW* hnsw_params = |
1003 | 0 | dynamic_cast<const SearchParametersHNSW*>(params)) { |
1004 | 0 | efSearch = hnsw_params->efSearch; |
1005 | 0 | } |
1006 | 0 | } |
1007 | |
|
1008 | 0 | int k = extract_k_from_ResultHandler(res); |
1009 | |
|
1010 | 0 | if (search_type == 1) { |
1011 | 0 | int nres = 0; |
1012 | |
|
1013 | 0 | for (int j = 0; j < nprobe; j++) { |
1014 | 0 | storage_idx_t cj = nearest_i[j]; |
1015 | |
|
1016 | 0 | if (cj < 0) |
1017 | 0 | break; |
1018 | | |
1019 | 0 | if (vt.get(cj)) |
1020 | 0 | continue; |
1021 | | |
1022 | 0 | int candidates_size = std::max(efSearch, k); |
1023 | 0 | MinimaxHeap candidates(candidates_size); |
1024 | |
|
1025 | 0 | candidates.push(cj, nearest_d[j]); |
1026 | |
|
1027 | 0 | nres = search_from_candidates( |
1028 | 0 | hnsw, |
1029 | 0 | qdis, |
1030 | 0 | res, |
1031 | 0 | candidates, |
1032 | 0 | vt, |
1033 | 0 | search_stats, |
1034 | 0 | 0, |
1035 | 0 | nres, |
1036 | 0 | params); |
1037 | 0 | nres = std::min(nres, candidates_size); |
1038 | 0 | } |
1039 | 0 | } else if (search_type == 2) { |
1040 | 0 | int candidates_size = std::max(efSearch, int(k)); |
1041 | 0 | candidates_size = std::max(candidates_size, int(nprobe)); |
1042 | |
|
1043 | 0 | MinimaxHeap candidates(candidates_size); |
1044 | 0 | for (int j = 0; j < nprobe; j++) { |
1045 | 0 | storage_idx_t cj = nearest_i[j]; |
1046 | |
|
1047 | 0 | if (cj < 0) |
1048 | 0 | break; |
1049 | 0 | candidates.push(cj, nearest_d[j]); |
1050 | 0 | } |
1051 | |
|
1052 | 0 | search_from_candidates( |
1053 | 0 | hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params); |
1054 | 0 | } |
1055 | 0 | } |
1056 | | |
1057 | 0 | void HNSW::permute_entries(const idx_t* map) { |
1058 | | // remap levels |
1059 | 0 | storage_idx_t ntotal = levels.size(); |
1060 | 0 | std::vector<storage_idx_t> imap(ntotal); // inverse mapping |
1061 | | // map: new index -> old index |
1062 | | // imap: old index -> new index |
1063 | 0 | for (int i = 0; i < ntotal; i++) { |
1064 | 0 | assert(map[i] >= 0 && map[i] < ntotal); |
1065 | 0 | imap[map[i]] = i; |
1066 | 0 | } |
1067 | 0 | if (entry_point != -1) { |
1068 | 0 | entry_point = imap[entry_point]; |
1069 | 0 | } |
1070 | 0 | std::vector<int> new_levels(ntotal); |
1071 | 0 | std::vector<size_t> new_offsets(ntotal + 1); |
1072 | 0 | std::vector<storage_idx_t> new_neighbors(neighbors.size()); |
1073 | 0 | size_t no = 0; |
1074 | 0 | for (int i = 0; i < ntotal; i++) { |
1075 | 0 | storage_idx_t o = map[i]; // corresponding "old" index |
1076 | 0 | new_levels[i] = levels[o]; |
1077 | 0 | for (size_t j = offsets[o]; j < offsets[o + 1]; j++) { |
1078 | 0 | storage_idx_t neigh = neighbors[j]; |
1079 | 0 | new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh; |
1080 | 0 | } |
1081 | 0 | new_offsets[i + 1] = no; |
1082 | 0 | } |
1083 | 0 | assert(new_offsets[ntotal] == offsets[ntotal]); |
1084 | | // swap everyone |
1085 | 0 | std::swap(levels, new_levels); |
1086 | 0 | std::swap(offsets, new_offsets); |
1087 | 0 | neighbors = std::move(new_neighbors); |
1088 | 0 | } |
1089 | | |
1090 | | /************************************************************** |
1091 | | * MinimaxHeap |
1092 | | **************************************************************/ |
1093 | | |
1094 | 17.5k | void HNSW::MinimaxHeap::push(storage_idx_t i, float v) { |
1095 | 17.5k | if (k == n) { |
1096 | 11.5k | if (v >= dis[0]) |
1097 | 9.16k | return; |
1098 | 2.41k | if (ids[0] != -1) { |
1099 | 2.36k | --nvalid; |
1100 | 2.36k | } |
1101 | 2.41k | faiss::heap_pop<HC>(k--, dis.data(), ids.data()); |
1102 | 2.41k | } |
1103 | 8.37k | faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i); |
1104 | 8.37k | ++nvalid; |
1105 | 8.37k | } |
1106 | | |
1107 | 0 | float HNSW::MinimaxHeap::max() const { |
1108 | 0 | return dis[0]; |
1109 | 0 | } |
1110 | | |
1111 | 3.06k | int HNSW::MinimaxHeap::size() const { |
1112 | 3.06k | return nvalid; |
1113 | 3.06k | } |
1114 | | |
1115 | 0 | void HNSW::MinimaxHeap::clear() { |
1116 | 0 | nvalid = k = 0; |
1117 | 0 | } |
1118 | | |
1119 | | #ifdef __AVX512F__ |
1120 | | |
1121 | | int HNSW::MinimaxHeap::pop_min(float* vmin_out) { |
1122 | | assert(k > 0); |
1123 | | static_assert( |
1124 | | std::is_same<storage_idx_t, int32_t>::value, |
1125 | | "This code expects storage_idx_t to be int32_t"); |
1126 | | |
1127 | | int32_t min_idx = -1; |
1128 | | float min_dis = std::numeric_limits<float>::infinity(); |
1129 | | |
1130 | | __m512i min_indices = _mm512_set1_epi32(-1); |
1131 | | __m512 min_distances = |
1132 | | _mm512_set1_ps(std::numeric_limits<float>::infinity()); |
1133 | | __m512i current_indices = _mm512_setr_epi32( |
1134 | | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); |
1135 | | __m512i offset = _mm512_set1_epi32(16); |
1136 | | |
1137 | | // The following loop tracks the rightmost index with the min distance. |
1138 | | // -1 index values are ignored. |
1139 | | const int k16 = (k / 16) * 16; |
1140 | | for (size_t iii = 0; iii < k16; iii += 16) { |
1141 | | __m512i indices = |
1142 | | _mm512_loadu_si512((const __m512i*)(ids.data() + iii)); |
1143 | | __m512 distances = _mm512_loadu_ps(dis.data() + iii); |
1144 | | |
1145 | | // This mask filters out -1 values among indices. |
1146 | | __mmask16 m1mask = |
1147 | | _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); |
1148 | | |
1149 | | __mmask16 dmask = |
1150 | | _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); |
1151 | | __mmask16 finalmask = m1mask | dmask; |
1152 | | |
1153 | | const __m512i min_indices_new = _mm512_mask_blend_epi32( |
1154 | | finalmask, current_indices, min_indices); |
1155 | | const __m512 min_distances_new = |
1156 | | _mm512_mask_blend_ps(finalmask, distances, min_distances); |
1157 | | |
1158 | | min_indices = min_indices_new; |
1159 | | min_distances = min_distances_new; |
1160 | | |
1161 | | current_indices = _mm512_add_epi32(current_indices, offset); |
1162 | | } |
1163 | | |
1164 | | // leftovers |
1165 | | if (k16 != k) { |
1166 | | const __mmask16 kmask = (1 << (k - k16)) - 1; |
1167 | | |
1168 | | __m512i indices = _mm512_mask_loadu_epi32( |
1169 | | _mm512_set1_epi32(-1), kmask, ids.data() + k16); |
1170 | | __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16); |
1171 | | |
1172 | | // This mask filters out -1 values among indices. |
1173 | | __mmask16 m1mask = |
1174 | | _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); |
1175 | | |
1176 | | __mmask16 dmask = |
1177 | | _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); |
1178 | | __mmask16 finalmask = m1mask | dmask; |
1179 | | |
1180 | | const __m512i min_indices_new = _mm512_mask_blend_epi32( |
1181 | | finalmask, current_indices, min_indices); |
1182 | | const __m512 min_distances_new = |
1183 | | _mm512_mask_blend_ps(finalmask, distances, min_distances); |
1184 | | |
1185 | | min_indices = min_indices_new; |
1186 | | min_distances = min_distances_new; |
1187 | | } |
1188 | | |
1189 | | // grab min distance |
1190 | | min_dis = _mm512_reduce_min_ps(min_distances); |
1191 | | // blend |
1192 | | __mmask16 mindmask = |
1193 | | _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis)); |
1194 | | // pick the max one |
1195 | | min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices); |
1196 | | |
1197 | | if (min_idx == -1) { |
1198 | | return -1; |
1199 | | } |
1200 | | |
1201 | | if (vmin_out) { |
1202 | | *vmin_out = min_dis; |
1203 | | } |
1204 | | int ret = ids[min_idx]; |
1205 | | ids[min_idx] = -1; |
1206 | | --nvalid; |
1207 | | return ret; |
1208 | | } |
1209 | | |
1210 | | #elif __AVX2__ |
1211 | | |
1212 | 2.68k | int HNSW::MinimaxHeap::pop_min(float* vmin_out) { |
1213 | 2.68k | assert(k > 0); |
1214 | 2.68k | static_assert( |
1215 | 2.68k | std::is_same<storage_idx_t, int32_t>::value, |
1216 | 2.68k | "This code expects storage_idx_t to be int32_t"); |
1217 | | |
1218 | 2.68k | int32_t min_idx = -1; |
1219 | 2.68k | float min_dis = std::numeric_limits<float>::infinity(); |
1220 | | |
1221 | 2.68k | size_t iii = 0; |
1222 | | |
1223 | 2.68k | __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1); |
1224 | 2.68k | __m256 min_distances = |
1225 | 2.68k | _mm256_set1_ps(std::numeric_limits<float>::infinity()); |
1226 | 2.68k | __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); |
1227 | 2.68k | __m256i offset = _mm256_set1_epi32(8); |
1228 | | |
1229 | | // The baseline version is available in non-AVX2 branch. |
1230 | | |
1231 | | // The following loop tracks the rightmost index with the min distance. |
1232 | | // -1 index values are ignored. |
1233 | 2.68k | const int k8 = (k / 8) * 8; |
1234 | 52.3k | for (; iii < k8; iii += 8) { |
1235 | 49.6k | __m256i indices = |
1236 | 49.6k | _mm256_loadu_si256((const __m256i*)(ids.data() + iii)); |
1237 | 49.6k | __m256 distances = _mm256_loadu_ps(dis.data() + iii); |
1238 | | |
1239 | | // This mask filters out -1 values among indices. |
1240 | 49.6k | __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices); |
1241 | | |
1242 | 49.6k | __m256i dmask = _mm256_castps_si256( |
1243 | 49.6k | _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS)); |
1244 | 49.6k | __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask)); |
1245 | | |
1246 | 49.6k | const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps( |
1247 | 49.6k | _mm256_castsi256_ps(current_indices), |
1248 | 49.6k | _mm256_castsi256_ps(min_indices), |
1249 | 49.6k | finalmask)); |
1250 | | |
1251 | 49.6k | const __m256 min_distances_new = |
1252 | 49.6k | _mm256_blendv_ps(distances, min_distances, finalmask); |
1253 | | |
1254 | 49.6k | min_indices = min_indices_new; |
1255 | 49.6k | min_distances = min_distances_new; |
1256 | | |
1257 | 49.6k | current_indices = _mm256_add_epi32(current_indices, offset); |
1258 | 49.6k | } |
1259 | | |
1260 | | // Vectorizing is doable, but is not practical |
1261 | 2.68k | int32_t vidx8[8]; |
1262 | 2.68k | float vdis8[8]; |
1263 | 2.68k | _mm256_storeu_ps(vdis8, min_distances); |
1264 | 2.68k | _mm256_storeu_si256((__m256i*)vidx8, min_indices); |
1265 | | |
1266 | 24.2k | for (size_t j = 0; j < 8; j++) { |
1267 | 21.5k | if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) { |
1268 | 7.78k | min_idx = vidx8[j]; |
1269 | 7.78k | min_dis = vdis8[j]; |
1270 | 7.78k | } |
1271 | 21.5k | } |
1272 | | |
1273 | | // process last values. Vectorizing is doable, but is not practical |
1274 | 7.96k | for (; iii < k; iii++) { |
1275 | 5.27k | if (ids[iii] != -1 && dis[iii] <= min_dis) { |
1276 | 254 | min_dis = dis[iii]; |
1277 | 254 | min_idx = iii; |
1278 | 254 | } |
1279 | 5.27k | } |
1280 | | |
1281 | 2.68k | if (min_idx == -1) { |
1282 | 0 | return -1; |
1283 | 0 | } |
1284 | | |
1285 | 2.68k | if (vmin_out) { |
1286 | 2.68k | *vmin_out = min_dis; |
1287 | 2.68k | } |
1288 | 2.68k | int ret = ids[min_idx]; |
1289 | 2.68k | ids[min_idx] = -1; |
1290 | 2.68k | --nvalid; |
1291 | 2.68k | return ret; |
1292 | 2.68k | } |
1293 | | |
1294 | | #else |
1295 | | |
1296 | | // baseline non-vectorized version |
1297 | | int HNSW::MinimaxHeap::pop_min(float* vmin_out) { |
1298 | | assert(k > 0); |
1299 | | // returns min. This is an O(n) operation |
1300 | | int i = k - 1; |
1301 | | while (i >= 0) { |
1302 | | if (ids[i] != -1) { |
1303 | | break; |
1304 | | } |
1305 | | i--; |
1306 | | } |
1307 | | if (i == -1) { |
1308 | | return -1; |
1309 | | } |
1310 | | int imin = i; |
1311 | | float vmin = dis[i]; |
1312 | | i--; |
1313 | | while (i >= 0) { |
1314 | | if (ids[i] != -1 && dis[i] < vmin) { |
1315 | | vmin = dis[i]; |
1316 | | imin = i; |
1317 | | } |
1318 | | i--; |
1319 | | } |
1320 | | if (vmin_out) { |
1321 | | *vmin_out = vmin; |
1322 | | } |
1323 | | int ret = ids[imin]; |
1324 | | ids[imin] = -1; |
1325 | | --nvalid; |
1326 | | |
1327 | | return ret; |
1328 | | } |
1329 | | #endif |
1330 | | |
1331 | 2.68k | int HNSW::MinimaxHeap::count_below(float thresh) { |
1332 | 2.68k | int n_below = 0; |
1333 | 405k | for (int i = 0; i < k; i++) { |
1334 | 402k | if (dis[i] < thresh) { |
1335 | 79.2k | n_below++; |
1336 | 79.2k | } |
1337 | 402k | } |
1338 | | |
1339 | 2.68k | return n_below; |
1340 | 2.68k | } |
1341 | | |
1342 | | } // namespace faiss |