Coverage Report

Created: 2025-09-29 18:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/LocalSearchQuantizer.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/LocalSearchQuantizer.h>
9
10
#include <cstddef>
11
#include <cstdio>
12
#include <cstring>
13
#include <memory>
14
#include <random>
15
16
#include <algorithm>
17
18
#include <faiss/impl/AuxIndexStructures.h>
19
#include <faiss/impl/FaissAssert.h>
20
#include <faiss/utils/distances.h>
21
#include <faiss/utils/hamming.h> // BitstringWriter
22
#include <faiss/utils/utils.h>
23
24
#include <faiss/utils/approx_topk/approx_topk.h>
25
26
// this is needed for prefetching
27
#include <faiss/impl/platform_macros.h>
28
29
#ifdef __AVX2__
30
#include <xmmintrin.h>
31
#endif
32
33
extern "C" {
34
// LU decomoposition of a general matrix
35
void sgetrf_(
36
        FINTEGER* m,
37
        FINTEGER* n,
38
        float* a,
39
        FINTEGER* lda,
40
        FINTEGER* ipiv,
41
        FINTEGER* info);
42
43
// generate inverse of a matrix given its LU decomposition
44
void sgetri_(
45
        FINTEGER* n,
46
        float* a,
47
        FINTEGER* lda,
48
        FINTEGER* ipiv,
49
        float* work,
50
        FINTEGER* lwork,
51
        FINTEGER* info);
52
53
// general matrix multiplication
54
int sgemm_(
55
        const char* transa,
56
        const char* transb,
57
        FINTEGER* m,
58
        FINTEGER* n,
59
        FINTEGER* k,
60
        const float* alpha,
61
        const float* a,
62
        FINTEGER* lda,
63
        const float* b,
64
        FINTEGER* ldb,
65
        float* beta,
66
        float* c,
67
        FINTEGER* ldc);
68
69
// LU decomoposition of a general matrix
70
void dgetrf_(
71
        FINTEGER* m,
72
        FINTEGER* n,
73
        double* a,
74
        FINTEGER* lda,
75
        FINTEGER* ipiv,
76
        FINTEGER* info);
77
78
// generate inverse of a matrix given its LU decomposition
79
void dgetri_(
80
        FINTEGER* n,
81
        double* a,
82
        FINTEGER* lda,
83
        FINTEGER* ipiv,
84
        double* work,
85
        FINTEGER* lwork,
86
        FINTEGER* info);
87
88
// general matrix multiplication
89
int dgemm_(
90
        const char* transa,
91
        const char* transb,
92
        FINTEGER* m,
93
        FINTEGER* n,
94
        FINTEGER* k,
95
        const double* alpha,
96
        const double* a,
97
        FINTEGER* lda,
98
        const double* b,
99
        FINTEGER* ldb,
100
        double* beta,
101
        double* c,
102
        FINTEGER* ldc);
103
}
104
105
namespace {
106
107
0
void fmat_inverse(float* a, FINTEGER n) {
108
0
    FINTEGER info;
109
0
    FINTEGER lwork = n * n;
110
0
    std::vector<FINTEGER> ipiv(n);
111
0
    std::vector<float> workspace(lwork);
112
113
0
    sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
114
0
    FAISS_THROW_IF_NOT(info == 0);
115
0
    sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
116
0
    FAISS_THROW_IF_NOT(info == 0);
117
0
}
118
119
// c and a and b can overlap
120
0
void dfvec_add(size_t d, const double* a, const float* b, double* c) {
121
0
    for (size_t i = 0; i < d; i++) {
122
0
        c[i] = a[i] + b[i];
123
0
    }
124
0
}
125
126
0
void dmat_inverse(double* a, FINTEGER n) {
127
0
    FINTEGER info;
128
0
    FINTEGER lwork = n * n;
129
0
    std::vector<FINTEGER> ipiv(n);
130
0
    std::vector<double> workspace(lwork);
131
132
0
    dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
133
0
    FAISS_THROW_IF_NOT(info == 0);
134
0
    dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
135
0
    FAISS_THROW_IF_NOT(info == 0);
136
0
}
137
138
void random_int32(
139
        std::vector<int32_t>& x,
140
        int32_t min,
141
        int32_t max,
142
0
        std::mt19937& gen) {
143
0
    std::uniform_int_distribution<int32_t> distrib(min, max);
144
0
    for (size_t i = 0; i < x.size(); i++) {
145
0
        x[i] = distrib(gen);
146
0
    }
147
0
}
148
149
} // anonymous namespace
150
151
namespace faiss {
152
153
lsq::LSQTimer lsq_timer;
154
using lsq::LSQTimerScope;
155
156
LocalSearchQuantizer::LocalSearchQuantizer(
157
        size_t d,
158
        size_t M,
159
        size_t nbits,
160
        Search_type_t search_type)
161
0
        : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
162
0
    K = (1 << nbits);
163
0
    std::srand(random_seed);
164
0
}
165
166
0
LocalSearchQuantizer::~LocalSearchQuantizer() {
167
0
    delete icm_encoder_factory;
168
0
}
169
170
0
LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
171
172
0
void LocalSearchQuantizer::train(size_t n, const float* x) {
173
0
    FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
174
0
    nperts = std::min(nperts, M);
175
176
0
    lsq_timer.reset();
177
0
    LSQTimerScope scope(&lsq_timer, "train");
178
0
    if (verbose) {
179
0
        printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
180
0
               M,
181
0
               n,
182
0
               d);
183
0
    }
184
185
    // allocate memory for codebooks, size [M, K, d]
186
0
    codebooks.resize(M * K * d);
187
188
    // randomly initialize codes
189
0
    std::mt19937 gen(random_seed);
190
0
    std::vector<int32_t> codes(n * M); // [n, M]
191
0
    random_int32(codes, 0, K - 1, gen);
192
193
    // compute standard derivations of each dimension
194
0
    std::vector<float> stddev(d, 0);
195
196
0
#pragma omp parallel for
197
0
    for (int64_t i = 0; i < d; i++) {
198
0
        float mean = 0;
199
0
        for (size_t j = 0; j < n; j++) {
200
0
            mean += x[j * d + i];
201
0
        }
202
0
        mean = mean / n;
203
204
0
        float sum = 0;
205
0
        for (size_t j = 0; j < n; j++) {
206
0
            float xi = x[j * d + i] - mean;
207
0
            sum += xi * xi;
208
0
        }
209
0
        stddev[i] = sqrtf(sum / n);
210
0
    }
211
212
0
    if (verbose) {
213
0
        float obj = evaluate(codes.data(), x, n);
214
0
        printf("Before training: obj = %lf\n", obj);
215
0
    }
216
217
0
    for (size_t i = 0; i < train_iters; i++) {
218
        // 1. update codebooks given x and codes
219
        // 2. add perturbation to codebooks (SR-D)
220
        // 3. refine codes given x and codebooks using icm
221
222
        // update codebooks
223
0
        update_codebooks(x, codes.data(), n);
224
225
0
        if (verbose) {
226
0
            float obj = evaluate(codes.data(), x, n);
227
0
            printf("iter %zd:\n", i);
228
0
            printf("\tafter updating codebooks: obj = %lf\n", obj);
229
0
        }
230
231
        // SR-D: perturb codebooks
232
0
        float T = pow((1.0f - (i + 1.0f) / train_iters), p);
233
0
        perturb_codebooks(T, stddev, gen);
234
235
0
        if (verbose) {
236
0
            float obj = evaluate(codes.data(), x, n);
237
0
            printf("\tafter perturbing codebooks: obj = %lf\n", obj);
238
0
        }
239
240
        // refine codes
241
0
        icm_encode(codes.data(), x, n, train_ils_iters, gen);
242
243
0
        if (verbose) {
244
0
            float obj = evaluate(codes.data(), x, n);
245
0
            printf("\tafter updating codes: obj = %lf\n", obj);
246
0
        }
247
0
    }
248
249
0
    is_trained = true;
250
0
    {
251
0
        std::vector<float> x_recons(n * d);
252
0
        std::vector<float> norms(n);
253
0
        decode_unpacked(codes.data(), x_recons.data(), n);
254
0
        fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
255
256
0
        train_norm(n, norms.data());
257
0
    }
258
259
0
    if (verbose) {
260
0
        float obj = evaluate(codes.data(), x, n);
261
0
        scope.finish();
262
0
        printf("After training: obj = %lf\n", obj);
263
264
0
        printf("Time statistic:\n");
265
0
        for (const auto& it : lsq_timer.t) {
266
0
            printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
267
0
        }
268
0
    }
269
0
}
270
271
void LocalSearchQuantizer::perturb_codebooks(
272
        float T,
273
        const std::vector<float>& stddev,
274
0
        std::mt19937& gen) {
275
0
    LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
276
277
0
    std::vector<std::normal_distribution<float>> distribs;
278
0
    for (size_t i = 0; i < d; i++) {
279
0
        distribs.emplace_back(0.0f, stddev[i]);
280
0
    }
281
282
0
    for (size_t m = 0; m < M; m++) {
283
0
        for (size_t k = 0; k < K; k++) {
284
0
            for (size_t i = 0; i < d; i++) {
285
0
                codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M;
286
0
            }
287
0
        }
288
0
    }
289
0
}
290
291
void LocalSearchQuantizer::compute_codes_add_centroids(
292
        const float* x,
293
        uint8_t* codes_out,
294
        size_t n,
295
0
        const float* centroids) const {
296
0
    FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
297
298
0
    lsq_timer.reset();
299
0
    LSQTimerScope scope(&lsq_timer, "encode");
300
0
    if (verbose) {
301
0
        printf("Encoding %zd vectors...\n", n);
302
0
    }
303
304
0
    std::vector<int32_t> codes(n * M);
305
0
    std::mt19937 gen(random_seed);
306
0
    random_int32(codes, 0, K - 1, gen);
307
308
0
    icm_encode(codes.data(), x, n, encode_ils_iters, gen);
309
0
    pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids);
310
311
0
    if (verbose) {
312
0
        scope.finish();
313
0
        printf("Time statistic:\n");
314
0
        for (const auto& it : lsq_timer.t) {
315
0
            printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
316
0
        }
317
0
    }
318
0
}
319
320
/** update codebooks given x and codes
321
 *
322
 * Let B denote the sparse matrix of codes, size [n, M * K].
323
 * Let C denote the codebooks, size [M * K, d].
324
 * Let X denote the training vectors, size [n, d]
325
 *
326
 * objective function:
327
 *     L = (X - BC)^2
328
 *
329
 * To minimize L, we have:
330
 *     C = (B'B)^(-1)B'X
331
 * where ' denote transposed
332
 *
333
 * Add a regularization term to make B'B inversible:
334
 *     C = (B'B + lambd * I)^(-1)B'X
335
 */
336
void LocalSearchQuantizer::update_codebooks(
337
        const float* x,
338
        const int32_t* codes,
339
0
        size_t n) {
340
0
    LSQTimerScope scope(&lsq_timer, "update_codebooks");
341
342
0
    if (!update_codebooks_with_double) {
343
        // allocate memory
344
        // bb = B'B, bx = BX
345
0
        std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
346
0
        std::vector<float> bx(M * K * d, 0.0f);     // [M * K, d]
347
348
        // compute B'B
349
0
        for (size_t i = 0; i < n; i++) {
350
0
            for (size_t m = 0; m < M; m++) {
351
0
                int32_t code1 = codes[i * M + m];
352
0
                int32_t idx1 = m * K + code1;
353
0
                bb[idx1 * M * K + idx1] += 1;
354
355
0
                for (size_t m2 = m + 1; m2 < M; m2++) {
356
0
                    int32_t code2 = codes[i * M + m2];
357
0
                    int32_t idx2 = m2 * K + code2;
358
0
                    bb[idx1 * M * K + idx2] += 1;
359
0
                    bb[idx2 * M * K + idx1] += 1;
360
0
                }
361
0
            }
362
0
        }
363
364
        // add a regularization term to B'B
365
0
        for (int64_t i = 0; i < M * K; i++) {
366
0
            bb[i * (M * K) + i] += lambd;
367
0
        }
368
369
        // compute (B'B)^(-1)
370
0
        fmat_inverse(bb.data(), M * K); // [M*K, M*K]
371
372
        // compute BX
373
0
        for (size_t i = 0; i < n; i++) {
374
0
            for (size_t m = 0; m < M; m++) {
375
0
                int32_t code = codes[i * M + m];
376
0
                float* data = bx.data() + (m * K + code) * d;
377
0
                fvec_add(d, data, x + i * d, data);
378
0
            }
379
0
        }
380
381
        // compute C = (B'B)^(-1) @ BX
382
        //
383
        // NOTE: LAPACK use column major order
384
        // out = alpha * op(A) * op(B) + beta * C
385
0
        FINTEGER nrows_A = d;
386
0
        FINTEGER ncols_A = M * K;
387
388
0
        FINTEGER nrows_B = M * K;
389
0
        FINTEGER ncols_B = M * K;
390
391
0
        float alpha = 1.0f;
392
0
        float beta = 0.0f;
393
0
        sgemm_("Not Transposed",
394
0
               "Not Transposed",
395
0
               &nrows_A, // nrows of op(A)
396
0
               &ncols_B, // ncols of op(B)
397
0
               &ncols_A, // ncols of op(A)
398
0
               &alpha,
399
0
               bx.data(),
400
0
               &nrows_A, // nrows of A
401
0
               bb.data(),
402
0
               &nrows_B, // nrows of B
403
0
               &beta,
404
0
               codebooks.data(),
405
0
               &nrows_A); // nrows of output
406
407
0
    } else {
408
        // allocate memory
409
        // bb = B'B, bx = BX
410
0
        std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
411
0
        std::vector<double> bx(M * K * d, 0.0f);     // [M * K, d]
412
413
        // compute B'B
414
0
        for (size_t i = 0; i < n; i++) {
415
0
            for (size_t m = 0; m < M; m++) {
416
0
                int32_t code1 = codes[i * M + m];
417
0
                int32_t idx1 = m * K + code1;
418
0
                bb[idx1 * M * K + idx1] += 1;
419
420
0
                for (size_t m2 = m + 1; m2 < M; m2++) {
421
0
                    int32_t code2 = codes[i * M + m2];
422
0
                    int32_t idx2 = m2 * K + code2;
423
0
                    bb[idx1 * M * K + idx2] += 1;
424
0
                    bb[idx2 * M * K + idx1] += 1;
425
0
                }
426
0
            }
427
0
        }
428
429
        // add a regularization term to B'B
430
0
        for (int64_t i = 0; i < M * K; i++) {
431
0
            bb[i * (M * K) + i] += lambd;
432
0
        }
433
434
        // compute (B'B)^(-1)
435
0
        dmat_inverse(bb.data(), M * K); // [M*K, M*K]
436
437
        // compute BX
438
0
        for (size_t i = 0; i < n; i++) {
439
0
            for (size_t m = 0; m < M; m++) {
440
0
                int32_t code = codes[i * M + m];
441
0
                double* data = bx.data() + (m * K + code) * d;
442
0
                dfvec_add(d, data, x + i * d, data);
443
0
            }
444
0
        }
445
446
        // compute C = (B'B)^(-1) @ BX
447
        //
448
        // NOTE: LAPACK use column major order
449
        // out = alpha * op(A) * op(B) + beta * C
450
0
        FINTEGER nrows_A = d;
451
0
        FINTEGER ncols_A = M * K;
452
453
0
        FINTEGER nrows_B = M * K;
454
0
        FINTEGER ncols_B = M * K;
455
456
0
        std::vector<double> d_codebooks(M * K * d);
457
458
0
        double alpha = 1.0f;
459
0
        double beta = 0.0f;
460
0
        dgemm_("Not Transposed",
461
0
               "Not Transposed",
462
0
               &nrows_A, // nrows of op(A)
463
0
               &ncols_B, // ncols of op(B)
464
0
               &ncols_A, // ncols of op(A)
465
0
               &alpha,
466
0
               bx.data(),
467
0
               &nrows_A, // nrows of A
468
0
               bb.data(),
469
0
               &nrows_B, // nrows of B
470
0
               &beta,
471
0
               d_codebooks.data(),
472
0
               &nrows_A); // nrows of output
473
474
0
        for (size_t i = 0; i < M * K * d; i++) {
475
0
            codebooks[i] = (float)d_codebooks[i];
476
0
        }
477
0
    }
478
0
}
479
480
/** encode using iterative conditional mode
481
 *
482
 * iterative conditional mode:
483
 *     For every subcode ci (i = 1, ..., M) of a vector, we fix the other
484
 *     subcodes cj (j != i) and then find the optimal value of ci such
485
 *     that minimizing the objective function.
486
487
 * objective function:
488
 *     L = (X - \sum cj)^2, j = 1, ..., M
489
 *     L = X^2 - 2X * \sum cj + (\sum cj)^2
490
 *
491
 * X^2 is negligable since it is the same for all possible value
492
 * k of the m-th subcode.
493
 *
494
 * 2X * \sum cj is the unary term
495
 * (\sum cj)^2 is the binary term
496
 * These two terms can be precomputed and store in a look up table.
497
 */
498
void LocalSearchQuantizer::icm_encode(
499
        int32_t* codes,
500
        const float* x,
501
        size_t n,
502
        size_t ils_iters,
503
0
        std::mt19937& gen) const {
504
0
    LSQTimerScope scope(&lsq_timer, "icm_encode");
505
506
0
    auto factory = icm_encoder_factory;
507
0
    std::unique_ptr<lsq::IcmEncoder> icm_encoder;
508
0
    if (factory == nullptr) {
509
0
        icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
510
0
    } else {
511
0
        icm_encoder.reset(factory->get(this));
512
0
    }
513
514
    // precompute binary terms for all chunks
515
0
    icm_encoder->set_binary_term();
516
517
0
    const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
518
0
    for (size_t i = 0; i < n_chunks; i++) {
519
0
        size_t ni = std::min(chunk_size, n - i * chunk_size);
520
521
0
        if (verbose) {
522
0
            printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n);
523
0
            fflush(stdout);
524
0
            if (i == n_chunks - 1 || i == 0) {
525
0
                printf("\n");
526
0
            }
527
0
        }
528
529
0
        const float* xi = x + i * chunk_size * d;
530
0
        int32_t* codesi = codes + i * chunk_size * M;
531
0
        icm_encoder->verbose = (verbose && i == 0);
532
0
        icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
533
0
    }
534
0
}
535
536
void LocalSearchQuantizer::icm_encode_impl(
537
        int32_t* codes,
538
        const float* x,
539
        const float* binaries,
540
        std::mt19937& gen,
541
        size_t n,
542
        size_t ils_iters,
543
0
        bool verbose) const {
544
0
    std::vector<float> unaries(n * M * K); // [M, n, K]
545
0
    compute_unary_terms(x, unaries.data(), n);
546
547
0
    std::vector<int32_t> best_codes;
548
0
    best_codes.assign(codes, codes + n * M);
549
550
0
    std::vector<float> best_objs(n, 0.0f);
551
0
    evaluate(codes, x, n, best_objs.data());
552
553
0
    FAISS_THROW_IF_NOT(nperts <= M);
554
0
    for (size_t iter1 = 0; iter1 < ils_iters; iter1++) {
555
        // add perturbation to codes
556
0
        perturb_codes(codes, n, gen);
557
558
0
        icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
559
560
0
        std::vector<float> icm_objs(n, 0.0f);
561
0
        evaluate(codes, x, n, icm_objs.data());
562
0
        size_t n_betters = 0;
563
0
        float mean_obj = 0.0f;
564
565
        // select the best code for every vector xi
566
0
#pragma omp parallel for reduction(+ : n_betters, mean_obj)
567
0
        for (int64_t i = 0; i < n; i++) {
568
0
            if (icm_objs[i] < best_objs[i]) {
569
0
                best_objs[i] = icm_objs[i];
570
0
                memcpy(best_codes.data() + i * M,
571
0
                       codes + i * M,
572
0
                       sizeof(int32_t) * M);
573
0
                n_betters += 1;
574
0
            }
575
0
            mean_obj += best_objs[i];
576
0
        }
577
0
        mean_obj /= n;
578
579
0
        memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
580
581
0
        if (verbose) {
582
0
            printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
583
0
                   iter1,
584
0
                   mean_obj,
585
0
                   n_betters,
586
0
                   n);
587
0
        }
588
0
    } // loop ils_iters
589
0
}
590
591
void LocalSearchQuantizer::icm_encode_step(
592
        int32_t* codes,
593
        const float* unaries,
594
        const float* binaries,
595
        size_t n,
596
0
        size_t n_iters) const {
597
0
    FAISS_THROW_IF_NOT(M != 0 && K != 0);
598
0
    FAISS_THROW_IF_NOT(binaries != nullptr);
599
600
0
#pragma omp parallel for schedule(dynamic)
601
0
    for (int64_t i = 0; i < n; i++) {
602
0
        std::vector<float> objs(K);
603
604
0
        for (size_t iter = 0; iter < n_iters; iter++) {
605
            // condition on the m-th subcode
606
0
            for (size_t m = 0; m < M; m++) {
607
                // copy
608
0
                auto u = unaries + m * n * K + i * K;
609
0
                for (size_t code = 0; code < K; code++) {
610
0
                    objs[code] = u[code];
611
0
                }
612
613
                // compute objective function by adding unary
614
                // and binary terms together
615
0
                for (size_t other_m = 0; other_m < M; other_m++) {
616
0
                    if (other_m == m) {
617
0
                        continue;
618
0
                    }
619
620
0
#ifdef __AVX2__
621
                    // TODO: add platform-independent compiler-independent
622
                    // prefetch utilities.
623
0
                    if (other_m + 1 < M) {
624
                        // do a single prefetch
625
0
                        int32_t code2 = codes[i * M + other_m + 1];
626
                        // for (int32_t code = 0; code < K; code += 64) {
627
0
                        int32_t code = 0;
628
0
                        {
629
0
                            size_t binary_idx = (other_m + 1) * M * K * K +
630
0
                                    m * K * K + code2 * K + code;
631
0
                            _mm_prefetch(
632
0
                                    (const char*)(binaries + binary_idx),
633
0
                                    _MM_HINT_T0);
634
0
                        }
635
0
                    }
636
0
#endif
637
638
0
                    for (int32_t code = 0; code < K; code++) {
639
0
                        int32_t code2 = codes[i * M + other_m];
640
0
                        size_t binary_idx = other_m * M * K * K + m * K * K +
641
0
                                code2 * K + code;
642
                        // binaries[m, other_m, code, code2].
643
                        // It is symmetric over (m <-> other_m)
644
                        //   and (code <-> code2).
645
                        // So, replace the op with
646
                        //   binaries[other_m, m, code2, code].
647
0
                        objs[code] += binaries[binary_idx];
648
0
                    }
649
0
                }
650
651
                // find the optimal value of the m-th subcode
652
0
                float best_obj = HUGE_VALF;
653
0
                int32_t best_code = 0;
654
655
                // find one using SIMD. The following operation is similar
656
                // to the search of the smallest element in objs
657
0
                using C = CMax<float, int>;
658
0
                HeapWithBuckets<C, 16, 1>::addn(
659
0
                        K, objs.data(), 1, &best_obj, &best_code);
660
661
                // done
662
0
                codes[i * M + m] = best_code;
663
664
0
            } // loop M
665
0
        }
666
0
    }
667
0
}
668
void LocalSearchQuantizer::perturb_codes(
669
        int32_t* codes,
670
        size_t n,
671
0
        std::mt19937& gen) const {
672
0
    LSQTimerScope scope(&lsq_timer, "perturb_codes");
673
674
0
    std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
675
0
    std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
676
677
0
    for (size_t i = 0; i < n; i++) {
678
0
        for (size_t j = 0; j < nperts; j++) {
679
0
            size_t m = m_distrib(gen);
680
0
            codes[i * M + m] = k_distrib(gen);
681
0
        }
682
0
    }
683
0
}
684
685
0
void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
686
0
    LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
687
688
0
#pragma omp parallel for
689
0
    for (int64_t m12 = 0; m12 < M * M; m12++) {
690
0
        size_t m1 = m12 / M;
691
0
        size_t m2 = m12 % M;
692
693
0
        for (size_t code1 = 0; code1 < K; code1++) {
694
0
            for (size_t code2 = 0; code2 < K; code2++) {
695
0
                const float* c1 = codebooks.data() + m1 * K * d + code1 * d;
696
0
                const float* c2 = codebooks.data() + m2 * K * d + code2 * d;
697
0
                float ip = fvec_inner_product(c1, c2, d);
698
                // binaries[m1, m2, code1, code2] = ip * 2
699
0
                binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] =
700
0
                        ip * 2;
701
0
            }
702
0
        }
703
0
    }
704
0
}
705
706
void LocalSearchQuantizer::compute_unary_terms(
707
        const float* x,
708
        float* unaries, // [M, n, K]
709
0
        size_t n) const {
710
0
    LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
711
712
    // compute x * codebook^T for each codebook
713
    //
714
    // NOTE: LAPACK use column major order
715
    // out = alpha * op(A) * op(B) + beta * C
716
717
0
    for (size_t m = 0; m < M; m++) {
718
0
        FINTEGER nrows_A = K;
719
0
        FINTEGER ncols_A = d;
720
721
0
        FINTEGER nrows_B = d;
722
0
        FINTEGER ncols_B = n;
723
724
0
        float alpha = -2.0f;
725
0
        float beta = 0.0f;
726
0
        sgemm_("Transposed",
727
0
               "Not Transposed",
728
0
               &nrows_A, // nrows of op(A)
729
0
               &ncols_B, // ncols of op(B)
730
0
               &ncols_A, // ncols of op(A)
731
0
               &alpha,
732
0
               codebooks.data() + m * K * d,
733
0
               &ncols_A, // nrows of A
734
0
               x,
735
0
               &nrows_B, // nrows of B
736
0
               &beta,
737
0
               unaries + m * n * K,
738
0
               &nrows_A); // nrows of output
739
0
    }
740
741
0
    std::vector<float> norms(M * K);
742
0
    fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
743
744
0
#pragma omp parallel for
745
0
    for (int64_t i = 0; i < n; i++) {
746
0
        for (size_t m = 0; m < M; m++) {
747
0
            float* u = unaries + m * n * K + i * K;
748
0
            fvec_add(K, u, norms.data() + m * K, u);
749
0
        }
750
0
    }
751
0
}
752
753
float LocalSearchQuantizer::evaluate(
754
        const int32_t* codes,
755
        const float* x,
756
        size_t n,
757
0
        float* objs) const {
758
0
    LSQTimerScope scope(&lsq_timer, "evaluate");
759
760
    // decode
761
0
    std::vector<float> decoded_x(n * d, 0.0f);
762
0
    float obj = 0.0f;
763
764
0
#pragma omp parallel for reduction(+ : obj)
765
0
    for (int64_t i = 0; i < n; i++) {
766
0
        const auto code = codes + i * M;
767
0
        const auto decoded_i = decoded_x.data() + i * d;
768
0
        for (size_t m = 0; m < M; m++) {
769
            // c = codebooks[m, code[m]]
770
0
            const auto c = codebooks.data() + m * K * d + code[m] * d;
771
0
            fvec_add(d, decoded_i, c, decoded_i);
772
0
        }
773
774
0
        float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
775
0
        obj += err;
776
777
0
        if (objs) {
778
0
            objs[i] = err;
779
0
        }
780
0
    }
781
782
0
    obj = obj / n;
783
0
    return obj;
784
0
}
785
786
namespace lsq {
787
788
IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
789
0
        : verbose(false), lsq(lsq) {}
790
791
0
void IcmEncoder::set_binary_term() {
792
0
    auto M = lsq->M;
793
0
    auto K = lsq->K;
794
0
    binaries.resize(M * M * K * K);
795
0
    lsq->compute_binary_terms(binaries.data());
796
0
}
797
798
void IcmEncoder::encode(
799
        int32_t* codes,
800
        const float* x,
801
        std::mt19937& gen,
802
        size_t n,
803
0
        size_t ils_iters) const {
804
0
    lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
805
0
}
806
807
0
double LSQTimer::get(const std::string& name) {
808
0
    if (t.count(name) == 0) {
809
0
        return 0.0;
810
0
    } else {
811
0
        return t[name];
812
0
    }
813
0
}
814
815
0
void LSQTimer::add(const std::string& name, double delta) {
816
0
    if (t.count(name) == 0) {
817
0
        t[name] = delta;
818
0
    } else {
819
0
        t[name] += delta;
820
0
    }
821
0
}
822
823
1
void LSQTimer::reset() {
824
1
    t.clear();
825
1
}
826
827
LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
828
0
        : timer(timer), name(name), finished(false) {
829
0
    t0 = getmillisecs();
830
0
}
831
832
0
void LSQTimerScope::finish() {
833
0
    if (!finished) {
834
0
        auto delta = getmillisecs() - t0;
835
0
        timer->add(name, delta);
836
0
        finished = true;
837
0
    }
838
0
}
839
840
0
LSQTimerScope::~LSQTimerScope() {
841
0
    finish();
842
0
}
843
844
} // namespace lsq
845
846
} // namespace faiss