/root/doris/contrib/faiss/faiss/IndexPQ.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexPQ.h> |
9 | | |
10 | | #include <cinttypes> |
11 | | #include <cmath> |
12 | | #include <cstddef> |
13 | | #include <cstdio> |
14 | | #include <cstring> |
15 | | |
16 | | #include <algorithm> |
17 | | #include <memory> |
18 | | |
19 | | #include <faiss/impl/DistanceComputer.h> |
20 | | #include <faiss/impl/FaissAssert.h> |
21 | | #include <faiss/utils/hamming.h> |
22 | | |
23 | | #include <faiss/impl/code_distance/code_distance.h> |
24 | | |
25 | | namespace faiss { |
26 | | |
27 | | /********************************************************* |
28 | | * IndexPQ implementation |
29 | | ********************************************************/ |
30 | | |
31 | | IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric) |
32 | 0 | : IndexFlatCodes(0, d, metric), pq(d, M, nbits) { |
33 | 0 | is_trained = false; |
34 | 0 | do_polysemous_training = false; |
35 | 0 | polysemous_ht = nbits * M + 1; |
36 | 0 | search_type = ST_PQ; |
37 | 0 | encode_signs = false; |
38 | 0 | code_size = pq.code_size; |
39 | 0 | } |
40 | | |
41 | 0 | IndexPQ::IndexPQ() { |
42 | 0 | metric_type = METRIC_L2; |
43 | 0 | is_trained = false; |
44 | 0 | do_polysemous_training = false; |
45 | 0 | polysemous_ht = pq.nbits * pq.M + 1; |
46 | 0 | search_type = ST_PQ; |
47 | 0 | encode_signs = false; |
48 | 0 | } |
49 | | |
50 | 0 | void IndexPQ::train(idx_t n, const float* x) { |
51 | 0 | if (!do_polysemous_training) { // standard training |
52 | 0 | pq.train(n, x); |
53 | 0 | } else { |
54 | 0 | idx_t ntrain_perm = polysemous_training.ntrain_permutation; |
55 | |
|
56 | 0 | if (ntrain_perm > n / 4) |
57 | 0 | ntrain_perm = n / 4; |
58 | 0 | if (verbose) { |
59 | 0 | printf("PQ training on %" PRId64 " points, remains %" PRId64 |
60 | 0 | " points: " |
61 | 0 | "training polysemous on %s\n", |
62 | 0 | n - ntrain_perm, |
63 | 0 | ntrain_perm, |
64 | 0 | ntrain_perm == 0 ? "centroids" : "these"); |
65 | 0 | } |
66 | 0 | pq.train(n - ntrain_perm, x); |
67 | |
|
68 | 0 | polysemous_training.optimize_pq_for_hamming( |
69 | 0 | pq, ntrain_perm, x + (n - ntrain_perm) * d); |
70 | 0 | } |
71 | 0 | is_trained = true; |
72 | 0 | } |
73 | | |
74 | | namespace { |
75 | | |
76 | | template <class PQDecoder> |
77 | | struct PQDistanceComputer : FlatCodesDistanceComputer { |
78 | | size_t d; |
79 | | MetricType metric; |
80 | | idx_t nb; |
81 | | const ProductQuantizer& pq; |
82 | | const float* sdc; |
83 | | std::vector<float> precomputed_table; |
84 | | size_t ndis; |
85 | | |
86 | 0 | float distance_to_code(const uint8_t* code) final { |
87 | 0 | ndis++; |
88 | |
|
89 | 0 | float dis = distance_single_code<PQDecoder>( |
90 | 0 | pq.M, pq.nbits, precomputed_table.data(), code); |
91 | 0 | return dis; |
92 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_10PQDecoder8EE16distance_to_codeEPKh Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_11PQDecoder16EE16distance_to_codeEPKh Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_16PQDecoderGenericEE16distance_to_codeEPKh |
93 | | |
94 | 0 | float symmetric_dis(idx_t i, idx_t j) override { |
95 | 0 | FAISS_THROW_IF_NOT(sdc); |
96 | 0 | const float* sdci = sdc; |
97 | 0 | float accu = 0; |
98 | 0 | PQDecoder codei(codes + i * code_size, pq.nbits); |
99 | 0 | PQDecoder codej(codes + j * code_size, pq.nbits); |
100 | |
|
101 | 0 | for (int l = 0; l < pq.M; l++) { |
102 | 0 | accu += sdci[codei.decode() + (codej.decode() << codei.nbits)]; |
103 | 0 | sdci += uint64_t(1) << (2 * codei.nbits); |
104 | 0 | } |
105 | 0 | ndis++; |
106 | 0 | return accu; |
107 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_10PQDecoder8EE13symmetric_disEll Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_11PQDecoder16EE13symmetric_disEll Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_16PQDecoderGenericEE13symmetric_disEll |
108 | | |
109 | | explicit PQDistanceComputer(const IndexPQ& storage) |
110 | 0 | : FlatCodesDistanceComputer( |
111 | 0 | storage.codes.data(), |
112 | 0 | storage.code_size), |
113 | 0 | pq(storage.pq) { |
114 | 0 | precomputed_table.resize(pq.M * pq.ksub); |
115 | 0 | nb = storage.ntotal; |
116 | 0 | d = storage.d; |
117 | 0 | metric = storage.metric_type; |
118 | 0 | if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) { |
119 | 0 | sdc = pq.sdc_table.data(); |
120 | 0 | } else { |
121 | 0 | sdc = nullptr; |
122 | 0 | } |
123 | 0 | ndis = 0; |
124 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_10PQDecoder8EEC2ERKNS_7IndexPQE Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_11PQDecoder16EEC2ERKNS_7IndexPQE Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_16PQDecoderGenericEEC2ERKNS_7IndexPQE |
125 | | |
126 | 0 | void set_query(const float* x) override { |
127 | 0 | if (metric == METRIC_L2) { |
128 | 0 | pq.compute_distance_table(x, precomputed_table.data()); |
129 | 0 | } else { |
130 | 0 | pq.compute_inner_prod_table(x, precomputed_table.data()); |
131 | 0 | } |
132 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_10PQDecoder8EE9set_queryEPKf Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_11PQDecoder16EE9set_queryEPKf Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_118PQDistanceComputerINS_16PQDecoderGenericEE9set_queryEPKf |
133 | | }; |
134 | | |
135 | | } // namespace |
136 | | |
137 | 0 | FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const { |
138 | 0 | if (pq.nbits == 8) { |
139 | 0 | return new PQDistanceComputer<PQDecoder8>(*this); |
140 | 0 | } else if (pq.nbits == 16) { |
141 | 0 | return new PQDistanceComputer<PQDecoder16>(*this); |
142 | 0 | } else { |
143 | 0 | return new PQDistanceComputer<PQDecoderGeneric>(*this); |
144 | 0 | } |
145 | 0 | } |
146 | | |
147 | | /***************************************** |
148 | | * IndexPQ polysemous search routines |
149 | | ******************************************/ |
150 | | |
151 | | void IndexPQ::search( |
152 | | idx_t n, |
153 | | const float* x, |
154 | | idx_t k, |
155 | | float* distances, |
156 | | idx_t* labels, |
157 | 0 | const SearchParameters* iparams) const { |
158 | 0 | FAISS_THROW_IF_NOT(k > 0); |
159 | 0 | FAISS_THROW_IF_NOT(is_trained); |
160 | | |
161 | 0 | const SearchParametersPQ* params = nullptr; |
162 | 0 | Search_type_t param_search_type = this->search_type; |
163 | |
|
164 | 0 | if (iparams) { |
165 | 0 | params = dynamic_cast<const SearchParametersPQ*>(iparams); |
166 | 0 | FAISS_THROW_IF_NOT_MSG(params, "invalid search params"); |
167 | 0 | FAISS_THROW_IF_NOT_MSG(!params->sel, "selector not supported"); |
168 | 0 | param_search_type = params->search_type; |
169 | 0 | } |
170 | | |
171 | 0 | if (param_search_type == ST_PQ) { // Simple PQ search |
172 | |
|
173 | 0 | if (metric_type == METRIC_L2) { |
174 | 0 | float_maxheap_array_t res = { |
175 | 0 | size_t(n), size_t(k), labels, distances}; |
176 | 0 | pq.search(x, n, codes.data(), ntotal, &res, true); |
177 | 0 | } else { |
178 | 0 | float_minheap_array_t res = { |
179 | 0 | size_t(n), size_t(k), labels, distances}; |
180 | 0 | pq.search_ip(x, n, codes.data(), ntotal, &res, true); |
181 | 0 | } |
182 | 0 | indexPQ_stats.nq += n; |
183 | 0 | indexPQ_stats.ncode += n * ntotal; |
184 | |
|
185 | 0 | } else if ( |
186 | 0 | param_search_type == ST_polysemous || |
187 | 0 | param_search_type == ST_polysemous_generalize) { |
188 | 0 | FAISS_THROW_IF_NOT(metric_type == METRIC_L2); |
189 | 0 | int param_polysemous_ht = |
190 | 0 | params ? params->polysemous_ht : this->polysemous_ht; |
191 | 0 | search_core_polysemous( |
192 | 0 | n, |
193 | 0 | x, |
194 | 0 | k, |
195 | 0 | distances, |
196 | 0 | labels, |
197 | 0 | param_polysemous_ht, |
198 | 0 | param_search_type == ST_polysemous_generalize); |
199 | |
|
200 | 0 | } else { // code-to-code distances |
201 | |
|
202 | 0 | std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]); |
203 | |
|
204 | 0 | if (!encode_signs) { |
205 | 0 | pq.compute_codes(x, q_codes.get(), n); |
206 | 0 | } else { |
207 | 0 | FAISS_THROW_IF_NOT(d == pq.nbits * pq.M); |
208 | 0 | memset(q_codes.get(), 0, n * pq.code_size); |
209 | 0 | for (size_t i = 0; i < n; i++) { |
210 | 0 | const float* xi = x + i * d; |
211 | 0 | uint8_t* code = q_codes.get() + i * pq.code_size; |
212 | 0 | for (int j = 0; j < d; j++) |
213 | 0 | if (xi[j] > 0) |
214 | 0 | code[j >> 3] |= 1 << (j & 7); |
215 | 0 | } |
216 | 0 | } |
217 | | |
218 | 0 | if (param_search_type == ST_SDC) { |
219 | 0 | float_maxheap_array_t res = { |
220 | 0 | size_t(n), size_t(k), labels, distances}; |
221 | |
|
222 | 0 | pq.search_sdc(q_codes.get(), n, codes.data(), ntotal, &res, true); |
223 | |
|
224 | 0 | } else { |
225 | 0 | std::unique_ptr<int[]> idistances(new int[n * k]); |
226 | |
|
227 | 0 | int_maxheap_array_t res = { |
228 | 0 | size_t(n), size_t(k), labels, idistances.get()}; |
229 | |
|
230 | 0 | if (param_search_type == ST_HE) { |
231 | 0 | hammings_knn_hc( |
232 | 0 | &res, |
233 | 0 | q_codes.get(), |
234 | 0 | codes.data(), |
235 | 0 | ntotal, |
236 | 0 | pq.code_size, |
237 | 0 | true); |
238 | |
|
239 | 0 | } else if (param_search_type == ST_generalized_HE) { |
240 | 0 | generalized_hammings_knn_hc( |
241 | 0 | &res, |
242 | 0 | q_codes.get(), |
243 | 0 | codes.data(), |
244 | 0 | ntotal, |
245 | 0 | pq.code_size, |
246 | 0 | true); |
247 | 0 | } |
248 | | |
249 | | // convert distances to floats |
250 | 0 | for (int i = 0; i < k * n; i++) |
251 | 0 | distances[i] = idistances[i]; |
252 | 0 | } |
253 | |
|
254 | 0 | indexPQ_stats.nq += n; |
255 | 0 | indexPQ_stats.ncode += n * ntotal; |
256 | 0 | } |
257 | 0 | } |
258 | | |
259 | 9 | void IndexPQStats::reset() { |
260 | 9 | nq = ncode = n_hamming_pass = 0; |
261 | 9 | } |
262 | | |
263 | | IndexPQStats indexPQ_stats; |
264 | | |
265 | | namespace { |
266 | | |
267 | | template <class HammingComputer> |
268 | | size_t polysemous_inner_loop( |
269 | | const IndexPQ* index, |
270 | | const float* dis_table_qi, |
271 | | const uint8_t* q_code, |
272 | | size_t k, |
273 | | float* heap_dis, |
274 | | int64_t* heap_ids, |
275 | 0 | int ht) { |
276 | 0 | int M = index->pq.M; |
277 | 0 | int code_size = index->pq.code_size; |
278 | 0 | int ksub = index->pq.ksub; |
279 | 0 | size_t ntotal = index->ntotal; |
280 | |
|
281 | 0 | const uint8_t* b_code = index->codes.data(); |
282 | |
|
283 | 0 | size_t n_pass_i = 0; |
284 | |
|
285 | 0 | HammingComputer hc(q_code, code_size); |
286 | |
|
287 | 0 | for (int64_t bi = 0; bi < ntotal; bi++) { |
288 | 0 | int hd = hc.hamming(b_code); |
289 | |
|
290 | 0 | if (hd < ht) { |
291 | 0 | n_pass_i++; |
292 | |
|
293 | 0 | float dis = 0; |
294 | 0 | const float* dis_table = dis_table_qi; |
295 | 0 | for (int m = 0; m < M; m++) { |
296 | 0 | dis += dis_table[b_code[m]]; |
297 | 0 | dis_table += ksub; |
298 | 0 | } |
299 | |
|
300 | 0 | if (dis < heap_dis[0]) { |
301 | 0 | maxheap_replace_top(k, heap_dis, heap_ids, dis, bi); |
302 | 0 | } |
303 | 0 | } |
304 | 0 | b_code += code_size; |
305 | 0 | } |
306 | 0 | return n_pass_i; |
307 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_16HammingComputer4EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_16HammingComputer8EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_17HammingComputer16EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_17HammingComputer20EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_17HammingComputer32EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_17HammingComputer64EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_22HammingComputerDefaultEEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_19GenHammingComputer8EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_20GenHammingComputer16EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_20GenHammingComputer32EEEmPKNS_7IndexPQEPKfPKhmPfPli Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_121polysemous_inner_loopINS_20GenHammingComputerM8EEEmPKNS_7IndexPQEPKfPKhmPfPli |
308 | | |
309 | | struct Run_polysemous_inner_loop { |
310 | | using T = size_t; |
311 | | template <class HammingComputer, class... Types> |
312 | 0 | size_t f(Types... args) { |
313 | 0 | return polysemous_inner_loop<HammingComputer>(args...); |
314 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_16HammingComputer4EJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_16HammingComputer8EJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_17HammingComputer16EJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_17HammingComputer20EJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_17HammingComputer32EJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_17HammingComputer64EJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_125Run_polysemous_inner_loop1fINS_22HammingComputerDefaultEJPKNS_7IndexPQEPKfPKhlPfPliEEEmDpT0_ |
315 | | }; |
316 | | |
317 | | } // anonymous namespace |
318 | | |
319 | | void IndexPQ::search_core_polysemous( |
320 | | idx_t n, |
321 | | const float* x, |
322 | | idx_t k, |
323 | | float* distances, |
324 | | idx_t* labels, |
325 | | int param_polysemous_ht, |
326 | 0 | bool generalized_hamming) const { |
327 | 0 | FAISS_THROW_IF_NOT(k > 0); |
328 | 0 | FAISS_THROW_IF_NOT(pq.nbits == 8); |
329 | | |
330 | 0 | if (param_polysemous_ht == 0) { |
331 | 0 | param_polysemous_ht = pq.nbits * pq.M + 1; |
332 | 0 | } |
333 | | |
334 | | // PQ distance tables |
335 | 0 | std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]); |
336 | 0 | pq.compute_distance_tables(n, x, dis_tables.get()); |
337 | | |
338 | | // Hamming embedding queries |
339 | 0 | std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]); |
340 | |
|
341 | 0 | if (false) { |
342 | 0 | pq.compute_codes(x, q_codes.get(), n); |
343 | 0 | } else { |
344 | 0 | #pragma omp parallel for |
345 | 0 | for (idx_t qi = 0; qi < n; qi++) { |
346 | 0 | pq.compute_code_from_distance_table( |
347 | 0 | dis_tables.get() + qi * pq.M * pq.ksub, |
348 | 0 | q_codes.get() + qi * pq.code_size); |
349 | 0 | } |
350 | 0 | } |
351 | |
|
352 | 0 | size_t n_pass = 0; |
353 | |
|
354 | 0 | int bad_code_size = 0; |
355 | |
|
356 | 0 | #pragma omp parallel for reduction(+ : n_pass, bad_code_size) |
357 | 0 | for (idx_t qi = 0; qi < n; qi++) { |
358 | 0 | const uint8_t* q_code = q_codes.get() + qi * pq.code_size; |
359 | |
|
360 | 0 | const float* dis_table_qi = dis_tables.get() + qi * pq.M * pq.ksub; |
361 | |
|
362 | 0 | int64_t* heap_ids = labels + qi * k; |
363 | 0 | float* heap_dis = distances + qi * k; |
364 | 0 | maxheap_heapify(k, heap_dis, heap_ids); |
365 | |
|
366 | 0 | if (!generalized_hamming) { |
367 | 0 | Run_polysemous_inner_loop r; |
368 | 0 | n_pass += dispatch_HammingComputer( |
369 | 0 | pq.code_size, |
370 | 0 | r, |
371 | 0 | this, |
372 | 0 | dis_table_qi, |
373 | 0 | q_code, |
374 | 0 | k, |
375 | 0 | heap_dis, |
376 | 0 | heap_ids, |
377 | 0 | param_polysemous_ht); |
378 | |
|
379 | 0 | } else { // generalized hamming |
380 | 0 | switch (pq.code_size) { |
381 | 0 | #define DISPATCH(cs) \ |
382 | 0 | case cs: \ |
383 | 0 | n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \ |
384 | 0 | this, \ |
385 | 0 | dis_table_qi, \ |
386 | 0 | q_code, \ |
387 | 0 | k, \ |
388 | 0 | heap_dis, \ |
389 | 0 | heap_ids, \ |
390 | 0 | param_polysemous_ht); \ |
391 | 0 | break; |
392 | 0 | DISPATCH(8) |
393 | 0 | DISPATCH(16) |
394 | 0 | DISPATCH(32) |
395 | 0 | default: |
396 | 0 | if (pq.code_size % 8 == 0) { |
397 | 0 | n_pass += polysemous_inner_loop<GenHammingComputerM8>( |
398 | 0 | this, |
399 | 0 | dis_table_qi, |
400 | 0 | q_code, |
401 | 0 | k, |
402 | 0 | heap_dis, |
403 | 0 | heap_ids, |
404 | 0 | param_polysemous_ht); |
405 | 0 | } else { |
406 | 0 | bad_code_size++; |
407 | 0 | } |
408 | 0 | break; |
409 | 0 | #undef DISPATCH |
410 | 0 | } |
411 | 0 | } |
412 | 0 | maxheap_reorder(k, heap_dis, heap_ids); |
413 | 0 | } |
414 | |
|
415 | 0 | if (bad_code_size) { |
416 | 0 | FAISS_THROW_FMT( |
417 | 0 | "code size %zd not supported for polysemous", pq.code_size); |
418 | 0 | } |
419 | | |
420 | 0 | indexPQ_stats.nq += n; |
421 | 0 | indexPQ_stats.ncode += n * ntotal; |
422 | 0 | indexPQ_stats.n_hamming_pass += n_pass; |
423 | 0 | } |
424 | | |
425 | | /* The standalone codec interface (just remaps to the PQ functions) */ |
426 | | |
427 | 0 | void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { |
428 | 0 | pq.compute_codes(x, bytes, n); |
429 | 0 | } |
430 | | |
431 | 0 | void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { |
432 | 0 | pq.decode(bytes, x, n); |
433 | 0 | } |
434 | | |
435 | | /***************************************** |
436 | | * Stats of IndexPQ codes |
437 | | ******************************************/ |
438 | | |
439 | | void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis) |
440 | 0 | const { |
441 | 0 | std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]); |
442 | |
|
443 | 0 | pq.compute_codes(x, q_codes.get(), n); |
444 | |
|
445 | 0 | hammings(q_codes.get(), codes.data(), n, ntotal, pq.code_size, dis); |
446 | 0 | } |
447 | | |
448 | | void IndexPQ::hamming_distance_histogram( |
449 | | idx_t n, |
450 | | const float* x, |
451 | | idx_t nb, |
452 | | const float* xb, |
453 | 0 | int64_t* hist) { |
454 | 0 | FAISS_THROW_IF_NOT(metric_type == METRIC_L2); |
455 | 0 | FAISS_THROW_IF_NOT(pq.code_size % 8 == 0); |
456 | 0 | FAISS_THROW_IF_NOT(pq.nbits == 8); |
457 | | |
458 | | // Hamming embedding queries |
459 | 0 | std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]); |
460 | 0 | pq.compute_codes(x, q_codes.get(), n); |
461 | |
|
462 | 0 | uint8_t* b_codes; |
463 | 0 | std::unique_ptr<uint8_t[]> del_b_codes; |
464 | |
|
465 | 0 | if (xb) { |
466 | 0 | b_codes = new uint8_t[nb * pq.code_size]; |
467 | 0 | del_b_codes.reset(b_codes); |
468 | 0 | pq.compute_codes(xb, b_codes, nb); |
469 | 0 | } else { |
470 | 0 | nb = ntotal; |
471 | 0 | b_codes = codes.data(); |
472 | 0 | } |
473 | 0 | int nbits = pq.M * pq.nbits; |
474 | 0 | memset(hist, 0, sizeof(*hist) * (nbits + 1)); |
475 | 0 | size_t bs = 256; |
476 | |
|
477 | 0 | #pragma omp parallel |
478 | 0 | { |
479 | 0 | std::vector<int64_t> histi(nbits + 1); |
480 | 0 | std::unique_ptr<hamdis_t[]> distances(new hamdis_t[nb * bs]); |
481 | 0 | #pragma omp for |
482 | 0 | for (idx_t q0 = 0; q0 < n; q0 += bs) { |
483 | | // printf ("dis stats: %zd/%zd\n", q0, n); |
484 | 0 | size_t q1 = q0 + bs; |
485 | 0 | if (q1 > n) |
486 | 0 | q1 = n; |
487 | |
|
488 | 0 | hammings( |
489 | 0 | q_codes.get() + q0 * pq.code_size, |
490 | 0 | b_codes, |
491 | 0 | q1 - q0, |
492 | 0 | nb, |
493 | 0 | pq.code_size, |
494 | 0 | distances.get()); |
495 | |
|
496 | 0 | for (size_t i = 0; i < nb * (q1 - q0); i++) |
497 | 0 | histi[distances[i]]++; |
498 | 0 | } |
499 | 0 | #pragma omp critical |
500 | 0 | { |
501 | 0 | for (int i = 0; i <= nbits; i++) |
502 | 0 | hist[i] += histi[i]; |
503 | 0 | } |
504 | 0 | } |
505 | 0 | } |
506 | | |
507 | | /***************************************** |
508 | | * MultiIndexQuantizer |
509 | | ******************************************/ |
510 | | |
511 | | namespace { |
512 | | |
513 | | template <typename T> |
514 | | struct PreSortedArray { |
515 | | const T* x; |
516 | | int N; |
517 | | |
518 | 0 | explicit PreSortedArray(int N) : N(N) {} |
519 | 0 | void init(const T* x_2) { |
520 | 0 | this->x = x_2; |
521 | 0 | } |
522 | | // get smallest value |
523 | 0 | T get_0() { |
524 | 0 | return x[0]; |
525 | 0 | } |
526 | | |
527 | | // get delta between n-smallest and n-1 -smallest |
528 | 0 | T get_diff(int n) { |
529 | 0 | return x[n] - x[n - 1]; |
530 | 0 | } |
531 | | |
532 | | // remap orders counted from smallest to indices in array |
533 | 0 | int get_ord(int n) { |
534 | 0 | return n; |
535 | 0 | } |
536 | | }; |
537 | | |
538 | | template <typename T> |
539 | | struct ArgSort { |
540 | | const T* x; |
541 | 0 | bool operator()(size_t i, size_t j) { |
542 | 0 | return x[i] < x[j]; |
543 | 0 | } |
544 | | }; |
545 | | |
546 | | /** Array that maintains a permutation of its elements so that the |
547 | | * array's elements are sorted |
548 | | */ |
549 | | template <typename T> |
550 | | struct SortedArray { |
551 | | const T* x; |
552 | | int N; |
553 | | std::vector<int> perm; |
554 | | |
555 | | explicit SortedArray(int N) { |
556 | | this->N = N; |
557 | | perm.resize(N); |
558 | | } |
559 | | |
560 | | void init(const T* x_2) { |
561 | | this->x = x_2; |
562 | | for (int n = 0; n < N; n++) |
563 | | perm[n] = n; |
564 | | ArgSort<T> cmp = {x_2}; |
565 | | std::sort(perm.begin(), perm.end(), cmp); |
566 | | } |
567 | | |
568 | | // get smallest value |
569 | | T get_0() { |
570 | | return x[perm[0]]; |
571 | | } |
572 | | |
573 | | // get delta between n-smallest and n-1 -smallest |
574 | | T get_diff(int n) { |
575 | | return x[perm[n]] - x[perm[n - 1]]; |
576 | | } |
577 | | |
578 | | // remap orders counted from smallest to indices in array |
579 | | int get_ord(int n) { |
580 | | return perm[n]; |
581 | | } |
582 | | }; |
583 | | |
584 | | /** Array has n values. Sort the k first ones and copy the other ones |
585 | | * into elements k..n-1 |
586 | | */ |
587 | | template <class C> |
588 | | void partial_sort( |
589 | | int k, |
590 | | int n, |
591 | | const typename C::T* vals, |
592 | 0 | typename C::TI* perm) { |
593 | | // insert first k elts in heap |
594 | 0 | for (int i = 1; i < k; i++) { |
595 | 0 | indirect_heap_push<C>(i + 1, vals, perm, perm[i]); |
596 | 0 | } |
597 | | |
598 | | // insert next n - k elts in heap |
599 | 0 | for (int i = k; i < n; i++) { |
600 | 0 | typename C::TI id = perm[i]; |
601 | 0 | typename C::TI top = perm[0]; |
602 | |
|
603 | 0 | if (C::cmp(vals[top], vals[id])) { |
604 | 0 | indirect_heap_pop<C>(k, vals, perm); |
605 | 0 | indirect_heap_push<C>(k, vals, perm, id); |
606 | 0 | perm[i] = top; |
607 | 0 | } else { |
608 | | // nothing, elt at i is good where it is. |
609 | 0 | } |
610 | 0 | } |
611 | | |
612 | | // order the k first elements in heap |
613 | 0 | for (int i = k - 1; i > 0; i--) { |
614 | 0 | typename C::TI top = perm[0]; |
615 | 0 | indirect_heap_pop<C>(i + 1, vals, perm); |
616 | 0 | perm[i] = top; |
617 | 0 | } |
618 | 0 | } |
619 | | |
620 | | /** same as SortedArray, but only the k first elements are sorted */ |
621 | | template <typename T> |
622 | | struct SemiSortedArray { |
623 | | const T* x; |
624 | | int N; |
625 | | |
626 | | // type of the heap: CMax = sort ascending |
627 | | using HC = CMax<T, int>; |
628 | | std::vector<int> perm; |
629 | | |
630 | | int k; // k elements are sorted |
631 | | |
632 | | int initial_k, k_factor; |
633 | | |
634 | 0 | explicit SemiSortedArray(int N) { |
635 | 0 | this->N = N; |
636 | 0 | perm.resize(N); |
637 | 0 | perm.resize(N); |
638 | 0 | initial_k = 3; |
639 | 0 | k_factor = 4; |
640 | 0 | } |
641 | | |
642 | 0 | void init(const T* x_2) { |
643 | 0 | this->x = x_2; |
644 | 0 | for (int n = 0; n < N; n++) |
645 | 0 | perm[n] = n; |
646 | 0 | k = 0; |
647 | 0 | grow(initial_k); |
648 | 0 | } |
649 | | |
650 | | /// grow the sorted part of the array to size next_k |
651 | 0 | void grow(int next_k) { |
652 | 0 | if (next_k < N) { |
653 | 0 | partial_sort<HC>(next_k - k, N - k, x, &perm[k]); |
654 | 0 | k = next_k; |
655 | 0 | } else { // full sort of remainder of array |
656 | 0 | ArgSort<T> cmp = {x}; |
657 | 0 | std::sort(perm.begin() + k, perm.end(), cmp); |
658 | 0 | k = N; |
659 | 0 | } |
660 | 0 | } |
661 | | |
662 | | // get smallest value |
663 | 0 | T get_0() { |
664 | 0 | return x[perm[0]]; |
665 | 0 | } |
666 | | |
667 | | // get delta between n-smallest and n-1 -smallest |
668 | 0 | T get_diff(int n) { |
669 | 0 | if (n >= k) { |
670 | | // want to keep powers of 2 - 1 |
671 | 0 | int next_k = (k + 1) * k_factor - 1; |
672 | 0 | grow(next_k); |
673 | 0 | } |
674 | 0 | return x[perm[n]] - x[perm[n - 1]]; |
675 | 0 | } |
676 | | |
677 | | // remap orders counted from smallest to indices in array |
678 | 0 | int get_ord(int n) { |
679 | 0 | assert(n < k); |
680 | 0 | return perm[n]; |
681 | 0 | } |
682 | | }; |
683 | | |
684 | | /***************************************** |
685 | | * Find the k smallest sums of M terms, where each term is taken in a |
686 | | * table x of n values. |
687 | | * |
688 | | * A combination of terms is encoded as a scalar 0 <= t < n^M. The |
689 | | * combination t0 ... t(M-1) that correspond to the sum |
690 | | * |
691 | | * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)] |
692 | | * |
693 | | * is encoded as |
694 | | * |
695 | | * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1) |
696 | | * |
697 | | * MinSumK is an object rather than a function, so that storage can be |
698 | | * re-used over several computations with the same sizes. use_seen is |
699 | | * good when there may be ties in the x array and it is a concern if |
700 | | * occasionally several t's are returned. |
701 | | * |
702 | | * @param x size M * n, values to add up |
703 | | * @param k nb of results to retrieve |
704 | | * @param M nb of terms |
705 | | * @param n nb of distinct values |
706 | | * @param sums output, size k, sorted |
707 | | * @param terms output, size k, with encoding as above |
708 | | * |
709 | | ******************************************/ |
710 | | template <typename T, class SSA, bool use_seen> |
711 | | struct MinSumK { |
712 | | int K; ///< nb of sums to return |
713 | | int M; ///< nb of elements to sum up |
714 | | int nbit; ///< nb of bits to encode one entry |
715 | | int N; ///< nb of possible elements for each of the M terms |
716 | | |
717 | | /** the heap. |
718 | | * We use a heap to maintain a queue of sums, with the associated |
719 | | * terms involved in the sum. |
720 | | */ |
721 | | using HC = CMin<T, int64_t>; |
722 | | size_t heap_capacity, heap_size; |
723 | | T* bh_val; |
724 | | int64_t* bh_ids; |
725 | | |
726 | | std::vector<SSA> ssx; |
727 | | |
728 | | // all results get pushed several times. When there are ties, they |
729 | | // are popped interleaved with others, so it is not easy to |
730 | | // identify them. Therefore, this bit array just marks elements |
731 | | // that were seen before. |
732 | | std::vector<uint8_t> seen; |
733 | | |
734 | 0 | MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) { |
735 | 0 | heap_capacity = K * M; |
736 | 0 | assert(N <= (1 << nbit)); |
737 | | |
738 | | // we'll do k steps, each step pushes at most M vals |
739 | 0 | bh_val = new T[heap_capacity]; |
740 | 0 | bh_ids = new int64_t[heap_capacity]; |
741 | |
|
742 | 0 | if (use_seen) { |
743 | 0 | int64_t n_ids = weight(M); |
744 | 0 | seen.resize((n_ids + 7) / 8); |
745 | 0 | } |
746 | |
|
747 | 0 | for (int m = 0; m < M; m++) |
748 | 0 | ssx.push_back(SSA(N)); |
749 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EEC2Eiiii Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EEC2Eiiii |
750 | | |
751 | 0 | int64_t weight(int i) { |
752 | 0 | return 1 << (i * nbit); |
753 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EE6weightEi Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EE6weightEi |
754 | | |
755 | 0 | bool is_seen(int64_t i) { |
756 | 0 | return (seen[i >> 3] >> (i & 7)) & 1; |
757 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EE7is_seenEl Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EE7is_seenEl |
758 | | |
759 | 0 | void mark_seen(int64_t i) { |
760 | 0 | if (use_seen) |
761 | 0 | seen[i >> 3] |= 1 << (i & 7); |
762 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EE9mark_seenEl Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EE9mark_seenEl |
763 | | |
764 | 0 | void run(const T* x, int64_t ldx, T* sums, int64_t* terms) { |
765 | 0 | heap_size = 0; |
766 | |
|
767 | 0 | for (int m = 0; m < M; m++) { |
768 | 0 | ssx[m].init(x); |
769 | 0 | x += ldx; |
770 | 0 | } |
771 | |
|
772 | 0 | { // initial result: take min for all elements |
773 | 0 | T sum = 0; |
774 | 0 | terms[0] = 0; |
775 | 0 | mark_seen(0); |
776 | 0 | for (int m = 0; m < M; m++) { |
777 | 0 | sum += ssx[m].get_0(); |
778 | 0 | } |
779 | 0 | sums[0] = sum; |
780 | 0 | for (int m = 0; m < M; m++) { |
781 | 0 | heap_push<HC>( |
782 | 0 | ++heap_size, |
783 | 0 | bh_val, |
784 | 0 | bh_ids, |
785 | 0 | sum + ssx[m].get_diff(1), |
786 | 0 | weight(m)); |
787 | 0 | } |
788 | 0 | } |
789 | |
|
790 | 0 | for (int k = 1; k < K; k++) { |
791 | | // pop smallest value from heap |
792 | 0 | if (use_seen) { // skip already seen elements |
793 | 0 | while (is_seen(bh_ids[0])) { |
794 | 0 | assert(heap_size > 0); |
795 | 0 | heap_pop<HC>(heap_size--, bh_val, bh_ids); |
796 | 0 | } |
797 | 0 | } |
798 | 0 | assert(heap_size > 0); |
799 | | |
800 | 0 | T sum = sums[k] = bh_val[0]; |
801 | 0 | int64_t ti = terms[k] = bh_ids[0]; |
802 | |
|
803 | 0 | if (use_seen) { |
804 | 0 | mark_seen(ti); |
805 | 0 | heap_pop<HC>(heap_size--, bh_val, bh_ids); |
806 | 0 | } else { |
807 | 0 | do { |
808 | 0 | heap_pop<HC>(heap_size--, bh_val, bh_ids); |
809 | 0 | } while (heap_size > 0 && bh_ids[0] == ti); |
810 | 0 | } |
811 | | |
812 | | // enqueue followers |
813 | 0 | int64_t ii = ti; |
814 | 0 | for (int m = 0; m < M; m++) { |
815 | 0 | int64_t n = ii & (((int64_t)1 << nbit) - 1); |
816 | 0 | ii >>= nbit; |
817 | 0 | if (n + 1 >= N) |
818 | 0 | continue; |
819 | | |
820 | 0 | enqueue_follower(ti, m, n, sum); |
821 | 0 | } |
822 | 0 | } |
823 | | |
824 | | /* |
825 | | for (int k = 0; k < K; k++) |
826 | | for (int l = k + 1; l < K; l++) |
827 | | assert (terms[k] != terms[l]); |
828 | | */ |
829 | | |
830 | | // convert indices by applying permutation |
831 | 0 | for (int k = 0; k < K; k++) { |
832 | 0 | int64_t ii = terms[k]; |
833 | 0 | if (use_seen) { |
834 | | // clear seen for reuse at next loop |
835 | 0 | seen[ii >> 3] = 0; |
836 | 0 | } |
837 | 0 | int64_t ti = 0; |
838 | 0 | for (int m = 0; m < M; m++) { |
839 | 0 | int64_t n = ii & (((int64_t)1 << nbit) - 1); |
840 | 0 | ti += int64_t(ssx[m].get_ord(n)) << (nbit * m); |
841 | 0 | ii >>= nbit; |
842 | 0 | } |
843 | 0 | terms[k] = ti; |
844 | 0 | } |
845 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EE3runEPKflPfPl Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EE3runEPKflPfPl |
846 | | |
847 | 0 | void enqueue_follower(int64_t ti, int m, int n, T sum) { |
848 | 0 | T next_sum = sum + ssx[m].get_diff(n + 1); |
849 | 0 | int64_t next_ti = ti + weight(m); |
850 | 0 | heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti); |
851 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EE16enqueue_followerEliif Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EE16enqueue_followerEliif |
852 | | |
853 | 0 | ~MinSumK() { |
854 | 0 | delete[] bh_ids; |
855 | 0 | delete[] bh_val; |
856 | 0 | } Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_15SemiSortedArrayIfEELb0EED2Ev Unexecuted instantiation: IndexPQ.cpp:_ZN5faiss12_GLOBAL__N_17MinSumKIfNS0_14PreSortedArrayIfEELb0EED2Ev |
857 | | }; |
858 | | |
859 | | } // anonymous namespace |
860 | | |
861 | | MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits) |
862 | 0 | : Index(d, METRIC_L2), pq(d, M, nbits) { |
863 | 0 | is_trained = false; |
864 | 0 | pq.verbose = verbose; |
865 | 0 | } |
866 | | |
867 | 0 | void MultiIndexQuantizer::train(idx_t n, const float* x) { |
868 | 0 | pq.verbose = verbose; |
869 | 0 | pq.train(n, x); |
870 | 0 | is_trained = true; |
871 | | // count virtual elements in index |
872 | 0 | ntotal = 1; |
873 | 0 | for (int m = 0; m < pq.M; m++) |
874 | 0 | ntotal *= pq.ksub; |
875 | 0 | } |
876 | | |
877 | | // block size used in MultiIndexQuantizer::search |
878 | | int multi_index_quantizer_search_bs = 32768; |
879 | | |
880 | | void MultiIndexQuantizer::search( |
881 | | idx_t n, |
882 | | const float* x, |
883 | | idx_t k, |
884 | | float* distances, |
885 | | idx_t* labels, |
886 | 0 | const SearchParameters* params) const { |
887 | 0 | FAISS_THROW_IF_NOT_MSG( |
888 | 0 | !params, "search params not supported for this index"); |
889 | 0 | if (n == 0) { |
890 | 0 | return; |
891 | 0 | } |
892 | 0 | FAISS_THROW_IF_NOT(k > 0); |
893 | | |
894 | | // the allocation just below can be severe... |
895 | 0 | idx_t bs = multi_index_quantizer_search_bs; |
896 | 0 | if (n > bs) { |
897 | 0 | for (idx_t i0 = 0; i0 < n; i0 += bs) { |
898 | 0 | idx_t i1 = std::min(i0 + bs, n); |
899 | 0 | if (verbose) { |
900 | 0 | printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64 |
901 | 0 | " / %" PRId64 "\n", |
902 | 0 | i0, |
903 | 0 | i1, |
904 | 0 | n); |
905 | 0 | } |
906 | 0 | search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k); |
907 | 0 | } |
908 | 0 | return; |
909 | 0 | } |
910 | | |
911 | 0 | std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]); |
912 | |
|
913 | 0 | pq.compute_distance_tables(n, x, dis_tables.get()); |
914 | |
|
915 | 0 | if (k == 1) { |
916 | | // simple version that just finds the min in each table |
917 | |
|
918 | 0 | #pragma omp parallel for |
919 | 0 | for (int i = 0; i < n; i++) { |
920 | 0 | const float* dis_table = dis_tables.get() + i * pq.ksub * pq.M; |
921 | 0 | float dis = 0; |
922 | 0 | idx_t label = 0; |
923 | |
|
924 | 0 | for (int s = 0; s < pq.M; s++) { |
925 | 0 | float vmin = HUGE_VALF; |
926 | 0 | idx_t lmin = -1; |
927 | |
|
928 | 0 | for (idx_t j = 0; j < pq.ksub; j++) { |
929 | 0 | if (dis_table[j] < vmin) { |
930 | 0 | vmin = dis_table[j]; |
931 | 0 | lmin = j; |
932 | 0 | } |
933 | 0 | } |
934 | 0 | dis += vmin; |
935 | 0 | label |= lmin << (s * pq.nbits); |
936 | 0 | dis_table += pq.ksub; |
937 | 0 | } |
938 | |
|
939 | 0 | distances[i] = dis; |
940 | 0 | labels[i] = label; |
941 | 0 | } |
942 | |
|
943 | 0 | } else { |
944 | 0 | #pragma omp parallel if (n > 1) |
945 | 0 | { |
946 | 0 | MinSumK<float, SemiSortedArray<float>, false> msk( |
947 | 0 | k, pq.M, pq.nbits, pq.ksub); |
948 | 0 | #pragma omp for |
949 | 0 | for (int i = 0; i < n; i++) { |
950 | 0 | msk.run(dis_tables.get() + i * pq.ksub * pq.M, |
951 | 0 | pq.ksub, |
952 | 0 | distances + i * k, |
953 | 0 | labels + i * k); |
954 | 0 | } |
955 | 0 | } |
956 | 0 | } |
957 | 0 | } |
958 | | |
959 | 0 | void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const { |
960 | 0 | int64_t jj = key; |
961 | 0 | for (int m = 0; m < pq.M; m++) { |
962 | 0 | int64_t n = jj & (((int64_t)1 << pq.nbits) - 1); |
963 | 0 | jj >>= pq.nbits; |
964 | 0 | memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub); |
965 | 0 | recons += pq.dsub; |
966 | 0 | } |
967 | 0 | } |
968 | | |
969 | 0 | void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) { |
970 | 0 | FAISS_THROW_MSG( |
971 | 0 | "This index has virtual elements, " |
972 | 0 | "it does not support add"); |
973 | 0 | } |
974 | | |
975 | 0 | void MultiIndexQuantizer::reset() { |
976 | 0 | FAISS_THROW_MSG( |
977 | 0 | "This index has virtual elements, " |
978 | 0 | "it does not support reset"); |
979 | 0 | } |
980 | | |
981 | | /***************************************** |
982 | | * MultiIndexQuantizer2 |
983 | | ******************************************/ |
984 | | |
985 | | MultiIndexQuantizer2::MultiIndexQuantizer2( |
986 | | int d, |
987 | | size_t M, |
988 | | size_t nbits, |
989 | | Index** indexes) |
990 | 0 | : MultiIndexQuantizer(d, M, nbits) { |
991 | 0 | assign_indexes.resize(M); |
992 | 0 | for (int i = 0; i < M; i++) { |
993 | 0 | FAISS_THROW_IF_NOT_MSG( |
994 | 0 | indexes[i]->d == pq.dsub, |
995 | 0 | "Provided sub-index has incorrect size"); |
996 | 0 | assign_indexes[i] = indexes[i]; |
997 | 0 | } |
998 | 0 | own_fields = false; |
999 | 0 | } |
1000 | | |
1001 | | MultiIndexQuantizer2::MultiIndexQuantizer2( |
1002 | | int d, |
1003 | | size_t nbits, |
1004 | | Index* assign_index_0, |
1005 | | Index* assign_index_1) |
1006 | 0 | : MultiIndexQuantizer(d, 2, nbits) { |
1007 | 0 | FAISS_THROW_IF_NOT_MSG( |
1008 | 0 | assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub, |
1009 | 0 | "Provided sub-index has incorrect size"); |
1010 | 0 | assign_indexes.resize(2); |
1011 | 0 | assign_indexes[0] = assign_index_0; |
1012 | 0 | assign_indexes[1] = assign_index_1; |
1013 | 0 | own_fields = false; |
1014 | 0 | } |
1015 | | |
1016 | 0 | void MultiIndexQuantizer2::train(idx_t n, const float* x) { |
1017 | 0 | MultiIndexQuantizer::train(n, x); |
1018 | | // add centroids to sub-indexes |
1019 | 0 | for (int i = 0; i < pq.M; i++) { |
1020 | 0 | assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0)); |
1021 | 0 | } |
1022 | 0 | } |
1023 | | |
1024 | | void MultiIndexQuantizer2::search( |
1025 | | idx_t n, |
1026 | | const float* x, |
1027 | | idx_t K, |
1028 | | float* distances, |
1029 | | idx_t* labels, |
1030 | 0 | const SearchParameters* params) const { |
1031 | 0 | FAISS_THROW_IF_NOT_MSG( |
1032 | 0 | !params, "search params not supported for this index"); |
1033 | | |
1034 | 0 | if (n == 0) { |
1035 | 0 | return; |
1036 | 0 | } |
1037 | | |
1038 | 0 | int k2 = std::min(K, int64_t(pq.ksub)); |
1039 | 0 | FAISS_THROW_IF_NOT(k2); |
1040 | | |
1041 | 0 | int64_t M = pq.M; |
1042 | 0 | int64_t dsub = pq.dsub, ksub = pq.ksub; |
1043 | | |
1044 | | // size (M, n, k2) |
1045 | 0 | std::vector<idx_t> sub_ids(n * M * k2); |
1046 | 0 | std::vector<float> sub_dis(n * M * k2); |
1047 | 0 | std::vector<float> xsub(n * dsub); |
1048 | |
|
1049 | 0 | for (int m = 0; m < M; m++) { |
1050 | 0 | float* xdest = xsub.data(); |
1051 | 0 | const float* xsrc = x + m * dsub; |
1052 | 0 | for (int j = 0; j < n; j++) { |
1053 | 0 | memcpy(xdest, xsrc, dsub * sizeof(xdest[0])); |
1054 | 0 | xsrc += d; |
1055 | 0 | xdest += dsub; |
1056 | 0 | } |
1057 | |
|
1058 | 0 | assign_indexes[m]->search( |
1059 | 0 | n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]); |
1060 | 0 | } |
1061 | |
|
1062 | 0 | if (K == 1) { |
1063 | | // simple version that just finds the min in each table |
1064 | 0 | assert(k2 == 1); |
1065 | | |
1066 | 0 | for (int i = 0; i < n; i++) { |
1067 | 0 | float dis = 0; |
1068 | 0 | idx_t label = 0; |
1069 | |
|
1070 | 0 | for (int m = 0; m < M; m++) { |
1071 | 0 | float vmin = sub_dis[i + m * n]; |
1072 | 0 | idx_t lmin = sub_ids[i + m * n]; |
1073 | 0 | dis += vmin; |
1074 | 0 | label |= lmin << (m * pq.nbits); |
1075 | 0 | } |
1076 | 0 | distances[i] = dis; |
1077 | 0 | labels[i] = label; |
1078 | 0 | } |
1079 | |
|
1080 | 0 | } else { |
1081 | 0 | #pragma omp parallel if (n > 1) |
1082 | 0 | { |
1083 | 0 | MinSumK<float, PreSortedArray<float>, false> msk( |
1084 | 0 | K, pq.M, pq.nbits, k2); |
1085 | 0 | #pragma omp for |
1086 | 0 | for (int i = 0; i < n; i++) { |
1087 | 0 | idx_t* li = labels + i * K; |
1088 | 0 | msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li); |
1089 | | |
1090 | | // remap ids |
1091 | |
|
1092 | 0 | const idx_t* idmap0 = sub_ids.data() + i * k2; |
1093 | 0 | int64_t ld_idmap = k2 * n; |
1094 | 0 | int64_t mask1 = ksub - (int64_t)1; |
1095 | |
|
1096 | 0 | for (int k = 0; k < K; k++) { |
1097 | 0 | const idx_t* idmap = idmap0; |
1098 | 0 | int64_t vin = li[k]; |
1099 | 0 | int64_t vout = 0; |
1100 | 0 | int bs = 0; |
1101 | 0 | for (int m = 0; m < M; m++) { |
1102 | 0 | int64_t s = vin & mask1; |
1103 | 0 | vin >>= pq.nbits; |
1104 | 0 | vout |= idmap[s] << bs; |
1105 | 0 | bs += pq.nbits; |
1106 | 0 | idmap += ld_idmap; |
1107 | 0 | } |
1108 | 0 | li[k] = vout; |
1109 | 0 | } |
1110 | 0 | } |
1111 | 0 | } |
1112 | 0 | } |
1113 | 0 | } |
1114 | | |
1115 | | } // namespace faiss |