Coverage Report

Created: 2025-10-13 20:49

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexPreTransform.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/IndexPreTransform.h>
11
12
#include <cmath>
13
#include <cstdio>
14
#include <cstring>
15
#include <memory>
16
17
#include <faiss/impl/AuxIndexStructures.h>
18
#include <faiss/impl/DistanceComputer.h>
19
#include <faiss/impl/FaissAssert.h>
20
21
namespace faiss {
22
23
/*********************************************
24
 * IndexPreTransform
25
 *********************************************/
26
27
0
IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
28
29
IndexPreTransform::IndexPreTransform(Index* index)
30
0
        : Index(index->d, index->metric_type), index(index), own_fields(false) {
31
0
    is_trained = index->is_trained;
32
0
    ntotal = index->ntotal;
33
0
}
34
35
IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
36
0
        : Index(index->d, index->metric_type), index(index), own_fields(false) {
37
0
    is_trained = index->is_trained;
38
0
    ntotal = index->ntotal;
39
0
    prepend_transform(ltrans);
40
0
}
41
42
0
void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
43
0
    FAISS_THROW_IF_NOT(ltrans->d_out == d);
44
0
    is_trained = is_trained && ltrans->is_trained;
45
0
    chain.insert(chain.begin(), ltrans);
46
0
    d = ltrans->d_in;
47
0
}
48
49
0
IndexPreTransform::~IndexPreTransform() {
50
0
    if (own_fields) {
51
0
        for (int i = 0; i < chain.size(); i++)
52
0
            delete chain[i];
53
0
        delete index;
54
0
    }
55
0
}
56
57
0
void IndexPreTransform::train(idx_t n, const float* x) {
58
0
    int last_untrained = 0;
59
0
    if (!index->is_trained) {
60
0
        last_untrained = chain.size();
61
0
    } else {
62
0
        for (int i = chain.size() - 1; i >= 0; i--) {
63
0
            if (!chain[i]->is_trained) {
64
0
                last_untrained = i;
65
0
                break;
66
0
            }
67
0
        }
68
0
    }
69
0
    const float* prev_x = x;
70
0
    std::unique_ptr<const float[]> del;
71
72
0
    if (verbose) {
73
0
        printf("IndexPreTransform::train: training chain 0 to %d\n",
74
0
               last_untrained);
75
0
    }
76
77
0
    for (int i = 0; i <= last_untrained; i++) {
78
0
        if (i < chain.size()) {
79
0
            VectorTransform* ltrans = chain[i];
80
0
            if (!ltrans->is_trained) {
81
0
                if (verbose) {
82
0
                    printf("   Training chain component %d/%zd\n",
83
0
                           i,
84
0
                           chain.size());
85
0
                    if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
86
0
                        opqm->verbose = true;
87
0
                    }
88
0
                }
89
0
                ltrans->train(n, prev_x);
90
0
            }
91
0
        } else {
92
0
            if (verbose) {
93
0
                printf("   Training sub-index\n");
94
0
            }
95
0
            index->train(n, prev_x);
96
0
        }
97
0
        if (i == last_untrained)
98
0
            break;
99
0
        if (verbose) {
100
0
            printf("   Applying transform %d/%zd\n", i, chain.size());
101
0
        }
102
103
0
        float* xt = chain[i]->apply(n, prev_x);
104
105
0
        if (prev_x != x) {
106
0
            del.reset();
107
0
        }
108
109
0
        prev_x = xt;
110
0
        del.reset(xt);
111
0
    }
112
113
0
    is_trained = true;
114
0
}
115
116
0
const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
117
0
    const float* prev_x = x;
118
0
    std::unique_ptr<const float[]> del;
119
120
0
    for (int i = 0; i < chain.size(); i++) {
121
0
        float* xt = chain[i]->apply(n, prev_x);
122
0
        std::unique_ptr<const float[]> del2(xt);
123
0
        del2.swap(del);
124
0
        prev_x = xt;
125
0
    }
126
0
    del.release();
127
0
    return prev_x;
128
0
}
129
130
void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
131
0
        const {
132
0
    const float* next_x = xt;
133
0
    std::unique_ptr<const float[]> del;
134
135
0
    for (int i = chain.size() - 1; i >= 0; i--) {
136
0
        float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
137
0
        std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
138
0
        chain[i]->reverse_transform(n, next_x, prev_x);
139
0
        del2.swap(del);
140
0
        next_x = prev_x;
141
0
    }
142
0
}
143
144
0
void IndexPreTransform::add(idx_t n, const float* x) {
145
0
    FAISS_THROW_IF_NOT(is_trained);
146
0
    TransformedVectors tv(x, apply_chain(n, x));
147
0
    index->add(n, tv.x);
148
0
    ntotal = index->ntotal;
149
0
}
150
151
void IndexPreTransform::add_with_ids(
152
        idx_t n,
153
        const float* x,
154
0
        const idx_t* xids) {
155
0
    FAISS_THROW_IF_NOT(is_trained);
156
0
    TransformedVectors tv(x, apply_chain(n, x));
157
0
    index->add_with_ids(n, tv.x, xids);
158
0
    ntotal = index->ntotal;
159
0
}
160
161
namespace {
162
163
const SearchParameters* extract_index_search_params(
164
0
        const SearchParameters* params_in) {
165
0
    auto params = dynamic_cast<const SearchParametersPreTransform*>(params_in);
166
0
    return params ? params->index_params : params_in;
167
0
}
168
169
} // namespace
170
171
void IndexPreTransform::search(
172
        idx_t n,
173
        const float* x,
174
        idx_t k,
175
        float* distances,
176
        idx_t* labels,
177
0
        const SearchParameters* params) const {
178
0
    FAISS_THROW_IF_NOT(k > 0);
179
0
    FAISS_THROW_IF_NOT(is_trained);
180
0
    const float* xt = apply_chain(n, x);
181
0
    std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
182
0
    index->search(
183
0
            n, xt, k, distances, labels, extract_index_search_params(params));
184
0
}
185
186
void IndexPreTransform::range_search(
187
        idx_t n,
188
        const float* x,
189
        float radius,
190
        RangeSearchResult* result,
191
0
        const SearchParameters* params) const {
192
0
    FAISS_THROW_IF_NOT(is_trained);
193
0
    TransformedVectors tv(x, apply_chain(n, x));
194
0
    index->range_search(
195
0
            n, tv.x, radius, result, extract_index_search_params(params));
196
0
}
197
198
0
void IndexPreTransform::reset() {
199
0
    index->reset();
200
0
    ntotal = 0;
201
0
}
202
203
0
size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
204
0
    size_t nremove = index->remove_ids(sel);
205
0
    ntotal = index->ntotal;
206
0
    return nremove;
207
0
}
208
209
0
void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
210
0
    float* x = chain.empty() ? recons : new float[index->d];
211
0
    std::unique_ptr<float[]> del(recons == x ? nullptr : x);
212
    // Initial reconstruction
213
0
    index->reconstruct(key, x);
214
215
    // Revert transformations from last to first
216
0
    reverse_chain(1, x, recons);
217
0
}
218
219
0
void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
220
0
    float* x = chain.empty() ? recons : new float[ni * index->d];
221
0
    std::unique_ptr<float[]> del(recons == x ? nullptr : x);
222
    // Initial reconstruction
223
0
    index->reconstruct_n(i0, ni, x);
224
225
    // Revert transformations from last to first
226
0
    reverse_chain(ni, x, recons);
227
0
}
228
229
void IndexPreTransform::search_and_reconstruct(
230
        idx_t n,
231
        const float* x,
232
        idx_t k,
233
        float* distances,
234
        idx_t* labels,
235
        float* recons,
236
0
        const SearchParameters* params) const {
237
0
    FAISS_THROW_IF_NOT(k > 0);
238
0
    FAISS_THROW_IF_NOT(is_trained);
239
240
0
    TransformedVectors trans(x, apply_chain(n, x));
241
242
0
    float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
243
0
    std::unique_ptr<float[]> del2(
244
0
            (recons_temp == recons) ? nullptr : recons_temp);
245
0
    index->search_and_reconstruct(
246
0
            n,
247
0
            trans.x,
248
0
            k,
249
0
            distances,
250
0
            labels,
251
0
            recons_temp,
252
0
            extract_index_search_params(params));
253
254
    // Revert transformations from last to first
255
0
    reverse_chain(n * k, recons_temp, recons);
256
0
}
257
258
0
size_t IndexPreTransform::sa_code_size() const {
259
0
    return index->sa_code_size();
260
0
}
261
262
void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
263
0
        const {
264
0
    TransformedVectors tv(x, apply_chain(n, x));
265
0
    index->sa_encode(n, tv.x, bytes);
266
0
}
267
268
void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
269
0
        const {
270
0
    if (chain.empty()) {
271
0
        index->sa_decode(n, bytes, x);
272
0
    } else {
273
0
        std::unique_ptr<float[]> x1(new float[index->d * n]);
274
0
        index->sa_decode(n, bytes, x1.get());
275
        // Revert transformations from last to first
276
0
        reverse_chain(n, x1.get(), x);
277
0
    }
278
0
}
279
280
0
void IndexPreTransform::merge_from(Index& otherIndex, idx_t add_id) {
281
0
    check_compatible_for_merge(otherIndex);
282
0
    auto other = static_cast<const IndexPreTransform*>(&otherIndex);
283
0
    index->merge_from(*other->index, add_id);
284
0
    ntotal = index->ntotal;
285
0
}
286
287
void IndexPreTransform::check_compatible_for_merge(
288
0
        const Index& otherIndex) const {
289
0
    auto other = dynamic_cast<const IndexPreTransform*>(&otherIndex);
290
0
    FAISS_THROW_IF_NOT(other);
291
0
    FAISS_THROW_IF_NOT(chain.size() == other->chain.size());
292
0
    for (int i = 0; i < chain.size(); i++) {
293
0
        chain[i]->check_identical(*other->chain[i]);
294
0
    }
295
0
    index->check_compatible_for_merge(*other->index);
296
0
}
297
298
namespace {
299
300
struct PreTransformDistanceComputer : DistanceComputer {
301
    const IndexPreTransform* index;
302
    std::unique_ptr<DistanceComputer> sub_dc;
303
    std::unique_ptr<const float[]> query;
304
305
    explicit PreTransformDistanceComputer(const IndexPreTransform* index)
306
0
            : index(index), sub_dc(index->index->get_distance_computer()) {}
307
308
0
    void set_query(const float* x) override {
309
0
        const float* xt = index->apply_chain(1, x);
310
0
        if (xt == x) {
311
0
            sub_dc->set_query(x);
312
0
        } else {
313
0
            query.reset(xt);
314
0
            sub_dc->set_query(xt);
315
0
        }
316
0
    }
317
318
0
    float symmetric_dis(idx_t i, idx_t j) override {
319
0
        return sub_dc->symmetric_dis(i, j);
320
0
    }
321
322
0
    float operator()(idx_t i) override {
323
0
        return (*sub_dc)(i);
324
0
    }
325
};
326
327
} // anonymous namespace
328
329
0
DistanceComputer* IndexPreTransform::get_distance_computer() const {
330
0
    if (chain.empty()) {
331
0
        return index->get_distance_computer();
332
0
    } else {
333
0
        return new PreTransformDistanceComputer(this);
334
0
    }
335
0
}
336
337
} // namespace faiss