contrib/faiss/faiss/invlists/BlockInvertedLists.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/invlists/BlockInvertedLists.h> |
9 | | |
10 | | #include <faiss/impl/CodePacker.h> |
11 | | #include <faiss/impl/FaissAssert.h> |
12 | | #include <faiss/impl/IDSelector.h> |
13 | | |
14 | | #include <faiss/impl/io.h> |
15 | | #include <faiss/impl/io_macros.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | BlockInvertedLists::BlockInvertedLists( |
20 | | size_t nlist, |
21 | | size_t n_per_block, |
22 | | size_t block_size) |
23 | 0 | : InvertedLists(nlist, InvertedLists::INVALID_CODE_SIZE), |
24 | 0 | n_per_block(n_per_block), |
25 | 0 | block_size(block_size) { |
26 | 0 | ids.resize(nlist); |
27 | 0 | codes.resize(nlist); |
28 | 0 | } |
29 | | |
30 | | BlockInvertedLists::BlockInvertedLists(size_t nlist, const CodePacker* packer) |
31 | 0 | : InvertedLists(nlist, InvertedLists::INVALID_CODE_SIZE), |
32 | 0 | n_per_block(packer->nvec), |
33 | 0 | block_size(packer->block_size), |
34 | 0 | packer(packer) { |
35 | 0 | ids.resize(nlist); |
36 | 0 | codes.resize(nlist); |
37 | 0 | } |
38 | | |
39 | | BlockInvertedLists::BlockInvertedLists() |
40 | 0 | : InvertedLists(0, InvertedLists::INVALID_CODE_SIZE) {} |
41 | | |
42 | | size_t BlockInvertedLists::add_entries( |
43 | | size_t list_no, |
44 | | size_t n_entry, |
45 | | const idx_t* ids_in, |
46 | 0 | const uint8_t* code) { |
47 | 0 | if (n_entry == 0) { |
48 | 0 | return 0; |
49 | 0 | } |
50 | 0 | FAISS_THROW_IF_NOT(list_no < nlist); |
51 | 0 | size_t o = ids[list_no].size(); |
52 | 0 | ids[list_no].resize(o + n_entry); |
53 | 0 | memcpy(&ids[list_no][o], ids_in, sizeof(ids_in[0]) * n_entry); |
54 | 0 | size_t n_block = (o + n_entry + n_per_block - 1) / n_per_block; |
55 | 0 | codes[list_no].resize(n_block * block_size); |
56 | 0 | if (o % block_size == 0) { |
57 | | // copy whole blocks |
58 | 0 | memcpy(&codes[list_no][o * packer->code_size], |
59 | 0 | code, |
60 | 0 | n_block * block_size); |
61 | 0 | } else { |
62 | 0 | FAISS_THROW_IF_NOT_MSG(packer, "missing code packer"); |
63 | 0 | std::vector<uint8_t> buffer(packer->code_size); |
64 | 0 | for (size_t i = 0; i < n_entry; i++) { |
65 | 0 | packer->unpack_1(code, i, buffer.data()); |
66 | 0 | packer->pack_1(buffer.data(), i + o, codes[list_no].data()); |
67 | 0 | } |
68 | 0 | } |
69 | 0 | return o; |
70 | 0 | } |
71 | | |
72 | 0 | size_t BlockInvertedLists::list_size(size_t list_no) const { |
73 | 0 | assert(list_no < nlist); |
74 | 0 | return ids[list_no].size(); |
75 | 0 | } |
76 | | |
77 | 0 | const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const { |
78 | 0 | assert(list_no < nlist); |
79 | 0 | return codes[list_no].get(); |
80 | 0 | } |
81 | | |
82 | 0 | size_t BlockInvertedLists::remove_ids(const IDSelector& sel) { |
83 | 0 | idx_t nremove = 0; |
84 | 0 | #pragma omp parallel for |
85 | 0 | for (idx_t i = 0; i < nlist; i++) { |
86 | 0 | std::vector<uint8_t> buffer(packer->code_size); |
87 | 0 | idx_t l = ids[i].size(), j = 0; |
88 | 0 | while (j < l) { |
89 | 0 | if (sel.is_member(ids[i][j])) { |
90 | 0 | l--; |
91 | 0 | ids[i][j] = ids[i][l]; |
92 | 0 | packer->unpack_1(codes[i].data(), l, buffer.data()); |
93 | 0 | packer->pack_1(buffer.data(), j, codes[i].data()); |
94 | 0 | } else { |
95 | 0 | j++; |
96 | 0 | } |
97 | 0 | } |
98 | 0 | resize(i, l); |
99 | 0 | nremove += ids[i].size() - l; |
100 | 0 | } |
101 | |
|
102 | 0 | return nremove; |
103 | 0 | } |
104 | | |
105 | 0 | const idx_t* BlockInvertedLists::get_ids(size_t list_no) const { |
106 | 0 | assert(list_no < nlist); |
107 | 0 | return ids[list_no].data(); |
108 | 0 | } |
109 | | |
110 | 0 | void BlockInvertedLists::resize(size_t list_no, size_t new_size) { |
111 | 0 | ids[list_no].resize(new_size); |
112 | 0 | size_t prev_nbytes = codes[list_no].size(); |
113 | 0 | size_t n_block = (new_size + n_per_block - 1) / n_per_block; |
114 | 0 | size_t new_nbytes = n_block * block_size; |
115 | 0 | codes[list_no].resize(new_nbytes); |
116 | 0 | if (prev_nbytes < new_nbytes) { |
117 | | // set new elements to 0 |
118 | 0 | memset(codes[list_no].data() + prev_nbytes, |
119 | 0 | 0, |
120 | 0 | new_nbytes - prev_nbytes); |
121 | 0 | } |
122 | 0 | } |
123 | | |
124 | | void BlockInvertedLists::update_entries( |
125 | | size_t, |
126 | | size_t, |
127 | | size_t, |
128 | | const idx_t*, |
129 | 0 | const uint8_t*) { |
130 | 0 | FAISS_THROW_MSG("not implemented"); |
131 | 0 | } |
132 | | |
133 | 0 | BlockInvertedLists::~BlockInvertedLists() { |
134 | 0 | delete packer; |
135 | 0 | } |
136 | | |
137 | | /************************************************** |
138 | | * IO hook implementation |
139 | | **************************************************/ |
140 | | |
141 | | BlockInvertedListsIOHook::BlockInvertedListsIOHook() |
142 | 1 | : InvertedListsIOHook("ilbl", typeid(BlockInvertedLists).name()) {} |
143 | | |
144 | | void BlockInvertedListsIOHook::write(const InvertedLists* ils_in, IOWriter* f) |
145 | 0 | const { |
146 | 0 | uint32_t h = fourcc("ilbl"); |
147 | 0 | WRITE1(h); |
148 | 0 | const BlockInvertedLists* il = |
149 | 0 | dynamic_cast<const BlockInvertedLists*>(ils_in); |
150 | 0 | WRITE1(il->nlist); |
151 | 0 | WRITE1(il->code_size); |
152 | 0 | WRITE1(il->n_per_block); |
153 | 0 | WRITE1(il->block_size); |
154 | |
|
155 | 0 | for (size_t i = 0; i < il->nlist; i++) { |
156 | 0 | WRITEVECTOR(il->ids[i]); |
157 | 0 | WRITEVECTOR(il->codes[i]); |
158 | 0 | } |
159 | 0 | } |
160 | | |
161 | | InvertedLists* BlockInvertedListsIOHook::read(IOReader* f, int /* io_flags */) |
162 | 0 | const { |
163 | 0 | BlockInvertedLists* il = new BlockInvertedLists(); |
164 | 0 | READ1(il->nlist); |
165 | 0 | READ1(il->code_size); |
166 | 0 | READ1(il->n_per_block); |
167 | 0 | READ1(il->block_size); |
168 | |
|
169 | 0 | il->ids.resize(il->nlist); |
170 | 0 | il->codes.resize(il->nlist); |
171 | |
|
172 | 0 | for (size_t i = 0; i < il->nlist; i++) { |
173 | 0 | READVECTOR(il->ids[i]); |
174 | 0 | READVECTOR(il->codes[i]); |
175 | 0 | } |
176 | | |
177 | 0 | return il; |
178 | 0 | } |
179 | | |
180 | | } // namespace faiss |