/root/doris/contrib/faiss/faiss/impl/AdditiveQuantizer.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <cmath> |
11 | | #include <cstdint> |
12 | | #include <vector> |
13 | | |
14 | | #include <faiss/Index.h> |
15 | | #include <faiss/IndexFlat.h> |
16 | | #include <faiss/impl/Quantizer.h> |
17 | | |
18 | | namespace faiss { |
19 | | |
20 | | /** Abstract structure for additive quantizers |
21 | | * |
22 | | * Different from the product quantizer in which the decoded vector is the |
23 | | * concatenation of M sub-vectors, additive quantizers sum M sub-vectors |
24 | | * to get the decoded vector. |
25 | | */ |
26 | | struct AdditiveQuantizer : Quantizer { |
27 | | size_t M; ///< number of codebooks |
28 | | std::vector<size_t> nbits; ///< bits for each step |
29 | | std::vector<float> codebooks; ///< codebooks |
30 | | |
31 | | // derived values |
32 | | /// codebook #1 is stored in rows codebook_offsets[i]:codebook_offsets[i+1] |
33 | | /// in the codebooks table of size total_codebook_size by d |
34 | | std::vector<uint64_t> codebook_offsets; |
35 | | size_t tot_bits = 0; ///< total number of bits (indexes + norms) |
36 | | size_t norm_bits = 0; ///< bits allocated for the norms |
37 | | size_t total_codebook_size = 0; ///< size of the codebook in vectors |
38 | | bool only_8bit = false; ///< are all nbits = 8 (use faster decoder) |
39 | | |
40 | | bool verbose = false; ///< verbose during training? |
41 | | bool is_trained = false; ///< is trained or not |
42 | | |
43 | | /// auxiliary data for ST_norm_lsq2x4 and ST_norm_rq2x4 |
44 | | /// store norms of codebook entries for 4-bit fastscan |
45 | | std::vector<float> norm_tabs; |
46 | | IndexFlat1D qnorm; ///< store and search norms |
47 | | |
48 | | void compute_codebook_tables(); |
49 | | |
50 | | /// norms of all codebook entries (size total_codebook_size) |
51 | | std::vector<float> centroid_norms; |
52 | | |
53 | | /// dot products of all codebook entries with the previous codebooks |
54 | | /// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1) |
55 | | std::vector<float> codebook_cross_products; |
56 | | |
57 | | /// norms and distance matrixes with beam search can get large, so use this |
58 | | /// to control for the amount of memory that can be allocated |
59 | | size_t max_mem_distances = 5 * (size_t(1) << 30); |
60 | | |
61 | | /// encode a norm into norm_bits bits |
62 | | uint64_t encode_norm(float norm) const; |
63 | | |
64 | | /// encode norm by non-uniform scalar quantization |
65 | | uint32_t encode_qcint(float x) const; |
66 | | |
67 | | /// decode norm by non-uniform scalar quantization |
68 | | float decode_qcint(uint32_t c) const; |
69 | | |
70 | | /// Encodes how search is performed and how vectors are encoded |
71 | | enum Search_type_t { |
72 | | ST_decompress, ///< decompress database vector |
73 | | ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or |
74 | | ///< normalized vectors) |
75 | | ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost |
76 | | ///< is in O(M^2)) |
77 | | ST_norm_float, ///< use a LUT, and store float32 norm with the vectors |
78 | | ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm |
79 | | ST_norm_qint4, |
80 | | ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm |
81 | | ST_norm_cqint4, |
82 | | |
83 | | ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast |
84 | | ///< scan) |
85 | | ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan) |
86 | | }; |
87 | | |
88 | | AdditiveQuantizer( |
89 | | size_t d, |
90 | | const std::vector<size_t>& nbits, |
91 | | Search_type_t search_type = ST_decompress); |
92 | | |
93 | | AdditiveQuantizer(); |
94 | | |
95 | | ///< compute derived values when d, M and nbits have been set |
96 | | void set_derived_values(); |
97 | | |
98 | | ///< Train the norm quantizer |
99 | | void train_norm(size_t n, const float* norms); |
100 | | |
101 | | void compute_codes(const float* x, uint8_t* codes, size_t n) |
102 | 0 | const override { |
103 | 0 | compute_codes_add_centroids(x, codes, n); |
104 | 0 | } |
105 | | |
106 | | /** Encode a set of vectors |
107 | | * |
108 | | * @param x vectors to encode, size n * d |
109 | | * @param codes output codes, size n * code_size |
110 | | * @param centroids centroids to be added to x, size n * d |
111 | | */ |
112 | | virtual void compute_codes_add_centroids( |
113 | | const float* x, |
114 | | uint8_t* codes, |
115 | | size_t n, |
116 | | const float* centroids = nullptr) const = 0; |
117 | | |
118 | | /** pack a series of code to bit-compact format |
119 | | * |
120 | | * @param codes codes to be packed, size n * code_size |
121 | | * @param packed_codes output bit-compact codes |
122 | | * @param ld_codes leading dimension of codes |
123 | | * @param norms norms of the vectors (size n). Will be computed if |
124 | | * needed but not provided |
125 | | * @param centroids centroids to be added to x, size n * d |
126 | | */ |
127 | | void pack_codes( |
128 | | size_t n, |
129 | | const int32_t* codes, |
130 | | uint8_t* packed_codes, |
131 | | int64_t ld_codes = -1, |
132 | | const float* norms = nullptr, |
133 | | const float* centroids = nullptr) const; |
134 | | |
135 | | /** Decode a set of vectors |
136 | | * |
137 | | * @param codes codes to decode, size n * code_size |
138 | | * @param x output vectors, size n * d |
139 | | */ |
140 | | void decode(const uint8_t* codes, float* x, size_t n) const override; |
141 | | |
142 | | /** Decode a set of vectors in non-packed format |
143 | | * |
144 | | * @param codes codes to decode, size n * ld_codes |
145 | | * @param x output vectors, size n * d |
146 | | */ |
147 | | virtual void decode_unpacked( |
148 | | const int32_t* codes, |
149 | | float* x, |
150 | | size_t n, |
151 | | int64_t ld_codes = -1) const; |
152 | | |
153 | | /**************************************************************************** |
154 | | * Search functions in an external set of codes. |
155 | | ****************************************************************************/ |
156 | | |
157 | | /// Also determines what's in the codes |
158 | | Search_type_t search_type; |
159 | | |
160 | | /// min/max for quantization of norms |
161 | | float norm_min = NAN, norm_max = NAN; |
162 | | |
163 | | template <bool is_IP, Search_type_t effective_search_type> |
164 | | float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const; |
165 | | |
166 | | /* |
167 | | float compute_1_L2sqr(const uint8_t* codes, const float* LUT); |
168 | | */ |
169 | | /**************************************************************************** |
170 | | * Support for exhaustive distance computations with all the centroids. |
171 | | * Hence, the number of these centroids should not be too large. |
172 | | ****************************************************************************/ |
173 | | |
174 | | /// decoding function for a code in a 64-bit word |
175 | | void decode_64bit(idx_t n, float* x) const; |
176 | | |
177 | | /** Compute inner-product look-up tables. Used in the centroid search |
178 | | * functions. |
179 | | * |
180 | | * @param xq query vector, size (n, d) |
181 | | * @param LUT look-up table, size (n, total_codebook_size) |
182 | | * @param alpha compute alpha * inner-product |
183 | | * @param ld_lut leading dimension of LUT |
184 | | */ |
185 | | virtual void compute_LUT( |
186 | | size_t n, |
187 | | const float* xq, |
188 | | float* LUT, |
189 | | float alpha = 1.0f, |
190 | | long ld_lut = -1) const; |
191 | | |
192 | | /// exact IP search |
193 | | void knn_centroids_inner_product( |
194 | | idx_t n, |
195 | | const float* xq, |
196 | | idx_t k, |
197 | | float* distances, |
198 | | idx_t* labels) const; |
199 | | |
200 | | /** For L2 search we need the L2 norms of the centroids |
201 | | * |
202 | | * @param norms output norms table, size total_codebook_size |
203 | | */ |
204 | | void compute_centroid_norms(float* norms) const; |
205 | | |
206 | | /** Exact L2 search, with precomputed norms */ |
207 | | void knn_centroids_L2( |
208 | | idx_t n, |
209 | | const float* xq, |
210 | | idx_t k, |
211 | | float* distances, |
212 | | idx_t* labels, |
213 | | const float* centroid_norms) const; |
214 | | |
215 | | virtual ~AdditiveQuantizer(); |
216 | | }; |
217 | | |
218 | | } // namespace faiss |