/root/doris/contrib/faiss/faiss/IndexRefine.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexRefine.h> |
9 | | |
10 | | #include <faiss/IndexFlat.h> |
11 | | #include <faiss/impl/AuxIndexStructures.h> |
12 | | #include <faiss/impl/FaissAssert.h> |
13 | | #include <faiss/utils/Heap.h> |
14 | | |
15 | | namespace faiss { |
16 | | |
17 | | /*************************************************** |
18 | | * IndexRefine |
19 | | ***************************************************/ |
20 | | |
21 | | IndexRefine::IndexRefine(Index* base_index, Index* refine_index) |
22 | 0 | : Index(base_index->d, base_index->metric_type), |
23 | 0 | base_index(base_index), |
24 | 0 | refine_index(refine_index) { |
25 | 0 | own_fields = own_refine_index = false; |
26 | 0 | if (refine_index != nullptr) { |
27 | 0 | FAISS_THROW_IF_NOT(base_index->d == refine_index->d); |
28 | 0 | FAISS_THROW_IF_NOT( |
29 | 0 | base_index->metric_type == refine_index->metric_type); |
30 | 0 | is_trained = base_index->is_trained && refine_index->is_trained; |
31 | 0 | FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal); |
32 | 0 | } // other case is useful only to construct an IndexRefineFlat |
33 | 0 | ntotal = base_index->ntotal; |
34 | 0 | } |
35 | | |
36 | | IndexRefine::IndexRefine() |
37 | 0 | : base_index(nullptr), |
38 | 0 | refine_index(nullptr), |
39 | 0 | own_fields(false), |
40 | 0 | own_refine_index(false) {} |
41 | | |
42 | 0 | void IndexRefine::train(idx_t n, const float* x) { |
43 | 0 | base_index->train(n, x); |
44 | 0 | refine_index->train(n, x); |
45 | 0 | is_trained = true; |
46 | 0 | } |
47 | | |
48 | 0 | void IndexRefine::add(idx_t n, const float* x) { |
49 | 0 | FAISS_THROW_IF_NOT(is_trained); |
50 | 0 | base_index->add(n, x); |
51 | 0 | refine_index->add(n, x); |
52 | 0 | ntotal = refine_index->ntotal; |
53 | 0 | } |
54 | | |
55 | 0 | void IndexRefine::reset() { |
56 | 0 | base_index->reset(); |
57 | 0 | refine_index->reset(); |
58 | 0 | ntotal = 0; |
59 | 0 | } |
60 | | |
61 | | namespace { |
62 | | |
63 | | using idx_t = faiss::idx_t; |
64 | | |
65 | | template <class C> |
66 | | static void reorder_2_heaps( |
67 | | idx_t n, |
68 | | idx_t k, |
69 | | idx_t* __restrict labels, |
70 | | float* __restrict distances, |
71 | | idx_t k_base, |
72 | | const idx_t* __restrict base_labels, |
73 | 0 | const float* __restrict base_distances) { |
74 | 0 | #pragma omp parallel for if (n > 1) |
75 | 0 | for (idx_t i = 0; i < n; i++) { |
76 | 0 | idx_t* idxo = labels + i * k; |
77 | 0 | float* diso = distances + i * k; |
78 | 0 | const idx_t* idxi = base_labels + i * k_base; |
79 | 0 | const float* disi = base_distances + i * k_base; |
80 | |
|
81 | 0 | heap_heapify<C>(k, diso, idxo, disi, idxi, k); |
82 | 0 | if (k_base != k) { // add remaining elements |
83 | 0 | heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k); |
84 | 0 | } |
85 | 0 | heap_reorder<C>(k, diso, idxo); |
86 | 0 | } Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMaxIflEEEEvllPlPflPKlPKf.omp_outlined_debug__ Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMinIflEEEEvllPlPflPKlPKf.omp_outlined_debug__ |
87 | 0 | } Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMaxIflEEEEvllPlPflPKlPKf Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMinIflEEEEvllPlPflPKlPKf |
88 | | |
89 | | } // anonymous namespace |
90 | | |
91 | | void IndexRefine::search( |
92 | | idx_t n, |
93 | | const float* x, |
94 | | idx_t k, |
95 | | float* distances, |
96 | | idx_t* labels, |
97 | 0 | const SearchParameters* params_in) const { |
98 | 0 | const IndexRefineSearchParameters* params = nullptr; |
99 | 0 | if (params_in) { |
100 | 0 | params = dynamic_cast<const IndexRefineSearchParameters*>(params_in); |
101 | 0 | FAISS_THROW_IF_NOT_MSG( |
102 | 0 | params, "IndexRefine params have incorrect type"); |
103 | 0 | } |
104 | | |
105 | 0 | idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor) |
106 | 0 | : idx_t(k * k_factor); |
107 | 0 | SearchParameters* base_index_params = |
108 | 0 | (params != nullptr) ? params->base_index_params : nullptr; |
109 | |
|
110 | 0 | FAISS_THROW_IF_NOT(k_base >= k); |
111 | | |
112 | 0 | FAISS_THROW_IF_NOT(base_index); |
113 | 0 | FAISS_THROW_IF_NOT(refine_index); |
114 | | |
115 | 0 | FAISS_THROW_IF_NOT(k > 0); |
116 | 0 | FAISS_THROW_IF_NOT(is_trained); |
117 | 0 | idx_t* base_labels = labels; |
118 | 0 | float* base_distances = distances; |
119 | 0 | std::unique_ptr<idx_t[]> del1; |
120 | 0 | std::unique_ptr<float[]> del2; |
121 | |
|
122 | 0 | if (k != k_base) { |
123 | 0 | base_labels = new idx_t[n * k_base]; |
124 | 0 | del1.reset(base_labels); |
125 | 0 | base_distances = new float[n * k_base]; |
126 | 0 | del2.reset(base_distances); |
127 | 0 | } |
128 | |
|
129 | 0 | base_index->search( |
130 | 0 | n, x, k_base, base_distances, base_labels, base_index_params); |
131 | |
|
132 | 0 | for (int i = 0; i < n * k_base; i++) |
133 | 0 | assert(base_labels[i] >= -1 && base_labels[i] < ntotal); |
134 | | |
135 | | // parallelize over queries |
136 | 0 | #pragma omp parallel if (n > 1) |
137 | 0 | { |
138 | 0 | std::unique_ptr<DistanceComputer> dc( |
139 | 0 | refine_index->get_distance_computer()); |
140 | 0 | #pragma omp for |
141 | 0 | for (idx_t i = 0; i < n; i++) { |
142 | 0 | dc->set_query(x + i * d); |
143 | 0 | idx_t ij = i * k_base; |
144 | 0 | for (idx_t j = 0; j < k_base; j++) { |
145 | 0 | idx_t idx = base_labels[ij]; |
146 | 0 | if (idx < 0) |
147 | 0 | break; |
148 | 0 | base_distances[ij] = (*dc)(idx); |
149 | 0 | ij++; |
150 | 0 | } |
151 | 0 | } |
152 | 0 | } |
153 | | |
154 | | // sort and store result |
155 | 0 | if (metric_type == METRIC_L2) { |
156 | 0 | typedef CMax<float, idx_t> C; |
157 | 0 | reorder_2_heaps<C>( |
158 | 0 | n, k, labels, distances, k_base, base_labels, base_distances); |
159 | |
|
160 | 0 | } else if (metric_type == METRIC_INNER_PRODUCT) { |
161 | 0 | typedef CMin<float, idx_t> C; |
162 | 0 | reorder_2_heaps<C>( |
163 | 0 | n, k, labels, distances, k_base, base_labels, base_distances); |
164 | 0 | } else { |
165 | 0 | FAISS_THROW_MSG("Metric type not supported"); |
166 | 0 | } |
167 | 0 | } |
168 | | |
169 | | void IndexRefine::range_search( |
170 | | idx_t n, |
171 | | const float* x, |
172 | | float radius, |
173 | | RangeSearchResult* result, |
174 | 0 | const SearchParameters* params_in) const { |
175 | 0 | const IndexRefineSearchParameters* params = nullptr; |
176 | 0 | if (params_in) { |
177 | 0 | params = dynamic_cast<const IndexRefineSearchParameters*>(params_in); |
178 | 0 | FAISS_THROW_IF_NOT_MSG( |
179 | 0 | params, "IndexRefine params have incorrect type"); |
180 | 0 | } |
181 | | |
182 | 0 | SearchParameters* base_index_params = |
183 | 0 | (params != nullptr) ? params->base_index_params : nullptr; |
184 | |
|
185 | 0 | base_index->range_search(n, x, radius, result, base_index_params); |
186 | |
|
187 | 0 | #pragma omp parallel if (n > 1) |
188 | 0 | { |
189 | 0 | std::unique_ptr<DistanceComputer> dc( |
190 | 0 | refine_index->get_distance_computer()); |
191 | |
|
192 | 0 | #pragma omp for |
193 | 0 | for (idx_t i = 0; i < n; i++) { |
194 | 0 | dc->set_query(x + i * d); |
195 | | |
196 | | // reevaluate distances |
197 | 0 | const size_t idx_start = result->lims[i]; |
198 | 0 | const size_t idx_end = result->lims[i + 1]; |
199 | |
|
200 | 0 | for (size_t j = idx_start; j < idx_end; j++) { |
201 | 0 | const auto label = result->labels[j]; |
202 | 0 | result->distances[j] = (*dc)(label); |
203 | 0 | } |
204 | 0 | } |
205 | 0 | } |
206 | 0 | } |
207 | | |
208 | 0 | void IndexRefine::reconstruct(idx_t key, float* recons) const { |
209 | 0 | refine_index->reconstruct(key, recons); |
210 | 0 | } |
211 | | |
212 | 0 | size_t IndexRefine::sa_code_size() const { |
213 | 0 | return base_index->sa_code_size() + refine_index->sa_code_size(); |
214 | 0 | } |
215 | | |
216 | 0 | void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { |
217 | 0 | size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size(); |
218 | 0 | std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]); |
219 | 0 | base_index->sa_encode(n, x, tmp1.get()); |
220 | 0 | std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]); |
221 | 0 | refine_index->sa_encode(n, x, tmp2.get()); |
222 | 0 | for (size_t i = 0; i < n; i++) { |
223 | 0 | uint8_t* b = bytes + i * (cs1 + cs2); |
224 | 0 | memcpy(b, tmp1.get() + cs1 * i, cs1); |
225 | 0 | memcpy(b + cs1, tmp2.get() + cs2 * i, cs2); |
226 | 0 | } |
227 | 0 | } |
228 | | |
229 | 0 | void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { |
230 | 0 | size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size(); |
231 | 0 | std::unique_ptr<uint8_t[]> tmp2( |
232 | 0 | new uint8_t[n * refine_index->sa_code_size()]); |
233 | 0 | for (size_t i = 0; i < n; i++) { |
234 | 0 | memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2); |
235 | 0 | } |
236 | |
|
237 | 0 | refine_index->sa_decode(n, tmp2.get(), x); |
238 | 0 | } |
239 | | |
240 | 0 | IndexRefine::~IndexRefine() { |
241 | 0 | if (own_fields) |
242 | 0 | delete base_index; |
243 | 0 | if (own_refine_index) |
244 | 0 | delete refine_index; |
245 | 0 | } |
246 | | |
247 | | /*************************************************** |
248 | | * IndexRefineFlat |
249 | | ***************************************************/ |
250 | | |
251 | | IndexRefineFlat::IndexRefineFlat(Index* base_index) |
252 | 0 | : IndexRefine( |
253 | 0 | base_index, |
254 | 0 | new IndexFlat(base_index->d, base_index->metric_type)) { |
255 | 0 | is_trained = base_index->is_trained; |
256 | 0 | own_refine_index = true; |
257 | 0 | FAISS_THROW_IF_NOT_MSG( |
258 | 0 | base_index->ntotal == 0, |
259 | 0 | "base_index should be empty in the beginning"); |
260 | 0 | } |
261 | | |
262 | | IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb) |
263 | 0 | : IndexRefine(base_index, nullptr) { |
264 | 0 | is_trained = base_index->is_trained; |
265 | 0 | refine_index = new IndexFlat(base_index->d, base_index->metric_type); |
266 | 0 | own_refine_index = true; |
267 | 0 | refine_index->add(base_index->ntotal, xb); |
268 | 0 | } |
269 | | |
270 | 0 | IndexRefineFlat::IndexRefineFlat() : IndexRefine() { |
271 | 0 | own_refine_index = true; |
272 | 0 | } |
273 | | |
274 | | void IndexRefineFlat::search( |
275 | | idx_t n, |
276 | | const float* x, |
277 | | idx_t k, |
278 | | float* distances, |
279 | | idx_t* labels, |
280 | 0 | const SearchParameters* params_in) const { |
281 | 0 | const IndexRefineSearchParameters* params = nullptr; |
282 | 0 | if (params_in) { |
283 | 0 | params = dynamic_cast<const IndexRefineSearchParameters*>(params_in); |
284 | 0 | FAISS_THROW_IF_NOT_MSG( |
285 | 0 | params, "IndexRefineFlat params have incorrect type"); |
286 | 0 | } |
287 | | |
288 | 0 | idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor) |
289 | 0 | : idx_t(k * k_factor); |
290 | 0 | SearchParameters* base_index_params = |
291 | 0 | (params != nullptr) ? params->base_index_params : nullptr; |
292 | |
|
293 | 0 | FAISS_THROW_IF_NOT(k_base >= k); |
294 | | |
295 | 0 | FAISS_THROW_IF_NOT(base_index); |
296 | 0 | FAISS_THROW_IF_NOT(refine_index); |
297 | | |
298 | 0 | FAISS_THROW_IF_NOT(k > 0); |
299 | 0 | FAISS_THROW_IF_NOT(is_trained); |
300 | 0 | idx_t* base_labels = labels; |
301 | 0 | float* base_distances = distances; |
302 | 0 | std::unique_ptr<idx_t[]> del1; |
303 | 0 | std::unique_ptr<float[]> del2; |
304 | |
|
305 | 0 | if (k != k_base) { |
306 | 0 | base_labels = new idx_t[n * k_base]; |
307 | 0 | del1.reset(base_labels); |
308 | 0 | base_distances = new float[n * k_base]; |
309 | 0 | del2.reset(base_distances); |
310 | 0 | } |
311 | |
|
312 | 0 | base_index->search( |
313 | 0 | n, x, k_base, base_distances, base_labels, base_index_params); |
314 | |
|
315 | 0 | for (int i = 0; i < n * k_base; i++) |
316 | 0 | assert(base_labels[i] >= -1 && base_labels[i] < ntotal); |
317 | | |
318 | | // compute refined distances |
319 | 0 | auto rf = dynamic_cast<const IndexFlat*>(refine_index); |
320 | 0 | FAISS_THROW_IF_NOT(rf); |
321 | | |
322 | 0 | rf->compute_distance_subset(n, x, k_base, base_distances, base_labels); |
323 | | |
324 | | // sort and store result |
325 | 0 | if (metric_type == METRIC_L2) { |
326 | 0 | typedef CMax<float, idx_t> C; |
327 | 0 | reorder_2_heaps<C>( |
328 | 0 | n, k, labels, distances, k_base, base_labels, base_distances); |
329 | |
|
330 | 0 | } else if (metric_type == METRIC_INNER_PRODUCT) { |
331 | 0 | typedef CMin<float, idx_t> C; |
332 | 0 | reorder_2_heaps<C>( |
333 | 0 | n, k, labels, distances, k_base, base_labels, base_distances); |
334 | 0 | } else { |
335 | 0 | FAISS_THROW_MSG("Metric type not supported"); |
336 | 0 | } |
337 | 0 | } |
338 | | |
339 | | } // namespace faiss |