/root/doris/contrib/faiss/faiss/impl/LocalSearchQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/LocalSearchQuantizer.h> |
9 | | |
10 | | #include <cstddef> |
11 | | #include <cstdio> |
12 | | #include <cstring> |
13 | | #include <memory> |
14 | | #include <random> |
15 | | |
16 | | #include <algorithm> |
17 | | |
18 | | #include <faiss/impl/AuxIndexStructures.h> |
19 | | #include <faiss/impl/FaissAssert.h> |
20 | | #include <faiss/utils/distances.h> |
21 | | #include <faiss/utils/hamming.h> // BitstringWriter |
22 | | #include <faiss/utils/utils.h> |
23 | | |
24 | | #include <faiss/utils/approx_topk/approx_topk.h> |
25 | | |
26 | | // this is needed for prefetching |
27 | | #include <faiss/impl/platform_macros.h> |
28 | | |
29 | | #ifdef __AVX2__ |
30 | | #include <xmmintrin.h> |
31 | | #endif |
32 | | |
33 | | extern "C" { |
34 | | // LU decomoposition of a general matrix |
35 | | void sgetrf_( |
36 | | FINTEGER* m, |
37 | | FINTEGER* n, |
38 | | float* a, |
39 | | FINTEGER* lda, |
40 | | FINTEGER* ipiv, |
41 | | FINTEGER* info); |
42 | | |
43 | | // generate inverse of a matrix given its LU decomposition |
44 | | void sgetri_( |
45 | | FINTEGER* n, |
46 | | float* a, |
47 | | FINTEGER* lda, |
48 | | FINTEGER* ipiv, |
49 | | float* work, |
50 | | FINTEGER* lwork, |
51 | | FINTEGER* info); |
52 | | |
53 | | // general matrix multiplication |
54 | | int sgemm_( |
55 | | const char* transa, |
56 | | const char* transb, |
57 | | FINTEGER* m, |
58 | | FINTEGER* n, |
59 | | FINTEGER* k, |
60 | | const float* alpha, |
61 | | const float* a, |
62 | | FINTEGER* lda, |
63 | | const float* b, |
64 | | FINTEGER* ldb, |
65 | | float* beta, |
66 | | float* c, |
67 | | FINTEGER* ldc); |
68 | | |
69 | | // LU decomoposition of a general matrix |
70 | | void dgetrf_( |
71 | | FINTEGER* m, |
72 | | FINTEGER* n, |
73 | | double* a, |
74 | | FINTEGER* lda, |
75 | | FINTEGER* ipiv, |
76 | | FINTEGER* info); |
77 | | |
78 | | // generate inverse of a matrix given its LU decomposition |
79 | | void dgetri_( |
80 | | FINTEGER* n, |
81 | | double* a, |
82 | | FINTEGER* lda, |
83 | | FINTEGER* ipiv, |
84 | | double* work, |
85 | | FINTEGER* lwork, |
86 | | FINTEGER* info); |
87 | | |
88 | | // general matrix multiplication |
89 | | int dgemm_( |
90 | | const char* transa, |
91 | | const char* transb, |
92 | | FINTEGER* m, |
93 | | FINTEGER* n, |
94 | | FINTEGER* k, |
95 | | const double* alpha, |
96 | | const double* a, |
97 | | FINTEGER* lda, |
98 | | const double* b, |
99 | | FINTEGER* ldb, |
100 | | double* beta, |
101 | | double* c, |
102 | | FINTEGER* ldc); |
103 | | } |
104 | | |
105 | | namespace { |
106 | | |
107 | 0 | void fmat_inverse(float* a, FINTEGER n) { |
108 | 0 | FINTEGER info; |
109 | 0 | FINTEGER lwork = n * n; |
110 | 0 | std::vector<FINTEGER> ipiv(n); |
111 | 0 | std::vector<float> workspace(lwork); |
112 | |
|
113 | 0 | sgetrf_(&n, &n, a, &n, ipiv.data(), &info); |
114 | 0 | FAISS_THROW_IF_NOT(info == 0); |
115 | 0 | sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info); |
116 | 0 | FAISS_THROW_IF_NOT(info == 0); |
117 | 0 | } |
118 | | |
119 | | // c and a and b can overlap |
120 | 0 | void dfvec_add(size_t d, const double* a, const float* b, double* c) { |
121 | 0 | for (size_t i = 0; i < d; i++) { |
122 | 0 | c[i] = a[i] + b[i]; |
123 | 0 | } |
124 | 0 | } |
125 | | |
126 | 0 | void dmat_inverse(double* a, FINTEGER n) { |
127 | 0 | FINTEGER info; |
128 | 0 | FINTEGER lwork = n * n; |
129 | 0 | std::vector<FINTEGER> ipiv(n); |
130 | 0 | std::vector<double> workspace(lwork); |
131 | |
|
132 | 0 | dgetrf_(&n, &n, a, &n, ipiv.data(), &info); |
133 | 0 | FAISS_THROW_IF_NOT(info == 0); |
134 | 0 | dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info); |
135 | 0 | FAISS_THROW_IF_NOT(info == 0); |
136 | 0 | } |
137 | | |
138 | | void random_int32( |
139 | | std::vector<int32_t>& x, |
140 | | int32_t min, |
141 | | int32_t max, |
142 | 0 | std::mt19937& gen) { |
143 | 0 | std::uniform_int_distribution<int32_t> distrib(min, max); |
144 | 0 | for (size_t i = 0; i < x.size(); i++) { |
145 | 0 | x[i] = distrib(gen); |
146 | 0 | } |
147 | 0 | } |
148 | | |
149 | | } // anonymous namespace |
150 | | |
151 | | namespace faiss { |
152 | | |
153 | | lsq::LSQTimer lsq_timer; |
154 | | using lsq::LSQTimerScope; |
155 | | |
156 | | LocalSearchQuantizer::LocalSearchQuantizer( |
157 | | size_t d, |
158 | | size_t M, |
159 | | size_t nbits, |
160 | | Search_type_t search_type) |
161 | 0 | : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) { |
162 | 0 | K = (1 << nbits); |
163 | 0 | std::srand(random_seed); |
164 | 0 | } |
165 | | |
166 | 0 | LocalSearchQuantizer::~LocalSearchQuantizer() { |
167 | 0 | delete icm_encoder_factory; |
168 | 0 | } |
169 | | |
170 | 0 | LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {} |
171 | | |
172 | 0 | void LocalSearchQuantizer::train(size_t n, const float* x) { |
173 | 0 | FAISS_THROW_IF_NOT(K == (1 << nbits[0])); |
174 | 0 | nperts = std::min(nperts, M); |
175 | |
|
176 | 0 | lsq_timer.reset(); |
177 | 0 | LSQTimerScope scope(&lsq_timer, "train"); |
178 | 0 | if (verbose) { |
179 | 0 | printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n", |
180 | 0 | M, |
181 | 0 | n, |
182 | 0 | d); |
183 | 0 | } |
184 | | |
185 | | // allocate memory for codebooks, size [M, K, d] |
186 | 0 | codebooks.resize(M * K * d); |
187 | | |
188 | | // randomly initialize codes |
189 | 0 | std::mt19937 gen(random_seed); |
190 | 0 | std::vector<int32_t> codes(n * M); // [n, M] |
191 | 0 | random_int32(codes, 0, K - 1, gen); |
192 | | |
193 | | // compute standard derivations of each dimension |
194 | 0 | std::vector<float> stddev(d, 0); |
195 | |
|
196 | 0 | #pragma omp parallel for |
197 | 0 | for (int64_t i = 0; i < d; i++) { |
198 | 0 | float mean = 0; |
199 | 0 | for (size_t j = 0; j < n; j++) { |
200 | 0 | mean += x[j * d + i]; |
201 | 0 | } |
202 | 0 | mean = mean / n; |
203 | |
|
204 | 0 | float sum = 0; |
205 | 0 | for (size_t j = 0; j < n; j++) { |
206 | 0 | float xi = x[j * d + i] - mean; |
207 | 0 | sum += xi * xi; |
208 | 0 | } |
209 | 0 | stddev[i] = sqrtf(sum / n); |
210 | 0 | } |
211 | |
|
212 | 0 | if (verbose) { |
213 | 0 | float obj = evaluate(codes.data(), x, n); |
214 | 0 | printf("Before training: obj = %lf\n", obj); |
215 | 0 | } |
216 | |
|
217 | 0 | for (size_t i = 0; i < train_iters; i++) { |
218 | | // 1. update codebooks given x and codes |
219 | | // 2. add perturbation to codebooks (SR-D) |
220 | | // 3. refine codes given x and codebooks using icm |
221 | | |
222 | | // update codebooks |
223 | 0 | update_codebooks(x, codes.data(), n); |
224 | |
|
225 | 0 | if (verbose) { |
226 | 0 | float obj = evaluate(codes.data(), x, n); |
227 | 0 | printf("iter %zd:\n", i); |
228 | 0 | printf("\tafter updating codebooks: obj = %lf\n", obj); |
229 | 0 | } |
230 | | |
231 | | // SR-D: perturb codebooks |
232 | 0 | float T = pow((1.0f - (i + 1.0f) / train_iters), p); |
233 | 0 | perturb_codebooks(T, stddev, gen); |
234 | |
|
235 | 0 | if (verbose) { |
236 | 0 | float obj = evaluate(codes.data(), x, n); |
237 | 0 | printf("\tafter perturbing codebooks: obj = %lf\n", obj); |
238 | 0 | } |
239 | | |
240 | | // refine codes |
241 | 0 | icm_encode(codes.data(), x, n, train_ils_iters, gen); |
242 | |
|
243 | 0 | if (verbose) { |
244 | 0 | float obj = evaluate(codes.data(), x, n); |
245 | 0 | printf("\tafter updating codes: obj = %lf\n", obj); |
246 | 0 | } |
247 | 0 | } |
248 | |
|
249 | 0 | is_trained = true; |
250 | 0 | { |
251 | 0 | std::vector<float> x_recons(n * d); |
252 | 0 | std::vector<float> norms(n); |
253 | 0 | decode_unpacked(codes.data(), x_recons.data(), n); |
254 | 0 | fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n); |
255 | |
|
256 | 0 | train_norm(n, norms.data()); |
257 | 0 | } |
258 | |
|
259 | 0 | if (verbose) { |
260 | 0 | float obj = evaluate(codes.data(), x, n); |
261 | 0 | scope.finish(); |
262 | 0 | printf("After training: obj = %lf\n", obj); |
263 | |
|
264 | 0 | printf("Time statistic:\n"); |
265 | 0 | for (const auto& it : lsq_timer.t) { |
266 | 0 | printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000); |
267 | 0 | } |
268 | 0 | } |
269 | 0 | } |
270 | | |
271 | | void LocalSearchQuantizer::perturb_codebooks( |
272 | | float T, |
273 | | const std::vector<float>& stddev, |
274 | 0 | std::mt19937& gen) { |
275 | 0 | LSQTimerScope scope(&lsq_timer, "perturb_codebooks"); |
276 | |
|
277 | 0 | std::vector<std::normal_distribution<float>> distribs; |
278 | 0 | for (size_t i = 0; i < d; i++) { |
279 | 0 | distribs.emplace_back(0.0f, stddev[i]); |
280 | 0 | } |
281 | |
|
282 | 0 | for (size_t m = 0; m < M; m++) { |
283 | 0 | for (size_t k = 0; k < K; k++) { |
284 | 0 | for (size_t i = 0; i < d; i++) { |
285 | 0 | codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M; |
286 | 0 | } |
287 | 0 | } |
288 | 0 | } |
289 | 0 | } |
290 | | |
291 | | void LocalSearchQuantizer::compute_codes_add_centroids( |
292 | | const float* x, |
293 | | uint8_t* codes_out, |
294 | | size_t n, |
295 | 0 | const float* centroids) const { |
296 | 0 | FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet."); |
297 | | |
298 | 0 | lsq_timer.reset(); |
299 | 0 | LSQTimerScope scope(&lsq_timer, "encode"); |
300 | 0 | if (verbose) { |
301 | 0 | printf("Encoding %zd vectors...\n", n); |
302 | 0 | } |
303 | |
|
304 | 0 | std::vector<int32_t> codes(n * M); |
305 | 0 | std::mt19937 gen(random_seed); |
306 | 0 | random_int32(codes, 0, K - 1, gen); |
307 | |
|
308 | 0 | icm_encode(codes.data(), x, n, encode_ils_iters, gen); |
309 | 0 | pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids); |
310 | |
|
311 | 0 | if (verbose) { |
312 | 0 | scope.finish(); |
313 | 0 | printf("Time statistic:\n"); |
314 | 0 | for (const auto& it : lsq_timer.t) { |
315 | 0 | printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000); |
316 | 0 | } |
317 | 0 | } |
318 | 0 | } |
319 | | |
320 | | /** update codebooks given x and codes |
321 | | * |
322 | | * Let B denote the sparse matrix of codes, size [n, M * K]. |
323 | | * Let C denote the codebooks, size [M * K, d]. |
324 | | * Let X denote the training vectors, size [n, d] |
325 | | * |
326 | | * objective function: |
327 | | * L = (X - BC)^2 |
328 | | * |
329 | | * To minimize L, we have: |
330 | | * C = (B'B)^(-1)B'X |
331 | | * where ' denote transposed |
332 | | * |
333 | | * Add a regularization term to make B'B inversible: |
334 | | * C = (B'B + lambd * I)^(-1)B'X |
335 | | */ |
336 | | void LocalSearchQuantizer::update_codebooks( |
337 | | const float* x, |
338 | | const int32_t* codes, |
339 | 0 | size_t n) { |
340 | 0 | LSQTimerScope scope(&lsq_timer, "update_codebooks"); |
341 | |
|
342 | 0 | if (!update_codebooks_with_double) { |
343 | | // allocate memory |
344 | | // bb = B'B, bx = BX |
345 | 0 | std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K] |
346 | 0 | std::vector<float> bx(M * K * d, 0.0f); // [M * K, d] |
347 | | |
348 | | // compute B'B |
349 | 0 | for (size_t i = 0; i < n; i++) { |
350 | 0 | for (size_t m = 0; m < M; m++) { |
351 | 0 | int32_t code1 = codes[i * M + m]; |
352 | 0 | int32_t idx1 = m * K + code1; |
353 | 0 | bb[idx1 * M * K + idx1] += 1; |
354 | |
|
355 | 0 | for (size_t m2 = m + 1; m2 < M; m2++) { |
356 | 0 | int32_t code2 = codes[i * M + m2]; |
357 | 0 | int32_t idx2 = m2 * K + code2; |
358 | 0 | bb[idx1 * M * K + idx2] += 1; |
359 | 0 | bb[idx2 * M * K + idx1] += 1; |
360 | 0 | } |
361 | 0 | } |
362 | 0 | } |
363 | | |
364 | | // add a regularization term to B'B |
365 | 0 | for (int64_t i = 0; i < M * K; i++) { |
366 | 0 | bb[i * (M * K) + i] += lambd; |
367 | 0 | } |
368 | | |
369 | | // compute (B'B)^(-1) |
370 | 0 | fmat_inverse(bb.data(), M * K); // [M*K, M*K] |
371 | | |
372 | | // compute BX |
373 | 0 | for (size_t i = 0; i < n; i++) { |
374 | 0 | for (size_t m = 0; m < M; m++) { |
375 | 0 | int32_t code = codes[i * M + m]; |
376 | 0 | float* data = bx.data() + (m * K + code) * d; |
377 | 0 | fvec_add(d, data, x + i * d, data); |
378 | 0 | } |
379 | 0 | } |
380 | | |
381 | | // compute C = (B'B)^(-1) @ BX |
382 | | // |
383 | | // NOTE: LAPACK use column major order |
384 | | // out = alpha * op(A) * op(B) + beta * C |
385 | 0 | FINTEGER nrows_A = d; |
386 | 0 | FINTEGER ncols_A = M * K; |
387 | |
|
388 | 0 | FINTEGER nrows_B = M * K; |
389 | 0 | FINTEGER ncols_B = M * K; |
390 | |
|
391 | 0 | float alpha = 1.0f; |
392 | 0 | float beta = 0.0f; |
393 | 0 | sgemm_("Not Transposed", |
394 | 0 | "Not Transposed", |
395 | 0 | &nrows_A, // nrows of op(A) |
396 | 0 | &ncols_B, // ncols of op(B) |
397 | 0 | &ncols_A, // ncols of op(A) |
398 | 0 | &alpha, |
399 | 0 | bx.data(), |
400 | 0 | &nrows_A, // nrows of A |
401 | 0 | bb.data(), |
402 | 0 | &nrows_B, // nrows of B |
403 | 0 | &beta, |
404 | 0 | codebooks.data(), |
405 | 0 | &nrows_A); // nrows of output |
406 | |
|
407 | 0 | } else { |
408 | | // allocate memory |
409 | | // bb = B'B, bx = BX |
410 | 0 | std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K] |
411 | 0 | std::vector<double> bx(M * K * d, 0.0f); // [M * K, d] |
412 | | |
413 | | // compute B'B |
414 | 0 | for (size_t i = 0; i < n; i++) { |
415 | 0 | for (size_t m = 0; m < M; m++) { |
416 | 0 | int32_t code1 = codes[i * M + m]; |
417 | 0 | int32_t idx1 = m * K + code1; |
418 | 0 | bb[idx1 * M * K + idx1] += 1; |
419 | |
|
420 | 0 | for (size_t m2 = m + 1; m2 < M; m2++) { |
421 | 0 | int32_t code2 = codes[i * M + m2]; |
422 | 0 | int32_t idx2 = m2 * K + code2; |
423 | 0 | bb[idx1 * M * K + idx2] += 1; |
424 | 0 | bb[idx2 * M * K + idx1] += 1; |
425 | 0 | } |
426 | 0 | } |
427 | 0 | } |
428 | | |
429 | | // add a regularization term to B'B |
430 | 0 | for (int64_t i = 0; i < M * K; i++) { |
431 | 0 | bb[i * (M * K) + i] += lambd; |
432 | 0 | } |
433 | | |
434 | | // compute (B'B)^(-1) |
435 | 0 | dmat_inverse(bb.data(), M * K); // [M*K, M*K] |
436 | | |
437 | | // compute BX |
438 | 0 | for (size_t i = 0; i < n; i++) { |
439 | 0 | for (size_t m = 0; m < M; m++) { |
440 | 0 | int32_t code = codes[i * M + m]; |
441 | 0 | double* data = bx.data() + (m * K + code) * d; |
442 | 0 | dfvec_add(d, data, x + i * d, data); |
443 | 0 | } |
444 | 0 | } |
445 | | |
446 | | // compute C = (B'B)^(-1) @ BX |
447 | | // |
448 | | // NOTE: LAPACK use column major order |
449 | | // out = alpha * op(A) * op(B) + beta * C |
450 | 0 | FINTEGER nrows_A = d; |
451 | 0 | FINTEGER ncols_A = M * K; |
452 | |
|
453 | 0 | FINTEGER nrows_B = M * K; |
454 | 0 | FINTEGER ncols_B = M * K; |
455 | |
|
456 | 0 | std::vector<double> d_codebooks(M * K * d); |
457 | |
|
458 | 0 | double alpha = 1.0f; |
459 | 0 | double beta = 0.0f; |
460 | 0 | dgemm_("Not Transposed", |
461 | 0 | "Not Transposed", |
462 | 0 | &nrows_A, // nrows of op(A) |
463 | 0 | &ncols_B, // ncols of op(B) |
464 | 0 | &ncols_A, // ncols of op(A) |
465 | 0 | &alpha, |
466 | 0 | bx.data(), |
467 | 0 | &nrows_A, // nrows of A |
468 | 0 | bb.data(), |
469 | 0 | &nrows_B, // nrows of B |
470 | 0 | &beta, |
471 | 0 | d_codebooks.data(), |
472 | 0 | &nrows_A); // nrows of output |
473 | |
|
474 | 0 | for (size_t i = 0; i < M * K * d; i++) { |
475 | 0 | codebooks[i] = (float)d_codebooks[i]; |
476 | 0 | } |
477 | 0 | } |
478 | 0 | } |
479 | | |
480 | | /** encode using iterative conditional mode |
481 | | * |
482 | | * iterative conditional mode: |
483 | | * For every subcode ci (i = 1, ..., M) of a vector, we fix the other |
484 | | * subcodes cj (j != i) and then find the optimal value of ci such |
485 | | * that minimizing the objective function. |
486 | | |
487 | | * objective function: |
488 | | * L = (X - \sum cj)^2, j = 1, ..., M |
489 | | * L = X^2 - 2X * \sum cj + (\sum cj)^2 |
490 | | * |
491 | | * X^2 is negligable since it is the same for all possible value |
492 | | * k of the m-th subcode. |
493 | | * |
494 | | * 2X * \sum cj is the unary term |
495 | | * (\sum cj)^2 is the binary term |
496 | | * These two terms can be precomputed and store in a look up table. |
497 | | */ |
498 | | void LocalSearchQuantizer::icm_encode( |
499 | | int32_t* codes, |
500 | | const float* x, |
501 | | size_t n, |
502 | | size_t ils_iters, |
503 | 0 | std::mt19937& gen) const { |
504 | 0 | LSQTimerScope scope(&lsq_timer, "icm_encode"); |
505 | |
|
506 | 0 | auto factory = icm_encoder_factory; |
507 | 0 | std::unique_ptr<lsq::IcmEncoder> icm_encoder; |
508 | 0 | if (factory == nullptr) { |
509 | 0 | icm_encoder.reset(lsq::IcmEncoderFactory().get(this)); |
510 | 0 | } else { |
511 | 0 | icm_encoder.reset(factory->get(this)); |
512 | 0 | } |
513 | | |
514 | | // precompute binary terms for all chunks |
515 | 0 | icm_encoder->set_binary_term(); |
516 | |
|
517 | 0 | const size_t n_chunks = (n + chunk_size - 1) / chunk_size; |
518 | 0 | for (size_t i = 0; i < n_chunks; i++) { |
519 | 0 | size_t ni = std::min(chunk_size, n - i * chunk_size); |
520 | |
|
521 | 0 | if (verbose) { |
522 | 0 | printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n); |
523 | 0 | fflush(stdout); |
524 | 0 | if (i == n_chunks - 1 || i == 0) { |
525 | 0 | printf("\n"); |
526 | 0 | } |
527 | 0 | } |
528 | |
|
529 | 0 | const float* xi = x + i * chunk_size * d; |
530 | 0 | int32_t* codesi = codes + i * chunk_size * M; |
531 | 0 | icm_encoder->verbose = (verbose && i == 0); |
532 | 0 | icm_encoder->encode(codesi, xi, gen, ni, ils_iters); |
533 | 0 | } |
534 | 0 | } |
535 | | |
536 | | void LocalSearchQuantizer::icm_encode_impl( |
537 | | int32_t* codes, |
538 | | const float* x, |
539 | | const float* binaries, |
540 | | std::mt19937& gen, |
541 | | size_t n, |
542 | | size_t ils_iters, |
543 | 0 | bool verbose) const { |
544 | 0 | std::vector<float> unaries(n * M * K); // [M, n, K] |
545 | 0 | compute_unary_terms(x, unaries.data(), n); |
546 | |
|
547 | 0 | std::vector<int32_t> best_codes; |
548 | 0 | best_codes.assign(codes, codes + n * M); |
549 | |
|
550 | 0 | std::vector<float> best_objs(n, 0.0f); |
551 | 0 | evaluate(codes, x, n, best_objs.data()); |
552 | |
|
553 | 0 | FAISS_THROW_IF_NOT(nperts <= M); |
554 | 0 | for (size_t iter1 = 0; iter1 < ils_iters; iter1++) { |
555 | | // add perturbation to codes |
556 | 0 | perturb_codes(codes, n, gen); |
557 | |
|
558 | 0 | icm_encode_step(codes, unaries.data(), binaries, n, icm_iters); |
559 | |
|
560 | 0 | std::vector<float> icm_objs(n, 0.0f); |
561 | 0 | evaluate(codes, x, n, icm_objs.data()); |
562 | 0 | size_t n_betters = 0; |
563 | 0 | float mean_obj = 0.0f; |
564 | | |
565 | | // select the best code for every vector xi |
566 | 0 | #pragma omp parallel for reduction(+ : n_betters, mean_obj) |
567 | 0 | for (int64_t i = 0; i < n; i++) { |
568 | 0 | if (icm_objs[i] < best_objs[i]) { |
569 | 0 | best_objs[i] = icm_objs[i]; |
570 | 0 | memcpy(best_codes.data() + i * M, |
571 | 0 | codes + i * M, |
572 | 0 | sizeof(int32_t) * M); |
573 | 0 | n_betters += 1; |
574 | 0 | } |
575 | 0 | mean_obj += best_objs[i]; |
576 | 0 | } |
577 | 0 | mean_obj /= n; |
578 | |
|
579 | 0 | memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M); |
580 | |
|
581 | 0 | if (verbose) { |
582 | 0 | printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n", |
583 | 0 | iter1, |
584 | 0 | mean_obj, |
585 | 0 | n_betters, |
586 | 0 | n); |
587 | 0 | } |
588 | 0 | } // loop ils_iters |
589 | 0 | } |
590 | | |
591 | | void LocalSearchQuantizer::icm_encode_step( |
592 | | int32_t* codes, |
593 | | const float* unaries, |
594 | | const float* binaries, |
595 | | size_t n, |
596 | 0 | size_t n_iters) const { |
597 | 0 | FAISS_THROW_IF_NOT(M != 0 && K != 0); |
598 | 0 | FAISS_THROW_IF_NOT(binaries != nullptr); |
599 | | |
600 | 0 | #pragma omp parallel for schedule(dynamic) |
601 | 0 | for (int64_t i = 0; i < n; i++) { |
602 | 0 | std::vector<float> objs(K); |
603 | |
|
604 | 0 | for (size_t iter = 0; iter < n_iters; iter++) { |
605 | | // condition on the m-th subcode |
606 | 0 | for (size_t m = 0; m < M; m++) { |
607 | | // copy |
608 | 0 | auto u = unaries + m * n * K + i * K; |
609 | 0 | for (size_t code = 0; code < K; code++) { |
610 | 0 | objs[code] = u[code]; |
611 | 0 | } |
612 | | |
613 | | // compute objective function by adding unary |
614 | | // and binary terms together |
615 | 0 | for (size_t other_m = 0; other_m < M; other_m++) { |
616 | 0 | if (other_m == m) { |
617 | 0 | continue; |
618 | 0 | } |
619 | | |
620 | 0 | #ifdef __AVX2__ |
621 | | // TODO: add platform-independent compiler-independent |
622 | | // prefetch utilities. |
623 | 0 | if (other_m + 1 < M) { |
624 | | // do a single prefetch |
625 | 0 | int32_t code2 = codes[i * M + other_m + 1]; |
626 | | // for (int32_t code = 0; code < K; code += 64) { |
627 | 0 | int32_t code = 0; |
628 | 0 | { |
629 | 0 | size_t binary_idx = (other_m + 1) * M * K * K + |
630 | 0 | m * K * K + code2 * K + code; |
631 | 0 | _mm_prefetch( |
632 | 0 | (const char*)(binaries + binary_idx), |
633 | 0 | _MM_HINT_T0); |
634 | 0 | } |
635 | 0 | } |
636 | 0 | #endif |
637 | |
|
638 | 0 | for (int32_t code = 0; code < K; code++) { |
639 | 0 | int32_t code2 = codes[i * M + other_m]; |
640 | 0 | size_t binary_idx = other_m * M * K * K + m * K * K + |
641 | 0 | code2 * K + code; |
642 | | // binaries[m, other_m, code, code2]. |
643 | | // It is symmetric over (m <-> other_m) |
644 | | // and (code <-> code2). |
645 | | // So, replace the op with |
646 | | // binaries[other_m, m, code2, code]. |
647 | 0 | objs[code] += binaries[binary_idx]; |
648 | 0 | } |
649 | 0 | } |
650 | | |
651 | | // find the optimal value of the m-th subcode |
652 | 0 | float best_obj = HUGE_VALF; |
653 | 0 | int32_t best_code = 0; |
654 | | |
655 | | // find one using SIMD. The following operation is similar |
656 | | // to the search of the smallest element in objs |
657 | 0 | using C = CMax<float, int>; |
658 | 0 | HeapWithBuckets<C, 16, 1>::addn( |
659 | 0 | K, objs.data(), 1, &best_obj, &best_code); |
660 | | |
661 | | // done |
662 | 0 | codes[i * M + m] = best_code; |
663 | |
|
664 | 0 | } // loop M |
665 | 0 | } |
666 | 0 | } |
667 | 0 | } |
668 | | void LocalSearchQuantizer::perturb_codes( |
669 | | int32_t* codes, |
670 | | size_t n, |
671 | 0 | std::mt19937& gen) const { |
672 | 0 | LSQTimerScope scope(&lsq_timer, "perturb_codes"); |
673 | |
|
674 | 0 | std::uniform_int_distribution<size_t> m_distrib(0, M - 1); |
675 | 0 | std::uniform_int_distribution<int32_t> k_distrib(0, K - 1); |
676 | |
|
677 | 0 | for (size_t i = 0; i < n; i++) { |
678 | 0 | for (size_t j = 0; j < nperts; j++) { |
679 | 0 | size_t m = m_distrib(gen); |
680 | 0 | codes[i * M + m] = k_distrib(gen); |
681 | 0 | } |
682 | 0 | } |
683 | 0 | } |
684 | | |
685 | 0 | void LocalSearchQuantizer::compute_binary_terms(float* binaries) const { |
686 | 0 | LSQTimerScope scope(&lsq_timer, "compute_binary_terms"); |
687 | |
|
688 | 0 | #pragma omp parallel for |
689 | 0 | for (int64_t m12 = 0; m12 < M * M; m12++) { |
690 | 0 | size_t m1 = m12 / M; |
691 | 0 | size_t m2 = m12 % M; |
692 | |
|
693 | 0 | for (size_t code1 = 0; code1 < K; code1++) { |
694 | 0 | for (size_t code2 = 0; code2 < K; code2++) { |
695 | 0 | const float* c1 = codebooks.data() + m1 * K * d + code1 * d; |
696 | 0 | const float* c2 = codebooks.data() + m2 * K * d + code2 * d; |
697 | 0 | float ip = fvec_inner_product(c1, c2, d); |
698 | | // binaries[m1, m2, code1, code2] = ip * 2 |
699 | 0 | binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] = |
700 | 0 | ip * 2; |
701 | 0 | } |
702 | 0 | } |
703 | 0 | } |
704 | 0 | } |
705 | | |
706 | | void LocalSearchQuantizer::compute_unary_terms( |
707 | | const float* x, |
708 | | float* unaries, // [M, n, K] |
709 | 0 | size_t n) const { |
710 | 0 | LSQTimerScope scope(&lsq_timer, "compute_unary_terms"); |
711 | | |
712 | | // compute x * codebook^T for each codebook |
713 | | // |
714 | | // NOTE: LAPACK use column major order |
715 | | // out = alpha * op(A) * op(B) + beta * C |
716 | |
|
717 | 0 | for (size_t m = 0; m < M; m++) { |
718 | 0 | FINTEGER nrows_A = K; |
719 | 0 | FINTEGER ncols_A = d; |
720 | |
|
721 | 0 | FINTEGER nrows_B = d; |
722 | 0 | FINTEGER ncols_B = n; |
723 | |
|
724 | 0 | float alpha = -2.0f; |
725 | 0 | float beta = 0.0f; |
726 | 0 | sgemm_("Transposed", |
727 | 0 | "Not Transposed", |
728 | 0 | &nrows_A, // nrows of op(A) |
729 | 0 | &ncols_B, // ncols of op(B) |
730 | 0 | &ncols_A, // ncols of op(A) |
731 | 0 | &alpha, |
732 | 0 | codebooks.data() + m * K * d, |
733 | 0 | &ncols_A, // nrows of A |
734 | 0 | x, |
735 | 0 | &nrows_B, // nrows of B |
736 | 0 | &beta, |
737 | 0 | unaries + m * n * K, |
738 | 0 | &nrows_A); // nrows of output |
739 | 0 | } |
740 | |
|
741 | 0 | std::vector<float> norms(M * K); |
742 | 0 | fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K); |
743 | |
|
744 | 0 | #pragma omp parallel for |
745 | 0 | for (int64_t i = 0; i < n; i++) { |
746 | 0 | for (size_t m = 0; m < M; m++) { |
747 | 0 | float* u = unaries + m * n * K + i * K; |
748 | 0 | fvec_add(K, u, norms.data() + m * K, u); |
749 | 0 | } |
750 | 0 | } |
751 | 0 | } |
752 | | |
753 | | float LocalSearchQuantizer::evaluate( |
754 | | const int32_t* codes, |
755 | | const float* x, |
756 | | size_t n, |
757 | 0 | float* objs) const { |
758 | 0 | LSQTimerScope scope(&lsq_timer, "evaluate"); |
759 | | |
760 | | // decode |
761 | 0 | std::vector<float> decoded_x(n * d, 0.0f); |
762 | 0 | float obj = 0.0f; |
763 | |
|
764 | 0 | #pragma omp parallel for reduction(+ : obj) |
765 | 0 | for (int64_t i = 0; i < n; i++) { |
766 | 0 | const auto code = codes + i * M; |
767 | 0 | const auto decoded_i = decoded_x.data() + i * d; |
768 | 0 | for (size_t m = 0; m < M; m++) { |
769 | | // c = codebooks[m, code[m]] |
770 | 0 | const auto c = codebooks.data() + m * K * d + code[m] * d; |
771 | 0 | fvec_add(d, decoded_i, c, decoded_i); |
772 | 0 | } |
773 | |
|
774 | 0 | float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d); |
775 | 0 | obj += err; |
776 | |
|
777 | 0 | if (objs) { |
778 | 0 | objs[i] = err; |
779 | 0 | } |
780 | 0 | } |
781 | |
|
782 | 0 | obj = obj / n; |
783 | 0 | return obj; |
784 | 0 | } |
785 | | |
786 | | namespace lsq { |
787 | | |
788 | | IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq) |
789 | 0 | : verbose(false), lsq(lsq) {} |
790 | | |
791 | 0 | void IcmEncoder::set_binary_term() { |
792 | 0 | auto M = lsq->M; |
793 | 0 | auto K = lsq->K; |
794 | 0 | binaries.resize(M * M * K * K); |
795 | 0 | lsq->compute_binary_terms(binaries.data()); |
796 | 0 | } |
797 | | |
798 | | void IcmEncoder::encode( |
799 | | int32_t* codes, |
800 | | const float* x, |
801 | | std::mt19937& gen, |
802 | | size_t n, |
803 | 0 | size_t ils_iters) const { |
804 | 0 | lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose); |
805 | 0 | } |
806 | | |
807 | 0 | double LSQTimer::get(const std::string& name) { |
808 | 0 | if (t.count(name) == 0) { |
809 | 0 | return 0.0; |
810 | 0 | } else { |
811 | 0 | return t[name]; |
812 | 0 | } |
813 | 0 | } |
814 | | |
815 | 0 | void LSQTimer::add(const std::string& name, double delta) { |
816 | 0 | if (t.count(name) == 0) { |
817 | 0 | t[name] = delta; |
818 | 0 | } else { |
819 | 0 | t[name] += delta; |
820 | 0 | } |
821 | 0 | } |
822 | | |
823 | 1 | void LSQTimer::reset() { |
824 | 1 | t.clear(); |
825 | 1 | } |
826 | | |
827 | | LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name) |
828 | 0 | : timer(timer), name(name), finished(false) { |
829 | 0 | t0 = getmillisecs(); |
830 | 0 | } |
831 | | |
832 | 0 | void LSQTimerScope::finish() { |
833 | 0 | if (!finished) { |
834 | 0 | auto delta = getmillisecs() - t0; |
835 | 0 | timer->add(name, delta); |
836 | 0 | finished = true; |
837 | 0 | } |
838 | 0 | } |
839 | | |
840 | 0 | LSQTimerScope::~LSQTimerScope() { |
841 | 0 | finish(); |
842 | 0 | } |
843 | | |
844 | | } // namespace lsq |
845 | | |
846 | | } // namespace faiss |