contrib/faiss/faiss/impl/ProductAdditiveQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/ProductAdditiveQuantizer.h> |
9 | | |
10 | | #include <cstddef> |
11 | | #include <cstdio> |
12 | | #include <cstring> |
13 | | #include <memory> |
14 | | |
15 | | #include <algorithm> |
16 | | |
17 | | #include <faiss/clone_index.h> |
18 | | #include <faiss/impl/AuxIndexStructures.h> |
19 | | #include <faiss/impl/FaissAssert.h> |
20 | | #include <faiss/utils/distances.h> |
21 | | #include <faiss/utils/hamming.h> |
22 | | |
23 | | extern "C" { |
24 | | |
25 | | // general matrix multiplication |
26 | | int sgemm_( |
27 | | const char* transa, |
28 | | const char* transb, |
29 | | FINTEGER* m, |
30 | | FINTEGER* n, |
31 | | FINTEGER* k, |
32 | | const float* alpha, |
33 | | const float* a, |
34 | | FINTEGER* lda, |
35 | | const float* b, |
36 | | FINTEGER* ldb, |
37 | | float* beta, |
38 | | float* c, |
39 | | FINTEGER* ldc); |
40 | | } |
41 | | |
42 | | namespace faiss { |
43 | | |
44 | | ProductAdditiveQuantizer::ProductAdditiveQuantizer( |
45 | | size_t d, |
46 | | const std::vector<AdditiveQuantizer*>& aqs, |
47 | 0 | Search_type_t search_type) { |
48 | 0 | init(d, aqs, search_type); |
49 | 0 | } |
50 | | |
51 | | ProductAdditiveQuantizer::ProductAdditiveQuantizer() |
52 | 0 | : ProductAdditiveQuantizer(0, {}) {} |
53 | | |
54 | | void ProductAdditiveQuantizer::init( |
55 | | size_t d, |
56 | | const std::vector<AdditiveQuantizer*>& aqs, |
57 | 0 | Search_type_t search_type) { |
58 | | // AdditiveQuantizer constructor |
59 | 0 | this->d = d; |
60 | 0 | this->search_type = search_type; |
61 | 0 | M = 0; |
62 | 0 | for (const auto& q : aqs) { |
63 | 0 | M += q->M; |
64 | 0 | nbits.insert(nbits.end(), q->nbits.begin(), q->nbits.end()); |
65 | 0 | } |
66 | 0 | set_derived_values(); |
67 | | |
68 | | // ProductAdditiveQuantizer |
69 | 0 | nsplits = aqs.size(); |
70 | |
|
71 | 0 | FAISS_THROW_IF_NOT(quantizers.empty()); |
72 | 0 | for (const auto& q : aqs) { |
73 | 0 | auto aq = dynamic_cast<AdditiveQuantizer*>(clone_Quantizer(q)); |
74 | 0 | quantizers.push_back(aq); |
75 | 0 | } |
76 | 0 | } |
77 | | |
78 | 0 | ProductAdditiveQuantizer::~ProductAdditiveQuantizer() { |
79 | 0 | for (auto& q : quantizers) { |
80 | 0 | delete q; |
81 | 0 | } |
82 | 0 | } |
83 | | |
84 | 0 | AdditiveQuantizer* ProductAdditiveQuantizer::subquantizer(size_t s) const { |
85 | 0 | return quantizers[s]; |
86 | 0 | } |
87 | | |
88 | 0 | void ProductAdditiveQuantizer::train(size_t n, const float* x) { |
89 | 0 | if (is_trained) { |
90 | 0 | return; |
91 | 0 | } |
92 | | |
93 | | // copy the subvectors into contiguous memory |
94 | 0 | size_t offset_d = 0; |
95 | 0 | std::vector<float> xt; |
96 | 0 | for (size_t s = 0; s < nsplits; s++) { |
97 | 0 | auto q = quantizers[s]; |
98 | 0 | xt.resize(q->d * n); |
99 | |
|
100 | 0 | #pragma omp parallel for if (n > 1000) |
101 | 0 | for (idx_t i = 0; i < n; i++) { |
102 | 0 | memcpy(xt.data() + i * q->d, |
103 | 0 | x + i * d + offset_d, |
104 | 0 | q->d * sizeof(*x)); |
105 | 0 | } |
106 | |
|
107 | 0 | q->train(n, xt.data()); |
108 | 0 | offset_d += q->d; |
109 | 0 | } |
110 | | |
111 | | // compute codebook size |
112 | 0 | size_t codebook_size = 0; |
113 | 0 | for (const auto& q : quantizers) { |
114 | 0 | codebook_size += q->total_codebook_size * q->d; |
115 | 0 | } |
116 | | |
117 | | // copy codebook from sub-quantizers |
118 | 0 | codebooks.resize(codebook_size); // size (M * ksub, dsub) |
119 | 0 | float* cb = codebooks.data(); |
120 | 0 | for (size_t s = 0; s < nsplits; s++) { |
121 | 0 | auto q = quantizers[s]; |
122 | 0 | size_t sub_codebook_size = q->total_codebook_size * q->d; |
123 | 0 | memcpy(cb, q->codebooks.data(), sub_codebook_size * sizeof(float)); |
124 | 0 | cb += sub_codebook_size; |
125 | 0 | } |
126 | |
|
127 | 0 | is_trained = true; |
128 | | |
129 | | // train norm |
130 | 0 | std::vector<int32_t> codes(n * M); |
131 | 0 | compute_unpacked_codes(x, codes.data(), n); |
132 | 0 | std::vector<float> x_recons(n * d); |
133 | 0 | std::vector<float> norms(n); |
134 | 0 | decode_unpacked(codes.data(), x_recons.data(), n); |
135 | 0 | fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n); |
136 | 0 | train_norm(n, norms.data()); |
137 | 0 | } |
138 | | |
139 | | void ProductAdditiveQuantizer::compute_codes_add_centroids( |
140 | | const float* x, |
141 | | uint8_t* codes_out, |
142 | | size_t n, |
143 | 0 | const float* centroids) const { |
144 | | // size (n, M) |
145 | 0 | std::vector<int32_t> unpacked_codes(n * M); |
146 | 0 | compute_unpacked_codes(x, unpacked_codes.data(), n, centroids); |
147 | | |
148 | | // pack |
149 | 0 | pack_codes(n, unpacked_codes.data(), codes_out, -1, nullptr, centroids); |
150 | 0 | } |
151 | | |
152 | | void ProductAdditiveQuantizer::compute_unpacked_codes( |
153 | | const float* x, |
154 | | int32_t* unpacked_codes, |
155 | | size_t n, |
156 | 0 | const float* centroids) const { |
157 | | /// TODO: actuallly we do not need to unpack and pack |
158 | 0 | size_t offset_d = 0, offset_m = 0; |
159 | 0 | std::vector<float> xsub; |
160 | 0 | std::vector<uint8_t> codes; |
161 | |
|
162 | 0 | for (size_t s = 0; s < nsplits; s++) { |
163 | 0 | const auto q = quantizers[s]; |
164 | 0 | xsub.resize(n * q->d); |
165 | 0 | codes.resize(n * q->code_size); |
166 | |
|
167 | 0 | #pragma omp parallel for if (n > 1000) |
168 | 0 | for (idx_t i = 0; i < n; i++) { |
169 | 0 | memcpy(xsub.data() + i * q->d, |
170 | 0 | x + i * d + offset_d, |
171 | 0 | q->d * sizeof(float)); |
172 | 0 | } |
173 | |
|
174 | 0 | q->compute_codes(xsub.data(), codes.data(), n); |
175 | | |
176 | | // unpack |
177 | 0 | #pragma omp parallel for if (n > 1000) |
178 | 0 | for (idx_t i = 0; i < n; i++) { |
179 | 0 | uint8_t* code = codes.data() + i * q->code_size; |
180 | 0 | BitstringReader bsr(code, q->code_size); |
181 | | |
182 | | // unpacked_codes[i][s][m] = codes[i][m] |
183 | 0 | for (size_t m = 0; m < q->M; m++) { |
184 | 0 | unpacked_codes[i * M + offset_m + m] = bsr.read(q->nbits[m]); |
185 | 0 | } |
186 | 0 | } |
187 | |
|
188 | 0 | offset_d += q->d; |
189 | 0 | offset_m += q->M; |
190 | 0 | } |
191 | 0 | } |
192 | | |
193 | | void ProductAdditiveQuantizer::decode_unpacked( |
194 | | const int32_t* codes, |
195 | | float* x, |
196 | | size_t n, |
197 | 0 | int64_t ld_codes) const { |
198 | 0 | FAISS_THROW_IF_NOT_MSG( |
199 | 0 | is_trained, "The product additive quantizer is not trained yet."); |
200 | | |
201 | 0 | if (ld_codes == -1) { |
202 | 0 | ld_codes = M; |
203 | 0 | } |
204 | | |
205 | | // product additive quantizer decoding |
206 | 0 | #pragma omp parallel for if (n > 1000) |
207 | 0 | for (int64_t i = 0; i < n; i++) { |
208 | 0 | const int32_t* codesi = codes + i * ld_codes; |
209 | |
|
210 | 0 | size_t offset_m = 0, offset_d = 0; |
211 | 0 | for (size_t s = 0; s < nsplits; s++) { |
212 | 0 | const auto q = quantizers[s]; |
213 | 0 | float* xi = x + i * d + offset_d; |
214 | |
|
215 | 0 | for (int m = 0; m < q->M; m++) { |
216 | 0 | int idx = codesi[offset_m + m]; |
217 | 0 | const float* c = codebooks.data() + |
218 | 0 | q->d * (codebook_offsets[offset_m + m] + idx); |
219 | 0 | if (m == 0) { |
220 | 0 | memcpy(xi, c, sizeof(*x) * q->d); |
221 | 0 | } else { |
222 | 0 | fvec_add(q->d, xi, c, xi); |
223 | 0 | } |
224 | 0 | } |
225 | |
|
226 | 0 | offset_m += q->M; |
227 | 0 | offset_d += q->d; |
228 | 0 | } |
229 | 0 | } |
230 | 0 | } |
231 | | |
232 | | void ProductAdditiveQuantizer::decode(const uint8_t* codes, float* x, size_t n) |
233 | 0 | const { |
234 | 0 | FAISS_THROW_IF_NOT_MSG( |
235 | 0 | is_trained, "The product additive quantizer is not trained yet."); |
236 | | |
237 | 0 | #pragma omp parallel for if (n > 1000) |
238 | 0 | for (int64_t i = 0; i < n; i++) { |
239 | 0 | BitstringReader bsr(codes + i * code_size, code_size); |
240 | |
|
241 | 0 | size_t offset_m = 0, offset_d = 0; |
242 | 0 | for (size_t s = 0; s < nsplits; s++) { |
243 | 0 | const auto q = quantizers[s]; |
244 | 0 | float* xi = x + i * d + offset_d; |
245 | |
|
246 | 0 | for (int m = 0; m < q->M; m++) { |
247 | 0 | int idx = bsr.read(q->nbits[m]); |
248 | 0 | const float* c = codebooks.data() + |
249 | 0 | q->d * (codebook_offsets[offset_m + m] + idx); |
250 | 0 | if (m == 0) { |
251 | 0 | memcpy(xi, c, sizeof(*x) * q->d); |
252 | 0 | } else { |
253 | 0 | fvec_add(q->d, xi, c, xi); |
254 | 0 | } |
255 | 0 | } |
256 | |
|
257 | 0 | offset_m += q->M; |
258 | 0 | offset_d += q->d; |
259 | 0 | } |
260 | 0 | } |
261 | 0 | } |
262 | | |
263 | | void ProductAdditiveQuantizer::compute_LUT( |
264 | | size_t n, |
265 | | const float* xq, |
266 | | float* LUT, |
267 | | float alpha, |
268 | 0 | long ld_lut) const { |
269 | | // codebooks: size (M * ksub, dsub) |
270 | | // xq: size (n, d) |
271 | | // output LUT: size (n, M * ksub) |
272 | |
|
273 | 0 | FINTEGER nqi = n; |
274 | | // leading dimension of 'LUT' and 'xq' |
275 | 0 | FINTEGER ld_LUT = ld_lut > 0 ? ld_lut : total_codebook_size; |
276 | 0 | FINTEGER ld_xq = d; |
277 | |
|
278 | 0 | float zero = 0; |
279 | 0 | size_t offset_d = 0; |
280 | 0 | size_t offset_cb = 0; |
281 | 0 | size_t offset_lut = 0; |
282 | |
|
283 | 0 | for (size_t s = 0; s < nsplits; s++) { |
284 | 0 | const auto q = quantizers[s]; |
285 | |
|
286 | 0 | FINTEGER ncenti = q->total_codebook_size; |
287 | 0 | FINTEGER ld_cb = q->d; // leading dimension of 'codebooks' |
288 | |
|
289 | 0 | auto codebooksi = codebooks.data() + offset_cb; |
290 | 0 | auto xqi = xq + offset_d; |
291 | 0 | auto LUTi = LUT + offset_lut; |
292 | |
|
293 | 0 | sgemm_("Transposed", |
294 | 0 | "Not transposed", |
295 | 0 | &ncenti, |
296 | 0 | &nqi, |
297 | 0 | &ld_cb, |
298 | 0 | &alpha, |
299 | 0 | codebooksi, |
300 | 0 | &ld_cb, |
301 | 0 | xqi, |
302 | 0 | &ld_xq, |
303 | 0 | &zero, |
304 | 0 | LUTi, |
305 | 0 | &ld_LUT); |
306 | |
|
307 | 0 | offset_d += q->d; |
308 | 0 | offset_cb += q->total_codebook_size * q->d; |
309 | 0 | offset_lut += q->total_codebook_size; |
310 | 0 | } |
311 | 0 | } |
312 | | |
313 | | /************************************* |
314 | | * Product Local Search Quantizer |
315 | | ************************************/ |
316 | | |
317 | | ProductLocalSearchQuantizer::ProductLocalSearchQuantizer( |
318 | | size_t d, |
319 | | size_t nsplits, |
320 | | size_t Msub, |
321 | | size_t nbits, |
322 | 0 | Search_type_t search_type) { |
323 | 0 | std::vector<AdditiveQuantizer*> aqs; |
324 | |
|
325 | 0 | if (nsplits > 0) { |
326 | 0 | FAISS_THROW_IF_NOT(d % nsplits == 0); |
327 | 0 | size_t dsub = d / nsplits; |
328 | |
|
329 | 0 | for (size_t i = 0; i < nsplits; i++) { |
330 | 0 | auto lsq = |
331 | 0 | new LocalSearchQuantizer(dsub, Msub, nbits, ST_decompress); |
332 | 0 | aqs.push_back(lsq); |
333 | 0 | } |
334 | 0 | } |
335 | 0 | init(d, aqs, search_type); |
336 | 0 | for (auto& q : aqs) { |
337 | 0 | delete q; |
338 | 0 | } |
339 | 0 | } |
340 | | |
341 | | ProductLocalSearchQuantizer::ProductLocalSearchQuantizer() |
342 | 0 | : ProductLocalSearchQuantizer(0, 0, 0, 0) {} |
343 | | |
344 | | /************************************* |
345 | | * Product Residual Quantizer |
346 | | ************************************/ |
347 | | |
348 | | ProductResidualQuantizer::ProductResidualQuantizer( |
349 | | size_t d, |
350 | | size_t nsplits, |
351 | | size_t Msub, |
352 | | size_t nbits, |
353 | 0 | Search_type_t search_type) { |
354 | 0 | std::vector<AdditiveQuantizer*> aqs; |
355 | |
|
356 | 0 | if (nsplits > 0) { |
357 | 0 | FAISS_THROW_IF_NOT(d % nsplits == 0); |
358 | 0 | size_t dsub = d / nsplits; |
359 | |
|
360 | 0 | for (size_t i = 0; i < nsplits; i++) { |
361 | 0 | auto rq = new ResidualQuantizer(dsub, Msub, nbits, ST_decompress); |
362 | 0 | aqs.push_back(rq); |
363 | 0 | } |
364 | 0 | } |
365 | 0 | init(d, aqs, search_type); |
366 | 0 | for (auto& q : aqs) { |
367 | 0 | delete q; |
368 | 0 | } |
369 | 0 | } |
370 | | |
371 | | ProductResidualQuantizer::ProductResidualQuantizer() |
372 | 0 | : ProductResidualQuantizer(0, 0, 0, 0) {} |
373 | | |
374 | | } // namespace faiss |