Coverage Report

Created: 2026-03-16 08:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/VectorTransform.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/VectorTransform.h>
11
12
#include <cinttypes>
13
#include <cmath>
14
#include <cstdio>
15
#include <cstring>
16
#include <memory>
17
18
#include <faiss/IndexPQ.h>
19
#include <faiss/impl/FaissAssert.h>
20
#include <faiss/utils/distances.h>
21
#include <faiss/utils/random.h>
22
#include <faiss/utils/utils.h>
23
24
using namespace faiss;
25
26
extern "C" {
27
28
// this is to keep the clang syntax checker happy
29
#ifndef FINTEGER
30
#define FINTEGER int
31
#endif
32
33
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
34
35
int sgemm_(
36
        const char* transa,
37
        const char* transb,
38
        FINTEGER* m,
39
        FINTEGER* n,
40
        FINTEGER* k,
41
        const float* alpha,
42
        const float* a,
43
        FINTEGER* lda,
44
        const float* b,
45
        FINTEGER* ldb,
46
        float* beta,
47
        float* c,
48
        FINTEGER* ldc);
49
50
int dgemm_(
51
        const char* transa,
52
        const char* transb,
53
        FINTEGER* m,
54
        FINTEGER* n,
55
        FINTEGER* k,
56
        const double* alpha,
57
        const double* a,
58
        FINTEGER* lda,
59
        const double* b,
60
        FINTEGER* ldb,
61
        double* beta,
62
        double* c,
63
        FINTEGER* ldc);
64
65
int ssyrk_(
66
        const char* uplo,
67
        const char* trans,
68
        FINTEGER* n,
69
        FINTEGER* k,
70
        float* alpha,
71
        float* a,
72
        FINTEGER* lda,
73
        float* beta,
74
        float* c,
75
        FINTEGER* ldc);
76
77
/* Lapack functions from http://www.netlib.org/clapack/old/single/ */
78
79
int ssyev_(
80
        const char* jobz,
81
        const char* uplo,
82
        FINTEGER* n,
83
        float* a,
84
        FINTEGER* lda,
85
        float* w,
86
        float* work,
87
        FINTEGER* lwork,
88
        FINTEGER* info);
89
90
int dsyev_(
91
        const char* jobz,
92
        const char* uplo,
93
        FINTEGER* n,
94
        double* a,
95
        FINTEGER* lda,
96
        double* w,
97
        double* work,
98
        FINTEGER* lwork,
99
        FINTEGER* info);
100
101
int sgesvd_(
102
        const char* jobu,
103
        const char* jobvt,
104
        FINTEGER* m,
105
        FINTEGER* n,
106
        float* a,
107
        FINTEGER* lda,
108
        float* s,
109
        float* u,
110
        FINTEGER* ldu,
111
        float* vt,
112
        FINTEGER* ldvt,
113
        float* work,
114
        FINTEGER* lwork,
115
        FINTEGER* info);
116
117
int dgesvd_(
118
        const char* jobu,
119
        const char* jobvt,
120
        FINTEGER* m,
121
        FINTEGER* n,
122
        double* a,
123
        FINTEGER* lda,
124
        double* s,
125
        double* u,
126
        FINTEGER* ldu,
127
        double* vt,
128
        FINTEGER* ldvt,
129
        double* work,
130
        FINTEGER* lwork,
131
        FINTEGER* info);
132
}
133
134
/*********************************************
135
 * VectorTransform
136
 *********************************************/
137
138
0
float* VectorTransform::apply(idx_t n, const float* x) const {
139
0
    float* xt = new float[n * d_out];
140
0
    apply_noalloc(n, x, xt);
141
0
    return xt;
142
0
}
143
144
0
void VectorTransform::train(idx_t, const float*) {
145
    // does nothing by default
146
0
}
147
148
0
void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
149
0
    FAISS_THROW_MSG("reverse transform not implemented");
150
0
}
151
152
0
void VectorTransform::check_identical(const VectorTransform& other) const {
153
0
    FAISS_THROW_IF_NOT(other.d_in == d_in && other.d_in == d_in);
154
0
}
155
156
/*********************************************
157
 * LinearTransform
158
 *********************************************/
159
160
/// both d_in > d_out and d_out < d_in are supported
161
LinearTransform::LinearTransform(int d_in, int d_out, bool have_bias)
162
0
        : VectorTransform(d_in, d_out),
163
0
          have_bias(have_bias),
164
0
          is_orthonormal(false),
165
0
          verbose(false) {
166
0
    is_trained = false; // will be trained when A and b are initialized
167
0
}
168
169
0
void LinearTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
170
0
    FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
171
172
0
    float c_factor;
173
0
    if (have_bias) {
174
0
        FAISS_THROW_IF_NOT_MSG(b.size() == d_out, "Bias not initialized");
175
0
        float* xi = xt;
176
0
        for (int i = 0; i < n; i++)
177
0
            for (int j = 0; j < d_out; j++)
178
0
                *xi++ = b[j];
179
0
        c_factor = 1.0;
180
0
    } else {
181
0
        c_factor = 0.0;
182
0
    }
183
184
0
    FAISS_THROW_IF_NOT_MSG(
185
0
            A.size() == d_out * d_in, "Transformation matrix not initialized");
186
187
0
    float one = 1;
188
0
    FINTEGER nbiti = d_out, ni = n, di = d_in;
189
0
    sgemm_("Transposed",
190
0
           "Not transposed",
191
0
           &nbiti,
192
0
           &ni,
193
0
           &di,
194
0
           &one,
195
0
           A.data(),
196
0
           &di,
197
0
           x,
198
0
           &di,
199
0
           &c_factor,
200
0
           xt,
201
0
           &nbiti);
202
0
}
203
204
void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
205
0
        const {
206
0
    if (have_bias) { // allocate buffer to store bias-corrected data
207
0
        float* y_new = new float[n * d_out];
208
0
        const float* yr = y;
209
0
        float* yw = y_new;
210
0
        for (idx_t i = 0; i < n; i++) {
211
0
            for (int j = 0; j < d_out; j++) {
212
0
                *yw++ = *yr++ - b[j];
213
0
            }
214
0
        }
215
0
        y = y_new;
216
0
    }
217
218
0
    {
219
0
        FINTEGER dii = d_in, doi = d_out, ni = n;
220
0
        float one = 1.0, zero = 0.0;
221
0
        sgemm_("Not",
222
0
               "Not",
223
0
               &dii,
224
0
               &ni,
225
0
               &doi,
226
0
               &one,
227
0
               A.data(),
228
0
               &dii,
229
0
               y,
230
0
               &doi,
231
0
               &zero,
232
0
               x,
233
0
               &dii);
234
0
    }
235
236
0
    if (have_bias)
237
0
        delete[] y;
238
0
}
239
240
0
void LinearTransform::set_is_orthonormal() {
241
0
    if (d_out > d_in) {
242
        // not clear what we should do in this case
243
0
        is_orthonormal = false;
244
0
        return;
245
0
    }
246
0
    if (d_out == 0) { // borderline case, unnormalized matrix
247
0
        is_orthonormal = true;
248
0
        return;
249
0
    }
250
251
0
    double eps = 4e-5;
252
0
    FAISS_ASSERT(A.size() >= d_out * d_in);
253
0
    {
254
0
        std::vector<float> ATA(d_out * d_out);
255
0
        FINTEGER dii = d_in, doi = d_out;
256
0
        float one = 1.0, zero = 0.0;
257
258
0
        sgemm_("Transposed",
259
0
               "Not",
260
0
               &doi,
261
0
               &doi,
262
0
               &dii,
263
0
               &one,
264
0
               A.data(),
265
0
               &dii,
266
0
               A.data(),
267
0
               &dii,
268
0
               &zero,
269
0
               ATA.data(),
270
0
               &doi);
271
272
0
        is_orthonormal = true;
273
0
        for (long i = 0; i < d_out; i++) {
274
0
            for (long j = 0; j < d_out; j++) {
275
0
                float v = ATA[i + j * d_out];
276
0
                if (i == j)
277
0
                    v -= 1;
278
0
                if (fabs(v) > eps) {
279
0
                    is_orthonormal = false;
280
0
                }
281
0
            }
282
0
        }
283
0
    }
284
0
}
285
286
void LinearTransform::reverse_transform(idx_t n, const float* xt, float* x)
287
0
        const {
288
0
    if (is_orthonormal) {
289
0
        transform_transpose(n, xt, x);
290
0
    } else {
291
0
        FAISS_THROW_MSG(
292
0
                "reverse transform not implemented for non-orthonormal matrices");
293
0
    }
294
0
}
295
296
void LinearTransform::print_if_verbose(
297
        const char* name,
298
        const std::vector<double>& mat,
299
        int n,
300
0
        int d) const {
301
0
    if (!verbose)
302
0
        return;
303
0
    printf("matrix %s: %d*%d [\n", name, n, d);
304
0
    FAISS_THROW_IF_NOT(mat.size() >= n * d);
305
0
    for (int i = 0; i < n; i++) {
306
0
        for (int j = 0; j < d; j++) {
307
0
            printf("%10.5g ", mat[i * d + j]);
308
0
        }
309
0
        printf("\n");
310
0
    }
311
0
    printf("]\n");
312
0
}
313
314
0
void LinearTransform::check_identical(const VectorTransform& other_in) const {
315
0
    VectorTransform::check_identical(other_in);
316
0
    auto other = dynamic_cast<const LinearTransform*>(&other_in);
317
0
    FAISS_THROW_IF_NOT(other);
318
0
    FAISS_THROW_IF_NOT(other->A == A && other->b == b);
319
0
}
320
321
/*********************************************
322
 * RandomRotationMatrix
323
 *********************************************/
324
325
0
void RandomRotationMatrix::init(int seed) {
326
0
    if (d_out <= d_in) {
327
0
        A.resize(d_out * d_in);
328
0
        float* q = A.data();
329
0
        float_randn(q, d_out * d_in, seed);
330
0
        matrix_qr(d_in, d_out, q);
331
0
    } else {
332
        // use tight-frame transformation
333
0
        A.resize(d_out * d_out);
334
0
        float* q = A.data();
335
0
        float_randn(q, d_out * d_out, seed);
336
0
        matrix_qr(d_out, d_out, q);
337
        // remove columns
338
0
        int i, j;
339
0
        for (i = 0; i < d_out; i++) {
340
0
            for (j = 0; j < d_in; j++) {
341
0
                q[i * d_in + j] = q[i * d_out + j];
342
0
            }
343
0
        }
344
0
        A.resize(d_in * d_out);
345
0
    }
346
0
    is_orthonormal = true;
347
0
    is_trained = true;
348
0
}
349
350
0
void RandomRotationMatrix::train(idx_t /*n*/, const float* /*x*/) {
351
    // initialize with some arbitrary seed
352
0
    init(12345);
353
0
}
354
355
/*********************************************
356
 * PCAMatrix
357
 *********************************************/
358
359
PCAMatrix::PCAMatrix(
360
        int d_in,
361
        int d_out,
362
        float eigen_power,
363
        bool random_rotation)
364
0
        : LinearTransform(d_in, d_out, true),
365
0
          eigen_power(eigen_power),
366
0
          random_rotation(random_rotation) {
367
0
    is_trained = false;
368
0
    max_points_per_d = 1000;
369
0
    balanced_bins = 0;
370
0
    epsilon = 0;
371
0
}
372
373
namespace {
374
375
/// Compute the eigenvalue decomposition of symmetric matrix cov,
376
/// dimensions d_in-by-d_in. Output eigenvectors in cov.
377
378
0
void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
379
0
    { // compute eigenvalues and vectors
380
0
        FINTEGER info = 0, lwork = -1, di = d_in;
381
0
        double workq;
382
383
0
        dsyev_("Vectors as well",
384
0
               "Upper",
385
0
               &di,
386
0
               cov,
387
0
               &di,
388
0
               eigenvalues,
389
0
               &workq,
390
0
               &lwork,
391
0
               &info);
392
0
        lwork = FINTEGER(workq);
393
0
        double* work = new double[lwork];
394
395
0
        dsyev_("Vectors as well",
396
0
               "Upper",
397
0
               &di,
398
0
               cov,
399
0
               &di,
400
0
               eigenvalues,
401
0
               work,
402
0
               &lwork,
403
0
               &info);
404
405
0
        delete[] work;
406
407
0
        if (info != 0) {
408
0
            fprintf(stderr,
409
0
                    "WARN ssyev info returns %d, "
410
0
                    "a very bad PCA matrix is learnt\n",
411
0
                    int(info));
412
            // do not throw exception, as the matrix could still be useful
413
0
        }
414
415
0
        if (verbose && d_in <= 10) {
416
0
            printf("info=%ld new eigvals=[", long(info));
417
0
            for (int j = 0; j < d_in; j++)
418
0
                printf("%g ", eigenvalues[j]);
419
0
            printf("]\n");
420
421
0
            double* ci = cov;
422
0
            printf("eigenvecs=\n");
423
0
            for (int i = 0; i < d_in; i++) {
424
0
                for (int j = 0; j < d_in; j++)
425
0
                    printf("%10.4g ", *ci++);
426
0
                printf("\n");
427
0
            }
428
0
        }
429
0
    }
430
431
    // revert order of eigenvectors & values
432
433
0
    for (int i = 0; i < d_in / 2; i++) {
434
0
        std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
435
0
        double* v1 = cov + i * d_in;
436
0
        double* v2 = cov + (d_in - 1 - i) * d_in;
437
0
        for (int j = 0; j < d_in; j++)
438
0
            std::swap(v1[j], v2[j]);
439
0
    }
440
0
}
441
442
} // namespace
443
444
0
void PCAMatrix::train(idx_t n, const float* x_in) {
445
0
    const float* x = fvecs_maybe_subsample(
446
0
            d_in, (size_t*)&n, max_points_per_d * d_in, x_in, verbose);
447
0
    TransformedVectors tv(x_in, x);
448
449
    // compute mean
450
0
    mean.clear();
451
0
    mean.resize(d_in, 0.0);
452
0
    if (have_bias) { // we may want to skip the bias
453
0
        const float* xi = x;
454
0
        for (int i = 0; i < n; i++) {
455
0
            for (int j = 0; j < d_in; j++)
456
0
                mean[j] += *xi++;
457
0
        }
458
0
        for (int j = 0; j < d_in; j++)
459
0
            mean[j] /= n;
460
0
    }
461
0
    if (verbose) {
462
0
        printf("mean=[");
463
0
        for (int j = 0; j < d_in; j++)
464
0
            printf("%g ", mean[j]);
465
0
        printf("]\n");
466
0
    }
467
468
0
    if (n >= d_in) {
469
        // compute covariance matrix, store it in PCA matrix
470
0
        PCAMat.resize(d_in * d_in);
471
0
        float* cov = PCAMat.data();
472
0
        { // initialize with  mean * mean^T term
473
0
            float* ci = cov;
474
0
            for (int i = 0; i < d_in; i++) {
475
0
                for (int j = 0; j < d_in; j++)
476
0
                    *ci++ = -n * mean[i] * mean[j];
477
0
            }
478
0
        }
479
0
        {
480
0
            FINTEGER di = d_in, ni = n;
481
0
            float one = 1.0;
482
0
            ssyrk_("Up",
483
0
                   "Non transposed",
484
0
                   &di,
485
0
                   &ni,
486
0
                   &one,
487
0
                   (float*)x,
488
0
                   &di,
489
0
                   &one,
490
0
                   cov,
491
0
                   &di);
492
0
        }
493
0
        if (verbose && d_in <= 10) {
494
0
            float* ci = cov;
495
0
            printf("cov=\n");
496
0
            for (int i = 0; i < d_in; i++) {
497
0
                for (int j = 0; j < d_in; j++)
498
0
                    printf("%10g ", *ci++);
499
0
                printf("\n");
500
0
            }
501
0
        }
502
503
0
        std::vector<double> covd(d_in * d_in);
504
0
        for (size_t i = 0; i < d_in * d_in; i++)
505
0
            covd[i] = cov[i];
506
507
0
        std::vector<double> eigenvaluesd(d_in);
508
509
0
        eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
510
511
0
        for (size_t i = 0; i < d_in * d_in; i++)
512
0
            PCAMat[i] = covd[i];
513
0
        eigenvalues.resize(d_in);
514
515
0
        for (size_t i = 0; i < d_in; i++)
516
0
            eigenvalues[i] = eigenvaluesd[i];
517
518
0
    } else {
519
0
        std::vector<float> xc(n * d_in);
520
521
0
        for (size_t i = 0; i < n; i++)
522
0
            for (size_t j = 0; j < d_in; j++)
523
0
                xc[i * d_in + j] = x[i * d_in + j] - mean[j];
524
525
        // compute Gram matrix
526
0
        std::vector<float> gram(n * n);
527
0
        {
528
0
            FINTEGER di = d_in, ni = n;
529
0
            float one = 1.0, zero = 0.0;
530
0
            ssyrk_("Up",
531
0
                   "Transposed",
532
0
                   &ni,
533
0
                   &di,
534
0
                   &one,
535
0
                   xc.data(),
536
0
                   &di,
537
0
                   &zero,
538
0
                   gram.data(),
539
0
                   &ni);
540
0
        }
541
542
0
        if (verbose && d_in <= 10) {
543
0
            float* ci = gram.data();
544
0
            printf("gram=\n");
545
0
            for (int i = 0; i < n; i++) {
546
0
                for (int j = 0; j < n; j++)
547
0
                    printf("%10g ", *ci++);
548
0
                printf("\n");
549
0
            }
550
0
        }
551
552
0
        std::vector<double> gramd(n * n);
553
0
        for (size_t i = 0; i < n * n; i++)
554
0
            gramd[i] = gram[i];
555
556
0
        std::vector<double> eigenvaluesd(n);
557
558
        // eig will fill in only the n first eigenvals
559
560
0
        eig(n, gramd.data(), eigenvaluesd.data(), verbose);
561
562
0
        PCAMat.resize(d_in * n);
563
564
0
        for (size_t i = 0; i < n * n; i++)
565
0
            gram[i] = gramd[i];
566
567
0
        eigenvalues.resize(d_in);
568
        // fill in only the n first ones
569
0
        for (size_t i = 0; i < n; i++)
570
0
            eigenvalues[i] = eigenvaluesd[i];
571
572
0
        { // compute PCAMat = x' * v
573
0
            FINTEGER di = d_in, ni = n;
574
0
            float one = 1.0;
575
576
0
            sgemm_("Non",
577
0
                   "Non Trans",
578
0
                   &di,
579
0
                   &ni,
580
0
                   &ni,
581
0
                   &one,
582
0
                   xc.data(),
583
0
                   &di,
584
0
                   gram.data(),
585
0
                   &ni,
586
0
                   &one,
587
0
                   PCAMat.data(),
588
0
                   &di);
589
0
        }
590
591
0
        if (verbose && d_in <= 10) {
592
0
            float* ci = PCAMat.data();
593
0
            printf("PCAMat=\n");
594
0
            for (int i = 0; i < n; i++) {
595
0
                for (int j = 0; j < d_in; j++)
596
0
                    printf("%10g ", *ci++);
597
0
                printf("\n");
598
0
            }
599
0
        }
600
0
        fvec_renorm_L2(d_in, n, PCAMat.data());
601
0
    }
602
603
0
    prepare_Ab();
604
0
    is_trained = true;
605
0
}
606
607
0
void PCAMatrix::copy_from(const PCAMatrix& other) {
608
0
    FAISS_THROW_IF_NOT(other.is_trained);
609
0
    mean = other.mean;
610
0
    eigenvalues = other.eigenvalues;
611
0
    PCAMat = other.PCAMat;
612
0
    prepare_Ab();
613
0
    is_trained = true;
614
0
}
615
616
0
void PCAMatrix::prepare_Ab() {
617
0
    FAISS_THROW_IF_NOT_FMT(
618
0
            d_out * d_in <= PCAMat.size(),
619
0
            "PCA matrix cannot output %d dimensions from %d ",
620
0
            d_out,
621
0
            d_in);
622
623
0
    if (!random_rotation) {
624
0
        A = PCAMat;
625
0
        A.resize(d_out * d_in); // strip off useless dimensions
626
627
        // first scale the components
628
0
        if (eigen_power != 0) {
629
0
            float* ai = A.data();
630
0
            for (int i = 0; i < d_out; i++) {
631
0
                float factor = pow(eigenvalues[i] + epsilon, eigen_power);
632
0
                for (int j = 0; j < d_in; j++)
633
0
                    *ai++ *= factor;
634
0
            }
635
0
        }
636
637
0
        if (balanced_bins != 0) {
638
0
            FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
639
0
            int dsub = d_out / balanced_bins;
640
0
            std::vector<float> Ain;
641
0
            std::swap(A, Ain);
642
0
            A.resize(d_out * d_in);
643
644
0
            std::vector<float> accu(balanced_bins);
645
0
            std::vector<int> counter(balanced_bins);
646
647
            // greedy assignment
648
0
            for (int i = 0; i < d_out; i++) {
649
                // find best bin
650
0
                int best_j = -1;
651
0
                float min_w = 1e30;
652
0
                for (int j = 0; j < balanced_bins; j++) {
653
0
                    if (counter[j] < dsub && accu[j] < min_w) {
654
0
                        min_w = accu[j];
655
0
                        best_j = j;
656
0
                    }
657
0
                }
658
0
                int row_dst = best_j * dsub + counter[best_j];
659
0
                accu[best_j] += eigenvalues[i];
660
0
                counter[best_j]++;
661
0
                memcpy(&A[row_dst * d_in], &Ain[i * d_in], d_in * sizeof(A[0]));
662
0
            }
663
664
0
            if (verbose) {
665
0
                printf("  bin accu=[");
666
0
                for (int i = 0; i < balanced_bins; i++)
667
0
                    printf("%g ", accu[i]);
668
0
                printf("]\n");
669
0
            }
670
0
        }
671
672
0
    } else {
673
0
        FAISS_THROW_IF_NOT_MSG(
674
0
                balanced_bins == 0,
675
0
                "both balancing bins and applying a random rotation "
676
0
                "does not make sense");
677
0
        RandomRotationMatrix rr(d_out, d_out);
678
679
0
        rr.init(5);
680
681
        // apply scaling on the rotation matrix (right multiplication)
682
0
        if (eigen_power != 0) {
683
0
            for (int i = 0; i < d_out; i++) {
684
0
                float factor = pow(eigenvalues[i], eigen_power);
685
0
                for (int j = 0; j < d_out; j++)
686
0
                    rr.A[j * d_out + i] *= factor;
687
0
            }
688
0
        }
689
690
0
        A.resize(d_in * d_out);
691
0
        {
692
0
            FINTEGER dii = d_in, doo = d_out;
693
0
            float one = 1.0, zero = 0.0;
694
695
0
            sgemm_("Not",
696
0
                   "Not",
697
0
                   &dii,
698
0
                   &doo,
699
0
                   &doo,
700
0
                   &one,
701
0
                   PCAMat.data(),
702
0
                   &dii,
703
0
                   rr.A.data(),
704
0
                   &doo,
705
0
                   &zero,
706
0
                   A.data(),
707
0
                   &dii);
708
0
        }
709
0
    }
710
711
0
    b.clear();
712
0
    b.resize(d_out);
713
714
0
    for (int i = 0; i < d_out; i++) {
715
0
        float accu = 0;
716
0
        for (int j = 0; j < d_in; j++)
717
0
            accu -= mean[j] * A[j + i * d_in];
718
0
        b[i] = accu;
719
0
    }
720
721
0
    is_orthonormal = eigen_power == 0;
722
0
}
723
724
/*********************************************
725
 * ITQMatrix
726
 *********************************************/
727
728
ITQMatrix::ITQMatrix(int d)
729
0
        : LinearTransform(d, d, false), max_iter(50), seed(123) {}
730
731
/** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
732
0
void ITQMatrix::train(idx_t n, const float* xf) {
733
0
    size_t d = d_in;
734
0
    std::vector<double> rotation(d * d);
735
736
0
    if (init_rotation.size() == d * d) {
737
0
        memcpy(rotation.data(),
738
0
               init_rotation.data(),
739
0
               d * d * sizeof(rotation[0]));
740
0
    } else {
741
0
        RandomRotationMatrix rrot(d, d);
742
0
        rrot.init(seed);
743
0
        for (size_t i = 0; i < d * d; i++) {
744
0
            rotation[i] = rrot.A[i];
745
0
        }
746
0
    }
747
748
0
    std::vector<double> x(n * d);
749
750
0
    for (size_t i = 0; i < n * d; i++) {
751
0
        x[i] = xf[i];
752
0
    }
753
754
0
    std::vector<double> rotated_x(n * d), cov_mat(d * d);
755
0
    std::vector<double> u(d * d), vt(d * d), singvals(d);
756
757
0
    for (int i = 0; i < max_iter; i++) {
758
0
        print_if_verbose("rotation", rotation, d, d);
759
0
        { // rotated_data = np.dot(training_data, rotation)
760
0
            FINTEGER di = d, ni = n;
761
0
            double one = 1, zero = 0;
762
0
            dgemm_("N",
763
0
                   "N",
764
0
                   &di,
765
0
                   &ni,
766
0
                   &di,
767
0
                   &one,
768
0
                   rotation.data(),
769
0
                   &di,
770
0
                   x.data(),
771
0
                   &di,
772
0
                   &zero,
773
0
                   rotated_x.data(),
774
0
                   &di);
775
0
        }
776
0
        print_if_verbose("rotated_x", rotated_x, n, d);
777
        // binarize
778
0
        for (size_t j = 0; j < n * d; j++) {
779
0
            rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
780
0
        }
781
        // covariance matrix
782
0
        { // rotated_data = np.dot(training_data, rotation)
783
0
            FINTEGER di = d, ni = n;
784
0
            double one = 1, zero = 0;
785
0
            dgemm_("N",
786
0
                   "T",
787
0
                   &di,
788
0
                   &di,
789
0
                   &ni,
790
0
                   &one,
791
0
                   rotated_x.data(),
792
0
                   &di,
793
0
                   x.data(),
794
0
                   &di,
795
0
                   &zero,
796
0
                   cov_mat.data(),
797
0
                   &di);
798
0
        }
799
0
        print_if_verbose("cov_mat", cov_mat, d, d);
800
        // SVD
801
0
        {
802
0
            FINTEGER di = d;
803
0
            FINTEGER lwork = -1, info;
804
0
            double lwork1;
805
806
            // workspace query
807
0
            dgesvd_("A",
808
0
                    "A",
809
0
                    &di,
810
0
                    &di,
811
0
                    cov_mat.data(),
812
0
                    &di,
813
0
                    singvals.data(),
814
0
                    u.data(),
815
0
                    &di,
816
0
                    vt.data(),
817
0
                    &di,
818
0
                    &lwork1,
819
0
                    &lwork,
820
0
                    &info);
821
822
0
            FAISS_THROW_IF_NOT(info == 0);
823
0
            lwork = size_t(lwork1);
824
0
            std::vector<double> work(lwork);
825
0
            dgesvd_("A",
826
0
                    "A",
827
0
                    &di,
828
0
                    &di,
829
0
                    cov_mat.data(),
830
0
                    &di,
831
0
                    singvals.data(),
832
0
                    u.data(),
833
0
                    &di,
834
0
                    vt.data(),
835
0
                    &di,
836
0
                    work.data(),
837
0
                    &lwork,
838
0
                    &info);
839
0
            FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
840
0
        }
841
0
        print_if_verbose("u", u, d, d);
842
0
        print_if_verbose("vt", vt, d, d);
843
        // update rotation
844
0
        {
845
0
            FINTEGER di = d;
846
0
            double one = 1, zero = 0;
847
0
            dgemm_("N",
848
0
                   "T",
849
0
                   &di,
850
0
                   &di,
851
0
                   &di,
852
0
                   &one,
853
0
                   u.data(),
854
0
                   &di,
855
0
                   vt.data(),
856
0
                   &di,
857
0
                   &zero,
858
0
                   rotation.data(),
859
0
                   &di);
860
0
        }
861
0
        print_if_verbose("final rot", rotation, d, d);
862
0
    }
863
0
    A.resize(d * d);
864
0
    for (size_t i = 0; i < d; i++) {
865
0
        for (size_t j = 0; j < d; j++) {
866
0
            A[i + d * j] = rotation[j + d * i];
867
0
        }
868
0
    }
869
0
    is_trained = true;
870
0
}
871
872
ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca)
873
0
        : VectorTransform(d_in, d_out),
874
0
          do_pca(do_pca),
875
0
          itq(d_out),
876
0
          pca_then_itq(d_in, d_out, false) {
877
0
    if (!do_pca) {
878
0
        FAISS_THROW_IF_NOT(d_in == d_out);
879
0
    }
880
0
    max_train_per_dim = 10;
881
0
    is_trained = false;
882
0
}
883
884
0
void ITQTransform::train(idx_t n, const float* x_in) {
885
0
    FAISS_THROW_IF_NOT(!is_trained);
886
887
0
    size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
888
0
    const float* x =
889
0
            fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x_in);
890
0
    TransformedVectors tv(x_in, x);
891
892
0
    std::unique_ptr<float[]> x_norm(new float[n * d_in]);
893
0
    { // normalize
894
0
        int d = d_in;
895
896
0
        mean.resize(d, 0);
897
0
        for (idx_t i = 0; i < n; i++) {
898
0
            for (idx_t j = 0; j < d; j++) {
899
0
                mean[j] += x[i * d + j];
900
0
            }
901
0
        }
902
0
        for (idx_t j = 0; j < d; j++) {
903
0
            mean[j] /= n;
904
0
        }
905
0
        for (idx_t i = 0; i < n; i++) {
906
0
            for (idx_t j = 0; j < d; j++) {
907
0
                x_norm[i * d + j] = x[i * d + j] - mean[j];
908
0
            }
909
0
        }
910
0
        fvec_renorm_L2(d_in, n, x_norm.get());
911
0
    }
912
913
    // train PCA
914
915
0
    PCAMatrix pca(d_in, d_out);
916
0
    float* x_pca;
917
0
    std::unique_ptr<float[]> x_pca_del;
918
0
    if (do_pca) {
919
0
        pca.have_bias = false; // for consistency with reference implem
920
0
        pca.train(n, x_norm.get());
921
0
        x_pca = pca.apply(n, x_norm.get());
922
0
        x_pca_del.reset(x_pca);
923
0
    } else {
924
0
        x_pca = x_norm.get();
925
0
    }
926
927
    // train ITQ
928
0
    itq.train(n, x_pca);
929
930
    // merge PCA and ITQ
931
0
    if (do_pca) {
932
0
        FINTEGER di = d_out, dini = d_in;
933
0
        float one = 1, zero = 0;
934
0
        pca_then_itq.A.resize(d_in * d_out);
935
0
        sgemm_("N",
936
0
               "N",
937
0
               &dini,
938
0
               &di,
939
0
               &di,
940
0
               &one,
941
0
               pca.A.data(),
942
0
               &dini,
943
0
               itq.A.data(),
944
0
               &di,
945
0
               &zero,
946
0
               pca_then_itq.A.data(),
947
0
               &dini);
948
0
    } else {
949
0
        pca_then_itq.A = itq.A;
950
0
    }
951
0
    pca_then_itq.is_trained = true;
952
0
    is_trained = true;
953
0
}
954
955
0
void ITQTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
956
0
    FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
957
958
0
    std::unique_ptr<float[]> x_norm(new float[n * d_in]);
959
0
    { // normalize
960
0
        int d = d_in;
961
0
        for (idx_t i = 0; i < n; i++) {
962
0
            for (idx_t j = 0; j < d; j++) {
963
0
                x_norm[i * d + j] = x[i * d + j] - mean[j];
964
0
            }
965
0
        }
966
        // this is not really useful if we are going to binarize right
967
        // afterwards but OK
968
0
        fvec_renorm_L2(d_in, n, x_norm.get());
969
0
    }
970
971
0
    pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
972
0
}
973
974
0
void ITQTransform::check_identical(const VectorTransform& other_in) const {
975
0
    VectorTransform::check_identical(other_in);
976
0
    auto other = dynamic_cast<const ITQTransform*>(&other_in);
977
0
    FAISS_THROW_IF_NOT(other);
978
0
    pca_then_itq.check_identical(other->pca_then_itq);
979
0
    FAISS_THROW_IF_NOT(other->mean == mean);
980
0
}
981
982
/*********************************************
983
 * OPQMatrix
984
 *********************************************/
985
986
OPQMatrix::OPQMatrix(int d, int M, int d2)
987
0
        : LinearTransform(d, d2 == -1 ? d : d2, false), M(M) {
988
0
    is_trained = false;
989
    // OPQ is quite expensive to train, so set this right.
990
0
    max_train_points = 256 * 256;
991
0
}
992
993
0
void OPQMatrix::train(idx_t n, const float* x_in) {
994
0
    const float* x = fvecs_maybe_subsample(
995
0
            d_in, (size_t*)&n, max_train_points, x_in, verbose);
996
0
    TransformedVectors tv(x_in, x);
997
998
    // To support d_out > d_in, we pad input vectors with 0s to d_out
999
0
    size_t d = d_out <= d_in ? d_in : d_out;
1000
0
    size_t d2 = d_out;
1001
1002
#if 0
1003
    // what this test shows: the only way of getting bit-exact
1004
    // reproducible results with sgeqrf and sgesvd seems to be forcing
1005
    // single-threading.
1006
    { // test repro
1007
        std::vector<float> r (d * d);
1008
        float * rotation = r.data();
1009
        float_randn (rotation, d * d, 1234);
1010
        printf("CS0: %016lx\n",
1011
               ivec_checksum (128*128, (int*)rotation));
1012
        matrix_qr (d, d, rotation);
1013
        printf("CS1: %016lx\n",
1014
               ivec_checksum (128*128, (int*)rotation));
1015
        return;
1016
    }
1017
#endif
1018
1019
0
    if (verbose) {
1020
0
        printf("OPQMatrix::train: training an OPQ rotation matrix "
1021
0
               "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
1022
0
               M,
1023
0
               n,
1024
0
               d_in,
1025
0
               d_out);
1026
0
    }
1027
1028
0
    std::vector<float> xtrain(n * d);
1029
    // center x
1030
0
    {
1031
0
        std::vector<float> sum(d);
1032
0
        const float* xi = x;
1033
0
        for (size_t i = 0; i < n; i++) {
1034
0
            for (int j = 0; j < d_in; j++)
1035
0
                sum[j] += *xi++;
1036
0
        }
1037
0
        for (int i = 0; i < d; i++)
1038
0
            sum[i] /= n;
1039
0
        float* yi = xtrain.data();
1040
0
        xi = x;
1041
0
        for (size_t i = 0; i < n; i++) {
1042
0
            for (int j = 0; j < d_in; j++)
1043
0
                *yi++ = *xi++ - sum[j];
1044
0
            yi += d - d_in;
1045
0
        }
1046
0
    }
1047
0
    float* rotation;
1048
1049
0
    if (A.size() == 0) {
1050
0
        A.resize(d * d);
1051
0
        rotation = A.data();
1052
0
        if (verbose)
1053
0
            printf("  OPQMatrix::train: making random %zd*%zd rotation\n",
1054
0
                   d,
1055
0
                   d);
1056
0
        float_randn(rotation, d * d, 1234);
1057
0
        matrix_qr(d, d, rotation);
1058
        // we use only the d * d2 upper part of the matrix
1059
0
        A.resize(d * d2);
1060
0
    } else {
1061
0
        FAISS_THROW_IF_NOT(A.size() == d * d2);
1062
0
        rotation = A.data();
1063
0
    }
1064
1065
0
    std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
1066
0
            tmp(d * d * 4);
1067
1068
0
    ProductQuantizer pq_default(d2, M, 8);
1069
0
    ProductQuantizer& pq_regular = pq ? *pq : pq_default;
1070
0
    std::vector<uint8_t> codes(pq_regular.code_size * n);
1071
1072
0
    double t0 = getmillisecs();
1073
0
    for (int iter = 0; iter < niter; iter++) {
1074
0
        { // torch.mm(xtrain, rotation:t())
1075
0
            FINTEGER di = d, d2i = d2, ni = n;
1076
0
            float zero = 0, one = 1;
1077
0
            sgemm_("Transposed",
1078
0
                   "Not transposed",
1079
0
                   &d2i,
1080
0
                   &ni,
1081
0
                   &di,
1082
0
                   &one,
1083
0
                   rotation,
1084
0
                   &di,
1085
0
                   xtrain.data(),
1086
0
                   &di,
1087
0
                   &zero,
1088
0
                   xproj.data(),
1089
0
                   &d2i);
1090
0
        }
1091
1092
0
        pq_regular.cp.max_points_per_centroid = 1000;
1093
0
        pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
1094
0
        pq_regular.verbose = verbose;
1095
0
        pq_regular.train(n, xproj.data());
1096
1097
0
        if (verbose) {
1098
0
            printf("    encode / decode\n");
1099
0
        }
1100
0
        if (pq_regular.assign_index) {
1101
0
            pq_regular.compute_codes_with_assign_index(
1102
0
                    xproj.data(), codes.data(), n);
1103
0
        } else {
1104
0
            pq_regular.compute_codes(xproj.data(), codes.data(), n);
1105
0
        }
1106
0
        pq_regular.decode(codes.data(), pq_recons.data(), n);
1107
1108
0
        float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
1109
1110
0
        if (verbose)
1111
0
            printf("    Iteration %d (%d PQ iterations):"
1112
0
                   "%.3f s, obj=%g\n",
1113
0
                   iter,
1114
0
                   pq_regular.cp.niter,
1115
0
                   (getmillisecs() - t0) / 1000.0,
1116
0
                   pq_err);
1117
1118
0
        {
1119
0
            float *u = tmp.data(), *vt = &tmp[d * d];
1120
0
            float* sing_val = &tmp[2 * d * d];
1121
0
            FINTEGER di = d, d2i = d2, ni = n;
1122
0
            float one = 1, zero = 0;
1123
1124
0
            if (verbose) {
1125
0
                printf("    X * recons\n");
1126
0
            }
1127
            // torch.mm(xtrain:t(), pq_recons)
1128
0
            sgemm_("Not",
1129
0
                   "Transposed",
1130
0
                   &d2i,
1131
0
                   &di,
1132
0
                   &ni,
1133
0
                   &one,
1134
0
                   pq_recons.data(),
1135
0
                   &d2i,
1136
0
                   xtrain.data(),
1137
0
                   &di,
1138
0
                   &zero,
1139
0
                   xxr.data(),
1140
0
                   &d2i);
1141
1142
0
            FINTEGER lwork = -1, info = -1;
1143
0
            float worksz;
1144
            // workspace query
1145
0
            sgesvd_("All",
1146
0
                    "All",
1147
0
                    &d2i,
1148
0
                    &di,
1149
0
                    xxr.data(),
1150
0
                    &d2i,
1151
0
                    sing_val,
1152
0
                    vt,
1153
0
                    &d2i,
1154
0
                    u,
1155
0
                    &di,
1156
0
                    &worksz,
1157
0
                    &lwork,
1158
0
                    &info);
1159
1160
0
            lwork = int(worksz);
1161
0
            std::vector<float> work(lwork);
1162
            // u and vt swapped
1163
0
            sgesvd_("All",
1164
0
                    "All",
1165
0
                    &d2i,
1166
0
                    &di,
1167
0
                    xxr.data(),
1168
0
                    &d2i,
1169
0
                    sing_val,
1170
0
                    vt,
1171
0
                    &d2i,
1172
0
                    u,
1173
0
                    &di,
1174
0
                    work.data(),
1175
0
                    &lwork,
1176
0
                    &info);
1177
1178
0
            sgemm_("Transposed",
1179
0
                   "Transposed",
1180
0
                   &di,
1181
0
                   &d2i,
1182
0
                   &d2i,
1183
0
                   &one,
1184
0
                   u,
1185
0
                   &di,
1186
0
                   vt,
1187
0
                   &d2i,
1188
0
                   &zero,
1189
0
                   rotation,
1190
0
                   &di);
1191
0
        }
1192
0
        pq_regular.train_type = ProductQuantizer::Train_hot_start;
1193
0
    }
1194
1195
    // revert A matrix
1196
0
    if (d > d_in) {
1197
0
        for (long i = 0; i < d_out; i++)
1198
0
            memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1199
0
        A.resize(d_in * d_out);
1200
0
    }
1201
1202
0
    is_trained = true;
1203
0
    is_orthonormal = true;
1204
0
}
1205
1206
/*********************************************
1207
 * NormalizationTransform
1208
 *********************************************/
1209
1210
NormalizationTransform::NormalizationTransform(int d, float norm)
1211
0
        : VectorTransform(d, d), norm(norm) {}
1212
1213
NormalizationTransform::NormalizationTransform()
1214
0
        : VectorTransform(-1, -1), norm(-1) {}
1215
1216
void NormalizationTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1217
0
        const {
1218
0
    if (norm == 2.0) {
1219
0
        memcpy(xt, x, sizeof(x[0]) * n * d_in);
1220
0
        fvec_renorm_L2(d_in, n, xt);
1221
0
    } else {
1222
0
        FAISS_THROW_MSG("not implemented");
1223
0
    }
1224
0
}
1225
1226
void NormalizationTransform::reverse_transform(
1227
        idx_t n,
1228
        const float* xt,
1229
0
        float* x) const {
1230
0
    memcpy(x, xt, sizeof(xt[0]) * n * d_in);
1231
0
}
1232
1233
void NormalizationTransform::check_identical(
1234
0
        const VectorTransform& other_in) const {
1235
0
    VectorTransform::check_identical(other_in);
1236
0
    auto other = dynamic_cast<const NormalizationTransform*>(&other_in);
1237
0
    FAISS_THROW_IF_NOT(other);
1238
0
    FAISS_THROW_IF_NOT(other->norm == norm);
1239
0
}
1240
1241
/*********************************************
1242
 * CenteringTransform
1243
 *********************************************/
1244
1245
0
CenteringTransform::CenteringTransform(int d) : VectorTransform(d, d) {
1246
0
    is_trained = false;
1247
0
}
1248
1249
0
void CenteringTransform::train(idx_t n, const float* x) {
1250
0
    FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
1251
0
    mean.resize(d_in, 0);
1252
0
    for (idx_t i = 0; i < n; i++) {
1253
0
        for (size_t j = 0; j < d_in; j++) {
1254
0
            mean[j] += *x++;
1255
0
        }
1256
0
    }
1257
1258
0
    for (size_t j = 0; j < d_in; j++) {
1259
0
        mean[j] /= n;
1260
0
    }
1261
0
    is_trained = true;
1262
0
}
1263
1264
void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1265
0
        const {
1266
0
    FAISS_THROW_IF_NOT(is_trained);
1267
1268
0
    for (idx_t i = 0; i < n; i++) {
1269
0
        for (size_t j = 0; j < d_in; j++) {
1270
0
            *xt++ = *x++ - mean[j];
1271
0
        }
1272
0
    }
1273
0
}
1274
1275
void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
1276
0
        const {
1277
0
    FAISS_THROW_IF_NOT(is_trained);
1278
1279
0
    for (idx_t i = 0; i < n; i++) {
1280
0
        for (size_t j = 0; j < d_in; j++) {
1281
0
            *x++ = *xt++ + mean[j];
1282
0
        }
1283
0
    }
1284
0
}
1285
1286
void CenteringTransform::check_identical(
1287
0
        const VectorTransform& other_in) const {
1288
0
    VectorTransform::check_identical(other_in);
1289
0
    auto other = dynamic_cast<const CenteringTransform*>(&other_in);
1290
0
    FAISS_THROW_IF_NOT(other);
1291
0
    FAISS_THROW_IF_NOT(other->mean == mean);
1292
0
}
1293
1294
/*********************************************
1295
 * RemapDimensionsTransform
1296
 *********************************************/
1297
1298
RemapDimensionsTransform::RemapDimensionsTransform(
1299
        int d_in,
1300
        int d_out,
1301
        const int* map_in)
1302
0
        : VectorTransform(d_in, d_out) {
1303
0
    map.resize(d_out);
1304
0
    for (int i = 0; i < d_out; i++) {
1305
0
        map[i] = map_in[i];
1306
0
        FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1307
0
    }
1308
0
}
1309
1310
RemapDimensionsTransform::RemapDimensionsTransform(
1311
        int d_in,
1312
        int d_out,
1313
        bool uniform)
1314
0
        : VectorTransform(d_in, d_out) {
1315
0
    map.resize(d_out, -1);
1316
1317
0
    if (uniform) {
1318
0
        if (d_in < d_out) {
1319
0
            for (int i = 0; i < d_in; i++) {
1320
0
                map[i * d_out / d_in] = i;
1321
0
            }
1322
0
        } else {
1323
0
            for (int i = 0; i < d_out; i++) {
1324
0
                map[i] = i * d_in / d_out;
1325
0
            }
1326
0
        }
1327
0
    } else {
1328
0
        for (int i = 0; i < d_in && i < d_out; i++)
1329
0
            map[i] = i;
1330
0
    }
1331
0
}
1332
1333
void RemapDimensionsTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1334
0
        const {
1335
0
    for (idx_t i = 0; i < n; i++) {
1336
0
        for (int j = 0; j < d_out; j++) {
1337
0
            xt[j] = map[j] < 0 ? 0 : x[map[j]];
1338
0
        }
1339
0
        x += d_in;
1340
0
        xt += d_out;
1341
0
    }
1342
0
}
1343
1344
void RemapDimensionsTransform::reverse_transform(
1345
        idx_t n,
1346
        const float* xt,
1347
0
        float* x) const {
1348
0
    memset(x, 0, sizeof(*x) * n * d_in);
1349
0
    for (idx_t i = 0; i < n; i++) {
1350
0
        for (int j = 0; j < d_out; j++) {
1351
0
            if (map[j] >= 0)
1352
0
                x[map[j]] = xt[j];
1353
0
        }
1354
0
        x += d_in;
1355
0
        xt += d_out;
1356
0
    }
1357
0
}
1358
1359
void RemapDimensionsTransform::check_identical(
1360
0
        const VectorTransform& other_in) const {
1361
0
    VectorTransform::check_identical(other_in);
1362
0
    auto other = dynamic_cast<const RemapDimensionsTransform*>(&other_in);
1363
0
    FAISS_THROW_IF_NOT(other);
1364
0
    FAISS_THROW_IF_NOT(other->map == map);
1365
0
}