Coverage Report

Created: 2026-03-14 03:07

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/residual_quantizer_encode_steps.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/residual_quantizer_encode_steps.h>
9
10
#include <faiss/impl/AuxIndexStructures.h>
11
#include <faiss/impl/FaissAssert.h>
12
#include <faiss/impl/ResidualQuantizer.h>
13
#include <faiss/utils/Heap.h>
14
#include <faiss/utils/distances.h>
15
#include <faiss/utils/simdlib.h>
16
#include <faiss/utils/utils.h>
17
18
#include <faiss/utils/approx_topk/approx_topk.h>
19
20
extern "C" {
21
22
// general matrix multiplication
23
int sgemm_(
24
        const char* transa,
25
        const char* transb,
26
        FINTEGER* m,
27
        FINTEGER* n,
28
        FINTEGER* k,
29
        const float* alpha,
30
        const float* a,
31
        FINTEGER* lda,
32
        const float* b,
33
        FINTEGER* ldb,
34
        float* beta,
35
        float* c,
36
        FINTEGER* ldc);
37
}
38
39
namespace faiss {
40
41
/********************************************************************
42
 * Basic routines
43
 ********************************************************************/
44
45
namespace {
46
47
template <size_t M, size_t NK>
48
void accum_and_store_tab(
49
        const size_t m_offset,
50
        const float* const __restrict codebook_cross_norms,
51
        const uint64_t* const __restrict codebook_offsets,
52
        const int32_t* const __restrict codes_i,
53
        const size_t b,
54
        const size_t ldc,
55
        const size_t K,
56
0
        float* const __restrict output) {
57
    // load pointers into registers
58
0
    const float* cbs[M];
59
0
    for (size_t ij = 0; ij < M; ij++) {
60
0
        const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
61
0
        cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
62
0
    }
63
64
    // do accumulation in registers using SIMD.
65
    // It is possible that compiler may be smart enough so that
66
    //   this manual SIMD unrolling might be unneeded.
67
0
#if defined(__AVX2__) || defined(__aarch64__)
68
0
    const size_t K8 = (K / (8 * NK)) * (8 * NK);
69
70
    // process in chunks of size (8 * NK) floats
71
0
    for (size_t kk = 0; kk < K8; kk += 8 * NK) {
72
0
        simd8float32 regs[NK];
73
0
        for (size_t ik = 0; ik < NK; ik++) {
74
0
            regs[ik].loadu(cbs[0] + kk + ik * 8);
75
0
        }
76
77
0
        for (size_t ij = 1; ij < M; ij++) {
78
0
            for (size_t ik = 0; ik < NK; ik++) {
79
0
                regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
80
0
            }
81
0
        }
82
83
        // write the result
84
0
        for (size_t ik = 0; ik < NK; ik++) {
85
0
            regs[ik].storeu(output + kk + ik * 8);
86
0
        }
87
0
    }
88
#else
89
    const size_t K8 = 0;
90
#endif
91
92
    // process leftovers
93
0
    for (size_t kk = K8; kk < K; kk++) {
94
0
        float reg = cbs[0][kk];
95
0
        for (size_t ij = 1; ij < M; ij++) {
96
0
            reg += cbs[ij][kk];
97
0
        }
98
0
        output[kk] = reg;
99
0
    }
100
0
}
101
102
template <size_t M, size_t NK>
103
void accum_and_add_tab(
104
        const size_t m_offset,
105
        const float* const __restrict codebook_cross_norms,
106
        const uint64_t* const __restrict codebook_offsets,
107
        const int32_t* const __restrict codes_i,
108
        const size_t b,
109
        const size_t ldc,
110
        const size_t K,
111
0
        float* const __restrict output) {
112
    // load pointers into registers
113
0
    const float* cbs[M];
114
0
    for (size_t ij = 0; ij < M; ij++) {
115
0
        const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
116
0
        cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
117
0
    }
118
119
    // do accumulation in registers using SIMD.
120
    // It is possible that compiler may be smart enough so that
121
    //   this manual SIMD unrolling might be unneeded.
122
0
#if defined(__AVX2__) || defined(__aarch64__)
123
0
    const size_t K8 = (K / (8 * NK)) * (8 * NK);
124
125
    // process in chunks of size (8 * NK) floats
126
0
    for (size_t kk = 0; kk < K8; kk += 8 * NK) {
127
0
        simd8float32 regs[NK];
128
0
        for (size_t ik = 0; ik < NK; ik++) {
129
0
            regs[ik].loadu(cbs[0] + kk + ik * 8);
130
0
        }
131
132
0
        for (size_t ij = 1; ij < M; ij++) {
133
0
            for (size_t ik = 0; ik < NK; ik++) {
134
0
                regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
135
0
            }
136
0
        }
137
138
        // write the result
139
0
        for (size_t ik = 0; ik < NK; ik++) {
140
0
            simd8float32 existing(output + kk + ik * 8);
141
0
            existing += regs[ik];
142
0
            existing.storeu(output + kk + ik * 8);
143
0
        }
144
0
    }
145
#else
146
    const size_t K8 = 0;
147
#endif
148
149
    // process leftovers
150
0
    for (size_t kk = K8; kk < K; kk++) {
151
0
        float reg = cbs[0][kk];
152
0
        for (size_t ij = 1; ij < M; ij++) {
153
0
            reg += cbs[ij][kk];
154
0
        }
155
0
        output[kk] += reg;
156
0
    }
157
0
}
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm1ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm2ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm3ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm4ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm5ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm6ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm7ELm4EEEvmPKfPKmPKimmmPf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_117accum_and_add_tabILm8ELm4EEEvmPKfPKmPKimmmPf
158
159
template <size_t M, size_t NK>
160
void accum_and_finalize_tab(
161
        const float* const __restrict codebook_cross_norms,
162
        const uint64_t* const __restrict codebook_offsets,
163
        const int32_t* const __restrict codes_i,
164
        const size_t b,
165
        const size_t ldc,
166
        const size_t K,
167
        const float* const __restrict distances_i,
168
        const float* const __restrict cd_common,
169
0
        float* const __restrict output) {
170
    // load pointers into registers
171
0
    const float* cbs[M];
172
0
    for (size_t ij = 0; ij < M; ij++) {
173
0
        const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
174
0
        cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
175
0
    }
176
177
    // do accumulation in registers using SIMD.
178
    // It is possible that compiler may be smart enough so that
179
    //   this manual SIMD unrolling might be unneeded.
180
0
#if defined(__AVX2__) || defined(__aarch64__)
181
0
    const size_t K8 = (K / (8 * NK)) * (8 * NK);
182
183
    // process in chunks of size (8 * NK) floats
184
0
    for (size_t kk = 0; kk < K8; kk += 8 * NK) {
185
0
        simd8float32 regs[NK];
186
0
        for (size_t ik = 0; ik < NK; ik++) {
187
0
            regs[ik].loadu(cbs[0] + kk + ik * 8);
188
0
        }
189
190
0
        for (size_t ij = 1; ij < M; ij++) {
191
0
            for (size_t ik = 0; ik < NK; ik++) {
192
0
                regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
193
0
            }
194
0
        }
195
196
0
        simd8float32 two(2.0f);
197
0
        for (size_t ik = 0; ik < NK; ik++) {
198
            // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
199
            //     + 2 * dp[k];
200
201
0
            simd8float32 common_v(cd_common + kk + ik * 8);
202
0
            common_v = fmadd(two, regs[ik], common_v);
203
204
0
            common_v += simd8float32(distances_i[b]);
205
0
            common_v.storeu(output + b * K + kk + ik * 8);
206
0
        }
207
0
    }
208
#else
209
    const size_t K8 = 0;
210
#endif
211
212
    // process leftovers
213
0
    for (size_t kk = K8; kk < K; kk++) {
214
0
        float reg = cbs[0][kk];
215
0
        for (size_t ij = 1; ij < M; ij++) {
216
0
            reg += cbs[ij][kk];
217
0
        }
218
219
0
        output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
220
0
    }
221
0
}
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm1ELm4EEEvPKfPKmPKimmmS3_S3_Pf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm2ELm4EEEvPKfPKmPKimmmS3_S3_Pf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm3ELm4EEEvPKfPKmPKimmmS3_S3_Pf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm4ELm4EEEvPKfPKmPKimmmS3_S3_Pf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm5ELm4EEEvPKfPKmPKimmmS3_S3_Pf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm6ELm4EEEvPKfPKmPKimmmS3_S3_Pf
Unexecuted instantiation: residual_quantizer_encode_steps.cpp:_ZN5faiss12_GLOBAL__N_122accum_and_finalize_tabILm7ELm4EEEvPKfPKmPKimmmS3_S3_Pf
222
223
} // anonymous namespace
224
225
/********************************************************************
226
 * Single encoding step
227
 ********************************************************************/
228
229
void beam_search_encode_step(
230
        size_t d,
231
        size_t K,
232
        const float* cent, /// size (K, d)
233
        size_t n,
234
        size_t beam_size,
235
        const float* residuals, /// size (n, beam_size, d)
236
        size_t m,
237
        const int32_t* codes, /// size (n, beam_size, m)
238
        size_t new_beam_size,
239
        int32_t* new_codes,   /// size (n, new_beam_size, m + 1)
240
        float* new_residuals, /// size (n, new_beam_size, d)
241
        float* new_distances, /// size (n, new_beam_size)
242
        Index* assign_index,
243
0
        ApproxTopK_mode_t approx_topk_mode) {
244
    // we have to fill in the whole output matrix
245
0
    FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
246
247
0
    std::vector<float> cent_distances;
248
0
    std::vector<idx_t> cent_ids;
249
250
0
    if (assign_index) {
251
        // search beam_size distances per query
252
0
        FAISS_THROW_IF_NOT(assign_index->d == d);
253
0
        cent_distances.resize(n * beam_size * new_beam_size);
254
0
        cent_ids.resize(n * beam_size * new_beam_size);
255
0
        if (assign_index->ntotal != 0) {
256
            // then we assume the codebooks are already added to the index
257
0
            FAISS_THROW_IF_NOT(assign_index->ntotal == K);
258
0
        } else {
259
0
            assign_index->add(K, cent);
260
0
        }
261
262
        // printf("beam_search_encode_step -- mem usage %zd\n",
263
        // get_mem_usage_kb());
264
0
        assign_index->search(
265
0
                n * beam_size,
266
0
                residuals,
267
0
                new_beam_size,
268
0
                cent_distances.data(),
269
0
                cent_ids.data());
270
0
    } else {
271
        // do one big distance computation
272
0
        cent_distances.resize(n * beam_size * K);
273
0
        pairwise_L2sqr(
274
0
                d, n * beam_size, residuals, K, cent, cent_distances.data());
275
0
    }
276
0
    InterruptCallback::check();
277
278
0
#pragma omp parallel for if (n > 100)
279
0
    for (int64_t i = 0; i < n; i++) {
280
0
        const int32_t* codes_i = codes + i * m * beam_size;
281
0
        int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
282
0
        const float* residuals_i = residuals + i * d * beam_size;
283
0
        float* new_residuals_i = new_residuals + i * d * new_beam_size;
284
285
0
        float* new_distances_i = new_distances + i * new_beam_size;
286
0
        using C = CMax<float, int>;
287
288
0
        if (assign_index) {
289
0
            const float* cent_distances_i =
290
0
                    cent_distances.data() + i * beam_size * new_beam_size;
291
0
            const idx_t* cent_ids_i =
292
0
                    cent_ids.data() + i * beam_size * new_beam_size;
293
294
            // here we could be a tad more efficient by merging sorted arrays
295
0
            for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
296
0
                new_distances_i[i_2] = C::neutral();
297
0
            }
298
0
            std::vector<int> perm(new_beam_size, -1);
299
0
            heap_addn<C>(
300
0
                    new_beam_size,
301
0
                    new_distances_i,
302
0
                    perm.data(),
303
0
                    cent_distances_i,
304
0
                    nullptr,
305
0
                    beam_size * new_beam_size);
306
0
            heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
307
308
0
            for (int j = 0; j < new_beam_size; j++) {
309
0
                int js = perm[j] / new_beam_size;
310
0
                int ls = cent_ids_i[perm[j]];
311
0
                if (m > 0) {
312
0
                    memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
313
0
                }
314
0
                new_codes_i[m] = ls;
315
0
                new_codes_i += m + 1;
316
0
                fvec_sub(
317
0
                        d,
318
0
                        residuals_i + js * d,
319
0
                        cent + ls * d,
320
0
                        new_residuals_i);
321
0
                new_residuals_i += d;
322
0
            }
323
324
0
        } else {
325
0
            const float* cent_distances_i =
326
0
                    cent_distances.data() + i * beam_size * K;
327
            // then we have to select the best results
328
0
            for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
329
0
                new_distances_i[i_2] = C::neutral();
330
0
            }
331
0
            std::vector<int> perm(new_beam_size, -1);
332
333
0
#define HANDLE_APPROX(NB, BD)                                  \
334
0
    case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
335
0
        HeapWithBuckets<C, NB, BD>::bs_addn(                   \
336
0
                beam_size,                                     \
337
0
                K,                                             \
338
0
                cent_distances_i,                              \
339
0
                new_beam_size,                                 \
340
0
                new_distances_i,                               \
341
0
                perm.data());                                  \
342
0
        break;
343
344
0
            switch (approx_topk_mode) {
345
0
                HANDLE_APPROX(8, 3)
346
0
                HANDLE_APPROX(8, 2)
347
0
                HANDLE_APPROX(16, 2)
348
0
                HANDLE_APPROX(32, 2)
349
0
                default:
350
0
                    heap_addn<C>(
351
0
                            new_beam_size,
352
0
                            new_distances_i,
353
0
                            perm.data(),
354
0
                            cent_distances_i,
355
0
                            nullptr,
356
0
                            beam_size * K);
357
0
            }
358
0
            heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
359
360
0
#undef HANDLE_APPROX
361
362
0
            for (int j = 0; j < new_beam_size; j++) {
363
0
                int js = perm[j] / K;
364
0
                int ls = perm[j] % K;
365
0
                if (m > 0) {
366
0
                    memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
367
0
                }
368
0
                new_codes_i[m] = ls;
369
0
                new_codes_i += m + 1;
370
0
                fvec_sub(
371
0
                        d,
372
0
                        residuals_i + js * d,
373
0
                        cent + ls * d,
374
0
                        new_residuals_i);
375
0
                new_residuals_i += d;
376
0
            }
377
0
        }
378
0
    }
379
0
}
380
381
// exposed in the faiss namespace
382
void beam_search_encode_step_tab(
383
        size_t K,
384
        size_t n,
385
        size_t beam_size,                  // input sizes
386
        const float* codebook_cross_norms, // size K * ldc
387
        size_t ldc,
388
        const uint64_t* codebook_offsets, // m
389
        const float* query_cp,            // size n * ldqc
390
        size_t ldqc,                      // >= K
391
        const float* cent_norms_i,        // size K
392
        size_t m,
393
        const int32_t* codes,   // n * beam_size * m
394
        const float* distances, // n * beam_size
395
        size_t new_beam_size,
396
        int32_t* new_codes,                 // n * new_beam_size * (m + 1)
397
        float* new_distances,               // n * new_beam_size
398
        ApproxTopK_mode_t approx_topk_mode) //
399
0
{
400
0
    FAISS_THROW_IF_NOT(ldc >= K);
401
402
0
#pragma omp parallel for if (n > 100) schedule(dynamic)
403
0
    for (int64_t i = 0; i < n; i++) {
404
0
        std::vector<float> cent_distances(beam_size * K);
405
0
        std::vector<float> cd_common(K);
406
407
0
        const int32_t* codes_i = codes + i * m * beam_size;
408
0
        const float* query_cp_i = query_cp + i * ldqc;
409
0
        const float* distances_i = distances + i * beam_size;
410
411
0
        for (size_t k = 0; k < K; k++) {
412
0
            cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413
0
        }
414
415
0
        bool use_baseline_implementation = false;
416
417
        // This is the baseline implementation. Its primary flaw
418
        //   that it writes way too many info to the temporary buffer
419
        //   called dp.
420
        //
421
        // This baseline code is kept intentionally because it is easy to
422
        // understand what an optimized version optimizes exactly.
423
        //
424
0
        if (use_baseline_implementation) {
425
0
            for (size_t b = 0; b < beam_size; b++) {
426
0
                std::vector<float> dp(K);
427
428
0
                for (size_t m1 = 0; m1 < m; m1++) {
429
0
                    size_t c = codes_i[b * m + m1];
430
0
                    const float* cb =
431
0
                            &codebook_cross_norms
432
0
                                    [(codebook_offsets[m1] + c) * ldc];
433
0
                    fvec_add(K, cb, dp.data(), dp.data());
434
0
                }
435
436
0
                for (size_t k = 0; k < K; k++) {
437
0
                    cent_distances[b * K + k] =
438
0
                            distances_i[b] + cd_common[k] + 2 * dp[k];
439
0
                }
440
0
            }
441
442
0
        } else {
443
            // An optimized implementation that avoids using a temporary buffer
444
            // and does the accumulation in registers.
445
446
            // Compute a sum of NK AQ codes.
447
0
#define ACCUM_AND_FINALIZE_TAB(NK)               \
448
0
    case NK:                                     \
449
0
        for (size_t b = 0; b < beam_size; b++) { \
450
0
            accum_and_finalize_tab<NK, 4>(       \
451
0
                    codebook_cross_norms,        \
452
0
                    codebook_offsets,            \
453
0
                    codes_i,                     \
454
0
                    b,                           \
455
0
                    ldc,                         \
456
0
                    K,                           \
457
0
                    distances_i,                 \
458
0
                    cd_common.data(),            \
459
0
                    cent_distances.data());      \
460
0
        }                                        \
461
0
        break;
462
463
            // this version contains many switch-case scenarios, but
464
            // they won't affect branch predictor.
465
0
            switch (m) {
466
0
                case 0:
467
                    // trivial case
468
0
                    for (size_t b = 0; b < beam_size; b++) {
469
0
                        for (size_t k = 0; k < K; k++) {
470
0
                            cent_distances[b * K + k] =
471
0
                                    distances_i[b] + cd_common[k];
472
0
                        }
473
0
                    }
474
0
                    break;
475
476
0
                    ACCUM_AND_FINALIZE_TAB(1)
477
0
                    ACCUM_AND_FINALIZE_TAB(2)
478
0
                    ACCUM_AND_FINALIZE_TAB(3)
479
0
                    ACCUM_AND_FINALIZE_TAB(4)
480
0
                    ACCUM_AND_FINALIZE_TAB(5)
481
0
                    ACCUM_AND_FINALIZE_TAB(6)
482
0
                    ACCUM_AND_FINALIZE_TAB(7)
483
484
0
                default: {
485
                    // m >= 8 case.
486
487
                    // A temporary buffer has to be used due to the lack of
488
                    // registers. But we'll try to accumulate up to 8 AQ codes
489
                    // in registers and issue a single write operation to the
490
                    // buffer, while the baseline does no accumulation. So, the
491
                    // number of write operations to the temporary buffer is
492
                    // reduced 8x.
493
494
                    // allocate a temporary buffer
495
0
                    std::vector<float> dp(K);
496
497
0
                    for (size_t b = 0; b < beam_size; b++) {
498
                        // Initialize it. Compute a sum of first 8 AQ codes
499
                        // because m >= 8 .
500
0
                        accum_and_store_tab<8, 4>(
501
0
                                m,
502
0
                                codebook_cross_norms,
503
0
                                codebook_offsets,
504
0
                                codes_i,
505
0
                                b,
506
0
                                ldc,
507
0
                                K,
508
0
                                dp.data());
509
510
0
#define ACCUM_AND_ADD_TAB(NK)          \
511
0
    case NK:                           \
512
0
        accum_and_add_tab<NK, 4>(      \
513
0
                m,                     \
514
0
                codebook_cross_norms,  \
515
0
                codebook_offsets + im, \
516
0
                codes_i + im,          \
517
0
                b,                     \
518
0
                ldc,                   \
519
0
                K,                     \
520
0
                dp.data());            \
521
0
        break;
522
523
                        // accumulate up to 8 additional AQ codes into
524
                        // a temporary buffer
525
0
                        for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
526
0
                            size_t m_left = m - im;
527
0
                            if (m_left > 8) {
528
0
                                m_left = 8;
529
0
                            }
530
531
0
                            switch (m_left) {
532
0
                                ACCUM_AND_ADD_TAB(1)
533
0
                                ACCUM_AND_ADD_TAB(2)
534
0
                                ACCUM_AND_ADD_TAB(3)
535
0
                                ACCUM_AND_ADD_TAB(4)
536
0
                                ACCUM_AND_ADD_TAB(5)
537
0
                                ACCUM_AND_ADD_TAB(6)
538
0
                                ACCUM_AND_ADD_TAB(7)
539
0
                                ACCUM_AND_ADD_TAB(8)
540
0
                            }
541
0
                        }
542
543
                        // done. finalize the result
544
0
                        for (size_t k = 0; k < K; k++) {
545
0
                            cent_distances[b * K + k] =
546
0
                                    distances_i[b] + cd_common[k] + 2 * dp[k];
547
0
                        }
548
0
                    }
549
0
                }
550
0
            }
551
552
            // the optimized implementation ends here
553
0
        }
554
0
        using C = CMax<float, int>;
555
0
        int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
556
0
        float* new_distances_i = new_distances + i * new_beam_size;
557
558
0
        const float* cent_distances_i = cent_distances.data();
559
560
        // then we have to select the best results
561
0
        for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
562
0
            new_distances_i[i_2] = C::neutral();
563
0
        }
564
0
        std::vector<int> perm(new_beam_size, -1);
565
566
0
#define HANDLE_APPROX(NB, BD)                                  \
567
0
    case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
568
0
        HeapWithBuckets<C, NB, BD>::bs_addn(                   \
569
0
                beam_size,                                     \
570
0
                K,                                             \
571
0
                cent_distances_i,                              \
572
0
                new_beam_size,                                 \
573
0
                new_distances_i,                               \
574
0
                perm.data());                                  \
575
0
        break;
576
577
0
        switch (approx_topk_mode) {
578
0
            HANDLE_APPROX(8, 3)
579
0
            HANDLE_APPROX(8, 2)
580
0
            HANDLE_APPROX(16, 2)
581
0
            HANDLE_APPROX(32, 2)
582
0
            default:
583
0
                heap_addn<C>(
584
0
                        new_beam_size,
585
0
                        new_distances_i,
586
0
                        perm.data(),
587
0
                        cent_distances_i,
588
0
                        nullptr,
589
0
                        beam_size * K);
590
0
                break;
591
0
        }
592
593
0
        heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
594
595
0
#undef HANDLE_APPROX
596
597
0
        for (int j = 0; j < new_beam_size; j++) {
598
0
            int js = perm[j] / K;
599
0
            int ls = perm[j] % K;
600
0
            if (m > 0) {
601
0
                memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
602
0
            }
603
0
            new_codes_i[m] = ls;
604
0
            new_codes_i += m + 1;
605
0
        }
606
0
    }
607
0
}
608
609
/********************************************************************
610
 * Multiple encoding steps
611
 ********************************************************************/
612
613
namespace rq_encode_steps {
614
615
void refine_beam_mp(
616
        const ResidualQuantizer& rq,
617
        size_t n,
618
        size_t beam_size,
619
        const float* x,
620
        int out_beam_size,
621
        int32_t* out_codes,
622
        float* out_residuals,
623
        float* out_distances,
624
0
        RefineBeamMemoryPool& pool) {
625
0
    int cur_beam_size = beam_size;
626
627
0
    double t0 = getmillisecs();
628
629
    // find the max_beam_size
630
0
    int max_beam_size = 0;
631
0
    {
632
0
        int tmp_beam_size = cur_beam_size;
633
0
        for (int m = 0; m < rq.M; m++) {
634
0
            int K = 1 << rq.nbits[m];
635
0
            int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
636
0
            tmp_beam_size = new_beam_size;
637
638
0
            if (max_beam_size < new_beam_size) {
639
0
                max_beam_size = new_beam_size;
640
0
            }
641
0
        }
642
0
    }
643
644
    // preallocate buffers
645
0
    pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
646
0
    pool.new_residuals.resize(n * max_beam_size * rq.d);
647
648
0
    pool.codes.resize(n * max_beam_size * (rq.M + 1));
649
0
    pool.distances.resize(n * max_beam_size);
650
0
    pool.residuals.resize(n * rq.d * max_beam_size);
651
652
0
    for (size_t i = 0; i < n * rq.d * beam_size; i++) {
653
0
        pool.residuals[i] = x[i];
654
0
    }
655
656
    // set up pointers to buffers
657
0
    int32_t* __restrict codes_ptr = pool.codes.data();
658
0
    float* __restrict residuals_ptr = pool.residuals.data();
659
660
0
    int32_t* __restrict new_codes_ptr = pool.new_codes.data();
661
0
    float* __restrict new_residuals_ptr = pool.new_residuals.data();
662
663
    // index
664
0
    std::unique_ptr<Index> assign_index;
665
0
    if (rq.assign_index_factory) {
666
0
        assign_index.reset((*rq.assign_index_factory)(rq.d));
667
0
    }
668
669
    // main loop
670
0
    size_t codes_size = 0;
671
0
    size_t distances_size = 0;
672
0
    size_t residuals_size = 0;
673
674
0
    for (int m = 0; m < rq.M; m++) {
675
0
        int K = 1 << rq.nbits[m];
676
677
0
        const float* __restrict codebooks_m =
678
0
                rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
679
680
0
        const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
681
682
0
        codes_size = n * new_beam_size * (m + 1);
683
0
        residuals_size = n * new_beam_size * rq.d;
684
0
        distances_size = n * new_beam_size;
685
686
0
        beam_search_encode_step(
687
0
                rq.d,
688
0
                K,
689
0
                codebooks_m,
690
0
                n,
691
0
                cur_beam_size,
692
0
                residuals_ptr,
693
0
                m,
694
0
                codes_ptr,
695
0
                new_beam_size,
696
0
                new_codes_ptr,
697
0
                new_residuals_ptr,
698
0
                pool.distances.data(),
699
0
                assign_index.get(),
700
0
                rq.approx_topk_mode);
701
702
0
        if (assign_index != nullptr) {
703
0
            assign_index->reset();
704
0
        }
705
706
0
        std::swap(codes_ptr, new_codes_ptr);
707
0
        std::swap(residuals_ptr, new_residuals_ptr);
708
709
0
        cur_beam_size = new_beam_size;
710
711
0
        if (rq.verbose) {
712
0
            float sum_distances = 0;
713
0
            for (int j = 0; j < distances_size; j++) {
714
0
                sum_distances += pool.distances[j];
715
0
            }
716
717
0
            printf("[%.3f s] encode stage %d, %d bits, "
718
0
                   "total error %g, beam_size %d\n",
719
0
                   (getmillisecs() - t0) / 1000,
720
0
                   m,
721
0
                   int(rq.nbits[m]),
722
0
                   sum_distances,
723
0
                   cur_beam_size);
724
0
        }
725
0
    }
726
727
0
    if (out_codes) {
728
0
        memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
729
0
    }
730
0
    if (out_residuals) {
731
0
        memcpy(out_residuals,
732
0
               residuals_ptr,
733
0
               residuals_size * sizeof(*residuals_ptr));
734
0
    }
735
0
    if (out_distances) {
736
0
        memcpy(out_distances,
737
0
               pool.distances.data(),
738
0
               distances_size * sizeof(pool.distances[0]));
739
0
    }
740
0
}
741
742
void refine_beam_LUT_mp(
743
        const ResidualQuantizer& rq,
744
        size_t n,
745
        const float* query_norms, // size n
746
        const float* query_cp,    //
747
        int out_beam_size,
748
        int32_t* out_codes,
749
        float* out_distances,
750
0
        RefineBeamLUTMemoryPool& pool) {
751
0
    int beam_size = 1;
752
753
0
    double t0 = getmillisecs();
754
755
    // find the max_beam_size
756
0
    int max_beam_size = 0;
757
0
    {
758
0
        int tmp_beam_size = beam_size;
759
0
        for (int m = 0; m < rq.M; m++) {
760
0
            int K = 1 << rq.nbits[m];
761
0
            int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
762
0
            tmp_beam_size = new_beam_size;
763
764
0
            if (max_beam_size < new_beam_size) {
765
0
                max_beam_size = new_beam_size;
766
0
            }
767
0
        }
768
0
    }
769
770
    // preallocate buffers
771
0
    pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
772
0
    pool.new_distances.resize(n * max_beam_size);
773
774
0
    pool.codes.resize(n * max_beam_size * (rq.M + 1));
775
0
    pool.distances.resize(n * max_beam_size);
776
777
0
    for (size_t i = 0; i < n; i++) {
778
0
        pool.distances[i] = query_norms[i];
779
0
    }
780
781
    // set up pointers to buffers
782
0
    int32_t* __restrict new_codes_ptr = pool.new_codes.data();
783
0
    float* __restrict new_distances_ptr = pool.new_distances.data();
784
785
0
    int32_t* __restrict codes_ptr = pool.codes.data();
786
0
    float* __restrict distances_ptr = pool.distances.data();
787
788
    // main loop
789
0
    size_t codes_size = 0;
790
0
    size_t distances_size = 0;
791
0
    size_t cross_ofs = 0;
792
0
    for (int m = 0; m < rq.M; m++) {
793
0
        int K = 1 << rq.nbits[m];
794
795
        // it is guaranteed that (new_beam_size <= max_beam_size)
796
0
        int new_beam_size = std::min(beam_size * K, out_beam_size);
797
798
0
        codes_size = n * new_beam_size * (m + 1);
799
0
        distances_size = n * new_beam_size;
800
0
        FAISS_THROW_IF_NOT(
801
0
                cross_ofs + rq.codebook_offsets[m] * K <=
802
0
                rq.codebook_cross_products.size());
803
0
        beam_search_encode_step_tab(
804
0
                K,
805
0
                n,
806
0
                beam_size,
807
0
                rq.codebook_cross_products.data() + cross_ofs,
808
0
                K,
809
0
                rq.codebook_offsets.data(),
810
0
                query_cp + rq.codebook_offsets[m],
811
0
                rq.total_codebook_size,
812
0
                rq.centroid_norms.data() + rq.codebook_offsets[m],
813
0
                m,
814
0
                codes_ptr,
815
0
                distances_ptr,
816
0
                new_beam_size,
817
0
                new_codes_ptr,
818
0
                new_distances_ptr,
819
0
                rq.approx_topk_mode);
820
0
        cross_ofs += rq.codebook_offsets[m] * K;
821
0
        std::swap(codes_ptr, new_codes_ptr);
822
0
        std::swap(distances_ptr, new_distances_ptr);
823
824
0
        beam_size = new_beam_size;
825
826
0
        if (rq.verbose) {
827
0
            float sum_distances = 0;
828
0
            for (int j = 0; j < distances_size; j++) {
829
0
                sum_distances += distances_ptr[j];
830
0
            }
831
0
            printf("[%.3f s] encode stage %d, %d bits, "
832
0
                   "total error %g, beam_size %d\n",
833
0
                   (getmillisecs() - t0) / 1000,
834
0
                   m,
835
0
                   int(rq.nbits[m]),
836
0
                   sum_distances,
837
0
                   beam_size);
838
0
        }
839
0
    }
840
0
    if (out_codes) {
841
0
        memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
842
0
    }
843
0
    if (out_distances) {
844
0
        memcpy(out_distances,
845
0
               distances_ptr,
846
0
               distances_size * sizeof(*distances_ptr));
847
0
    }
848
0
}
849
850
// this is for use_beam_LUT == 0
851
void compute_codes_add_centroids_mp_lut0(
852
        const ResidualQuantizer& rq,
853
        const float* x,
854
        uint8_t* codes_out,
855
        size_t n,
856
        const float* centroids,
857
0
        ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
858
0
    pool.codes.resize(rq.max_beam_size * rq.M * n);
859
0
    pool.distances.resize(rq.max_beam_size * n);
860
861
0
    pool.residuals.resize(rq.max_beam_size * n * rq.d);
862
863
0
    refine_beam_mp(
864
0
            rq,
865
0
            n,
866
0
            1,
867
0
            x,
868
0
            rq.max_beam_size,
869
0
            pool.codes.data(),
870
0
            pool.residuals.data(),
871
0
            pool.distances.data(),
872
0
            pool.refine_beam_pool);
873
874
0
    if (rq.search_type == ResidualQuantizer::ST_norm_float ||
875
0
        rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
876
0
        rq.search_type == ResidualQuantizer::ST_norm_qint4) {
877
0
        pool.norms.resize(n);
878
        // recover the norms of reconstruction as
879
        // || original_vector - residual ||^2
880
0
        for (size_t i = 0; i < n; i++) {
881
0
            pool.norms[i] = fvec_L2sqr(
882
0
                    x + i * rq.d,
883
0
                    pool.residuals.data() + i * rq.max_beam_size * rq.d,
884
0
                    rq.d);
885
0
        }
886
0
    }
887
888
    // pack only the first code of the beam
889
    //   (hence the ld_codes=M * max_beam_size)
890
0
    rq.pack_codes(
891
0
            n,
892
0
            pool.codes.data(),
893
0
            codes_out,
894
0
            rq.M * rq.max_beam_size,
895
0
            (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
896
0
            centroids);
897
0
}
898
899
// use_beam_LUT == 1
900
void compute_codes_add_centroids_mp_lut1(
901
        const ResidualQuantizer& rq,
902
        const float* x,
903
        uint8_t* codes_out,
904
        size_t n,
905
        const float* centroids,
906
0
        ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
907
    //
908
0
    pool.codes.resize(rq.max_beam_size * rq.M * n);
909
0
    pool.distances.resize(rq.max_beam_size * n);
910
911
0
    FAISS_THROW_IF_NOT_MSG(
912
0
            rq.M == 1 || rq.codebook_cross_products.size() > 0,
913
0
            "call compute_codebook_tables first");
914
915
0
    pool.query_norms.resize(n);
916
0
    fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
917
918
0
    pool.query_cp.resize(n * rq.total_codebook_size);
919
0
    {
920
0
        FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
921
0
        float zero = 0, one = 1;
922
0
        sgemm_("Transposed",
923
0
               "Not transposed",
924
0
               &ti,
925
0
               &ni,
926
0
               &di,
927
0
               &one,
928
0
               rq.codebooks.data(),
929
0
               &di,
930
0
               x,
931
0
               &di,
932
0
               &zero,
933
0
               pool.query_cp.data(),
934
0
               &ti);
935
0
    }
936
937
0
    refine_beam_LUT_mp(
938
0
            rq,
939
0
            n,
940
0
            pool.query_norms.data(),
941
0
            pool.query_cp.data(),
942
0
            rq.max_beam_size,
943
0
            pool.codes.data(),
944
0
            pool.distances.data(),
945
0
            pool.refine_beam_lut_pool);
946
947
    // pack only the first code of the beam
948
    //   (hence the ld_codes=M * max_beam_size)
949
0
    rq.pack_codes(
950
0
            n,
951
0
            pool.codes.data(),
952
0
            codes_out,
953
0
            rq.M * rq.max_beam_size,
954
0
            nullptr,
955
0
            centroids);
956
0
}
957
958
} // namespace rq_encode_steps
959
960
} // namespace faiss