/root/doris/contrib/faiss/faiss/IndexIVFFlat.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/IndexIVFFlat.h> |
11 | | |
12 | | #include <omp.h> |
13 | | |
14 | | #include <cinttypes> |
15 | | #include <cstdio> |
16 | | |
17 | | #include <faiss/IndexFlat.h> |
18 | | |
19 | | #include <faiss/impl/AuxIndexStructures.h> |
20 | | #include <faiss/impl/IDSelector.h> |
21 | | |
22 | | #include <faiss/impl/FaissAssert.h> |
23 | | #include <faiss/utils/distances.h> |
24 | | #include <faiss/utils/utils.h> |
25 | | |
26 | | namespace faiss { |
27 | | |
28 | | /***************************************** |
29 | | * IndexIVFFlat implementation |
30 | | ******************************************/ |
31 | | |
32 | | IndexIVFFlat::IndexIVFFlat( |
33 | | Index* quantizer, |
34 | | size_t d, |
35 | | size_t nlist, |
36 | | MetricType metric) |
37 | 0 | : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) { |
38 | 0 | code_size = sizeof(float) * d; |
39 | 0 | by_residual = false; |
40 | 0 | } |
41 | | |
42 | 0 | IndexIVFFlat::IndexIVFFlat() { |
43 | 0 | by_residual = false; |
44 | 0 | } |
45 | | |
46 | | void IndexIVFFlat::add_core( |
47 | | idx_t n, |
48 | | const float* x, |
49 | | const idx_t* xids, |
50 | | const idx_t* coarse_idx, |
51 | 0 | void* inverted_list_context) { |
52 | 0 | FAISS_THROW_IF_NOT(is_trained); |
53 | 0 | FAISS_THROW_IF_NOT(coarse_idx); |
54 | 0 | FAISS_THROW_IF_NOT(!by_residual); |
55 | 0 | assert(invlists); |
56 | 0 | direct_map.check_can_add(xids); |
57 | |
|
58 | 0 | int64_t n_add = 0; |
59 | |
|
60 | 0 | DirectMapAdd dm_adder(direct_map, n, xids); |
61 | |
|
62 | 0 | #pragma omp parallel reduction(+ : n_add) |
63 | 0 | { |
64 | 0 | int nt = omp_get_num_threads(); |
65 | 0 | int rank = omp_get_thread_num(); |
66 | | |
67 | | // each thread takes care of a subset of lists |
68 | 0 | for (size_t i = 0; i < n; i++) { |
69 | 0 | idx_t list_no = coarse_idx[i]; |
70 | |
|
71 | 0 | if (list_no >= 0 && list_no % nt == rank) { |
72 | 0 | idx_t id = xids ? xids[i] : ntotal + i; |
73 | 0 | const float* xi = x + i * d; |
74 | 0 | size_t offset = invlists->add_entry( |
75 | 0 | list_no, id, (const uint8_t*)xi, inverted_list_context); |
76 | 0 | dm_adder.add(i, list_no, offset); |
77 | 0 | n_add++; |
78 | 0 | } else if (rank == 0 && list_no == -1) { |
79 | 0 | dm_adder.add(i, -1, 0); |
80 | 0 | } |
81 | 0 | } |
82 | 0 | } |
83 | |
|
84 | 0 | if (verbose) { |
85 | 0 | printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64 |
86 | 0 | " vectors\n", |
87 | 0 | n_add, |
88 | 0 | n); |
89 | 0 | } |
90 | 0 | ntotal += n; |
91 | 0 | } |
92 | | |
93 | | void IndexIVFFlat::encode_vectors( |
94 | | idx_t n, |
95 | | const float* x, |
96 | | const idx_t* list_nos, |
97 | | uint8_t* codes, |
98 | 0 | bool include_listnos) const { |
99 | 0 | FAISS_THROW_IF_NOT(!by_residual); |
100 | 0 | if (!include_listnos) { |
101 | 0 | memcpy(codes, x, code_size * n); |
102 | 0 | } else { |
103 | 0 | size_t coarse_size = coarse_code_size(); |
104 | 0 | for (size_t i = 0; i < n; i++) { |
105 | 0 | int64_t list_no = list_nos[i]; |
106 | 0 | uint8_t* code = codes + i * (code_size + coarse_size); |
107 | 0 | const float* xi = x + i * d; |
108 | 0 | if (list_no >= 0) { |
109 | 0 | encode_listno(list_no, code); |
110 | 0 | memcpy(code + coarse_size, xi, code_size); |
111 | 0 | } else { |
112 | 0 | memset(code, 0, code_size + coarse_size); |
113 | 0 | } |
114 | 0 | } |
115 | 0 | } |
116 | 0 | } |
117 | | |
118 | 0 | void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { |
119 | 0 | size_t coarse_size = coarse_code_size(); |
120 | 0 | for (size_t i = 0; i < n; i++) { |
121 | 0 | const uint8_t* code = bytes + i * (code_size + coarse_size); |
122 | 0 | float* xi = x + i * d; |
123 | 0 | memcpy(xi, code + coarse_size, code_size); |
124 | 0 | } |
125 | 0 | } |
126 | | |
127 | | namespace { |
128 | | |
129 | | template <MetricType metric, class C, bool use_sel> |
130 | | struct IVFFlatScanner : InvertedListScanner { |
131 | | size_t d; |
132 | | |
133 | | IVFFlatScanner(size_t d, bool store_pairs, const IDSelector* sel) |
134 | 0 | : InvertedListScanner(store_pairs, sel), d(d) { |
135 | 0 | keep_max = is_similarity_metric(metric); |
136 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EEC2EmbPKNS_10IDSelectorE Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EEC2EmbPKNS_10IDSelectorE Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EEC2EmbPKNS_10IDSelectorE Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EEC2EmbPKNS_10IDSelectorE |
137 | | |
138 | | const float* xi; |
139 | 0 | void set_query(const float* query) override { |
140 | 0 | this->xi = query; |
141 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE9set_queryEPKf Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE9set_queryEPKf Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE9set_queryEPKf Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE9set_queryEPKf |
142 | | |
143 | 0 | void set_list(idx_t list_no, float /* coarse_dis */) override { |
144 | 0 | this->list_no = list_no; |
145 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE8set_listElf Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE8set_listElf Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE8set_listElf Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE8set_listElf |
146 | | |
147 | 0 | float distance_to_code(const uint8_t* code) const override { |
148 | 0 | const float* yj = (float*)code; |
149 | 0 | float dis = metric == METRIC_INNER_PRODUCT |
150 | 0 | ? fvec_inner_product(xi, yj, d) |
151 | 0 | : fvec_L2sqr(xi, yj, d); |
152 | 0 | return dis; |
153 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE16distance_to_codeEPKh Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE16distance_to_codeEPKh Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE16distance_to_codeEPKh Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE16distance_to_codeEPKh |
154 | | |
155 | | size_t scan_codes( |
156 | | size_t list_size, |
157 | | const uint8_t* codes, |
158 | | const idx_t* ids, |
159 | | float* simi, |
160 | | idx_t* idxi, |
161 | 0 | size_t k) const override { |
162 | 0 | const float* list_vecs = (const float*)codes; |
163 | 0 | size_t nup = 0; |
164 | 0 | for (size_t j = 0; j < list_size; j++) { |
165 | 0 | const float* yj = list_vecs + d * j; |
166 | 0 | if (use_sel && !sel->is_member(ids[j])) { |
167 | 0 | continue; |
168 | 0 | } |
169 | 0 | float dis = metric == METRIC_INNER_PRODUCT |
170 | 0 | ? fvec_inner_product(xi, yj, d) |
171 | 0 | : fvec_L2sqr(xi, yj, d); |
172 | 0 | if (C::cmp(simi[0], dis)) { |
173 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
174 | 0 | heap_replace_top<C>(k, simi, idxi, dis, id); |
175 | 0 | nup++; |
176 | 0 | } |
177 | 0 | } |
178 | 0 | return nup; |
179 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE10scan_codesEmPKhPKlPfPlm Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE10scan_codesEmPKhPKlPfPlm Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE10scan_codesEmPKhPKlPfPlm Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE10scan_codesEmPKhPKlPfPlm |
180 | | |
181 | | void scan_codes_range( |
182 | | size_t list_size, |
183 | | const uint8_t* codes, |
184 | | const idx_t* ids, |
185 | | float radius, |
186 | 0 | RangeQueryResult& res) const override { |
187 | 0 | const float* list_vecs = (const float*)codes; |
188 | 0 | for (size_t j = 0; j < list_size; j++) { |
189 | 0 | const float* yj = list_vecs + d * j; |
190 | 0 | if (use_sel && !sel->is_member(ids[j])) { |
191 | 0 | continue; |
192 | 0 | } |
193 | 0 | float dis = metric == METRIC_INNER_PRODUCT |
194 | 0 | ? fvec_inner_product(xi, yj, d) |
195 | 0 | : fvec_L2sqr(xi, yj, d); |
196 | 0 | if (C::cmp(radius, dis)) { |
197 | 0 | int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; |
198 | 0 | res.add(dis, id); |
199 | 0 | } |
200 | 0 | } |
201 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE |
202 | | }; |
203 | | |
204 | | template <bool use_sel> |
205 | | InvertedListScanner* get_InvertedListScanner1( |
206 | | const IndexIVFFlat* ivf, |
207 | | bool store_pairs, |
208 | 0 | const IDSelector* sel) { |
209 | 0 | if (ivf->metric_type == METRIC_INNER_PRODUCT) { |
210 | 0 | return new IVFFlatScanner< |
211 | 0 | METRIC_INNER_PRODUCT, |
212 | 0 | CMin<float, int64_t>, |
213 | 0 | use_sel>(ivf->d, store_pairs, sel); |
214 | 0 | } else if (ivf->metric_type == METRIC_L2) { |
215 | 0 | return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>( |
216 | 0 | ivf->d, store_pairs, sel); |
217 | 0 | } else { |
218 | 0 | FAISS_THROW_MSG("metric type not supported"); |
219 | 0 | } |
220 | 0 | } Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_124get_InvertedListScanner1ILb1EEEPNS_19InvertedListScannerEPKNS_12IndexIVFFlatEbPKNS_10IDSelectorE Unexecuted instantiation: IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_124get_InvertedListScanner1ILb0EEEPNS_19InvertedListScannerEPKNS_12IndexIVFFlatEbPKNS_10IDSelectorE |
221 | | |
222 | | } // anonymous namespace |
223 | | |
224 | | InvertedListScanner* IndexIVFFlat::get_InvertedListScanner( |
225 | | bool store_pairs, |
226 | | const IDSelector* sel, |
227 | 0 | const IVFSearchParameters*) const { |
228 | 0 | if (sel) { |
229 | 0 | return get_InvertedListScanner1<true>(this, store_pairs, sel); |
230 | 0 | } else { |
231 | 0 | return get_InvertedListScanner1<false>(this, store_pairs, sel); |
232 | 0 | } |
233 | 0 | } |
234 | | |
235 | | void IndexIVFFlat::reconstruct_from_offset( |
236 | | int64_t list_no, |
237 | | int64_t offset, |
238 | 0 | float* recons) const { |
239 | 0 | memcpy(recons, invlists->get_single_code(list_no, offset), code_size); |
240 | 0 | } |
241 | | |
242 | | /***************************************** |
243 | | * IndexIVFFlatDedup implementation |
244 | | ******************************************/ |
245 | | |
246 | | IndexIVFFlatDedup::IndexIVFFlatDedup( |
247 | | Index* quantizer, |
248 | | size_t d, |
249 | | size_t nlist_, |
250 | | MetricType metric_type) |
251 | 0 | : IndexIVFFlat(quantizer, d, nlist_, metric_type) {} |
252 | | |
253 | 0 | void IndexIVFFlatDedup::train(idx_t n, const float* x) { |
254 | 0 | std::unordered_map<uint64_t, idx_t> map; |
255 | 0 | std::unique_ptr<float[]> x2(new float[n * d]); |
256 | |
|
257 | 0 | int64_t n2 = 0; |
258 | 0 | for (int64_t i = 0; i < n; i++) { |
259 | 0 | uint64_t hash = hash_bytes((uint8_t*)(x + i * d), code_size); |
260 | 0 | if (map.count(hash) && |
261 | 0 | !memcmp(x2.get() + map[hash] * d, x + i * d, code_size)) { |
262 | | // is duplicate, skip |
263 | 0 | } else { |
264 | 0 | map[hash] = n2; |
265 | 0 | memcpy(x2.get() + n2 * d, x + i * d, code_size); |
266 | 0 | n2++; |
267 | 0 | } |
268 | 0 | } |
269 | 0 | if (verbose) { |
270 | 0 | printf("IndexIVFFlatDedup::train: train on %" PRId64 |
271 | 0 | " points after dedup " |
272 | 0 | "(was %" PRId64 " points)\n", |
273 | 0 | n2, |
274 | 0 | n); |
275 | 0 | } |
276 | 0 | IndexIVFFlat::train(n2, x2.get()); |
277 | 0 | } |
278 | | |
279 | | void IndexIVFFlatDedup::add_with_ids( |
280 | | idx_t na, |
281 | | const float* x, |
282 | 0 | const idx_t* xids) { |
283 | 0 | FAISS_THROW_IF_NOT(is_trained); |
284 | 0 | assert(invlists); |
285 | 0 | FAISS_THROW_IF_NOT_MSG( |
286 | 0 | direct_map.no(), "IVFFlatDedup not implemented with direct_map"); |
287 | 0 | std::unique_ptr<int64_t[]> idx(new int64_t[na]); |
288 | 0 | quantizer->assign(na, x, idx.get()); |
289 | |
|
290 | 0 | int64_t n_add = 0, n_dup = 0; |
291 | |
|
292 | 0 | #pragma omp parallel reduction(+ : n_add, n_dup) |
293 | 0 | { |
294 | 0 | int nt = omp_get_num_threads(); |
295 | 0 | int rank = omp_get_thread_num(); |
296 | | |
297 | | // each thread takes care of a subset of lists |
298 | 0 | for (size_t i = 0; i < na; i++) { |
299 | 0 | int64_t list_no = idx[i]; |
300 | |
|
301 | 0 | if (list_no < 0 || list_no % nt != rank) { |
302 | 0 | continue; |
303 | 0 | } |
304 | | |
305 | 0 | idx_t id = xids ? xids[i] : ntotal + i; |
306 | 0 | const float* xi = x + i * d; |
307 | | |
308 | | // search if there is already an entry with that id |
309 | 0 | InvertedLists::ScopedCodes codes(invlists, list_no); |
310 | |
|
311 | 0 | int64_t n = invlists->list_size(list_no); |
312 | 0 | int64_t offset = -1; |
313 | 0 | for (int64_t o = 0; o < n; o++) { |
314 | 0 | if (!memcmp(codes.get() + o * code_size, xi, code_size)) { |
315 | 0 | offset = o; |
316 | 0 | break; |
317 | 0 | } |
318 | 0 | } |
319 | |
|
320 | 0 | if (offset == -1) { // not found |
321 | 0 | invlists->add_entry(list_no, id, (const uint8_t*)xi); |
322 | 0 | } else { |
323 | | // mark equivalence |
324 | 0 | idx_t id2 = invlists->get_single_id(list_no, offset); |
325 | 0 | std::pair<idx_t, idx_t> pair(id2, id); |
326 | |
|
327 | 0 | #pragma omp critical |
328 | | // executed by one thread at a time |
329 | 0 | instances.insert(pair); |
330 | |
|
331 | 0 | n_dup++; |
332 | 0 | } |
333 | 0 | n_add++; |
334 | 0 | } |
335 | 0 | } |
336 | 0 | if (verbose) { |
337 | 0 | printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64 |
338 | 0 | " vectors" |
339 | 0 | " (out of which %" PRId64 " are duplicates)\n", |
340 | 0 | n_add, |
341 | 0 | na, |
342 | 0 | n_dup); |
343 | 0 | } |
344 | 0 | ntotal += n_add; |
345 | 0 | } |
346 | | |
347 | | void IndexIVFFlatDedup::search_preassigned( |
348 | | idx_t n, |
349 | | const float* x, |
350 | | idx_t k, |
351 | | const idx_t* assign, |
352 | | const float* centroid_dis, |
353 | | float* distances, |
354 | | idx_t* labels, |
355 | | bool store_pairs, |
356 | | const IVFSearchParameters* params, |
357 | 0 | IndexIVFStats* stats) const { |
358 | 0 | FAISS_THROW_IF_NOT_MSG( |
359 | 0 | !store_pairs, "store_pairs not supported in IVFDedup"); |
360 | | |
361 | 0 | IndexIVFFlat::search_preassigned( |
362 | 0 | n, x, k, assign, centroid_dis, distances, labels, false, params); |
363 | |
|
364 | 0 | std::vector<idx_t> labels2(k); |
365 | 0 | std::vector<float> dis2(k); |
366 | |
|
367 | 0 | for (int64_t i = 0; i < n; i++) { |
368 | 0 | idx_t* labels1 = labels + i * k; |
369 | 0 | float* dis1 = distances + i * k; |
370 | 0 | int64_t j = 0; |
371 | 0 | for (; j < k; j++) { |
372 | 0 | if (instances.find(labels1[j]) != instances.end()) { |
373 | | // a duplicate: special handling |
374 | 0 | break; |
375 | 0 | } |
376 | 0 | } |
377 | 0 | if (j < k) { |
378 | | // there are duplicates, special handling |
379 | 0 | int64_t j0 = j; |
380 | 0 | int64_t rp = j; |
381 | 0 | while (j < k) { |
382 | 0 | auto range = instances.equal_range(labels1[rp]); |
383 | 0 | float dis = dis1[rp]; |
384 | 0 | labels2[j] = labels1[rp]; |
385 | 0 | dis2[j] = dis; |
386 | 0 | j++; |
387 | 0 | for (auto it = range.first; j < k && it != range.second; ++it) { |
388 | 0 | labels2[j] = it->second; |
389 | 0 | dis2[j] = dis; |
390 | 0 | j++; |
391 | 0 | } |
392 | 0 | rp++; |
393 | 0 | } |
394 | 0 | memcpy(labels1 + j0, |
395 | 0 | labels2.data() + j0, |
396 | 0 | sizeof(labels1[0]) * (k - j0)); |
397 | 0 | memcpy(dis1 + j0, dis2.data() + j0, sizeof(dis2[0]) * (k - j0)); |
398 | 0 | } |
399 | 0 | } |
400 | 0 | } |
401 | | |
402 | 0 | size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel) { |
403 | 0 | std::unordered_map<idx_t, idx_t> replace; |
404 | 0 | std::vector<std::pair<idx_t, idx_t>> toadd; |
405 | 0 | for (auto it = instances.begin(); it != instances.end();) { |
406 | 0 | if (sel.is_member(it->first)) { |
407 | | // then we erase this entry |
408 | 0 | if (!sel.is_member(it->second)) { |
409 | | // if the second is not erased |
410 | 0 | if (replace.count(it->first) == 0) { |
411 | 0 | replace[it->first] = it->second; |
412 | 0 | } else { // remember we should add an element |
413 | 0 | std::pair<idx_t, idx_t> new_entry( |
414 | 0 | replace[it->first], it->second); |
415 | 0 | toadd.push_back(new_entry); |
416 | 0 | } |
417 | 0 | } |
418 | 0 | it = instances.erase(it); |
419 | 0 | } else { |
420 | 0 | if (sel.is_member(it->second)) { |
421 | 0 | it = instances.erase(it); |
422 | 0 | } else { |
423 | 0 | ++it; |
424 | 0 | } |
425 | 0 | } |
426 | 0 | } |
427 | |
|
428 | 0 | instances.insert(toadd.begin(), toadd.end()); |
429 | | |
430 | | // mostly copied from IndexIVF.cpp |
431 | |
|
432 | 0 | FAISS_THROW_IF_NOT_MSG( |
433 | 0 | direct_map.no(), "direct map remove not implemented"); |
434 | | |
435 | 0 | std::vector<int64_t> toremove(nlist); |
436 | |
|
437 | 0 | #pragma omp parallel for |
438 | 0 | for (int64_t i = 0; i < nlist; i++) { |
439 | 0 | int64_t l0 = invlists->list_size(i), l = l0, j = 0; |
440 | 0 | InvertedLists::ScopedIds idsi(invlists, i); |
441 | 0 | while (j < l) { |
442 | 0 | if (sel.is_member(idsi[j])) { |
443 | 0 | if (replace.count(idsi[j]) == 0) { |
444 | 0 | l--; |
445 | 0 | invlists->update_entry( |
446 | 0 | i, |
447 | 0 | j, |
448 | 0 | invlists->get_single_id(i, l), |
449 | 0 | InvertedLists::ScopedCodes(invlists, i, l).get()); |
450 | 0 | } else { |
451 | 0 | invlists->update_entry( |
452 | 0 | i, |
453 | 0 | j, |
454 | 0 | replace[idsi[j]], |
455 | 0 | InvertedLists::ScopedCodes(invlists, i, j).get()); |
456 | 0 | j++; |
457 | 0 | } |
458 | 0 | } else { |
459 | 0 | j++; |
460 | 0 | } |
461 | 0 | } |
462 | 0 | toremove[i] = l0 - l; |
463 | 0 | } |
464 | | // this will not run well in parallel on ondisk because of possible shrinks |
465 | 0 | int64_t nremove = 0; |
466 | 0 | for (int64_t i = 0; i < nlist; i++) { |
467 | 0 | if (toremove[i] > 0) { |
468 | 0 | nremove += toremove[i]; |
469 | 0 | invlists->resize(i, invlists->list_size(i) - toremove[i]); |
470 | 0 | } |
471 | 0 | } |
472 | 0 | ntotal -= nremove; |
473 | 0 | return nremove; |
474 | 0 | } |
475 | | |
476 | | void IndexIVFFlatDedup::range_search( |
477 | | idx_t, |
478 | | const float*, |
479 | | float, |
480 | | RangeSearchResult*, |
481 | 0 | const SearchParameters*) const { |
482 | 0 | FAISS_THROW_MSG("not implemented"); |
483 | 0 | } |
484 | | |
485 | 0 | void IndexIVFFlatDedup::update_vectors(int, const idx_t*, const float*) { |
486 | 0 | FAISS_THROW_MSG("not implemented"); |
487 | 0 | } |
488 | | |
489 | | void IndexIVFFlatDedup::reconstruct_from_offset(int64_t, int64_t, float*) |
490 | 0 | const { |
491 | 0 | FAISS_THROW_MSG("not implemented"); |
492 | 0 | } |
493 | | |
494 | | } // namespace faiss |