Coverage Report

Created: 2025-09-02 13:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/Clustering.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/Clustering.h>
11
#include <faiss/VectorTransform.h>
12
#include <faiss/impl/AuxIndexStructures.h>
13
14
#include <chrono>
15
#include <cinttypes>
16
#include <cmath>
17
#include <cstdio>
18
#include <cstring>
19
20
#include <omp.h>
21
22
#include <faiss/IndexFlat.h>
23
#include <faiss/impl/FaissAssert.h>
24
#include <faiss/impl/kmeans1d.h>
25
#include <faiss/utils/distances.h>
26
#include <faiss/utils/random.h>
27
#include <faiss/utils/utils.h>
28
29
namespace faiss {
30
31
0
Clustering::Clustering(int d, int k) : d(d), k(k) {}
32
33
Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
34
0
        : ClusteringParameters(cp), d(d), k(k) {}
35
36
0
void Clustering::post_process_centroids() {
37
0
    if (spherical) {
38
0
        fvec_renorm_L2(d, k, centroids.data());
39
0
    }
40
41
0
    if (int_centroids) {
42
0
        for (size_t i = 0; i < centroids.size(); i++)
43
0
            centroids[i] = roundf(centroids[i]);
44
0
    }
45
0
}
46
47
void Clustering::train(
48
        idx_t nx,
49
        const float* x_in,
50
        Index& index,
51
0
        const float* weights) {
52
0
    train_encoded(
53
0
            nx,
54
0
            reinterpret_cast<const uint8_t*>(x_in),
55
0
            nullptr,
56
0
            index,
57
0
            weights);
58
0
}
59
60
namespace {
61
62
0
uint64_t get_actual_rng_seed(const int seed) {
63
0
    return (seed >= 0)
64
0
            ? seed
65
0
            : static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
66
0
                                            .time_since_epoch()
67
0
                                            .count());
68
0
}
69
70
idx_t subsample_training_set(
71
        const Clustering& clus,
72
        idx_t nx,
73
        const uint8_t* x,
74
        size_t line_size,
75
        const float* weights,
76
        uint8_t** x_out,
77
0
        float** weights_out) {
78
0
    if (clus.verbose) {
79
0
        printf("Sampling a subset of %zd / %" PRId64 " for training\n",
80
0
               clus.k * clus.max_points_per_centroid,
81
0
               nx);
82
0
    }
83
84
0
    const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
85
86
0
    std::vector<int> perm;
87
0
    if (clus.use_faster_subsampling) {
88
        // use subsampling with splitmix64 rng
89
0
        SplitMix64RandomGenerator rng(actual_seed);
90
91
0
        const idx_t new_nx = clus.k * clus.max_points_per_centroid;
92
0
        perm.resize(new_nx);
93
0
        for (idx_t i = 0; i < new_nx; i++) {
94
0
            perm[i] = rng.rand_int(nx);
95
0
        }
96
0
    } else {
97
        // use subsampling with a default std rng
98
0
        perm.resize(nx);
99
0
        rand_perm(perm.data(), nx, actual_seed);
100
0
    }
101
102
0
    nx = clus.k * clus.max_points_per_centroid;
103
0
    uint8_t* x_new = new uint8_t[nx * line_size];
104
0
    *x_out = x_new;
105
106
    // might be worth omp-ing as well
107
0
    for (idx_t i = 0; i < nx; i++) {
108
0
        memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
109
0
    }
110
0
    if (weights) {
111
0
        float* weights_new = new float[nx];
112
0
        for (idx_t i = 0; i < nx; i++) {
113
0
            weights_new[i] = weights[perm[i]];
114
0
        }
115
0
        *weights_out = weights_new;
116
0
    } else {
117
0
        *weights_out = nullptr;
118
0
    }
119
0
    return nx;
120
0
}
121
122
/** compute centroids as (weighted) sum of training points
123
 *
124
 * @param x            training vectors, size n * code_size (from codec)
125
 * @param codec        how to decode the vectors (if NULL then cast to float*)
126
 * @param weights      per-training vector weight, size n (or NULL)
127
 * @param assign       nearest centroid for each training vector, size n
128
 * @param k_frozen     do not update the k_frozen first centroids
129
 * @param centroids    centroid vectors (output only), size k * d
130
 * @param hassign      histogram of assignments per centroid (size k),
131
 *                     should be 0 on input
132
 *
133
 */
134
135
void compute_centroids(
136
        size_t d,
137
        size_t k,
138
        size_t n,
139
        size_t k_frozen,
140
        const uint8_t* x,
141
        const Index* codec,
142
        const int64_t* assign,
143
        const float* weights,
144
        float* hassign,
145
0
        float* centroids) {
146
0
    k -= k_frozen;
147
0
    centroids += k_frozen * d;
148
149
0
    memset(centroids, 0, sizeof(*centroids) * d * k);
150
151
0
    size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
152
153
0
#pragma omp parallel
154
0
    {
155
0
        int nt = omp_get_num_threads();
156
0
        int rank = omp_get_thread_num();
157
158
        // this thread is taking care of centroids c0:c1
159
0
        size_t c0 = (k * rank) / nt;
160
0
        size_t c1 = (k * (rank + 1)) / nt;
161
0
        std::vector<float> decode_buffer(d);
162
163
0
        for (size_t i = 0; i < n; i++) {
164
0
            int64_t ci = assign[i];
165
0
            assert(ci >= 0 && ci < k + k_frozen);
166
0
            ci -= k_frozen;
167
0
            if (ci >= c0 && ci < c1) {
168
0
                float* c = centroids + ci * d;
169
0
                const float* xi;
170
0
                if (!codec) {
171
0
                    xi = reinterpret_cast<const float*>(x + i * line_size);
172
0
                } else {
173
0
                    float* xif = decode_buffer.data();
174
0
                    codec->sa_decode(1, x + i * line_size, xif);
175
0
                    xi = xif;
176
0
                }
177
0
                if (weights) {
178
0
                    float w = weights[i];
179
0
                    hassign[ci] += w;
180
0
                    for (size_t j = 0; j < d; j++) {
181
0
                        c[j] += xi[j] * w;
182
0
                    }
183
0
                } else {
184
0
                    hassign[ci] += 1.0;
185
0
                    for (size_t j = 0; j < d; j++) {
186
0
                        c[j] += xi[j];
187
0
                    }
188
0
                }
189
0
            }
190
0
        }
191
0
    }
192
193
0
#pragma omp parallel for
194
0
    for (idx_t ci = 0; ci < k; ci++) {
195
0
        if (hassign[ci] == 0) {
196
0
            continue;
197
0
        }
198
0
        float norm = 1 / hassign[ci];
199
0
        float* c = centroids + ci * d;
200
0
        for (size_t j = 0; j < d; j++) {
201
0
            c[j] *= norm;
202
0
        }
203
0
    }
204
0
}
205
206
// a bit above machine epsilon for float16
207
0
#define EPS (1 / 1024.)
208
209
/** Handle empty clusters by splitting larger ones.
210
 *
211
 * It works by slightly changing the centroids to make 2 clusters from
212
 * a single one. Takes the same arguments as compute_centroids.
213
 *
214
 * @return           nb of spliting operations (larger is worse)
215
 */
216
int split_clusters(
217
        size_t d,
218
        size_t k,
219
        size_t n,
220
        size_t k_frozen,
221
        float* hassign,
222
0
        float* centroids) {
223
0
    k -= k_frozen;
224
0
    centroids += k_frozen * d;
225
226
    /* Take care of void clusters */
227
0
    size_t nsplit = 0;
228
0
    RandomGenerator rng(1234);
229
0
    for (size_t ci = 0; ci < k; ci++) {
230
0
        if (hassign[ci] == 0) { /* need to redefine a centroid */
231
0
            size_t cj;
232
0
            for (cj = 0; true; cj = (cj + 1) % k) {
233
                /* probability to pick this cluster for split */
234
0
                float p = (hassign[cj] - 1.0) / (float)(n - k);
235
0
                float r = rng.rand_float();
236
0
                if (r < p) {
237
0
                    break; /* found our cluster to be split */
238
0
                }
239
0
            }
240
0
            memcpy(centroids + ci * d,
241
0
                   centroids + cj * d,
242
0
                   sizeof(*centroids) * d);
243
244
            /* small symmetric pertubation */
245
0
            for (size_t j = 0; j < d; j++) {
246
0
                if (j % 2 == 0) {
247
0
                    centroids[ci * d + j] *= 1 + EPS;
248
0
                    centroids[cj * d + j] *= 1 - EPS;
249
0
                } else {
250
0
                    centroids[ci * d + j] *= 1 - EPS;
251
0
                    centroids[cj * d + j] *= 1 + EPS;
252
0
                }
253
0
            }
254
255
            /* assume even split of the cluster */
256
0
            hassign[ci] = hassign[cj] / 2;
257
0
            hassign[cj] -= hassign[ci];
258
0
            nsplit++;
259
0
        }
260
0
    }
261
262
0
    return nsplit;
263
0
}
264
265
} // namespace
266
267
void Clustering::train_encoded(
268
        idx_t nx,
269
        const uint8_t* x_in,
270
        const Index* codec,
271
        Index& index,
272
0
        const float* weights) {
273
0
    FAISS_THROW_IF_NOT_FMT(
274
0
            nx >= k,
275
0
            "Number of training points (%" PRId64
276
0
            ") should be at least "
277
0
            "as large as number of clusters (%zd)",
278
0
            nx,
279
0
            k);
280
281
0
    FAISS_THROW_IF_NOT_FMT(
282
0
            (!codec || codec->d == d),
283
0
            "Codec dimension %d not the same as data dimension %d",
284
0
            int(codec->d),
285
0
            int(d));
286
287
0
    FAISS_THROW_IF_NOT_FMT(
288
0
            index.d == d,
289
0
            "Index dimension %d not the same as data dimension %d",
290
0
            int(index.d),
291
0
            int(d));
292
293
0
    double t0 = getmillisecs();
294
295
0
    if (!codec && check_input_data_for_NaNs) {
296
        // Check for NaNs in input data. Normally it is the user's
297
        // responsibility, but it may spare us some hard-to-debug
298
        // reports.
299
0
        const float* x = reinterpret_cast<const float*>(x_in);
300
0
        for (size_t i = 0; i < nx * d; i++) {
301
0
            FAISS_THROW_IF_NOT_MSG(
302
0
                    std::isfinite(x[i]), "input contains NaN's or Inf's");
303
0
        }
304
0
    }
305
306
0
    const uint8_t* x = x_in;
307
0
    std::unique_ptr<uint8_t[]> del1;
308
0
    std::unique_ptr<float[]> del3;
309
0
    size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
310
311
0
    if (nx > k * max_points_per_centroid) {
312
0
        uint8_t* x_new;
313
0
        float* weights_new;
314
0
        nx = subsample_training_set(
315
0
                *this, nx, x, line_size, weights, &x_new, &weights_new);
316
0
        del1.reset(x_new);
317
0
        x = x_new;
318
0
        del3.reset(weights_new);
319
0
        weights = weights_new;
320
0
    } else if (nx < k * min_points_per_centroid) {
321
0
        fprintf(stderr,
322
0
                "WARNING clustering %" PRId64
323
0
                " points to %zd centroids: "
324
0
                "please provide at least %" PRId64 " training points\n",
325
0
                nx,
326
0
                k,
327
0
                idx_t(k) * min_points_per_centroid);
328
0
    }
329
330
0
    if (nx == k) {
331
        // this is a corner case, just copy training set to clusters
332
0
        if (verbose) {
333
0
            printf("Number of training points (%" PRId64
334
0
                   ") same as number of "
335
0
                   "clusters, just copying\n",
336
0
                   nx);
337
0
        }
338
0
        centroids.resize(d * k);
339
0
        if (!codec) {
340
0
            memcpy(centroids.data(), x_in, sizeof(float) * d * k);
341
0
        } else {
342
0
            codec->sa_decode(nx, x_in, centroids.data());
343
0
        }
344
345
        // one fake iteration...
346
0
        ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
347
0
        iteration_stats.push_back(stats);
348
349
0
        index.reset();
350
0
        index.add(k, centroids.data());
351
0
        return;
352
0
    }
353
354
0
    if (verbose) {
355
0
        printf("Clustering %" PRId64
356
0
               " points in %zdD to %zd clusters, "
357
0
               "redo %d times, %d iterations\n",
358
0
               nx,
359
0
               d,
360
0
               k,
361
0
               nredo,
362
0
               niter);
363
0
        if (codec) {
364
0
            printf("Input data encoded in %zd bytes per vector\n",
365
0
                   codec->sa_code_size());
366
0
        }
367
0
    }
368
369
0
    std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
370
0
    std::unique_ptr<float[]> dis(new float[nx]);
371
372
    // remember best iteration for redo
373
0
    bool lower_is_better = !is_similarity_metric(index.metric_type);
374
0
    float best_obj = lower_is_better ? HUGE_VALF : -HUGE_VALF;
375
0
    std::vector<ClusteringIterationStats> best_iteration_stats;
376
0
    std::vector<float> best_centroids;
377
378
    // support input centroids
379
380
0
    FAISS_THROW_IF_NOT_MSG(
381
0
            centroids.size() % d == 0,
382
0
            "size of provided input centroids not a multiple of dimension");
383
384
0
    size_t n_input_centroids = centroids.size() / d;
385
386
0
    if (verbose && n_input_centroids > 0) {
387
0
        printf("  Using %zd centroids provided as input (%sfrozen)\n",
388
0
               n_input_centroids,
389
0
               frozen_centroids ? "" : "not ");
390
0
    }
391
392
0
    double t_search_tot = 0;
393
0
    if (verbose) {
394
0
        printf("  Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
395
0
    }
396
0
    t0 = getmillisecs();
397
398
    // initialize seed
399
0
    const uint64_t actual_seed = get_actual_rng_seed(seed);
400
401
    // temporary buffer to decode vectors during the optimization
402
0
    std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
403
404
0
    for (int redo = 0; redo < nredo; redo++) {
405
0
        if (verbose && nredo > 1) {
406
0
            printf("Outer iteration %d / %d\n", redo, nredo);
407
0
        }
408
409
        // initialize (remaining) centroids with random points from the dataset
410
0
        centroids.resize(d * k);
411
0
        std::vector<int> perm(nx);
412
413
0
        rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
414
415
0
        if (!codec) {
416
0
            for (int i = n_input_centroids; i < k; i++) {
417
0
                memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
418
0
            }
419
0
        } else {
420
0
            for (int i = n_input_centroids; i < k; i++) {
421
0
                codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
422
0
            }
423
0
        }
424
425
0
        post_process_centroids();
426
427
        // prepare the index
428
429
0
        if (index.ntotal != 0) {
430
0
            index.reset();
431
0
        }
432
433
0
        if (!index.is_trained) {
434
0
            index.train(k, centroids.data());
435
0
        }
436
437
0
        index.add(k, centroids.data());
438
439
        // k-means iterations
440
441
0
        float obj = 0;
442
0
        for (int i = 0; i < niter; i++) {
443
0
            double t0s = getmillisecs();
444
445
0
            if (!codec) {
446
0
                index.search(
447
0
                        nx,
448
0
                        reinterpret_cast<const float*>(x),
449
0
                        1,
450
0
                        dis.get(),
451
0
                        assign.get());
452
0
            } else {
453
                // search by blocks of decode_block_size vectors
454
0
                size_t code_size = codec->sa_code_size();
455
0
                for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
456
0
                    size_t i1 = i0 + decode_block_size;
457
0
                    if (i1 > nx) {
458
0
                        i1 = nx;
459
0
                    }
460
0
                    codec->sa_decode(
461
0
                            i1 - i0, x + code_size * i0, decode_buffer.data());
462
0
                    index.search(
463
0
                            i1 - i0,
464
0
                            decode_buffer.data(),
465
0
                            1,
466
0
                            dis.get() + i0,
467
0
                            assign.get() + i0);
468
0
                }
469
0
            }
470
471
0
            InterruptCallback::check();
472
0
            t_search_tot += getmillisecs() - t0s;
473
474
            // accumulate objective
475
0
            obj = 0;
476
0
            for (int j = 0; j < nx; j++) {
477
0
                obj += dis[j];
478
0
            }
479
480
            // update the centroids
481
0
            std::vector<float> hassign(k);
482
483
0
            size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
484
0
            compute_centroids(
485
0
                    d,
486
0
                    k,
487
0
                    nx,
488
0
                    k_frozen,
489
0
                    x,
490
0
                    codec,
491
0
                    assign.get(),
492
0
                    weights,
493
0
                    hassign.data(),
494
0
                    centroids.data());
495
496
0
            int nsplit = split_clusters(
497
0
                    d, k, nx, k_frozen, hassign.data(), centroids.data());
498
499
            // collect statistics
500
0
            ClusteringIterationStats stats = {
501
0
                    obj,
502
0
                    (getmillisecs() - t0) / 1000.0,
503
0
                    t_search_tot / 1000,
504
0
                    imbalance_factor(nx, k, assign.get()),
505
0
                    nsplit};
506
0
            iteration_stats.push_back(stats);
507
508
0
            if (verbose) {
509
0
                printf("  Iteration %d (%.2f s, search %.2f s): "
510
0
                       "objective=%g imbalance=%.3f nsplit=%d       \r",
511
0
                       i,
512
0
                       stats.time,
513
0
                       stats.time_search,
514
0
                       stats.obj,
515
0
                       stats.imbalance_factor,
516
0
                       nsplit);
517
0
                fflush(stdout);
518
0
            }
519
520
0
            post_process_centroids();
521
522
            // add centroids to index for the next iteration (or for output)
523
524
0
            index.reset();
525
0
            if (update_index) {
526
0
                index.train(k, centroids.data());
527
0
            }
528
529
0
            index.add(k, centroids.data());
530
0
            InterruptCallback::check();
531
0
        }
532
533
0
        if (verbose)
534
0
            printf("\n");
535
0
        if (nredo > 1) {
536
0
            if ((lower_is_better && obj < best_obj) ||
537
0
                (!lower_is_better && obj > best_obj)) {
538
0
                if (verbose) {
539
0
                    printf("Objective improved: keep new clusters\n");
540
0
                }
541
0
                best_centroids = centroids;
542
0
                best_iteration_stats = iteration_stats;
543
0
                best_obj = obj;
544
0
            }
545
0
            index.reset();
546
0
        }
547
0
    }
548
0
    if (nredo > 1) {
549
0
        centroids = best_centroids;
550
0
        iteration_stats = best_iteration_stats;
551
0
        index.reset();
552
0
        index.add(k, best_centroids.data());
553
0
    }
554
0
}
555
556
0
Clustering1D::Clustering1D(int k) : Clustering(1, k) {}
557
558
Clustering1D::Clustering1D(int k, const ClusteringParameters& cp)
559
0
        : Clustering(1, k, cp) {}
560
561
0
void Clustering1D::train_exact(idx_t n, const float* x) {
562
0
    const float* xt = x;
563
564
0
    std::unique_ptr<uint8_t[]> del;
565
0
    if (n > k * max_points_per_centroid) {
566
0
        uint8_t* x_new;
567
0
        float* weights_new;
568
0
        n = subsample_training_set(
569
0
                *this,
570
0
                n,
571
0
                (uint8_t*)x,
572
0
                sizeof(float) * d,
573
0
                nullptr,
574
0
                &x_new,
575
0
                &weights_new);
576
0
        del.reset(x_new);
577
0
        xt = (float*)x_new;
578
0
    }
579
580
0
    centroids.resize(k);
581
0
    double uf = kmeans1d(xt, n, k, centroids.data());
582
583
0
    ClusteringIterationStats stats = {0.0, 0.0, 0.0, uf, 0};
584
0
    iteration_stats.push_back(stats);
585
0
}
586
587
float kmeans_clustering(
588
        size_t d,
589
        size_t n,
590
        size_t k,
591
        const float* x,
592
0
        float* centroids) {
593
0
    Clustering clus(d, k);
594
0
    clus.verbose = d * n * k > (size_t(1) << 30);
595
    // display logs if > 1Gflop per iteration
596
0
    IndexFlatL2 index(d);
597
0
    clus.train(n, x, index);
598
0
    memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
599
0
    return clus.iteration_stats.back().obj;
600
0
}
601
602
/******************************************************************************
603
 * ProgressiveDimClustering implementation
604
 ******************************************************************************/
605
606
0
ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
607
0
    progressive_dim_steps = 10;
608
0
    apply_pca = true; // seems a good idea to do this by default
609
0
    niter = 10;       // reduce nb of iterations per step
610
0
}
611
612
0
Index* ProgressiveDimIndexFactory::operator()(int dim) {
613
0
    return new IndexFlatL2(dim);
614
0
}
615
616
0
ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
617
618
ProgressiveDimClustering::ProgressiveDimClustering(
619
        int d,
620
        int k,
621
        const ProgressiveDimClusteringParameters& cp)
622
0
        : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
623
624
namespace {
625
626
0
void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
627
0
    idx_t d = std::min(d1, d2);
628
0
    for (idx_t i = 0; i < n; i++) {
629
0
        memcpy(dest, src, sizeof(float) * d);
630
0
        src += d1;
631
0
        dest += d2;
632
0
    }
633
0
}
634
635
} // namespace
636
637
void ProgressiveDimClustering::train(
638
        idx_t n,
639
        const float* x,
640
0
        ProgressiveDimIndexFactory& factory) {
641
0
    int d_prev = 0;
642
643
0
    PCAMatrix pca(d, d);
644
645
0
    std::vector<float> xbuf;
646
0
    if (apply_pca) {
647
0
        if (verbose) {
648
0
            printf("Training PCA transform\n");
649
0
        }
650
0
        pca.train(n, x);
651
0
        if (verbose) {
652
0
            printf("Apply PCA\n");
653
0
        }
654
0
        xbuf.resize(n * d);
655
0
        pca.apply_noalloc(n, x, xbuf.data());
656
0
        x = xbuf.data();
657
0
    }
658
659
0
    for (int iter = 0; iter < progressive_dim_steps; iter++) {
660
0
        int di = int(pow(d, (1. + iter) / progressive_dim_steps));
661
0
        if (verbose) {
662
0
            printf("Progressive dim step %d: cluster in dimension %d\n",
663
0
                   iter,
664
0
                   di);
665
0
        }
666
0
        std::unique_ptr<Index> clustering_index(factory(di));
667
668
0
        Clustering clus(di, k, *this);
669
0
        if (d_prev > 0) {
670
            // copy warm-start centroids (padded with 0s)
671
0
            clus.centroids.resize(k * di);
672
0
            copy_columns(
673
0
                    k, d_prev, centroids.data(), di, clus.centroids.data());
674
0
        }
675
0
        std::vector<float> xsub(n * di);
676
0
        copy_columns(n, d, x, di, xsub.data());
677
678
0
        clus.train(n, xsub.data(), *clustering_index.get());
679
680
0
        centroids = clus.centroids;
681
0
        iteration_stats.insert(
682
0
                iteration_stats.end(),
683
0
                clus.iteration_stats.begin(),
684
0
                clus.iteration_stats.end());
685
686
0
        d_prev = di;
687
0
    }
688
689
0
    if (apply_pca) {
690
0
        if (verbose) {
691
0
            printf("Revert PCA transform on centroids\n");
692
0
        }
693
0
        std::vector<float> cent_transformed(d * k);
694
0
        pca.reverse_transform(k, centroids.data(), cent_transformed.data());
695
0
        cent_transformed.swap(centroids);
696
0
    }
697
0
}
698
699
} // namespace faiss