Coverage Report

Created: 2025-10-17 00:26

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/RaBitQuantizer.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/RaBitQuantizer.h>
9
10
#include <algorithm>
11
#include <cmath>
12
#include <cstring>
13
#include <limits>
14
#include <memory>
15
#include <vector>
16
17
#include <faiss/impl/FaissAssert.h>
18
#include <faiss/utils/distances.h>
19
20
namespace faiss {
21
22
struct FactorsData {
23
    // ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0)
24
    float or_minus_c_l2sqr = 0;
25
    float dp_multiplier = 0;
26
};
27
28
struct QueryFactorsData {
29
    float c1 = 0;
30
    float c2 = 0;
31
    float c34 = 0;
32
33
    float qr_to_c_L2sqr = 0;
34
    float qr_norm_L2sqr = 0;
35
};
36
37
0
static size_t get_code_size(const size_t d) {
38
0
    return (d + 7) / 8 + sizeof(FactorsData);
39
0
}
40
41
RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric)
42
0
        : Quantizer(d, get_code_size(d)), metric_type{metric} {}
43
44
0
void RaBitQuantizer::train(size_t n, const float* x) {
45
    // does nothing
46
0
}
47
48
void RaBitQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
49
0
        const {
50
0
    compute_codes_core(x, codes, n, centroid);
51
0
}
52
53
void RaBitQuantizer::compute_codes_core(
54
        const float* x,
55
        uint8_t* codes,
56
        size_t n,
57
0
        const float* centroid_in) const {
58
0
    FAISS_ASSERT(codes != nullptr);
59
0
    FAISS_ASSERT(x != nullptr);
60
0
    FAISS_ASSERT(
61
0
            (metric_type == MetricType::METRIC_L2 ||
62
0
             metric_type == MetricType::METRIC_INNER_PRODUCT));
63
64
0
    if (n == 0) {
65
0
        return;
66
0
    }
67
68
    // compute some helper constants
69
0
    const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
70
71
    // compute codes
72
0
#pragma omp parallel for if (n > 1000)
73
0
    for (int64_t i = 0; i < n; i++) {
74
        // ||or - c||^2
75
0
        float norm_L2sqr = 0;
76
        // ||or||^2, which is equal to ||P(or)||^2 and ||P^(-1)(or)||^2
77
0
        float or_L2sqr = 0;
78
        // dot product
79
0
        float dp_oO = 0;
80
81
        // the code
82
0
        uint8_t* code = codes + i * code_size;
83
0
        FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8);
84
85
        // cleanup it
86
0
        if (code != nullptr) {
87
0
            memset(code, 0, code_size);
88
0
        }
89
90
0
        for (size_t j = 0; j < d; j++) {
91
0
            const float or_minus_c = x[i * d + j] -
92
0
                    ((centroid_in == nullptr) ? 0 : centroid_in[j]);
93
0
            norm_L2sqr += or_minus_c * or_minus_c;
94
0
            or_L2sqr += x[i * d + j] * x[i * d + j];
95
96
0
            const bool xb = (or_minus_c > 0);
97
98
0
            dp_oO += xb ? or_minus_c : (-or_minus_c);
99
100
            // store the output data
101
0
            if (code != nullptr) {
102
0
                if (xb) {
103
                    // enable a particular bit
104
0
                    code[j / 8] |= (1 << (j % 8));
105
0
                }
106
0
            }
107
0
        }
108
109
        // compute factors
110
111
        // compute the inverse norm
112
0
        const float inv_norm_L2 =
113
0
                (std::abs(norm_L2sqr) < std::numeric_limits<float>::epsilon())
114
0
                ? 1.0f
115
0
                : (1.0f / std::sqrt(norm_L2sqr));
116
0
        dp_oO *= inv_norm_L2;
117
0
        dp_oO *= inv_d_sqrt;
118
119
0
        const float inv_dp_oO =
120
0
                (std::abs(dp_oO) < std::numeric_limits<float>::epsilon())
121
0
                ? 1.0f
122
0
                : (1.0f / dp_oO);
123
124
0
        fac->or_minus_c_l2sqr = norm_L2sqr;
125
0
        if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
126
0
            fac->or_minus_c_l2sqr -= or_L2sqr;
127
0
        }
128
129
0
        fac->dp_multiplier = inv_dp_oO * std::sqrt(norm_L2sqr);
130
0
    }
131
0
}
132
133
0
void RaBitQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
134
0
    decode_core(codes, x, n, centroid);
135
0
}
136
137
void RaBitQuantizer::decode_core(
138
        const uint8_t* codes,
139
        float* x,
140
        size_t n,
141
0
        const float* centroid_in) const {
142
0
    FAISS_ASSERT(codes != nullptr);
143
0
    FAISS_ASSERT(x != nullptr);
144
145
0
    const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
146
147
0
#pragma omp parallel for if (n > 1000)
148
0
    for (int64_t i = 0; i < n; i++) {
149
0
        const uint8_t* code = codes + i * code_size;
150
151
        // split the code into parts
152
0
        const uint8_t* binary_data = code;
153
0
        const FactorsData* fac =
154
0
                reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
155
156
        //
157
0
        for (size_t j = 0; j < d; j++) {
158
            // extract i-th bit
159
0
            const uint8_t masker = (1 << (j % 8));
160
0
            const float bit = ((binary_data[j / 8] & masker) == masker) ? 1 : 0;
161
162
            // compute the output code
163
0
            x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt +
164
0
                    ((centroid_in == nullptr) ? 0 : centroid_in[j]);
165
0
        }
166
0
    }
167
0
}
168
169
struct RaBitDistanceComputer : FlatCodesDistanceComputer {
170
    // dimensionality
171
    size_t d = 0;
172
    // a centroid to use
173
    const float* centroid = nullptr;
174
175
    // the metric
176
    MetricType metric_type = MetricType::METRIC_L2;
177
178
    RaBitDistanceComputer();
179
180
    float symmetric_dis(idx_t i, idx_t j) override;
181
};
182
183
0
RaBitDistanceComputer::RaBitDistanceComputer() = default;
184
185
0
float RaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
186
0
    FAISS_THROW_MSG("Not implemented");
187
0
}
188
189
struct RaBitDistanceComputerNotQ : RaBitDistanceComputer {
190
    // the rotated query (qr - c)
191
    std::vector<float> rotated_q;
192
    // some additional numbers for the query
193
    QueryFactorsData query_fac;
194
195
    RaBitDistanceComputerNotQ();
196
197
    float distance_to_code(const uint8_t* code) override;
198
199
    void set_query(const float* x) override;
200
};
201
202
0
RaBitDistanceComputerNotQ::RaBitDistanceComputerNotQ() = default;
203
204
0
float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
205
0
    FAISS_ASSERT(code != nullptr);
206
0
    FAISS_ASSERT(
207
0
            (metric_type == MetricType::METRIC_L2 ||
208
0
             metric_type == MetricType::METRIC_INNER_PRODUCT));
209
210
    // split the code into parts
211
0
    const uint8_t* binary_data = code;
212
0
    const FactorsData* fac =
213
0
            reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
214
215
    // this is the baseline code
216
    //
217
    // compute <q,o> using floats
218
0
    float dot_qo = 0;
219
    // It was a willful decision (after the discussion) to not to pre-cache
220
    //   the sum of all bits, just in order to reduce the overhead per vector.
221
0
    uint64_t sum_q = 0;
222
0
    for (size_t i = 0; i < d; i++) {
223
        // extract i-th bit
224
0
        const uint8_t masker = (1 << (i % 8));
225
0
        const bool b_bit = ((binary_data[i / 8] & masker) == masker);
226
227
        // accumulate dp
228
0
        dot_qo += (b_bit) ? rotated_q[i] : 0;
229
        // accumulate sum-of-bits
230
0
        sum_q += (b_bit) ? 1 : 0;
231
0
    }
232
233
0
    float final_dot = 0;
234
    // dot-product itself
235
0
    final_dot += query_fac.c1 * dot_qo;
236
    // normalizer coefficients
237
0
    final_dot += query_fac.c2 * sum_q;
238
    // normalizer coefficients
239
0
    final_dot -= query_fac.c34;
240
241
    // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
242
0
    const float or_c_l2sqr = fac->or_minus_c_l2sqr;
243
244
    // pre_dist = ||or - c||^2 + ||qr - c||^2 -
245
    //     2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
246
0
    const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
247
0
            2 * fac->dp_multiplier * final_dot;
248
249
0
    if (metric_type == MetricType::METRIC_L2) {
250
        // ||or - q||^ 2
251
0
        return pre_dist;
252
0
    } else {
253
        // metric == MetricType::METRIC_INNER_PRODUCT
254
255
        // this is ||q||^2
256
0
        const float query_norm_sqr = query_fac.qr_norm_L2sqr;
257
258
        // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
259
0
        return -0.5f * (pre_dist - query_norm_sqr);
260
0
    }
261
0
}
262
263
0
void RaBitDistanceComputerNotQ::set_query(const float* x) {
264
0
    FAISS_ASSERT(x != nullptr);
265
0
    FAISS_ASSERT(
266
0
            (metric_type == MetricType::METRIC_L2 ||
267
0
             metric_type == MetricType::METRIC_INNER_PRODUCT));
268
269
    // compute the distance from the query to the centroid
270
0
    if (centroid != nullptr) {
271
0
        query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
272
0
    } else {
273
0
        query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
274
0
    }
275
276
    // subtract c, obtain P^(-1)(qr - c)
277
0
    rotated_q.resize(d);
278
0
    for (size_t i = 0; i < d; i++) {
279
0
        rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
280
0
    }
281
282
    // compute some numbers
283
0
    const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
284
285
    // do not quantize the query
286
0
    float sum_q = 0;
287
0
    for (size_t i = 0; i < d; i++) {
288
0
        sum_q += rotated_q[i];
289
0
    }
290
291
0
    query_fac.c1 = 2 * inv_d;
292
0
    query_fac.c2 = 0;
293
0
    query_fac.c34 = sum_q * inv_d;
294
295
0
    if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
296
        // precompute if needed
297
0
        query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
298
0
    }
299
0
}
300
301
//
302
struct RaBitDistanceComputerQ : RaBitDistanceComputer {
303
    // the rotated and quantized query (qr - c)
304
    std::vector<uint8_t> rotated_qq;
305
    // we're using the proposed relayout-ed scheme from 3.3 that allows
306
    //    using popcounts for computing the distance.
307
    std::vector<uint8_t> rearranged_rotated_qq;
308
    // some additional numbers for the query
309
    QueryFactorsData query_fac;
310
311
    // the number of bits for SQ quantization of the query (qb > 0)
312
    uint8_t qb = 8;
313
    // the smallest value divisible by 8 that is not smaller than dim
314
    size_t popcount_aligned_dim = 0;
315
316
    RaBitDistanceComputerQ();
317
318
    float distance_to_code(const uint8_t* code) override;
319
320
    void set_query(const float* x) override;
321
};
322
323
0
RaBitDistanceComputerQ::RaBitDistanceComputerQ() = default;
324
325
0
float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
326
0
    FAISS_ASSERT(code != nullptr);
327
0
    FAISS_ASSERT(
328
0
            (metric_type == MetricType::METRIC_L2 ||
329
0
             metric_type == MetricType::METRIC_INNER_PRODUCT));
330
331
    // split the code into parts
332
0
    const uint8_t* binary_data = code;
333
0
    const FactorsData* fac =
334
0
            reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
335
336
    // // this is the baseline code
337
    // //
338
    // // compute <q,o> using integers
339
    // size_t dot_qo = 0;
340
    // for (size_t i = 0; i < d; i++) {
341
    //     // extract i-th bit
342
    //     const uint8_t masker = (1 << (i % 8));
343
    //     const uint8_t bit = ((binary_data[i / 8] & masker) == masker) ? 1 :
344
    //     0;
345
    //
346
    //     // accumulate dp
347
    //     dot_qo += bit * rotated_qq[i];
348
    // }
349
350
    // this is the scheme for popcount
351
0
    const size_t di_8b = (d + 7) / 8;
352
0
    const size_t di_64b = (di_8b / 8) * 8;
353
354
0
    uint64_t dot_qo = 0;
355
0
    for (size_t j = 0; j < qb; j++) {
356
0
        const uint8_t* query_j = rearranged_rotated_qq.data() + j * di_8b;
357
358
        // process 64-bit popcounts
359
0
        uint64_t count_dot = 0;
360
0
        for (size_t i = 0; i < di_64b; i += 8) {
361
0
            const auto qv = *(const uint64_t*)(query_j + i);
362
0
            const auto yv = *(const uint64_t*)(binary_data + i);
363
0
            count_dot += __builtin_popcountll(qv & yv);
364
0
        }
365
366
        // process leftovers
367
0
        for (size_t i = di_64b; i < di_8b; i++) {
368
0
            const auto qv = *(query_j + i);
369
0
            const auto yv = *(binary_data + i);
370
0
            count_dot += __builtin_popcount(qv & yv);
371
0
        }
372
373
0
        dot_qo += (count_dot << j);
374
0
    }
375
376
    // It was a willful decision (after the discussion) to not to pre-cache
377
    //   the sum of all bits, just in order to reduce the overhead per vector.
378
0
    uint64_t sum_q = 0;
379
0
    {
380
        // process 64-bit popcounts
381
0
        for (size_t i = 0; i < di_64b; i += 8) {
382
0
            const auto yv = *(const uint64_t*)(binary_data + i);
383
0
            sum_q += __builtin_popcountll(yv);
384
0
        }
385
386
        // process leftovers
387
0
        for (size_t i = di_64b; i < di_8b; i++) {
388
0
            const auto yv = *(binary_data + i);
389
0
            sum_q += __builtin_popcount(yv);
390
0
        }
391
0
    }
392
393
0
    float final_dot = 0;
394
    // dot-product itself
395
0
    final_dot += query_fac.c1 * dot_qo;
396
    // normalizer coefficients
397
0
    final_dot += query_fac.c2 * sum_q;
398
    // normalizer coefficients
399
0
    final_dot -= query_fac.c34;
400
401
    // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
402
0
    const float or_c_l2sqr = fac->or_minus_c_l2sqr;
403
404
    // pre_dist = ||or - c||^2 + ||qr - c||^2 -
405
    //     2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
406
0
    const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
407
0
            2 * fac->dp_multiplier * final_dot;
408
409
0
    if (metric_type == MetricType::METRIC_L2) {
410
        // ||or - q||^ 2
411
0
        return pre_dist;
412
0
    } else {
413
        // metric == MetricType::METRIC_INNER_PRODUCT
414
415
        // this is ||q||^2
416
0
        const float query_norm_sqr = query_fac.qr_norm_L2sqr;
417
418
        // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
419
0
        return -0.5f * (pre_dist - query_norm_sqr);
420
0
    }
421
0
}
422
423
0
void RaBitDistanceComputerQ::set_query(const float* x) {
424
0
    FAISS_ASSERT(x != nullptr);
425
0
    FAISS_ASSERT(
426
0
            (metric_type == MetricType::METRIC_L2 ||
427
0
             metric_type == MetricType::METRIC_INNER_PRODUCT));
428
429
    // compute the distance from the query to the centroid
430
0
    if (centroid != nullptr) {
431
0
        query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
432
0
    } else {
433
0
        query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
434
0
    }
435
436
    // allocate space
437
0
    rotated_qq.resize(d);
438
439
    // rotate the query
440
0
    std::vector<float> rotated_q(d);
441
0
    for (size_t i = 0; i < d; i++) {
442
0
        rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
443
0
    }
444
445
    // compute some numbers
446
0
    const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
447
448
    // quantize the query. compute min and max
449
0
    float v_min = std::numeric_limits<float>::max();
450
0
    float v_max = std::numeric_limits<float>::lowest();
451
0
    for (size_t i = 0; i < d; i++) {
452
0
        const float v_q = rotated_q[i];
453
0
        v_min = std::min(v_min, v_q);
454
0
        v_max = std::max(v_max, v_q);
455
0
    }
456
457
0
    const float pow_2_qb = 1 << qb;
458
459
0
    const float delta = (v_max - v_min) / (pow_2_qb - 1);
460
0
    const float inv_delta = 1.0f / delta;
461
462
0
    size_t sum_qq = 0;
463
0
    for (int32_t i = 0; i < d; i++) {
464
0
        const float v_q = rotated_q[i];
465
466
        // a default non-randomized SQ
467
0
        const int v_qq = std::round((v_q - v_min) * inv_delta);
468
469
0
        rotated_qq[i] = std::min(255, std::max(0, v_qq));
470
0
        sum_qq += v_qq;
471
0
    }
472
473
    // rearrange the query vector
474
0
    popcount_aligned_dim = ((d + 7) / 8) * 8;
475
0
    size_t offset = (d + 7) / 8;
476
477
0
    rearranged_rotated_qq.resize(offset * qb);
478
0
    std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
479
480
0
    for (size_t idim = 0; idim < d; idim++) {
481
0
        for (size_t iv = 0; iv < qb; iv++) {
482
0
            const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
483
0
            rearranged_rotated_qq[iv * offset + idim / 8] |=
484
0
                    bit ? (1 << (idim % 8)) : 0;
485
0
        }
486
0
    }
487
488
0
    query_fac.c1 = 2 * delta * inv_d;
489
0
    query_fac.c2 = 2 * v_min * inv_d;
490
0
    query_fac.c34 = inv_d * (delta * sum_qq + d * v_min);
491
492
0
    if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
493
        // precompute if needed
494
0
        query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
495
0
    }
496
0
}
497
498
FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
499
        uint8_t qb,
500
0
        const float* centroid_in) const {
501
0
    if (qb == 0) {
502
0
        auto dc = std::make_unique<RaBitDistanceComputerNotQ>();
503
0
        dc->metric_type = metric_type;
504
0
        dc->d = d;
505
0
        dc->centroid = centroid_in;
506
507
0
        return dc.release();
508
0
    } else {
509
0
        auto dc = std::make_unique<RaBitDistanceComputerQ>();
510
0
        dc->metric_type = metric_type;
511
0
        dc->d = d;
512
0
        dc->centroid = centroid_in;
513
0
        dc->qb = qb;
514
515
0
        return dc.release();
516
0
    }
517
0
}
518
519
} // namespace faiss