Coverage Report

Created: 2025-10-28 13:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/ResidualQuantizer.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/ResidualQuantizer.h>
9
10
#include <algorithm>
11
#include <cmath>
12
#include <cstddef>
13
#include <cstdio>
14
#include <cstring>
15
#include <memory>
16
17
#include <faiss/IndexFlat.h>
18
#include <faiss/VectorTransform.h>
19
#include <faiss/impl/FaissAssert.h>
20
#include <faiss/impl/residual_quantizer_encode_steps.h>
21
#include <faiss/utils/distances.h>
22
#include <faiss/utils/hamming.h>
23
#include <faiss/utils/utils.h>
24
25
extern "C" {
26
27
// general matrix multiplication
28
int sgemm_(
29
        const char* transa,
30
        const char* transb,
31
        FINTEGER* m,
32
        FINTEGER* n,
33
        FINTEGER* k,
34
        const float* alpha,
35
        const float* a,
36
        FINTEGER* lda,
37
        const float* b,
38
        FINTEGER* ldb,
39
        float* beta,
40
        float* c,
41
        FINTEGER* ldc);
42
43
// http://www.netlib.org/clapack/old/single/sgels.c
44
// solve least squares
45
46
int sgelsd_(
47
        FINTEGER* m,
48
        FINTEGER* n,
49
        FINTEGER* nrhs,
50
        float* a,
51
        FINTEGER* lda,
52
        float* b,
53
        FINTEGER* ldb,
54
        float* s,
55
        float* rcond,
56
        FINTEGER* rank,
57
        float* work,
58
        FINTEGER* lwork,
59
        FINTEGER* iwork,
60
        FINTEGER* info);
61
}
62
63
namespace faiss {
64
65
0
ResidualQuantizer::ResidualQuantizer() {
66
0
    d = 0;
67
0
    M = 0;
68
0
    verbose = false;
69
0
}
70
71
ResidualQuantizer::ResidualQuantizer(
72
        size_t d,
73
        const std::vector<size_t>& nbits,
74
        Search_type_t search_type)
75
0
        : ResidualQuantizer() {
76
0
    this->search_type = search_type;
77
0
    this->d = d;
78
0
    M = nbits.size();
79
0
    this->nbits = nbits;
80
0
    set_derived_values();
81
0
}
82
83
ResidualQuantizer::ResidualQuantizer(
84
        size_t d,
85
        size_t M,
86
        size_t nbits,
87
        Search_type_t search_type)
88
0
        : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
89
90
void ResidualQuantizer::initialize_from(
91
        const ResidualQuantizer& other,
92
0
        int skip_M) {
93
0
    FAISS_THROW_IF_NOT(M + skip_M <= other.M);
94
0
    FAISS_THROW_IF_NOT(skip_M >= 0);
95
96
0
    Search_type_t this_search_type = search_type;
97
0
    int this_M = M;
98
99
    // a first good approximation: override everything
100
0
    *this = other;
101
102
    // adjust derived values
103
0
    M = this_M;
104
0
    search_type = this_search_type;
105
0
    nbits.resize(M);
106
0
    memcpy(nbits.data(),
107
0
           other.nbits.data() + skip_M,
108
0
           nbits.size() * sizeof(nbits[0]));
109
110
0
    set_derived_values();
111
112
    // resize codebooks if trained
113
0
    if (codebooks.size() > 0) {
114
0
        FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
115
0
        codebooks.resize(total_codebook_size * d);
116
0
        memcpy(codebooks.data(),
117
0
               other.codebooks.data() + other.codebook_offsets[skip_M] * d,
118
0
               codebooks.size() * sizeof(codebooks[0]));
119
        // TODO: norm_tabs?
120
0
    }
121
0
}
122
123
/****************************************************************
124
 * Training
125
 ****************************************************************/
126
127
0
void ResidualQuantizer::train(size_t n, const float* x) {
128
0
    codebooks.resize(d * codebook_offsets.back());
129
130
0
    if (verbose) {
131
0
        printf("Training ResidualQuantizer, with %zd steps on %zd %zdD vectors\n",
132
0
               M,
133
0
               n,
134
0
               size_t(d));
135
0
    }
136
137
0
    int cur_beam_size = 1;
138
0
    std::vector<float> residuals(x, x + n * d);
139
0
    std::vector<int32_t> codes;
140
0
    std::vector<float> distances;
141
0
    double t0 = getmillisecs();
142
0
    double clustering_time = 0;
143
144
0
    for (int m = 0; m < M; m++) {
145
0
        int K = 1 << nbits[m];
146
147
        // on which residuals to train
148
0
        std::vector<float>& train_residuals = residuals;
149
0
        std::vector<float> residuals1;
150
0
        if (train_type & Train_top_beam) {
151
0
            residuals1.resize(n * d);
152
0
            for (size_t j = 0; j < n; j++) {
153
0
                memcpy(residuals1.data() + j * d,
154
0
                       residuals.data() + j * d * cur_beam_size,
155
0
                       sizeof(residuals[0]) * d);
156
0
            }
157
0
            train_residuals = residuals1;
158
0
        }
159
0
        std::vector<float> codebooks;
160
0
        float obj = 0;
161
162
0
        std::unique_ptr<Index> assign_index;
163
0
        if (assign_index_factory) {
164
0
            assign_index.reset((*assign_index_factory)(d));
165
0
        } else {
166
0
            assign_index.reset(new IndexFlatL2(d));
167
0
        }
168
169
0
        double t1 = getmillisecs();
170
171
0
        if (!(train_type & Train_progressive_dim)) { // regular kmeans
172
0
            Clustering clus(d, K, cp);
173
0
            clus.train(
174
0
                    train_residuals.size() / d,
175
0
                    train_residuals.data(),
176
0
                    *assign_index.get());
177
0
            codebooks.swap(clus.centroids);
178
0
            assign_index->reset();
179
0
            obj = clus.iteration_stats.back().obj;
180
0
        } else { // progressive dim clustering
181
0
            ProgressiveDimClustering clus(d, K, cp);
182
0
            ProgressiveDimIndexFactory default_fac;
183
0
            clus.train(
184
0
                    train_residuals.size() / d,
185
0
                    train_residuals.data(),
186
0
                    assign_index_factory ? *assign_index_factory : default_fac);
187
0
            codebooks.swap(clus.centroids);
188
0
            obj = clus.iteration_stats.back().obj;
189
0
        }
190
0
        clustering_time += (getmillisecs() - t1) / 1000;
191
192
0
        memcpy(this->codebooks.data() + codebook_offsets[m] * d,
193
0
               codebooks.data(),
194
0
               codebooks.size() * sizeof(codebooks[0]));
195
196
        // quantize using the new codebooks
197
198
0
        int new_beam_size = std::min(cur_beam_size * K, max_beam_size);
199
0
        std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
200
0
        std::vector<float> new_residuals(n * new_beam_size * d);
201
0
        std::vector<float> new_distances(n * new_beam_size);
202
203
0
        size_t bs;
204
0
        { // determine batch size
205
0
            size_t mem = memory_per_point();
206
0
            if (n > 1 && mem * n > max_mem_distances) {
207
                // then split queries to reduce temp memory
208
0
                bs = std::max(max_mem_distances / mem, size_t(1));
209
0
            } else {
210
0
                bs = n;
211
0
            }
212
0
        }
213
214
0
        for (size_t i0 = 0; i0 < n; i0 += bs) {
215
0
            size_t i1 = std::min(i0 + bs, n);
216
217
            /* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
218
                i0, i1, K, assign_index->ntotal); */
219
220
0
            beam_search_encode_step(
221
0
                    d,
222
0
                    K,
223
0
                    codebooks.data(),
224
0
                    i1 - i0,
225
0
                    cur_beam_size,
226
0
                    residuals.data() + i0 * cur_beam_size * d,
227
0
                    m,
228
0
                    codes.data() + i0 * cur_beam_size * m,
229
0
                    new_beam_size,
230
0
                    new_codes.data() + i0 * new_beam_size * (m + 1),
231
0
                    new_residuals.data() + i0 * new_beam_size * d,
232
0
                    new_distances.data() + i0 * new_beam_size,
233
0
                    assign_index.get(),
234
0
                    approx_topk_mode);
235
0
        }
236
0
        codes.swap(new_codes);
237
0
        residuals.swap(new_residuals);
238
0
        distances.swap(new_distances);
239
240
0
        float sum_distances = 0;
241
0
        for (int j = 0; j < distances.size(); j++) {
242
0
            sum_distances += distances[j];
243
0
        }
244
245
0
        if (verbose) {
246
0
            printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
247
0
                   "total distance %g, beam_size %d->%d (batch size %zd)\n",
248
0
                   (getmillisecs() - t0) / 1000,
249
0
                   clustering_time,
250
0
                   m,
251
0
                   int(nbits[m]),
252
0
                   obj,
253
0
                   sum_distances,
254
0
                   cur_beam_size,
255
0
                   new_beam_size,
256
0
                   bs);
257
0
        }
258
0
        cur_beam_size = new_beam_size;
259
0
    }
260
261
0
    is_trained = true;
262
263
0
    if (train_type & Train_refine_codebook) {
264
0
        for (int iter = 0; iter < niter_codebook_refine; iter++) {
265
0
            if (verbose) {
266
0
                printf("re-estimating the codebooks to minimize "
267
0
                       "quantization errors (iter %d).\n",
268
0
                       iter);
269
0
            }
270
0
            retrain_AQ_codebook(n, x);
271
0
        }
272
0
    }
273
274
    // find min and max norms
275
0
    std::vector<float> norms(n);
276
277
0
    for (size_t i = 0; i < n; i++) {
278
0
        norms[i] = fvec_L2sqr(
279
0
                x + i * d, residuals.data() + i * cur_beam_size * d, d);
280
0
    }
281
282
    // fvec_norms_L2sqr(norms.data(), x, d, n);
283
0
    train_norm(n, norms.data());
284
285
0
    if (!(train_type & Skip_codebook_tables)) {
286
0
        compute_codebook_tables();
287
0
    }
288
0
}
289
290
0
float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
291
0
    FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
292
293
0
    if (verbose) {
294
0
        printf("  encoding %zd training vectors\n", n);
295
0
    }
296
0
    std::vector<uint8_t> codes(n * code_size);
297
0
    compute_codes(x, codes.data(), n);
298
299
    // compute reconstruction error
300
0
    float input_recons_error;
301
0
    {
302
0
        std::vector<float> x_recons(n * d);
303
0
        decode(codes.data(), x_recons.data(), n);
304
0
        input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
305
0
        if (verbose) {
306
0
            printf("  input quantization error %g\n", input_recons_error);
307
0
        }
308
0
    }
309
310
    // build matrix of the linear system
311
0
    std::vector<float> C(n * total_codebook_size);
312
0
    for (size_t i = 0; i < n; i++) {
313
0
        BitstringReader bsr(codes.data() + i * code_size, code_size);
314
0
        for (int m = 0; m < M; m++) {
315
0
            int idx = bsr.read(nbits[m]);
316
0
            C[i + (codebook_offsets[m] + idx) * n] = 1;
317
0
        }
318
0
    }
319
320
    // transpose training vectors
321
0
    std::vector<float> xt(n * d);
322
323
0
    for (size_t i = 0; i < n; i++) {
324
0
        for (size_t j = 0; j < d; j++) {
325
0
            xt[j * n + i] = x[i * d + j];
326
0
        }
327
0
    }
328
329
0
    { // solve least squares
330
0
        FINTEGER lwork = -1;
331
0
        FINTEGER di = d, ni = n, tcsi = total_codebook_size;
332
0
        FINTEGER info = -1, rank = -1;
333
334
0
        float rcond = 1e-4; // this is an important parameter because the code
335
                            // matrix can be rank deficient for small problems,
336
                            // the default rcond=-1 does not work
337
0
        float worksize;
338
0
        std::vector<float> sing_vals(total_codebook_size);
339
0
        FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
340
                              // upper bound
341
0
        std::vector<FINTEGER> iwork(
342
0
                3 * total_codebook_size * nlvl + 11 * total_codebook_size);
343
344
        // worksize query
345
0
        sgelsd_(&ni,
346
0
                &tcsi,
347
0
                &di,
348
0
                C.data(),
349
0
                &ni,
350
0
                xt.data(),
351
0
                &ni,
352
0
                sing_vals.data(),
353
0
                &rcond,
354
0
                &rank,
355
0
                &worksize,
356
0
                &lwork,
357
0
                iwork.data(),
358
0
                &info);
359
0
        FAISS_THROW_IF_NOT(info == 0);
360
361
0
        lwork = worksize;
362
0
        std::vector<float> work(lwork);
363
        // actual call
364
0
        sgelsd_(&ni,
365
0
                &tcsi,
366
0
                &di,
367
0
                C.data(),
368
0
                &ni,
369
0
                xt.data(),
370
0
                &ni,
371
0
                sing_vals.data(),
372
0
                &rcond,
373
0
                &rank,
374
0
                work.data(),
375
0
                &lwork,
376
0
                iwork.data(),
377
0
                &info);
378
0
        FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
379
0
        if (verbose) {
380
0
            printf("   sgelsd rank=%d/%d\n",
381
0
                   int(rank),
382
0
                   int(total_codebook_size));
383
0
        }
384
0
    }
385
386
    // result is in xt, re-transpose to codebook
387
388
0
    for (size_t i = 0; i < total_codebook_size; i++) {
389
0
        for (size_t j = 0; j < d; j++) {
390
0
            codebooks[i * d + j] = xt[j * n + i];
391
0
            FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
392
0
        }
393
0
    }
394
395
0
    float output_recons_error = 0;
396
0
    for (size_t j = 0; j < d; j++) {
397
0
        output_recons_error += fvec_norm_L2sqr(
398
0
                xt.data() + total_codebook_size + n * j,
399
0
                n - total_codebook_size);
400
0
    }
401
0
    if (verbose) {
402
0
        printf("  output quantization error %g\n", output_recons_error);
403
0
    }
404
0
    return output_recons_error;
405
0
}
406
407
0
size_t ResidualQuantizer::memory_per_point(int beam_size) const {
408
0
    if (beam_size < 0) {
409
0
        beam_size = max_beam_size;
410
0
    }
411
0
    size_t mem;
412
0
    mem = beam_size * d * 2 * sizeof(float); // size for 2 beams at a time
413
0
    mem += beam_size * beam_size *
414
0
            (sizeof(float) + sizeof(idx_t)); // size for 1 beam search result
415
0
    return mem;
416
0
}
417
418
/****************************************************************
419
 * Encoding
420
 ****************************************************************/
421
422
using namespace rq_encode_steps;
423
424
void ResidualQuantizer::compute_codes_add_centroids(
425
        const float* x,
426
        uint8_t* codes_out,
427
        size_t n,
428
0
        const float* centroids) const {
429
0
    FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
430
431
    //
432
0
    size_t mem = memory_per_point();
433
434
0
    size_t bs = max_mem_distances / mem;
435
0
    if (bs == 0) {
436
0
        bs = 1; // otherwise we can't do much
437
0
    }
438
439
    // prepare memory pools
440
0
    ComputeCodesAddCentroidsLUT0MemoryPool pool0;
441
0
    ComputeCodesAddCentroidsLUT1MemoryPool pool1;
442
443
0
    for (size_t i0 = 0; i0 < n; i0 += bs) {
444
0
        size_t i1 = std::min(n, i0 + bs);
445
0
        const float* cent = nullptr;
446
0
        if (centroids != nullptr) {
447
0
            cent = centroids + i0 * d;
448
0
        }
449
450
0
        if (use_beam_LUT == 0) {
451
0
            compute_codes_add_centroids_mp_lut0(
452
0
                    *this,
453
0
                    x + i0 * d,
454
0
                    codes_out + i0 * code_size,
455
0
                    i1 - i0,
456
0
                    cent,
457
0
                    pool0);
458
0
        } else if (use_beam_LUT == 1) {
459
0
            compute_codes_add_centroids_mp_lut1(
460
0
                    *this,
461
0
                    x + i0 * d,
462
0
                    codes_out + i0 * code_size,
463
0
                    i1 - i0,
464
0
                    cent,
465
0
                    pool1);
466
0
        }
467
0
    }
468
0
}
469
470
void ResidualQuantizer::refine_beam(
471
        size_t n,
472
        size_t beam_size,
473
        const float* x,
474
        int out_beam_size,
475
        int32_t* out_codes,
476
        float* out_residuals,
477
0
        float* out_distances) const {
478
0
    RefineBeamMemoryPool pool;
479
0
    refine_beam_mp(
480
0
            *this,
481
0
            n,
482
0
            beam_size,
483
0
            x,
484
0
            out_beam_size,
485
0
            out_codes,
486
0
            out_residuals,
487
0
            out_distances,
488
0
            pool);
489
0
}
490
491
/*******************************************************************
492
 * Functions using the dot products between codebook entries
493
 *******************************************************************/
494
495
void ResidualQuantizer::refine_beam_LUT(
496
        size_t n,
497
        const float* query_norms, // size n
498
        const float* query_cp,    //
499
        int out_beam_size,
500
        int32_t* out_codes,
501
0
        float* out_distances) const {
502
0
    RefineBeamLUTMemoryPool pool;
503
0
    refine_beam_LUT_mp(
504
0
            *this,
505
0
            n,
506
0
            query_norms,
507
0
            query_cp,
508
0
            out_beam_size,
509
0
            out_codes,
510
0
            out_distances,
511
0
            pool);
512
0
}
513
514
} // namespace faiss