/root/doris/contrib/faiss/faiss/IndexBinaryHNSW.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexBinaryHNSW.h> |
9 | | |
10 | | #include <omp.h> |
11 | | #include <cassert> |
12 | | #include <cmath> |
13 | | #include <cstdio> |
14 | | #include <cstdlib> |
15 | | #include <cstring> |
16 | | #include <memory> |
17 | | |
18 | | #include <cstdint> |
19 | | |
20 | | #include <faiss/IndexBinaryFlat.h> |
21 | | #include <faiss/impl/AuxIndexStructures.h> |
22 | | #include <faiss/impl/DistanceComputer.h> |
23 | | #include <faiss/impl/FaissAssert.h> |
24 | | #include <faiss/impl/ResultHandler.h> |
25 | | #include <faiss/utils/Heap.h> |
26 | | #include <faiss/utils/hamming.h> |
27 | | #include <faiss/utils/random.h> |
28 | | |
29 | | namespace faiss { |
30 | | |
31 | | /************************************************************** |
32 | | * add / search blocks of descriptors |
33 | | **************************************************************/ |
34 | | |
35 | | namespace { |
36 | | |
37 | | void hnsw_add_vertices( |
38 | | IndexBinaryHNSW& index_hnsw, |
39 | | size_t n0, |
40 | | size_t n, |
41 | | const uint8_t* x, |
42 | | bool verbose, |
43 | 0 | bool preset_levels = false) { |
44 | 0 | HNSW& hnsw = index_hnsw.hnsw; |
45 | 0 | size_t ntotal = n0 + n; |
46 | 0 | double t0 = getmillisecs(); |
47 | 0 | if (verbose) { |
48 | 0 | printf("hnsw_add_vertices: adding %zd elements on top of %zd " |
49 | 0 | "(preset_levels=%d)\n", |
50 | 0 | n, |
51 | 0 | n0, |
52 | 0 | int(preset_levels)); |
53 | 0 | } |
54 | |
|
55 | 0 | int max_level = hnsw.prepare_level_tab(n, preset_levels); |
56 | |
|
57 | 0 | if (verbose) { |
58 | 0 | printf(" max_level = %d\n", max_level); |
59 | 0 | } |
60 | |
|
61 | 0 | std::vector<omp_lock_t> locks(ntotal); |
62 | 0 | for (int i = 0; i < ntotal; i++) { |
63 | 0 | omp_init_lock(&locks[i]); |
64 | 0 | } |
65 | | |
66 | | // add vectors from highest to lowest level |
67 | 0 | std::vector<int> hist; |
68 | 0 | std::vector<int> order(n); |
69 | |
|
70 | 0 | { // make buckets with vectors of the same level |
71 | | |
72 | | // build histogram |
73 | 0 | for (int i = 0; i < n; i++) { |
74 | 0 | HNSW::storage_idx_t pt_id = i + n0; |
75 | 0 | int pt_level = hnsw.levels[pt_id] - 1; |
76 | 0 | while (pt_level >= hist.size()) { |
77 | 0 | hist.push_back(0); |
78 | 0 | } |
79 | 0 | hist[pt_level]++; |
80 | 0 | } |
81 | | |
82 | | // accumulate |
83 | 0 | std::vector<int> offsets(hist.size() + 1, 0); |
84 | 0 | for (int i = 0; i < hist.size() - 1; i++) { |
85 | 0 | offsets[i + 1] = offsets[i] + hist[i]; |
86 | 0 | } |
87 | | |
88 | | // bucket sort |
89 | 0 | for (int i = 0; i < n; i++) { |
90 | 0 | HNSW::storage_idx_t pt_id = i + n0; |
91 | 0 | int pt_level = hnsw.levels[pt_id] - 1; |
92 | 0 | order[offsets[pt_level]++] = pt_id; |
93 | 0 | } |
94 | 0 | } |
95 | |
|
96 | 0 | { // perform add |
97 | 0 | RandomGenerator rng2(789); |
98 | |
|
99 | 0 | int i1 = n; |
100 | |
|
101 | 0 | for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) { |
102 | 0 | int i0 = i1 - hist[pt_level]; |
103 | |
|
104 | 0 | if (verbose) { |
105 | 0 | printf("Adding %d elements at level %d\n", i1 - i0, pt_level); |
106 | 0 | } |
107 | | |
108 | | // random permutation to get rid of dataset order bias |
109 | 0 | for (int j = i0; j < i1; j++) { |
110 | 0 | std::swap(order[j], order[j + rng2.rand_int(i1 - j)]); |
111 | 0 | } |
112 | |
|
113 | 0 | #pragma omp parallel |
114 | 0 | { |
115 | 0 | VisitedTable vt(ntotal); |
116 | |
|
117 | 0 | std::unique_ptr<DistanceComputer> dis( |
118 | 0 | index_hnsw.get_distance_computer()); |
119 | 0 | int prev_display = |
120 | 0 | verbose && omp_get_thread_num() == 0 ? 0 : -1; |
121 | |
|
122 | 0 | #pragma omp for schedule(dynamic) |
123 | 0 | for (int i = i0; i < i1; i++) { |
124 | 0 | HNSW::storage_idx_t pt_id = order[i]; |
125 | 0 | dis->set_query( |
126 | 0 | (float*)(x + (pt_id - n0) * index_hnsw.code_size)); |
127 | |
|
128 | 0 | hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt); |
129 | |
|
130 | 0 | if (prev_display >= 0 && i - i0 > prev_display + 10000) { |
131 | 0 | prev_display = i - i0; |
132 | 0 | printf(" %d / %d\r", i - i0, i1 - i0); |
133 | 0 | fflush(stdout); |
134 | 0 | } |
135 | 0 | } |
136 | 0 | } |
137 | 0 | i1 = i0; |
138 | 0 | } |
139 | 0 | FAISS_ASSERT(i1 == 0); |
140 | 0 | } |
141 | 0 | if (verbose) { |
142 | 0 | printf("Done in %.3f ms\n", getmillisecs() - t0); |
143 | 0 | } |
144 | |
|
145 | 0 | for (int i = 0; i < ntotal; i++) |
146 | 0 | omp_destroy_lock(&locks[i]); |
147 | 0 | } |
148 | | |
149 | | } // anonymous namespace |
150 | | |
151 | | /************************************************************** |
152 | | * IndexBinaryHNSW implementation |
153 | | **************************************************************/ |
154 | | |
155 | 0 | IndexBinaryHNSW::IndexBinaryHNSW() { |
156 | 0 | is_trained = true; |
157 | 0 | } |
158 | | |
159 | | IndexBinaryHNSW::IndexBinaryHNSW(int d, int M) |
160 | 0 | : IndexBinary(d), |
161 | 0 | hnsw(M), |
162 | 0 | own_fields(true), |
163 | 0 | storage(new IndexBinaryFlat(d)) { |
164 | 0 | is_trained = true; |
165 | 0 | } |
166 | | |
167 | | IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary* storage, int M) |
168 | 0 | : IndexBinary(storage->d), |
169 | 0 | hnsw(M), |
170 | 0 | own_fields(false), |
171 | 0 | storage(storage) { |
172 | 0 | is_trained = true; |
173 | 0 | } |
174 | | |
175 | 0 | IndexBinaryHNSW::~IndexBinaryHNSW() { |
176 | 0 | if (own_fields) { |
177 | 0 | delete storage; |
178 | 0 | } |
179 | 0 | } |
180 | | |
181 | 0 | void IndexBinaryHNSW::train(idx_t n, const uint8_t* x) { |
182 | | // hnsw structure does not require training |
183 | 0 | storage->train(n, x); |
184 | 0 | is_trained = true; |
185 | 0 | } |
186 | | |
187 | | void IndexBinaryHNSW::search( |
188 | | idx_t n, |
189 | | const uint8_t* x, |
190 | | idx_t k, |
191 | | int32_t* distances, |
192 | | idx_t* labels, |
193 | 0 | const SearchParameters* params) const { |
194 | 0 | FAISS_THROW_IF_NOT_MSG( |
195 | 0 | !params, "search params not supported for this index"); |
196 | 0 | FAISS_THROW_IF_NOT(k > 0); |
197 | | |
198 | | // we use the buffer for distances as float but convert them back |
199 | | // to int in the end |
200 | 0 | float* distances_f = (float*)distances; |
201 | |
|
202 | 0 | using RH = HeapBlockResultHandler<HNSW::C>; |
203 | 0 | RH bres(n, distances_f, labels, k); |
204 | |
|
205 | 0 | #pragma omp parallel |
206 | 0 | { |
207 | 0 | VisitedTable vt(ntotal); |
208 | 0 | std::unique_ptr<DistanceComputer> dis(get_distance_computer()); |
209 | 0 | RH::SingleResultHandler res(bres); |
210 | |
|
211 | 0 | #pragma omp for |
212 | 0 | for (idx_t i = 0; i < n; i++) { |
213 | 0 | res.begin(i); |
214 | 0 | dis->set_query((float*)(x + i * code_size)); |
215 | 0 | hnsw.search(*dis, res, vt); |
216 | 0 | res.end(); |
217 | 0 | } |
218 | 0 | } |
219 | |
|
220 | 0 | #pragma omp parallel for |
221 | 0 | for (int i = 0; i < n * k; ++i) { |
222 | 0 | distances[i] = std::round(distances_f[i]); |
223 | 0 | } |
224 | 0 | } |
225 | | |
226 | 0 | void IndexBinaryHNSW::add(idx_t n, const uint8_t* x) { |
227 | 0 | FAISS_THROW_IF_NOT(is_trained); |
228 | 0 | int n0 = ntotal; |
229 | 0 | storage->add(n, x); |
230 | 0 | ntotal = storage->ntotal; |
231 | |
|
232 | 0 | hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal); |
233 | 0 | } |
234 | | |
235 | 0 | void IndexBinaryHNSW::reset() { |
236 | 0 | hnsw.reset(); |
237 | 0 | storage->reset(); |
238 | 0 | ntotal = 0; |
239 | 0 | } |
240 | | |
241 | 0 | void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t* recons) const { |
242 | 0 | storage->reconstruct(key, recons); |
243 | 0 | } |
244 | | |
245 | | namespace { |
246 | | |
247 | | template <class HammingComputer> |
248 | | struct FlatHammingDis : DistanceComputer { |
249 | | const int code_size; |
250 | | const uint8_t* b; |
251 | | size_t ndis; |
252 | | HammingComputer hc; |
253 | | |
254 | 0 | float operator()(idx_t i) override { |
255 | 0 | ndis++; |
256 | 0 | return hc.hamming(b + i * code_size); |
257 | 0 | } Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer4EEclEl Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer8EEclEl Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer16EEclEl Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer20EEclEl Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer32EEclEl Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer64EEclEl Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_22HammingComputerDefaultEEclEl |
258 | | |
259 | 0 | float symmetric_dis(idx_t i, idx_t j) override { |
260 | 0 | return HammingComputerDefault(b + j * code_size, code_size) |
261 | 0 | .hamming(b + i * code_size); |
262 | 0 | } Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer4EE13symmetric_disEll Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer8EE13symmetric_disEll Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer16EE13symmetric_disEll Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer20EE13symmetric_disEll Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer32EE13symmetric_disEll Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer64EE13symmetric_disEll Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_22HammingComputerDefaultEE13symmetric_disEll |
263 | | |
264 | | explicit FlatHammingDis(const IndexBinaryFlat& storage) |
265 | 0 | : code_size(storage.code_size), |
266 | 0 | b(storage.xb.data()), |
267 | 0 | ndis(0), |
268 | 0 | hc() {} Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer4EEC2ERKNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer8EEC2ERKNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer16EEC2ERKNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer20EEC2ERKNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer32EEC2ERKNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer64EEC2ERKNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_22HammingComputerDefaultEEC2ERKNS_15IndexBinaryFlatE |
269 | | |
270 | | // NOTE: Pointers are cast from float in order to reuse the floating-point |
271 | | // DistanceComputer. |
272 | 0 | void set_query(const float* x) override { |
273 | 0 | hc.set((uint8_t*)x, code_size); |
274 | 0 | } Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer4EE9set_queryEPKf Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer8EE9set_queryEPKf Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer16EE9set_queryEPKf Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer20EE9set_queryEPKf Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer32EE9set_queryEPKf Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer64EE9set_queryEPKf Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_22HammingComputerDefaultEE9set_queryEPKf |
275 | | |
276 | 0 | ~FlatHammingDis() override { |
277 | 0 | #pragma omp critical |
278 | 0 | { hnsw_stats.ndis += ndis; } |
279 | 0 | } Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer4EED2Ev Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_16HammingComputer8EED2Ev Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer16EED2Ev Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer20EED2Ev Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer32EED2Ev Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_17HammingComputer64EED2Ev Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_114FlatHammingDisINS_22HammingComputerDefaultEED2Ev |
280 | | }; |
281 | | |
282 | | struct BuildDistanceComputer { |
283 | | using T = DistanceComputer*; |
284 | | template <class HammingComputer> |
285 | 0 | DistanceComputer* f(IndexBinaryFlat* flat_storage) { |
286 | 0 | return new FlatHammingDis<HammingComputer>(*flat_storage); |
287 | 0 | } Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_16HammingComputer4EEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_16HammingComputer8EEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_17HammingComputer16EEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_17HammingComputer20EEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_17HammingComputer32EEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_17HammingComputer64EEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE Unexecuted instantiation: IndexBinaryHNSW.cpp:_ZN5faiss12_GLOBAL__N_121BuildDistanceComputer1fINS_22HammingComputerDefaultEEEPNS_16DistanceComputerEPNS_15IndexBinaryFlatE |
288 | | }; |
289 | | |
290 | | } // namespace |
291 | | |
292 | 0 | DistanceComputer* IndexBinaryHNSW::get_distance_computer() const { |
293 | 0 | IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage); |
294 | 0 | FAISS_ASSERT(flat_storage != nullptr); |
295 | 0 | BuildDistanceComputer bd; |
296 | 0 | return dispatch_HammingComputer(code_size, bd, flat_storage); |
297 | 0 | } |
298 | | |
299 | | } // namespace faiss |