Coverage Report

Created: 2025-10-13 05:37

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexAdditiveQuantizerFastScan.h>
9
10
#include <cassert>
11
#include <memory>
12
13
#include <faiss/impl/FaissAssert.h>
14
#include <faiss/impl/LocalSearchQuantizer.h>
15
#include <faiss/impl/LookupTableScaler.h>
16
#include <faiss/impl/ResidualQuantizer.h>
17
#include <faiss/impl/pq4_fast_scan.h>
18
#include <faiss/utils/quantize_lut.h>
19
#include <faiss/utils/utils.h>
20
21
namespace faiss {
22
23
inline size_t roundup(size_t a, size_t b) {
24
    return (a + b - 1) / b * b;
25
}
26
27
IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
28
        AdditiveQuantizer* aq,
29
        MetricType metric,
30
0
        int bbs) {
31
0
    init(aq, metric, bbs);
32
0
}
33
34
void IndexAdditiveQuantizerFastScan::init(
35
        AdditiveQuantizer* aq_init,
36
        MetricType metric,
37
0
        int bbs) {
38
0
    FAISS_THROW_IF_NOT(aq_init != nullptr);
39
0
    FAISS_THROW_IF_NOT(!aq_init->nbits.empty());
40
0
    FAISS_THROW_IF_NOT(aq_init->nbits[0] == 4);
41
0
    if (metric == METRIC_INNER_PRODUCT) {
42
0
        FAISS_THROW_IF_NOT_MSG(
43
0
                aq_init->search_type == AdditiveQuantizer::ST_LUT_nonorm,
44
0
                "Search type must be ST_LUT_nonorm for IP metric");
45
0
    } else {
46
0
        FAISS_THROW_IF_NOT_MSG(
47
0
                aq_init->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
48
0
                        aq_init->search_type ==
49
0
                                AdditiveQuantizer::ST_norm_rq2x4,
50
0
                "Search type must be lsq2x4 or rq2x4 for L2 metric");
51
0
    }
52
53
0
    this->aq = aq_init;
54
0
    if (metric == METRIC_L2) {
55
0
        M = aq_init->M + 2; // 2x4 bits AQ
56
0
    } else {
57
0
        M = aq_init->M;
58
0
    }
59
0
    init_fastscan(aq_init->d, M, 4, metric, bbs);
60
61
0
    max_train_points = 1024 * ksub * M;
62
0
}
63
64
IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan()
65
0
        : IndexFastScan() {
66
0
    is_trained = false;
67
0
    aq = nullptr;
68
0
}
69
70
IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
71
        const IndexAdditiveQuantizer& orig,
72
0
        int bbs) {
73
0
    init(orig.aq, orig.metric_type, bbs);
74
75
0
    ntotal = orig.ntotal;
76
0
    is_trained = orig.is_trained;
77
0
    orig_codes = orig.codes.data();
78
79
0
    ntotal2 = roundup(ntotal, bbs);
80
0
    codes.resize(ntotal2 * M2 / 2);
81
0
    pq4_pack_codes(orig_codes, ntotal, M, ntotal2, bbs, M2, codes.get());
82
0
}
83
84
0
IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() = default;
85
86
0
void IndexAdditiveQuantizerFastScan::train(idx_t n, const float* x_in) {
87
0
    if (is_trained) {
88
0
        return;
89
0
    }
90
91
0
    const int seed = 0x12345;
92
0
    size_t nt = n;
93
0
    const float* x = fvecs_maybe_subsample(
94
0
            d, &nt, max_train_points, x_in, verbose, seed);
95
0
    n = nt;
96
0
    if (verbose) {
97
0
        printf("training additive quantizer on %zd vectors\n", nt);
98
0
    }
99
100
0
    aq->verbose = verbose;
101
0
    aq->train(n, x);
102
0
    if (metric_type == METRIC_L2) {
103
0
        estimate_norm_scale(n, x);
104
0
    }
105
106
0
    is_trained = true;
107
0
}
108
109
void IndexAdditiveQuantizerFastScan::estimate_norm_scale(
110
        idx_t n,
111
0
        const float* x_in) {
112
0
    FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
113
114
0
    constexpr int seed = 0x980903;
115
0
    constexpr size_t max_points_estimated = 65536;
116
0
    size_t ns = n;
117
0
    const float* x = fvecs_maybe_subsample(
118
0
            d, &ns, max_points_estimated, x_in, verbose, seed);
119
0
    n = ns;
120
0
    std::unique_ptr<float[]> del_x;
121
0
    if (x != x_in) {
122
0
        del_x.reset((float*)x);
123
0
    }
124
125
0
    std::vector<float> dis_tables(n * M * ksub);
126
0
    compute_float_LUT(dis_tables.data(), n, x);
127
128
    // here we compute the mean of scales for each query
129
    // TODO: try max of scales
130
0
    double scale = 0;
131
132
0
#pragma omp parallel for reduction(+ : scale)
133
0
    for (idx_t i = 0; i < n; i++) {
134
0
        const float* lut = dis_tables.data() + i * M * ksub;
135
0
        scale += quantize_lut::aq_estimate_norm_scale(M, ksub, 2, lut);
136
0
    }
137
0
    scale /= n;
138
0
    norm_scale = (int)std::roundf(std::max(scale, 1.0));
139
140
0
    if (verbose) {
141
0
        printf("estimated norm scale: %lf\n", scale);
142
0
        printf("rounded norm scale: %d\n", norm_scale);
143
0
    }
144
0
}
145
146
void IndexAdditiveQuantizerFastScan::compute_codes(
147
        uint8_t* tmp_codes,
148
        idx_t n,
149
0
        const float* x) const {
150
0
    aq->compute_codes(x, tmp_codes, n);
151
0
}
152
153
void IndexAdditiveQuantizerFastScan::compute_float_LUT(
154
        float* lut,
155
        idx_t n,
156
0
        const float* x) const {
157
0
    if (metric_type == METRIC_INNER_PRODUCT) {
158
0
        aq->compute_LUT(n, x, lut, 1.0f);
159
0
    } else {
160
        // compute inner product look-up tables
161
0
        const size_t ip_dim12 = aq->M * ksub;
162
0
        const size_t norm_dim12 = 2 * ksub;
163
0
        std::vector<float> ip_lut(n * ip_dim12);
164
0
        aq->compute_LUT(n, x, ip_lut.data(), -2.0f);
165
166
        // copy and rescale norm look-up tables
167
0
        auto norm_tabs = aq->norm_tabs;
168
0
        if (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2) {
169
0
            for (size_t i = 0; i < norm_tabs.size(); i++) {
170
0
                norm_tabs[i] /= norm_scale;
171
0
            }
172
0
        }
173
0
        const float* norm_lut = norm_tabs.data();
174
0
        FAISS_THROW_IF_NOT(norm_tabs.size() == norm_dim12);
175
176
        // combine them
177
0
        for (idx_t i = 0; i < n; i++) {
178
0
            memcpy(lut, ip_lut.data() + i * ip_dim12, ip_dim12 * sizeof(*lut));
179
0
            lut += ip_dim12;
180
0
            memcpy(lut, norm_lut, norm_dim12 * sizeof(*lut));
181
0
            lut += norm_dim12;
182
0
        }
183
0
    }
184
0
}
185
186
void IndexAdditiveQuantizerFastScan::search(
187
        idx_t n,
188
        const float* x,
189
        idx_t k,
190
        float* distances,
191
        idx_t* labels,
192
0
        const SearchParameters* params) const {
193
0
    FAISS_THROW_IF_NOT_MSG(
194
0
            !params, "search params not supported for this index");
195
0
    FAISS_THROW_IF_NOT(k > 0);
196
0
    bool rescale = (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2);
197
0
    if (!rescale) {
198
0
        IndexFastScan::search(n, x, k, distances, labels);
199
0
        return;
200
0
    }
201
202
0
    NormTableScaler scaler(norm_scale);
203
0
    if (metric_type == METRIC_L2) {
204
0
        search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
205
0
    } else {
206
0
        search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
207
0
    }
208
0
}
209
210
void IndexAdditiveQuantizerFastScan::sa_decode(
211
        idx_t n,
212
        const uint8_t* bytes,
213
0
        float* x) const {
214
0
    aq->decode(bytes, x, n);
215
0
}
216
217
/**************************************************************************************
218
 * IndexResidualQuantizerFastScan
219
 **************************************************************************************/
220
221
IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan(
222
        int d,        ///< dimensionality of the input vectors
223
        size_t M,     ///< number of subquantizers
224
        size_t nbits, ///< number of bit per subvector index
225
        MetricType metric,
226
        Search_type_t search_type,
227
        int bbs)
228
0
        : rq(d, M, nbits, search_type) {
229
0
    init(&rq, metric, bbs);
230
0
}
231
232
0
IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan() {
233
0
    aq = &rq;
234
0
}
235
236
/**************************************************************************************
237
 * IndexLocalSearchQuantizerFastScan
238
 **************************************************************************************/
239
240
IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan(
241
        int d,
242
        size_t M,     ///< number of subquantizers
243
        size_t nbits, ///< number of bit per subvector index
244
        MetricType metric,
245
        Search_type_t search_type,
246
        int bbs)
247
0
        : lsq(d, M, nbits, search_type) {
248
0
    init(&lsq, metric, bbs);
249
0
}
250
251
0
IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan() {
252
0
    aq = &lsq;
253
0
}
254
255
/**************************************************************************************
256
 * IndexProductResidualQuantizerFastScan
257
 **************************************************************************************/
258
259
IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan(
260
        int d,          ///< dimensionality of the input vectors
261
        size_t nsplits, ///< number of residual quantizers
262
        size_t Msub,    ///< number of subquantizers per RQ
263
        size_t nbits,   ///< number of bit per subvector index
264
        MetricType metric,
265
        Search_type_t search_type,
266
        int bbs)
267
0
        : prq(d, nsplits, Msub, nbits, search_type) {
268
0
    init(&prq, metric, bbs);
269
0
}
270
271
0
IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan() {
272
0
    aq = &prq;
273
0
}
274
275
/**************************************************************************************
276
 * IndexProductLocalSearchQuantizerFastScan
277
 **************************************************************************************/
278
279
IndexProductLocalSearchQuantizerFastScan::
280
        IndexProductLocalSearchQuantizerFastScan(
281
                int d,          ///< dimensionality of the input vectors
282
                size_t nsplits, ///< number of local search quantizers
283
                size_t Msub,    ///< number of subquantizers per LSQ
284
                size_t nbits,   ///< number of bit per subvector index
285
                MetricType metric,
286
                Search_type_t search_type,
287
                int bbs)
288
0
        : plsq(d, nsplits, Msub, nbits, search_type) {
289
0
    init(&plsq, metric, bbs);
290
0
}
291
292
IndexProductLocalSearchQuantizerFastScan::
293
0
        IndexProductLocalSearchQuantizerFastScan() {
294
0
    aq = &plsq;
295
0
}
296
297
} // namespace faiss