Coverage Report

Created: 2026-03-17 19:28

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/invlists/BlockInvertedLists.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/invlists/BlockInvertedLists.h>
9
10
#include <faiss/impl/CodePacker.h>
11
#include <faiss/impl/FaissAssert.h>
12
#include <faiss/impl/IDSelector.h>
13
14
#include <faiss/impl/io.h>
15
#include <faiss/impl/io_macros.h>
16
17
namespace faiss {
18
19
BlockInvertedLists::BlockInvertedLists(
20
        size_t nlist,
21
        size_t n_per_block,
22
        size_t block_size)
23
0
        : InvertedLists(nlist, InvertedLists::INVALID_CODE_SIZE),
24
0
          n_per_block(n_per_block),
25
0
          block_size(block_size) {
26
0
    ids.resize(nlist);
27
0
    codes.resize(nlist);
28
0
}
29
30
BlockInvertedLists::BlockInvertedLists(size_t nlist, const CodePacker* packer)
31
0
        : InvertedLists(nlist, InvertedLists::INVALID_CODE_SIZE),
32
0
          n_per_block(packer->nvec),
33
0
          block_size(packer->block_size),
34
0
          packer(packer) {
35
0
    ids.resize(nlist);
36
0
    codes.resize(nlist);
37
0
}
38
39
BlockInvertedLists::BlockInvertedLists()
40
0
        : InvertedLists(0, InvertedLists::INVALID_CODE_SIZE) {}
41
42
size_t BlockInvertedLists::add_entries(
43
        size_t list_no,
44
        size_t n_entry,
45
        const idx_t* ids_in,
46
0
        const uint8_t* code) {
47
0
    if (n_entry == 0) {
48
0
        return 0;
49
0
    }
50
0
    FAISS_THROW_IF_NOT(list_no < nlist);
51
0
    size_t o = ids[list_no].size();
52
0
    ids[list_no].resize(o + n_entry);
53
0
    memcpy(&ids[list_no][o], ids_in, sizeof(ids_in[0]) * n_entry);
54
0
    size_t n_block = (o + n_entry + n_per_block - 1) / n_per_block;
55
0
    codes[list_no].resize(n_block * block_size);
56
0
    if (o % block_size == 0) {
57
        // copy whole blocks
58
0
        memcpy(&codes[list_no][o * packer->code_size],
59
0
               code,
60
0
               n_block * block_size);
61
0
    } else {
62
0
        FAISS_THROW_IF_NOT_MSG(packer, "missing code packer");
63
0
        std::vector<uint8_t> buffer(packer->code_size);
64
0
        for (size_t i = 0; i < n_entry; i++) {
65
0
            packer->unpack_1(code, i, buffer.data());
66
0
            packer->pack_1(buffer.data(), i + o, codes[list_no].data());
67
0
        }
68
0
    }
69
0
    return o;
70
0
}
71
72
0
size_t BlockInvertedLists::list_size(size_t list_no) const {
73
0
    assert(list_no < nlist);
74
0
    return ids[list_no].size();
75
0
}
76
77
0
const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
78
0
    assert(list_no < nlist);
79
0
    return codes[list_no].get();
80
0
}
81
82
0
size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
83
0
    idx_t nremove = 0;
84
0
#pragma omp parallel for
85
0
    for (idx_t i = 0; i < nlist; i++) {
86
0
        std::vector<uint8_t> buffer(packer->code_size);
87
0
        idx_t l = ids[i].size(), j = 0;
88
0
        while (j < l) {
89
0
            if (sel.is_member(ids[i][j])) {
90
0
                l--;
91
0
                ids[i][j] = ids[i][l];
92
0
                packer->unpack_1(codes[i].data(), l, buffer.data());
93
0
                packer->pack_1(buffer.data(), j, codes[i].data());
94
0
            } else {
95
0
                j++;
96
0
            }
97
0
        }
98
0
        resize(i, l);
99
0
        nremove += ids[i].size() - l;
100
0
    }
101
102
0
    return nremove;
103
0
}
104
105
0
const idx_t* BlockInvertedLists::get_ids(size_t list_no) const {
106
0
    assert(list_no < nlist);
107
0
    return ids[list_no].data();
108
0
}
109
110
0
void BlockInvertedLists::resize(size_t list_no, size_t new_size) {
111
0
    ids[list_no].resize(new_size);
112
0
    size_t prev_nbytes = codes[list_no].size();
113
0
    size_t n_block = (new_size + n_per_block - 1) / n_per_block;
114
0
    size_t new_nbytes = n_block * block_size;
115
0
    codes[list_no].resize(new_nbytes);
116
0
    if (prev_nbytes < new_nbytes) {
117
        // set new elements to 0
118
0
        memset(codes[list_no].data() + prev_nbytes,
119
0
               0,
120
0
               new_nbytes - prev_nbytes);
121
0
    }
122
0
}
123
124
void BlockInvertedLists::update_entries(
125
        size_t,
126
        size_t,
127
        size_t,
128
        const idx_t*,
129
0
        const uint8_t*) {
130
0
    FAISS_THROW_MSG("not implemented");
131
0
}
132
133
0
BlockInvertedLists::~BlockInvertedLists() {
134
0
    delete packer;
135
0
}
136
137
/**************************************************
138
 * IO hook implementation
139
 **************************************************/
140
141
BlockInvertedListsIOHook::BlockInvertedListsIOHook()
142
1
        : InvertedListsIOHook("ilbl", typeid(BlockInvertedLists).name()) {}
143
144
void BlockInvertedListsIOHook::write(const InvertedLists* ils_in, IOWriter* f)
145
0
        const {
146
0
    uint32_t h = fourcc("ilbl");
147
0
    WRITE1(h);
148
0
    const BlockInvertedLists* il =
149
0
            dynamic_cast<const BlockInvertedLists*>(ils_in);
150
0
    WRITE1(il->nlist);
151
0
    WRITE1(il->code_size);
152
0
    WRITE1(il->n_per_block);
153
0
    WRITE1(il->block_size);
154
155
0
    for (size_t i = 0; i < il->nlist; i++) {
156
0
        WRITEVECTOR(il->ids[i]);
157
0
        WRITEVECTOR(il->codes[i]);
158
0
    }
159
0
}
160
161
InvertedLists* BlockInvertedListsIOHook::read(IOReader* f, int /* io_flags */)
162
0
        const {
163
0
    BlockInvertedLists* il = new BlockInvertedLists();
164
0
    READ1(il->nlist);
165
0
    READ1(il->code_size);
166
0
    READ1(il->n_per_block);
167
0
    READ1(il->block_size);
168
169
0
    il->ids.resize(il->nlist);
170
0
    il->codes.resize(il->nlist);
171
172
0
    for (size_t i = 0; i < il->nlist; i++) {
173
0
        READVECTOR(il->ids[i]);
174
0
        READVECTOR(il->codes[i]);
175
0
    }
176
177
0
    return il;
178
0
}
179
180
} // namespace faiss