Coverage Report

Created: 2026-03-13 03:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/impl/lattice_Zn.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
#ifndef FAISS_LATTICE_ZN_H
10
#define FAISS_LATTICE_ZN_H
11
12
#include <stddef.h>
13
#include <stdint.h>
14
#include <vector>
15
16
namespace faiss {
17
18
/** returns the nearest vertex in the sphere to a query. Returns only
19
 * the coordinates, not an id.
20
 *
21
 * Algorithm: all points are derived from a one atom vector up to a
22
 * permutation and sign changes. The search function finds the most
23
 * appropriate atom and transformation.
24
 */
25
struct ZnSphereSearch {
26
    int dimS, r2;
27
    int natom;
28
29
    /// size dim * ntatom
30
    std::vector<float> voc;
31
32
    ZnSphereSearch(int dim, int r2);
33
34
    /// find nearest centroid. x does not need to be normalized
35
    float search(const float* x, float* c) const;
36
37
    /// full call. Requires externally-allocated temp space
38
    float search(
39
            const float* x,
40
            float* c,
41
            float* tmp,   // size 2 *dim
42
            int* tmp_int, // size dim
43
            int* ibest_out = nullptr) const;
44
45
    // multi-threaded
46
    void search_multi(int n, const float* x, float* c_out, float* dp_out);
47
};
48
49
/***************************************************************************
50
 * Support ids as well.
51
 *
52
 * Limitations: ids are limited to 64 bit
53
 ***************************************************************************/
54
55
struct EnumeratedVectors {
56
    /// size of the collection
57
    uint64_t nv;
58
    int dim;
59
60
0
    explicit EnumeratedVectors(int dim) : nv(0), dim(dim) {}
61
62
    /// encode a vector from a collection
63
    virtual uint64_t encode(const float* x) const = 0;
64
65
    /// decode it
66
    virtual void decode(uint64_t code, float* c) const = 0;
67
68
    // call encode on nc vectors
69
    void encode_multi(size_t nc, const float* c, uint64_t* codes) const;
70
71
    // call decode on nc codes
72
    void decode_multi(size_t nc, const uint64_t* codes, float* c) const;
73
74
    // find the nearest neighbor of each xq
75
    // (decodes and computes distances)
76
    void find_nn(
77
            size_t n,
78
            const uint64_t* codes,
79
            size_t nq,
80
            const float* xq,
81
            int64_t* idx,
82
            float* dis);
83
84
0
    virtual ~EnumeratedVectors() {}
85
};
86
87
struct Repeat {
88
    float val;
89
    int n;
90
};
91
92
/** Repeats: used to encode a vector that has n occurrences of
93
 *  val. Encodes the signs and permutation of the vector. Useful for
94
 *  atoms.
95
 */
96
struct Repeats {
97
    int dim;
98
    std::vector<Repeat> repeats;
99
100
    // initialize from a template of the atom.
101
    Repeats(int dim = 0, const float* c = nullptr);
102
103
    // count number of possible codes for this atom
104
    uint64_t count() const;
105
106
    uint64_t encode(const float* c) const;
107
108
    void decode(uint64_t code, float* c) const;
109
};
110
111
/** codec that can return ids for the encoded vectors
112
 *
113
 * uses the ZnSphereSearch to encode the vector by encoding the
114
 * permutation and signs. Depends on ZnSphereSearch because it uses
115
 * the atom numbers */
116
struct ZnSphereCodec : ZnSphereSearch, EnumeratedVectors {
117
    struct CodeSegment : Repeats {
118
0
        explicit CodeSegment(const Repeats& r) : Repeats(r) {}
119
        uint64_t c0; // first code assigned to segment
120
        int signbits;
121
    };
122
123
    std::vector<CodeSegment> code_segments;
124
    uint64_t nv;
125
    size_t code_size;
126
127
    ZnSphereCodec(int dim, int r2);
128
129
    uint64_t search_and_encode(const float* x) const;
130
131
    void decode(uint64_t code, float* c) const override;
132
133
    /// takes vectors that do not need to be centroids
134
    uint64_t encode(const float* x) const override;
135
};
136
137
/** recursive sphere codec
138
 *
139
 * Uses a recursive decomposition on the dimensions to encode
140
 * centroids found by the ZnSphereSearch. The codes are *not*
141
 * compatible with the ones of ZnSpehreCodec
142
 */
143
struct ZnSphereCodecRec : EnumeratedVectors {
144
    int r2;
145
146
    int log2_dim;
147
    int code_size;
148
149
    ZnSphereCodecRec(int dim, int r2);
150
151
    uint64_t encode_centroid(const float* c) const;
152
153
    void decode(uint64_t code, float* c) const override;
154
155
    /// vectors need to be centroids (does not work on arbitrary
156
    /// vectors)
157
    uint64_t encode(const float* x) const override;
158
159
    std::vector<uint64_t> all_nv;
160
    std::vector<uint64_t> all_nv_cum;
161
162
    int decode_cache_ld;
163
    std::vector<std::vector<float>> decode_cache;
164
165
    // nb of vectors in the sphere in dim 2^ld with r2 radius
166
    uint64_t get_nv(int ld, int r2a) const;
167
168
    // cumulative version
169
    uint64_t get_nv_cum(int ld, int r2t, int r2a) const;
170
    void set_nv_cum(int ld, int r2t, int r2a, uint64_t v);
171
};
172
173
/** Codec that uses the recursive codec if dim is a power of 2 and
174
 * the regular one otherwise */
175
struct ZnSphereCodecAlt : ZnSphereCodec {
176
    bool use_rec;
177
    ZnSphereCodecRec znc_rec;
178
179
    ZnSphereCodecAlt(int dim, int r2);
180
181
    uint64_t encode(const float* x) const override;
182
183
    void decode(uint64_t code, float* c) const override;
184
};
185
186
} // namespace faiss
187
188
#endif