/root/doris/contrib/faiss/faiss/IndexFastScan.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <faiss/Index.h> |
11 | | #include <faiss/utils/AlignedTable.h> |
12 | | |
13 | | namespace faiss { |
14 | | |
15 | | struct CodePacker; |
16 | | struct NormTableScaler; |
17 | | |
18 | | /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now. |
19 | | * |
20 | | * The codes are not stored sequentially but grouped in blocks of size bbs. |
21 | | * This makes it possible to compute distances quickly with SIMD instructions. |
22 | | * The trailing codes (padding codes that are added to complete the last code) |
23 | | * are garbage. |
24 | | * |
25 | | * Implementations: |
26 | | * 12: blocked loop with internal loop on Q with qbs |
27 | | * 13: same with reservoir accumulator to store results |
28 | | * 14: no qbs with heap accumulator |
29 | | * 15: no qbs with reservoir accumulator |
30 | | */ |
31 | | struct IndexFastScan : Index { |
32 | | // implementation to select |
33 | | int implem = 0; |
34 | | // skip some parts of the computation (for timing) |
35 | | int skip = 0; |
36 | | |
37 | | // size of the kernel |
38 | | int bbs; // set at build time |
39 | | int qbs = 0; // query block size 0 = use default |
40 | | |
41 | | // vector quantizer |
42 | | size_t M; |
43 | | size_t nbits; |
44 | | size_t ksub; |
45 | | size_t code_size; |
46 | | |
47 | | // packed version of the codes |
48 | | size_t ntotal2; |
49 | | size_t M2; |
50 | | |
51 | | AlignedTable<uint8_t> codes; |
52 | | |
53 | | // this is for testing purposes only |
54 | | // (set when initialized by IndexPQ or IndexAQ) |
55 | | const uint8_t* orig_codes = nullptr; |
56 | | |
57 | | void init_fastscan( |
58 | | int d, |
59 | | size_t M, |
60 | | size_t nbits, |
61 | | MetricType metric, |
62 | | int bbs); |
63 | | |
64 | | IndexFastScan(); |
65 | | |
66 | | void reset() override; |
67 | | |
68 | | void search( |
69 | | idx_t n, |
70 | | const float* x, |
71 | | idx_t k, |
72 | | float* distances, |
73 | | idx_t* labels, |
74 | | const SearchParameters* params = nullptr) const override; |
75 | | |
76 | | void add(idx_t n, const float* x) override; |
77 | | |
78 | | virtual void compute_codes(uint8_t* codes, idx_t n, const float* x) |
79 | | const = 0; |
80 | | |
81 | | virtual void compute_float_LUT(float* lut, idx_t n, const float* x) |
82 | | const = 0; |
83 | | |
84 | | // called by search function |
85 | | void compute_quantized_LUT( |
86 | | idx_t n, |
87 | | const float* x, |
88 | | uint8_t* lut, |
89 | | float* normalizers) const; |
90 | | |
91 | | template <bool is_max> |
92 | | void search_dispatch_implem( |
93 | | idx_t n, |
94 | | const float* x, |
95 | | idx_t k, |
96 | | float* distances, |
97 | | idx_t* labels, |
98 | | const NormTableScaler* scaler) const; |
99 | | |
100 | | template <class Cfloat> |
101 | | void search_implem_234( |
102 | | idx_t n, |
103 | | const float* x, |
104 | | idx_t k, |
105 | | float* distances, |
106 | | idx_t* labels, |
107 | | const NormTableScaler* scaler) const; |
108 | | |
109 | | template <class C> |
110 | | void search_implem_12( |
111 | | idx_t n, |
112 | | const float* x, |
113 | | idx_t k, |
114 | | float* distances, |
115 | | idx_t* labels, |
116 | | int impl, |
117 | | const NormTableScaler* scaler) const; |
118 | | |
119 | | template <class C> |
120 | | void search_implem_14( |
121 | | idx_t n, |
122 | | const float* x, |
123 | | idx_t k, |
124 | | float* distances, |
125 | | idx_t* labels, |
126 | | int impl, |
127 | | const NormTableScaler* scaler) const; |
128 | | |
129 | | void reconstruct(idx_t key, float* recons) const override; |
130 | | size_t remove_ids(const IDSelector& sel) override; |
131 | | |
132 | | CodePacker* get_CodePacker() const; |
133 | | |
134 | | void merge_from(Index& otherIndex, idx_t add_id = 0) override; |
135 | | void check_compatible_for_merge(const Index& otherIndex) const override; |
136 | | |
137 | | /// standalone codes interface (but the codes are flattened) |
138 | 0 | size_t sa_code_size() const override { |
139 | 0 | return code_size; |
140 | 0 | } |
141 | | |
142 | 0 | void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override { |
143 | 0 | compute_codes(bytes, n, x); |
144 | 0 | } |
145 | | }; |
146 | | |
147 | | struct FastScanStats { |
148 | | uint64_t t0, t1, t2, t3; |
149 | 1 | FastScanStats() { |
150 | 1 | reset(); |
151 | 1 | } |
152 | 1 | void reset() { |
153 | 1 | memset(this, 0, sizeof(*this)); |
154 | 1 | } |
155 | | }; |
156 | | |
157 | | FAISS_API extern FastScanStats FastScan_stats; |
158 | | |
159 | | } // namespace faiss |