/root/doris/contrib/faiss/faiss/IndexIVFFastScan.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <memory> |
11 | | |
12 | | #include <faiss/IndexIVF.h> |
13 | | #include <faiss/utils/AlignedTable.h> |
14 | | |
15 | | namespace faiss { |
16 | | |
17 | | struct NormTableScaler; |
18 | | struct SIMDResultHandlerToFloat; |
19 | | struct Quantizer; |
20 | | |
21 | | /** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now. |
22 | | * |
23 | | * The codes in the inverted lists are not stored sequentially but |
24 | | * grouped in blocks of size bbs. This makes it possible to very quickly |
25 | | * compute distances with SIMD instructions. |
26 | | * |
27 | | * Implementations (implem): |
28 | | * 0: auto-select implementation (default) |
29 | | * 1: orig's search, re-implemented |
30 | | * 2: orig's search, re-ordered by invlist |
31 | | * 10: optimizer int16 search, collect results in heap, no qbs |
32 | | * 11: idem, collect results in reservoir |
33 | | * 12: optimizer int16 search, collect results in heap, uses qbs |
34 | | * 13: idem, collect results in reservoir |
35 | | * 14: internally multithreaded implem over nq * nprobe |
36 | | * 15: same with reservoir |
37 | | * |
38 | | * For range search, only 10 and 12 are supported. |
39 | | * add 100 to the implem to force single-thread scanning (the coarse quantizer |
40 | | * may still use multiple threads). |
41 | | */ |
42 | | |
43 | | struct IndexIVFFastScan : IndexIVF { |
44 | | // size of the kernel |
45 | | int bbs; // set at build time |
46 | | |
47 | | size_t M; |
48 | | size_t nbits; |
49 | | size_t ksub; |
50 | | |
51 | | // M rounded up to a multiple of 2 |
52 | | size_t M2; |
53 | | |
54 | | // search-time implementation |
55 | | int implem = 0; |
56 | | // skip some parts of the computation (for timing) |
57 | | int skip = 0; |
58 | | |
59 | | // batching factors at search time (0 = default) |
60 | | int qbs = 0; |
61 | | size_t qbs2 = 0; |
62 | | |
63 | | // quantizer used to pack the codes |
64 | | Quantizer* fine_quantizer = nullptr; |
65 | | |
66 | | IndexIVFFastScan( |
67 | | Index* quantizer, |
68 | | size_t d, |
69 | | size_t nlist, |
70 | | size_t code_size, |
71 | | MetricType metric = METRIC_L2); |
72 | | |
73 | | IndexIVFFastScan(); |
74 | | |
75 | | /// called by implementations |
76 | | void init_fastscan( |
77 | | Quantizer* fine_quantizer, |
78 | | size_t M, |
79 | | size_t nbits, |
80 | | size_t nlist, |
81 | | MetricType metric, |
82 | | int bbs); |
83 | | |
84 | | // initialize the CodePacker in the InvertedLists |
85 | | void init_code_packer(); |
86 | | |
87 | | ~IndexIVFFastScan() override; |
88 | | |
89 | | /// orig's inverted lists (for debugging) |
90 | | InvertedLists* orig_invlists = nullptr; |
91 | | |
92 | | void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; |
93 | | |
94 | | // prepare look-up tables |
95 | | |
96 | | virtual bool lookup_table_is_3d() const = 0; |
97 | | |
98 | | // compact way of conveying coarse quantization results |
99 | | struct CoarseQuantized { |
100 | | size_t nprobe; |
101 | | const float* dis = nullptr; |
102 | | const idx_t* ids = nullptr; |
103 | | }; |
104 | | |
105 | | virtual void compute_LUT( |
106 | | size_t n, |
107 | | const float* x, |
108 | | const CoarseQuantized& cq, |
109 | | AlignedTable<float>& dis_tables, |
110 | | AlignedTable<float>& biases) const = 0; |
111 | | |
112 | | void compute_LUT_uint8( |
113 | | size_t n, |
114 | | const float* x, |
115 | | const CoarseQuantized& cq, |
116 | | AlignedTable<uint8_t>& dis_tables, |
117 | | AlignedTable<uint16_t>& biases, |
118 | | float* normalizers) const; |
119 | | |
120 | | void search( |
121 | | idx_t n, |
122 | | const float* x, |
123 | | idx_t k, |
124 | | float* distances, |
125 | | idx_t* labels, |
126 | | const SearchParameters* params = nullptr) const override; |
127 | | |
128 | | void search_preassigned( |
129 | | idx_t n, |
130 | | const float* x, |
131 | | idx_t k, |
132 | | const idx_t* assign, |
133 | | const float* centroid_dis, |
134 | | float* distances, |
135 | | idx_t* labels, |
136 | | bool store_pairs, |
137 | | const IVFSearchParameters* params = nullptr, |
138 | | IndexIVFStats* stats = nullptr) const override; |
139 | | |
140 | | void range_search( |
141 | | idx_t n, |
142 | | const float* x, |
143 | | float radius, |
144 | | RangeSearchResult* result, |
145 | | const SearchParameters* params = nullptr) const override; |
146 | | |
147 | | // internal search funcs |
148 | | |
149 | | // dispatch to implementations and parallelize |
150 | | void search_dispatch_implem( |
151 | | idx_t n, |
152 | | const float* x, |
153 | | idx_t k, |
154 | | float* distances, |
155 | | idx_t* labels, |
156 | | const CoarseQuantized& cq, |
157 | | const NormTableScaler* scaler, |
158 | | const IVFSearchParameters* params = nullptr) const; |
159 | | |
160 | | void range_search_dispatch_implem( |
161 | | idx_t n, |
162 | | const float* x, |
163 | | float radius, |
164 | | RangeSearchResult& rres, |
165 | | const CoarseQuantized& cq_in, |
166 | | const NormTableScaler* scaler, |
167 | | const IVFSearchParameters* params = nullptr) const; |
168 | | |
169 | | // impl 1 and 2 are just for verification |
170 | | template <class C> |
171 | | void search_implem_1( |
172 | | idx_t n, |
173 | | const float* x, |
174 | | idx_t k, |
175 | | float* distances, |
176 | | idx_t* labels, |
177 | | const CoarseQuantized& cq, |
178 | | const NormTableScaler* scaler, |
179 | | const IVFSearchParameters* params = nullptr) const; |
180 | | |
181 | | template <class C> |
182 | | void search_implem_2( |
183 | | idx_t n, |
184 | | const float* x, |
185 | | idx_t k, |
186 | | float* distances, |
187 | | idx_t* labels, |
188 | | const CoarseQuantized& cq, |
189 | | const NormTableScaler* scaler, |
190 | | const IVFSearchParameters* params = nullptr) const; |
191 | | |
192 | | // implem 10 and 12 are not multithreaded internally, so |
193 | | // export search stats |
194 | | void search_implem_10( |
195 | | idx_t n, |
196 | | const float* x, |
197 | | SIMDResultHandlerToFloat& handler, |
198 | | const CoarseQuantized& cq, |
199 | | size_t* ndis_out, |
200 | | size_t* nlist_out, |
201 | | const NormTableScaler* scaler, |
202 | | const IVFSearchParameters* params = nullptr) const; |
203 | | |
204 | | void search_implem_12( |
205 | | idx_t n, |
206 | | const float* x, |
207 | | SIMDResultHandlerToFloat& handler, |
208 | | const CoarseQuantized& cq, |
209 | | size_t* ndis_out, |
210 | | size_t* nlist_out, |
211 | | const NormTableScaler* scaler, |
212 | | const IVFSearchParameters* params = nullptr) const; |
213 | | |
214 | | // implem 14 is multithreaded internally across nprobes and queries |
215 | | void search_implem_14( |
216 | | idx_t n, |
217 | | const float* x, |
218 | | idx_t k, |
219 | | float* distances, |
220 | | idx_t* labels, |
221 | | const CoarseQuantized& cq, |
222 | | int impl, |
223 | | const NormTableScaler* scaler, |
224 | | const IVFSearchParameters* params = nullptr) const; |
225 | | |
226 | | // reconstruct vectors from packed invlists |
227 | | void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons) |
228 | | const override; |
229 | | |
230 | | CodePacker* get_CodePacker() const override; |
231 | | |
232 | | // reconstruct orig invlists (for debugging) |
233 | | void reconstruct_orig_invlists(); |
234 | | |
235 | | /** Decode a set of vectors. |
236 | | * |
237 | | * NOTE: The codes in the IndexFastScan object are non-contiguous. |
238 | | * But this method requires a contiguous representation. |
239 | | * |
240 | | * @param n number of vectors |
241 | | * @param bytes input encoded vectors, size n * code_size |
242 | | * @param x output vectors, size n * d |
243 | | */ |
244 | | void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; |
245 | | }; |
246 | | |
247 | | struct IVFFastScanStats { |
248 | | uint64_t times[10]; |
249 | | uint64_t t_compute_distance_tables, t_round; |
250 | | uint64_t t_copy_pack, t_scan, t_to_flat; |
251 | | uint64_t reservoir_times[4]; |
252 | | double t_aq_encode; |
253 | | double t_aq_norm_encode; |
254 | | |
255 | 0 | double Mcy_at(int i) { |
256 | 0 | return times[i] / (1000 * 1000.0); |
257 | 0 | } |
258 | | |
259 | 0 | double Mcy_reservoir_at(int i) { |
260 | 0 | return reservoir_times[i] / (1000 * 1000.0); |
261 | 0 | } |
262 | 1 | IVFFastScanStats() { |
263 | 1 | reset(); |
264 | 1 | } |
265 | 1 | void reset() { |
266 | 1 | memset(this, 0, sizeof(*this)); |
267 | 1 | } |
268 | | }; |
269 | | |
270 | | FAISS_API extern IVFFastScanStats IVFFastScan_stats; |
271 | | |
272 | | } // namespace faiss |