contrib/faiss/faiss/impl/pq4_fast_scan.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/impl/FaissAssert.h> |
9 | | #include <faiss/impl/platform_macros.h> |
10 | | #include <faiss/impl/pq4_fast_scan.h> |
11 | | #include <faiss/impl/simd_result_handlers.h> |
12 | | |
13 | | #include <array> |
14 | | |
15 | | namespace faiss { |
16 | | |
17 | | using namespace simd_result_handlers; |
18 | | |
19 | | /*************************************************************** |
20 | | * Packing functions for codes |
21 | | ***************************************************************/ |
22 | | |
23 | | namespace { |
24 | | |
25 | | /* extract the column starting at (i, j) |
26 | | * from packed matrix src of size (m, n)*/ |
27 | | template <typename T, class TA> |
28 | | void get_matrix_column( |
29 | | T* src, |
30 | | size_t m, |
31 | | size_t n, |
32 | | int64_t i, |
33 | | int64_t j, |
34 | 0 | TA& dest) { |
35 | 0 | for (int64_t k = 0; k < dest.size(); k++) { |
36 | 0 | if (k + i >= 0 && k + i < m) { |
37 | 0 | dest[k] = src[(k + i) * n + j]; |
38 | 0 | } else { |
39 | 0 | dest[k] = 0; |
40 | 0 | } |
41 | 0 | } |
42 | 0 | } |
43 | | |
44 | | } // anonymous namespace |
45 | | |
46 | | void pq4_pack_codes( |
47 | | const uint8_t* codes, |
48 | | size_t ntotal, |
49 | | size_t M, |
50 | | size_t nb, |
51 | | size_t bbs, |
52 | | size_t nsq, |
53 | 0 | uint8_t* blocks) { |
54 | 0 | FAISS_THROW_IF_NOT(bbs % 32 == 0); |
55 | 0 | FAISS_THROW_IF_NOT(nb % bbs == 0); |
56 | 0 | FAISS_THROW_IF_NOT(nsq % 2 == 0); |
57 | | |
58 | 0 | if (nb == 0) { |
59 | 0 | return; |
60 | 0 | } |
61 | 0 | memset(blocks, 0, nb * nsq / 2); |
62 | | #ifdef FAISS_BIG_ENDIAN |
63 | | const uint8_t perm0[16] = { |
64 | | 8, 0, 9, 1, 10, 2, 11, 3, 12, 4, 13, 5, 14, 6, 15, 7}; |
65 | | #else |
66 | 0 | const uint8_t perm0[16] = { |
67 | 0 | 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; |
68 | 0 | #endif |
69 | |
|
70 | 0 | uint8_t* codes2 = blocks; |
71 | 0 | for (size_t i0 = 0; i0 < nb; i0 += bbs) { |
72 | 0 | for (int sq = 0; sq < nsq; sq += 2) { |
73 | 0 | for (size_t i = 0; i < bbs; i += 32) { |
74 | 0 | std::array<uint8_t, 32> c, c0, c1; |
75 | 0 | get_matrix_column( |
76 | 0 | codes, ntotal, (M + 1) / 2, i0 + i, sq / 2, c); |
77 | 0 | for (int j = 0; j < 32; j++) { |
78 | 0 | c0[j] = c[j] & 15; |
79 | 0 | c1[j] = c[j] >> 4; |
80 | 0 | } |
81 | 0 | for (int j = 0; j < 16; j++) { |
82 | 0 | uint8_t d0, d1; |
83 | 0 | d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4); |
84 | 0 | d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4); |
85 | 0 | codes2[j] = d0; |
86 | 0 | codes2[j + 16] = d1; |
87 | 0 | } |
88 | 0 | codes2 += 32; |
89 | 0 | } |
90 | 0 | } |
91 | 0 | } |
92 | 0 | } |
93 | | |
94 | | void pq4_pack_codes_range( |
95 | | const uint8_t* codes, |
96 | | size_t M, |
97 | | size_t i0, |
98 | | size_t i1, |
99 | | size_t bbs, |
100 | | size_t nsq, |
101 | 0 | uint8_t* blocks) { |
102 | | #ifdef FAISS_BIG_ENDIAN |
103 | | const uint8_t perm0[16] = { |
104 | | 8, 0, 9, 1, 10, 2, 11, 3, 12, 4, 13, 5, 14, 6, 15, 7}; |
105 | | #else |
106 | 0 | const uint8_t perm0[16] = { |
107 | 0 | 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; |
108 | 0 | #endif |
109 | | |
110 | | // range of affected blocks |
111 | 0 | size_t block0 = i0 / bbs; |
112 | 0 | size_t block1 = ((i1 - 1) / bbs) + 1; |
113 | |
|
114 | 0 | for (size_t b = block0; b < block1; b++) { |
115 | 0 | uint8_t* codes2 = blocks + b * bbs * nsq / 2; |
116 | 0 | int64_t i_base = b * bbs - i0; |
117 | 0 | for (int sq = 0; sq < nsq; sq += 2) { |
118 | 0 | for (size_t i = 0; i < bbs; i += 32) { |
119 | 0 | std::array<uint8_t, 32> c, c0, c1; |
120 | 0 | get_matrix_column( |
121 | 0 | codes, i1 - i0, (M + 1) / 2, i_base + i, sq / 2, c); |
122 | 0 | for (int j = 0; j < 32; j++) { |
123 | 0 | c0[j] = c[j] & 15; |
124 | 0 | c1[j] = c[j] >> 4; |
125 | 0 | } |
126 | 0 | for (int j = 0; j < 16; j++) { |
127 | 0 | uint8_t d0, d1; |
128 | 0 | d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4); |
129 | 0 | d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4); |
130 | 0 | codes2[j] |= d0; |
131 | 0 | codes2[j + 16] |= d1; |
132 | 0 | } |
133 | 0 | codes2 += 32; |
134 | 0 | } |
135 | 0 | } |
136 | 0 | } |
137 | 0 | } |
138 | | |
139 | | namespace { |
140 | | |
141 | | // get the specific address of the vector inside a block |
142 | | // shift is used for determine the if the saved in bits 0..3 (false) or |
143 | | // bits 4..7 (true) |
144 | | size_t get_vector_specific_address( |
145 | | size_t bbs, |
146 | | size_t vector_id, |
147 | | size_t sq, |
148 | 0 | bool& shift) { |
149 | | // get the vector_id inside the block |
150 | 0 | vector_id = vector_id % bbs; |
151 | 0 | shift = vector_id > 15; |
152 | 0 | vector_id = vector_id & 15; |
153 | | |
154 | | // get the address of the vector in sq |
155 | 0 | size_t address; |
156 | 0 | if (vector_id < 8) { |
157 | 0 | address = vector_id << 1; |
158 | 0 | } else { |
159 | 0 | address = ((vector_id - 8) << 1) + 1; |
160 | 0 | } |
161 | 0 | if (sq & 1) { |
162 | 0 | address += 16; |
163 | 0 | } |
164 | 0 | return (sq >> 1) * bbs + address; |
165 | 0 | } |
166 | | |
167 | | } // anonymous namespace |
168 | | |
169 | | uint8_t pq4_get_packed_element( |
170 | | const uint8_t* data, |
171 | | size_t bbs, |
172 | | size_t nsq, |
173 | | size_t vector_id, |
174 | 0 | size_t sq) { |
175 | | // move to correct bbs-sized block |
176 | | // number of blocks * block size |
177 | 0 | data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs); |
178 | 0 | bool shift; |
179 | 0 | size_t address = get_vector_specific_address(bbs, vector_id, sq, shift); |
180 | 0 | if (shift) { |
181 | 0 | return data[address] >> 4; |
182 | 0 | } else { |
183 | 0 | return data[address] & 15; |
184 | 0 | } |
185 | 0 | } |
186 | | |
187 | | void pq4_set_packed_element( |
188 | | uint8_t* data, |
189 | | uint8_t code, |
190 | | size_t bbs, |
191 | | size_t nsq, |
192 | | size_t vector_id, |
193 | 0 | size_t sq) { |
194 | | // move to correct bbs-sized block |
195 | | // number of blocks * block size |
196 | 0 | data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs); |
197 | 0 | bool shift; |
198 | 0 | size_t address = get_vector_specific_address(bbs, vector_id, sq, shift); |
199 | 0 | if (shift) { |
200 | 0 | data[address] = (code << 4) | (data[address] & 15); |
201 | 0 | } else { |
202 | 0 | data[address] = code | (data[address] & ~15); |
203 | 0 | } |
204 | 0 | } |
205 | | |
206 | | /*************************************************************** |
207 | | * CodePackerPQ4 implementation |
208 | | ***************************************************************/ |
209 | | |
210 | 0 | CodePackerPQ4::CodePackerPQ4(size_t nsq, size_t bbs) { |
211 | 0 | this->nsq = nsq; |
212 | 0 | nvec = bbs; |
213 | 0 | code_size = (nsq * 4 + 7) / 8; |
214 | 0 | block_size = ((nsq + 1) / 2) * bbs; |
215 | 0 | } |
216 | | |
217 | | void CodePackerPQ4::pack_1( |
218 | | const uint8_t* flat_code, |
219 | | size_t offset, |
220 | 0 | uint8_t* block) const { |
221 | 0 | size_t bbs = nvec; |
222 | 0 | if (offset >= nvec) { |
223 | 0 | block += (offset / nvec) * block_size; |
224 | 0 | offset = offset % nvec; |
225 | 0 | } |
226 | 0 | for (size_t i = 0; i < code_size; i++) { |
227 | 0 | uint8_t code = flat_code[i]; |
228 | 0 | pq4_set_packed_element(block, code & 15, bbs, nsq, offset, 2 * i); |
229 | 0 | pq4_set_packed_element(block, code >> 4, bbs, nsq, offset, 2 * i + 1); |
230 | 0 | } |
231 | 0 | } |
232 | | |
233 | | void CodePackerPQ4::unpack_1( |
234 | | const uint8_t* block, |
235 | | size_t offset, |
236 | 0 | uint8_t* flat_code) const { |
237 | 0 | size_t bbs = nvec; |
238 | 0 | if (offset >= nvec) { |
239 | 0 | block += (offset / nvec) * block_size; |
240 | 0 | offset = offset % nvec; |
241 | 0 | } |
242 | 0 | for (size_t i = 0; i < code_size; i++) { |
243 | 0 | uint8_t code0, code1; |
244 | 0 | code0 = pq4_get_packed_element(block, bbs, nsq, offset, 2 * i); |
245 | 0 | code1 = pq4_get_packed_element(block, bbs, nsq, offset, 2 * i + 1); |
246 | 0 | flat_code[i] = code0 | (code1 << 4); |
247 | 0 | } |
248 | 0 | } |
249 | | |
250 | | /*************************************************************** |
251 | | * Packing functions for Look-Up Tables (LUT) |
252 | | ***************************************************************/ |
253 | | |
254 | 0 | void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest) { |
255 | 0 | for (int q = 0; q < nq; q++) { |
256 | 0 | for (int sq = 0; sq < nsq; sq += 2) { |
257 | 0 | memcpy(dest + (sq / 2 * nq + q) * 32, |
258 | 0 | src + (q * nsq + sq) * 16, |
259 | 0 | 16); |
260 | 0 | memcpy(dest + (sq / 2 * nq + q) * 32 + 16, |
261 | 0 | src + (q * nsq + sq + 1) * 16, |
262 | 0 | 16); |
263 | 0 | } |
264 | 0 | } |
265 | 0 | } |
266 | | |
267 | 0 | int pq4_pack_LUT_qbs(int qbs, int nsq, const uint8_t* src, uint8_t* dest) { |
268 | 0 | FAISS_THROW_IF_NOT(nsq % 2 == 0); |
269 | 0 | size_t dim12 = 16 * nsq; |
270 | 0 | int i0 = 0; |
271 | 0 | int qi = qbs; |
272 | 0 | while (qi) { |
273 | 0 | int nq = qi & 15; |
274 | 0 | qi >>= 4; |
275 | 0 | pq4_pack_LUT(nq, nsq, src + i0 * dim12, dest + i0 * dim12); |
276 | 0 | i0 += nq; |
277 | 0 | } |
278 | 0 | return i0; |
279 | 0 | } |
280 | | |
281 | | namespace { |
282 | | |
283 | | void pack_LUT_1_q_map( |
284 | | int nq, |
285 | | const int* q_map, |
286 | | int nsq, |
287 | | const uint8_t* src, |
288 | 0 | uint8_t* dest) { |
289 | 0 | for (int qi = 0; qi < nq; qi++) { |
290 | 0 | int q = q_map[qi]; |
291 | 0 | for (int sq = 0; sq < nsq; sq += 2) { |
292 | 0 | memcpy(dest + (sq / 2 * nq + qi) * 32, |
293 | 0 | src + (q * nsq + sq) * 16, |
294 | 0 | 16); |
295 | 0 | memcpy(dest + (sq / 2 * nq + qi) * 32 + 16, |
296 | 0 | src + (q * nsq + sq + 1) * 16, |
297 | 0 | 16); |
298 | 0 | } |
299 | 0 | } |
300 | 0 | } |
301 | | |
302 | | } // anonymous namespace |
303 | | |
304 | | int pq4_pack_LUT_qbs_q_map( |
305 | | int qbs, |
306 | | int nsq, |
307 | | const uint8_t* src, |
308 | | const int* q_map, |
309 | 0 | uint8_t* dest) { |
310 | 0 | FAISS_THROW_IF_NOT(nsq % 2 == 0); |
311 | 0 | size_t dim12 = 16 * nsq; |
312 | 0 | int i0 = 0; |
313 | 0 | int qi = qbs; |
314 | 0 | while (qi) { |
315 | 0 | int nq = qi & 15; |
316 | 0 | qi >>= 4; |
317 | 0 | pack_LUT_1_q_map(nq, q_map + i0, nsq, src, dest + i0 * dim12); |
318 | 0 | i0 += nq; |
319 | 0 | } |
320 | 0 | return i0; |
321 | 0 | } |
322 | | |
323 | | } // namespace faiss |