/root/doris/contrib/faiss/faiss/impl/kmeans1d.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstring> |
10 | | #include <functional> |
11 | | #include <numeric> |
12 | | #include <string> |
13 | | #include <unordered_map> |
14 | | #include <vector> |
15 | | |
16 | | #include <faiss/Index.h> |
17 | | #include <faiss/impl/FaissAssert.h> |
18 | | #include <faiss/impl/kmeans1d.h> |
19 | | |
20 | | namespace faiss { |
21 | | |
22 | | using LookUpFunc = std::function<float(idx_t, idx_t)>; |
23 | | |
24 | | void reduce( |
25 | | const std::vector<idx_t>& rows, |
26 | | const std::vector<idx_t>& input_cols, |
27 | | const LookUpFunc& lookup, |
28 | 0 | std::vector<idx_t>& output_cols) { |
29 | 0 | for (idx_t col : input_cols) { |
30 | 0 | while (!output_cols.empty()) { |
31 | 0 | idx_t row = rows[output_cols.size() - 1]; |
32 | 0 | float a = lookup(row, col); |
33 | 0 | float b = lookup(row, output_cols.back()); |
34 | 0 | if (a >= b) { // defeated |
35 | 0 | break; |
36 | 0 | } |
37 | 0 | output_cols.pop_back(); |
38 | 0 | } |
39 | 0 | if (output_cols.size() < rows.size()) { |
40 | 0 | output_cols.push_back(col); |
41 | 0 | } |
42 | 0 | } |
43 | 0 | } |
44 | | |
45 | | void interpolate( |
46 | | const std::vector<idx_t>& rows, |
47 | | const std::vector<idx_t>& cols, |
48 | | const LookUpFunc& lookup, |
49 | 0 | idx_t* argmins) { |
50 | 0 | std::unordered_map<idx_t, idx_t> idx_to_col; |
51 | 0 | for (idx_t idx = 0; idx < cols.size(); ++idx) { |
52 | 0 | idx_to_col[cols[idx]] = idx; |
53 | 0 | } |
54 | |
|
55 | 0 | idx_t start = 0; |
56 | 0 | for (idx_t r = 0; r < rows.size(); r += 2) { |
57 | 0 | idx_t row = rows[r]; |
58 | 0 | idx_t end = cols.size() - 1; |
59 | 0 | if (r < rows.size() - 1) { |
60 | 0 | idx_t idx = argmins[rows[r + 1]]; |
61 | 0 | end = idx_to_col[idx]; |
62 | 0 | } |
63 | 0 | idx_t argmin = cols[start]; |
64 | 0 | float min = lookup(row, argmin); |
65 | 0 | for (idx_t c = start + 1; c <= end; c++) { |
66 | 0 | float value = lookup(row, cols[c]); |
67 | 0 | if (value < min) { |
68 | 0 | argmin = cols[c]; |
69 | 0 | min = value; |
70 | 0 | } |
71 | 0 | } |
72 | 0 | argmins[row] = argmin; |
73 | 0 | start = end; |
74 | 0 | } |
75 | 0 | } |
76 | | |
77 | | /** SMAWK algo. Find the row minima of a monotone matrix. |
78 | | * |
79 | | * References: |
80 | | * 1. http://web.cs.unlv.edu/larmore/Courses/CSC477/monge.pdf |
81 | | * 2. https://gist.github.com/dstein64/8e94a6a25efc1335657e910ff525f405 |
82 | | * 3. https://github.com/dstein64/kmeans1d |
83 | | */ |
84 | | void smawk_impl( |
85 | | const std::vector<idx_t>& rows, |
86 | | const std::vector<idx_t>& input_cols, |
87 | | const LookUpFunc& lookup, |
88 | 0 | idx_t* argmins) { |
89 | 0 | if (rows.size() == 0) { |
90 | 0 | return; |
91 | 0 | } |
92 | | |
93 | | /********************************** |
94 | | * REDUCE |
95 | | **********************************/ |
96 | 0 | auto ptr = &input_cols; |
97 | 0 | std::vector<idx_t> survived_cols; // survived columns |
98 | 0 | if (rows.size() < input_cols.size()) { |
99 | 0 | reduce(rows, input_cols, lookup, survived_cols); |
100 | 0 | ptr = &survived_cols; |
101 | 0 | } |
102 | 0 | auto& cols = *ptr; // avoid memory copy |
103 | | |
104 | | /********************************** |
105 | | * INTERPOLATE |
106 | | **********************************/ |
107 | | |
108 | | // call recursively on odd-indexed rows |
109 | 0 | std::vector<idx_t> odd_rows; |
110 | 0 | for (idx_t i = 1; i < rows.size(); i += 2) { |
111 | 0 | odd_rows.push_back(rows[i]); |
112 | 0 | } |
113 | 0 | smawk_impl(odd_rows, cols, lookup, argmins); |
114 | | |
115 | | // interpolate the even-indexed rows |
116 | 0 | interpolate(rows, cols, lookup, argmins); |
117 | 0 | } |
118 | | |
119 | | void smawk( |
120 | | const idx_t nrows, |
121 | | const idx_t ncols, |
122 | | const LookUpFunc& lookup, |
123 | 0 | idx_t* argmins) { |
124 | 0 | std::vector<idx_t> rows(nrows); |
125 | 0 | std::vector<idx_t> cols(ncols); |
126 | 0 | std::iota(std::begin(rows), std::end(rows), 0); |
127 | 0 | std::iota(std::begin(cols), std::end(cols), 0); |
128 | |
|
129 | 0 | smawk_impl(rows, cols, lookup, argmins); |
130 | 0 | } |
131 | | |
132 | | void smawk( |
133 | | const idx_t nrows, |
134 | | const idx_t ncols, |
135 | | const float* x, |
136 | 0 | idx_t* argmins) { |
137 | 0 | auto lookup = [&x, &ncols](idx_t i, idx_t j) { return x[i * ncols + j]; }; |
138 | 0 | smawk(nrows, ncols, lookup, argmins); |
139 | 0 | } |
140 | | |
141 | | namespace { |
142 | | |
143 | | class CostCalculator { |
144 | | // The reuslt would be inaccurate if we use float |
145 | | std::vector<double> cumsum; |
146 | | std::vector<double> cumsum2; |
147 | | |
148 | | public: |
149 | 0 | CostCalculator(const std::vector<float>& vec, idx_t n) { |
150 | 0 | cumsum.push_back(0.0); |
151 | 0 | cumsum2.push_back(0.0); |
152 | 0 | for (idx_t i = 0; i < n; ++i) { |
153 | 0 | float x = vec[i]; |
154 | 0 | cumsum.push_back(x + cumsum[i]); |
155 | 0 | cumsum2.push_back(x * x + cumsum2[i]); |
156 | 0 | } |
157 | 0 | } |
158 | | |
159 | 0 | float operator()(idx_t i, idx_t j) { |
160 | 0 | if (j < i) { |
161 | 0 | return 0.0f; |
162 | 0 | } |
163 | 0 | auto mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1); |
164 | 0 | auto result = cumsum2[j + 1] - cumsum2[i]; |
165 | 0 | result += (j - i + 1) * (mu * mu); |
166 | 0 | result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]); |
167 | 0 | return float(result); |
168 | 0 | } |
169 | | }; |
170 | | |
171 | | template <class T> |
172 | | class Matrix { |
173 | | std::vector<T> data; |
174 | | idx_t nrows; |
175 | | idx_t ncols; |
176 | | |
177 | | public: |
178 | 0 | Matrix(idx_t nrows, idx_t ncols) { |
179 | 0 | this->nrows = nrows; |
180 | 0 | this->ncols = ncols; |
181 | 0 | data.resize(nrows * ncols); |
182 | 0 | } Unexecuted instantiation: kmeans1d.cpp:_ZN5faiss12_GLOBAL__N_16MatrixIfEC2Ell Unexecuted instantiation: kmeans1d.cpp:_ZN5faiss12_GLOBAL__N_16MatrixIlEC2Ell |
183 | | |
184 | 0 | inline T& at(idx_t i, idx_t j) { |
185 | 0 | return data[i * ncols + j]; |
186 | 0 | } Unexecuted instantiation: kmeans1d.cpp:_ZN5faiss12_GLOBAL__N_16MatrixIfE2atEll Unexecuted instantiation: kmeans1d.cpp:_ZN5faiss12_GLOBAL__N_16MatrixIlE2atEll |
187 | | }; |
188 | | |
189 | | } // anonymous namespace |
190 | | |
191 | 0 | double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) { |
192 | 0 | FAISS_THROW_IF_NOT(n >= nclusters); |
193 | | |
194 | | // corner case |
195 | 0 | if (n == nclusters) { |
196 | 0 | memcpy(centroids, x, n * sizeof(*x)); |
197 | 0 | return 0.0f; |
198 | 0 | } |
199 | | |
200 | | /*************************************************** |
201 | | * sort in ascending order, O(NlogN) in time |
202 | | ***************************************************/ |
203 | 0 | std::vector<float> arr(x, x + n); |
204 | 0 | std::sort(arr.begin(), arr.end()); |
205 | | |
206 | | /*************************************************** |
207 | | dynamic programming algorithm |
208 | | |
209 | | Reference: https://arxiv.org/abs/1701.07204 |
210 | | ------------------------------- |
211 | | |
212 | | Assume x is already sorted in ascending order. |
213 | | |
214 | | N: number of points |
215 | | K: number of clusters |
216 | | |
217 | | CC(i, j): the cost of grouping xi,...,xj into one cluster |
218 | | D[k][m]: the cost of optimally clustering x1,...,xm into k clusters |
219 | | T[k][m]: the start index of the k-th cluster |
220 | | |
221 | | The DP process is as follow: |
222 | | D[k][m] = min_i D[k − 1][i − 1] + CC(i, m) |
223 | | T[k][m] = argmin_i D[k − 1][i − 1] + CC(i, m) |
224 | | |
225 | | This could be solved in O(KN^2) time and O(KN) space. |
226 | | |
227 | | To further reduce the time complexity, we use SMAWK algo to |
228 | | solve the argmin problem as follow: |
229 | | |
230 | | For each k: |
231 | | C[m][i] = D[k − 1][i − 1] + CC(i, m) |
232 | | |
233 | | Here C is a n x n totally monotone matrix. |
234 | | We could find the row minima by SMAWK in O(N) time. |
235 | | |
236 | | Now the time complexity is reduced from O(kN^2) to O(KN). |
237 | | ****************************************************/ |
238 | |
|
239 | 0 | CostCalculator CC(arr, n); |
240 | 0 | Matrix<float> D(nclusters, n); |
241 | 0 | Matrix<idx_t> T(nclusters, n); |
242 | |
|
243 | 0 | for (idx_t m = 0; m < n; m++) { |
244 | 0 | D.at(0, m) = CC(0, m); |
245 | 0 | T.at(0, m) = 0; |
246 | 0 | } |
247 | |
|
248 | 0 | std::vector<idx_t> indices(nclusters, 0); |
249 | |
|
250 | 0 | for (idx_t k = 1; k < nclusters; ++k) { |
251 | | // we define C here |
252 | 0 | auto C = [&D, &CC, &k](idx_t m, idx_t i) { |
253 | 0 | if (i == 0) { |
254 | 0 | return CC(i, m); |
255 | 0 | } |
256 | 0 | idx_t col = std::min(m, i - 1); |
257 | 0 | return D.at(k - 1, col) + CC(i, m); |
258 | 0 | }; |
259 | |
|
260 | 0 | std::vector<idx_t> argmins(n); // argmin of each row |
261 | 0 | smawk(n, n, C, argmins.data()); |
262 | 0 | for (idx_t m = 0; m < argmins.size(); m++) { |
263 | 0 | idx_t idx = argmins[m]; |
264 | 0 | D.at(k, m) = C(m, idx); |
265 | 0 | T.at(k, m) = idx; |
266 | 0 | } |
267 | 0 | } |
268 | | |
269 | | /*************************************************** |
270 | | compute centroids by backtracking |
271 | | |
272 | | T[K - 1][T[K][N] - 1] T[K][N] N |
273 | | --------------|------------------------|-----------| |
274 | | | cluster K - 1 | cluster K | |
275 | | |
276 | | ****************************************************/ |
277 | | |
278 | | // for imbalance factor |
279 | 0 | double tot = 0.0; |
280 | 0 | double uf = 0.0; |
281 | |
|
282 | 0 | idx_t end = n; |
283 | 0 | for (idx_t k = nclusters - 1; k >= 0; k--) { |
284 | 0 | const idx_t start = T.at(k, end - 1); |
285 | 0 | const float sum = |
286 | 0 | std::accumulate(arr.data() + start, arr.data() + end, 0.0f); |
287 | 0 | const idx_t size = end - start; |
288 | 0 | FAISS_THROW_IF_NOT_FMT( |
289 | 0 | size > 0, "Cluster %d: size %d", int(k), int(size)); |
290 | 0 | centroids[k] = sum / size; |
291 | 0 | end = start; |
292 | |
|
293 | 0 | tot += size; |
294 | 0 | uf += size * double(size); |
295 | 0 | } |
296 | | |
297 | 0 | uf = uf * nclusters / (tot * tot); |
298 | 0 | return uf; |
299 | 0 | } |
300 | | |
301 | | } // namespace faiss |