contrib/faiss/faiss/invlists/DirectMap.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/invlists/DirectMap.h> |
11 | | |
12 | | #include <cassert> |
13 | | #include <cstdio> |
14 | | |
15 | | #include <faiss/impl/AuxIndexStructures.h> |
16 | | #include <faiss/impl/FaissAssert.h> |
17 | | #include <faiss/impl/IDSelector.h> |
18 | | #include <faiss/invlists/BlockInvertedLists.h> |
19 | | |
20 | | namespace faiss { |
21 | | |
22 | 41 | DirectMap::DirectMap() : type(NoMap) {} |
23 | | |
24 | | void DirectMap::set_type( |
25 | | Type new_type, |
26 | | const InvertedLists* invlists, |
27 | 0 | size_t ntotal) { |
28 | 0 | FAISS_THROW_IF_NOT( |
29 | 0 | new_type == NoMap || new_type == Array || new_type == Hashtable); |
30 | | |
31 | 0 | if (new_type == type) { |
32 | | // nothing to do |
33 | 0 | return; |
34 | 0 | } |
35 | | |
36 | 0 | array.clear(); |
37 | 0 | hashtable.clear(); |
38 | 0 | type = new_type; |
39 | |
|
40 | 0 | if (new_type == NoMap) { |
41 | 0 | return; |
42 | 0 | } else if (new_type == Array) { |
43 | 0 | array.resize(ntotal, -1); |
44 | 0 | } else if (new_type == Hashtable) { |
45 | 0 | hashtable.reserve(ntotal); |
46 | 0 | } |
47 | | |
48 | 0 | for (size_t key = 0; key < invlists->nlist; key++) { |
49 | 0 | size_t list_size = invlists->list_size(key); |
50 | 0 | InvertedLists::ScopedIds idlist(invlists, key); |
51 | |
|
52 | 0 | if (new_type == Array) { |
53 | 0 | for (long ofs = 0; ofs < list_size; ofs++) { |
54 | 0 | FAISS_THROW_IF_NOT_MSG( |
55 | 0 | 0 <= idlist[ofs] && idlist[ofs] < ntotal, |
56 | 0 | "direct map supported only for seuquential ids"); |
57 | 0 | array[idlist[ofs]] = lo_build(key, ofs); |
58 | 0 | } |
59 | 0 | } else if (new_type == Hashtable) { |
60 | 0 | for (long ofs = 0; ofs < list_size; ofs++) { |
61 | 0 | hashtable[idlist[ofs]] = lo_build(key, ofs); |
62 | 0 | } |
63 | 0 | } |
64 | 0 | } |
65 | 0 | } |
66 | | |
67 | 0 | void DirectMap::clear() { |
68 | 0 | array.clear(); |
69 | 0 | hashtable.clear(); |
70 | 0 | } |
71 | | |
72 | 0 | idx_t DirectMap::get(idx_t key) const { |
73 | 0 | if (type == Array) { |
74 | 0 | FAISS_THROW_IF_NOT_MSG(key >= 0 && key < array.size(), "invalid key"); |
75 | 0 | idx_t lo = array[key]; |
76 | 0 | FAISS_THROW_IF_NOT_MSG(lo >= 0, "-1 entry in direct_map"); |
77 | 0 | return lo; |
78 | 0 | } else if (type == Hashtable) { |
79 | 0 | auto res = hashtable.find(key); |
80 | 0 | FAISS_THROW_IF_NOT_MSG(res != hashtable.end(), "key not found"); |
81 | 0 | return res->second; |
82 | 0 | } else { |
83 | 0 | FAISS_THROW_MSG("direct map not initialized"); |
84 | 0 | } |
85 | 0 | } |
86 | | |
87 | 0 | void DirectMap::add_single_id(idx_t id, idx_t list_no, size_t offset) { |
88 | 0 | if (type == NoMap) |
89 | 0 | return; |
90 | | |
91 | 0 | if (type == Array) { |
92 | 0 | assert(id == array.size()); |
93 | 0 | if (list_no >= 0) { |
94 | 0 | array.push_back(lo_build(list_no, offset)); |
95 | 0 | } else { |
96 | 0 | array.push_back(-1); |
97 | 0 | } |
98 | 0 | } else if (type == Hashtable) { |
99 | 0 | if (list_no >= 0) { |
100 | 0 | hashtable[id] = lo_build(list_no, offset); |
101 | 0 | } |
102 | 0 | } |
103 | 0 | } |
104 | | |
105 | 25 | void DirectMap::check_can_add(const idx_t* ids) { |
106 | 25 | if (type == Array && ids) { |
107 | 0 | FAISS_THROW_MSG("cannot have array direct map and add with ids"); |
108 | 0 | } |
109 | 25 | } |
110 | | |
111 | | /********************* DirectMapAdd implementation */ |
112 | | |
113 | | DirectMapAdd::DirectMapAdd(DirectMap& direct_map, size_t n, const idx_t* xids) |
114 | 26 | : direct_map(direct_map), type(direct_map.type), n(n), xids(xids) { |
115 | 26 | if (type == DirectMap::Array) { |
116 | 0 | FAISS_THROW_IF_NOT(xids == nullptr); |
117 | 0 | ntotal = direct_map.array.size(); |
118 | 0 | direct_map.array.resize(ntotal + n, -1); |
119 | 26 | } else if (type == DirectMap::Hashtable) { |
120 | | // can't parallel update hashtable so use temp array |
121 | 0 | all_ofs.resize(n, -1); |
122 | 0 | } |
123 | 26 | } |
124 | | |
125 | 3.47k | void DirectMapAdd::add(size_t i, idx_t list_no, size_t ofs) { |
126 | 3.47k | if (type == DirectMap::Array) { |
127 | 0 | direct_map.array[ntotal + i] = lo_build(list_no, ofs); |
128 | 3.47k | } else if (type == DirectMap::Hashtable) { |
129 | 0 | all_ofs[i] = lo_build(list_no, ofs); |
130 | 0 | } |
131 | 3.47k | } |
132 | | |
133 | 26 | DirectMapAdd::~DirectMapAdd() { |
134 | 26 | if (type == DirectMap::Hashtable) { |
135 | 0 | for (int i = 0; i < n; i++) { |
136 | 0 | idx_t id = xids ? xids[i] : ntotal + i; |
137 | 0 | direct_map.hashtable[id] = all_ofs[i]; |
138 | 0 | } |
139 | 0 | } |
140 | 26 | } |
141 | | |
142 | | /********************************************************/ |
143 | | |
144 | | using ScopedCodes = InvertedLists::ScopedCodes; |
145 | | using ScopedIds = InvertedLists::ScopedIds; |
146 | | |
147 | 0 | size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) { |
148 | 0 | size_t nlist = invlists->nlist; |
149 | 0 | std::vector<idx_t> toremove(nlist); |
150 | |
|
151 | 0 | size_t nremove = 0; |
152 | 0 | BlockInvertedLists* block_invlists = |
153 | 0 | dynamic_cast<BlockInvertedLists*>(invlists); |
154 | 0 | if (type == NoMap) { |
155 | 0 | if (block_invlists != nullptr) { |
156 | 0 | return block_invlists->remove_ids(sel); |
157 | 0 | } |
158 | | // exhaustive scan of IVF |
159 | 0 | #pragma omp parallel for |
160 | 0 | for (idx_t i = 0; i < nlist; i++) { |
161 | 0 | idx_t l0 = invlists->list_size(i), l = l0, j = 0; |
162 | 0 | ScopedIds idsi(invlists, i); |
163 | 0 | while (j < l) { |
164 | 0 | if (sel.is_member(idsi[j])) { |
165 | 0 | l--; |
166 | 0 | invlists->update_entry( |
167 | 0 | i, |
168 | 0 | j, |
169 | 0 | invlists->get_single_id(i, l), |
170 | 0 | ScopedCodes(invlists, i, l).get()); |
171 | 0 | } else { |
172 | 0 | j++; |
173 | 0 | } |
174 | 0 | } |
175 | 0 | toremove[i] = l0 - l; |
176 | 0 | } |
177 | | // this will not run well in parallel on ondisk because of |
178 | | // possible shrinks |
179 | 0 | for (idx_t i = 0; i < nlist; i++) { |
180 | 0 | if (toremove[i] > 0) { |
181 | 0 | nremove += toremove[i]; |
182 | 0 | invlists->resize(i, invlists->list_size(i) - toremove[i]); |
183 | 0 | } |
184 | 0 | } |
185 | 0 | } else if (type == Hashtable) { |
186 | 0 | FAISS_THROW_IF_MSG( |
187 | 0 | block_invlists, |
188 | 0 | "remove with hashtable is not supported with BlockInvertedLists"); |
189 | 0 | const IDSelectorArray* sela = |
190 | 0 | dynamic_cast<const IDSelectorArray*>(&sel); |
191 | 0 | FAISS_THROW_IF_NOT_MSG( |
192 | 0 | sela, "remove with hashtable works only with IDSelectorArray"); |
193 | | |
194 | 0 | for (idx_t i = 0; i < sela->n; i++) { |
195 | 0 | idx_t id = sela->ids[i]; |
196 | 0 | auto res = hashtable.find(id); |
197 | 0 | if (res != hashtable.end()) { |
198 | 0 | size_t list_no = lo_listno(res->second); |
199 | 0 | size_t offset = lo_offset(res->second); |
200 | 0 | idx_t last = invlists->list_size(list_no) - 1; |
201 | 0 | hashtable.erase(res); |
202 | 0 | if (offset < last) { |
203 | 0 | idx_t last_id = invlists->get_single_id(list_no, last); |
204 | 0 | invlists->update_entry( |
205 | 0 | list_no, |
206 | 0 | offset, |
207 | 0 | last_id, |
208 | 0 | ScopedCodes(invlists, list_no, last).get()); |
209 | | // update hash entry for last element |
210 | 0 | hashtable[last_id] = lo_build(list_no, offset); |
211 | 0 | } |
212 | 0 | invlists->resize(list_no, last); |
213 | 0 | nremove++; |
214 | 0 | } |
215 | 0 | } |
216 | |
|
217 | 0 | } else { |
218 | 0 | FAISS_THROW_MSG("remove not supported with this direct_map format"); |
219 | 0 | } |
220 | 0 | return nremove; |
221 | 0 | } |
222 | | |
223 | | void DirectMap::update_codes( |
224 | | InvertedLists* invlists, |
225 | | int n, |
226 | | const idx_t* ids, |
227 | | const idx_t* assign, |
228 | 0 | const uint8_t* codes) { |
229 | 0 | FAISS_THROW_IF_NOT(type == Array); |
230 | | |
231 | 0 | size_t code_size = invlists->code_size; |
232 | |
|
233 | 0 | for (size_t i = 0; i < n; i++) { |
234 | 0 | idx_t id = ids[i]; |
235 | 0 | FAISS_THROW_IF_NOT_MSG( |
236 | 0 | 0 <= id && id < array.size(), "id to update out of range"); |
237 | 0 | { // remove old one |
238 | 0 | idx_t dm = array[id]; |
239 | 0 | int64_t ofs = lo_offset(dm); |
240 | 0 | int64_t il = lo_listno(dm); |
241 | 0 | size_t l = invlists->list_size(il); |
242 | 0 | if (ofs != l - 1) { // move l - 1 to ofs |
243 | 0 | int64_t id2 = invlists->get_single_id(il, l - 1); |
244 | 0 | array[id2] = lo_build(il, ofs); |
245 | 0 | invlists->update_entry( |
246 | 0 | il, ofs, id2, invlists->get_single_code(il, l - 1)); |
247 | 0 | } |
248 | 0 | invlists->resize(il, l - 1); |
249 | 0 | } |
250 | 0 | { // insert new one |
251 | 0 | int64_t il = assign[i]; |
252 | 0 | size_t l = invlists->list_size(il); |
253 | 0 | idx_t dm = lo_build(il, l); |
254 | 0 | array[id] = dm; |
255 | 0 | invlists->add_entry(il, id, codes + i * code_size); |
256 | 0 | } |
257 | 0 | } |
258 | 0 | } |
259 | | |
260 | | } // namespace faiss |