Coverage Report

Created: 2025-12-29 20:29

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexIVFFlat.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#include <faiss/IndexIVFFlat.h>
11
12
#include <omp.h>
13
14
#include <cinttypes>
15
#include <cstdio>
16
17
#include <faiss/IndexFlat.h>
18
19
#include <faiss/impl/AuxIndexStructures.h>
20
#include <faiss/impl/IDSelector.h>
21
22
#include <faiss/impl/FaissAssert.h>
23
#include <faiss/utils/distances.h>
24
#include <faiss/utils/utils.h>
25
26
namespace faiss {
27
28
/*****************************************
29
 * IndexIVFFlat implementation
30
 ******************************************/
31
32
IndexIVFFlat::IndexIVFFlat(
33
        Index* quantizer,
34
        size_t d,
35
        size_t nlist,
36
        MetricType metric)
37
31
        : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
38
31
    code_size = sizeof(float) * d;
39
31
    by_residual = false;
40
31
}
41
42
2
IndexIVFFlat::IndexIVFFlat() {
43
2
    by_residual = false;
44
2
}
45
46
void IndexIVFFlat::add_core(
47
        idx_t n,
48
        const float* x,
49
        const idx_t* xids,
50
        const idx_t* coarse_idx,
51
1.02k
        void* inverted_list_context) {
52
1.02k
    FAISS_THROW_IF_NOT(is_trained);
53
24
    FAISS_THROW_IF_NOT(coarse_idx);
54
24
    FAISS_THROW_IF_NOT(!by_residual);
55
24
    assert(invlists);
56
24
    direct_map.check_can_add(xids);
57
58
24
    int64_t n_add = 0;
59
60
24
    DirectMapAdd dm_adder(direct_map, n, xids);
61
62
24
#pragma omp parallel reduction(+ : n_add)
63
80
    {
64
80
        int nt = omp_get_num_threads();
65
80
        int rank = omp_get_thread_num();
66
67
        // each thread takes care of a subset of lists
68
11.4k
        for (size_t i = 0; i < n; i++) {
69
11.4k
            idx_t list_no = coarse_idx[i];
70
71
11.4k
            if (list_no >= 0 && list_no % nt == rank) {
72
2.77k
                idx_t id = xids ? xids[i] : ntotal + i;
73
2.77k
                const float* xi = x + i * d;
74
2.77k
                size_t offset = invlists->add_entry(
75
2.77k
                        list_no, id, (const uint8_t*)xi, inverted_list_context);
76
2.77k
                dm_adder.add(i, list_no, offset);
77
2.77k
                n_add++;
78
8.63k
            } else if (rank == 0 && list_no == -1) {
79
0
                dm_adder.add(i, -1, 0);
80
0
            }
81
11.4k
        }
82
80
    }
83
84
24
    if (verbose) {
85
0
        printf("IndexIVFFlat::add_core: added %" PRId64 " / %" PRId64
86
0
               " vectors\n",
87
0
               n_add,
88
0
               n);
89
0
    }
90
24
    ntotal += n;
91
24
}
92
93
void IndexIVFFlat::encode_vectors(
94
        idx_t n,
95
        const float* x,
96
        const idx_t* list_nos,
97
        uint8_t* codes,
98
0
        bool include_listnos) const {
99
0
    FAISS_THROW_IF_NOT(!by_residual);
100
0
    if (!include_listnos) {
101
0
        memcpy(codes, x, code_size * n);
102
0
    } else {
103
0
        size_t coarse_size = coarse_code_size();
104
0
        for (size_t i = 0; i < n; i++) {
105
0
            int64_t list_no = list_nos[i];
106
0
            uint8_t* code = codes + i * (code_size + coarse_size);
107
0
            const float* xi = x + i * d;
108
0
            if (list_no >= 0) {
109
0
                encode_listno(list_no, code);
110
0
                memcpy(code + coarse_size, xi, code_size);
111
0
            } else {
112
0
                memset(code, 0, code_size + coarse_size);
113
0
            }
114
0
        }
115
0
    }
116
0
}
117
118
0
void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
119
0
    size_t coarse_size = coarse_code_size();
120
0
    for (size_t i = 0; i < n; i++) {
121
0
        const uint8_t* code = bytes + i * (code_size + coarse_size);
122
0
        float* xi = x + i * d;
123
0
        memcpy(xi, code + coarse_size, code_size);
124
0
    }
125
0
}
126
127
namespace {
128
129
template <MetricType metric, class C, bool use_sel>
130
struct IVFFlatScanner : InvertedListScanner {
131
    size_t d;
132
133
    IVFFlatScanner(size_t d, bool store_pairs, const IDSelector* sel)
134
30
            : InvertedListScanner(store_pairs, sel), d(d) {
135
30
        keep_max = is_similarity_metric(metric);
136
30
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EEC2EmbPKNS_10IDSelectorE
Line
Count
Source
134
8
            : InvertedListScanner(store_pairs, sel), d(d) {
135
8
        keep_max = is_similarity_metric(metric);
136
8
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EEC2EmbPKNS_10IDSelectorE
Line
Count
Source
134
10
            : InvertedListScanner(store_pairs, sel), d(d) {
135
10
        keep_max = is_similarity_metric(metric);
136
10
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EEC2EmbPKNS_10IDSelectorE
Line
Count
Source
134
5
            : InvertedListScanner(store_pairs, sel), d(d) {
135
5
        keep_max = is_similarity_metric(metric);
136
5
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EEC2EmbPKNS_10IDSelectorE
Line
Count
Source
134
7
            : InvertedListScanner(store_pairs, sel), d(d) {
135
7
        keep_max = is_similarity_metric(metric);
136
7
    }
137
138
    const float* xi;
139
30
    void set_query(const float* query) override {
140
30
        this->xi = query;
141
30
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE9set_queryEPKf
Line
Count
Source
139
8
    void set_query(const float* query) override {
140
8
        this->xi = query;
141
8
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE9set_queryEPKf
Line
Count
Source
139
10
    void set_query(const float* query) override {
140
10
        this->xi = query;
141
10
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE9set_queryEPKf
Line
Count
Source
139
5
    void set_query(const float* query) override {
140
5
        this->xi = query;
141
5
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE9set_queryEPKf
Line
Count
Source
139
7
    void set_query(const float* query) override {
140
7
        this->xi = query;
141
7
    }
142
143
102
    void set_list(idx_t list_no, float /* coarse_dis */) override {
144
102
        this->list_no = list_no;
145
102
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE8set_listElf
Line
Count
Source
143
32
    void set_list(idx_t list_no, float /* coarse_dis */) override {
144
32
        this->list_no = list_no;
145
32
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE8set_listElf
Line
Count
Source
143
22
    void set_list(idx_t list_no, float /* coarse_dis */) override {
144
22
        this->list_no = list_no;
145
22
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE8set_listElf
Line
Count
Source
143
20
    void set_list(idx_t list_no, float /* coarse_dis */) override {
144
20
        this->list_no = list_no;
145
20
    }
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE8set_listElf
Line
Count
Source
143
28
    void set_list(idx_t list_no, float /* coarse_dis */) override {
144
28
        this->list_no = list_no;
145
28
    }
146
147
0
    float distance_to_code(const uint8_t* code) const override {
148
0
        const float* yj = (float*)code;
149
0
        float dis = metric == METRIC_INNER_PRODUCT
150
0
                ? fvec_inner_product(xi, yj, d)
151
0
                : fvec_L2sqr(xi, yj, d);
152
0
        return dis;
153
0
    }
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE16distance_to_codeEPKh
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE16distance_to_codeEPKh
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE16distance_to_codeEPKh
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE16distance_to_codeEPKh
154
155
    size_t scan_codes(
156
            size_t list_size,
157
            const uint8_t* codes,
158
            const idx_t* ids,
159
            float* simi,
160
            idx_t* idxi,
161
8
            size_t k) const override {
162
8
        const float* list_vecs = (const float*)codes;
163
8
        size_t nup = 0;
164
208
        for (size_t j = 0; j < list_size; j++) {
165
200
            const float* yj = list_vecs + d * j;
166
200
            if (use_sel && !sel->is_member(ids[j])) {
167
0
                continue;
168
0
            }
169
200
            float dis = metric == METRIC_INNER_PRODUCT
170
200
                    ? fvec_inner_product(xi, yj, d)
171
200
                    : fvec_L2sqr(xi, yj, d);
172
200
            if (C::cmp(simi[0], dis)) {
173
54
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
174
54
                heap_replace_top<C>(k, simi, idxi, dis, id);
175
54
                nup++;
176
54
            }
177
200
        }
178
8
        return nup;
179
8
    }
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE10scan_codesEmPKhPKlPfPlm
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE10scan_codesEmPKhPKlPfPlm
Unexecuted instantiation: IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE10scan_codesEmPKhPKlPfPlm
IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE10scan_codesEmPKhPKlPfPlm
Line
Count
Source
161
8
            size_t k) const override {
162
8
        const float* list_vecs = (const float*)codes;
163
8
        size_t nup = 0;
164
208
        for (size_t j = 0; j < list_size; j++) {
165
200
            const float* yj = list_vecs + d * j;
166
200
            if (use_sel && !sel->is_member(ids[j])) {
167
0
                continue;
168
0
            }
169
200
            float dis = metric == METRIC_INNER_PRODUCT
170
200
                    ? fvec_inner_product(xi, yj, d)
171
200
                    : fvec_L2sqr(xi, yj, d);
172
200
            if (C::cmp(simi[0], dis)) {
173
54
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
174
54
                heap_replace_top<C>(k, simi, idxi, dis, id);
175
54
                nup++;
176
54
            }
177
200
        }
178
8
        return nup;
179
8
    }
180
181
    void scan_codes_range(
182
            size_t list_size,
183
            const uint8_t* codes,
184
            const idx_t* ids,
185
            float radius,
186
94
            RangeQueryResult& res) const override {
187
94
        const float* list_vecs = (const float*)codes;
188
3.37k
        for (size_t j = 0; j < list_size; j++) {
189
3.27k
            const float* yj = list_vecs + d * j;
190
3.27k
            if (use_sel && !sel->is_member(ids[j])) {
191
0
                continue;
192
0
            }
193
3.27k
            float dis = metric == METRIC_INNER_PRODUCT
194
3.27k
                    ? fvec_inner_product(xi, yj, d)
195
3.27k
                    : fvec_L2sqr(xi, yj, d);
196
3.27k
            if (C::cmp(radius, dis)) {
197
1.18k
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
198
1.18k
                res.add(dis, id);
199
1.18k
            }
200
3.27k
        }
201
94
    }
IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb1EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE
Line
Count
Source
186
32
            RangeQueryResult& res) const override {
187
32
        const float* list_vecs = (const float*)codes;
188
1.92k
        for (size_t j = 0; j < list_size; j++) {
189
1.89k
            const float* yj = list_vecs + d * j;
190
1.89k
            if (use_sel && !sel->is_member(ids[j])) {
191
0
                continue;
192
0
            }
193
1.89k
            float dis = metric == METRIC_INNER_PRODUCT
194
1.89k
                    ? fvec_inner_product(xi, yj, d)
195
1.89k
                    : fvec_L2sqr(xi, yj, d);
196
1.89k
            if (C::cmp(radius, dis)) {
197
757
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
198
757
                res.add(dis, id);
199
757
            }
200
1.89k
        }
201
32
    }
IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb1EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE
Line
Count
Source
186
22
            RangeQueryResult& res) const override {
187
22
        const float* list_vecs = (const float*)codes;
188
541
        for (size_t j = 0; j < list_size; j++) {
189
519
            const float* yj = list_vecs + d * j;
190
519
            if (use_sel && !sel->is_member(ids[j])) {
191
0
                continue;
192
0
            }
193
519
            float dis = metric == METRIC_INNER_PRODUCT
194
519
                    ? fvec_inner_product(xi, yj, d)
195
519
                    : fvec_L2sqr(xi, yj, d);
196
519
            if (C::cmp(radius, dis)) {
197
237
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
198
237
                res.add(dis, id);
199
237
            }
200
519
        }
201
22
    }
IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE0ENS_4CMinIflEELb0EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE
Line
Count
Source
186
20
            RangeQueryResult& res) const override {
187
20
        const float* list_vecs = (const float*)codes;
188
413
        for (size_t j = 0; j < list_size; j++) {
189
393
            const float* yj = list_vecs + d * j;
190
393
            if (use_sel && !sel->is_member(ids[j])) {
191
0
                continue;
192
0
            }
193
393
            float dis = metric == METRIC_INNER_PRODUCT
194
393
                    ? fvec_inner_product(xi, yj, d)
195
393
                    : fvec_L2sqr(xi, yj, d);
196
393
            if (C::cmp(radius, dis)) {
197
5
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
198
5
                res.add(dis, id);
199
5
            }
200
393
        }
201
20
    }
IndexIVFFlat.cpp:_ZNK5faiss12_GLOBAL__N_114IVFFlatScannerILNS_10MetricTypeE1ENS_4CMaxIflEELb0EE16scan_codes_rangeEmPKhPKlfRNS_16RangeQueryResultE
Line
Count
Source
186
20
            RangeQueryResult& res) const override {
187
20
        const float* list_vecs = (const float*)codes;
188
491
        for (size_t j = 0; j < list_size; j++) {
189
471
            const float* yj = list_vecs + d * j;
190
471
            if (use_sel && !sel->is_member(ids[j])) {
191
0
                continue;
192
0
            }
193
471
            float dis = metric == METRIC_INNER_PRODUCT
194
471
                    ? fvec_inner_product(xi, yj, d)
195
471
                    : fvec_L2sqr(xi, yj, d);
196
471
            if (C::cmp(radius, dis)) {
197
189
                int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
198
189
                res.add(dis, id);
199
189
            }
200
471
        }
201
20
    }
202
};
203
204
template <bool use_sel>
205
InvertedListScanner* get_InvertedListScanner1(
206
        const IndexIVFFlat* ivf,
207
        bool store_pairs,
208
30
        const IDSelector* sel) {
209
30
    if (ivf->metric_type == METRIC_INNER_PRODUCT) {
210
13
        return new IVFFlatScanner<
211
13
                METRIC_INNER_PRODUCT,
212
13
                CMin<float, int64_t>,
213
13
                use_sel>(ivf->d, store_pairs, sel);
214
17
    } else if (ivf->metric_type == METRIC_L2) {
215
17
        return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>(
216
17
                ivf->d, store_pairs, sel);
217
17
    } else {
218
0
        FAISS_THROW_MSG("metric type not supported");
219
0
    }
220
30
}
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_124get_InvertedListScanner1ILb1EEEPNS_19InvertedListScannerEPKNS_12IndexIVFFlatEbPKNS_10IDSelectorE
Line
Count
Source
208
18
        const IDSelector* sel) {
209
18
    if (ivf->metric_type == METRIC_INNER_PRODUCT) {
210
8
        return new IVFFlatScanner<
211
8
                METRIC_INNER_PRODUCT,
212
8
                CMin<float, int64_t>,
213
8
                use_sel>(ivf->d, store_pairs, sel);
214
10
    } else if (ivf->metric_type == METRIC_L2) {
215
10
        return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>(
216
10
                ivf->d, store_pairs, sel);
217
10
    } else {
218
0
        FAISS_THROW_MSG("metric type not supported");
219
0
    }
220
18
}
IndexIVFFlat.cpp:_ZN5faiss12_GLOBAL__N_124get_InvertedListScanner1ILb0EEEPNS_19InvertedListScannerEPKNS_12IndexIVFFlatEbPKNS_10IDSelectorE
Line
Count
Source
208
12
        const IDSelector* sel) {
209
12
    if (ivf->metric_type == METRIC_INNER_PRODUCT) {
210
5
        return new IVFFlatScanner<
211
5
                METRIC_INNER_PRODUCT,
212
5
                CMin<float, int64_t>,
213
5
                use_sel>(ivf->d, store_pairs, sel);
214
7
    } else if (ivf->metric_type == METRIC_L2) {
215
7
        return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>(
216
7
                ivf->d, store_pairs, sel);
217
7
    } else {
218
0
        FAISS_THROW_MSG("metric type not supported");
219
0
    }
220
12
}
221
222
} // anonymous namespace
223
224
InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
225
        bool store_pairs,
226
        const IDSelector* sel,
227
30
        const IVFSearchParameters*) const {
228
30
    if (sel) {
229
18
        return get_InvertedListScanner1<true>(this, store_pairs, sel);
230
18
    } else {
231
12
        return get_InvertedListScanner1<false>(this, store_pairs, sel);
232
12
    }
233
30
}
234
235
void IndexIVFFlat::reconstruct_from_offset(
236
        int64_t list_no,
237
        int64_t offset,
238
0
        float* recons) const {
239
0
    memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
240
0
}
241
242
/*****************************************
243
 * IndexIVFFlatDedup implementation
244
 ******************************************/
245
246
IndexIVFFlatDedup::IndexIVFFlatDedup(
247
        Index* quantizer,
248
        size_t d,
249
        size_t nlist_,
250
        MetricType metric_type)
251
0
        : IndexIVFFlat(quantizer, d, nlist_, metric_type) {}
252
253
0
void IndexIVFFlatDedup::train(idx_t n, const float* x) {
254
0
    std::unordered_map<uint64_t, idx_t> map;
255
0
    std::unique_ptr<float[]> x2(new float[n * d]);
256
257
0
    int64_t n2 = 0;
258
0
    for (int64_t i = 0; i < n; i++) {
259
0
        uint64_t hash = hash_bytes((uint8_t*)(x + i * d), code_size);
260
0
        if (map.count(hash) &&
261
0
            !memcmp(x2.get() + map[hash] * d, x + i * d, code_size)) {
262
            // is duplicate, skip
263
0
        } else {
264
0
            map[hash] = n2;
265
0
            memcpy(x2.get() + n2 * d, x + i * d, code_size);
266
0
            n2++;
267
0
        }
268
0
    }
269
0
    if (verbose) {
270
0
        printf("IndexIVFFlatDedup::train: train on %" PRId64
271
0
               " points after dedup "
272
0
               "(was %" PRId64 " points)\n",
273
0
               n2,
274
0
               n);
275
0
    }
276
0
    IndexIVFFlat::train(n2, x2.get());
277
0
}
278
279
void IndexIVFFlatDedup::add_with_ids(
280
        idx_t na,
281
        const float* x,
282
0
        const idx_t* xids) {
283
0
    FAISS_THROW_IF_NOT(is_trained);
284
0
    assert(invlists);
285
0
    FAISS_THROW_IF_NOT_MSG(
286
0
            direct_map.no(), "IVFFlatDedup not implemented with direct_map");
287
0
    std::unique_ptr<int64_t[]> idx(new int64_t[na]);
288
0
    quantizer->assign(na, x, idx.get());
289
290
0
    int64_t n_add = 0, n_dup = 0;
291
292
0
#pragma omp parallel reduction(+ : n_add, n_dup)
293
0
    {
294
0
        int nt = omp_get_num_threads();
295
0
        int rank = omp_get_thread_num();
296
297
        // each thread takes care of a subset of lists
298
0
        for (size_t i = 0; i < na; i++) {
299
0
            int64_t list_no = idx[i];
300
301
0
            if (list_no < 0 || list_no % nt != rank) {
302
0
                continue;
303
0
            }
304
305
0
            idx_t id = xids ? xids[i] : ntotal + i;
306
0
            const float* xi = x + i * d;
307
308
            // search if there is already an entry with that id
309
0
            InvertedLists::ScopedCodes codes(invlists, list_no);
310
311
0
            int64_t n = invlists->list_size(list_no);
312
0
            int64_t offset = -1;
313
0
            for (int64_t o = 0; o < n; o++) {
314
0
                if (!memcmp(codes.get() + o * code_size, xi, code_size)) {
315
0
                    offset = o;
316
0
                    break;
317
0
                }
318
0
            }
319
320
0
            if (offset == -1) { // not found
321
0
                invlists->add_entry(list_no, id, (const uint8_t*)xi);
322
0
            } else {
323
                // mark equivalence
324
0
                idx_t id2 = invlists->get_single_id(list_no, offset);
325
0
                std::pair<idx_t, idx_t> pair(id2, id);
326
327
0
#pragma omp critical
328
                // executed by one thread at a time
329
0
                instances.insert(pair);
330
331
0
                n_dup++;
332
0
            }
333
0
            n_add++;
334
0
        }
335
0
    }
336
0
    if (verbose) {
337
0
        printf("IndexIVFFlat::add_with_ids: added %" PRId64 " / %" PRId64
338
0
               " vectors"
339
0
               " (out of which %" PRId64 " are duplicates)\n",
340
0
               n_add,
341
0
               na,
342
0
               n_dup);
343
0
    }
344
0
    ntotal += n_add;
345
0
}
346
347
void IndexIVFFlatDedup::search_preassigned(
348
        idx_t n,
349
        const float* x,
350
        idx_t k,
351
        const idx_t* assign,
352
        const float* centroid_dis,
353
        float* distances,
354
        idx_t* labels,
355
        bool store_pairs,
356
        const IVFSearchParameters* params,
357
0
        IndexIVFStats* stats) const {
358
0
    FAISS_THROW_IF_NOT_MSG(
359
0
            !store_pairs, "store_pairs not supported in IVFDedup");
360
361
0
    IndexIVFFlat::search_preassigned(
362
0
            n, x, k, assign, centroid_dis, distances, labels, false, params);
363
364
0
    std::vector<idx_t> labels2(k);
365
0
    std::vector<float> dis2(k);
366
367
0
    for (int64_t i = 0; i < n; i++) {
368
0
        idx_t* labels1 = labels + i * k;
369
0
        float* dis1 = distances + i * k;
370
0
        int64_t j = 0;
371
0
        for (; j < k; j++) {
372
0
            if (instances.find(labels1[j]) != instances.end()) {
373
                // a duplicate: special handling
374
0
                break;
375
0
            }
376
0
        }
377
0
        if (j < k) {
378
            // there are duplicates, special handling
379
0
            int64_t j0 = j;
380
0
            int64_t rp = j;
381
0
            while (j < k) {
382
0
                auto range = instances.equal_range(labels1[rp]);
383
0
                float dis = dis1[rp];
384
0
                labels2[j] = labels1[rp];
385
0
                dis2[j] = dis;
386
0
                j++;
387
0
                for (auto it = range.first; j < k && it != range.second; ++it) {
388
0
                    labels2[j] = it->second;
389
0
                    dis2[j] = dis;
390
0
                    j++;
391
0
                }
392
0
                rp++;
393
0
            }
394
0
            memcpy(labels1 + j0,
395
0
                   labels2.data() + j0,
396
0
                   sizeof(labels1[0]) * (k - j0));
397
0
            memcpy(dis1 + j0, dis2.data() + j0, sizeof(dis2[0]) * (k - j0));
398
0
        }
399
0
    }
400
0
}
401
402
0
size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel) {
403
0
    std::unordered_map<idx_t, idx_t> replace;
404
0
    std::vector<std::pair<idx_t, idx_t>> toadd;
405
0
    for (auto it = instances.begin(); it != instances.end();) {
406
0
        if (sel.is_member(it->first)) {
407
            // then we erase this entry
408
0
            if (!sel.is_member(it->second)) {
409
                // if the second is not erased
410
0
                if (replace.count(it->first) == 0) {
411
0
                    replace[it->first] = it->second;
412
0
                } else { // remember we should add an element
413
0
                    std::pair<idx_t, idx_t> new_entry(
414
0
                            replace[it->first], it->second);
415
0
                    toadd.push_back(new_entry);
416
0
                }
417
0
            }
418
0
            it = instances.erase(it);
419
0
        } else {
420
0
            if (sel.is_member(it->second)) {
421
0
                it = instances.erase(it);
422
0
            } else {
423
0
                ++it;
424
0
            }
425
0
        }
426
0
    }
427
428
0
    instances.insert(toadd.begin(), toadd.end());
429
430
    // mostly copied from IndexIVF.cpp
431
432
0
    FAISS_THROW_IF_NOT_MSG(
433
0
            direct_map.no(), "direct map remove not implemented");
434
435
0
    std::vector<int64_t> toremove(nlist);
436
437
0
#pragma omp parallel for
438
0
    for (int64_t i = 0; i < nlist; i++) {
439
0
        int64_t l0 = invlists->list_size(i), l = l0, j = 0;
440
0
        InvertedLists::ScopedIds idsi(invlists, i);
441
0
        while (j < l) {
442
0
            if (sel.is_member(idsi[j])) {
443
0
                if (replace.count(idsi[j]) == 0) {
444
0
                    l--;
445
0
                    invlists->update_entry(
446
0
                            i,
447
0
                            j,
448
0
                            invlists->get_single_id(i, l),
449
0
                            InvertedLists::ScopedCodes(invlists, i, l).get());
450
0
                } else {
451
0
                    invlists->update_entry(
452
0
                            i,
453
0
                            j,
454
0
                            replace[idsi[j]],
455
0
                            InvertedLists::ScopedCodes(invlists, i, j).get());
456
0
                    j++;
457
0
                }
458
0
            } else {
459
0
                j++;
460
0
            }
461
0
        }
462
0
        toremove[i] = l0 - l;
463
0
    }
464
    // this will not run well in parallel on ondisk because of possible shrinks
465
0
    int64_t nremove = 0;
466
0
    for (int64_t i = 0; i < nlist; i++) {
467
0
        if (toremove[i] > 0) {
468
0
            nremove += toremove[i];
469
0
            invlists->resize(i, invlists->list_size(i) - toremove[i]);
470
0
        }
471
0
    }
472
0
    ntotal -= nremove;
473
0
    return nremove;
474
0
}
475
476
void IndexIVFFlatDedup::range_search(
477
        idx_t,
478
        const float*,
479
        float,
480
        RangeSearchResult*,
481
0
        const SearchParameters*) const {
482
0
    FAISS_THROW_MSG("not implemented");
483
0
}
484
485
0
void IndexIVFFlatDedup::update_vectors(int, const idx_t*, const float*) {
486
0
    FAISS_THROW_MSG("not implemented");
487
0
}
488
489
void IndexIVFFlatDedup::reconstruct_from_offset(int64_t, int64_t, float*)
490
0
        const {
491
0
    FAISS_THROW_MSG("not implemented");
492
0
}
493
494
} // namespace faiss