/root/doris/contrib/faiss/faiss/IndexScalarQuantizer.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexScalarQuantizer.h> |
11 | | |
12 | | #include <algorithm> |
13 | | #include <cstdio> |
14 | | |
15 | | #include <omp.h> |
16 | | |
17 | | #include <faiss/impl/FaissAssert.h> |
18 | | #include <faiss/impl/IDSelector.h> |
19 | | #include <faiss/impl/ScalarQuantizer.h> |
20 | | #include <faiss/utils/utils.h> |
21 | | |
22 | | namespace faiss { |
23 | | |
24 | | /******************************************************************* |
25 | | * IndexScalarQuantizer implementation |
26 | | ********************************************************************/ |
27 | | |
28 | | IndexScalarQuantizer::IndexScalarQuantizer( |
29 | | int d, |
30 | | ScalarQuantizer::QuantizerType qtype, |
31 | | MetricType metric) |
32 | 0 | : IndexFlatCodes(0, d, metric), sq(d, qtype) { |
33 | 0 | is_trained = qtype == ScalarQuantizer::QT_fp16 || |
34 | 0 | qtype == ScalarQuantizer::QT_8bit_direct || |
35 | 0 | qtype == ScalarQuantizer::QT_bf16 || |
36 | 0 | qtype == ScalarQuantizer::QT_8bit_direct_signed; |
37 | 0 | code_size = sq.code_size; |
38 | 0 | } |
39 | | |
40 | | IndexScalarQuantizer::IndexScalarQuantizer() |
41 | 0 | : IndexScalarQuantizer(0, ScalarQuantizer::QT_8bit) {} |
42 | | |
43 | 0 | void IndexScalarQuantizer::train(idx_t n, const float* x) { |
44 | 0 | sq.train(n, x); |
45 | 0 | is_trained = true; |
46 | 0 | } |
47 | | |
48 | | void IndexScalarQuantizer::search( |
49 | | idx_t n, |
50 | | const float* x, |
51 | | idx_t k, |
52 | | float* distances, |
53 | | idx_t* labels, |
54 | 0 | const SearchParameters* params) const { |
55 | 0 | const IDSelector* sel = params ? params->sel : nullptr; |
56 | |
|
57 | 0 | FAISS_THROW_IF_NOT(k > 0); |
58 | 0 | FAISS_THROW_IF_NOT(is_trained); |
59 | 0 | FAISS_THROW_IF_NOT( |
60 | 0 | metric_type == METRIC_L2 || metric_type == METRIC_INNER_PRODUCT); |
61 | | |
62 | 0 | #pragma omp parallel |
63 | 0 | { |
64 | 0 | std::unique_ptr<InvertedListScanner> scanner( |
65 | 0 | sq.select_InvertedListScanner(metric_type, nullptr, true, sel)); |
66 | |
|
67 | 0 | scanner->list_no = 0; // directly the list number |
68 | |
|
69 | 0 | #pragma omp for |
70 | 0 | for (idx_t i = 0; i < n; i++) { |
71 | 0 | float* D = distances + k * i; |
72 | 0 | idx_t* I = labels + k * i; |
73 | | // re-order heap |
74 | 0 | if (metric_type == METRIC_L2) { |
75 | 0 | maxheap_heapify(k, D, I); |
76 | 0 | } else { |
77 | 0 | minheap_heapify(k, D, I); |
78 | 0 | } |
79 | 0 | scanner->set_query(x + i * d); |
80 | 0 | scanner->scan_codes(ntotal, codes.data(), nullptr, D, I, k); |
81 | | |
82 | | // re-order heap |
83 | 0 | if (metric_type == METRIC_L2) { |
84 | 0 | maxheap_reorder(k, D, I); |
85 | 0 | } else { |
86 | 0 | minheap_reorder(k, D, I); |
87 | 0 | } |
88 | 0 | } |
89 | 0 | } |
90 | 0 | } |
91 | | |
92 | | FlatCodesDistanceComputer* IndexScalarQuantizer::get_FlatCodesDistanceComputer() |
93 | 0 | const { |
94 | 0 | ScalarQuantizer::SQDistanceComputer* dc = |
95 | 0 | sq.get_distance_computer(metric_type); |
96 | 0 | dc->code_size = sq.code_size; |
97 | 0 | dc->codes = codes.data(); |
98 | 0 | return dc; |
99 | 0 | } |
100 | | |
101 | | /* Codec interface */ |
102 | | |
103 | | void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) |
104 | 0 | const { |
105 | 0 | FAISS_THROW_IF_NOT(is_trained); |
106 | 0 | sq.compute_codes(x, bytes, n); |
107 | 0 | } |
108 | | |
109 | | void IndexScalarQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) |
110 | 0 | const { |
111 | 0 | FAISS_THROW_IF_NOT(is_trained); |
112 | 0 | sq.decode(bytes, x, n); |
113 | 0 | } |
114 | | |
115 | | /******************************************************************* |
116 | | * IndexIVFScalarQuantizer implementation |
117 | | ********************************************************************/ |
118 | | |
119 | | IndexIVFScalarQuantizer::IndexIVFScalarQuantizer( |
120 | | Index* quantizer, |
121 | | size_t d, |
122 | | size_t nlist, |
123 | | ScalarQuantizer::QuantizerType qtype, |
124 | | MetricType metric, |
125 | | bool by_residual) |
126 | 0 | : IndexIVF(quantizer, d, nlist, 0, metric), sq(d, qtype) { |
127 | 0 | code_size = sq.code_size; |
128 | 0 | this->by_residual = by_residual; |
129 | | // was not known at construction time |
130 | 0 | invlists->code_size = code_size; |
131 | 0 | is_trained = false; |
132 | 0 | } |
133 | | |
134 | 0 | IndexIVFScalarQuantizer::IndexIVFScalarQuantizer() : IndexIVF() { |
135 | 0 | by_residual = true; |
136 | 0 | } |
137 | | |
138 | | void IndexIVFScalarQuantizer::train_encoder( |
139 | | idx_t n, |
140 | | const float* x, |
141 | 0 | const idx_t* assign) { |
142 | 0 | sq.train(n, x); |
143 | 0 | } |
144 | | |
145 | 0 | idx_t IndexIVFScalarQuantizer::train_encoder_num_vectors() const { |
146 | 0 | return 100000; |
147 | 0 | } |
148 | | |
149 | | void IndexIVFScalarQuantizer::encode_vectors( |
150 | | idx_t n, |
151 | | const float* x, |
152 | | const idx_t* list_nos, |
153 | | uint8_t* codes, |
154 | 0 | bool include_listnos) const { |
155 | 0 | std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer()); |
156 | 0 | size_t coarse_size = include_listnos ? coarse_code_size() : 0; |
157 | 0 | memset(codes, 0, (code_size + coarse_size) * n); |
158 | |
|
159 | 0 | #pragma omp parallel if (n > 1000) |
160 | 0 | { |
161 | 0 | std::vector<float> residual(d); |
162 | |
|
163 | 0 | #pragma omp for |
164 | 0 | for (idx_t i = 0; i < n; i++) { |
165 | 0 | int64_t list_no = list_nos[i]; |
166 | 0 | if (list_no >= 0) { |
167 | 0 | const float* xi = x + i * d; |
168 | 0 | uint8_t* code = codes + i * (code_size + coarse_size); |
169 | 0 | if (by_residual) { |
170 | 0 | quantizer->compute_residual(xi, residual.data(), list_no); |
171 | 0 | xi = residual.data(); |
172 | 0 | } |
173 | 0 | if (coarse_size) { |
174 | 0 | encode_listno(list_no, code); |
175 | 0 | } |
176 | 0 | squant->encode_vector(xi, code + coarse_size); |
177 | 0 | } |
178 | 0 | } |
179 | 0 | } |
180 | 0 | } |
181 | | |
182 | | void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x) |
183 | 0 | const { |
184 | 0 | std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer()); |
185 | 0 | size_t coarse_size = coarse_code_size(); |
186 | |
|
187 | 0 | #pragma omp parallel if (n > 1000) |
188 | 0 | { |
189 | 0 | std::vector<float> residual(d); |
190 | |
|
191 | 0 | #pragma omp for |
192 | 0 | for (idx_t i = 0; i < n; i++) { |
193 | 0 | const uint8_t* code = codes + i * (code_size + coarse_size); |
194 | 0 | int64_t list_no = decode_listno(code); |
195 | 0 | float* xi = x + i * d; |
196 | 0 | squant->decode_vector(code + coarse_size, xi); |
197 | 0 | if (by_residual) { |
198 | 0 | quantizer->reconstruct(list_no, residual.data()); |
199 | 0 | for (size_t j = 0; j < d; j++) { |
200 | 0 | xi[j] += residual[j]; |
201 | 0 | } |
202 | 0 | } |
203 | 0 | } |
204 | 0 | } |
205 | 0 | } |
206 | | |
207 | | void IndexIVFScalarQuantizer::add_core( |
208 | | idx_t n, |
209 | | const float* x, |
210 | | const idx_t* xids, |
211 | | const idx_t* coarse_idx, |
212 | 0 | void* inverted_list_context) { |
213 | 0 | FAISS_THROW_IF_NOT(is_trained); |
214 | | |
215 | 0 | std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer()); |
216 | |
|
217 | 0 | DirectMapAdd dm_add(direct_map, n, xids); |
218 | |
|
219 | 0 | #pragma omp parallel |
220 | 0 | { |
221 | 0 | std::vector<float> residual(d); |
222 | 0 | std::vector<uint8_t> one_code(code_size); |
223 | 0 | int nt = omp_get_num_threads(); |
224 | 0 | int rank = omp_get_thread_num(); |
225 | | |
226 | | // each thread takes care of a subset of lists |
227 | 0 | for (size_t i = 0; i < n; i++) { |
228 | 0 | int64_t list_no = coarse_idx[i]; |
229 | 0 | if (list_no >= 0 && list_no % nt == rank) { |
230 | 0 | int64_t id = xids ? xids[i] : ntotal + i; |
231 | |
|
232 | 0 | const float* xi = x + i * d; |
233 | 0 | if (by_residual) { |
234 | 0 | quantizer->compute_residual(xi, residual.data(), list_no); |
235 | 0 | xi = residual.data(); |
236 | 0 | } |
237 | |
|
238 | 0 | memset(one_code.data(), 0, code_size); |
239 | 0 | squant->encode_vector(xi, one_code.data()); |
240 | |
|
241 | 0 | size_t ofs = invlists->add_entry( |
242 | 0 | list_no, id, one_code.data(), inverted_list_context); |
243 | |
|
244 | 0 | dm_add.add(i, list_no, ofs); |
245 | |
|
246 | 0 | } else if (rank == 0 && list_no == -1) { |
247 | 0 | dm_add.add(i, -1, 0); |
248 | 0 | } |
249 | 0 | } |
250 | 0 | } |
251 | |
|
252 | 0 | ntotal += n; |
253 | 0 | } |
254 | | |
255 | | InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner( |
256 | | bool store_pairs, |
257 | | const IDSelector* sel, |
258 | 0 | const IVFSearchParameters*) const { |
259 | 0 | return sq.select_InvertedListScanner( |
260 | 0 | metric_type, quantizer, store_pairs, sel, by_residual); |
261 | 0 | } |
262 | | |
263 | | void IndexIVFScalarQuantizer::reconstruct_from_offset( |
264 | | int64_t list_no, |
265 | | int64_t offset, |
266 | 0 | float* recons) const { |
267 | 0 | const uint8_t* code = invlists->get_single_code(list_no, offset); |
268 | |
|
269 | 0 | if (by_residual) { |
270 | 0 | std::vector<float> centroid(d); |
271 | 0 | quantizer->reconstruct(list_no, centroid.data()); |
272 | |
|
273 | 0 | sq.decode(code, recons, 1); |
274 | 0 | for (int i = 0; i < d; ++i) { |
275 | 0 | recons[i] += centroid[i]; |
276 | 0 | } |
277 | 0 | } else { |
278 | 0 | sq.decode(code, recons, 1); |
279 | 0 | } |
280 | 0 | } |
281 | | |
282 | | } // namespace faiss |