/root/doris/contrib/faiss/faiss/IndexFastScan.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <faiss/IndexFastScan.h> |
9 | | |
10 | | #include <cassert> |
11 | | #include <climits> |
12 | | #include <memory> |
13 | | |
14 | | #include <omp.h> |
15 | | |
16 | | #include <faiss/impl/FaissAssert.h> |
17 | | #include <faiss/impl/IDSelector.h> |
18 | | #include <faiss/impl/LookupTableScaler.h> |
19 | | #include <faiss/impl/ResultHandler.h> |
20 | | #include <faiss/utils/hamming.h> |
21 | | |
22 | | #include <faiss/impl/pq4_fast_scan.h> |
23 | | #include <faiss/impl/simd_result_handlers.h> |
24 | | #include <faiss/utils/quantize_lut.h> |
25 | | |
26 | | namespace faiss { |
27 | | |
28 | | using namespace simd_result_handlers; |
29 | | |
30 | | inline size_t roundup(size_t a, size_t b) { |
31 | | return (a + b - 1) / b * b; |
32 | | } |
33 | | |
34 | | void IndexFastScan::init_fastscan( |
35 | | int d, |
36 | | size_t M_init, |
37 | | size_t nbits_init, |
38 | | MetricType metric, |
39 | 0 | int bbs) { |
40 | 0 | FAISS_THROW_IF_NOT(nbits_init == 4); |
41 | 0 | FAISS_THROW_IF_NOT(bbs % 32 == 0); |
42 | 0 | this->d = d; |
43 | 0 | this->M = M_init; |
44 | 0 | this->nbits = nbits_init; |
45 | 0 | this->metric_type = metric; |
46 | 0 | this->bbs = bbs; |
47 | 0 | ksub = (1 << nbits_init); |
48 | |
|
49 | 0 | code_size = (M_init * nbits_init + 7) / 8; |
50 | 0 | ntotal = ntotal2 = 0; |
51 | 0 | M2 = roundup(M_init, 2); |
52 | 0 | is_trained = false; |
53 | 0 | } |
54 | | |
55 | | IndexFastScan::IndexFastScan() |
56 | 0 | : bbs(0), M(0), code_size(0), ntotal2(0), M2(0) {} |
57 | | |
58 | 0 | void IndexFastScan::reset() { |
59 | 0 | codes.resize(0); |
60 | 0 | ntotal = 0; |
61 | 0 | } |
62 | | |
63 | 0 | void IndexFastScan::add(idx_t n, const float* x) { |
64 | 0 | FAISS_THROW_IF_NOT(is_trained); |
65 | | |
66 | | // do some blocking to avoid excessive allocs |
67 | 0 | constexpr idx_t bs = 65536; |
68 | 0 | if (n > bs) { |
69 | 0 | for (idx_t i0 = 0; i0 < n; i0 += bs) { |
70 | 0 | idx_t i1 = std::min(n, i0 + bs); |
71 | 0 | if (verbose) { |
72 | 0 | printf("IndexFastScan::add %zd/%zd\n", size_t(i1), size_t(n)); |
73 | 0 | } |
74 | 0 | add(i1 - i0, x + i0 * d); |
75 | 0 | } |
76 | 0 | return; |
77 | 0 | } |
78 | 0 | InterruptCallback::check(); |
79 | |
|
80 | 0 | AlignedTable<uint8_t> tmp_codes(n * code_size); |
81 | 0 | compute_codes(tmp_codes.get(), n, x); |
82 | |
|
83 | 0 | ntotal2 = roundup(ntotal + n, bbs); |
84 | 0 | size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4 |
85 | 0 | size_t old_size = codes.size(); |
86 | 0 | if (new_size > old_size) { |
87 | 0 | codes.resize(new_size); |
88 | 0 | memset(codes.get() + old_size, 0, new_size - old_size); |
89 | 0 | } |
90 | |
|
91 | 0 | pq4_pack_codes_range( |
92 | 0 | tmp_codes.get(), M, ntotal, ntotal + n, bbs, M2, codes.get()); |
93 | |
|
94 | 0 | ntotal += n; |
95 | 0 | } |
96 | | |
97 | 0 | CodePacker* IndexFastScan::get_CodePacker() const { |
98 | 0 | return new CodePackerPQ4(M, bbs); |
99 | 0 | } |
100 | | |
101 | 0 | size_t IndexFastScan::remove_ids(const IDSelector& sel) { |
102 | 0 | idx_t j = 0; |
103 | 0 | std::vector<uint8_t> buffer(code_size); |
104 | 0 | CodePackerPQ4 packer(M, bbs); |
105 | 0 | for (idx_t i = 0; i < ntotal; i++) { |
106 | 0 | if (sel.is_member(i)) { |
107 | | // should be removed |
108 | 0 | } else { |
109 | 0 | if (i > j) { |
110 | 0 | packer.unpack_1(codes.data(), i, buffer.data()); |
111 | 0 | packer.pack_1(buffer.data(), j, codes.data()); |
112 | 0 | } |
113 | 0 | j++; |
114 | 0 | } |
115 | 0 | } |
116 | 0 | size_t nremove = ntotal - j; |
117 | 0 | if (nremove > 0) { |
118 | 0 | ntotal = j; |
119 | 0 | ntotal2 = roundup(ntotal, bbs); |
120 | 0 | size_t new_size = ntotal2 * M2 / 2; |
121 | 0 | codes.resize(new_size); |
122 | 0 | } |
123 | 0 | return nremove; |
124 | 0 | } |
125 | | |
126 | 0 | void IndexFastScan::check_compatible_for_merge(const Index& otherIndex) const { |
127 | 0 | const IndexFastScan* other = |
128 | 0 | dynamic_cast<const IndexFastScan*>(&otherIndex); |
129 | 0 | FAISS_THROW_IF_NOT(other); |
130 | 0 | FAISS_THROW_IF_NOT(other->M == M); |
131 | 0 | FAISS_THROW_IF_NOT(other->bbs == bbs); |
132 | 0 | FAISS_THROW_IF_NOT(other->d == d); |
133 | 0 | FAISS_THROW_IF_NOT(other->code_size == code_size); |
134 | 0 | FAISS_THROW_IF_NOT_MSG( |
135 | 0 | typeid(*this) == typeid(*other), |
136 | 0 | "can only merge indexes of the same type"); |
137 | 0 | } |
138 | | |
139 | 0 | void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) { |
140 | 0 | check_compatible_for_merge(otherIndex); |
141 | 0 | IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex); |
142 | 0 | ntotal2 = roundup(ntotal + other->ntotal, bbs); |
143 | 0 | codes.resize(ntotal2 * M2 / 2); |
144 | 0 | std::vector<uint8_t> buffer(code_size); |
145 | 0 | CodePackerPQ4 packer(M, bbs); |
146 | |
|
147 | 0 | for (int i = 0; i < other->ntotal; i++) { |
148 | 0 | packer.unpack_1(other->codes.data(), i, buffer.data()); |
149 | 0 | packer.pack_1(buffer.data(), ntotal + i, codes.data()); |
150 | 0 | } |
151 | 0 | ntotal += other->ntotal; |
152 | 0 | other->reset(); |
153 | 0 | } |
154 | | |
155 | | namespace { |
156 | | |
157 | | template <class C, typename dis_t> |
158 | | void estimators_from_tables_generic( |
159 | | const IndexFastScan& index, |
160 | | const uint8_t* codes, |
161 | | size_t ncodes, |
162 | | const dis_t* dis_table, |
163 | | size_t k, |
164 | | typename C::T* heap_dis, |
165 | | int64_t* heap_ids, |
166 | 0 | const NormTableScaler* scaler) { |
167 | 0 | using accu_t = typename C::T; |
168 | |
|
169 | 0 | for (size_t j = 0; j < ncodes; ++j) { |
170 | 0 | BitstringReader bsr(codes + j * index.code_size, index.code_size); |
171 | 0 | accu_t dis = 0; |
172 | 0 | const dis_t* dt = dis_table; |
173 | 0 | int nscale = scaler ? scaler->nscale : 0; |
174 | |
|
175 | 0 | for (size_t m = 0; m < index.M - nscale; m++) { |
176 | 0 | uint64_t c = bsr.read(index.nbits); |
177 | 0 | dis += dt[c]; |
178 | 0 | dt += index.ksub; |
179 | 0 | } |
180 | |
|
181 | 0 | if (nscale) { |
182 | 0 | for (size_t m = 0; m < nscale; m++) { |
183 | 0 | uint64_t c = bsr.read(index.nbits); |
184 | 0 | dis += scaler->scale_one(dt[c]); |
185 | 0 | dt += index.ksub; |
186 | 0 | } |
187 | 0 | } |
188 | |
|
189 | 0 | if (C::cmp(heap_dis[0], dis)) { |
190 | 0 | heap_pop<C>(k, heap_dis, heap_ids); |
191 | 0 | heap_push<C>(k, heap_dis, heap_ids, dis, j); |
192 | 0 | } |
193 | 0 | } |
194 | 0 | } Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_130estimators_from_tables_genericINS_4CMaxIflEEfEEvRKNS_13IndexFastScanEPKhmPKT0_mPNT_1TEPlPKNS_15NormTableScalerE Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_130estimators_from_tables_genericINS_4CMinIflEEfEEvRKNS_13IndexFastScanEPKhmPKT0_mPNT_1TEPlPKNS_15NormTableScalerE |
195 | | |
196 | | template <class C> |
197 | | ResultHandlerCompare<C, false>* make_knn_handler( |
198 | | int impl, |
199 | | idx_t n, |
200 | | idx_t k, |
201 | | size_t ntotal, |
202 | | float* distances, |
203 | | idx_t* labels, |
204 | 0 | const IDSelector* sel = nullptr) { |
205 | 0 | using HeapHC = HeapHandler<C, false>; |
206 | 0 | using ReservoirHC = ReservoirHandler<C, false>; |
207 | 0 | using SingleResultHC = SingleResultHandler<C, false>; |
208 | |
|
209 | 0 | if (k == 1) { |
210 | 0 | return new SingleResultHC(n, ntotal, distances, labels, sel); |
211 | 0 | } else if (impl % 2 == 0) { |
212 | 0 | return new HeapHC(n, ntotal, k, distances, labels, sel); |
213 | 0 | } else /* if (impl % 2 == 1) */ { |
214 | 0 | return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel); |
215 | 0 | } |
216 | 0 | } Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_116make_knn_handlerINS_4CMaxItiEEEEPNS_20simd_result_handlers20ResultHandlerCompareIT_Lb0EEEillmPfPlPKNS_10IDSelectorE Unexecuted instantiation: IndexFastScan.cpp:_ZN5faiss12_GLOBAL__N_116make_knn_handlerINS_4CMinItiEEEEPNS_20simd_result_handlers20ResultHandlerCompareIT_Lb0EEEillmPfPlPKNS_10IDSelectorE |
217 | | |
218 | | } // anonymous namespace |
219 | | |
220 | | using namespace quantize_lut; |
221 | | |
222 | | void IndexFastScan::compute_quantized_LUT( |
223 | | idx_t n, |
224 | | const float* x, |
225 | | uint8_t* lut, |
226 | 0 | float* normalizers) const { |
227 | 0 | size_t dim12 = ksub * M; |
228 | 0 | std::unique_ptr<float[]> dis_tables(new float[n * dim12]); |
229 | 0 | compute_float_LUT(dis_tables.get(), n, x); |
230 | |
|
231 | 0 | for (uint64_t i = 0; i < n; i++) { |
232 | 0 | round_uint8_per_column( |
233 | 0 | dis_tables.get() + i * dim12, |
234 | 0 | M, |
235 | 0 | ksub, |
236 | 0 | &normalizers[2 * i], |
237 | 0 | &normalizers[2 * i + 1]); |
238 | 0 | } |
239 | |
|
240 | 0 | for (uint64_t i = 0; i < n; i++) { |
241 | 0 | const float* t_in = dis_tables.get() + i * dim12; |
242 | 0 | uint8_t* t_out = lut + i * M2 * ksub; |
243 | |
|
244 | 0 | for (int j = 0; j < dim12; j++) { |
245 | 0 | t_out[j] = int(t_in[j]); |
246 | 0 | } |
247 | 0 | memset(t_out + dim12, 0, (M2 - M) * ksub); |
248 | 0 | } |
249 | 0 | } |
250 | | |
251 | | /****************************************************************************** |
252 | | * Search driver routine |
253 | | ******************************************************************************/ |
254 | | |
255 | | void IndexFastScan::search( |
256 | | idx_t n, |
257 | | const float* x, |
258 | | idx_t k, |
259 | | float* distances, |
260 | | idx_t* labels, |
261 | 0 | const SearchParameters* params) const { |
262 | 0 | FAISS_THROW_IF_NOT_MSG( |
263 | 0 | !params, "search params not supported for this index"); |
264 | 0 | FAISS_THROW_IF_NOT(k > 0); |
265 | | |
266 | 0 | if (metric_type == METRIC_L2) { |
267 | 0 | search_dispatch_implem<true>(n, x, k, distances, labels, nullptr); |
268 | 0 | } else { |
269 | 0 | search_dispatch_implem<false>(n, x, k, distances, labels, nullptr); |
270 | 0 | } |
271 | 0 | } |
272 | | |
273 | | template <bool is_max> |
274 | | void IndexFastScan::search_dispatch_implem( |
275 | | idx_t n, |
276 | | const float* x, |
277 | | idx_t k, |
278 | | float* distances, |
279 | | idx_t* labels, |
280 | 0 | const NormTableScaler* scaler) const { |
281 | 0 | using Cfloat = typename std::conditional< |
282 | 0 | is_max, |
283 | 0 | CMax<float, int64_t>, |
284 | 0 | CMin<float, int64_t>>::type; |
285 | |
|
286 | 0 | using C = typename std:: |
287 | 0 | conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type; |
288 | |
|
289 | 0 | if (n == 0) { |
290 | 0 | return; |
291 | 0 | } |
292 | | |
293 | | // actual implementation used |
294 | 0 | int impl = implem; |
295 | |
|
296 | 0 | if (impl == 0) { |
297 | 0 | if (bbs == 32) { |
298 | 0 | impl = 12; |
299 | 0 | } else { |
300 | 0 | impl = 14; |
301 | 0 | } |
302 | 0 | if (k > 20) { |
303 | 0 | impl++; |
304 | 0 | } |
305 | 0 | } |
306 | |
|
307 | 0 | if (implem == 1) { |
308 | 0 | FAISS_THROW_MSG("not implemented"); |
309 | 0 | } else if (implem == 2 || implem == 3 || implem == 4) { |
310 | 0 | FAISS_THROW_IF_NOT(orig_codes != nullptr); |
311 | 0 | search_implem_234<Cfloat>(n, x, k, distances, labels, scaler); |
312 | 0 | } else if (impl >= 12 && impl <= 15) { |
313 | 0 | FAISS_THROW_IF_NOT(ntotal < INT_MAX); |
314 | 0 | int nt = std::min(omp_get_max_threads(), int(n)); |
315 | 0 | if (nt < 2) { |
316 | 0 | if (impl == 12 || impl == 13) { |
317 | 0 | search_implem_12<C>(n, x, k, distances, labels, impl, scaler); |
318 | 0 | } else { |
319 | 0 | search_implem_14<C>(n, x, k, distances, labels, impl, scaler); |
320 | 0 | } |
321 | 0 | } else { |
322 | | // explicitly slice over threads |
323 | 0 | #pragma omp parallel for num_threads(nt) |
324 | 0 | for (int slice = 0; slice < nt; slice++) { |
325 | 0 | idx_t i0 = n * slice / nt; |
326 | 0 | idx_t i1 = n * (slice + 1) / nt; |
327 | 0 | float* dis_i = distances + i0 * k; |
328 | 0 | idx_t* lab_i = labels + i0 * k; |
329 | 0 | if (impl == 12 || impl == 13) { |
330 | 0 | search_implem_12<C>( |
331 | 0 | i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler); |
332 | 0 | } else { |
333 | 0 | search_implem_14<C>( |
334 | 0 | i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler); |
335 | 0 | } |
336 | 0 | } Unexecuted instantiation: IndexFastScan.cpp:_ZNK5faiss13IndexFastScan22search_dispatch_implemILb1EEEvlPKflPfPlPKNS_15NormTableScalerE.omp_outlined_debug__ Unexecuted instantiation: IndexFastScan.cpp:_ZNK5faiss13IndexFastScan22search_dispatch_implemILb0EEEvlPKflPfPlPKNS_15NormTableScalerE.omp_outlined_debug__ |
337 | 0 | } |
338 | 0 | } else { |
339 | 0 | FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl); |
340 | 0 | } |
341 | 0 | } Unexecuted instantiation: _ZNK5faiss13IndexFastScan22search_dispatch_implemILb1EEEvlPKflPfPlPKNS_15NormTableScalerE Unexecuted instantiation: _ZNK5faiss13IndexFastScan22search_dispatch_implemILb0EEEvlPKflPfPlPKNS_15NormTableScalerE |
342 | | |
343 | | template <class Cfloat> |
344 | | void IndexFastScan::search_implem_234( |
345 | | idx_t n, |
346 | | const float* x, |
347 | | idx_t k, |
348 | | float* distances, |
349 | | idx_t* labels, |
350 | 0 | const NormTableScaler* scaler) const { |
351 | 0 | FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4); |
352 | | |
353 | 0 | const size_t dim12 = ksub * M; |
354 | 0 | std::unique_ptr<float[]> dis_tables(new float[n * dim12]); |
355 | 0 | compute_float_LUT(dis_tables.get(), n, x); |
356 | |
|
357 | 0 | std::vector<float> normalizers(n * 2); |
358 | |
|
359 | 0 | if (implem == 2) { |
360 | | // default float |
361 | 0 | } else if (implem == 3 || implem == 4) { |
362 | 0 | for (uint64_t i = 0; i < n; i++) { |
363 | 0 | round_uint8_per_column( |
364 | 0 | dis_tables.get() + i * dim12, |
365 | 0 | M, |
366 | 0 | ksub, |
367 | 0 | &normalizers[2 * i], |
368 | 0 | &normalizers[2 * i + 1]); |
369 | 0 | } |
370 | 0 | } |
371 | |
|
372 | 0 | #pragma omp parallel for if (n > 1000) |
373 | 0 | for (int64_t i = 0; i < n; i++) { |
374 | 0 | int64_t* heap_ids = labels + i * k; |
375 | 0 | float* heap_dis = distances + i * k; |
376 | |
|
377 | 0 | heap_heapify<Cfloat>(k, heap_dis, heap_ids); |
378 | |
|
379 | 0 | estimators_from_tables_generic<Cfloat>( |
380 | 0 | *this, |
381 | 0 | orig_codes, |
382 | 0 | ntotal, |
383 | 0 | dis_tables.get() + i * dim12, |
384 | 0 | k, |
385 | 0 | heap_dis, |
386 | 0 | heap_ids, |
387 | 0 | scaler); |
388 | |
|
389 | 0 | heap_reorder<Cfloat>(k, heap_dis, heap_ids); |
390 | |
|
391 | 0 | if (implem == 4) { |
392 | 0 | float a = normalizers[2 * i]; |
393 | 0 | float b = normalizers[2 * i + 1]; |
394 | |
|
395 | 0 | for (int j = 0; j < k; j++) { |
396 | 0 | heap_dis[j] = heap_dis[j] / a + b; |
397 | 0 | } |
398 | 0 | } |
399 | 0 | } Unexecuted instantiation: IndexFastScan.cpp:_ZNK5faiss13IndexFastScan17search_implem_234INS_4CMaxIflEEEEvlPKflPfPlPKNS_15NormTableScalerE.omp_outlined_debug__ Unexecuted instantiation: IndexFastScan.cpp:_ZNK5faiss13IndexFastScan17search_implem_234INS_4CMinIflEEEEvlPKflPfPlPKNS_15NormTableScalerE.omp_outlined_debug__ |
400 | 0 | } Unexecuted instantiation: _ZNK5faiss13IndexFastScan17search_implem_234INS_4CMaxIflEEEEvlPKflPfPlPKNS_15NormTableScalerE Unexecuted instantiation: _ZNK5faiss13IndexFastScan17search_implem_234INS_4CMinIflEEEEvlPKflPfPlPKNS_15NormTableScalerE |
401 | | |
402 | | template <class C> |
403 | | void IndexFastScan::search_implem_12( |
404 | | idx_t n, |
405 | | const float* x, |
406 | | idx_t k, |
407 | | float* distances, |
408 | | idx_t* labels, |
409 | | int impl, |
410 | 0 | const NormTableScaler* scaler) const { |
411 | 0 | using RH = ResultHandlerCompare<C, false>; |
412 | 0 | FAISS_THROW_IF_NOT(bbs == 32); |
413 | | |
414 | | // handle qbs2 blocking by recursive call |
415 | 0 | int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs); |
416 | 0 | if (n > qbs2) { |
417 | 0 | for (int64_t i0 = 0; i0 < n; i0 += qbs2) { |
418 | 0 | int64_t i1 = std::min(i0 + qbs2, n); |
419 | 0 | search_implem_12<C>( |
420 | 0 | i1 - i0, |
421 | 0 | x + d * i0, |
422 | 0 | k, |
423 | 0 | distances + i0 * k, |
424 | 0 | labels + i0 * k, |
425 | 0 | impl, |
426 | 0 | scaler); |
427 | 0 | } |
428 | 0 | return; |
429 | 0 | } |
430 | | |
431 | 0 | size_t dim12 = ksub * M2; |
432 | 0 | AlignedTable<uint8_t> quantized_dis_tables(n * dim12); |
433 | 0 | std::unique_ptr<float[]> normalizers(new float[2 * n]); |
434 | |
|
435 | 0 | if (skip & 1) { |
436 | 0 | quantized_dis_tables.clear(); |
437 | 0 | } else { |
438 | 0 | compute_quantized_LUT( |
439 | 0 | n, x, quantized_dis_tables.get(), normalizers.get()); |
440 | 0 | } |
441 | |
|
442 | 0 | AlignedTable<uint8_t> LUT(n * dim12); |
443 | | |
444 | | // block sizes are encoded in qbs, 4 bits at a time |
445 | | |
446 | | // caution: we override an object field |
447 | 0 | int qbs = this->qbs; |
448 | |
|
449 | 0 | if (n != pq4_qbs_to_nq(qbs)) { |
450 | 0 | qbs = pq4_preferred_qbs(n); |
451 | 0 | } |
452 | |
|
453 | 0 | int LUT_nq = |
454 | 0 | pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get()); |
455 | 0 | FAISS_THROW_IF_NOT(LUT_nq == n); |
456 | | |
457 | 0 | std::unique_ptr<RH> handler( |
458 | 0 | make_knn_handler<C>(impl, n, k, ntotal, distances, labels)); |
459 | 0 | handler->disable = bool(skip & 2); |
460 | 0 | handler->normalizers = normalizers.get(); |
461 | |
|
462 | 0 | if (skip & 4) { |
463 | | // pass |
464 | 0 | } else { |
465 | 0 | pq4_accumulate_loop_qbs( |
466 | 0 | qbs, |
467 | 0 | ntotal2, |
468 | 0 | M2, |
469 | 0 | codes.get(), |
470 | 0 | LUT.get(), |
471 | 0 | *handler.get(), |
472 | 0 | scaler); |
473 | 0 | } |
474 | 0 | if (!(skip & 8)) { |
475 | 0 | handler->end(); |
476 | 0 | } |
477 | 0 | } Unexecuted instantiation: _ZNK5faiss13IndexFastScan16search_implem_12INS_4CMaxItiEEEEvlPKflPfPliPKNS_15NormTableScalerE Unexecuted instantiation: _ZNK5faiss13IndexFastScan16search_implem_12INS_4CMinItiEEEEvlPKflPfPliPKNS_15NormTableScalerE |
478 | | |
479 | | FastScanStats FastScan_stats; |
480 | | |
481 | | template <class C> |
482 | | void IndexFastScan::search_implem_14( |
483 | | idx_t n, |
484 | | const float* x, |
485 | | idx_t k, |
486 | | float* distances, |
487 | | idx_t* labels, |
488 | | int impl, |
489 | 0 | const NormTableScaler* scaler) const { |
490 | 0 | using RH = ResultHandlerCompare<C, false>; |
491 | 0 | FAISS_THROW_IF_NOT(bbs % 32 == 0); |
492 | | |
493 | 0 | int qbs2 = qbs == 0 ? 4 : qbs; |
494 | | |
495 | | // handle qbs2 blocking by recursive call |
496 | 0 | if (n > qbs2) { |
497 | 0 | for (int64_t i0 = 0; i0 < n; i0 += qbs2) { |
498 | 0 | int64_t i1 = std::min(i0 + qbs2, n); |
499 | 0 | search_implem_14<C>( |
500 | 0 | i1 - i0, |
501 | 0 | x + d * i0, |
502 | 0 | k, |
503 | 0 | distances + i0 * k, |
504 | 0 | labels + i0 * k, |
505 | 0 | impl, |
506 | 0 | scaler); |
507 | 0 | } |
508 | 0 | return; |
509 | 0 | } |
510 | | |
511 | 0 | size_t dim12 = ksub * M2; |
512 | 0 | AlignedTable<uint8_t> quantized_dis_tables(n * dim12); |
513 | 0 | std::unique_ptr<float[]> normalizers(new float[2 * n]); |
514 | |
|
515 | 0 | if (skip & 1) { |
516 | 0 | quantized_dis_tables.clear(); |
517 | 0 | } else { |
518 | 0 | compute_quantized_LUT( |
519 | 0 | n, x, quantized_dis_tables.get(), normalizers.get()); |
520 | 0 | } |
521 | |
|
522 | 0 | AlignedTable<uint8_t> LUT(n * dim12); |
523 | 0 | pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get()); |
524 | |
|
525 | 0 | std::unique_ptr<RH> handler( |
526 | 0 | make_knn_handler<C>(impl, n, k, ntotal, distances, labels)); |
527 | 0 | handler->disable = bool(skip & 2); |
528 | 0 | handler->normalizers = normalizers.get(); |
529 | |
|
530 | 0 | if (skip & 4) { |
531 | | // pass |
532 | 0 | } else { |
533 | 0 | pq4_accumulate_loop( |
534 | 0 | n, |
535 | 0 | ntotal2, |
536 | 0 | bbs, |
537 | 0 | M2, |
538 | 0 | codes.get(), |
539 | 0 | LUT.get(), |
540 | 0 | *handler.get(), |
541 | 0 | scaler); |
542 | 0 | } |
543 | 0 | if (!(skip & 8)) { |
544 | 0 | handler->end(); |
545 | 0 | } |
546 | 0 | } Unexecuted instantiation: _ZNK5faiss13IndexFastScan16search_implem_14INS_4CMaxItiEEEEvlPKflPfPliPKNS_15NormTableScalerE Unexecuted instantiation: _ZNK5faiss13IndexFastScan16search_implem_14INS_4CMinItiEEEEvlPKflPfPliPKNS_15NormTableScalerE |
547 | | |
548 | | template void IndexFastScan::search_dispatch_implem<true>( |
549 | | idx_t n, |
550 | | const float* x, |
551 | | idx_t k, |
552 | | float* distances, |
553 | | idx_t* labels, |
554 | | const NormTableScaler* scaler) const; |
555 | | |
556 | | template void IndexFastScan::search_dispatch_implem<false>( |
557 | | idx_t n, |
558 | | const float* x, |
559 | | idx_t k, |
560 | | float* distances, |
561 | | idx_t* labels, |
562 | | const NormTableScaler* scaler) const; |
563 | | |
564 | 0 | void IndexFastScan::reconstruct(idx_t key, float* recons) const { |
565 | 0 | std::vector<uint8_t> code(code_size, 0); |
566 | 0 | BitstringWriter bsw(code.data(), code_size); |
567 | 0 | for (size_t m = 0; m < M; m++) { |
568 | 0 | uint8_t c = pq4_get_packed_element(codes.data(), bbs, M2, key, m); |
569 | 0 | bsw.write(c, nbits); |
570 | 0 | } |
571 | 0 | sa_decode(1, code.data(), recons); |
572 | 0 | } |
573 | | |
574 | | } // namespace faiss |