/root/doris/contrib/faiss/faiss/utils/utils.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | /* |
11 | | * A few utilitary functions for similarity search: |
12 | | * - optimized exhaustive distance and knn search functions |
13 | | * - some functions reimplemented from torch for speed |
14 | | */ |
15 | | |
16 | | #ifndef FAISS_utils_h |
17 | | #define FAISS_utils_h |
18 | | |
19 | | #include <stdint.h> |
20 | | #include <set> |
21 | | #include <string> |
22 | | #include <vector> |
23 | | |
24 | | #include <faiss/impl/platform_macros.h> |
25 | | #include <faiss/utils/Heap.h> |
26 | | |
27 | | namespace faiss { |
28 | | |
29 | | /**************************************************************************** |
30 | | * Get compile specific variables |
31 | | ***************************************************************************/ |
32 | | |
33 | | /// get compile options |
34 | | std::string get_compile_options(); |
35 | | |
36 | | /************************************************** |
37 | | * Get some stats about the system |
38 | | **************************************************/ |
39 | | |
40 | | // Expose Faiss version as a string |
41 | | std::string get_version(); |
42 | | |
43 | | /// ms elapsed since some arbitrary epoch |
44 | | double getmillisecs(); |
45 | | |
46 | | /// get current RSS usage in kB |
47 | | size_t get_mem_usage_kb(); |
48 | | |
49 | | uint64_t get_cycles(); |
50 | | |
51 | | /*************************************************************************** |
52 | | * Misc matrix and vector manipulation functions |
53 | | ***************************************************************************/ |
54 | | |
55 | | /* perform a reflection (not an efficient implementation, just for test ) */ |
56 | | void reflection(const float* u, float* x, size_t n, size_t d, size_t nu); |
57 | | |
58 | | /** compute the Q of the QR decomposition for m > n |
59 | | * @param a size n * m: input matrix and output Q |
60 | | */ |
61 | | void matrix_qr(int m, int n, float* a); |
62 | | |
63 | | /** distances are supposed to be sorted. Sorts indices with same distance*/ |
64 | | void ranklist_handle_ties(int k, int64_t* idx, const float* dis); |
65 | | |
66 | | /** count the number of common elements between v1 and v2 |
67 | | * algorithm = sorting + bissection to avoid double-counting duplicates |
68 | | */ |
69 | | size_t ranklist_intersection_size( |
70 | | size_t k1, |
71 | | const int64_t* v1, |
72 | | size_t k2, |
73 | | const int64_t* v2); |
74 | | |
75 | | /** merge a result table into another one |
76 | | * |
77 | | * @param I0, D0 first result table, size (n, k) |
78 | | * @param I1, D1 second result table, size (n, k) |
79 | | * @param keep_min if true, keep min values, otherwise keep max |
80 | | * @param translation add this value to all I1's indexes |
81 | | * @return nb of values that were taken from the second table |
82 | | */ |
83 | | size_t merge_result_table_with( |
84 | | size_t n, |
85 | | size_t k, |
86 | | int64_t* I0, |
87 | | float* D0, |
88 | | const int64_t* I1, |
89 | | const float* D1, |
90 | | bool keep_min = true, |
91 | | int64_t translation = 0); |
92 | | |
93 | | /// a balanced assignment has a IF of 1, a completely unbalanced assignment has |
94 | | /// an IF = k. |
95 | | double imbalance_factor(int64_t n, int k, const int64_t* assign); |
96 | | |
97 | | /// same, takes a histogram as input |
98 | | double imbalance_factor(int k, const int64_t* hist); |
99 | | |
100 | | /// compute histogram on v |
101 | | int ivec_hist(size_t n, const int* v, int vmax, int* hist); |
102 | | |
103 | | /** Compute histogram of bits on a code array |
104 | | * |
105 | | * @param codes size(n, nbits / 8) |
106 | | * @param hist size(nbits): nb of 1s in the array of codes |
107 | | */ |
108 | | void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist); |
109 | | |
110 | | /// compute a checksum on a table. |
111 | | uint64_t ivec_checksum(size_t n, const int32_t* a); |
112 | | |
113 | | /// compute a checksum on a table. |
114 | | uint64_t bvec_checksum(size_t n, const uint8_t* a); |
115 | | |
116 | | /** compute checksums for the rows of a matrix |
117 | | * |
118 | | * @param n number of rows |
119 | | * @param d size per row |
120 | | * @param a matrix to handle, size n * d |
121 | | * @param cs output checksums, size n |
122 | | */ |
123 | | void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs); |
124 | | |
125 | | /** random subsamples a set of vectors if there are too many of them |
126 | | * |
127 | | * @param d dimension of the vectors |
128 | | * @param n on input: nb of input vectors, output: nb of output vectors |
129 | | * @param nmax max nb of vectors to keep |
130 | | * @param x input array, size *n-by-d |
131 | | * @param seed random seed to use for sampling |
132 | | * @return x or an array allocated with new [] with *n vectors |
133 | | */ |
134 | | const float* fvecs_maybe_subsample( |
135 | | size_t d, |
136 | | size_t* n, |
137 | | size_t nmax, |
138 | | const float* x, |
139 | | bool verbose = false, |
140 | | int64_t seed = 1234); |
141 | | |
142 | | /** Convert binary vector to +1/-1 valued float vector. |
143 | | * |
144 | | * @param d dimension of the vector (multiple of 8) |
145 | | * @param x_in input binary vector (uint8_t table of size d / 8) |
146 | | * @param x_out output float vector (float table of size d) |
147 | | */ |
148 | | void binary_to_real(size_t d, const uint8_t* x_in, float* x_out); |
149 | | |
150 | | /** Convert float vector to binary vector. Components > 0 are converted to 1, |
151 | | * others to 0. |
152 | | * |
153 | | * @param d dimension of the vector (multiple of 8) |
154 | | * @param x_in input float vector (float table of size d) |
155 | | * @param x_out output binary vector (uint8_t table of size d / 8) |
156 | | */ |
157 | | void real_to_binary(size_t d, const float* x_in, uint8_t* x_out); |
158 | | |
159 | | /** A reasonable hashing function */ |
160 | | uint64_t hash_bytes(const uint8_t* bytes, int64_t n); |
161 | | |
162 | | /** Whether OpenMP annotations were respected. */ |
163 | | bool check_openmp(); |
164 | | |
165 | | /** This class is used to combine range and knn search results |
166 | | * in contrib.exhaustive_search.range_search_gpu */ |
167 | | |
168 | | template <typename T> |
169 | | struct CombinerRangeKNN { |
170 | | int64_t nq; /// nb of queries |
171 | | size_t k; /// number of neighbors for the knn search part |
172 | | T r2; /// range search radius |
173 | | bool keep_max; /// whether to keep max values instead of min. |
174 | | |
175 | | CombinerRangeKNN(int64_t nq, size_t k, T r2, bool keep_max) |
176 | 0 | : nq(nq), k(k), r2(r2), keep_max(keep_max) {}Unexecuted instantiation: _ZN5faiss16CombinerRangeKNNIfEC2Elmfb Unexecuted instantiation: _ZN5faiss16CombinerRangeKNNIsEC2Elmsb |
177 | | |
178 | | /// Knn search results |
179 | | const int64_t* I = nullptr; /// size nq * k |
180 | | const T* D = nullptr; /// size nq * k |
181 | | |
182 | | /// optional: range search results (ignored if mask is NULL) |
183 | | const bool* mask = |
184 | | nullptr; /// mask for where knn results are valid, size nq |
185 | | // range search results for remaining entries nrange = sum(mask) |
186 | | const int64_t* lim_remain = nullptr; /// size nrange + 1 |
187 | | const T* D_remain = nullptr; /// size lim_remain[nrange] |
188 | | const int64_t* I_remain = nullptr; /// size lim_remain[nrange] |
189 | | |
190 | | const int64_t* L_res = nullptr; /// size nq + 1 |
191 | | // Phase 1: compute sizes into limits array (of size nq + 1) |
192 | | void compute_sizes(int64_t* L_res); |
193 | | |
194 | | /// Phase 2: caller allocates D_res and I_res (size L_res[nq]) |
195 | | /// Phase 3: fill in D_res and I_res |
196 | | void write_result(T* D_res, int64_t* I_res); |
197 | | }; |
198 | | |
199 | | struct CodeSet { |
200 | | size_t d; |
201 | | std::set<std::vector<uint8_t>> s; |
202 | | |
203 | 0 | explicit CodeSet(size_t d) : d(d) {} |
204 | | void insert(size_t n, const uint8_t* codes, bool* inserted); |
205 | | }; |
206 | | |
207 | | } // namespace faiss |
208 | | |
209 | | #endif /* FAISS_utils_h */ |