/root/doris/contrib/faiss/faiss/IndexPQ.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #ifndef FAISS_INDEX_PQ_H |
9 | | #define FAISS_INDEX_PQ_H |
10 | | |
11 | | #include <stdint.h> |
12 | | |
13 | | #include <vector> |
14 | | |
15 | | #include <faiss/IndexFlatCodes.h> |
16 | | #include <faiss/impl/PolysemousTraining.h> |
17 | | #include <faiss/impl/ProductQuantizer.h> |
18 | | #include <faiss/impl/platform_macros.h> |
19 | | |
20 | | namespace faiss { |
21 | | |
22 | | /** Index based on a product quantizer. Stored vectors are |
23 | | * approximated by PQ codes. */ |
24 | | struct IndexPQ : IndexFlatCodes { |
25 | | /// The product quantizer used to encode the vectors |
26 | | ProductQuantizer pq; |
27 | | |
28 | | /** Constructor. |
29 | | * |
30 | | * @param d dimensionality of the input vectors |
31 | | * @param M number of subquantizers |
32 | | * @param nbits number of bit per subvector index |
33 | | */ |
34 | | IndexPQ(int d, size_t M, size_t nbits, MetricType metric = METRIC_L2); |
35 | | |
36 | | IndexPQ(); |
37 | | |
38 | | void train(idx_t n, const float* x) override; |
39 | | |
40 | | void search( |
41 | | idx_t n, |
42 | | const float* x, |
43 | | idx_t k, |
44 | | float* distances, |
45 | | idx_t* labels, |
46 | | const SearchParameters* params = nullptr) const override; |
47 | | |
48 | | /* The standalone codec interface */ |
49 | | void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override; |
50 | | |
51 | | void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; |
52 | | |
53 | | FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override; |
54 | | |
55 | | /****************************************************** |
56 | | * Polysemous codes implementation |
57 | | ******************************************************/ |
58 | | bool do_polysemous_training; ///< false = standard PQ |
59 | | |
60 | | /// parameters used for the polysemous training |
61 | | PolysemousTraining polysemous_training; |
62 | | |
63 | | /// how to perform the search in search_core |
64 | | enum Search_type_t { |
65 | | ST_PQ, ///< asymmetric product quantizer (default) |
66 | | ST_HE, ///< Hamming distance on codes |
67 | | ST_generalized_HE, ///< nb of same codes |
68 | | ST_SDC, ///< symmetric product quantizer (SDC) |
69 | | ST_polysemous, ///< HE filter (using ht) + PQ combination |
70 | | ST_polysemous_generalize, ///< Filter on generalized Hamming |
71 | | }; |
72 | | |
73 | | Search_type_t search_type; |
74 | | |
75 | | // just encode the sign of the components, instead of using the PQ encoder |
76 | | // used only for the queries |
77 | | bool encode_signs; |
78 | | |
79 | | /// Hamming threshold used for polysemy |
80 | | int polysemous_ht; |
81 | | |
82 | | // actual polysemous search |
83 | | void search_core_polysemous( |
84 | | idx_t n, |
85 | | const float* x, |
86 | | idx_t k, |
87 | | float* distances, |
88 | | idx_t* labels, |
89 | | int polysemous_ht, |
90 | | bool generalized_hamming) const; |
91 | | |
92 | | /// prepare query for a polysemous search, but instead of |
93 | | /// computing the result, just get the histogram of Hamming |
94 | | /// distances. May be computed on a provided dataset if xb != NULL |
95 | | /// @param dist_histogram (M * nbits + 1) |
96 | | void hamming_distance_histogram( |
97 | | idx_t n, |
98 | | const float* x, |
99 | | idx_t nb, |
100 | | const float* xb, |
101 | | int64_t* dist_histogram); |
102 | | |
103 | | /** compute pairwise distances between queries and database |
104 | | * |
105 | | * @param n nb of query vectors |
106 | | * @param x query vector, size n * d |
107 | | * @param dis output distances, size n * ntotal |
108 | | */ |
109 | | void hamming_distance_table(idx_t n, const float* x, int32_t* dis) const; |
110 | | }; |
111 | | |
112 | | /// override search parameters from the class |
113 | | struct SearchParametersPQ : SearchParameters { |
114 | | IndexPQ::Search_type_t search_type; |
115 | | int polysemous_ht; |
116 | | }; |
117 | | |
118 | | /// statistics are robust to internal threading, but not if |
119 | | /// IndexPQ::search is called by multiple threads |
120 | | struct IndexPQStats { |
121 | | size_t nq; // nb of queries run |
122 | | size_t ncode; // nb of codes visited |
123 | | |
124 | | size_t n_hamming_pass; // nb of passed Hamming distance tests (for polysemy) |
125 | | |
126 | 8 | IndexPQStats() { |
127 | 8 | reset(); |
128 | 8 | } |
129 | | void reset(); |
130 | | }; |
131 | | |
132 | | FAISS_API extern IndexPQStats indexPQ_stats; |
133 | | |
134 | | /** Quantizer where centroids are virtual: they are the Cartesian |
135 | | * product of sub-centroids. */ |
136 | | struct MultiIndexQuantizer : Index { |
137 | | ProductQuantizer pq; |
138 | | |
139 | | MultiIndexQuantizer( |
140 | | int d, ///< dimension of the input vectors |
141 | | size_t M, ///< number of subquantizers |
142 | | size_t nbits); ///< number of bit per subvector index |
143 | | |
144 | | void train(idx_t n, const float* x) override; |
145 | | |
146 | | void search( |
147 | | idx_t n, |
148 | | const float* x, |
149 | | idx_t k, |
150 | | float* distances, |
151 | | idx_t* labels, |
152 | | const SearchParameters* params = nullptr) const override; |
153 | | |
154 | | /// add and reset will crash at runtime |
155 | | void add(idx_t n, const float* x) override; |
156 | | void reset() override; |
157 | | |
158 | 0 | MultiIndexQuantizer() {} |
159 | | |
160 | | void reconstruct(idx_t key, float* recons) const override; |
161 | | }; |
162 | | |
163 | | // block size used in MultiIndexQuantizer::search |
164 | | FAISS_API extern int multi_index_quantizer_search_bs; |
165 | | |
166 | | /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes |
167 | | */ |
168 | | struct MultiIndexQuantizer2 : MultiIndexQuantizer { |
169 | | /// M Indexes on d / M dimensions |
170 | | std::vector<Index*> assign_indexes; |
171 | | bool own_fields; |
172 | | |
173 | | MultiIndexQuantizer2(int d, size_t M, size_t nbits, Index** indexes); |
174 | | |
175 | | MultiIndexQuantizer2( |
176 | | int d, |
177 | | size_t nbits, |
178 | | Index* assign_index_0, |
179 | | Index* assign_index_1); |
180 | | |
181 | | void train(idx_t n, const float* x) override; |
182 | | |
183 | | void search( |
184 | | idx_t n, |
185 | | const float* x, |
186 | | idx_t k, |
187 | | float* distances, |
188 | | idx_t* labels, |
189 | | const SearchParameters* params = nullptr) const override; |
190 | | }; |
191 | | |
192 | | } // namespace faiss |
193 | | |
194 | | #endif |