Coverage Report

Created: 2025-10-01 20:59

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/LocalSearchQuantizer.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#pragma once
9
10
#include <stdint.h>
11
12
#include <random>
13
#include <string>
14
#include <unordered_map>
15
#include <vector>
16
17
#include <faiss/impl/AdditiveQuantizer.h>
18
#include <faiss/impl/platform_macros.h>
19
#include <faiss/utils/utils.h>
20
21
namespace faiss {
22
23
namespace lsq {
24
25
struct IcmEncoderFactory;
26
27
} // namespace lsq
28
29
/** Implementation of LSQ/LSQ++ described in the following two papers:
30
 *
31
 * Revisiting additive quantization
32
 * Julieta Martinez, et al. ECCV 2016
33
 *
34
 * LSQ++: Lower running time and higher recall in multi-codebook quantization
35
 * Julieta Martinez, et al. ECCV 2018
36
 *
37
 * This implementation is mostly translated from the Julia implementations
38
 * by Julieta Martinez:
39
 * (https://github.com/una-dinosauria/local-search-quantization,
40
 *  https://github.com/una-dinosauria/Rayuela.jl)
41
 *
42
 * The trained codes are stored in `codebooks` which is called
43
 * `centroids` in PQ and RQ.
44
 */
45
struct LocalSearchQuantizer : AdditiveQuantizer {
46
    size_t K; ///< number of codes per codebook
47
48
    size_t train_iters = 25;      ///< number of iterations in training
49
    size_t encode_ils_iters = 16; ///< iterations of local search in encoding
50
    size_t train_ils_iters = 8;   ///< iterations of local search in training
51
    size_t icm_iters = 4;         ///< number of iterations in icm
52
53
    float p = 0.5f;      ///< temperature factor
54
    float lambd = 1e-2f; ///< regularization factor
55
56
    size_t chunk_size = 10000; ///< nb of vectors to encode at a time
57
58
    int random_seed = 0x12345; ///< seed for random generator
59
    size_t nperts = 4;         ///< number of perturbation in each code
60
61
    ///< if non-NULL, use this encoder to encode (owned by the object)
62
    lsq::IcmEncoderFactory* icm_encoder_factory = nullptr;
63
64
    bool update_codebooks_with_double = true;
65
66
    LocalSearchQuantizer(
67
            size_t d,     /* dimensionality of the input vectors */
68
            size_t M,     /* number of subquantizers */
69
            size_t nbits, /* number of bit per subvector index */
70
            Search_type_t search_type =
71
                    ST_decompress); /* determines the storage type */
72
73
    LocalSearchQuantizer();
74
75
    ~LocalSearchQuantizer() override;
76
77
    // Train the local search quantizer
78
    void train(size_t n, const float* x) override;
79
80
    /** Encode a set of vectors
81
     *
82
     * @param x      vectors to encode, size n * d
83
     * @param codes  output codes, size n * code_size
84
     * @param n      number of vectors
85
     * @param centroids  centroids to be added to x, size n * d
86
     */
87
    void compute_codes_add_centroids(
88
            const float* x,
89
            uint8_t* codes,
90
            size_t n,
91
            const float* centroids = nullptr) const override;
92
93
    /** Update codebooks given encodings
94
     *
95
     * @param x      training vectors, size n * d
96
     * @param codes  encoded training vectors, size n * M
97
     * @param n      number of vectors
98
     */
99
    void update_codebooks(const float* x, const int32_t* codes, size_t n);
100
101
    /** Encode vectors given codebooks using iterative conditional mode (icm).
102
     *
103
     * @param codes     output codes, size n * M
104
     * @param x         vectors to encode, size n * d
105
     * @param n         number of vectors
106
     * @param ils_iters number of iterations of iterative local search
107
     */
108
    void icm_encode(
109
            int32_t* codes,
110
            const float* x,
111
            size_t n,
112
            size_t ils_iters,
113
            std::mt19937& gen) const;
114
115
    void icm_encode_impl(
116
            int32_t* codes,
117
            const float* x,
118
            const float* unaries,
119
            std::mt19937& gen,
120
            size_t n,
121
            size_t ils_iters,
122
            bool verbose) const;
123
124
    void icm_encode_step(
125
            int32_t* codes,
126
            const float* unaries,
127
            const float* binaries,
128
            size_t n,
129
            size_t n_iters) const;
130
131
    /** Add some perturbation to codes
132
     *
133
     * @param codes codes to be perturbed, size n * M
134
     * @param n     number of vectors
135
     */
136
    void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
137
138
    /** Add some perturbation to codebooks
139
     *
140
     * @param T         temperature of simulated annealing
141
     * @param stddev    standard derivations of each dimension in training data
142
     */
143
    void perturb_codebooks(
144
            float T,
145
            const std::vector<float>& stddev,
146
            std::mt19937& gen);
147
148
    /** Compute binary terms
149
     *
150
     * @param binaries binary terms, size M * M * K * K
151
     */
152
    void compute_binary_terms(float* binaries) const;
153
154
    /** Compute unary terms
155
     *
156
     * @param n       number of vectors
157
     * @param x       vectors to encode, size n * d
158
     * @param unaries unary terms, size n * M * K
159
     */
160
    void compute_unary_terms(const float* x, float* unaries, size_t n) const;
161
162
    /** Helper function to compute reconstruction error
163
     *
164
     * @param codes encoded codes, size n * M
165
     * @param x     vectors to encode, size n * d
166
     * @param n     number of vectors
167
     * @param objs  if it is not null, store reconstruction
168
                    error of each vector into it, size n
169
     */
170
    float evaluate(
171
            const int32_t* codes,
172
            const float* x,
173
            size_t n,
174
            float* objs = nullptr) const;
175
};
176
177
namespace lsq {
178
179
struct IcmEncoder {
180
    std::vector<float> binaries;
181
182
    bool verbose;
183
184
    const LocalSearchQuantizer* lsq;
185
186
    explicit IcmEncoder(const LocalSearchQuantizer* lsq);
187
188
0
    virtual ~IcmEncoder() {}
189
190
    ///< compute binary terms
191
    virtual void set_binary_term();
192
193
    /** Encode vectors given codebooks
194
     *
195
     * @param codes     output codes, size n * M
196
     * @param x         vectors to encode, size n * d
197
     * @param gen       random generator
198
     * @param n         number of vectors
199
     * @param ils_iters number of iterations of iterative local search
200
     */
201
    virtual void encode(
202
            int32_t* codes,
203
            const float* x,
204
            std::mt19937& gen,
205
            size_t n,
206
            size_t ils_iters) const;
207
};
208
209
struct IcmEncoderFactory {
210
0
    virtual IcmEncoder* get(const LocalSearchQuantizer* lsq) {
211
0
        return new IcmEncoder(lsq);
212
0
    }
213
0
    virtual ~IcmEncoderFactory() {}
214
};
215
216
/** A helper struct to count consuming time during training.
217
 *  It is NOT thread-safe.
218
 */
219
struct LSQTimer {
220
    std::unordered_map<std::string, double> t;
221
222
1
    LSQTimer() {
223
1
        reset();
224
1
    }
225
226
    double get(const std::string& name);
227
228
    void add(const std::string& name, double delta);
229
230
    void reset();
231
};
232
233
struct LSQTimerScope {
234
    double t0;
235
    LSQTimer* timer;
236
    std::string name;
237
    bool finished;
238
239
    LSQTimerScope(LSQTimer* timer, std::string name);
240
241
    void finish();
242
243
    ~LSQTimerScope();
244
};
245
246
} // namespace lsq
247
248
} // namespace faiss