/root/doris/contrib/faiss/faiss/impl/LocalSearchQuantizer.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <stdint.h> |
11 | | |
12 | | #include <random> |
13 | | #include <string> |
14 | | #include <unordered_map> |
15 | | #include <vector> |
16 | | |
17 | | #include <faiss/impl/AdditiveQuantizer.h> |
18 | | #include <faiss/impl/platform_macros.h> |
19 | | #include <faiss/utils/utils.h> |
20 | | |
21 | | namespace faiss { |
22 | | |
23 | | namespace lsq { |
24 | | |
25 | | struct IcmEncoderFactory; |
26 | | |
27 | | } // namespace lsq |
28 | | |
29 | | /** Implementation of LSQ/LSQ++ described in the following two papers: |
30 | | * |
31 | | * Revisiting additive quantization |
32 | | * Julieta Martinez, et al. ECCV 2016 |
33 | | * |
34 | | * LSQ++: Lower running time and higher recall in multi-codebook quantization |
35 | | * Julieta Martinez, et al. ECCV 2018 |
36 | | * |
37 | | * This implementation is mostly translated from the Julia implementations |
38 | | * by Julieta Martinez: |
39 | | * (https://github.com/una-dinosauria/local-search-quantization, |
40 | | * https://github.com/una-dinosauria/Rayuela.jl) |
41 | | * |
42 | | * The trained codes are stored in `codebooks` which is called |
43 | | * `centroids` in PQ and RQ. |
44 | | */ |
45 | | struct LocalSearchQuantizer : AdditiveQuantizer { |
46 | | size_t K; ///< number of codes per codebook |
47 | | |
48 | | size_t train_iters = 25; ///< number of iterations in training |
49 | | size_t encode_ils_iters = 16; ///< iterations of local search in encoding |
50 | | size_t train_ils_iters = 8; ///< iterations of local search in training |
51 | | size_t icm_iters = 4; ///< number of iterations in icm |
52 | | |
53 | | float p = 0.5f; ///< temperature factor |
54 | | float lambd = 1e-2f; ///< regularization factor |
55 | | |
56 | | size_t chunk_size = 10000; ///< nb of vectors to encode at a time |
57 | | |
58 | | int random_seed = 0x12345; ///< seed for random generator |
59 | | size_t nperts = 4; ///< number of perturbation in each code |
60 | | |
61 | | ///< if non-NULL, use this encoder to encode (owned by the object) |
62 | | lsq::IcmEncoderFactory* icm_encoder_factory = nullptr; |
63 | | |
64 | | bool update_codebooks_with_double = true; |
65 | | |
66 | | LocalSearchQuantizer( |
67 | | size_t d, /* dimensionality of the input vectors */ |
68 | | size_t M, /* number of subquantizers */ |
69 | | size_t nbits, /* number of bit per subvector index */ |
70 | | Search_type_t search_type = |
71 | | ST_decompress); /* determines the storage type */ |
72 | | |
73 | | LocalSearchQuantizer(); |
74 | | |
75 | | ~LocalSearchQuantizer() override; |
76 | | |
77 | | // Train the local search quantizer |
78 | | void train(size_t n, const float* x) override; |
79 | | |
80 | | /** Encode a set of vectors |
81 | | * |
82 | | * @param x vectors to encode, size n * d |
83 | | * @param codes output codes, size n * code_size |
84 | | * @param n number of vectors |
85 | | * @param centroids centroids to be added to x, size n * d |
86 | | */ |
87 | | void compute_codes_add_centroids( |
88 | | const float* x, |
89 | | uint8_t* codes, |
90 | | size_t n, |
91 | | const float* centroids = nullptr) const override; |
92 | | |
93 | | /** Update codebooks given encodings |
94 | | * |
95 | | * @param x training vectors, size n * d |
96 | | * @param codes encoded training vectors, size n * M |
97 | | * @param n number of vectors |
98 | | */ |
99 | | void update_codebooks(const float* x, const int32_t* codes, size_t n); |
100 | | |
101 | | /** Encode vectors given codebooks using iterative conditional mode (icm). |
102 | | * |
103 | | * @param codes output codes, size n * M |
104 | | * @param x vectors to encode, size n * d |
105 | | * @param n number of vectors |
106 | | * @param ils_iters number of iterations of iterative local search |
107 | | */ |
108 | | void icm_encode( |
109 | | int32_t* codes, |
110 | | const float* x, |
111 | | size_t n, |
112 | | size_t ils_iters, |
113 | | std::mt19937& gen) const; |
114 | | |
115 | | void icm_encode_impl( |
116 | | int32_t* codes, |
117 | | const float* x, |
118 | | const float* unaries, |
119 | | std::mt19937& gen, |
120 | | size_t n, |
121 | | size_t ils_iters, |
122 | | bool verbose) const; |
123 | | |
124 | | void icm_encode_step( |
125 | | int32_t* codes, |
126 | | const float* unaries, |
127 | | const float* binaries, |
128 | | size_t n, |
129 | | size_t n_iters) const; |
130 | | |
131 | | /** Add some perturbation to codes |
132 | | * |
133 | | * @param codes codes to be perturbed, size n * M |
134 | | * @param n number of vectors |
135 | | */ |
136 | | void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const; |
137 | | |
138 | | /** Add some perturbation to codebooks |
139 | | * |
140 | | * @param T temperature of simulated annealing |
141 | | * @param stddev standard derivations of each dimension in training data |
142 | | */ |
143 | | void perturb_codebooks( |
144 | | float T, |
145 | | const std::vector<float>& stddev, |
146 | | std::mt19937& gen); |
147 | | |
148 | | /** Compute binary terms |
149 | | * |
150 | | * @param binaries binary terms, size M * M * K * K |
151 | | */ |
152 | | void compute_binary_terms(float* binaries) const; |
153 | | |
154 | | /** Compute unary terms |
155 | | * |
156 | | * @param n number of vectors |
157 | | * @param x vectors to encode, size n * d |
158 | | * @param unaries unary terms, size n * M * K |
159 | | */ |
160 | | void compute_unary_terms(const float* x, float* unaries, size_t n) const; |
161 | | |
162 | | /** Helper function to compute reconstruction error |
163 | | * |
164 | | * @param codes encoded codes, size n * M |
165 | | * @param x vectors to encode, size n * d |
166 | | * @param n number of vectors |
167 | | * @param objs if it is not null, store reconstruction |
168 | | error of each vector into it, size n |
169 | | */ |
170 | | float evaluate( |
171 | | const int32_t* codes, |
172 | | const float* x, |
173 | | size_t n, |
174 | | float* objs = nullptr) const; |
175 | | }; |
176 | | |
177 | | namespace lsq { |
178 | | |
179 | | struct IcmEncoder { |
180 | | std::vector<float> binaries; |
181 | | |
182 | | bool verbose; |
183 | | |
184 | | const LocalSearchQuantizer* lsq; |
185 | | |
186 | | explicit IcmEncoder(const LocalSearchQuantizer* lsq); |
187 | | |
188 | 0 | virtual ~IcmEncoder() {} |
189 | | |
190 | | ///< compute binary terms |
191 | | virtual void set_binary_term(); |
192 | | |
193 | | /** Encode vectors given codebooks |
194 | | * |
195 | | * @param codes output codes, size n * M |
196 | | * @param x vectors to encode, size n * d |
197 | | * @param gen random generator |
198 | | * @param n number of vectors |
199 | | * @param ils_iters number of iterations of iterative local search |
200 | | */ |
201 | | virtual void encode( |
202 | | int32_t* codes, |
203 | | const float* x, |
204 | | std::mt19937& gen, |
205 | | size_t n, |
206 | | size_t ils_iters) const; |
207 | | }; |
208 | | |
209 | | struct IcmEncoderFactory { |
210 | 0 | virtual IcmEncoder* get(const LocalSearchQuantizer* lsq) { |
211 | 0 | return new IcmEncoder(lsq); |
212 | 0 | } |
213 | 0 | virtual ~IcmEncoderFactory() {} |
214 | | }; |
215 | | |
216 | | /** A helper struct to count consuming time during training. |
217 | | * It is NOT thread-safe. |
218 | | */ |
219 | | struct LSQTimer { |
220 | | std::unordered_map<std::string, double> t; |
221 | | |
222 | 1 | LSQTimer() { |
223 | 1 | reset(); |
224 | 1 | } |
225 | | |
226 | | double get(const std::string& name); |
227 | | |
228 | | void add(const std::string& name, double delta); |
229 | | |
230 | | void reset(); |
231 | | }; |
232 | | |
233 | | struct LSQTimerScope { |
234 | | double t0; |
235 | | LSQTimer* timer; |
236 | | std::string name; |
237 | | bool finished; |
238 | | |
239 | | LSQTimerScope(LSQTimer* timer, std::string name); |
240 | | |
241 | | void finish(); |
242 | | |
243 | | ~LSQTimerScope(); |
244 | | }; |
245 | | |
246 | | } // namespace lsq |
247 | | |
248 | | } // namespace faiss |