/root/doris/contrib/faiss/faiss/IndexHNSW.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexHNSW.h> |
9 | | |
10 | | #include <omp.h> |
11 | | #include <cinttypes> |
12 | | #include <cstdio> |
13 | | #include <cstdlib> |
14 | | #include <cstring> |
15 | | |
16 | | #include <limits> |
17 | | #include <memory> |
18 | | #include <queue> |
19 | | #include <random> |
20 | | |
21 | | #include <cstdint> |
22 | | |
23 | | #include <faiss/Index2Layer.h> |
24 | | #include <faiss/IndexFlat.h> |
25 | | #include <faiss/IndexIVFPQ.h> |
26 | | #include <faiss/impl/AuxIndexStructures.h> |
27 | | #include <faiss/impl/FaissAssert.h> |
28 | | #include <faiss/impl/ResultHandler.h> |
29 | | #include <faiss/utils/random.h> |
30 | | #include <faiss/utils/sorting.h> |
31 | | |
32 | | namespace faiss { |
33 | | |
34 | | using MinimaxHeap = HNSW::MinimaxHeap; |
35 | | using storage_idx_t = HNSW::storage_idx_t; |
36 | | using NodeDistFarther = HNSW::NodeDistFarther; |
37 | | |
38 | | HNSWStats hnsw_stats; |
39 | | |
40 | | /************************************************************** |
41 | | * add / search blocks of descriptors |
42 | | **************************************************************/ |
43 | | |
44 | | namespace { |
45 | | |
46 | 20.1k | DistanceComputer* storage_distance_computer(const Index* storage) { |
47 | 20.1k | if (is_similarity_metric(storage->metric_type)) { |
48 | 1.47k | return new NegativeDistanceComputer(storage->get_distance_computer()); |
49 | 18.6k | } else { |
50 | 18.6k | return storage->get_distance_computer(); |
51 | 18.6k | } |
52 | 20.1k | } |
53 | | |
54 | | void hnsw_add_vertices( |
55 | | IndexHNSW& index_hnsw, |
56 | | size_t n0, |
57 | | size_t n, |
58 | | const float* x, |
59 | | bool verbose, |
60 | 18.8k | bool preset_levels = false) { |
61 | 18.8k | size_t d = index_hnsw.d; |
62 | 18.8k | HNSW& hnsw = index_hnsw.hnsw; |
63 | 18.8k | size_t ntotal = n0 + n; |
64 | 18.8k | double t0 = getmillisecs(); |
65 | 18.8k | if (verbose) { |
66 | 0 | printf("hnsw_add_vertices: adding %zd elements on top of %zd " |
67 | 0 | "(preset_levels=%d)\n", |
68 | 0 | n, |
69 | 0 | n0, |
70 | 0 | int(preset_levels)); |
71 | 0 | } |
72 | | |
73 | 18.8k | if (n == 0) { |
74 | 0 | return; |
75 | 0 | } |
76 | | |
77 | 18.8k | int max_level = hnsw.prepare_level_tab(n, preset_levels); |
78 | | |
79 | 18.8k | if (verbose) { |
80 | 0 | printf(" max_level = %d\n", max_level); |
81 | 0 | } |
82 | | |
83 | 18.8k | std::vector<omp_lock_t> locks(ntotal); |
84 | 7.50M | for (int i = 0; i < ntotal; i++) |
85 | 7.48M | omp_init_lock(&locks[i]); |
86 | | |
87 | | // add vectors from highest to lowest level |
88 | 18.8k | std::vector<int> hist; |
89 | 18.8k | std::vector<int> order(n); |
90 | | |
91 | 18.8k | { // make buckets with vectors of the same level |
92 | | |
93 | | // build histogram |
94 | 45.9k | for (int i = 0; i < n; i++) { |
95 | 27.1k | storage_idx_t pt_id = i + n0; |
96 | 27.1k | int pt_level = hnsw.levels[pt_id] - 1; |
97 | 47.0k | while (pt_level >= hist.size()) |
98 | 19.8k | hist.push_back(0); |
99 | 27.1k | hist[pt_level]++; |
100 | 27.1k | } |
101 | | |
102 | | // accumulate |
103 | 18.8k | std::vector<int> offsets(hist.size() + 1, 0); |
104 | 19.8k | for (int i = 0; i < hist.size() - 1; i++) { |
105 | 1.08k | offsets[i + 1] = offsets[i] + hist[i]; |
106 | 1.08k | } |
107 | | |
108 | | // bucket sort |
109 | 45.9k | for (int i = 0; i < n; i++) { |
110 | 27.1k | storage_idx_t pt_id = i + n0; |
111 | 27.1k | int pt_level = hnsw.levels[pt_id] - 1; |
112 | 27.1k | order[offsets[pt_level]++] = pt_id; |
113 | 27.1k | } |
114 | 18.8k | } |
115 | | |
116 | 18.8k | idx_t check_period = InterruptCallback::get_period_hint( |
117 | 18.8k | max_level * index_hnsw.d * hnsw.efConstruction); |
118 | | |
119 | 18.8k | { // perform add |
120 | 18.8k | RandomGenerator rng2(789); |
121 | | |
122 | 18.8k | int i1 = n; |
123 | | |
124 | 18.8k | for (int pt_level = hist.size() - 1; |
125 | 38.6k | pt_level >= int(!index_hnsw.init_level0); |
126 | 19.8k | pt_level--) { |
127 | 19.8k | int i0 = i1 - hist[pt_level]; |
128 | | |
129 | 19.8k | if (verbose) { |
130 | 0 | printf("Adding %d elements at level %d\n", i1 - i0, pt_level); |
131 | 0 | } |
132 | | |
133 | | // random permutation to get rid of dataset order bias |
134 | 47.0k | for (int j = i0; j < i1; j++) |
135 | 27.1k | std::swap(order[j], order[j + rng2.rand_int(i1 - j)]); |
136 | | |
137 | 19.8k | bool interrupt = false; |
138 | | |
139 | 19.8k | #pragma omp parallel if (i1 > i0 + 100) |
140 | 84.9k | { |
141 | 84.9k | VisitedTable vt(ntotal); |
142 | | |
143 | 84.9k | std::unique_ptr<DistanceComputer> dis( |
144 | 84.9k | storage_distance_computer(index_hnsw.storage)); |
145 | 84.9k | int prev_display = |
146 | 84.9k | verbose && omp_get_thread_num() == 0 ? 0 : -1; |
147 | 84.9k | size_t counter = 0; |
148 | | |
149 | | // here we should do schedule(dynamic) but this segfaults for |
150 | | // some versions of LLVM. The performance impact should not be |
151 | | // too large when (i1 - i0) / num_threads >> 1 |
152 | 84.9k | #pragma omp for schedule(static) |
153 | 84.9k | for (int i = i0; i < i1; i++) { |
154 | 84.9k | storage_idx_t pt_id = order[i]; |
155 | 84.9k | dis->set_query(x + (pt_id - n0) * d); |
156 | | |
157 | | // cannot break |
158 | 84.9k | if (interrupt) { |
159 | 84.9k | continue; |
160 | 84.9k | } |
161 | | |
162 | 84.9k | hnsw.add_with_locks( |
163 | 84.9k | *dis, |
164 | 84.9k | pt_level, |
165 | 84.9k | pt_id, |
166 | 84.9k | locks, |
167 | 84.9k | vt, |
168 | 84.9k | index_hnsw.keep_max_size_level0 && (pt_level == 0)); |
169 | | |
170 | 84.9k | if (prev_display >= 0 && i - i0 > prev_display + 10000) { |
171 | 84.9k | prev_display = i - i0; |
172 | 84.9k | printf(" %d / %d\r", i - i0, i1 - i0); |
173 | 84.9k | fflush(stdout); |
174 | 84.9k | } |
175 | 84.9k | if (counter % check_period == 0) { |
176 | 84.9k | if (InterruptCallback::is_interrupted()) { |
177 | 84.9k | interrupt = true; |
178 | 84.9k | } |
179 | 84.9k | } |
180 | 84.9k | counter++; |
181 | 84.9k | } |
182 | 84.9k | } |
183 | 19.8k | if (interrupt) { |
184 | 0 | FAISS_THROW_MSG("computation interrupted"); |
185 | 0 | } |
186 | 19.8k | i1 = i0; |
187 | 19.8k | } |
188 | 18.8k | if (index_hnsw.init_level0) { |
189 | 18.8k | FAISS_ASSERT(i1 == 0); |
190 | 18.8k | } else { |
191 | 0 | FAISS_ASSERT((i1 - hist[0]) == 0); |
192 | 0 | } |
193 | 18.8k | } |
194 | 18.8k | if (verbose) { |
195 | 0 | printf("Done in %.3f ms\n", getmillisecs() - t0); |
196 | 0 | } |
197 | | |
198 | 7.50M | for (int i = 0; i < ntotal; i++) { |
199 | 7.48M | omp_destroy_lock(&locks[i]); |
200 | 7.48M | } |
201 | 18.8k | } |
202 | | |
203 | | } // namespace |
204 | | |
205 | | /************************************************************** |
206 | | * IndexHNSW implementation |
207 | | **************************************************************/ |
208 | | |
209 | | IndexHNSW::IndexHNSW(int d, int M, MetricType metric) |
210 | 59 | : Index(d, metric), hnsw(M) {} |
211 | | |
212 | | IndexHNSW::IndexHNSW(Index* storage, int M) |
213 | 135 | : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) { |
214 | 135 | metric_arg = storage->metric_arg; |
215 | 135 | } |
216 | | |
217 | 193 | IndexHNSW::~IndexHNSW() { |
218 | 193 | if (own_fields) { |
219 | 192 | delete storage; |
220 | 192 | } |
221 | 193 | } |
222 | | |
223 | 32 | void IndexHNSW::train(idx_t n, const float* x) { |
224 | 32 | FAISS_THROW_IF_NOT_MSG( |
225 | 32 | storage, |
226 | 32 | "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly"); |
227 | | // hnsw structure does not require training |
228 | 32 | storage->train(n, x); |
229 | 32 | is_trained = true; |
230 | 32 | } |
231 | | |
232 | | namespace { |
233 | | |
234 | | template <class BlockResultHandler> |
235 | | void hnsw_search( |
236 | | const IndexHNSW* index, |
237 | | idx_t n, |
238 | | const float* x, |
239 | | BlockResultHandler& bres, |
240 | 152 | const SearchParameters* params) { |
241 | 152 | FAISS_THROW_IF_NOT_MSG( |
242 | 152 | index->storage, |
243 | 152 | "No storage index, please use IndexHNSWFlat (or variants) " |
244 | 152 | "instead of IndexHNSW directly"); |
245 | 152 | const HNSW& hnsw = index->hnsw; |
246 | | |
247 | 152 | int efSearch = hnsw.efSearch; |
248 | 152 | if (params) { |
249 | 142 | if (const SearchParametersHNSW* hnsw_params = |
250 | 142 | dynamic_cast<const SearchParametersHNSW*>(params)) { |
251 | 142 | efSearch = hnsw_params->efSearch; |
252 | 142 | } |
253 | 142 | } |
254 | 152 | size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; |
255 | | |
256 | 152 | idx_t check_period = InterruptCallback::get_period_hint( |
257 | 152 | hnsw.max_level * index->d * efSearch); |
258 | | |
259 | 304 | for (idx_t i0 = 0; i0 < n; i0 += check_period) { |
260 | 152 | idx_t i1 = std::min(i0 + check_period, n); |
261 | | |
262 | 152 | #pragma omp parallel if (i1 - i0 > 1) |
263 | 455 | { |
264 | 455 | VisitedTable vt(index->ntotal); |
265 | 455 | typename BlockResultHandler::SingleResultHandler res(bres); |
266 | | |
267 | 455 | std::unique_ptr<DistanceComputer> dis( |
268 | 455 | storage_distance_computer(index->storage)); |
269 | | |
270 | 455 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) |
271 | 455 | for (idx_t i = i0; i < i1; i++) { |
272 | 455 | res.begin(i); |
273 | 455 | dis->set_query(x + i * index->d); |
274 | | |
275 | 455 | HNSWStats stats = hnsw.search(*dis, res, vt, params); |
276 | 455 | n1 += stats.n1; |
277 | 455 | n2 += stats.n2; |
278 | 455 | ndis += stats.ndis; |
279 | 455 | nhops += stats.nhops; |
280 | 455 | res.end(); |
281 | 455 | } |
282 | 455 | } IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE.omp_outlined_debug__ Line | Count | Source | 263 | 173 | { | 264 | 173 | VisitedTable vt(index->ntotal); | 265 | 173 | typename BlockResultHandler::SingleResultHandler res(bres); | 266 | | | 267 | 173 | std::unique_ptr<DistanceComputer> dis( | 268 | 173 | storage_distance_computer(index->storage)); | 269 | | | 270 | 173 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) | 271 | 173 | for (idx_t i = i0; i < i1; i++) { | 272 | 173 | res.begin(i); | 273 | 173 | dis->set_query(x + i * index->d); | 274 | | | 275 | 173 | HNSWStats stats = hnsw.search(*dis, res, vt, params); | 276 | 173 | n1 += stats.n1; | 277 | 173 | n2 += stats.n2; | 278 | 173 | ndis += stats.ndis; | 279 | 173 | nhops += stats.nhops; | 280 | 173 | res.end(); | 281 | 173 | } | 282 | 173 | } |
IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE.omp_outlined_debug__ Line | Count | Source | 263 | 282 | { | 264 | 282 | VisitedTable vt(index->ntotal); | 265 | 282 | typename BlockResultHandler::SingleResultHandler res(bres); | 266 | | | 267 | 282 | std::unique_ptr<DistanceComputer> dis( | 268 | 282 | storage_distance_computer(index->storage)); | 269 | | | 270 | 282 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) | 271 | 282 | for (idx_t i = i0; i < i1; i++) { | 272 | 282 | res.begin(i); | 273 | 282 | dis->set_query(x + i * index->d); | 274 | | | 275 | 282 | HNSWStats stats = hnsw.search(*dis, res, vt, params); | 276 | 282 | n1 += stats.n1; | 277 | 282 | n2 += stats.n2; | 278 | 282 | ndis += stats.ndis; | 279 | 282 | nhops += stats.nhops; | 280 | 282 | res.end(); | 281 | 282 | } | 282 | 282 | } |
|
283 | 152 | InterruptCallback::check(); |
284 | 152 | } |
285 | | |
286 | 152 | hnsw_stats.combine({n1, n2, ndis, nhops}); |
287 | 152 | } IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE Line | Count | Source | 240 | 58 | const SearchParameters* params) { | 241 | 58 | FAISS_THROW_IF_NOT_MSG( | 242 | 58 | index->storage, | 243 | 58 | "No storage index, please use IndexHNSWFlat (or variants) " | 244 | 58 | "instead of IndexHNSW directly"); | 245 | 58 | const HNSW& hnsw = index->hnsw; | 246 | | | 247 | 58 | int efSearch = hnsw.efSearch; | 248 | 58 | if (params) { | 249 | 53 | if (const SearchParametersHNSW* hnsw_params = | 250 | 53 | dynamic_cast<const SearchParametersHNSW*>(params)) { | 251 | 53 | efSearch = hnsw_params->efSearch; | 252 | 53 | } | 253 | 53 | } | 254 | 58 | size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; | 255 | | | 256 | 58 | idx_t check_period = InterruptCallback::get_period_hint( | 257 | 58 | hnsw.max_level * index->d * efSearch); | 258 | | | 259 | 116 | for (idx_t i0 = 0; i0 < n; i0 += check_period) { | 260 | 58 | idx_t i1 = std::min(i0 + check_period, n); | 261 | | | 262 | 58 | #pragma omp parallel if (i1 - i0 > 1) | 263 | 58 | { | 264 | 58 | VisitedTable vt(index->ntotal); | 265 | 58 | typename BlockResultHandler::SingleResultHandler res(bres); | 266 | | | 267 | 58 | std::unique_ptr<DistanceComputer> dis( | 268 | 58 | storage_distance_computer(index->storage)); | 269 | | | 270 | 58 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) | 271 | 58 | for (idx_t i = i0; i < i1; i++) { | 272 | 58 | res.begin(i); | 273 | 58 | dis->set_query(x + i * index->d); | 274 | | | 275 | 58 | HNSWStats stats = hnsw.search(*dis, res, vt, params); | 276 | 58 | n1 += stats.n1; | 277 | 58 | n2 += stats.n2; | 278 | 58 | ndis += stats.ndis; | 279 | 58 | nhops += stats.nhops; | 280 | 58 | res.end(); | 281 | 58 | } | 282 | 58 | } | 283 | 58 | InterruptCallback::check(); | 284 | 58 | } | 285 | | | 286 | 58 | hnsw_stats.combine({n1, n2, ndis, nhops}); | 287 | 58 | } |
IndexHNSW.cpp:_ZN5faiss12_GLOBAL__N_111hnsw_searchINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEvPKNS_9IndexHNSWElPKfRT_PKNS_16SearchParametersE Line | Count | Source | 240 | 94 | const SearchParameters* params) { | 241 | 94 | FAISS_THROW_IF_NOT_MSG( | 242 | 94 | index->storage, | 243 | 94 | "No storage index, please use IndexHNSWFlat (or variants) " | 244 | 94 | "instead of IndexHNSW directly"); | 245 | 94 | const HNSW& hnsw = index->hnsw; | 246 | | | 247 | 94 | int efSearch = hnsw.efSearch; | 248 | 94 | if (params) { | 249 | 89 | if (const SearchParametersHNSW* hnsw_params = | 250 | 89 | dynamic_cast<const SearchParametersHNSW*>(params)) { | 251 | 89 | efSearch = hnsw_params->efSearch; | 252 | 89 | } | 253 | 89 | } | 254 | 94 | size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; | 255 | | | 256 | 94 | idx_t check_period = InterruptCallback::get_period_hint( | 257 | 94 | hnsw.max_level * index->d * efSearch); | 258 | | | 259 | 188 | for (idx_t i0 = 0; i0 < n; i0 += check_period) { | 260 | 94 | idx_t i1 = std::min(i0 + check_period, n); | 261 | | | 262 | 94 | #pragma omp parallel if (i1 - i0 > 1) | 263 | 94 | { | 264 | 94 | VisitedTable vt(index->ntotal); | 265 | 94 | typename BlockResultHandler::SingleResultHandler res(bres); | 266 | | | 267 | 94 | std::unique_ptr<DistanceComputer> dis( | 268 | 94 | storage_distance_computer(index->storage)); | 269 | | | 270 | 94 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) | 271 | 94 | for (idx_t i = i0; i < i1; i++) { | 272 | 94 | res.begin(i); | 273 | 94 | dis->set_query(x + i * index->d); | 274 | | | 275 | 94 | HNSWStats stats = hnsw.search(*dis, res, vt, params); | 276 | 94 | n1 += stats.n1; | 277 | 94 | n2 += stats.n2; | 278 | 94 | ndis += stats.ndis; | 279 | 94 | nhops += stats.nhops; | 280 | 94 | res.end(); | 281 | 94 | } | 282 | 94 | } | 283 | 94 | InterruptCallback::check(); | 284 | 94 | } | 285 | | | 286 | 94 | hnsw_stats.combine({n1, n2, ndis, nhops}); | 287 | 94 | } |
|
288 | | |
289 | | } // anonymous namespace |
290 | | |
291 | | void IndexHNSW::search( |
292 | | idx_t n, |
293 | | const float* x, |
294 | | idx_t k, |
295 | | float* distances, |
296 | | idx_t* labels, |
297 | 57 | const SearchParameters* params) const { |
298 | 57 | FAISS_THROW_IF_NOT(k > 0); |
299 | | |
300 | 57 | using RH = HeapBlockResultHandler<HNSW::C>; |
301 | 57 | RH bres(n, distances, labels, k); |
302 | | |
303 | 57 | hnsw_search(this, n, x, bres, params); |
304 | | |
305 | 57 | if (is_similarity_metric(this->metric_type)) { |
306 | | // we need to revert the negated distances |
307 | 196 | for (size_t i = 0; i < k * n; i++) { |
308 | 177 | distances[i] = -distances[i]; |
309 | 177 | } |
310 | 19 | } |
311 | 57 | } |
312 | | |
313 | | void IndexHNSW::range_search( |
314 | | idx_t n, |
315 | | const float* x, |
316 | | float radius, |
317 | | RangeSearchResult* result, |
318 | 94 | const SearchParameters* params) const { |
319 | 94 | using RH = RangeSearchBlockResultHandler<HNSW::C>; |
320 | 94 | RH bres(result, is_similarity_metric(metric_type) ? -radius : radius); |
321 | | |
322 | 94 | hnsw_search(this, n, x, bres, params); |
323 | | |
324 | 94 | if (is_similarity_metric(this->metric_type)) { |
325 | | // we need to revert the negated distances |
326 | 1.51k | for (size_t i = 0; i < result->lims[result->nq]; i++) { |
327 | 1.48k | result->distances[i] = -result->distances[i]; |
328 | 1.48k | } |
329 | 23 | } |
330 | 94 | } |
331 | | |
332 | 18.8k | void IndexHNSW::add(idx_t n, const float* x) { |
333 | 18.8k | FAISS_THROW_IF_NOT_MSG( |
334 | 18.8k | storage, |
335 | 18.8k | "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly"); |
336 | 18.8k | FAISS_THROW_IF_NOT(is_trained); |
337 | 18.8k | int n0 = ntotal; |
338 | 18.8k | storage->add(n, x); |
339 | 18.8k | ntotal = storage->ntotal; |
340 | | |
341 | 18.8k | hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal); |
342 | 18.8k | } |
343 | | |
344 | 0 | void IndexHNSW::reset() { |
345 | 0 | hnsw.reset(); |
346 | 0 | storage->reset(); |
347 | 0 | ntotal = 0; |
348 | 0 | } |
349 | | |
350 | 0 | void IndexHNSW::reconstruct(idx_t key, float* recons) const { |
351 | 0 | storage->reconstruct(key, recons); |
352 | 0 | } |
353 | | |
354 | | /************************************************************** |
355 | | * This section of functions were used during the development of HNSW support. |
356 | | * They may be useful in the future but are dormant for now, and thus are not |
357 | | * unit tested at the moment. |
358 | | * shrink_level_0_neighbors |
359 | | * search_level_0 |
360 | | * init_level_0_from_knngraph |
361 | | * init_level_0_from_entry_points |
362 | | * reorder_links |
363 | | * link_singletons |
364 | | **************************************************************/ |
365 | 0 | void IndexHNSW::shrink_level_0_neighbors(int new_size) { |
366 | 0 | #pragma omp parallel |
367 | 0 | { |
368 | 0 | std::unique_ptr<DistanceComputer> dis( |
369 | 0 | storage_distance_computer(storage)); |
370 | |
|
371 | 0 | #pragma omp for |
372 | 0 | for (idx_t i = 0; i < ntotal; i++) { |
373 | 0 | size_t begin, end; |
374 | 0 | hnsw.neighbor_range(i, 0, &begin, &end); |
375 | |
|
376 | 0 | std::priority_queue<NodeDistFarther> initial_list; |
377 | |
|
378 | 0 | for (size_t j = begin; j < end; j++) { |
379 | 0 | int v1 = hnsw.neighbors[j]; |
380 | 0 | if (v1 < 0) |
381 | 0 | break; |
382 | 0 | initial_list.emplace(dis->symmetric_dis(i, v1), v1); |
383 | | |
384 | | // initial_list.emplace(qdis(v1), v1); |
385 | 0 | } |
386 | |
|
387 | 0 | std::vector<NodeDistFarther> shrunk_list; |
388 | 0 | HNSW::shrink_neighbor_list( |
389 | 0 | *dis, initial_list, shrunk_list, new_size); |
390 | |
|
391 | 0 | for (size_t j = begin; j < end; j++) { |
392 | 0 | if (j - begin < shrunk_list.size()) |
393 | 0 | hnsw.neighbors[j] = shrunk_list[j - begin].id; |
394 | 0 | else |
395 | 0 | hnsw.neighbors[j] = -1; |
396 | 0 | } |
397 | 0 | } |
398 | 0 | } |
399 | 0 | } |
400 | | |
401 | | void IndexHNSW::search_level_0( |
402 | | idx_t n, |
403 | | const float* x, |
404 | | idx_t k, |
405 | | const storage_idx_t* nearest, |
406 | | const float* nearest_d, |
407 | | float* distances, |
408 | | idx_t* labels, |
409 | | int nprobe, |
410 | | int search_type, |
411 | 0 | const SearchParameters* params) const { |
412 | 0 | FAISS_THROW_IF_NOT(k > 0); |
413 | 0 | FAISS_THROW_IF_NOT(nprobe > 0); |
414 | | |
415 | 0 | storage_idx_t ntotal = hnsw.levels.size(); |
416 | |
|
417 | 0 | using RH = HeapBlockResultHandler<HNSW::C>; |
418 | 0 | RH bres(n, distances, labels, k); |
419 | |
|
420 | 0 | #pragma omp parallel |
421 | 0 | { |
422 | 0 | std::unique_ptr<DistanceComputer> qdis( |
423 | 0 | storage_distance_computer(storage)); |
424 | 0 | HNSWStats search_stats; |
425 | 0 | VisitedTable vt(ntotal); |
426 | 0 | RH::SingleResultHandler res(bres); |
427 | |
|
428 | 0 | #pragma omp for |
429 | 0 | for (idx_t i = 0; i < n; i++) { |
430 | 0 | res.begin(i); |
431 | 0 | qdis->set_query(x + i * d); |
432 | |
|
433 | 0 | hnsw.search_level_0( |
434 | 0 | *qdis.get(), |
435 | 0 | res, |
436 | 0 | nprobe, |
437 | 0 | nearest + i * nprobe, |
438 | 0 | nearest_d + i * nprobe, |
439 | 0 | search_type, |
440 | 0 | search_stats, |
441 | 0 | vt, |
442 | 0 | params); |
443 | 0 | res.end(); |
444 | 0 | vt.advance(); |
445 | 0 | } |
446 | 0 | #pragma omp critical |
447 | 0 | { hnsw_stats.combine(search_stats); } |
448 | 0 | } |
449 | 0 | if (is_similarity_metric(this->metric_type)) { |
450 | | // we need to revert the negated distances |
451 | 0 | #pragma omp parallel for |
452 | 0 | for (int64_t i = 0; i < k * n; i++) { |
453 | 0 | distances[i] = -distances[i]; |
454 | 0 | } |
455 | 0 | } |
456 | 0 | } |
457 | | |
458 | | void IndexHNSW::init_level_0_from_knngraph( |
459 | | int k, |
460 | | const float* D, |
461 | 0 | const idx_t* I) { |
462 | 0 | int dest_size = hnsw.nb_neighbors(0); |
463 | |
|
464 | 0 | #pragma omp parallel for |
465 | 0 | for (idx_t i = 0; i < ntotal; i++) { |
466 | 0 | DistanceComputer* qdis = storage_distance_computer(storage); |
467 | 0 | std::vector<float> vec(d); |
468 | 0 | storage->reconstruct(i, vec.data()); |
469 | 0 | qdis->set_query(vec.data()); |
470 | |
|
471 | 0 | std::priority_queue<NodeDistFarther> initial_list; |
472 | |
|
473 | 0 | for (size_t j = 0; j < k; j++) { |
474 | 0 | int v1 = I[i * k + j]; |
475 | 0 | if (v1 == i) |
476 | 0 | continue; |
477 | 0 | if (v1 < 0) |
478 | 0 | break; |
479 | 0 | initial_list.emplace(D[i * k + j], v1); |
480 | 0 | } |
481 | |
|
482 | 0 | std::vector<NodeDistFarther> shrunk_list; |
483 | 0 | HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size); |
484 | |
|
485 | 0 | size_t begin, end; |
486 | 0 | hnsw.neighbor_range(i, 0, &begin, &end); |
487 | |
|
488 | 0 | for (size_t j = begin; j < end; j++) { |
489 | 0 | if (j - begin < shrunk_list.size()) |
490 | 0 | hnsw.neighbors[j] = shrunk_list[j - begin].id; |
491 | 0 | else |
492 | 0 | hnsw.neighbors[j] = -1; |
493 | 0 | } |
494 | 0 | } |
495 | 0 | } |
496 | | |
497 | | void IndexHNSW::init_level_0_from_entry_points( |
498 | | int n, |
499 | | const storage_idx_t* points, |
500 | 0 | const storage_idx_t* nearests) { |
501 | 0 | std::vector<omp_lock_t> locks(ntotal); |
502 | 0 | for (int i = 0; i < ntotal; i++) |
503 | 0 | omp_init_lock(&locks[i]); |
504 | |
|
505 | 0 | #pragma omp parallel |
506 | 0 | { |
507 | 0 | VisitedTable vt(ntotal); |
508 | |
|
509 | 0 | std::unique_ptr<DistanceComputer> dis( |
510 | 0 | storage_distance_computer(storage)); |
511 | 0 | std::vector<float> vec(storage->d); |
512 | |
|
513 | 0 | #pragma omp for schedule(dynamic) |
514 | 0 | for (int i = 0; i < n; i++) { |
515 | 0 | storage_idx_t pt_id = points[i]; |
516 | 0 | storage_idx_t nearest = nearests[i]; |
517 | 0 | storage->reconstruct(pt_id, vec.data()); |
518 | 0 | dis->set_query(vec.data()); |
519 | |
|
520 | 0 | hnsw.add_links_starting_from( |
521 | 0 | *dis, pt_id, nearest, (*dis)(nearest), 0, locks.data(), vt); |
522 | |
|
523 | 0 | if (verbose && i % 10000 == 0) { |
524 | 0 | printf(" %d / %d\r", i, n); |
525 | 0 | fflush(stdout); |
526 | 0 | } |
527 | 0 | } |
528 | 0 | } |
529 | 0 | if (verbose) { |
530 | 0 | printf("\n"); |
531 | 0 | } |
532 | |
|
533 | 0 | for (int i = 0; i < ntotal; i++) |
534 | 0 | omp_destroy_lock(&locks[i]); |
535 | 0 | } |
536 | | |
537 | 0 | void IndexHNSW::reorder_links() { |
538 | 0 | int M = hnsw.nb_neighbors(0); |
539 | |
|
540 | 0 | #pragma omp parallel |
541 | 0 | { |
542 | 0 | std::vector<float> distances(M); |
543 | 0 | std::vector<size_t> order(M); |
544 | 0 | std::vector<storage_idx_t> tmp(M); |
545 | 0 | std::unique_ptr<DistanceComputer> dis( |
546 | 0 | storage_distance_computer(storage)); |
547 | |
|
548 | 0 | #pragma omp for |
549 | 0 | for (storage_idx_t i = 0; i < ntotal; i++) { |
550 | 0 | size_t begin, end; |
551 | 0 | hnsw.neighbor_range(i, 0, &begin, &end); |
552 | |
|
553 | 0 | for (size_t j = begin; j < end; j++) { |
554 | 0 | storage_idx_t nj = hnsw.neighbors[j]; |
555 | 0 | if (nj < 0) { |
556 | 0 | end = j; |
557 | 0 | break; |
558 | 0 | } |
559 | 0 | distances[j - begin] = dis->symmetric_dis(i, nj); |
560 | 0 | tmp[j - begin] = nj; |
561 | 0 | } |
562 | |
|
563 | 0 | fvec_argsort(end - begin, distances.data(), order.data()); |
564 | 0 | for (size_t j = begin; j < end; j++) { |
565 | 0 | hnsw.neighbors[j] = tmp[order[j - begin]]; |
566 | 0 | } |
567 | 0 | } |
568 | 0 | } |
569 | 0 | } |
570 | | |
571 | 0 | void IndexHNSW::link_singletons() { |
572 | 0 | printf("search for singletons\n"); |
573 | |
|
574 | 0 | std::vector<bool> seen(ntotal); |
575 | |
|
576 | 0 | for (size_t i = 0; i < ntotal; i++) { |
577 | 0 | size_t begin, end; |
578 | 0 | hnsw.neighbor_range(i, 0, &begin, &end); |
579 | 0 | for (size_t j = begin; j < end; j++) { |
580 | 0 | storage_idx_t ni = hnsw.neighbors[j]; |
581 | 0 | if (ni >= 0) |
582 | 0 | seen[ni] = true; |
583 | 0 | } |
584 | 0 | } |
585 | |
|
586 | 0 | int n_sing = 0, n_sing_l1 = 0; |
587 | 0 | std::vector<storage_idx_t> singletons; |
588 | 0 | for (storage_idx_t i = 0; i < ntotal; i++) { |
589 | 0 | if (!seen[i]) { |
590 | 0 | singletons.push_back(i); |
591 | 0 | n_sing++; |
592 | 0 | if (hnsw.levels[i] > 1) |
593 | 0 | n_sing_l1++; |
594 | 0 | } |
595 | 0 | } |
596 | |
|
597 | 0 | printf(" Found %d / %" PRId64 " singletons (%d appear in a level above)\n", |
598 | 0 | n_sing, |
599 | 0 | ntotal, |
600 | 0 | n_sing_l1); |
601 | |
|
602 | 0 | std::vector<float> recons(singletons.size() * d); |
603 | 0 | for (int i = 0; i < singletons.size(); i++) { |
604 | 0 | FAISS_ASSERT(!"not implemented"); |
605 | 0 | } |
606 | 0 | } |
607 | | |
608 | 0 | void IndexHNSW::permute_entries(const idx_t* perm) { |
609 | 0 | auto flat_storage = dynamic_cast<IndexFlatCodes*>(storage); |
610 | 0 | FAISS_THROW_IF_NOT_MSG( |
611 | 0 | flat_storage, "don't know how to permute this index"); |
612 | 0 | flat_storage->permute_entries(perm); |
613 | 0 | hnsw.permute_entries(perm); |
614 | 0 | } |
615 | | |
616 | 0 | DistanceComputer* IndexHNSW::get_distance_computer() const { |
617 | 0 | return storage->get_distance_computer(); |
618 | 0 | } |
619 | | |
620 | | /************************************************************** |
621 | | * IndexHNSWFlat implementation |
622 | | **************************************************************/ |
623 | | |
624 | 59 | IndexHNSWFlat::IndexHNSWFlat() { |
625 | 59 | is_trained = true; |
626 | 59 | } |
627 | | |
628 | | IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric) |
629 | 133 | : IndexHNSW( |
630 | 133 | (metric == METRIC_L2) ? new IndexFlatL2(d) |
631 | 133 | : new IndexFlat(d, metric), |
632 | 133 | M) { |
633 | 133 | own_fields = true; |
634 | 133 | is_trained = true; |
635 | 133 | } |
636 | | |
637 | | /************************************************************** |
638 | | * IndexHNSWPQ implementation |
639 | | **************************************************************/ |
640 | | |
641 | 0 | IndexHNSWPQ::IndexHNSWPQ() = default; |
642 | | |
643 | | IndexHNSWPQ::IndexHNSWPQ( |
644 | | int d, |
645 | | int pq_m, |
646 | | int M, |
647 | | int pq_nbits, |
648 | | MetricType metric) |
649 | 0 | : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) { |
650 | 0 | own_fields = true; |
651 | 0 | is_trained = false; |
652 | 0 | } |
653 | | |
654 | 0 | void IndexHNSWPQ::train(idx_t n, const float* x) { |
655 | 0 | IndexHNSW::train(n, x); |
656 | 0 | (dynamic_cast<IndexPQ*>(storage))->pq.compute_sdc_table(); |
657 | 0 | } |
658 | | |
659 | | /************************************************************** |
660 | | * IndexHNSWSQ implementation |
661 | | **************************************************************/ |
662 | | |
663 | | IndexHNSWSQ::IndexHNSWSQ( |
664 | | int d, |
665 | | ScalarQuantizer::QuantizerType qtype, |
666 | | int M, |
667 | | MetricType metric) |
668 | 2 | : IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) { |
669 | 2 | is_trained = this->storage->is_trained; |
670 | 2 | own_fields = true; |
671 | 2 | } |
672 | | |
673 | 0 | IndexHNSWSQ::IndexHNSWSQ() = default; |
674 | | |
675 | | /************************************************************** |
676 | | * IndexHNSW2Level implementation |
677 | | **************************************************************/ |
678 | | |
679 | | IndexHNSW2Level::IndexHNSW2Level( |
680 | | Index* quantizer, |
681 | | size_t nlist, |
682 | | int m_pq, |
683 | | int M) |
684 | 0 | : IndexHNSW(new Index2Layer(quantizer, nlist, m_pq), M) { |
685 | 0 | own_fields = true; |
686 | 0 | is_trained = false; |
687 | 0 | } |
688 | | |
689 | 0 | IndexHNSW2Level::IndexHNSW2Level() = default; |
690 | | |
691 | | namespace { |
692 | | |
693 | | // same as search_from_candidates but uses v |
694 | | // visno -> is in result list |
695 | | // visno + 1 -> in result list + in candidates |
696 | | int search_from_candidates_2( |
697 | | const HNSW& hnsw, |
698 | | DistanceComputer& qdis, |
699 | | int k, |
700 | | idx_t* I, |
701 | | float* D, |
702 | | MinimaxHeap& candidates, |
703 | | VisitedTable& vt, |
704 | | HNSWStats& stats, |
705 | | int level, |
706 | 0 | int nres_in = 0) { |
707 | 0 | int nres = nres_in; |
708 | 0 | for (int i = 0; i < candidates.size(); i++) { |
709 | 0 | idx_t v1 = candidates.ids[i]; |
710 | 0 | FAISS_ASSERT(v1 >= 0); |
711 | 0 | vt.visited[v1] = vt.visno + 1; |
712 | 0 | } |
713 | | |
714 | 0 | int nstep = 0; |
715 | |
|
716 | 0 | while (candidates.size() > 0) { |
717 | 0 | float d0 = 0; |
718 | 0 | int v0 = candidates.pop_min(&d0); |
719 | |
|
720 | 0 | size_t begin, end; |
721 | 0 | hnsw.neighbor_range(v0, level, &begin, &end); |
722 | |
|
723 | 0 | for (size_t j = begin; j < end; j++) { |
724 | 0 | int v1 = hnsw.neighbors[j]; |
725 | 0 | if (v1 < 0) |
726 | 0 | break; |
727 | 0 | if (vt.visited[v1] == vt.visno + 1) { |
728 | | // nothing to do |
729 | 0 | } else { |
730 | 0 | float d = qdis(v1); |
731 | 0 | candidates.push(v1, d); |
732 | | |
733 | | // never seen before --> add to heap |
734 | 0 | if (vt.visited[v1] < vt.visno) { |
735 | 0 | if (nres < k) { |
736 | 0 | faiss::maxheap_push(++nres, D, I, d, v1); |
737 | 0 | } else if (d < D[0]) { |
738 | 0 | faiss::maxheap_replace_top(nres, D, I, d, v1); |
739 | 0 | } |
740 | 0 | } |
741 | 0 | vt.visited[v1] = vt.visno + 1; |
742 | 0 | } |
743 | 0 | } |
744 | |
|
745 | 0 | nstep++; |
746 | 0 | if (nstep > hnsw.efSearch) { |
747 | 0 | break; |
748 | 0 | } |
749 | 0 | } |
750 | |
|
751 | 0 | stats.n1++; |
752 | 0 | if (candidates.size() == 0) |
753 | 0 | stats.n2++; |
754 | |
|
755 | 0 | return nres; |
756 | 0 | } |
757 | | |
758 | | } // namespace |
759 | | |
760 | | void IndexHNSW2Level::search( |
761 | | idx_t n, |
762 | | const float* x, |
763 | | idx_t k, |
764 | | float* distances, |
765 | | idx_t* labels, |
766 | 0 | const SearchParameters* params) const { |
767 | 0 | FAISS_THROW_IF_NOT(k > 0); |
768 | 0 | FAISS_THROW_IF_NOT_MSG( |
769 | 0 | !params, "search params not supported for this index"); |
770 | | |
771 | 0 | if (dynamic_cast<const Index2Layer*>(storage)) { |
772 | 0 | IndexHNSW::search(n, x, k, distances, labels); |
773 | |
|
774 | 0 | } else { // "mixed" search |
775 | 0 | size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; |
776 | |
|
777 | 0 | const IndexIVFPQ* index_ivfpq = |
778 | 0 | dynamic_cast<const IndexIVFPQ*>(storage); |
779 | |
|
780 | 0 | int nprobe = index_ivfpq->nprobe; |
781 | |
|
782 | 0 | std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]); |
783 | 0 | std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]); |
784 | |
|
785 | 0 | index_ivfpq->quantizer->search( |
786 | 0 | n, x, nprobe, coarse_dis.get(), coarse_assign.get()); |
787 | |
|
788 | 0 | index_ivfpq->search_preassigned( |
789 | 0 | n, |
790 | 0 | x, |
791 | 0 | k, |
792 | 0 | coarse_assign.get(), |
793 | 0 | coarse_dis.get(), |
794 | 0 | distances, |
795 | 0 | labels, |
796 | 0 | false); |
797 | |
|
798 | 0 | #pragma omp parallel |
799 | 0 | { |
800 | 0 | VisitedTable vt(ntotal); |
801 | 0 | std::unique_ptr<DistanceComputer> dis( |
802 | 0 | storage_distance_computer(storage)); |
803 | |
|
804 | 0 | constexpr int candidates_size = 1; |
805 | 0 | MinimaxHeap candidates(candidates_size); |
806 | |
|
807 | 0 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) |
808 | 0 | for (idx_t i = 0; i < n; i++) { |
809 | 0 | idx_t* idxi = labels + i * k; |
810 | 0 | float* simi = distances + i * k; |
811 | 0 | dis->set_query(x + i * d); |
812 | | |
813 | | // mark all inverted list elements as visited |
814 | |
|
815 | 0 | for (int j = 0; j < nprobe; j++) { |
816 | 0 | idx_t key = coarse_assign[j + i * nprobe]; |
817 | 0 | if (key < 0) |
818 | 0 | break; |
819 | 0 | size_t list_length = index_ivfpq->get_list_size(key); |
820 | 0 | const idx_t* ids = index_ivfpq->invlists->get_ids(key); |
821 | |
|
822 | 0 | for (int jj = 0; jj < list_length; jj++) { |
823 | 0 | vt.set(ids[jj]); |
824 | 0 | } |
825 | 0 | } |
826 | |
|
827 | 0 | candidates.clear(); |
828 | |
|
829 | 0 | for (int j = 0; j < k; j++) { |
830 | 0 | if (idxi[j] < 0) |
831 | 0 | break; |
832 | 0 | candidates.push(idxi[j], simi[j]); |
833 | 0 | } |
834 | | |
835 | | // reorder from sorted to heap |
836 | 0 | maxheap_heapify(k, simi, idxi, simi, idxi, k); |
837 | |
|
838 | 0 | HNSWStats search_stats; |
839 | 0 | search_from_candidates_2( |
840 | 0 | hnsw, |
841 | 0 | *dis, |
842 | 0 | k, |
843 | 0 | idxi, |
844 | 0 | simi, |
845 | 0 | candidates, |
846 | 0 | vt, |
847 | 0 | search_stats, |
848 | 0 | 0, |
849 | 0 | k); |
850 | 0 | n1 += search_stats.n1; |
851 | 0 | n2 += search_stats.n2; |
852 | 0 | ndis += search_stats.ndis; |
853 | 0 | nhops += search_stats.nhops; |
854 | |
|
855 | 0 | vt.advance(); |
856 | 0 | vt.advance(); |
857 | |
|
858 | 0 | maxheap_reorder(k, simi, idxi); |
859 | 0 | } |
860 | 0 | } |
861 | |
|
862 | 0 | hnsw_stats.combine({n1, n2, ndis, nhops}); |
863 | 0 | } |
864 | 0 | } |
865 | | |
866 | 0 | void IndexHNSW2Level::flip_to_ivf() { |
867 | 0 | Index2Layer* storage2l = dynamic_cast<Index2Layer*>(storage); |
868 | |
|
869 | 0 | FAISS_THROW_IF_NOT(storage2l); |
870 | | |
871 | 0 | IndexIVFPQ* index_ivfpq = new IndexIVFPQ( |
872 | 0 | storage2l->q1.quantizer, |
873 | 0 | d, |
874 | 0 | storage2l->q1.nlist, |
875 | 0 | storage2l->pq.M, |
876 | 0 | 8); |
877 | 0 | index_ivfpq->pq = storage2l->pq; |
878 | 0 | index_ivfpq->is_trained = storage2l->is_trained; |
879 | 0 | index_ivfpq->precompute_table(); |
880 | 0 | index_ivfpq->own_fields = storage2l->q1.own_fields; |
881 | 0 | storage2l->transfer_to_IVFPQ(*index_ivfpq); |
882 | 0 | index_ivfpq->make_direct_map(true); |
883 | |
|
884 | 0 | storage = index_ivfpq; |
885 | 0 | delete storage2l; |
886 | 0 | } |
887 | | |
888 | | /************************************************************** |
889 | | * IndexHNSWCagra implementation |
890 | | **************************************************************/ |
891 | | |
892 | 0 | IndexHNSWCagra::IndexHNSWCagra() { |
893 | 0 | is_trained = true; |
894 | 0 | } |
895 | | |
896 | | IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric) |
897 | 0 | : IndexHNSW( |
898 | 0 | (metric == METRIC_L2) |
899 | 0 | ? static_cast<IndexFlat*>(new IndexFlatL2(d)) |
900 | 0 | : static_cast<IndexFlat*>(new IndexFlatIP(d)), |
901 | 0 | M) { |
902 | 0 | FAISS_THROW_IF_NOT_MSG( |
903 | 0 | ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)), |
904 | 0 | "unsupported metric type for IndexHNSWCagra"); |
905 | 0 | own_fields = true; |
906 | 0 | is_trained = true; |
907 | 0 | init_level0 = true; |
908 | 0 | keep_max_size_level0 = true; |
909 | 0 | } |
910 | | |
911 | 0 | void IndexHNSWCagra::add(idx_t n, const float* x) { |
912 | 0 | FAISS_THROW_IF_NOT_MSG( |
913 | 0 | !base_level_only, |
914 | 0 | "Cannot add vectors when base_level_only is set to True"); |
915 | | |
916 | 0 | IndexHNSW::add(n, x); |
917 | 0 | } |
918 | | |
919 | | void IndexHNSWCagra::search( |
920 | | idx_t n, |
921 | | const float* x, |
922 | | idx_t k, |
923 | | float* distances, |
924 | | idx_t* labels, |
925 | 0 | const SearchParameters* params) const { |
926 | 0 | if (!base_level_only) { |
927 | 0 | IndexHNSW::search(n, x, k, distances, labels, params); |
928 | 0 | } else { |
929 | 0 | std::vector<storage_idx_t> nearest(n); |
930 | 0 | std::vector<float> nearest_d(n); |
931 | |
|
932 | 0 | #pragma omp for |
933 | 0 | for (idx_t i = 0; i < n; i++) { |
934 | 0 | std::unique_ptr<DistanceComputer> dis( |
935 | 0 | storage_distance_computer(this->storage)); |
936 | 0 | dis->set_query(x + i * d); |
937 | 0 | nearest[i] = -1; |
938 | 0 | nearest_d[i] = std::numeric_limits<float>::max(); |
939 | |
|
940 | 0 | std::random_device rd; |
941 | 0 | std::mt19937 gen(rd()); |
942 | 0 | std::uniform_int_distribution<idx_t> distrib(0, this->ntotal - 1); |
943 | |
|
944 | 0 | for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) { |
945 | 0 | auto idx = distrib(gen); |
946 | 0 | auto distance = (*dis)(idx); |
947 | 0 | if (distance < nearest_d[i]) { |
948 | 0 | nearest[i] = idx; |
949 | 0 | nearest_d[i] = distance; |
950 | 0 | } |
951 | 0 | } |
952 | 0 | FAISS_THROW_IF_NOT_MSG( |
953 | 0 | nearest[i] >= 0, "Could not find a valid entrypoint."); |
954 | 0 | } |
955 | |
|
956 | 0 | search_level_0( |
957 | 0 | n, |
958 | 0 | x, |
959 | 0 | k, |
960 | 0 | nearest.data(), |
961 | 0 | nearest_d.data(), |
962 | 0 | distances, |
963 | 0 | labels, |
964 | 0 | 1, // n_probes |
965 | 0 | 1, // search_type |
966 | 0 | params); |
967 | 0 | } |
968 | 0 | } |
969 | | |
970 | | } // namespace faiss |