Coverage Report

Created: 2025-12-30 21:06

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/AdditiveQuantizer.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#pragma once
9
10
#include <cmath>
11
#include <cstdint>
12
#include <vector>
13
14
#include <faiss/Index.h>
15
#include <faiss/IndexFlat.h>
16
#include <faiss/impl/Quantizer.h>
17
18
namespace faiss {
19
20
/** Abstract structure for additive quantizers
21
 *
22
 * Different from the product quantizer in which the decoded vector is the
23
 * concatenation of M sub-vectors, additive quantizers sum M sub-vectors
24
 * to get the decoded vector.
25
 */
26
struct AdditiveQuantizer : Quantizer {
27
    size_t M;                     ///< number of codebooks
28
    std::vector<size_t> nbits;    ///< bits for each step
29
    std::vector<float> codebooks; ///< codebooks
30
31
    // derived values
32
    /// codebook #1 is stored in rows codebook_offsets[i]:codebook_offsets[i+1]
33
    /// in the codebooks table of size total_codebook_size by d
34
    std::vector<uint64_t> codebook_offsets;
35
    size_t tot_bits = 0;            ///< total number of bits (indexes + norms)
36
    size_t norm_bits = 0;           ///< bits allocated for the norms
37
    size_t total_codebook_size = 0; ///< size of the codebook in vectors
38
    bool only_8bit = false;         ///< are all nbits = 8 (use faster decoder)
39
40
    bool verbose = false;    ///< verbose during training?
41
    bool is_trained = false; ///< is trained or not
42
43
    /// auxiliary data for ST_norm_lsq2x4 and ST_norm_rq2x4
44
    /// store norms of codebook entries for 4-bit fastscan
45
    std::vector<float> norm_tabs;
46
    IndexFlat1D qnorm; ///< store and search norms
47
48
    void compute_codebook_tables();
49
50
    /// norms of all codebook entries (size total_codebook_size)
51
    std::vector<float> centroid_norms;
52
53
    /// dot products of all codebook entries with the previous codebooks
54
    /// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
55
    std::vector<float> codebook_cross_products;
56
57
    /// norms and distance matrixes with beam search can get large, so use this
58
    /// to control for the amount of memory that can be allocated
59
    size_t max_mem_distances = 5 * (size_t(1) << 30);
60
61
    /// encode a norm into norm_bits bits
62
    uint64_t encode_norm(float norm) const;
63
64
    /// encode norm by non-uniform scalar quantization
65
    uint32_t encode_qcint(float x) const;
66
67
    /// decode norm by non-uniform scalar quantization
68
    float decode_qcint(uint32_t c) const;
69
70
    /// Encodes how search is performed and how vectors are encoded
71
    enum Search_type_t {
72
        ST_decompress,    ///< decompress database vector
73
        ST_LUT_nonorm,    ///< use a LUT, don't include norms (OK for IP or
74
                          ///< normalized vectors)
75
        ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
76
                          ///< is in O(M^2))
77
        ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
78
        ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
79
        ST_norm_qint4,
80
        ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
81
        ST_norm_cqint4,
82
83
        ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
84
                        ///< scan)
85
        ST_norm_rq2x4,  ///< use a 2x4 bits rq as norm quantizer (for fast scan)
86
    };
87
88
    AdditiveQuantizer(
89
            size_t d,
90
            const std::vector<size_t>& nbits,
91
            Search_type_t search_type = ST_decompress);
92
93
    AdditiveQuantizer();
94
95
    ///< compute derived values when d, M and nbits have been set
96
    void set_derived_values();
97
98
    ///< Train the norm quantizer
99
    void train_norm(size_t n, const float* norms);
100
101
    void compute_codes(const float* x, uint8_t* codes, size_t n)
102
0
            const override {
103
0
        compute_codes_add_centroids(x, codes, n);
104
0
    }
105
106
    /** Encode a set of vectors
107
     *
108
     * @param x      vectors to encode, size n * d
109
     * @param codes  output codes, size n * code_size
110
     * @param centroids  centroids to be added to x, size n * d
111
     */
112
    virtual void compute_codes_add_centroids(
113
            const float* x,
114
            uint8_t* codes,
115
            size_t n,
116
            const float* centroids = nullptr) const = 0;
117
118
    /** pack a series of code to bit-compact format
119
     *
120
     * @param codes        codes to be packed, size n * code_size
121
     * @param packed_codes output bit-compact codes
122
     * @param ld_codes     leading dimension of codes
123
     * @param norms        norms of the vectors (size n). Will be computed if
124
     *                     needed but not provided
125
     * @param centroids    centroids to be added to x, size n * d
126
     */
127
    void pack_codes(
128
            size_t n,
129
            const int32_t* codes,
130
            uint8_t* packed_codes,
131
            int64_t ld_codes = -1,
132
            const float* norms = nullptr,
133
            const float* centroids = nullptr) const;
134
135
    /** Decode a set of vectors
136
     *
137
     * @param codes  codes to decode, size n * code_size
138
     * @param x      output vectors, size n * d
139
     */
140
    void decode(const uint8_t* codes, float* x, size_t n) const override;
141
142
    /** Decode a set of vectors in non-packed format
143
     *
144
     * @param codes  codes to decode, size n * ld_codes
145
     * @param x      output vectors, size n * d
146
     */
147
    virtual void decode_unpacked(
148
            const int32_t* codes,
149
            float* x,
150
            size_t n,
151
            int64_t ld_codes = -1) const;
152
153
    /****************************************************************************
154
     * Search functions in an external set of codes.
155
     ****************************************************************************/
156
157
    /// Also determines what's in the codes
158
    Search_type_t search_type;
159
160
    /// min/max for quantization of norms
161
    float norm_min = NAN, norm_max = NAN;
162
163
    template <bool is_IP, Search_type_t effective_search_type>
164
    float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
165
166
    /*
167
        float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
168
    */
169
    /****************************************************************************
170
     * Support for exhaustive distance computations with all the centroids.
171
     * Hence, the number of these centroids should not be too large.
172
     ****************************************************************************/
173
174
    /// decoding function for a code in a 64-bit word
175
    void decode_64bit(idx_t n, float* x) const;
176
177
    /** Compute inner-product look-up tables. Used in the centroid search
178
     * functions.
179
     *
180
     * @param xq     query vector, size (n, d)
181
     * @param LUT    look-up table, size (n, total_codebook_size)
182
     * @param alpha  compute alpha * inner-product
183
     * @param ld_lut  leading dimension of LUT
184
     */
185
    virtual void compute_LUT(
186
            size_t n,
187
            const float* xq,
188
            float* LUT,
189
            float alpha = 1.0f,
190
            long ld_lut = -1) const;
191
192
    /// exact IP search
193
    void knn_centroids_inner_product(
194
            idx_t n,
195
            const float* xq,
196
            idx_t k,
197
            float* distances,
198
            idx_t* labels) const;
199
200
    /** For L2 search we need the L2 norms of the centroids
201
     *
202
     * @param norms    output norms table, size total_codebook_size
203
     */
204
    void compute_centroid_norms(float* norms) const;
205
206
    /** Exact L2 search, with precomputed norms */
207
    void knn_centroids_L2(
208
            idx_t n,
209
            const float* xq,
210
            idx_t k,
211
            float* distances,
212
            idx_t* labels,
213
            const float* centroid_norms) const;
214
215
    virtual ~AdditiveQuantizer();
216
};
217
218
} // namespace faiss