contrib/faiss/faiss/invlists/InvertedLists.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/invlists/InvertedLists.h> |
9 | | |
10 | | #include <cstdio> |
11 | | #include <memory> |
12 | | |
13 | | #include <faiss/impl/FaissAssert.h> |
14 | | #include <faiss/utils/utils.h> |
15 | | |
16 | | namespace faiss { |
17 | | |
18 | 0 | InvertedListsIterator::~InvertedListsIterator() {} |
19 | | |
20 | | /***************************************** |
21 | | * InvertedLists implementation |
22 | | ******************************************/ |
23 | | |
24 | | InvertedLists::InvertedLists(size_t nlist, size_t code_size) |
25 | 41 | : nlist(nlist), code_size(code_size) {} |
26 | | |
27 | 41 | InvertedLists::~InvertedLists() {} |
28 | | |
29 | 0 | idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const { |
30 | 0 | assert(offset < list_size(list_no)); |
31 | 0 | const idx_t* ids = get_ids(list_no); |
32 | 0 | idx_t id = ids[offset]; |
33 | 0 | release_ids(list_no, ids); |
34 | 0 | return id; |
35 | 0 | } |
36 | | |
37 | 122 | void InvertedLists::release_codes(size_t, const uint8_t*) const {} |
38 | | |
39 | 122 | void InvertedLists::release_ids(size_t, const idx_t*) const {} |
40 | | |
41 | 42 | void InvertedLists::prefetch_lists(const idx_t*, int) const {} |
42 | | |
43 | | const uint8_t* InvertedLists::get_single_code(size_t list_no, size_t offset) |
44 | 0 | const { |
45 | 0 | assert(offset < list_size(list_no)); |
46 | 0 | return get_codes(list_no) + offset * code_size; |
47 | 0 | } |
48 | | |
49 | | size_t InvertedLists::add_entry( |
50 | | size_t list_no, |
51 | | idx_t theid, |
52 | | const uint8_t* code, |
53 | 3.86k | void* /*inverted_list_context*/) { |
54 | 3.86k | return add_entries(list_no, 1, &theid, code); |
55 | 3.86k | } |
56 | | |
57 | | void InvertedLists::update_entry( |
58 | | size_t list_no, |
59 | | size_t offset, |
60 | | idx_t id, |
61 | 0 | const uint8_t* code) { |
62 | 0 | update_entries(list_no, offset, 1, &id, code); |
63 | 0 | } |
64 | | |
65 | 0 | void InvertedLists::reset() { |
66 | 0 | for (size_t i = 0; i < nlist; i++) { |
67 | 0 | resize(i, 0); |
68 | 0 | } |
69 | 0 | } |
70 | | |
71 | 0 | void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) { |
72 | 0 | #pragma omp parallel for |
73 | 0 | for (idx_t i = 0; i < nlist; i++) { |
74 | 0 | size_t list_size = oivf->list_size(i); |
75 | 0 | ScopedIds ids(oivf, i); |
76 | 0 | if (add_id == 0) { |
77 | 0 | add_entries(i, list_size, ids.get(), ScopedCodes(oivf, i).get()); |
78 | 0 | } else { |
79 | 0 | std::vector<idx_t> new_ids(list_size); |
80 | |
|
81 | 0 | for (size_t j = 0; j < list_size; j++) { |
82 | 0 | new_ids[j] = ids[j] + add_id; |
83 | 0 | } |
84 | 0 | add_entries( |
85 | 0 | i, list_size, new_ids.data(), ScopedCodes(oivf, i).get()); |
86 | 0 | } |
87 | 0 | oivf->resize(i, 0); |
88 | 0 | } |
89 | 0 | } |
90 | | |
91 | | size_t InvertedLists::copy_subset_to( |
92 | | InvertedLists& oivf, |
93 | | subset_type_t subset_type, |
94 | | idx_t a1, |
95 | 0 | idx_t a2) const { |
96 | 0 | FAISS_THROW_IF_NOT(nlist == oivf.nlist); |
97 | 0 | FAISS_THROW_IF_NOT(code_size == oivf.code_size); |
98 | 0 | FAISS_THROW_IF_NOT_FMT( |
99 | 0 | subset_type >= 0 && subset_type <= 4, |
100 | 0 | "subset type %d not implemented", |
101 | 0 | subset_type); |
102 | 0 | size_t accu_n = 0; |
103 | 0 | size_t accu_a1 = 0; |
104 | 0 | size_t accu_a2 = 0; |
105 | 0 | size_t n_added = 0; |
106 | |
|
107 | 0 | size_t ntotal = 0; |
108 | 0 | if (subset_type == 2) { |
109 | 0 | ntotal = compute_ntotal(); |
110 | 0 | } |
111 | |
|
112 | 0 | for (idx_t list_no = 0; list_no < nlist; list_no++) { |
113 | 0 | size_t n = list_size(list_no); |
114 | 0 | ScopedIds ids_in(this, list_no); |
115 | |
|
116 | 0 | if (subset_type == SUBSET_TYPE_ID_RANGE) { |
117 | 0 | for (idx_t i = 0; i < n; i++) { |
118 | 0 | idx_t id = ids_in[i]; |
119 | 0 | if (a1 <= id && id < a2) { |
120 | 0 | oivf.add_entry( |
121 | 0 | list_no, |
122 | 0 | get_single_id(list_no, i), |
123 | 0 | ScopedCodes(this, list_no, i).get()); |
124 | 0 | n_added++; |
125 | 0 | } |
126 | 0 | } |
127 | 0 | } else if (subset_type == SUBSET_TYPE_ID_MOD) { |
128 | 0 | for (idx_t i = 0; i < n; i++) { |
129 | 0 | idx_t id = ids_in[i]; |
130 | 0 | if (id % a1 == a2) { |
131 | 0 | oivf.add_entry( |
132 | 0 | list_no, |
133 | 0 | get_single_id(list_no, i), |
134 | 0 | ScopedCodes(this, list_no, i).get()); |
135 | 0 | n_added++; |
136 | 0 | } |
137 | 0 | } |
138 | 0 | } else if (subset_type == SUBSET_TYPE_ELEMENT_RANGE) { |
139 | | // see what is allocated to a1 and to a2 |
140 | 0 | size_t next_accu_n = accu_n + n; |
141 | 0 | size_t next_accu_a1 = next_accu_n * a1 / ntotal; |
142 | 0 | size_t i1 = next_accu_a1 - accu_a1; |
143 | 0 | size_t next_accu_a2 = next_accu_n * a2 / ntotal; |
144 | 0 | size_t i2 = next_accu_a2 - accu_a2; |
145 | |
|
146 | 0 | for (idx_t i = i1; i < i2; i++) { |
147 | 0 | oivf.add_entry( |
148 | 0 | list_no, |
149 | 0 | get_single_id(list_no, i), |
150 | 0 | ScopedCodes(this, list_no, i).get()); |
151 | 0 | } |
152 | |
|
153 | 0 | n_added += i2 - i1; |
154 | 0 | accu_a1 = next_accu_a1; |
155 | 0 | accu_a2 = next_accu_a2; |
156 | 0 | } else if (subset_type == SUBSET_TYPE_INVLIST_FRACTION) { |
157 | 0 | size_t i1 = n * a2 / a1; |
158 | 0 | size_t i2 = n * (a2 + 1) / a1; |
159 | |
|
160 | 0 | for (idx_t i = i1; i < i2; i++) { |
161 | 0 | oivf.add_entry( |
162 | 0 | list_no, |
163 | 0 | get_single_id(list_no, i), |
164 | 0 | ScopedCodes(this, list_no, i).get()); |
165 | 0 | } |
166 | |
|
167 | 0 | n_added += i2 - i1; |
168 | 0 | } else if (subset_type == SUBSET_TYPE_INVLIST) { |
169 | 0 | if (list_no >= a1 && list_no < a2) { |
170 | 0 | oivf.add_entries( |
171 | 0 | list_no, |
172 | 0 | n, |
173 | 0 | ScopedIds(this, list_no).get(), |
174 | 0 | ScopedCodes(this, list_no).get()); |
175 | 0 | n_added += n; |
176 | 0 | } |
177 | 0 | } |
178 | 0 | accu_n += n; |
179 | 0 | } |
180 | 0 | return n_added; |
181 | 0 | } |
182 | | |
183 | 0 | double InvertedLists::imbalance_factor() const { |
184 | 0 | std::vector<int64_t> hist(nlist); |
185 | |
|
186 | 0 | for (size_t i = 0; i < nlist; i++) { |
187 | 0 | hist[i] = list_size(i); |
188 | 0 | } |
189 | |
|
190 | 0 | return faiss::imbalance_factor(nlist, hist.data()); |
191 | 0 | } |
192 | | |
193 | 0 | void InvertedLists::print_stats() const { |
194 | 0 | std::vector<int> sizes(40); |
195 | 0 | for (size_t i = 0; i < nlist; i++) { |
196 | 0 | for (size_t j = 0; j < sizes.size(); j++) { |
197 | 0 | if ((list_size(i) >> j) == 0) { |
198 | 0 | sizes[j]++; |
199 | 0 | break; |
200 | 0 | } |
201 | 0 | } |
202 | 0 | } |
203 | 0 | for (size_t i = 0; i < sizes.size(); i++) { |
204 | 0 | if (sizes[i]) { |
205 | 0 | printf("list size in < %zu: %d instances\n", |
206 | 0 | static_cast<size_t>(1) << i, |
207 | 0 | sizes[i]); |
208 | 0 | } |
209 | 0 | } |
210 | 0 | } |
211 | | |
212 | 0 | size_t InvertedLists::compute_ntotal() const { |
213 | 0 | size_t tot = 0; |
214 | 0 | for (size_t i = 0; i < nlist; i++) { |
215 | 0 | tot += list_size(i); |
216 | 0 | } |
217 | 0 | return tot; |
218 | 0 | } |
219 | | |
220 | | bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context) |
221 | 0 | const { |
222 | 0 | if (use_iterator) { |
223 | 0 | return !std::unique_ptr<InvertedListsIterator>( |
224 | 0 | get_iterator(list_no, inverted_list_context)) |
225 | 0 | ->is_available(); |
226 | 0 | } else { |
227 | 0 | FAISS_THROW_IF_NOT(inverted_list_context == nullptr); |
228 | 0 | return list_size(list_no) == 0; |
229 | 0 | } |
230 | 0 | } |
231 | | |
232 | | // implemnent iterator on top of get_codes / get_ids |
233 | | namespace { |
234 | | |
235 | | struct CodeArrayIterator : InvertedListsIterator { |
236 | | size_t list_size; |
237 | | size_t code_size; |
238 | | InvertedLists::ScopedCodes codes; |
239 | | InvertedLists::ScopedIds ids; |
240 | | size_t idx = 0; |
241 | | |
242 | | CodeArrayIterator(const InvertedLists* il, size_t list_no) |
243 | 0 | : list_size(il->list_size(list_no)), |
244 | 0 | code_size(il->code_size), |
245 | 0 | codes(il, list_no), |
246 | 0 | ids(il, list_no) {} |
247 | | |
248 | 0 | bool is_available() const override { |
249 | 0 | return idx < list_size; |
250 | 0 | } |
251 | 0 | void next() override { |
252 | 0 | idx++; |
253 | 0 | } |
254 | 0 | std::pair<idx_t, const uint8_t*> get_id_and_codes() override { |
255 | 0 | return {ids[idx], codes.get() + code_size * idx}; |
256 | 0 | } |
257 | | }; |
258 | | |
259 | | } // namespace |
260 | | |
261 | | InvertedListsIterator* InvertedLists::get_iterator( |
262 | | size_t list_no, |
263 | 0 | void* inverted_list_context) const { |
264 | 0 | FAISS_THROW_IF_NOT(inverted_list_context == nullptr); |
265 | 0 | return new CodeArrayIterator(this, list_no); |
266 | 0 | } |
267 | | |
268 | | /***************************************** |
269 | | * ArrayInvertedLists implementation |
270 | | ******************************************/ |
271 | | |
272 | | ArrayInvertedLists::ArrayInvertedLists(size_t nlist, size_t code_size) |
273 | 41 | : InvertedLists(nlist, code_size) { |
274 | 41 | ids.resize(nlist); |
275 | 41 | codes.resize(nlist); |
276 | 41 | } |
277 | | |
278 | | size_t ArrayInvertedLists::add_entries( |
279 | | size_t list_no, |
280 | | size_t n_entry, |
281 | | const idx_t* ids_in, |
282 | 3.86k | const uint8_t* code) { |
283 | 3.86k | if (n_entry == 0) |
284 | 0 | return 0; |
285 | 3.86k | assert(list_no < nlist); |
286 | 3.86k | size_t o = ids[list_no].size(); |
287 | 3.86k | ids[list_no].resize(o + n_entry); |
288 | 3.86k | memcpy(&ids[list_no][o], ids_in, sizeof(ids_in[0]) * n_entry); |
289 | 3.86k | codes[list_no].resize((o + n_entry) * code_size); |
290 | 3.86k | memcpy(&codes[list_no][o * code_size], code, code_size * n_entry); |
291 | 3.86k | return o; |
292 | 3.86k | } |
293 | | |
294 | 122 | size_t ArrayInvertedLists::list_size(size_t list_no) const { |
295 | 122 | assert(list_no < nlist); |
296 | 122 | return ids[list_no].size(); |
297 | 122 | } |
298 | | |
299 | | bool ArrayInvertedLists::is_empty(size_t list_no, void* inverted_list_context) |
300 | 132 | const { |
301 | 132 | FAISS_THROW_IF_NOT(inverted_list_context == nullptr); |
302 | 132 | return ids[list_no].size() == 0; |
303 | 132 | } |
304 | | |
305 | 122 | const uint8_t* ArrayInvertedLists::get_codes(size_t list_no) const { |
306 | 122 | assert(list_no < nlist); |
307 | 122 | return codes[list_no].data(); |
308 | 122 | } |
309 | | |
310 | 122 | const idx_t* ArrayInvertedLists::get_ids(size_t list_no) const { |
311 | 122 | assert(list_no < nlist); |
312 | 122 | return ids[list_no].data(); |
313 | 122 | } |
314 | | |
315 | 0 | void ArrayInvertedLists::resize(size_t list_no, size_t new_size) { |
316 | 0 | ids[list_no].resize(new_size); |
317 | 0 | codes[list_no].resize(new_size * code_size); |
318 | 0 | } |
319 | | |
320 | | void ArrayInvertedLists::update_entries( |
321 | | size_t list_no, |
322 | | size_t offset, |
323 | | size_t n_entry, |
324 | | const idx_t* ids_in, |
325 | 0 | const uint8_t* codes_in) { |
326 | 0 | assert(list_no < nlist); |
327 | 0 | assert(n_entry + offset <= ids[list_no].size()); |
328 | 0 | memcpy(&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry); |
329 | 0 | memcpy(&codes[list_no][offset * code_size], codes_in, code_size * n_entry); |
330 | 0 | } |
331 | | |
332 | 0 | void ArrayInvertedLists::permute_invlists(const idx_t* map) { |
333 | 0 | std::vector<MaybeOwnedVector<uint8_t>> new_codes(nlist); |
334 | 0 | std::vector<MaybeOwnedVector<idx_t>> new_ids(nlist); |
335 | |
|
336 | 0 | for (size_t i = 0; i < nlist; i++) { |
337 | 0 | size_t o = map[i]; |
338 | 0 | FAISS_THROW_IF_NOT(o < nlist); |
339 | 0 | std::swap(new_codes[i], codes[o]); |
340 | 0 | std::swap(new_ids[i], ids[o]); |
341 | 0 | } |
342 | 0 | std::swap(codes, new_codes); |
343 | 0 | std::swap(ids, new_ids); |
344 | 0 | } |
345 | | |
346 | 41 | ArrayInvertedLists::~ArrayInvertedLists() {} |
347 | | |
348 | | /***************************************************************** |
349 | | * Meta-inverted list implementations |
350 | | *****************************************************************/ |
351 | | |
352 | | size_t ReadOnlyInvertedLists::add_entries( |
353 | | size_t, |
354 | | size_t, |
355 | | const idx_t*, |
356 | 0 | const uint8_t*) { |
357 | 0 | FAISS_THROW_MSG("not implemented"); |
358 | 0 | } |
359 | | |
360 | | void ReadOnlyInvertedLists::update_entries( |
361 | | size_t, |
362 | | size_t, |
363 | | size_t, |
364 | | const idx_t*, |
365 | 0 | const uint8_t*) { |
366 | 0 | FAISS_THROW_MSG("not implemented"); |
367 | 0 | } |
368 | | |
369 | 0 | void ReadOnlyInvertedLists::resize(size_t, size_t) { |
370 | 0 | FAISS_THROW_MSG("not implemented"); |
371 | 0 | } |
372 | | |
373 | | /***************************************** |
374 | | * HStackInvertedLists implementation |
375 | | ******************************************/ |
376 | | |
377 | | HStackInvertedLists::HStackInvertedLists(int nil, const InvertedLists** ils_in) |
378 | 0 | : ReadOnlyInvertedLists( |
379 | 0 | nil > 0 ? ils_in[0]->nlist : 0, |
380 | 0 | nil > 0 ? ils_in[0]->code_size : 0) { |
381 | 0 | FAISS_THROW_IF_NOT(nil > 0); |
382 | 0 | for (int i = 0; i < nil; i++) { |
383 | 0 | ils.push_back(ils_in[i]); |
384 | 0 | FAISS_THROW_IF_NOT( |
385 | 0 | ils_in[i]->code_size == code_size && ils_in[i]->nlist == nlist); |
386 | 0 | } |
387 | 0 | } |
388 | | |
389 | 0 | size_t HStackInvertedLists::list_size(size_t list_no) const { |
390 | 0 | size_t sz = 0; |
391 | 0 | for (int i = 0; i < ils.size(); i++) { |
392 | 0 | const InvertedLists* il = ils[i]; |
393 | 0 | sz += il->list_size(list_no); |
394 | 0 | } |
395 | 0 | return sz; |
396 | 0 | } |
397 | | |
398 | 0 | const uint8_t* HStackInvertedLists::get_codes(size_t list_no) const { |
399 | 0 | uint8_t *codes = new uint8_t[code_size * list_size(list_no)], *c = codes; |
400 | |
|
401 | 0 | for (int i = 0; i < ils.size(); i++) { |
402 | 0 | const InvertedLists* il = ils[i]; |
403 | 0 | size_t sz = il->list_size(list_no) * code_size; |
404 | 0 | if (sz > 0) { |
405 | 0 | memcpy(c, ScopedCodes(il, list_no).get(), sz); |
406 | 0 | c += sz; |
407 | 0 | } |
408 | 0 | } |
409 | 0 | return codes; |
410 | 0 | } |
411 | | |
412 | | const uint8_t* HStackInvertedLists::get_single_code( |
413 | | size_t list_no, |
414 | 0 | size_t offset) const { |
415 | 0 | for (int i = 0; i < ils.size(); i++) { |
416 | 0 | const InvertedLists* il = ils[i]; |
417 | 0 | size_t sz = il->list_size(list_no); |
418 | 0 | if (offset < sz) { |
419 | | // here we have to copy the code, otherwise it will crash at dealloc |
420 | 0 | uint8_t* code = new uint8_t[code_size]; |
421 | 0 | memcpy(code, ScopedCodes(il, list_no, offset).get(), code_size); |
422 | 0 | return code; |
423 | 0 | } |
424 | 0 | offset -= sz; |
425 | 0 | } |
426 | 0 | FAISS_THROW_FMT("offset %zd unknown", offset); |
427 | 0 | } |
428 | | |
429 | 0 | void HStackInvertedLists::release_codes(size_t, const uint8_t* codes) const { |
430 | 0 | delete[] codes; |
431 | 0 | } |
432 | | |
433 | 0 | const idx_t* HStackInvertedLists::get_ids(size_t list_no) const { |
434 | 0 | idx_t *ids = new idx_t[list_size(list_no)], *c = ids; |
435 | |
|
436 | 0 | for (int i = 0; i < ils.size(); i++) { |
437 | 0 | const InvertedLists* il = ils[i]; |
438 | 0 | size_t sz = il->list_size(list_no); |
439 | 0 | if (sz > 0) { |
440 | 0 | memcpy(c, ScopedIds(il, list_no).get(), sz * sizeof(idx_t)); |
441 | 0 | c += sz; |
442 | 0 | } |
443 | 0 | } |
444 | 0 | return ids; |
445 | 0 | } |
446 | | |
447 | 0 | idx_t HStackInvertedLists::get_single_id(size_t list_no, size_t offset) const { |
448 | 0 | for (int i = 0; i < ils.size(); i++) { |
449 | 0 | const InvertedLists* il = ils[i]; |
450 | 0 | size_t sz = il->list_size(list_no); |
451 | 0 | if (offset < sz) { |
452 | 0 | return il->get_single_id(list_no, offset); |
453 | 0 | } |
454 | 0 | offset -= sz; |
455 | 0 | } |
456 | 0 | FAISS_THROW_FMT("offset %zd unknown", offset); |
457 | 0 | } |
458 | | |
459 | 0 | void HStackInvertedLists::release_ids(size_t, const idx_t* ids) const { |
460 | 0 | delete[] ids; |
461 | 0 | } |
462 | | |
463 | | void HStackInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist) |
464 | 0 | const { |
465 | 0 | for (int i = 0; i < ils.size(); i++) { |
466 | 0 | const InvertedLists* il = ils[i]; |
467 | 0 | il->prefetch_lists(list_nos, nlist); |
468 | 0 | } |
469 | 0 | } |
470 | | |
471 | | /***************************************** |
472 | | * SliceInvertedLists implementation |
473 | | ******************************************/ |
474 | | |
475 | | namespace { |
476 | | |
477 | 0 | idx_t translate_list_no(const SliceInvertedLists* sil, idx_t list_no) { |
478 | 0 | FAISS_THROW_IF_NOT(list_no >= 0 && list_no < sil->nlist); |
479 | 0 | return list_no + sil->i0; |
480 | 0 | } |
481 | | |
482 | | } // namespace |
483 | | |
484 | | SliceInvertedLists::SliceInvertedLists( |
485 | | const InvertedLists* il, |
486 | | idx_t i0, |
487 | | idx_t i1) |
488 | 0 | : ReadOnlyInvertedLists(i1 - i0, il->code_size), |
489 | 0 | il(il), |
490 | 0 | i0(i0), |
491 | 0 | i1(i1) {} |
492 | | |
493 | 0 | size_t SliceInvertedLists::list_size(size_t list_no) const { |
494 | 0 | return il->list_size(translate_list_no(this, list_no)); |
495 | 0 | } |
496 | | |
497 | 0 | const uint8_t* SliceInvertedLists::get_codes(size_t list_no) const { |
498 | 0 | return il->get_codes(translate_list_no(this, list_no)); |
499 | 0 | } |
500 | | |
501 | | const uint8_t* SliceInvertedLists::get_single_code( |
502 | | size_t list_no, |
503 | 0 | size_t offset) const { |
504 | 0 | return il->get_single_code(translate_list_no(this, list_no), offset); |
505 | 0 | } |
506 | | |
507 | | void SliceInvertedLists::release_codes(size_t list_no, const uint8_t* codes) |
508 | 0 | const { |
509 | 0 | return il->release_codes(translate_list_no(this, list_no), codes); |
510 | 0 | } |
511 | | |
512 | 0 | const idx_t* SliceInvertedLists::get_ids(size_t list_no) const { |
513 | 0 | return il->get_ids(translate_list_no(this, list_no)); |
514 | 0 | } |
515 | | |
516 | 0 | idx_t SliceInvertedLists::get_single_id(size_t list_no, size_t offset) const { |
517 | 0 | return il->get_single_id(translate_list_no(this, list_no), offset); |
518 | 0 | } |
519 | | |
520 | 0 | void SliceInvertedLists::release_ids(size_t list_no, const idx_t* ids) const { |
521 | 0 | return il->release_ids(translate_list_no(this, list_no), ids); |
522 | 0 | } |
523 | | |
524 | | void SliceInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist) |
525 | 0 | const { |
526 | 0 | std::vector<idx_t> translated_list_nos; |
527 | 0 | for (int j = 0; j < nlist; j++) { |
528 | 0 | idx_t list_no = list_nos[j]; |
529 | 0 | if (list_no < 0) |
530 | 0 | continue; |
531 | 0 | translated_list_nos.push_back(translate_list_no(this, list_no)); |
532 | 0 | } |
533 | 0 | il->prefetch_lists(translated_list_nos.data(), translated_list_nos.size()); |
534 | 0 | } |
535 | | |
536 | | /***************************************** |
537 | | * VStackInvertedLists implementation |
538 | | ******************************************/ |
539 | | |
540 | | namespace { |
541 | | |
542 | | // find the invlist this number belongs to |
543 | 0 | int translate_list_no(const VStackInvertedLists* vil, idx_t list_no) { |
544 | 0 | FAISS_THROW_IF_NOT(list_no >= 0 && list_no < vil->nlist); |
545 | 0 | int i0 = 0, i1 = vil->ils.size(); |
546 | 0 | const idx_t* cumsz = vil->cumsz.data(); |
547 | 0 | while (i0 + 1 < i1) { |
548 | 0 | int imed = (i0 + i1) / 2; |
549 | 0 | if (list_no >= cumsz[imed]) { |
550 | 0 | i0 = imed; |
551 | 0 | } else { |
552 | 0 | i1 = imed; |
553 | 0 | } |
554 | 0 | } |
555 | 0 | assert(list_no >= cumsz[i0] && list_no < cumsz[i0 + 1]); |
556 | 0 | return i0; |
557 | 0 | } |
558 | | |
559 | 0 | idx_t sum_il_sizes(int nil, const InvertedLists** ils_in) { |
560 | 0 | idx_t tot = 0; |
561 | 0 | for (int i = 0; i < nil; i++) { |
562 | 0 | tot += ils_in[i]->nlist; |
563 | 0 | } |
564 | 0 | return tot; |
565 | 0 | } |
566 | | |
567 | | } // namespace |
568 | | |
569 | | VStackInvertedLists::VStackInvertedLists(int nil, const InvertedLists** ils_in) |
570 | 0 | : ReadOnlyInvertedLists( |
571 | 0 | sum_il_sizes(nil, ils_in), |
572 | 0 | nil > 0 ? ils_in[0]->code_size : 0) { |
573 | 0 | FAISS_THROW_IF_NOT(nil > 0); |
574 | 0 | cumsz.resize(nil + 1); |
575 | 0 | for (int i = 0; i < nil; i++) { |
576 | 0 | ils.push_back(ils_in[i]); |
577 | 0 | FAISS_THROW_IF_NOT(ils_in[i]->code_size == code_size); |
578 | 0 | cumsz[i + 1] = cumsz[i] + ils_in[i]->nlist; |
579 | 0 | } |
580 | 0 | } |
581 | | |
582 | 0 | size_t VStackInvertedLists::list_size(size_t list_no) const { |
583 | 0 | int i = translate_list_no(this, list_no); |
584 | 0 | list_no -= cumsz[i]; |
585 | 0 | return ils[i]->list_size(list_no); |
586 | 0 | } |
587 | | |
588 | 0 | const uint8_t* VStackInvertedLists::get_codes(size_t list_no) const { |
589 | 0 | int i = translate_list_no(this, list_no); |
590 | 0 | list_no -= cumsz[i]; |
591 | 0 | return ils[i]->get_codes(list_no); |
592 | 0 | } |
593 | | |
594 | | const uint8_t* VStackInvertedLists::get_single_code( |
595 | | size_t list_no, |
596 | 0 | size_t offset) const { |
597 | 0 | int i = translate_list_no(this, list_no); |
598 | 0 | list_no -= cumsz[i]; |
599 | 0 | return ils[i]->get_single_code(list_no, offset); |
600 | 0 | } |
601 | | |
602 | | void VStackInvertedLists::release_codes(size_t list_no, const uint8_t* codes) |
603 | 0 | const { |
604 | 0 | int i = translate_list_no(this, list_no); |
605 | 0 | list_no -= cumsz[i]; |
606 | 0 | return ils[i]->release_codes(list_no, codes); |
607 | 0 | } |
608 | | |
609 | 0 | const idx_t* VStackInvertedLists::get_ids(size_t list_no) const { |
610 | 0 | int i = translate_list_no(this, list_no); |
611 | 0 | list_no -= cumsz[i]; |
612 | 0 | return ils[i]->get_ids(list_no); |
613 | 0 | } |
614 | | |
615 | 0 | idx_t VStackInvertedLists::get_single_id(size_t list_no, size_t offset) const { |
616 | 0 | int i = translate_list_no(this, list_no); |
617 | 0 | list_no -= cumsz[i]; |
618 | 0 | return ils[i]->get_single_id(list_no, offset); |
619 | 0 | } |
620 | | |
621 | 0 | void VStackInvertedLists::release_ids(size_t list_no, const idx_t* ids) const { |
622 | 0 | int i = translate_list_no(this, list_no); |
623 | 0 | list_no -= cumsz[i]; |
624 | 0 | return ils[i]->release_ids(list_no, ids); |
625 | 0 | } |
626 | | |
627 | | void VStackInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist) |
628 | 0 | const { |
629 | 0 | std::vector<int> ilno(nlist, -1); |
630 | 0 | std::vector<int> n_per_il(ils.size(), 0); |
631 | 0 | for (int j = 0; j < nlist; j++) { |
632 | 0 | idx_t list_no = list_nos[j]; |
633 | 0 | if (list_no < 0) |
634 | 0 | continue; |
635 | 0 | int i = ilno[j] = translate_list_no(this, list_no); |
636 | 0 | n_per_il[i]++; |
637 | 0 | } |
638 | 0 | std::vector<int> cum_n_per_il(ils.size() + 1, 0); |
639 | 0 | for (int j = 0; j < ils.size(); j++) { |
640 | 0 | cum_n_per_il[j + 1] = cum_n_per_il[j] + n_per_il[j]; |
641 | 0 | } |
642 | 0 | std::vector<idx_t> sorted_list_nos(cum_n_per_il.back()); |
643 | 0 | for (int j = 0; j < nlist; j++) { |
644 | 0 | idx_t list_no = list_nos[j]; |
645 | 0 | if (list_no < 0) |
646 | 0 | continue; |
647 | 0 | int i = ilno[j]; |
648 | 0 | list_no -= cumsz[i]; |
649 | 0 | sorted_list_nos[cum_n_per_il[i]++] = list_no; |
650 | 0 | } |
651 | |
|
652 | 0 | int i0 = 0; |
653 | 0 | for (int j = 0; j < ils.size(); j++) { |
654 | 0 | int i1 = i0 + n_per_il[j]; |
655 | 0 | if (i1 > i0) { |
656 | 0 | ils[j]->prefetch_lists(sorted_list_nos.data() + i0, i1 - i0); |
657 | 0 | } |
658 | 0 | i0 = i1; |
659 | 0 | } |
660 | 0 | } |
661 | | |
662 | | /***************************************** |
663 | | * MaskedInvertedLists implementation |
664 | | ******************************************/ |
665 | | |
666 | | MaskedInvertedLists::MaskedInvertedLists( |
667 | | const InvertedLists* il0, |
668 | | const InvertedLists* il1) |
669 | 0 | : ReadOnlyInvertedLists(il0->nlist, il0->code_size), |
670 | 0 | il0(il0), |
671 | 0 | il1(il1) { |
672 | 0 | FAISS_THROW_IF_NOT(il1->nlist == nlist); |
673 | 0 | FAISS_THROW_IF_NOT(il1->code_size == code_size); |
674 | 0 | } |
675 | | |
676 | 0 | size_t MaskedInvertedLists::list_size(size_t list_no) const { |
677 | 0 | size_t sz = il0->list_size(list_no); |
678 | 0 | return sz ? sz : il1->list_size(list_no); |
679 | 0 | } |
680 | | |
681 | 0 | const uint8_t* MaskedInvertedLists::get_codes(size_t list_no) const { |
682 | 0 | size_t sz = il0->list_size(list_no); |
683 | 0 | return (sz ? il0 : il1)->get_codes(list_no); |
684 | 0 | } |
685 | | |
686 | 0 | const idx_t* MaskedInvertedLists::get_ids(size_t list_no) const { |
687 | 0 | size_t sz = il0->list_size(list_no); |
688 | 0 | return (sz ? il0 : il1)->get_ids(list_no); |
689 | 0 | } |
690 | | |
691 | | void MaskedInvertedLists::release_codes(size_t list_no, const uint8_t* codes) |
692 | 0 | const { |
693 | 0 | size_t sz = il0->list_size(list_no); |
694 | 0 | (sz ? il0 : il1)->release_codes(list_no, codes); |
695 | 0 | } |
696 | | |
697 | 0 | void MaskedInvertedLists::release_ids(size_t list_no, const idx_t* ids) const { |
698 | 0 | size_t sz = il0->list_size(list_no); |
699 | 0 | (sz ? il0 : il1)->release_ids(list_no, ids); |
700 | 0 | } |
701 | | |
702 | 0 | idx_t MaskedInvertedLists::get_single_id(size_t list_no, size_t offset) const { |
703 | 0 | size_t sz = il0->list_size(list_no); |
704 | 0 | return (sz ? il0 : il1)->get_single_id(list_no, offset); |
705 | 0 | } |
706 | | |
707 | | const uint8_t* MaskedInvertedLists::get_single_code( |
708 | | size_t list_no, |
709 | 0 | size_t offset) const { |
710 | 0 | size_t sz = il0->list_size(list_no); |
711 | 0 | return (sz ? il0 : il1)->get_single_code(list_no, offset); |
712 | 0 | } |
713 | | |
714 | | void MaskedInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist) |
715 | 0 | const { |
716 | 0 | std::vector<idx_t> list0, list1; |
717 | 0 | for (int i = 0; i < nlist; i++) { |
718 | 0 | idx_t list_no = list_nos[i]; |
719 | 0 | if (list_no < 0) |
720 | 0 | continue; |
721 | 0 | size_t sz = il0->list_size(list_no); |
722 | 0 | (sz ? list0 : list1).push_back(list_no); |
723 | 0 | } |
724 | 0 | il0->prefetch_lists(list0.data(), list0.size()); |
725 | 0 | il1->prefetch_lists(list1.data(), list1.size()); |
726 | 0 | } |
727 | | |
728 | | /***************************************** |
729 | | * MaskedInvertedLists implementation |
730 | | ******************************************/ |
731 | | |
732 | | StopWordsInvertedLists::StopWordsInvertedLists( |
733 | | const InvertedLists* il0, |
734 | | size_t maxsize) |
735 | 0 | : ReadOnlyInvertedLists(il0->nlist, il0->code_size), |
736 | 0 | il0(il0), |
737 | 0 | maxsize(maxsize) {} |
738 | | |
739 | 0 | size_t StopWordsInvertedLists::list_size(size_t list_no) const { |
740 | 0 | size_t sz = il0->list_size(list_no); |
741 | 0 | return sz < maxsize ? sz : 0; |
742 | 0 | } |
743 | | |
744 | 0 | const uint8_t* StopWordsInvertedLists::get_codes(size_t list_no) const { |
745 | 0 | return il0->list_size(list_no) < maxsize ? il0->get_codes(list_no) |
746 | 0 | : nullptr; |
747 | 0 | } |
748 | | |
749 | 0 | const idx_t* StopWordsInvertedLists::get_ids(size_t list_no) const { |
750 | 0 | return il0->list_size(list_no) < maxsize ? il0->get_ids(list_no) : nullptr; |
751 | 0 | } |
752 | | |
753 | | void StopWordsInvertedLists::release_codes(size_t list_no, const uint8_t* codes) |
754 | 0 | const { |
755 | 0 | if (il0->list_size(list_no) < maxsize) { |
756 | 0 | il0->release_codes(list_no, codes); |
757 | 0 | } |
758 | 0 | } |
759 | | |
760 | | void StopWordsInvertedLists::release_ids(size_t list_no, const idx_t* ids) |
761 | 0 | const { |
762 | 0 | if (il0->list_size(list_no) < maxsize) { |
763 | 0 | il0->release_ids(list_no, ids); |
764 | 0 | } |
765 | 0 | } |
766 | | |
767 | | idx_t StopWordsInvertedLists::get_single_id(size_t list_no, size_t offset) |
768 | 0 | const { |
769 | 0 | FAISS_THROW_IF_NOT(il0->list_size(list_no) < maxsize); |
770 | 0 | return il0->get_single_id(list_no, offset); |
771 | 0 | } |
772 | | |
773 | | const uint8_t* StopWordsInvertedLists::get_single_code( |
774 | | size_t list_no, |
775 | 0 | size_t offset) const { |
776 | 0 | FAISS_THROW_IF_NOT(il0->list_size(list_no) < maxsize); |
777 | 0 | return il0->get_single_code(list_no, offset); |
778 | 0 | } |
779 | | |
780 | | void StopWordsInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist) |
781 | 0 | const { |
782 | 0 | std::vector<idx_t> list0; |
783 | 0 | for (int i = 0; i < nlist; i++) { |
784 | 0 | idx_t list_no = list_nos[i]; |
785 | 0 | if (list_no < 0) |
786 | 0 | continue; |
787 | 0 | if (il0->list_size(list_no) < maxsize) { |
788 | 0 | list0.push_back(list_no); |
789 | 0 | } |
790 | 0 | } |
791 | 0 | il0->prefetch_lists(list0.data(), list0.size()); |
792 | 0 | } |
793 | | |
794 | | } // namespace faiss |