/root/doris/contrib/faiss/faiss/IndexFlat.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #ifndef INDEX_FLAT_H |
11 | | #define INDEX_FLAT_H |
12 | | |
13 | | #include <vector> |
14 | | |
15 | | #include <faiss/IndexFlatCodes.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | /** Index that stores the full vectors and performs exhaustive search */ |
20 | | struct IndexFlat : IndexFlatCodes { |
21 | | explicit IndexFlat( |
22 | | idx_t d, ///< dimensionality of the input vectors |
23 | | MetricType metric = METRIC_L2); |
24 | | |
25 | | void search( |
26 | | idx_t n, |
27 | | const float* x, |
28 | | idx_t k, |
29 | | float* distances, |
30 | | idx_t* labels, |
31 | | const SearchParameters* params = nullptr) const override; |
32 | | |
33 | | void range_search( |
34 | | idx_t n, |
35 | | const float* x, |
36 | | float radius, |
37 | | RangeSearchResult* result, |
38 | | const SearchParameters* params = nullptr) const override; |
39 | | |
40 | | void reconstruct(idx_t key, float* recons) const override; |
41 | | |
42 | | /** compute distance with a subset of vectors |
43 | | * |
44 | | * @param x query vectors, size n * d |
45 | | * @param labels indices of the vectors that should be compared |
46 | | * for each query vector, size n * k |
47 | | * @param distances |
48 | | * corresponding output distances, size n * k |
49 | | */ |
50 | | void compute_distance_subset( |
51 | | idx_t n, |
52 | | const float* x, |
53 | | idx_t k, |
54 | | float* distances, |
55 | | const idx_t* labels) const; |
56 | | |
57 | | // get pointer to the floating point data |
58 | 0 | float* get_xb() { |
59 | 0 | return (float*)codes.data(); |
60 | 0 | } |
61 | 18.0k | const float* get_xb() const { |
62 | 18.0k | return (const float*)codes.data(); |
63 | 18.0k | } |
64 | | |
65 | 99 | IndexFlat() {} |
66 | | |
67 | | FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override; |
68 | | |
69 | | /* The stanadlone codec interface (just memcopies in this case) */ |
70 | | void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override; |
71 | | |
72 | | void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; |
73 | | }; |
74 | | |
75 | | struct IndexFlatIP : IndexFlat { |
76 | 3 | explicit IndexFlatIP(idx_t d) : IndexFlat(d, METRIC_INNER_PRODUCT) {} |
77 | 17 | IndexFlatIP() {} |
78 | | }; |
79 | | |
80 | | struct IndexFlatL2 : IndexFlat { |
81 | | // Special cache for L2 norms. |
82 | | // If this cache is set, then get_distance_computer() returns |
83 | | // a special version that computes the distance using dot products |
84 | | // and l2 norms. |
85 | | std::vector<float> cached_l2norms; |
86 | | |
87 | | /** |
88 | | * @param d dimensionality of the input vectors |
89 | | */ |
90 | 104 | explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {} |
91 | 82 | IndexFlatL2() {} |
92 | | |
93 | | // override for l2 norms cache. |
94 | | FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override; |
95 | | |
96 | | // compute L2 norms |
97 | | void sync_l2norms(); |
98 | | // clear L2 norms |
99 | | void clear_l2norms(); |
100 | | }; |
101 | | |
102 | | /// optimized version for 1D "vectors". |
103 | | struct IndexFlat1D : IndexFlatL2 { |
104 | | bool continuous_update = true; ///< is the permutation updated continuously? |
105 | | |
106 | | std::vector<idx_t> perm; ///< sorted database indices |
107 | | |
108 | | explicit IndexFlat1D(bool continuous_update = true); |
109 | | |
110 | | /// if not continuous_update, call this between the last add and |
111 | | /// the first search |
112 | | void update_permutation(); |
113 | | |
114 | | void add(idx_t n, const float* x) override; |
115 | | |
116 | | void reset() override; |
117 | | |
118 | | /// Warn: the distances returned are L1 not L2 |
119 | | void search( |
120 | | idx_t n, |
121 | | const float* x, |
122 | | idx_t k, |
123 | | float* distances, |
124 | | idx_t* labels, |
125 | | const SearchParameters* params = nullptr) const override; |
126 | | }; |
127 | | |
128 | | } // namespace faiss |
129 | | |
130 | | #endif |