/root/doris/contrib/faiss/faiss/IndexPreTransform.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexPreTransform.h> |
11 | | |
12 | | #include <cmath> |
13 | | #include <cstdio> |
14 | | #include <cstring> |
15 | | #include <memory> |
16 | | |
17 | | #include <faiss/impl/AuxIndexStructures.h> |
18 | | #include <faiss/impl/DistanceComputer.h> |
19 | | #include <faiss/impl/FaissAssert.h> |
20 | | |
21 | | namespace faiss { |
22 | | |
23 | | /********************************************* |
24 | | * IndexPreTransform |
25 | | *********************************************/ |
26 | | |
27 | 0 | IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {} |
28 | | |
29 | | IndexPreTransform::IndexPreTransform(Index* index) |
30 | 0 | : Index(index->d, index->metric_type), index(index), own_fields(false) { |
31 | 0 | is_trained = index->is_trained; |
32 | 0 | ntotal = index->ntotal; |
33 | 0 | } |
34 | | |
35 | | IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index) |
36 | 0 | : Index(index->d, index->metric_type), index(index), own_fields(false) { |
37 | 0 | is_trained = index->is_trained; |
38 | 0 | ntotal = index->ntotal; |
39 | 0 | prepend_transform(ltrans); |
40 | 0 | } |
41 | | |
42 | 0 | void IndexPreTransform::prepend_transform(VectorTransform* ltrans) { |
43 | 0 | FAISS_THROW_IF_NOT(ltrans->d_out == d); |
44 | 0 | is_trained = is_trained && ltrans->is_trained; |
45 | 0 | chain.insert(chain.begin(), ltrans); |
46 | 0 | d = ltrans->d_in; |
47 | 0 | } |
48 | | |
49 | 0 | IndexPreTransform::~IndexPreTransform() { |
50 | 0 | if (own_fields) { |
51 | 0 | for (int i = 0; i < chain.size(); i++) |
52 | 0 | delete chain[i]; |
53 | 0 | delete index; |
54 | 0 | } |
55 | 0 | } |
56 | | |
57 | 0 | void IndexPreTransform::train(idx_t n, const float* x) { |
58 | 0 | int last_untrained = 0; |
59 | 0 | if (!index->is_trained) { |
60 | 0 | last_untrained = chain.size(); |
61 | 0 | } else { |
62 | 0 | for (int i = chain.size() - 1; i >= 0; i--) { |
63 | 0 | if (!chain[i]->is_trained) { |
64 | 0 | last_untrained = i; |
65 | 0 | break; |
66 | 0 | } |
67 | 0 | } |
68 | 0 | } |
69 | 0 | const float* prev_x = x; |
70 | 0 | std::unique_ptr<const float[]> del; |
71 | |
|
72 | 0 | if (verbose) { |
73 | 0 | printf("IndexPreTransform::train: training chain 0 to %d\n", |
74 | 0 | last_untrained); |
75 | 0 | } |
76 | |
|
77 | 0 | for (int i = 0; i <= last_untrained; i++) { |
78 | 0 | if (i < chain.size()) { |
79 | 0 | VectorTransform* ltrans = chain[i]; |
80 | 0 | if (!ltrans->is_trained) { |
81 | 0 | if (verbose) { |
82 | 0 | printf(" Training chain component %d/%zd\n", |
83 | 0 | i, |
84 | 0 | chain.size()); |
85 | 0 | if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) { |
86 | 0 | opqm->verbose = true; |
87 | 0 | } |
88 | 0 | } |
89 | 0 | ltrans->train(n, prev_x); |
90 | 0 | } |
91 | 0 | } else { |
92 | 0 | if (verbose) { |
93 | 0 | printf(" Training sub-index\n"); |
94 | 0 | } |
95 | 0 | index->train(n, prev_x); |
96 | 0 | } |
97 | 0 | if (i == last_untrained) |
98 | 0 | break; |
99 | 0 | if (verbose) { |
100 | 0 | printf(" Applying transform %d/%zd\n", i, chain.size()); |
101 | 0 | } |
102 | |
|
103 | 0 | float* xt = chain[i]->apply(n, prev_x); |
104 | |
|
105 | 0 | if (prev_x != x) { |
106 | 0 | del.reset(); |
107 | 0 | } |
108 | |
|
109 | 0 | prev_x = xt; |
110 | 0 | del.reset(xt); |
111 | 0 | } |
112 | |
|
113 | 0 | is_trained = true; |
114 | 0 | } |
115 | | |
116 | 0 | const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const { |
117 | 0 | const float* prev_x = x; |
118 | 0 | std::unique_ptr<const float[]> del; |
119 | |
|
120 | 0 | for (int i = 0; i < chain.size(); i++) { |
121 | 0 | float* xt = chain[i]->apply(n, prev_x); |
122 | 0 | std::unique_ptr<const float[]> del2(xt); |
123 | 0 | del2.swap(del); |
124 | 0 | prev_x = xt; |
125 | 0 | } |
126 | 0 | del.release(); |
127 | 0 | return prev_x; |
128 | 0 | } |
129 | | |
130 | | void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x) |
131 | 0 | const { |
132 | 0 | const float* next_x = xt; |
133 | 0 | std::unique_ptr<const float[]> del; |
134 | |
|
135 | 0 | for (int i = chain.size() - 1; i >= 0; i--) { |
136 | 0 | float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in]; |
137 | 0 | std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x); |
138 | 0 | chain[i]->reverse_transform(n, next_x, prev_x); |
139 | 0 | del2.swap(del); |
140 | 0 | next_x = prev_x; |
141 | 0 | } |
142 | 0 | } |
143 | | |
144 | 0 | void IndexPreTransform::add(idx_t n, const float* x) { |
145 | 0 | FAISS_THROW_IF_NOT(is_trained); |
146 | 0 | TransformedVectors tv(x, apply_chain(n, x)); |
147 | 0 | index->add(n, tv.x); |
148 | 0 | ntotal = index->ntotal; |
149 | 0 | } |
150 | | |
151 | | void IndexPreTransform::add_with_ids( |
152 | | idx_t n, |
153 | | const float* x, |
154 | 0 | const idx_t* xids) { |
155 | 0 | FAISS_THROW_IF_NOT(is_trained); |
156 | 0 | TransformedVectors tv(x, apply_chain(n, x)); |
157 | 0 | index->add_with_ids(n, tv.x, xids); |
158 | 0 | ntotal = index->ntotal; |
159 | 0 | } |
160 | | |
161 | | namespace { |
162 | | |
163 | | const SearchParameters* extract_index_search_params( |
164 | 0 | const SearchParameters* params_in) { |
165 | 0 | auto params = dynamic_cast<const SearchParametersPreTransform*>(params_in); |
166 | 0 | return params ? params->index_params : params_in; |
167 | 0 | } |
168 | | |
169 | | } // namespace |
170 | | |
171 | | void IndexPreTransform::search( |
172 | | idx_t n, |
173 | | const float* x, |
174 | | idx_t k, |
175 | | float* distances, |
176 | | idx_t* labels, |
177 | 0 | const SearchParameters* params) const { |
178 | 0 | FAISS_THROW_IF_NOT(k > 0); |
179 | 0 | FAISS_THROW_IF_NOT(is_trained); |
180 | 0 | const float* xt = apply_chain(n, x); |
181 | 0 | std::unique_ptr<const float[]> del(xt == x ? nullptr : xt); |
182 | 0 | index->search( |
183 | 0 | n, xt, k, distances, labels, extract_index_search_params(params)); |
184 | 0 | } |
185 | | |
186 | | void IndexPreTransform::range_search( |
187 | | idx_t n, |
188 | | const float* x, |
189 | | float radius, |
190 | | RangeSearchResult* result, |
191 | 0 | const SearchParameters* params) const { |
192 | 0 | FAISS_THROW_IF_NOT(is_trained); |
193 | 0 | TransformedVectors tv(x, apply_chain(n, x)); |
194 | 0 | index->range_search( |
195 | 0 | n, tv.x, radius, result, extract_index_search_params(params)); |
196 | 0 | } |
197 | | |
198 | 0 | void IndexPreTransform::reset() { |
199 | 0 | index->reset(); |
200 | 0 | ntotal = 0; |
201 | 0 | } |
202 | | |
203 | 0 | size_t IndexPreTransform::remove_ids(const IDSelector& sel) { |
204 | 0 | size_t nremove = index->remove_ids(sel); |
205 | 0 | ntotal = index->ntotal; |
206 | 0 | return nremove; |
207 | 0 | } |
208 | | |
209 | 0 | void IndexPreTransform::reconstruct(idx_t key, float* recons) const { |
210 | 0 | float* x = chain.empty() ? recons : new float[index->d]; |
211 | 0 | std::unique_ptr<float[]> del(recons == x ? nullptr : x); |
212 | | // Initial reconstruction |
213 | 0 | index->reconstruct(key, x); |
214 | | |
215 | | // Revert transformations from last to first |
216 | 0 | reverse_chain(1, x, recons); |
217 | 0 | } |
218 | | |
219 | 0 | void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const { |
220 | 0 | float* x = chain.empty() ? recons : new float[ni * index->d]; |
221 | 0 | std::unique_ptr<float[]> del(recons == x ? nullptr : x); |
222 | | // Initial reconstruction |
223 | 0 | index->reconstruct_n(i0, ni, x); |
224 | | |
225 | | // Revert transformations from last to first |
226 | 0 | reverse_chain(ni, x, recons); |
227 | 0 | } |
228 | | |
229 | | void IndexPreTransform::search_and_reconstruct( |
230 | | idx_t n, |
231 | | const float* x, |
232 | | idx_t k, |
233 | | float* distances, |
234 | | idx_t* labels, |
235 | | float* recons, |
236 | 0 | const SearchParameters* params) const { |
237 | 0 | FAISS_THROW_IF_NOT(k > 0); |
238 | 0 | FAISS_THROW_IF_NOT(is_trained); |
239 | | |
240 | 0 | TransformedVectors trans(x, apply_chain(n, x)); |
241 | |
|
242 | 0 | float* recons_temp = chain.empty() ? recons : new float[n * k * index->d]; |
243 | 0 | std::unique_ptr<float[]> del2( |
244 | 0 | (recons_temp == recons) ? nullptr : recons_temp); |
245 | 0 | index->search_and_reconstruct( |
246 | 0 | n, |
247 | 0 | trans.x, |
248 | 0 | k, |
249 | 0 | distances, |
250 | 0 | labels, |
251 | 0 | recons_temp, |
252 | 0 | extract_index_search_params(params)); |
253 | | |
254 | | // Revert transformations from last to first |
255 | 0 | reverse_chain(n * k, recons_temp, recons); |
256 | 0 | } |
257 | | |
258 | 0 | size_t IndexPreTransform::sa_code_size() const { |
259 | 0 | return index->sa_code_size(); |
260 | 0 | } |
261 | | |
262 | | void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes) |
263 | 0 | const { |
264 | 0 | TransformedVectors tv(x, apply_chain(n, x)); |
265 | 0 | index->sa_encode(n, tv.x, bytes); |
266 | 0 | } |
267 | | |
268 | | void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x) |
269 | 0 | const { |
270 | 0 | if (chain.empty()) { |
271 | 0 | index->sa_decode(n, bytes, x); |
272 | 0 | } else { |
273 | 0 | std::unique_ptr<float[]> x1(new float[index->d * n]); |
274 | 0 | index->sa_decode(n, bytes, x1.get()); |
275 | | // Revert transformations from last to first |
276 | 0 | reverse_chain(n, x1.get(), x); |
277 | 0 | } |
278 | 0 | } |
279 | | |
280 | 0 | void IndexPreTransform::merge_from(Index& otherIndex, idx_t add_id) { |
281 | 0 | check_compatible_for_merge(otherIndex); |
282 | 0 | auto other = static_cast<const IndexPreTransform*>(&otherIndex); |
283 | 0 | index->merge_from(*other->index, add_id); |
284 | 0 | ntotal = index->ntotal; |
285 | 0 | } |
286 | | |
287 | | void IndexPreTransform::check_compatible_for_merge( |
288 | 0 | const Index& otherIndex) const { |
289 | 0 | auto other = dynamic_cast<const IndexPreTransform*>(&otherIndex); |
290 | 0 | FAISS_THROW_IF_NOT(other); |
291 | 0 | FAISS_THROW_IF_NOT(chain.size() == other->chain.size()); |
292 | 0 | for (int i = 0; i < chain.size(); i++) { |
293 | 0 | chain[i]->check_identical(*other->chain[i]); |
294 | 0 | } |
295 | 0 | index->check_compatible_for_merge(*other->index); |
296 | 0 | } |
297 | | |
298 | | namespace { |
299 | | |
300 | | struct PreTransformDistanceComputer : DistanceComputer { |
301 | | const IndexPreTransform* index; |
302 | | std::unique_ptr<DistanceComputer> sub_dc; |
303 | | std::unique_ptr<const float[]> query; |
304 | | |
305 | | explicit PreTransformDistanceComputer(const IndexPreTransform* index) |
306 | 0 | : index(index), sub_dc(index->index->get_distance_computer()) {} |
307 | | |
308 | 0 | void set_query(const float* x) override { |
309 | 0 | const float* xt = index->apply_chain(1, x); |
310 | 0 | if (xt == x) { |
311 | 0 | sub_dc->set_query(x); |
312 | 0 | } else { |
313 | 0 | query.reset(xt); |
314 | 0 | sub_dc->set_query(xt); |
315 | 0 | } |
316 | 0 | } |
317 | | |
318 | 0 | float symmetric_dis(idx_t i, idx_t j) override { |
319 | 0 | return sub_dc->symmetric_dis(i, j); |
320 | 0 | } |
321 | | |
322 | 0 | float operator()(idx_t i) override { |
323 | 0 | return (*sub_dc)(i); |
324 | 0 | } |
325 | | }; |
326 | | |
327 | | } // anonymous namespace |
328 | | |
329 | 0 | DistanceComputer* IndexPreTransform::get_distance_computer() const { |
330 | 0 | if (chain.empty()) { |
331 | 0 | return index->get_distance_computer(); |
332 | 0 | } else { |
333 | 0 | return new PreTransformDistanceComputer(this); |
334 | 0 | } |
335 | 0 | } |
336 | | |
337 | | } // namespace faiss |