/root/doris/contrib/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexAdditiveQuantizerFastScan.h> |
9 | | |
10 | | #include <cassert> |
11 | | #include <memory> |
12 | | |
13 | | #include <faiss/impl/FaissAssert.h> |
14 | | #include <faiss/impl/LocalSearchQuantizer.h> |
15 | | #include <faiss/impl/LookupTableScaler.h> |
16 | | #include <faiss/impl/ResidualQuantizer.h> |
17 | | #include <faiss/impl/pq4_fast_scan.h> |
18 | | #include <faiss/utils/quantize_lut.h> |
19 | | #include <faiss/utils/utils.h> |
20 | | |
21 | | namespace faiss { |
22 | | |
23 | | inline size_t roundup(size_t a, size_t b) { |
24 | | return (a + b - 1) / b * b; |
25 | | } |
26 | | |
27 | | IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan( |
28 | | AdditiveQuantizer* aq, |
29 | | MetricType metric, |
30 | 0 | int bbs) { |
31 | 0 | init(aq, metric, bbs); |
32 | 0 | } |
33 | | |
34 | | void IndexAdditiveQuantizerFastScan::init( |
35 | | AdditiveQuantizer* aq_init, |
36 | | MetricType metric, |
37 | 0 | int bbs) { |
38 | 0 | FAISS_THROW_IF_NOT(aq_init != nullptr); |
39 | 0 | FAISS_THROW_IF_NOT(!aq_init->nbits.empty()); |
40 | 0 | FAISS_THROW_IF_NOT(aq_init->nbits[0] == 4); |
41 | 0 | if (metric == METRIC_INNER_PRODUCT) { |
42 | 0 | FAISS_THROW_IF_NOT_MSG( |
43 | 0 | aq_init->search_type == AdditiveQuantizer::ST_LUT_nonorm, |
44 | 0 | "Search type must be ST_LUT_nonorm for IP metric"); |
45 | 0 | } else { |
46 | 0 | FAISS_THROW_IF_NOT_MSG( |
47 | 0 | aq_init->search_type == AdditiveQuantizer::ST_norm_lsq2x4 || |
48 | 0 | aq_init->search_type == |
49 | 0 | AdditiveQuantizer::ST_norm_rq2x4, |
50 | 0 | "Search type must be lsq2x4 or rq2x4 for L2 metric"); |
51 | 0 | } |
52 | | |
53 | 0 | this->aq = aq_init; |
54 | 0 | if (metric == METRIC_L2) { |
55 | 0 | M = aq_init->M + 2; // 2x4 bits AQ |
56 | 0 | } else { |
57 | 0 | M = aq_init->M; |
58 | 0 | } |
59 | 0 | init_fastscan(aq_init->d, M, 4, metric, bbs); |
60 | |
|
61 | 0 | max_train_points = 1024 * ksub * M; |
62 | 0 | } |
63 | | |
64 | | IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan() |
65 | 0 | : IndexFastScan() { |
66 | 0 | is_trained = false; |
67 | 0 | aq = nullptr; |
68 | 0 | } |
69 | | |
70 | | IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan( |
71 | | const IndexAdditiveQuantizer& orig, |
72 | 0 | int bbs) { |
73 | 0 | init(orig.aq, orig.metric_type, bbs); |
74 | |
|
75 | 0 | ntotal = orig.ntotal; |
76 | 0 | is_trained = orig.is_trained; |
77 | 0 | orig_codes = orig.codes.data(); |
78 | |
|
79 | 0 | ntotal2 = roundup(ntotal, bbs); |
80 | 0 | codes.resize(ntotal2 * M2 / 2); |
81 | 0 | pq4_pack_codes(orig_codes, ntotal, M, ntotal2, bbs, M2, codes.get()); |
82 | 0 | } |
83 | | |
84 | 0 | IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() = default; |
85 | | |
86 | 0 | void IndexAdditiveQuantizerFastScan::train(idx_t n, const float* x_in) { |
87 | 0 | if (is_trained) { |
88 | 0 | return; |
89 | 0 | } |
90 | | |
91 | 0 | const int seed = 0x12345; |
92 | 0 | size_t nt = n; |
93 | 0 | const float* x = fvecs_maybe_subsample( |
94 | 0 | d, &nt, max_train_points, x_in, verbose, seed); |
95 | 0 | n = nt; |
96 | 0 | if (verbose) { |
97 | 0 | printf("training additive quantizer on %zd vectors\n", nt); |
98 | 0 | } |
99 | |
|
100 | 0 | aq->verbose = verbose; |
101 | 0 | aq->train(n, x); |
102 | 0 | if (metric_type == METRIC_L2) { |
103 | 0 | estimate_norm_scale(n, x); |
104 | 0 | } |
105 | |
|
106 | 0 | is_trained = true; |
107 | 0 | } |
108 | | |
109 | | void IndexAdditiveQuantizerFastScan::estimate_norm_scale( |
110 | | idx_t n, |
111 | 0 | const float* x_in) { |
112 | 0 | FAISS_THROW_IF_NOT(metric_type == METRIC_L2); |
113 | | |
114 | 0 | constexpr int seed = 0x980903; |
115 | 0 | constexpr size_t max_points_estimated = 65536; |
116 | 0 | size_t ns = n; |
117 | 0 | const float* x = fvecs_maybe_subsample( |
118 | 0 | d, &ns, max_points_estimated, x_in, verbose, seed); |
119 | 0 | n = ns; |
120 | 0 | std::unique_ptr<float[]> del_x; |
121 | 0 | if (x != x_in) { |
122 | 0 | del_x.reset((float*)x); |
123 | 0 | } |
124 | |
|
125 | 0 | std::vector<float> dis_tables(n * M * ksub); |
126 | 0 | compute_float_LUT(dis_tables.data(), n, x); |
127 | | |
128 | | // here we compute the mean of scales for each query |
129 | | // TODO: try max of scales |
130 | 0 | double scale = 0; |
131 | |
|
132 | 0 | #pragma omp parallel for reduction(+ : scale) |
133 | 0 | for (idx_t i = 0; i < n; i++) { |
134 | 0 | const float* lut = dis_tables.data() + i * M * ksub; |
135 | 0 | scale += quantize_lut::aq_estimate_norm_scale(M, ksub, 2, lut); |
136 | 0 | } |
137 | 0 | scale /= n; |
138 | 0 | norm_scale = (int)std::roundf(std::max(scale, 1.0)); |
139 | |
|
140 | 0 | if (verbose) { |
141 | 0 | printf("estimated norm scale: %lf\n", scale); |
142 | 0 | printf("rounded norm scale: %d\n", norm_scale); |
143 | 0 | } |
144 | 0 | } |
145 | | |
146 | | void IndexAdditiveQuantizerFastScan::compute_codes( |
147 | | uint8_t* tmp_codes, |
148 | | idx_t n, |
149 | 0 | const float* x) const { |
150 | 0 | aq->compute_codes(x, tmp_codes, n); |
151 | 0 | } |
152 | | |
153 | | void IndexAdditiveQuantizerFastScan::compute_float_LUT( |
154 | | float* lut, |
155 | | idx_t n, |
156 | 0 | const float* x) const { |
157 | 0 | if (metric_type == METRIC_INNER_PRODUCT) { |
158 | 0 | aq->compute_LUT(n, x, lut, 1.0f); |
159 | 0 | } else { |
160 | | // compute inner product look-up tables |
161 | 0 | const size_t ip_dim12 = aq->M * ksub; |
162 | 0 | const size_t norm_dim12 = 2 * ksub; |
163 | 0 | std::vector<float> ip_lut(n * ip_dim12); |
164 | 0 | aq->compute_LUT(n, x, ip_lut.data(), -2.0f); |
165 | | |
166 | | // copy and rescale norm look-up tables |
167 | 0 | auto norm_tabs = aq->norm_tabs; |
168 | 0 | if (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2) { |
169 | 0 | for (size_t i = 0; i < norm_tabs.size(); i++) { |
170 | 0 | norm_tabs[i] /= norm_scale; |
171 | 0 | } |
172 | 0 | } |
173 | 0 | const float* norm_lut = norm_tabs.data(); |
174 | 0 | FAISS_THROW_IF_NOT(norm_tabs.size() == norm_dim12); |
175 | | |
176 | | // combine them |
177 | 0 | for (idx_t i = 0; i < n; i++) { |
178 | 0 | memcpy(lut, ip_lut.data() + i * ip_dim12, ip_dim12 * sizeof(*lut)); |
179 | 0 | lut += ip_dim12; |
180 | 0 | memcpy(lut, norm_lut, norm_dim12 * sizeof(*lut)); |
181 | 0 | lut += norm_dim12; |
182 | 0 | } |
183 | 0 | } |
184 | 0 | } |
185 | | |
186 | | void IndexAdditiveQuantizerFastScan::search( |
187 | | idx_t n, |
188 | | const float* x, |
189 | | idx_t k, |
190 | | float* distances, |
191 | | idx_t* labels, |
192 | 0 | const SearchParameters* params) const { |
193 | 0 | FAISS_THROW_IF_NOT_MSG( |
194 | 0 | !params, "search params not supported for this index"); |
195 | 0 | FAISS_THROW_IF_NOT(k > 0); |
196 | 0 | bool rescale = (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2); |
197 | 0 | if (!rescale) { |
198 | 0 | IndexFastScan::search(n, x, k, distances, labels); |
199 | 0 | return; |
200 | 0 | } |
201 | | |
202 | 0 | NormTableScaler scaler(norm_scale); |
203 | 0 | if (metric_type == METRIC_L2) { |
204 | 0 | search_dispatch_implem<true>(n, x, k, distances, labels, &scaler); |
205 | 0 | } else { |
206 | 0 | search_dispatch_implem<false>(n, x, k, distances, labels, &scaler); |
207 | 0 | } |
208 | 0 | } |
209 | | |
210 | | void IndexAdditiveQuantizerFastScan::sa_decode( |
211 | | idx_t n, |
212 | | const uint8_t* bytes, |
213 | 0 | float* x) const { |
214 | 0 | aq->decode(bytes, x, n); |
215 | 0 | } |
216 | | |
217 | | /************************************************************************************** |
218 | | * IndexResidualQuantizerFastScan |
219 | | **************************************************************************************/ |
220 | | |
221 | | IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan( |
222 | | int d, ///< dimensionality of the input vectors |
223 | | size_t M, ///< number of subquantizers |
224 | | size_t nbits, ///< number of bit per subvector index |
225 | | MetricType metric, |
226 | | Search_type_t search_type, |
227 | | int bbs) |
228 | 0 | : rq(d, M, nbits, search_type) { |
229 | 0 | init(&rq, metric, bbs); |
230 | 0 | } |
231 | | |
232 | 0 | IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan() { |
233 | 0 | aq = &rq; |
234 | 0 | } |
235 | | |
236 | | /************************************************************************************** |
237 | | * IndexLocalSearchQuantizerFastScan |
238 | | **************************************************************************************/ |
239 | | |
240 | | IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan( |
241 | | int d, |
242 | | size_t M, ///< number of subquantizers |
243 | | size_t nbits, ///< number of bit per subvector index |
244 | | MetricType metric, |
245 | | Search_type_t search_type, |
246 | | int bbs) |
247 | 0 | : lsq(d, M, nbits, search_type) { |
248 | 0 | init(&lsq, metric, bbs); |
249 | 0 | } |
250 | | |
251 | 0 | IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan() { |
252 | 0 | aq = &lsq; |
253 | 0 | } |
254 | | |
255 | | /************************************************************************************** |
256 | | * IndexProductResidualQuantizerFastScan |
257 | | **************************************************************************************/ |
258 | | |
259 | | IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan( |
260 | | int d, ///< dimensionality of the input vectors |
261 | | size_t nsplits, ///< number of residual quantizers |
262 | | size_t Msub, ///< number of subquantizers per RQ |
263 | | size_t nbits, ///< number of bit per subvector index |
264 | | MetricType metric, |
265 | | Search_type_t search_type, |
266 | | int bbs) |
267 | 0 | : prq(d, nsplits, Msub, nbits, search_type) { |
268 | 0 | init(&prq, metric, bbs); |
269 | 0 | } |
270 | | |
271 | 0 | IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan() { |
272 | 0 | aq = &prq; |
273 | 0 | } |
274 | | |
275 | | /************************************************************************************** |
276 | | * IndexProductLocalSearchQuantizerFastScan |
277 | | **************************************************************************************/ |
278 | | |
279 | | IndexProductLocalSearchQuantizerFastScan:: |
280 | | IndexProductLocalSearchQuantizerFastScan( |
281 | | int d, ///< dimensionality of the input vectors |
282 | | size_t nsplits, ///< number of local search quantizers |
283 | | size_t Msub, ///< number of subquantizers per LSQ |
284 | | size_t nbits, ///< number of bit per subvector index |
285 | | MetricType metric, |
286 | | Search_type_t search_type, |
287 | | int bbs) |
288 | 0 | : plsq(d, nsplits, Msub, nbits, search_type) { |
289 | 0 | init(&plsq, metric, bbs); |
290 | 0 | } |
291 | | |
292 | | IndexProductLocalSearchQuantizerFastScan:: |
293 | 0 | IndexProductLocalSearchQuantizerFastScan() { |
294 | 0 | aq = &plsq; |
295 | 0 | } |
296 | | |
297 | | } // namespace faiss |