Coverage Report

Created: 2025-10-14 04:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexRefine.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexRefine.h>
9
10
#include <faiss/IndexFlat.h>
11
#include <faiss/impl/AuxIndexStructures.h>
12
#include <faiss/impl/FaissAssert.h>
13
#include <faiss/utils/Heap.h>
14
15
namespace faiss {
16
17
/***************************************************
18
 * IndexRefine
19
 ***************************************************/
20
21
IndexRefine::IndexRefine(Index* base_index, Index* refine_index)
22
0
        : Index(base_index->d, base_index->metric_type),
23
0
          base_index(base_index),
24
0
          refine_index(refine_index) {
25
0
    own_fields = own_refine_index = false;
26
0
    if (refine_index != nullptr) {
27
0
        FAISS_THROW_IF_NOT(base_index->d == refine_index->d);
28
0
        FAISS_THROW_IF_NOT(
29
0
                base_index->metric_type == refine_index->metric_type);
30
0
        is_trained = base_index->is_trained && refine_index->is_trained;
31
0
        FAISS_THROW_IF_NOT(base_index->ntotal == refine_index->ntotal);
32
0
    } // other case is useful only to construct an IndexRefineFlat
33
0
    ntotal = base_index->ntotal;
34
0
}
35
36
IndexRefine::IndexRefine()
37
0
        : base_index(nullptr),
38
0
          refine_index(nullptr),
39
0
          own_fields(false),
40
0
          own_refine_index(false) {}
41
42
0
void IndexRefine::train(idx_t n, const float* x) {
43
0
    base_index->train(n, x);
44
0
    refine_index->train(n, x);
45
0
    is_trained = true;
46
0
}
47
48
0
void IndexRefine::add(idx_t n, const float* x) {
49
0
    FAISS_THROW_IF_NOT(is_trained);
50
0
    base_index->add(n, x);
51
0
    refine_index->add(n, x);
52
0
    ntotal = refine_index->ntotal;
53
0
}
54
55
0
void IndexRefine::reset() {
56
0
    base_index->reset();
57
0
    refine_index->reset();
58
0
    ntotal = 0;
59
0
}
60
61
namespace {
62
63
using idx_t = faiss::idx_t;
64
65
template <class C>
66
static void reorder_2_heaps(
67
        idx_t n,
68
        idx_t k,
69
        idx_t* __restrict labels,
70
        float* __restrict distances,
71
        idx_t k_base,
72
        const idx_t* __restrict base_labels,
73
0
        const float* __restrict base_distances) {
74
0
#pragma omp parallel for if (n > 1)
75
0
    for (idx_t i = 0; i < n; i++) {
76
0
        idx_t* idxo = labels + i * k;
77
0
        float* diso = distances + i * k;
78
0
        const idx_t* idxi = base_labels + i * k_base;
79
0
        const float* disi = base_distances + i * k_base;
80
81
0
        heap_heapify<C>(k, diso, idxo, disi, idxi, k);
82
0
        if (k_base != k) { // add remaining elements
83
0
            heap_addn<C>(k, diso, idxo, disi + k, idxi + k, k_base - k);
84
0
        }
85
0
        heap_reorder<C>(k, diso, idxo);
86
0
    }
Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMaxIflEEEEvllPlPflPKlPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMinIflEEEEvllPlPflPKlPKf.omp_outlined_debug__
87
0
}
Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMaxIflEEEEvllPlPflPKlPKf
Unexecuted instantiation: IndexRefine.cpp:_ZN5faiss12_GLOBAL__N_115reorder_2_heapsINS_4CMinIflEEEEvllPlPflPKlPKf
88
89
} // anonymous namespace
90
91
void IndexRefine::search(
92
        idx_t n,
93
        const float* x,
94
        idx_t k,
95
        float* distances,
96
        idx_t* labels,
97
0
        const SearchParameters* params_in) const {
98
0
    const IndexRefineSearchParameters* params = nullptr;
99
0
    if (params_in) {
100
0
        params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
101
0
        FAISS_THROW_IF_NOT_MSG(
102
0
                params, "IndexRefine params have incorrect type");
103
0
    }
104
105
0
    idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
106
0
                                       : idx_t(k * k_factor);
107
0
    SearchParameters* base_index_params =
108
0
            (params != nullptr) ? params->base_index_params : nullptr;
109
110
0
    FAISS_THROW_IF_NOT(k_base >= k);
111
112
0
    FAISS_THROW_IF_NOT(base_index);
113
0
    FAISS_THROW_IF_NOT(refine_index);
114
115
0
    FAISS_THROW_IF_NOT(k > 0);
116
0
    FAISS_THROW_IF_NOT(is_trained);
117
0
    idx_t* base_labels = labels;
118
0
    float* base_distances = distances;
119
0
    std::unique_ptr<idx_t[]> del1;
120
0
    std::unique_ptr<float[]> del2;
121
122
0
    if (k != k_base) {
123
0
        base_labels = new idx_t[n * k_base];
124
0
        del1.reset(base_labels);
125
0
        base_distances = new float[n * k_base];
126
0
        del2.reset(base_distances);
127
0
    }
128
129
0
    base_index->search(
130
0
            n, x, k_base, base_distances, base_labels, base_index_params);
131
132
0
    for (int i = 0; i < n * k_base; i++)
133
0
        assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
134
135
        // parallelize over queries
136
0
#pragma omp parallel if (n > 1)
137
0
    {
138
0
        std::unique_ptr<DistanceComputer> dc(
139
0
                refine_index->get_distance_computer());
140
0
#pragma omp for
141
0
        for (idx_t i = 0; i < n; i++) {
142
0
            dc->set_query(x + i * d);
143
0
            idx_t ij = i * k_base;
144
0
            for (idx_t j = 0; j < k_base; j++) {
145
0
                idx_t idx = base_labels[ij];
146
0
                if (idx < 0)
147
0
                    break;
148
0
                base_distances[ij] = (*dc)(idx);
149
0
                ij++;
150
0
            }
151
0
        }
152
0
    }
153
154
    // sort and store result
155
0
    if (metric_type == METRIC_L2) {
156
0
        typedef CMax<float, idx_t> C;
157
0
        reorder_2_heaps<C>(
158
0
                n, k, labels, distances, k_base, base_labels, base_distances);
159
160
0
    } else if (metric_type == METRIC_INNER_PRODUCT) {
161
0
        typedef CMin<float, idx_t> C;
162
0
        reorder_2_heaps<C>(
163
0
                n, k, labels, distances, k_base, base_labels, base_distances);
164
0
    } else {
165
0
        FAISS_THROW_MSG("Metric type not supported");
166
0
    }
167
0
}
168
169
void IndexRefine::range_search(
170
        idx_t n,
171
        const float* x,
172
        float radius,
173
        RangeSearchResult* result,
174
0
        const SearchParameters* params_in) const {
175
0
    const IndexRefineSearchParameters* params = nullptr;
176
0
    if (params_in) {
177
0
        params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
178
0
        FAISS_THROW_IF_NOT_MSG(
179
0
                params, "IndexRefine params have incorrect type");
180
0
    }
181
182
0
    SearchParameters* base_index_params =
183
0
            (params != nullptr) ? params->base_index_params : nullptr;
184
185
0
    base_index->range_search(n, x, radius, result, base_index_params);
186
187
0
#pragma omp parallel if (n > 1)
188
0
    {
189
0
        std::unique_ptr<DistanceComputer> dc(
190
0
                refine_index->get_distance_computer());
191
192
0
#pragma omp for
193
0
        for (idx_t i = 0; i < n; i++) {
194
0
            dc->set_query(x + i * d);
195
196
            // reevaluate distances
197
0
            const size_t idx_start = result->lims[i];
198
0
            const size_t idx_end = result->lims[i + 1];
199
200
0
            for (size_t j = idx_start; j < idx_end; j++) {
201
0
                const auto label = result->labels[j];
202
0
                result->distances[j] = (*dc)(label);
203
0
            }
204
0
        }
205
0
    }
206
0
}
207
208
0
void IndexRefine::reconstruct(idx_t key, float* recons) const {
209
0
    refine_index->reconstruct(key, recons);
210
0
}
211
212
0
size_t IndexRefine::sa_code_size() const {
213
0
    return base_index->sa_code_size() + refine_index->sa_code_size();
214
0
}
215
216
0
void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
217
0
    size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
218
0
    std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
219
0
    base_index->sa_encode(n, x, tmp1.get());
220
0
    std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
221
0
    refine_index->sa_encode(n, x, tmp2.get());
222
0
    for (size_t i = 0; i < n; i++) {
223
0
        uint8_t* b = bytes + i * (cs1 + cs2);
224
0
        memcpy(b, tmp1.get() + cs1 * i, cs1);
225
0
        memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
226
0
    }
227
0
}
228
229
0
void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
230
0
    size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
231
0
    std::unique_ptr<uint8_t[]> tmp2(
232
0
            new uint8_t[n * refine_index->sa_code_size()]);
233
0
    for (size_t i = 0; i < n; i++) {
234
0
        memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
235
0
    }
236
237
0
    refine_index->sa_decode(n, tmp2.get(), x);
238
0
}
239
240
0
IndexRefine::~IndexRefine() {
241
0
    if (own_fields)
242
0
        delete base_index;
243
0
    if (own_refine_index)
244
0
        delete refine_index;
245
0
}
246
247
/***************************************************
248
 * IndexRefineFlat
249
 ***************************************************/
250
251
IndexRefineFlat::IndexRefineFlat(Index* base_index)
252
0
        : IndexRefine(
253
0
                  base_index,
254
0
                  new IndexFlat(base_index->d, base_index->metric_type)) {
255
0
    is_trained = base_index->is_trained;
256
0
    own_refine_index = true;
257
0
    FAISS_THROW_IF_NOT_MSG(
258
0
            base_index->ntotal == 0,
259
0
            "base_index should be empty in the beginning");
260
0
}
261
262
IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
263
0
        : IndexRefine(base_index, nullptr) {
264
0
    is_trained = base_index->is_trained;
265
0
    refine_index = new IndexFlat(base_index->d, base_index->metric_type);
266
0
    own_refine_index = true;
267
0
    refine_index->add(base_index->ntotal, xb);
268
0
}
269
270
0
IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
271
0
    own_refine_index = true;
272
0
}
273
274
void IndexRefineFlat::search(
275
        idx_t n,
276
        const float* x,
277
        idx_t k,
278
        float* distances,
279
        idx_t* labels,
280
0
        const SearchParameters* params_in) const {
281
0
    const IndexRefineSearchParameters* params = nullptr;
282
0
    if (params_in) {
283
0
        params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
284
0
        FAISS_THROW_IF_NOT_MSG(
285
0
                params, "IndexRefineFlat params have incorrect type");
286
0
    }
287
288
0
    idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
289
0
                                       : idx_t(k * k_factor);
290
0
    SearchParameters* base_index_params =
291
0
            (params != nullptr) ? params->base_index_params : nullptr;
292
293
0
    FAISS_THROW_IF_NOT(k_base >= k);
294
295
0
    FAISS_THROW_IF_NOT(base_index);
296
0
    FAISS_THROW_IF_NOT(refine_index);
297
298
0
    FAISS_THROW_IF_NOT(k > 0);
299
0
    FAISS_THROW_IF_NOT(is_trained);
300
0
    idx_t* base_labels = labels;
301
0
    float* base_distances = distances;
302
0
    std::unique_ptr<idx_t[]> del1;
303
0
    std::unique_ptr<float[]> del2;
304
305
0
    if (k != k_base) {
306
0
        base_labels = new idx_t[n * k_base];
307
0
        del1.reset(base_labels);
308
0
        base_distances = new float[n * k_base];
309
0
        del2.reset(base_distances);
310
0
    }
311
312
0
    base_index->search(
313
0
            n, x, k_base, base_distances, base_labels, base_index_params);
314
315
0
    for (int i = 0; i < n * k_base; i++)
316
0
        assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
317
318
    // compute refined distances
319
0
    auto rf = dynamic_cast<const IndexFlat*>(refine_index);
320
0
    FAISS_THROW_IF_NOT(rf);
321
322
0
    rf->compute_distance_subset(n, x, k_base, base_distances, base_labels);
323
324
    // sort and store result
325
0
    if (metric_type == METRIC_L2) {
326
0
        typedef CMax<float, idx_t> C;
327
0
        reorder_2_heaps<C>(
328
0
                n, k, labels, distances, k_base, base_labels, base_distances);
329
330
0
    } else if (metric_type == METRIC_INNER_PRODUCT) {
331
0
        typedef CMin<float, idx_t> C;
332
0
        reorder_2_heaps<C>(
333
0
                n, k, labels, distances, k_base, base_labels, base_distances);
334
0
    } else {
335
0
        FAISS_THROW_MSG("Metric type not supported");
336
0
    }
337
0
}
338
339
} // namespace faiss