Coverage Report

Created: 2026-03-15 18:01

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/pq4_fast_scan.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/impl/FaissAssert.h>
9
#include <faiss/impl/platform_macros.h>
10
#include <faiss/impl/pq4_fast_scan.h>
11
#include <faiss/impl/simd_result_handlers.h>
12
13
#include <array>
14
15
namespace faiss {
16
17
using namespace simd_result_handlers;
18
19
/***************************************************************
20
 * Packing functions for codes
21
 ***************************************************************/
22
23
namespace {
24
25
/* extract the column starting at (i, j)
26
 * from packed matrix src of size (m, n)*/
27
template <typename T, class TA>
28
void get_matrix_column(
29
        T* src,
30
        size_t m,
31
        size_t n,
32
        int64_t i,
33
        int64_t j,
34
0
        TA& dest) {
35
0
    for (int64_t k = 0; k < dest.size(); k++) {
36
0
        if (k + i >= 0 && k + i < m) {
37
0
            dest[k] = src[(k + i) * n + j];
38
0
        } else {
39
0
            dest[k] = 0;
40
0
        }
41
0
    }
42
0
}
43
44
} // anonymous namespace
45
46
void pq4_pack_codes(
47
        const uint8_t* codes,
48
        size_t ntotal,
49
        size_t M,
50
        size_t nb,
51
        size_t bbs,
52
        size_t nsq,
53
0
        uint8_t* blocks) {
54
0
    FAISS_THROW_IF_NOT(bbs % 32 == 0);
55
0
    FAISS_THROW_IF_NOT(nb % bbs == 0);
56
0
    FAISS_THROW_IF_NOT(nsq % 2 == 0);
57
58
0
    if (nb == 0) {
59
0
        return;
60
0
    }
61
0
    memset(blocks, 0, nb * nsq / 2);
62
#ifdef FAISS_BIG_ENDIAN
63
    const uint8_t perm0[16] = {
64
            8, 0, 9, 1, 10, 2, 11, 3, 12, 4, 13, 5, 14, 6, 15, 7};
65
#else
66
0
    const uint8_t perm0[16] = {
67
0
            0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
68
0
#endif
69
70
0
    uint8_t* codes2 = blocks;
71
0
    for (size_t i0 = 0; i0 < nb; i0 += bbs) {
72
0
        for (int sq = 0; sq < nsq; sq += 2) {
73
0
            for (size_t i = 0; i < bbs; i += 32) {
74
0
                std::array<uint8_t, 32> c, c0, c1;
75
0
                get_matrix_column(
76
0
                        codes, ntotal, (M + 1) / 2, i0 + i, sq / 2, c);
77
0
                for (int j = 0; j < 32; j++) {
78
0
                    c0[j] = c[j] & 15;
79
0
                    c1[j] = c[j] >> 4;
80
0
                }
81
0
                for (int j = 0; j < 16; j++) {
82
0
                    uint8_t d0, d1;
83
0
                    d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4);
84
0
                    d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4);
85
0
                    codes2[j] = d0;
86
0
                    codes2[j + 16] = d1;
87
0
                }
88
0
                codes2 += 32;
89
0
            }
90
0
        }
91
0
    }
92
0
}
93
94
void pq4_pack_codes_range(
95
        const uint8_t* codes,
96
        size_t M,
97
        size_t i0,
98
        size_t i1,
99
        size_t bbs,
100
        size_t nsq,
101
0
        uint8_t* blocks) {
102
#ifdef FAISS_BIG_ENDIAN
103
    const uint8_t perm0[16] = {
104
            8, 0, 9, 1, 10, 2, 11, 3, 12, 4, 13, 5, 14, 6, 15, 7};
105
#else
106
0
    const uint8_t perm0[16] = {
107
0
            0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
108
0
#endif
109
110
    // range of affected blocks
111
0
    size_t block0 = i0 / bbs;
112
0
    size_t block1 = ((i1 - 1) / bbs) + 1;
113
114
0
    for (size_t b = block0; b < block1; b++) {
115
0
        uint8_t* codes2 = blocks + b * bbs * nsq / 2;
116
0
        int64_t i_base = b * bbs - i0;
117
0
        for (int sq = 0; sq < nsq; sq += 2) {
118
0
            for (size_t i = 0; i < bbs; i += 32) {
119
0
                std::array<uint8_t, 32> c, c0, c1;
120
0
                get_matrix_column(
121
0
                        codes, i1 - i0, (M + 1) / 2, i_base + i, sq / 2, c);
122
0
                for (int j = 0; j < 32; j++) {
123
0
                    c0[j] = c[j] & 15;
124
0
                    c1[j] = c[j] >> 4;
125
0
                }
126
0
                for (int j = 0; j < 16; j++) {
127
0
                    uint8_t d0, d1;
128
0
                    d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4);
129
0
                    d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4);
130
0
                    codes2[j] |= d0;
131
0
                    codes2[j + 16] |= d1;
132
0
                }
133
0
                codes2 += 32;
134
0
            }
135
0
        }
136
0
    }
137
0
}
138
139
namespace {
140
141
// get the specific address of the vector inside a block
142
// shift is used for determine the if the saved in bits 0..3 (false) or
143
// bits 4..7 (true)
144
size_t get_vector_specific_address(
145
        size_t bbs,
146
        size_t vector_id,
147
        size_t sq,
148
0
        bool& shift) {
149
    // get the vector_id inside the block
150
0
    vector_id = vector_id % bbs;
151
0
    shift = vector_id > 15;
152
0
    vector_id = vector_id & 15;
153
154
    // get the address of the vector in sq
155
0
    size_t address;
156
0
    if (vector_id < 8) {
157
0
        address = vector_id << 1;
158
0
    } else {
159
0
        address = ((vector_id - 8) << 1) + 1;
160
0
    }
161
0
    if (sq & 1) {
162
0
        address += 16;
163
0
    }
164
0
    return (sq >> 1) * bbs + address;
165
0
}
166
167
} // anonymous namespace
168
169
uint8_t pq4_get_packed_element(
170
        const uint8_t* data,
171
        size_t bbs,
172
        size_t nsq,
173
        size_t vector_id,
174
0
        size_t sq) {
175
    // move to correct bbs-sized block
176
    // number of blocks * block size
177
0
    data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
178
0
    bool shift;
179
0
    size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
180
0
    if (shift) {
181
0
        return data[address] >> 4;
182
0
    } else {
183
0
        return data[address] & 15;
184
0
    }
185
0
}
186
187
void pq4_set_packed_element(
188
        uint8_t* data,
189
        uint8_t code,
190
        size_t bbs,
191
        size_t nsq,
192
        size_t vector_id,
193
0
        size_t sq) {
194
    // move to correct bbs-sized block
195
    // number of blocks * block size
196
0
    data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
197
0
    bool shift;
198
0
    size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
199
0
    if (shift) {
200
0
        data[address] = (code << 4) | (data[address] & 15);
201
0
    } else {
202
0
        data[address] = code | (data[address] & ~15);
203
0
    }
204
0
}
205
206
/***************************************************************
207
 * CodePackerPQ4 implementation
208
 ***************************************************************/
209
210
0
CodePackerPQ4::CodePackerPQ4(size_t nsq, size_t bbs) {
211
0
    this->nsq = nsq;
212
0
    nvec = bbs;
213
0
    code_size = (nsq * 4 + 7) / 8;
214
0
    block_size = ((nsq + 1) / 2) * bbs;
215
0
}
216
217
void CodePackerPQ4::pack_1(
218
        const uint8_t* flat_code,
219
        size_t offset,
220
0
        uint8_t* block) const {
221
0
    size_t bbs = nvec;
222
0
    if (offset >= nvec) {
223
0
        block += (offset / nvec) * block_size;
224
0
        offset = offset % nvec;
225
0
    }
226
0
    for (size_t i = 0; i < code_size; i++) {
227
0
        uint8_t code = flat_code[i];
228
0
        pq4_set_packed_element(block, code & 15, bbs, nsq, offset, 2 * i);
229
0
        pq4_set_packed_element(block, code >> 4, bbs, nsq, offset, 2 * i + 1);
230
0
    }
231
0
}
232
233
void CodePackerPQ4::unpack_1(
234
        const uint8_t* block,
235
        size_t offset,
236
0
        uint8_t* flat_code) const {
237
0
    size_t bbs = nvec;
238
0
    if (offset >= nvec) {
239
0
        block += (offset / nvec) * block_size;
240
0
        offset = offset % nvec;
241
0
    }
242
0
    for (size_t i = 0; i < code_size; i++) {
243
0
        uint8_t code0, code1;
244
0
        code0 = pq4_get_packed_element(block, bbs, nsq, offset, 2 * i);
245
0
        code1 = pq4_get_packed_element(block, bbs, nsq, offset, 2 * i + 1);
246
0
        flat_code[i] = code0 | (code1 << 4);
247
0
    }
248
0
}
249
250
/***************************************************************
251
 * Packing functions for Look-Up Tables (LUT)
252
 ***************************************************************/
253
254
0
void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest) {
255
0
    for (int q = 0; q < nq; q++) {
256
0
        for (int sq = 0; sq < nsq; sq += 2) {
257
0
            memcpy(dest + (sq / 2 * nq + q) * 32,
258
0
                   src + (q * nsq + sq) * 16,
259
0
                   16);
260
0
            memcpy(dest + (sq / 2 * nq + q) * 32 + 16,
261
0
                   src + (q * nsq + sq + 1) * 16,
262
0
                   16);
263
0
        }
264
0
    }
265
0
}
266
267
0
int pq4_pack_LUT_qbs(int qbs, int nsq, const uint8_t* src, uint8_t* dest) {
268
0
    FAISS_THROW_IF_NOT(nsq % 2 == 0);
269
0
    size_t dim12 = 16 * nsq;
270
0
    int i0 = 0;
271
0
    int qi = qbs;
272
0
    while (qi) {
273
0
        int nq = qi & 15;
274
0
        qi >>= 4;
275
0
        pq4_pack_LUT(nq, nsq, src + i0 * dim12, dest + i0 * dim12);
276
0
        i0 += nq;
277
0
    }
278
0
    return i0;
279
0
}
280
281
namespace {
282
283
void pack_LUT_1_q_map(
284
        int nq,
285
        const int* q_map,
286
        int nsq,
287
        const uint8_t* src,
288
0
        uint8_t* dest) {
289
0
    for (int qi = 0; qi < nq; qi++) {
290
0
        int q = q_map[qi];
291
0
        for (int sq = 0; sq < nsq; sq += 2) {
292
0
            memcpy(dest + (sq / 2 * nq + qi) * 32,
293
0
                   src + (q * nsq + sq) * 16,
294
0
                   16);
295
0
            memcpy(dest + (sq / 2 * nq + qi) * 32 + 16,
296
0
                   src + (q * nsq + sq + 1) * 16,
297
0
                   16);
298
0
        }
299
0
    }
300
0
}
301
302
} // anonymous namespace
303
304
int pq4_pack_LUT_qbs_q_map(
305
        int qbs,
306
        int nsq,
307
        const uint8_t* src,
308
        const int* q_map,
309
0
        uint8_t* dest) {
310
0
    FAISS_THROW_IF_NOT(nsq % 2 == 0);
311
0
    size_t dim12 = 16 * nsq;
312
0
    int i0 = 0;
313
0
    int qi = qbs;
314
0
    while (qi) {
315
0
        int nq = qi & 15;
316
0
        qi >>= 4;
317
0
        pack_LUT_1_q_map(nq, q_map + i0, nsq, src, dest + i0 * dim12);
318
0
        i0 += nq;
319
0
    }
320
0
    return i0;
321
0
}
322
323
} // namespace faiss