/root/doris/contrib/faiss/faiss/IndexIVFPQFastScan.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexIVFPQFastScan.h> |
9 | | |
10 | | #include <cassert> |
11 | | #include <cinttypes> |
12 | | #include <cstdio> |
13 | | |
14 | | #include <memory> |
15 | | |
16 | | #include <faiss/impl/AuxIndexStructures.h> |
17 | | #include <faiss/impl/FaissAssert.h> |
18 | | #include <faiss/utils/distances.h> |
19 | | #include <faiss/utils/simdlib.h> |
20 | | |
21 | | #include <faiss/invlists/BlockInvertedLists.h> |
22 | | |
23 | | #include <faiss/impl/pq4_fast_scan.h> |
24 | | #include <faiss/impl/simd_result_handlers.h> |
25 | | |
26 | | namespace faiss { |
27 | | |
28 | | using namespace simd_result_handlers; |
29 | | |
30 | | inline size_t roundup(size_t a, size_t b) { |
31 | | return (a + b - 1) / b * b; |
32 | | } |
33 | | |
34 | | IndexIVFPQFastScan::IndexIVFPQFastScan( |
35 | | Index* quantizer, |
36 | | size_t d, |
37 | | size_t nlist, |
38 | | size_t M, |
39 | | size_t nbits, |
40 | | MetricType metric, |
41 | | int bbs) |
42 | 0 | : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) { |
43 | 0 | by_residual = false; // set to false by default because it's faster |
44 | |
|
45 | 0 | init_fastscan(&pq, M, nbits, nlist, metric, bbs); |
46 | 0 | } |
47 | | |
48 | 0 | IndexIVFPQFastScan::IndexIVFPQFastScan() { |
49 | 0 | by_residual = false; |
50 | 0 | bbs = 0; |
51 | 0 | M2 = 0; |
52 | 0 | } |
53 | | |
54 | | IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs) |
55 | 0 | : IndexIVFFastScan( |
56 | 0 | orig.quantizer, |
57 | 0 | orig.d, |
58 | 0 | orig.nlist, |
59 | 0 | orig.pq.code_size, |
60 | 0 | orig.metric_type), |
61 | 0 | pq(orig.pq) { |
62 | 0 | FAISS_THROW_IF_NOT(orig.pq.nbits == 4); |
63 | | |
64 | 0 | init_fastscan( |
65 | 0 | &pq, orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs); |
66 | |
|
67 | 0 | by_residual = orig.by_residual; |
68 | 0 | ntotal = orig.ntotal; |
69 | 0 | is_trained = orig.is_trained; |
70 | 0 | nprobe = orig.nprobe; |
71 | |
|
72 | 0 | precomputed_table.resize(orig.precomputed_table.size()); |
73 | |
|
74 | 0 | if (precomputed_table.nbytes() > 0) { |
75 | 0 | memcpy(precomputed_table.get(), |
76 | 0 | orig.precomputed_table.data(), |
77 | 0 | precomputed_table.nbytes()); |
78 | 0 | } |
79 | |
|
80 | 0 | #pragma omp parallel for if (nlist > 100) |
81 | 0 | for (idx_t i = 0; i < nlist; i++) { |
82 | 0 | size_t nb = orig.invlists->list_size(i); |
83 | 0 | size_t nb2 = roundup(nb, bbs); |
84 | 0 | AlignedTable<uint8_t> tmp(nb2 * M2 / 2); |
85 | 0 | pq4_pack_codes( |
86 | 0 | InvertedLists::ScopedCodes(orig.invlists, i).get(), |
87 | 0 | nb, |
88 | 0 | M, |
89 | 0 | nb2, |
90 | 0 | bbs, |
91 | 0 | M2, |
92 | 0 | tmp.get()); |
93 | 0 | invlists->add_entries( |
94 | 0 | i, |
95 | 0 | nb, |
96 | 0 | InvertedLists::ScopedIds(orig.invlists, i).get(), |
97 | 0 | tmp.get()); |
98 | 0 | } |
99 | |
|
100 | 0 | orig_invlists = orig.invlists; |
101 | 0 | } |
102 | | |
103 | | /********************************************************* |
104 | | * Training |
105 | | *********************************************************/ |
106 | | |
107 | | void IndexIVFPQFastScan::train_encoder( |
108 | | idx_t n, |
109 | | const float* x, |
110 | 0 | const idx_t* assign) { |
111 | 0 | pq.verbose = verbose; |
112 | 0 | pq.train(n, x); |
113 | |
|
114 | 0 | if (by_residual && metric_type == METRIC_L2) { |
115 | 0 | precompute_table(); |
116 | 0 | } |
117 | 0 | } |
118 | | |
119 | 0 | idx_t IndexIVFPQFastScan::train_encoder_num_vectors() const { |
120 | 0 | return pq.cp.max_points_per_centroid * pq.ksub; |
121 | 0 | } |
122 | | |
123 | 0 | void IndexIVFPQFastScan::precompute_table() { |
124 | 0 | initialize_IVFPQ_precomputed_table( |
125 | 0 | use_precomputed_table, |
126 | 0 | quantizer, |
127 | 0 | pq, |
128 | 0 | precomputed_table, |
129 | 0 | by_residual, |
130 | 0 | verbose); |
131 | 0 | } |
132 | | |
133 | | /********************************************************* |
134 | | * Code management functions |
135 | | *********************************************************/ |
136 | | |
137 | | void IndexIVFPQFastScan::encode_vectors( |
138 | | idx_t n, |
139 | | const float* x, |
140 | | const idx_t* list_nos, |
141 | | uint8_t* codes, |
142 | 0 | bool include_listnos) const { |
143 | 0 | if (by_residual) { |
144 | 0 | AlignedTable<float> residuals(n * d); |
145 | 0 | for (size_t i = 0; i < n; i++) { |
146 | 0 | if (list_nos[i] < 0) { |
147 | 0 | memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d); |
148 | 0 | } else { |
149 | 0 | quantizer->compute_residual( |
150 | 0 | x + i * d, residuals.data() + i * d, list_nos[i]); |
151 | 0 | } |
152 | 0 | } |
153 | 0 | pq.compute_codes(residuals.data(), codes, n); |
154 | 0 | } else { |
155 | 0 | pq.compute_codes(x, codes, n); |
156 | 0 | } |
157 | |
|
158 | 0 | if (include_listnos) { |
159 | 0 | size_t coarse_size = coarse_code_size(); |
160 | 0 | for (idx_t i = n - 1; i >= 0; i--) { |
161 | 0 | uint8_t* code = codes + i * (coarse_size + code_size); |
162 | 0 | memmove(code + coarse_size, codes + i * code_size, code_size); |
163 | 0 | encode_listno(list_nos[i], code); |
164 | 0 | } |
165 | 0 | } |
166 | 0 | } |
167 | | |
168 | | /********************************************************* |
169 | | * Look-Up Table functions |
170 | | *********************************************************/ |
171 | | |
172 | | void fvec_madd_simd( |
173 | | size_t n, |
174 | | const float* a, |
175 | | float bf, |
176 | | const float* b, |
177 | 0 | float* c) { |
178 | 0 | assert(is_aligned_pointer(a)); |
179 | 0 | assert(is_aligned_pointer(b)); |
180 | 0 | assert(is_aligned_pointer(c)); |
181 | 0 | assert(n % 8 == 0); |
182 | 0 | simd8float32 bf8(bf); |
183 | 0 | n /= 8; |
184 | 0 | for (size_t i = 0; i < n; i++) { |
185 | 0 | simd8float32 ai(a); |
186 | 0 | simd8float32 bi(b); |
187 | |
|
188 | 0 | simd8float32 ci = fmadd(bf8, bi, ai); |
189 | 0 | ci.store(c); |
190 | 0 | c += 8; |
191 | 0 | a += 8; |
192 | 0 | b += 8; |
193 | 0 | } |
194 | 0 | } |
195 | | |
196 | 0 | bool IndexIVFPQFastScan::lookup_table_is_3d() const { |
197 | 0 | return by_residual && metric_type == METRIC_L2; |
198 | 0 | } |
199 | | |
200 | | void IndexIVFPQFastScan::compute_LUT( |
201 | | size_t n, |
202 | | const float* x, |
203 | | const CoarseQuantized& cq, |
204 | | AlignedTable<float>& dis_tables, |
205 | 0 | AlignedTable<float>& biases) const { |
206 | 0 | size_t dim12 = pq.ksub * pq.M; |
207 | 0 | size_t d = pq.d; |
208 | 0 | size_t nprobe = this->nprobe; |
209 | |
|
210 | 0 | if (by_residual) { |
211 | 0 | if (metric_type == METRIC_L2) { |
212 | 0 | dis_tables.resize(n * nprobe * dim12); |
213 | |
|
214 | 0 | if (use_precomputed_table == 1) { |
215 | 0 | biases.resize(n * nprobe); |
216 | 0 | memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe); |
217 | |
|
218 | 0 | AlignedTable<float> ip_table(n * dim12); |
219 | 0 | pq.compute_inner_prod_tables(n, x, ip_table.get()); |
220 | |
|
221 | 0 | #pragma omp parallel for if (n * nprobe > 8000) |
222 | 0 | for (idx_t ij = 0; ij < n * nprobe; ij++) { |
223 | 0 | idx_t i = ij / nprobe; |
224 | 0 | float* tab = dis_tables.get() + ij * dim12; |
225 | 0 | idx_t cij = cq.ids[ij]; |
226 | |
|
227 | 0 | if (cij >= 0) { |
228 | 0 | fvec_madd_simd( |
229 | 0 | dim12, |
230 | 0 | precomputed_table.get() + cij * dim12, |
231 | 0 | -2, |
232 | 0 | ip_table.get() + i * dim12, |
233 | 0 | tab); |
234 | 0 | } else { |
235 | | // fill with NaNs so that they are ignored during |
236 | | // LUT quantization |
237 | 0 | memset(tab, -1, sizeof(float) * dim12); |
238 | 0 | } |
239 | 0 | } |
240 | |
|
241 | 0 | } else { |
242 | 0 | std::unique_ptr<float[]> xrel(new float[n * nprobe * d]); |
243 | 0 | biases.resize(n * nprobe); |
244 | 0 | memset(biases.get(), 0, sizeof(float) * n * nprobe); |
245 | |
|
246 | 0 | #pragma omp parallel for if (n * nprobe > 8000) |
247 | 0 | for (idx_t ij = 0; ij < n * nprobe; ij++) { |
248 | 0 | idx_t i = ij / nprobe; |
249 | 0 | float* xij = &xrel[ij * d]; |
250 | 0 | idx_t cij = cq.ids[ij]; |
251 | |
|
252 | 0 | if (cij >= 0) { |
253 | 0 | quantizer->compute_residual(x + i * d, xij, cij); |
254 | 0 | } else { |
255 | | // will fill with NaNs |
256 | 0 | memset(xij, -1, sizeof(float) * d); |
257 | 0 | } |
258 | 0 | } |
259 | |
|
260 | 0 | pq.compute_distance_tables( |
261 | 0 | n * nprobe, xrel.get(), dis_tables.get()); |
262 | 0 | } |
263 | |
|
264 | 0 | } else if (metric_type == METRIC_INNER_PRODUCT) { |
265 | 0 | dis_tables.resize(n * dim12); |
266 | 0 | pq.compute_inner_prod_tables(n, x, dis_tables.get()); |
267 | | // compute_inner_prod_tables(pq, n, x, dis_tables.get()); |
268 | |
|
269 | 0 | biases.resize(n * nprobe); |
270 | 0 | memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe); |
271 | 0 | } else { |
272 | 0 | FAISS_THROW_FMT("metric %d not supported", metric_type); |
273 | 0 | } |
274 | |
|
275 | 0 | } else { |
276 | 0 | dis_tables.resize(n * dim12); |
277 | 0 | if (metric_type == METRIC_L2) { |
278 | 0 | pq.compute_distance_tables(n, x, dis_tables.get()); |
279 | 0 | } else if (metric_type == METRIC_INNER_PRODUCT) { |
280 | 0 | pq.compute_inner_prod_tables(n, x, dis_tables.get()); |
281 | 0 | } else { |
282 | 0 | FAISS_THROW_FMT("metric %d not supported", metric_type); |
283 | 0 | } |
284 | 0 | } |
285 | 0 | } |
286 | | |
287 | | } // namespace faiss |