/root/doris/contrib/faiss/faiss/IndexFlat.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexFlat.h> |
11 | | |
12 | | #include <faiss/impl/AuxIndexStructures.h> |
13 | | #include <faiss/impl/FaissAssert.h> |
14 | | #include <faiss/utils/Heap.h> |
15 | | #include <faiss/utils/distances.h> |
16 | | #include <faiss/utils/extra_distances.h> |
17 | | #include <faiss/utils/prefetch.h> |
18 | | #include <faiss/utils/sorting.h> |
19 | | #include <cstring> |
20 | | |
21 | | namespace faiss { |
22 | | |
23 | | IndexFlat::IndexFlat(idx_t d, MetricType metric) |
24 | 138 | : IndexFlatCodes(sizeof(float) * d, d, metric) {} |
25 | | |
26 | | void IndexFlat::search( |
27 | | idx_t n, |
28 | | const float* x, |
29 | | idx_t k, |
30 | | float* distances, |
31 | | idx_t* labels, |
32 | 0 | const SearchParameters* params) const { |
33 | 0 | IDSelector* sel = params ? params->sel : nullptr; |
34 | 0 | FAISS_THROW_IF_NOT(k > 0); |
35 | | |
36 | | // we see the distances and labels as heaps |
37 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
38 | 0 | float_minheap_array_t res = {size_t(n), size_t(k), labels, distances}; |
39 | 0 | knn_inner_product(x, get_xb(), d, n, ntotal, &res, sel); |
40 | 0 | } else if (metric_type == METRIC_L2) { |
41 | 0 | float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances}; |
42 | 0 | knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel); |
43 | 0 | } else { |
44 | 0 | FAISS_THROW_IF_NOT(!sel); // TODO implement with selector |
45 | 0 | knn_extra_metrics( |
46 | 0 | x, |
47 | 0 | get_xb(), |
48 | 0 | d, |
49 | 0 | n, |
50 | 0 | ntotal, |
51 | 0 | metric_type, |
52 | 0 | metric_arg, |
53 | 0 | k, |
54 | 0 | distances, |
55 | 0 | labels); |
56 | 0 | } |
57 | 0 | } |
58 | | |
59 | | void IndexFlat::range_search( |
60 | | idx_t n, |
61 | | const float* x, |
62 | | float radius, |
63 | | RangeSearchResult* result, |
64 | 3 | const SearchParameters* params) const { |
65 | 3 | IDSelector* sel = params ? params->sel : nullptr; |
66 | | |
67 | 3 | switch (metric_type) { |
68 | 3 | case METRIC_INNER_PRODUCT: |
69 | 3 | range_search_inner_product( |
70 | 3 | x, get_xb(), d, n, ntotal, radius, result, sel); |
71 | 3 | break; |
72 | 0 | case METRIC_L2: |
73 | 0 | range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result, sel); |
74 | 0 | break; |
75 | 0 | default: |
76 | 0 | FAISS_THROW_MSG("metric type not supported"); |
77 | 3 | } |
78 | 3 | } |
79 | | |
80 | | void IndexFlat::compute_distance_subset( |
81 | | idx_t n, |
82 | | const float* x, |
83 | | idx_t k, |
84 | | float* distances, |
85 | 0 | const idx_t* labels) const { |
86 | 0 | switch (metric_type) { |
87 | 0 | case METRIC_INNER_PRODUCT: |
88 | 0 | fvec_inner_products_by_idx(distances, x, get_xb(), labels, d, n, k); |
89 | 0 | break; |
90 | 0 | case METRIC_L2: |
91 | 0 | fvec_L2sqr_by_idx(distances, x, get_xb(), labels, d, n, k); |
92 | 0 | break; |
93 | 0 | default: |
94 | 0 | FAISS_THROW_MSG("metric type not supported"); |
95 | 0 | } |
96 | 0 | } |
97 | | |
98 | | namespace { |
99 | | |
100 | | struct FlatL2Dis : FlatCodesDistanceComputer { |
101 | | size_t d; |
102 | | idx_t nb; |
103 | | const float* q; |
104 | | const float* b; |
105 | | size_t ndis; |
106 | | |
107 | 985k | float distance_to_code(const uint8_t* code) final { |
108 | 985k | ndis++; |
109 | 985k | return fvec_L2sqr(q, (float*)code, d); |
110 | 985k | } |
111 | | |
112 | 6.17M | float symmetric_dis(idx_t i, idx_t j) override { |
113 | 6.17M | return fvec_L2sqr(b + j * d, b + i * d, d); |
114 | 6.17M | } |
115 | | |
116 | | explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr) |
117 | 19.0k | : FlatCodesDistanceComputer( |
118 | 19.0k | storage.codes.data(), |
119 | 19.0k | storage.code_size), |
120 | 19.0k | d(storage.d), |
121 | 19.0k | nb(storage.ntotal), |
122 | 19.0k | q(q), |
123 | 19.0k | b(storage.get_xb()), |
124 | 19.0k | ndis(0) {} |
125 | | |
126 | 20.8k | void set_query(const float* x) override { |
127 | 20.8k | q = x; |
128 | 20.8k | } |
129 | | |
130 | | // compute four distances |
131 | | void distances_batch_4( |
132 | | const idx_t idx0, |
133 | | const idx_t idx1, |
134 | | const idx_t idx2, |
135 | | const idx_t idx3, |
136 | | float& dis0, |
137 | | float& dis1, |
138 | | float& dis2, |
139 | 1.30M | float& dis3) final override { |
140 | 1.30M | ndis += 4; |
141 | | |
142 | | // compute first, assign next |
143 | 1.30M | const float* __restrict y0 = |
144 | 1.30M | reinterpret_cast<const float*>(codes + idx0 * code_size); |
145 | 1.30M | const float* __restrict y1 = |
146 | 1.30M | reinterpret_cast<const float*>(codes + idx1 * code_size); |
147 | 1.30M | const float* __restrict y2 = |
148 | 1.30M | reinterpret_cast<const float*>(codes + idx2 * code_size); |
149 | 1.30M | const float* __restrict y3 = |
150 | 1.30M | reinterpret_cast<const float*>(codes + idx3 * code_size); |
151 | | |
152 | 1.30M | float dp0 = 0; |
153 | 1.30M | float dp1 = 0; |
154 | 1.30M | float dp2 = 0; |
155 | 1.30M | float dp3 = 0; |
156 | 1.30M | fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); |
157 | 1.30M | dis0 = dp0; |
158 | 1.30M | dis1 = dp1; |
159 | 1.30M | dis2 = dp2; |
160 | 1.30M | dis3 = dp3; |
161 | 1.30M | } |
162 | | }; |
163 | | |
164 | | struct FlatIPDis : FlatCodesDistanceComputer { |
165 | | size_t d; |
166 | | idx_t nb; |
167 | | const float* q; |
168 | | const float* b; |
169 | | size_t ndis; |
170 | | |
171 | 1.09M | float symmetric_dis(idx_t i, idx_t j) final override { |
172 | 1.09M | return fvec_inner_product(b + j * d, b + i * d, d); |
173 | 1.09M | } |
174 | | |
175 | 196k | float distance_to_code(const uint8_t* code) final override { |
176 | 196k | ndis++; |
177 | 196k | return fvec_inner_product(q, (const float*)code, d); |
178 | 196k | } |
179 | | |
180 | | explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr) |
181 | 1.17k | : FlatCodesDistanceComputer( |
182 | 1.17k | storage.codes.data(), |
183 | 1.17k | storage.code_size), |
184 | 1.17k | d(storage.d), |
185 | 1.17k | nb(storage.ntotal), |
186 | 1.17k | q(q), |
187 | 1.17k | b(storage.get_xb()), |
188 | 1.17k | ndis(0) {} |
189 | | |
190 | 6.48k | void set_query(const float* x) override { |
191 | 6.48k | q = x; |
192 | 6.48k | } |
193 | | |
194 | | // compute four distances |
195 | | void distances_batch_4( |
196 | | const idx_t idx0, |
197 | | const idx_t idx1, |
198 | | const idx_t idx2, |
199 | | const idx_t idx3, |
200 | | float& dis0, |
201 | | float& dis1, |
202 | | float& dis2, |
203 | 271k | float& dis3) final override { |
204 | 271k | ndis += 4; |
205 | | |
206 | | // compute first, assign next |
207 | 271k | const float* __restrict y0 = |
208 | 271k | reinterpret_cast<const float*>(codes + idx0 * code_size); |
209 | 271k | const float* __restrict y1 = |
210 | 271k | reinterpret_cast<const float*>(codes + idx1 * code_size); |
211 | 271k | const float* __restrict y2 = |
212 | 271k | reinterpret_cast<const float*>(codes + idx2 * code_size); |
213 | 271k | const float* __restrict y3 = |
214 | 271k | reinterpret_cast<const float*>(codes + idx3 * code_size); |
215 | | |
216 | 271k | float dp0 = 0; |
217 | 271k | float dp1 = 0; |
218 | 271k | float dp2 = 0; |
219 | 271k | float dp3 = 0; |
220 | 271k | fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); |
221 | 271k | dis0 = dp0; |
222 | 271k | dis1 = dp1; |
223 | 271k | dis2 = dp2; |
224 | 271k | dis3 = dp3; |
225 | 271k | } |
226 | | }; |
227 | | |
228 | | } // namespace |
229 | | |
230 | 20.2k | FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const { |
231 | 20.2k | if (metric_type == METRIC_L2) { |
232 | 19.0k | return new FlatL2Dis(*this); |
233 | 19.0k | } else if (metric_type == METRIC_INNER_PRODUCT) { |
234 | 1.17k | return new FlatIPDis(*this); |
235 | 1.17k | } else { |
236 | 0 | return get_extra_distance_computer( |
237 | 0 | d, metric_type, metric_arg, ntotal, get_xb()); |
238 | 0 | } |
239 | 20.2k | } |
240 | | |
241 | 0 | void IndexFlat::reconstruct(idx_t key, float* recons) const { |
242 | 0 | memcpy(recons, &(codes[key * code_size]), code_size); |
243 | 0 | } |
244 | | |
245 | 18.8k | void IndexFlat::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { |
246 | 18.8k | if (n > 0) { |
247 | 18.8k | memcpy(bytes, x, sizeof(float) * d * n); |
248 | 18.8k | } |
249 | 18.8k | } |
250 | | |
251 | 0 | void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { |
252 | 0 | if (n > 0) { |
253 | 0 | memcpy(x, bytes, sizeof(float) * d * n); |
254 | 0 | } |
255 | 0 | } |
256 | | |
257 | | /*************************************************** |
258 | | * IndexFlatL2 |
259 | | ***************************************************/ |
260 | | |
261 | | namespace { |
262 | | struct FlatL2WithNormsDis : FlatCodesDistanceComputer { |
263 | | size_t d; |
264 | | idx_t nb; |
265 | | const float* q; |
266 | | const float* b; |
267 | | size_t ndis; |
268 | | |
269 | | const float* l2norms; |
270 | | float query_l2norm; |
271 | | |
272 | 0 | float distance_to_code(const uint8_t* code) final override { |
273 | 0 | ndis++; |
274 | 0 | return fvec_L2sqr(q, (float*)code, d); |
275 | 0 | } |
276 | | |
277 | 0 | float operator()(const idx_t i) final override { |
278 | 0 | const float* __restrict y = |
279 | 0 | reinterpret_cast<const float*>(codes + i * code_size); |
280 | |
|
281 | 0 | prefetch_L2(l2norms + i); |
282 | 0 | const float dp0 = fvec_inner_product(q, y, d); |
283 | 0 | return query_l2norm + l2norms[i] - 2 * dp0; |
284 | 0 | } |
285 | | |
286 | 0 | float symmetric_dis(idx_t i, idx_t j) final override { |
287 | 0 | const float* __restrict yi = |
288 | 0 | reinterpret_cast<const float*>(codes + i * code_size); |
289 | 0 | const float* __restrict yj = |
290 | 0 | reinterpret_cast<const float*>(codes + j * code_size); |
291 | |
|
292 | 0 | prefetch_L2(l2norms + i); |
293 | 0 | prefetch_L2(l2norms + j); |
294 | 0 | const float dp0 = fvec_inner_product(yi, yj, d); |
295 | 0 | return l2norms[i] + l2norms[j] - 2 * dp0; |
296 | 0 | } |
297 | | |
298 | | explicit FlatL2WithNormsDis( |
299 | | const IndexFlatL2& storage, |
300 | | const float* q = nullptr) |
301 | 0 | : FlatCodesDistanceComputer( |
302 | 0 | storage.codes.data(), |
303 | 0 | storage.code_size), |
304 | 0 | d(storage.d), |
305 | 0 | nb(storage.ntotal), |
306 | 0 | q(q), |
307 | 0 | b(storage.get_xb()), |
308 | 0 | ndis(0), |
309 | 0 | l2norms(storage.cached_l2norms.data()), |
310 | 0 | query_l2norm(0) {} |
311 | | |
312 | 0 | void set_query(const float* x) override { |
313 | 0 | q = x; |
314 | 0 | query_l2norm = fvec_norm_L2sqr(q, d); |
315 | 0 | } |
316 | | |
317 | | // compute four distances |
318 | | void distances_batch_4( |
319 | | const idx_t idx0, |
320 | | const idx_t idx1, |
321 | | const idx_t idx2, |
322 | | const idx_t idx3, |
323 | | float& dis0, |
324 | | float& dis1, |
325 | | float& dis2, |
326 | 0 | float& dis3) final override { |
327 | 0 | ndis += 4; |
328 | | |
329 | | // compute first, assign next |
330 | 0 | const float* __restrict y0 = |
331 | 0 | reinterpret_cast<const float*>(codes + idx0 * code_size); |
332 | 0 | const float* __restrict y1 = |
333 | 0 | reinterpret_cast<const float*>(codes + idx1 * code_size); |
334 | 0 | const float* __restrict y2 = |
335 | 0 | reinterpret_cast<const float*>(codes + idx2 * code_size); |
336 | 0 | const float* __restrict y3 = |
337 | 0 | reinterpret_cast<const float*>(codes + idx3 * code_size); |
338 | |
|
339 | 0 | prefetch_L2(l2norms + idx0); |
340 | 0 | prefetch_L2(l2norms + idx1); |
341 | 0 | prefetch_L2(l2norms + idx2); |
342 | 0 | prefetch_L2(l2norms + idx3); |
343 | |
|
344 | 0 | float dp0 = 0; |
345 | 0 | float dp1 = 0; |
346 | 0 | float dp2 = 0; |
347 | 0 | float dp3 = 0; |
348 | 0 | fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3); |
349 | 0 | dis0 = query_l2norm + l2norms[idx0] - 2 * dp0; |
350 | 0 | dis1 = query_l2norm + l2norms[idx1] - 2 * dp1; |
351 | 0 | dis2 = query_l2norm + l2norms[idx2] - 2 * dp2; |
352 | 0 | dis3 = query_l2norm + l2norms[idx3] - 2 * dp3; |
353 | 0 | } |
354 | | }; |
355 | | |
356 | | } // namespace |
357 | | |
358 | 0 | void IndexFlatL2::sync_l2norms() { |
359 | 0 | cached_l2norms.resize(ntotal); |
360 | 0 | fvec_norms_L2sqr( |
361 | 0 | cached_l2norms.data(), |
362 | 0 | reinterpret_cast<const float*>(codes.data()), |
363 | 0 | d, |
364 | 0 | ntotal); |
365 | 0 | } |
366 | | |
367 | 0 | void IndexFlatL2::clear_l2norms() { |
368 | 0 | cached_l2norms.clear(); |
369 | 0 | cached_l2norms.shrink_to_fit(); |
370 | 0 | } |
371 | | |
372 | 19.0k | FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const { |
373 | 19.0k | if (metric_type == METRIC_L2) { |
374 | 19.0k | if (!cached_l2norms.empty()) { |
375 | 0 | return new FlatL2WithNormsDis(*this); |
376 | 0 | } |
377 | 19.0k | } |
378 | | |
379 | 19.0k | return IndexFlat::get_FlatCodesDistanceComputer(); |
380 | 19.0k | } |
381 | | |
382 | | /*************************************************** |
383 | | * IndexFlat1D |
384 | | ***************************************************/ |
385 | | |
386 | | IndexFlat1D::IndexFlat1D(bool continuous_update) |
387 | 0 | : IndexFlatL2(1), continuous_update(continuous_update) {} |
388 | | |
389 | | /// if not continuous_update, call this between the last add and |
390 | | /// the first search |
391 | 0 | void IndexFlat1D::update_permutation() { |
392 | 0 | perm.resize(ntotal); |
393 | 0 | if (ntotal < 1000000) { |
394 | 0 | fvec_argsort(ntotal, get_xb(), (size_t*)perm.data()); |
395 | 0 | } else { |
396 | 0 | fvec_argsort_parallel(ntotal, get_xb(), (size_t*)perm.data()); |
397 | 0 | } |
398 | 0 | } |
399 | | |
400 | 0 | void IndexFlat1D::add(idx_t n, const float* x) { |
401 | 0 | IndexFlatL2::add(n, x); |
402 | 0 | if (continuous_update) |
403 | 0 | update_permutation(); |
404 | 0 | } |
405 | | |
406 | 0 | void IndexFlat1D::reset() { |
407 | 0 | IndexFlatL2::reset(); |
408 | 0 | perm.clear(); |
409 | 0 | } |
410 | | |
411 | | void IndexFlat1D::search( |
412 | | idx_t n, |
413 | | const float* x, |
414 | | idx_t k, |
415 | | float* distances, |
416 | | idx_t* labels, |
417 | 0 | const SearchParameters* params) const { |
418 | 0 | FAISS_THROW_IF_NOT_MSG( |
419 | 0 | !params, "search params not supported for this index"); |
420 | 0 | FAISS_THROW_IF_NOT(k > 0); |
421 | 0 | FAISS_THROW_IF_NOT_MSG( |
422 | 0 | perm.size() == ntotal, "Call update_permutation before search"); |
423 | 0 | const float* xb = get_xb(); |
424 | |
|
425 | 0 | #pragma omp parallel for if (n > 10000) |
426 | 0 | for (idx_t i = 0; i < n; i++) { |
427 | 0 | float q = x[i]; // query |
428 | 0 | float* D = distances + i * k; |
429 | 0 | idx_t* I = labels + i * k; |
430 | | |
431 | | // binary search |
432 | 0 | idx_t i0 = 0, i1 = ntotal; |
433 | 0 | idx_t wp = 0; |
434 | |
|
435 | 0 | if (ntotal == 0) { |
436 | 0 | for (idx_t j = 0; j < k; j++) { |
437 | 0 | I[j] = -1; |
438 | 0 | D[j] = HUGE_VAL; |
439 | 0 | } |
440 | 0 | goto done; |
441 | 0 | } |
442 | | |
443 | 0 | if (xb[perm[i0]] > q) { |
444 | 0 | i1 = 0; |
445 | 0 | goto finish_right; |
446 | 0 | } |
447 | | |
448 | 0 | if (xb[perm[i1 - 1]] <= q) { |
449 | 0 | i0 = i1 - 1; |
450 | 0 | goto finish_left; |
451 | 0 | } |
452 | | |
453 | 0 | while (i0 + 1 < i1) { |
454 | 0 | idx_t imed = (i0 + i1) / 2; |
455 | 0 | if (xb[perm[imed]] <= q) |
456 | 0 | i0 = imed; |
457 | 0 | else |
458 | 0 | i1 = imed; |
459 | 0 | } |
460 | | |
461 | | // query is between xb[perm[i0]] and xb[perm[i1]] |
462 | | // expand to nearest neighs |
463 | |
|
464 | 0 | while (wp < k) { |
465 | 0 | float xleft = xb[perm[i0]]; |
466 | 0 | float xright = xb[perm[i1]]; |
467 | |
|
468 | 0 | if (q - xleft < xright - q) { |
469 | 0 | D[wp] = q - xleft; |
470 | 0 | I[wp] = perm[i0]; |
471 | 0 | i0--; |
472 | 0 | wp++; |
473 | 0 | if (i0 < 0) { |
474 | 0 | goto finish_right; |
475 | 0 | } |
476 | 0 | } else { |
477 | 0 | D[wp] = xright - q; |
478 | 0 | I[wp] = perm[i1]; |
479 | 0 | i1++; |
480 | 0 | wp++; |
481 | 0 | if (i1 >= ntotal) { |
482 | 0 | goto finish_left; |
483 | 0 | } |
484 | 0 | } |
485 | 0 | } |
486 | 0 | goto done; |
487 | | |
488 | 0 | finish_right: |
489 | | // grow to the right from i1 |
490 | 0 | while (wp < k) { |
491 | 0 | if (i1 < ntotal) { |
492 | 0 | D[wp] = xb[perm[i1]] - q; |
493 | 0 | I[wp] = perm[i1]; |
494 | 0 | i1++; |
495 | 0 | } else { |
496 | 0 | D[wp] = std::numeric_limits<float>::infinity(); |
497 | 0 | I[wp] = -1; |
498 | 0 | } |
499 | 0 | wp++; |
500 | 0 | } |
501 | 0 | goto done; |
502 | | |
503 | 0 | finish_left: |
504 | | // grow to the left from i0 |
505 | 0 | while (wp < k) { |
506 | 0 | if (i0 >= 0) { |
507 | 0 | D[wp] = q - xb[perm[i0]]; |
508 | 0 | I[wp] = perm[i0]; |
509 | 0 | i0--; |
510 | 0 | } else { |
511 | 0 | D[wp] = std::numeric_limits<float>::infinity(); |
512 | 0 | I[wp] = -1; |
513 | 0 | } |
514 | 0 | wp++; |
515 | 0 | } |
516 | 0 | done:; |
517 | 0 | } |
518 | 0 | } |
519 | | |
520 | | } // namespace faiss |