Coverage Report

Created: 2025-12-11 01:17

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/Clustering.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/Clustering.h>
11
#include <faiss/VectorTransform.h>
12
#include <faiss/impl/AuxIndexStructures.h>
13
14
#include <chrono>
15
#include <cinttypes>
16
#include <cmath>
17
#include <cstdio>
18
#include <cstring>
19
20
#include <omp.h>
21
22
#include <faiss/IndexFlat.h>
23
#include <faiss/impl/FaissAssert.h>
24
#include <faiss/impl/kmeans1d.h>
25
#include <faiss/utils/distances.h>
26
#include <faiss/utils/random.h>
27
#include <faiss/utils/utils.h>
28
29
namespace faiss {
30
31
0
Clustering::Clustering(int d, int k) : d(d), k(k) {}
32
33
Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
34
29
        : ClusteringParameters(cp), d(d), k(k) {}
35
36
297
void Clustering::post_process_centroids() {
37
297
    if (spherical) {
38
143
        fvec_renorm_L2(d, k, centroids.data());
39
143
    }
40
41
297
    if (int_centroids) {
42
0
        for (size_t i = 0; i < centroids.size(); i++)
43
0
            centroids[i] = roundf(centroids[i]);
44
0
    }
45
297
}
46
47
void Clustering::train(
48
        idx_t nx,
49
        const float* x_in,
50
        Index& index,
51
29
        const float* weights) {
52
29
    train_encoded(
53
29
            nx,
54
29
            reinterpret_cast<const uint8_t*>(x_in),
55
29
            nullptr,
56
29
            index,
57
29
            weights);
58
29
}
59
60
namespace {
61
62
27
uint64_t get_actual_rng_seed(const int seed) {
63
27
    return (seed >= 0)
64
27
            ? seed
65
27
            : static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
66
0
                                            .time_since_epoch()
67
0
                                            .count());
68
27
}
69
70
idx_t subsample_training_set(
71
        const Clustering& clus,
72
        idx_t nx,
73
        const uint8_t* x,
74
        size_t line_size,
75
        const float* weights,
76
        uint8_t** x_out,
77
0
        float** weights_out) {
78
0
    if (clus.verbose) {
79
0
        printf("Sampling a subset of %zd / %" PRId64 " for training\n",
80
0
               clus.k * clus.max_points_per_centroid,
81
0
               nx);
82
0
    }
83
84
0
    const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
85
86
0
    std::vector<int> perm;
87
0
    if (clus.use_faster_subsampling) {
88
        // use subsampling with splitmix64 rng
89
0
        SplitMix64RandomGenerator rng(actual_seed);
90
91
0
        const idx_t new_nx = clus.k * clus.max_points_per_centroid;
92
0
        perm.resize(new_nx);
93
0
        for (idx_t i = 0; i < new_nx; i++) {
94
0
            perm[i] = rng.rand_int(nx);
95
0
        }
96
0
    } else {
97
        // use subsampling with a default std rng
98
0
        perm.resize(nx);
99
0
        rand_perm(perm.data(), nx, actual_seed);
100
0
    }
101
102
0
    nx = clus.k * clus.max_points_per_centroid;
103
0
    uint8_t* x_new = new uint8_t[nx * line_size];
104
0
    *x_out = x_new;
105
106
    // might be worth omp-ing as well
107
0
    for (idx_t i = 0; i < nx; i++) {
108
0
        memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
109
0
    }
110
0
    if (weights) {
111
0
        float* weights_new = new float[nx];
112
0
        for (idx_t i = 0; i < nx; i++) {
113
0
            weights_new[i] = weights[perm[i]];
114
0
        }
115
0
        *weights_out = weights_new;
116
0
    } else {
117
0
        *weights_out = nullptr;
118
0
    }
119
0
    return nx;
120
0
}
121
122
/** compute centroids as (weighted) sum of training points
123
 *
124
 * @param x            training vectors, size n * code_size (from codec)
125
 * @param codec        how to decode the vectors (if NULL then cast to float*)
126
 * @param weights      per-training vector weight, size n (or NULL)
127
 * @param assign       nearest centroid for each training vector, size n
128
 * @param k_frozen     do not update the k_frozen first centroids
129
 * @param centroids    centroid vectors (output only), size k * d
130
 * @param hassign      histogram of assignments per centroid (size k),
131
 *                     should be 0 on input
132
 *
133
 */
134
135
void compute_centroids(
136
        size_t d,
137
        size_t k,
138
        size_t n,
139
        size_t k_frozen,
140
        const uint8_t* x,
141
        const Index* codec,
142
        const int64_t* assign,
143
        const float* weights,
144
        float* hassign,
145
270
        float* centroids) {
146
270
    k -= k_frozen;
147
270
    centroids += k_frozen * d;
148
149
270
    memset(centroids, 0, sizeof(*centroids) * d * k);
150
151
270
    size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
152
153
270
#pragma omp parallel
154
837
    {
155
837
        int nt = omp_get_num_threads();
156
837
        int rank = omp_get_thread_num();
157
158
        // this thread is taking care of centroids c0:c1
159
837
        size_t c0 = (k * rank) / nt;
160
837
        size_t c1 = (k * (rank + 1)) / nt;
161
837
        std::vector<float> decode_buffer(d);
162
163
135k
        for (size_t i = 0; i < n; i++) {
164
134k
            int64_t ci = assign[i];
165
134k
            assert(ci >= 0 && ci < k + k_frozen);
166
135k
            ci -= k_frozen;
167
135k
            if (ci >= c0 && ci < c1) {
168
34.6k
                float* c = centroids + ci * d;
169
34.6k
                const float* xi;
170
34.6k
                if (!codec) {
171
34.6k
                    xi = reinterpret_cast<const float*>(x + i * line_size);
172
18.4E
                } else {
173
18.4E
                    float* xif = decode_buffer.data();
174
18.4E
                    codec->sa_decode(1, x + i * line_size, xif);
175
18.4E
                    xi = xif;
176
18.4E
                }
177
34.6k
                if (weights) {
178
0
                    float w = weights[i];
179
0
                    hassign[ci] += w;
180
0
                    for (size_t j = 0; j < d; j++) {
181
0
                        c[j] += xi[j] * w;
182
0
                    }
183
34.6k
                } else {
184
34.6k
                    hassign[ci] += 1.0;
185
2.21M
                    for (size_t j = 0; j < d; j++) {
186
2.18M
                        c[j] += xi[j];
187
2.18M
                    }
188
34.6k
                }
189
34.6k
            }
190
135k
        }
191
837
    }
192
193
270
#pragma omp parallel for
194
1.41k
    for (idx_t ci = 0; ci < k; ci++) {
195
709
        if (hassign[ci] == 0) {
196
0
            continue;
197
0
        }
198
709
        float norm = 1 / hassign[ci];
199
709
        float* c = centroids + ci * d;
200
84.2k
        for (size_t j = 0; j < d; j++) {
201
83.5k
            c[j] *= norm;
202
83.5k
        }
203
709
    }
204
270
}
205
206
// a bit above machine epsilon for float16
207
0
#define EPS (1 / 1024.)
208
209
/** Handle empty clusters by splitting larger ones.
210
 *
211
 * It works by slightly changing the centroids to make 2 clusters from
212
 * a single one. Takes the same arguments as compute_centroids.
213
 *
214
 * @return           nb of spliting operations (larger is worse)
215
 */
216
int split_clusters(
217
        size_t d,
218
        size_t k,
219
        size_t n,
220
        size_t k_frozen,
221
        float* hassign,
222
270
        float* centroids) {
223
270
    k -= k_frozen;
224
270
    centroids += k_frozen * d;
225
226
    /* Take care of void clusters */
227
270
    size_t nsplit = 0;
228
270
    RandomGenerator rng(1234);
229
1.35k
    for (size_t ci = 0; ci < k; ci++) {
230
1.08k
        if (hassign[ci] == 0) { /* need to redefine a centroid */
231
0
            size_t cj;
232
0
            for (cj = 0; true; cj = (cj + 1) % k) {
233
                /* probability to pick this cluster for split */
234
0
                float p = (hassign[cj] - 1.0) / (float)(n - k);
235
0
                float r = rng.rand_float();
236
0
                if (r < p) {
237
0
                    break; /* found our cluster to be split */
238
0
                }
239
0
            }
240
0
            memcpy(centroids + ci * d,
241
0
                   centroids + cj * d,
242
0
                   sizeof(*centroids) * d);
243
244
            /* small symmetric pertubation */
245
0
            for (size_t j = 0; j < d; j++) {
246
0
                if (j % 2 == 0) {
247
0
                    centroids[ci * d + j] *= 1 + EPS;
248
0
                    centroids[cj * d + j] *= 1 - EPS;
249
0
                } else {
250
0
                    centroids[ci * d + j] *= 1 - EPS;
251
0
                    centroids[cj * d + j] *= 1 + EPS;
252
0
                }
253
0
            }
254
255
            /* assume even split of the cluster */
256
0
            hassign[ci] = hassign[cj] / 2;
257
0
            hassign[cj] -= hassign[ci];
258
0
            nsplit++;
259
0
        }
260
1.08k
    }
261
262
270
    return nsplit;
263
270
}
264
265
} // namespace
266
267
void Clustering::train_encoded(
268
        idx_t nx,
269
        const uint8_t* x_in,
270
        const Index* codec,
271
        Index& index,
272
29
        const float* weights) {
273
29
    FAISS_THROW_IF_NOT_FMT(
274
29
            nx >= k,
275
29
            "Number of training points (%" PRId64
276
29
            ") should be at least "
277
29
            "as large as number of clusters (%zd)",
278
29
            nx,
279
29
            k);
280
281
27
    FAISS_THROW_IF_NOT_FMT(
282
27
            (!codec || codec->d == d),
283
27
            "Codec dimension %d not the same as data dimension %d",
284
27
            int(codec->d),
285
27
            int(d));
286
287
27
    FAISS_THROW_IF_NOT_FMT(
288
27
            index.d == d,
289
27
            "Index dimension %d not the same as data dimension %d",
290
27
            int(index.d),
291
27
            int(d));
292
293
27
    double t0 = getmillisecs();
294
295
27
    if (!codec && check_input_data_for_NaNs) {
296
        // Check for NaNs in input data. Normally it is the user's
297
        // responsibility, but it may spare us some hard-to-debug
298
        // reports.
299
27
        const float* x = reinterpret_cast<const float*>(x_in);
300
618k
        for (size_t i = 0; i < nx * d; i++) {
301
618k
            FAISS_THROW_IF_NOT_MSG(
302
618k
                    std::isfinite(x[i]), "input contains NaN's or Inf's");
303
618k
        }
304
27
    }
305
306
27
    const uint8_t* x = x_in;
307
27
    std::unique_ptr<uint8_t[]> del1;
308
27
    std::unique_ptr<float[]> del3;
309
27
    size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
310
311
27
    if (nx > k * max_points_per_centroid) {
312
0
        uint8_t* x_new;
313
0
        float* weights_new;
314
0
        nx = subsample_training_set(
315
0
                *this, nx, x, line_size, weights, &x_new, &weights_new);
316
0
        del1.reset(x_new);
317
0
        x = x_new;
318
0
        del3.reset(weights_new);
319
0
        weights = weights_new;
320
27
    } else if (nx < k * min_points_per_centroid) {
321
20
        fprintf(stderr,
322
20
                "WARNING clustering %" PRId64
323
20
                " points to %zd centroids: "
324
20
                "please provide at least %" PRId64 " training points\n",
325
20
                nx,
326
20
                k,
327
20
                idx_t(k) * min_points_per_centroid);
328
20
    }
329
330
27
    if (nx == k) {
331
        // this is a corner case, just copy training set to clusters
332
0
        if (verbose) {
333
0
            printf("Number of training points (%" PRId64
334
0
                   ") same as number of "
335
0
                   "clusters, just copying\n",
336
0
                   nx);
337
0
        }
338
0
        centroids.resize(d * k);
339
0
        if (!codec) {
340
0
            memcpy(centroids.data(), x_in, sizeof(float) * d * k);
341
0
        } else {
342
0
            codec->sa_decode(nx, x_in, centroids.data());
343
0
        }
344
345
        // one fake iteration...
346
0
        ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
347
0
        iteration_stats.push_back(stats);
348
349
0
        index.reset();
350
0
        index.add(k, centroids.data());
351
0
        return;
352
0
    }
353
354
27
    if (verbose) {
355
0
        printf("Clustering %" PRId64
356
0
               " points in %zdD to %zd clusters, "
357
0
               "redo %d times, %d iterations\n",
358
0
               nx,
359
0
               d,
360
0
               k,
361
0
               nredo,
362
0
               niter);
363
0
        if (codec) {
364
0
            printf("Input data encoded in %zd bytes per vector\n",
365
0
                   codec->sa_code_size());
366
0
        }
367
0
    }
368
369
27
    std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
370
27
    std::unique_ptr<float[]> dis(new float[nx]);
371
372
    // remember best iteration for redo
373
27
    bool lower_is_better = !is_similarity_metric(index.metric_type);
374
27
    float best_obj = lower_is_better ? HUGE_VALF : -HUGE_VALF;
375
27
    std::vector<ClusteringIterationStats> best_iteration_stats;
376
27
    std::vector<float> best_centroids;
377
378
    // support input centroids
379
380
27
    FAISS_THROW_IF_NOT_MSG(
381
27
            centroids.size() % d == 0,
382
27
            "size of provided input centroids not a multiple of dimension");
383
384
27
    size_t n_input_centroids = centroids.size() / d;
385
386
27
    if (verbose && n_input_centroids > 0) {
387
0
        printf("  Using %zd centroids provided as input (%sfrozen)\n",
388
0
               n_input_centroids,
389
0
               frozen_centroids ? "" : "not ");
390
0
    }
391
392
27
    double t_search_tot = 0;
393
27
    if (verbose) {
394
0
        printf("  Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
395
0
    }
396
27
    t0 = getmillisecs();
397
398
    // initialize seed
399
27
    const uint64_t actual_seed = get_actual_rng_seed(seed);
400
401
    // temporary buffer to decode vectors during the optimization
402
27
    std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
403
404
54
    for (int redo = 0; redo < nredo; redo++) {
405
27
        if (verbose && nredo > 1) {
406
0
            printf("Outer iteration %d / %d\n", redo, nredo);
407
0
        }
408
409
        // initialize (remaining) centroids with random points from the dataset
410
27
        centroids.resize(d * k);
411
27
        std::vector<int> perm(nx);
412
413
27
        rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
414
415
27
        if (!codec) {
416
135
            for (int i = n_input_centroids; i < k; i++) {
417
108
                memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
418
108
            }
419
27
        } else {
420
0
            for (int i = n_input_centroids; i < k; i++) {
421
0
                codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
422
0
            }
423
0
        }
424
425
27
        post_process_centroids();
426
427
        // prepare the index
428
429
27
        if (index.ntotal != 0) {
430
0
            index.reset();
431
0
        }
432
433
27
        if (!index.is_trained) {
434
0
            index.train(k, centroids.data());
435
0
        }
436
437
27
        index.add(k, centroids.data());
438
439
        // k-means iterations
440
441
27
        float obj = 0;
442
297
        for (int i = 0; i < niter; i++) {
443
270
            double t0s = getmillisecs();
444
445
270
            if (!codec) {
446
270
                index.search(
447
270
                        nx,
448
270
                        reinterpret_cast<const float*>(x),
449
270
                        1,
450
270
                        dis.get(),
451
270
                        assign.get());
452
270
            } else {
453
                // search by blocks of decode_block_size vectors
454
0
                size_t code_size = codec->sa_code_size();
455
0
                for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
456
0
                    size_t i1 = i0 + decode_block_size;
457
0
                    if (i1 > nx) {
458
0
                        i1 = nx;
459
0
                    }
460
0
                    codec->sa_decode(
461
0
                            i1 - i0, x + code_size * i0, decode_buffer.data());
462
0
                    index.search(
463
0
                            i1 - i0,
464
0
                            decode_buffer.data(),
465
0
                            1,
466
0
                            dis.get() + i0,
467
0
                            assign.get() + i0);
468
0
                }
469
0
            }
470
471
270
            InterruptCallback::check();
472
270
            t_search_tot += getmillisecs() - t0s;
473
474
            // accumulate objective
475
270
            obj = 0;
476
38.7k
            for (int j = 0; j < nx; j++) {
477
38.4k
                obj += dis[j];
478
38.4k
            }
479
480
            // update the centroids
481
270
            std::vector<float> hassign(k);
482
483
270
            size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
484
270
            compute_centroids(
485
270
                    d,
486
270
                    k,
487
270
                    nx,
488
270
                    k_frozen,
489
270
                    x,
490
270
                    codec,
491
270
                    assign.get(),
492
270
                    weights,
493
270
                    hassign.data(),
494
270
                    centroids.data());
495
496
270
            int nsplit = split_clusters(
497
270
                    d, k, nx, k_frozen, hassign.data(), centroids.data());
498
499
            // collect statistics
500
270
            ClusteringIterationStats stats = {
501
270
                    obj,
502
270
                    (getmillisecs() - t0) / 1000.0,
503
270
                    t_search_tot / 1000,
504
270
                    imbalance_factor(nx, k, assign.get()),
505
270
                    nsplit};
506
270
            iteration_stats.push_back(stats);
507
508
270
            if (verbose) {
509
0
                printf("  Iteration %d (%.2f s, search %.2f s): "
510
0
                       "objective=%g imbalance=%.3f nsplit=%d       \r",
511
0
                       i,
512
0
                       stats.time,
513
0
                       stats.time_search,
514
0
                       stats.obj,
515
0
                       stats.imbalance_factor,
516
0
                       nsplit);
517
0
                fflush(stdout);
518
0
            }
519
520
270
            post_process_centroids();
521
522
            // add centroids to index for the next iteration (or for output)
523
524
270
            index.reset();
525
270
            if (update_index) {
526
0
                index.train(k, centroids.data());
527
0
            }
528
529
270
            index.add(k, centroids.data());
530
270
            InterruptCallback::check();
531
270
        }
532
533
27
        if (verbose)
534
0
            printf("\n");
535
27
        if (nredo > 1) {
536
0
            if ((lower_is_better && obj < best_obj) ||
537
0
                (!lower_is_better && obj > best_obj)) {
538
0
                if (verbose) {
539
0
                    printf("Objective improved: keep new clusters\n");
540
0
                }
541
0
                best_centroids = centroids;
542
0
                best_iteration_stats = iteration_stats;
543
0
                best_obj = obj;
544
0
            }
545
0
            index.reset();
546
0
        }
547
27
    }
548
27
    if (nredo > 1) {
549
0
        centroids = best_centroids;
550
0
        iteration_stats = best_iteration_stats;
551
0
        index.reset();
552
0
        index.add(k, best_centroids.data());
553
0
    }
554
27
}
555
556
0
Clustering1D::Clustering1D(int k) : Clustering(1, k) {}
557
558
Clustering1D::Clustering1D(int k, const ClusteringParameters& cp)
559
0
        : Clustering(1, k, cp) {}
560
561
0
void Clustering1D::train_exact(idx_t n, const float* x) {
562
0
    const float* xt = x;
563
564
0
    std::unique_ptr<uint8_t[]> del;
565
0
    if (n > k * max_points_per_centroid) {
566
0
        uint8_t* x_new;
567
0
        float* weights_new;
568
0
        n = subsample_training_set(
569
0
                *this,
570
0
                n,
571
0
                (uint8_t*)x,
572
0
                sizeof(float) * d,
573
0
                nullptr,
574
0
                &x_new,
575
0
                &weights_new);
576
0
        del.reset(x_new);
577
0
        xt = (float*)x_new;
578
0
    }
579
580
0
    centroids.resize(k);
581
0
    double uf = kmeans1d(xt, n, k, centroids.data());
582
583
0
    ClusteringIterationStats stats = {0.0, 0.0, 0.0, uf, 0};
584
0
    iteration_stats.push_back(stats);
585
0
}
586
587
float kmeans_clustering(
588
        size_t d,
589
        size_t n,
590
        size_t k,
591
        const float* x,
592
0
        float* centroids) {
593
0
    Clustering clus(d, k);
594
0
    clus.verbose = d * n * k > (size_t(1) << 30);
595
    // display logs if > 1Gflop per iteration
596
0
    IndexFlatL2 index(d);
597
0
    clus.train(n, x, index);
598
0
    memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
599
0
    return clus.iteration_stats.back().obj;
600
0
}
601
602
/******************************************************************************
603
 * ProgressiveDimClustering implementation
604
 ******************************************************************************/
605
606
0
ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
607
0
    progressive_dim_steps = 10;
608
0
    apply_pca = true; // seems a good idea to do this by default
609
0
    niter = 10;       // reduce nb of iterations per step
610
0
}
611
612
0
Index* ProgressiveDimIndexFactory::operator()(int dim) {
613
0
    return new IndexFlatL2(dim);
614
0
}
615
616
0
ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
617
618
ProgressiveDimClustering::ProgressiveDimClustering(
619
        int d,
620
        int k,
621
        const ProgressiveDimClusteringParameters& cp)
622
0
        : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
623
624
namespace {
625
626
0
void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
627
0
    idx_t d = std::min(d1, d2);
628
0
    for (idx_t i = 0; i < n; i++) {
629
0
        memcpy(dest, src, sizeof(float) * d);
630
0
        src += d1;
631
0
        dest += d2;
632
0
    }
633
0
}
634
635
} // namespace
636
637
void ProgressiveDimClustering::train(
638
        idx_t n,
639
        const float* x,
640
0
        ProgressiveDimIndexFactory& factory) {
641
0
    int d_prev = 0;
642
643
0
    PCAMatrix pca(d, d);
644
645
0
    std::vector<float> xbuf;
646
0
    if (apply_pca) {
647
0
        if (verbose) {
648
0
            printf("Training PCA transform\n");
649
0
        }
650
0
        pca.train(n, x);
651
0
        if (verbose) {
652
0
            printf("Apply PCA\n");
653
0
        }
654
0
        xbuf.resize(n * d);
655
0
        pca.apply_noalloc(n, x, xbuf.data());
656
0
        x = xbuf.data();
657
0
    }
658
659
0
    for (int iter = 0; iter < progressive_dim_steps; iter++) {
660
0
        int di = int(pow(d, (1. + iter) / progressive_dim_steps));
661
0
        if (verbose) {
662
0
            printf("Progressive dim step %d: cluster in dimension %d\n",
663
0
                   iter,
664
0
                   di);
665
0
        }
666
0
        std::unique_ptr<Index> clustering_index(factory(di));
667
668
0
        Clustering clus(di, k, *this);
669
0
        if (d_prev > 0) {
670
            // copy warm-start centroids (padded with 0s)
671
0
            clus.centroids.resize(k * di);
672
0
            copy_columns(
673
0
                    k, d_prev, centroids.data(), di, clus.centroids.data());
674
0
        }
675
0
        std::vector<float> xsub(n * di);
676
0
        copy_columns(n, d, x, di, xsub.data());
677
678
0
        clus.train(n, xsub.data(), *clustering_index.get());
679
680
0
        centroids = clus.centroids;
681
0
        iteration_stats.insert(
682
0
                iteration_stats.end(),
683
0
                clus.iteration_stats.begin(),
684
0
                clus.iteration_stats.end());
685
686
0
        d_prev = di;
687
0
    }
688
689
0
    if (apply_pca) {
690
0
        if (verbose) {
691
0
            printf("Revert PCA transform on centroids\n");
692
0
        }
693
0
        std::vector<float> cent_transformed(d * k);
694
0
        pca.reverse_transform(k, centroids.data(), cent_transformed.data());
695
0
        cent_transformed.swap(centroids);
696
0
    }
697
0
}
698
699
} // namespace faiss