/root/doris/contrib/faiss/faiss/impl/ResidualQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/ResidualQuantizer.h> |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cmath> |
12 | | #include <cstddef> |
13 | | #include <cstdio> |
14 | | #include <cstring> |
15 | | #include <memory> |
16 | | |
17 | | #include <faiss/IndexFlat.h> |
18 | | #include <faiss/VectorTransform.h> |
19 | | #include <faiss/impl/FaissAssert.h> |
20 | | #include <faiss/impl/residual_quantizer_encode_steps.h> |
21 | | #include <faiss/utils/distances.h> |
22 | | #include <faiss/utils/hamming.h> |
23 | | #include <faiss/utils/utils.h> |
24 | | |
25 | | extern "C" { |
26 | | |
27 | | // general matrix multiplication |
28 | | int sgemm_( |
29 | | const char* transa, |
30 | | const char* transb, |
31 | | FINTEGER* m, |
32 | | FINTEGER* n, |
33 | | FINTEGER* k, |
34 | | const float* alpha, |
35 | | const float* a, |
36 | | FINTEGER* lda, |
37 | | const float* b, |
38 | | FINTEGER* ldb, |
39 | | float* beta, |
40 | | float* c, |
41 | | FINTEGER* ldc); |
42 | | |
43 | | // http://www.netlib.org/clapack/old/single/sgels.c |
44 | | // solve least squares |
45 | | |
46 | | int sgelsd_( |
47 | | FINTEGER* m, |
48 | | FINTEGER* n, |
49 | | FINTEGER* nrhs, |
50 | | float* a, |
51 | | FINTEGER* lda, |
52 | | float* b, |
53 | | FINTEGER* ldb, |
54 | | float* s, |
55 | | float* rcond, |
56 | | FINTEGER* rank, |
57 | | float* work, |
58 | | FINTEGER* lwork, |
59 | | FINTEGER* iwork, |
60 | | FINTEGER* info); |
61 | | } |
62 | | |
63 | | namespace faiss { |
64 | | |
65 | 0 | ResidualQuantizer::ResidualQuantizer() { |
66 | 0 | d = 0; |
67 | 0 | M = 0; |
68 | 0 | verbose = false; |
69 | 0 | } |
70 | | |
71 | | ResidualQuantizer::ResidualQuantizer( |
72 | | size_t d, |
73 | | const std::vector<size_t>& nbits, |
74 | | Search_type_t search_type) |
75 | 0 | : ResidualQuantizer() { |
76 | 0 | this->search_type = search_type; |
77 | 0 | this->d = d; |
78 | 0 | M = nbits.size(); |
79 | 0 | this->nbits = nbits; |
80 | 0 | set_derived_values(); |
81 | 0 | } |
82 | | |
83 | | ResidualQuantizer::ResidualQuantizer( |
84 | | size_t d, |
85 | | size_t M, |
86 | | size_t nbits, |
87 | | Search_type_t search_type) |
88 | 0 | : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {} |
89 | | |
90 | | void ResidualQuantizer::initialize_from( |
91 | | const ResidualQuantizer& other, |
92 | 0 | int skip_M) { |
93 | 0 | FAISS_THROW_IF_NOT(M + skip_M <= other.M); |
94 | 0 | FAISS_THROW_IF_NOT(skip_M >= 0); |
95 | | |
96 | 0 | Search_type_t this_search_type = search_type; |
97 | 0 | int this_M = M; |
98 | | |
99 | | // a first good approximation: override everything |
100 | 0 | *this = other; |
101 | | |
102 | | // adjust derived values |
103 | 0 | M = this_M; |
104 | 0 | search_type = this_search_type; |
105 | 0 | nbits.resize(M); |
106 | 0 | memcpy(nbits.data(), |
107 | 0 | other.nbits.data() + skip_M, |
108 | 0 | nbits.size() * sizeof(nbits[0])); |
109 | |
|
110 | 0 | set_derived_values(); |
111 | | |
112 | | // resize codebooks if trained |
113 | 0 | if (codebooks.size() > 0) { |
114 | 0 | FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d); |
115 | 0 | codebooks.resize(total_codebook_size * d); |
116 | 0 | memcpy(codebooks.data(), |
117 | 0 | other.codebooks.data() + other.codebook_offsets[skip_M] * d, |
118 | 0 | codebooks.size() * sizeof(codebooks[0])); |
119 | | // TODO: norm_tabs? |
120 | 0 | } |
121 | 0 | } |
122 | | |
123 | | /**************************************************************** |
124 | | * Training |
125 | | ****************************************************************/ |
126 | | |
127 | 0 | void ResidualQuantizer::train(size_t n, const float* x) { |
128 | 0 | codebooks.resize(d * codebook_offsets.back()); |
129 | |
|
130 | 0 | if (verbose) { |
131 | 0 | printf("Training ResidualQuantizer, with %zd steps on %zd %zdD vectors\n", |
132 | 0 | M, |
133 | 0 | n, |
134 | 0 | size_t(d)); |
135 | 0 | } |
136 | |
|
137 | 0 | int cur_beam_size = 1; |
138 | 0 | std::vector<float> residuals(x, x + n * d); |
139 | 0 | std::vector<int32_t> codes; |
140 | 0 | std::vector<float> distances; |
141 | 0 | double t0 = getmillisecs(); |
142 | 0 | double clustering_time = 0; |
143 | |
|
144 | 0 | for (int m = 0; m < M; m++) { |
145 | 0 | int K = 1 << nbits[m]; |
146 | | |
147 | | // on which residuals to train |
148 | 0 | std::vector<float>& train_residuals = residuals; |
149 | 0 | std::vector<float> residuals1; |
150 | 0 | if (train_type & Train_top_beam) { |
151 | 0 | residuals1.resize(n * d); |
152 | 0 | for (size_t j = 0; j < n; j++) { |
153 | 0 | memcpy(residuals1.data() + j * d, |
154 | 0 | residuals.data() + j * d * cur_beam_size, |
155 | 0 | sizeof(residuals[0]) * d); |
156 | 0 | } |
157 | 0 | train_residuals = residuals1; |
158 | 0 | } |
159 | 0 | std::vector<float> codebooks; |
160 | 0 | float obj = 0; |
161 | |
|
162 | 0 | std::unique_ptr<Index> assign_index; |
163 | 0 | if (assign_index_factory) { |
164 | 0 | assign_index.reset((*assign_index_factory)(d)); |
165 | 0 | } else { |
166 | 0 | assign_index.reset(new IndexFlatL2(d)); |
167 | 0 | } |
168 | |
|
169 | 0 | double t1 = getmillisecs(); |
170 | |
|
171 | 0 | if (!(train_type & Train_progressive_dim)) { // regular kmeans |
172 | 0 | Clustering clus(d, K, cp); |
173 | 0 | clus.train( |
174 | 0 | train_residuals.size() / d, |
175 | 0 | train_residuals.data(), |
176 | 0 | *assign_index.get()); |
177 | 0 | codebooks.swap(clus.centroids); |
178 | 0 | assign_index->reset(); |
179 | 0 | obj = clus.iteration_stats.back().obj; |
180 | 0 | } else { // progressive dim clustering |
181 | 0 | ProgressiveDimClustering clus(d, K, cp); |
182 | 0 | ProgressiveDimIndexFactory default_fac; |
183 | 0 | clus.train( |
184 | 0 | train_residuals.size() / d, |
185 | 0 | train_residuals.data(), |
186 | 0 | assign_index_factory ? *assign_index_factory : default_fac); |
187 | 0 | codebooks.swap(clus.centroids); |
188 | 0 | obj = clus.iteration_stats.back().obj; |
189 | 0 | } |
190 | 0 | clustering_time += (getmillisecs() - t1) / 1000; |
191 | |
|
192 | 0 | memcpy(this->codebooks.data() + codebook_offsets[m] * d, |
193 | 0 | codebooks.data(), |
194 | 0 | codebooks.size() * sizeof(codebooks[0])); |
195 | | |
196 | | // quantize using the new codebooks |
197 | |
|
198 | 0 | int new_beam_size = std::min(cur_beam_size * K, max_beam_size); |
199 | 0 | std::vector<int32_t> new_codes(n * new_beam_size * (m + 1)); |
200 | 0 | std::vector<float> new_residuals(n * new_beam_size * d); |
201 | 0 | std::vector<float> new_distances(n * new_beam_size); |
202 | |
|
203 | 0 | size_t bs; |
204 | 0 | { // determine batch size |
205 | 0 | size_t mem = memory_per_point(); |
206 | 0 | if (n > 1 && mem * n > max_mem_distances) { |
207 | | // then split queries to reduce temp memory |
208 | 0 | bs = std::max(max_mem_distances / mem, size_t(1)); |
209 | 0 | } else { |
210 | 0 | bs = n; |
211 | 0 | } |
212 | 0 | } |
213 | |
|
214 | 0 | for (size_t i0 = 0; i0 < n; i0 += bs) { |
215 | 0 | size_t i1 = std::min(i0 + bs, n); |
216 | | |
217 | | /* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n", |
218 | | i0, i1, K, assign_index->ntotal); */ |
219 | |
|
220 | 0 | beam_search_encode_step( |
221 | 0 | d, |
222 | 0 | K, |
223 | 0 | codebooks.data(), |
224 | 0 | i1 - i0, |
225 | 0 | cur_beam_size, |
226 | 0 | residuals.data() + i0 * cur_beam_size * d, |
227 | 0 | m, |
228 | 0 | codes.data() + i0 * cur_beam_size * m, |
229 | 0 | new_beam_size, |
230 | 0 | new_codes.data() + i0 * new_beam_size * (m + 1), |
231 | 0 | new_residuals.data() + i0 * new_beam_size * d, |
232 | 0 | new_distances.data() + i0 * new_beam_size, |
233 | 0 | assign_index.get(), |
234 | 0 | approx_topk_mode); |
235 | 0 | } |
236 | 0 | codes.swap(new_codes); |
237 | 0 | residuals.swap(new_residuals); |
238 | 0 | distances.swap(new_distances); |
239 | |
|
240 | 0 | float sum_distances = 0; |
241 | 0 | for (int j = 0; j < distances.size(); j++) { |
242 | 0 | sum_distances += distances[j]; |
243 | 0 | } |
244 | |
|
245 | 0 | if (verbose) { |
246 | 0 | printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, " |
247 | 0 | "total distance %g, beam_size %d->%d (batch size %zd)\n", |
248 | 0 | (getmillisecs() - t0) / 1000, |
249 | 0 | clustering_time, |
250 | 0 | m, |
251 | 0 | int(nbits[m]), |
252 | 0 | obj, |
253 | 0 | sum_distances, |
254 | 0 | cur_beam_size, |
255 | 0 | new_beam_size, |
256 | 0 | bs); |
257 | 0 | } |
258 | 0 | cur_beam_size = new_beam_size; |
259 | 0 | } |
260 | |
|
261 | 0 | is_trained = true; |
262 | |
|
263 | 0 | if (train_type & Train_refine_codebook) { |
264 | 0 | for (int iter = 0; iter < niter_codebook_refine; iter++) { |
265 | 0 | if (verbose) { |
266 | 0 | printf("re-estimating the codebooks to minimize " |
267 | 0 | "quantization errors (iter %d).\n", |
268 | 0 | iter); |
269 | 0 | } |
270 | 0 | retrain_AQ_codebook(n, x); |
271 | 0 | } |
272 | 0 | } |
273 | | |
274 | | // find min and max norms |
275 | 0 | std::vector<float> norms(n); |
276 | |
|
277 | 0 | for (size_t i = 0; i < n; i++) { |
278 | 0 | norms[i] = fvec_L2sqr( |
279 | 0 | x + i * d, residuals.data() + i * cur_beam_size * d, d); |
280 | 0 | } |
281 | | |
282 | | // fvec_norms_L2sqr(norms.data(), x, d, n); |
283 | 0 | train_norm(n, norms.data()); |
284 | |
|
285 | 0 | if (!(train_type & Skip_codebook_tables)) { |
286 | 0 | compute_codebook_tables(); |
287 | 0 | } |
288 | 0 | } |
289 | | |
290 | 0 | float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) { |
291 | 0 | FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points"); |
292 | | |
293 | 0 | if (verbose) { |
294 | 0 | printf(" encoding %zd training vectors\n", n); |
295 | 0 | } |
296 | 0 | std::vector<uint8_t> codes(n * code_size); |
297 | 0 | compute_codes(x, codes.data(), n); |
298 | | |
299 | | // compute reconstruction error |
300 | 0 | float input_recons_error; |
301 | 0 | { |
302 | 0 | std::vector<float> x_recons(n * d); |
303 | 0 | decode(codes.data(), x_recons.data(), n); |
304 | 0 | input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d); |
305 | 0 | if (verbose) { |
306 | 0 | printf(" input quantization error %g\n", input_recons_error); |
307 | 0 | } |
308 | 0 | } |
309 | | |
310 | | // build matrix of the linear system |
311 | 0 | std::vector<float> C(n * total_codebook_size); |
312 | 0 | for (size_t i = 0; i < n; i++) { |
313 | 0 | BitstringReader bsr(codes.data() + i * code_size, code_size); |
314 | 0 | for (int m = 0; m < M; m++) { |
315 | 0 | int idx = bsr.read(nbits[m]); |
316 | 0 | C[i + (codebook_offsets[m] + idx) * n] = 1; |
317 | 0 | } |
318 | 0 | } |
319 | | |
320 | | // transpose training vectors |
321 | 0 | std::vector<float> xt(n * d); |
322 | |
|
323 | 0 | for (size_t i = 0; i < n; i++) { |
324 | 0 | for (size_t j = 0; j < d; j++) { |
325 | 0 | xt[j * n + i] = x[i * d + j]; |
326 | 0 | } |
327 | 0 | } |
328 | |
|
329 | 0 | { // solve least squares |
330 | 0 | FINTEGER lwork = -1; |
331 | 0 | FINTEGER di = d, ni = n, tcsi = total_codebook_size; |
332 | 0 | FINTEGER info = -1, rank = -1; |
333 | |
|
334 | 0 | float rcond = 1e-4; // this is an important parameter because the code |
335 | | // matrix can be rank deficient for small problems, |
336 | | // the default rcond=-1 does not work |
337 | 0 | float worksize; |
338 | 0 | std::vector<float> sing_vals(total_codebook_size); |
339 | 0 | FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an |
340 | | // upper bound |
341 | 0 | std::vector<FINTEGER> iwork( |
342 | 0 | 3 * total_codebook_size * nlvl + 11 * total_codebook_size); |
343 | | |
344 | | // worksize query |
345 | 0 | sgelsd_(&ni, |
346 | 0 | &tcsi, |
347 | 0 | &di, |
348 | 0 | C.data(), |
349 | 0 | &ni, |
350 | 0 | xt.data(), |
351 | 0 | &ni, |
352 | 0 | sing_vals.data(), |
353 | 0 | &rcond, |
354 | 0 | &rank, |
355 | 0 | &worksize, |
356 | 0 | &lwork, |
357 | 0 | iwork.data(), |
358 | 0 | &info); |
359 | 0 | FAISS_THROW_IF_NOT(info == 0); |
360 | | |
361 | 0 | lwork = worksize; |
362 | 0 | std::vector<float> work(lwork); |
363 | | // actual call |
364 | 0 | sgelsd_(&ni, |
365 | 0 | &tcsi, |
366 | 0 | &di, |
367 | 0 | C.data(), |
368 | 0 | &ni, |
369 | 0 | xt.data(), |
370 | 0 | &ni, |
371 | 0 | sing_vals.data(), |
372 | 0 | &rcond, |
373 | 0 | &rank, |
374 | 0 | work.data(), |
375 | 0 | &lwork, |
376 | 0 | iwork.data(), |
377 | 0 | &info); |
378 | 0 | FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info)); |
379 | 0 | if (verbose) { |
380 | 0 | printf(" sgelsd rank=%d/%d\n", |
381 | 0 | int(rank), |
382 | 0 | int(total_codebook_size)); |
383 | 0 | } |
384 | 0 | } |
385 | | |
386 | | // result is in xt, re-transpose to codebook |
387 | | |
388 | 0 | for (size_t i = 0; i < total_codebook_size; i++) { |
389 | 0 | for (size_t j = 0; j < d; j++) { |
390 | 0 | codebooks[i * d + j] = xt[j * n + i]; |
391 | 0 | FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j])); |
392 | 0 | } |
393 | 0 | } |
394 | | |
395 | 0 | float output_recons_error = 0; |
396 | 0 | for (size_t j = 0; j < d; j++) { |
397 | 0 | output_recons_error += fvec_norm_L2sqr( |
398 | 0 | xt.data() + total_codebook_size + n * j, |
399 | 0 | n - total_codebook_size); |
400 | 0 | } |
401 | 0 | if (verbose) { |
402 | 0 | printf(" output quantization error %g\n", output_recons_error); |
403 | 0 | } |
404 | 0 | return output_recons_error; |
405 | 0 | } |
406 | | |
407 | 0 | size_t ResidualQuantizer::memory_per_point(int beam_size) const { |
408 | 0 | if (beam_size < 0) { |
409 | 0 | beam_size = max_beam_size; |
410 | 0 | } |
411 | 0 | size_t mem; |
412 | 0 | mem = beam_size * d * 2 * sizeof(float); // size for 2 beams at a time |
413 | 0 | mem += beam_size * beam_size * |
414 | 0 | (sizeof(float) + sizeof(idx_t)); // size for 1 beam search result |
415 | 0 | return mem; |
416 | 0 | } |
417 | | |
418 | | /**************************************************************** |
419 | | * Encoding |
420 | | ****************************************************************/ |
421 | | |
422 | | using namespace rq_encode_steps; |
423 | | |
424 | | void ResidualQuantizer::compute_codes_add_centroids( |
425 | | const float* x, |
426 | | uint8_t* codes_out, |
427 | | size_t n, |
428 | 0 | const float* centroids) const { |
429 | 0 | FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet."); |
430 | | |
431 | | // |
432 | 0 | size_t mem = memory_per_point(); |
433 | |
|
434 | 0 | size_t bs = max_mem_distances / mem; |
435 | 0 | if (bs == 0) { |
436 | 0 | bs = 1; // otherwise we can't do much |
437 | 0 | } |
438 | | |
439 | | // prepare memory pools |
440 | 0 | ComputeCodesAddCentroidsLUT0MemoryPool pool0; |
441 | 0 | ComputeCodesAddCentroidsLUT1MemoryPool pool1; |
442 | |
|
443 | 0 | for (size_t i0 = 0; i0 < n; i0 += bs) { |
444 | 0 | size_t i1 = std::min(n, i0 + bs); |
445 | 0 | const float* cent = nullptr; |
446 | 0 | if (centroids != nullptr) { |
447 | 0 | cent = centroids + i0 * d; |
448 | 0 | } |
449 | |
|
450 | 0 | if (use_beam_LUT == 0) { |
451 | 0 | compute_codes_add_centroids_mp_lut0( |
452 | 0 | *this, |
453 | 0 | x + i0 * d, |
454 | 0 | codes_out + i0 * code_size, |
455 | 0 | i1 - i0, |
456 | 0 | cent, |
457 | 0 | pool0); |
458 | 0 | } else if (use_beam_LUT == 1) { |
459 | 0 | compute_codes_add_centroids_mp_lut1( |
460 | 0 | *this, |
461 | 0 | x + i0 * d, |
462 | 0 | codes_out + i0 * code_size, |
463 | 0 | i1 - i0, |
464 | 0 | cent, |
465 | 0 | pool1); |
466 | 0 | } |
467 | 0 | } |
468 | 0 | } |
469 | | |
470 | | void ResidualQuantizer::refine_beam( |
471 | | size_t n, |
472 | | size_t beam_size, |
473 | | const float* x, |
474 | | int out_beam_size, |
475 | | int32_t* out_codes, |
476 | | float* out_residuals, |
477 | 0 | float* out_distances) const { |
478 | 0 | RefineBeamMemoryPool pool; |
479 | 0 | refine_beam_mp( |
480 | 0 | *this, |
481 | 0 | n, |
482 | 0 | beam_size, |
483 | 0 | x, |
484 | 0 | out_beam_size, |
485 | 0 | out_codes, |
486 | 0 | out_residuals, |
487 | 0 | out_distances, |
488 | 0 | pool); |
489 | 0 | } |
490 | | |
491 | | /******************************************************************* |
492 | | * Functions using the dot products between codebook entries |
493 | | *******************************************************************/ |
494 | | |
495 | | void ResidualQuantizer::refine_beam_LUT( |
496 | | size_t n, |
497 | | const float* query_norms, // size n |
498 | | const float* query_cp, // |
499 | | int out_beam_size, |
500 | | int32_t* out_codes, |
501 | 0 | float* out_distances) const { |
502 | 0 | RefineBeamLUTMemoryPool pool; |
503 | 0 | refine_beam_LUT_mp( |
504 | 0 | *this, |
505 | 0 | n, |
506 | 0 | query_norms, |
507 | 0 | query_cp, |
508 | 0 | out_beam_size, |
509 | 0 | out_codes, |
510 | 0 | out_distances, |
511 | 0 | pool); |
512 | 0 | } |
513 | | |
514 | | } // namespace faiss |