/root/doris/contrib/faiss/faiss/MetaIndexes.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/MetaIndexes.h> |
11 | | |
12 | | #include <cinttypes> |
13 | | #include <cstdint> |
14 | | #include <cstdio> |
15 | | #include <limits> |
16 | | |
17 | | #include <faiss/impl/AuxIndexStructures.h> |
18 | | #include <faiss/impl/FaissAssert.h> |
19 | | #include <faiss/utils/Heap.h> |
20 | | #include <faiss/utils/WorkerThread.h> |
21 | | #include <faiss/utils/random.h> |
22 | | #include <faiss/utils/utils.h> |
23 | | |
24 | | namespace faiss { |
25 | | |
26 | | /***************************************************** |
27 | | * IndexSplitVectors implementation |
28 | | *******************************************************/ |
29 | | |
30 | | IndexSplitVectors::IndexSplitVectors(idx_t d, bool threaded) |
31 | 0 | : Index(d), own_fields(false), threaded(threaded), sum_d(0) {} |
32 | | |
33 | 0 | void IndexSplitVectors::add_sub_index(Index* index) { |
34 | 0 | sub_indexes.push_back(index); |
35 | 0 | sync_with_sub_indexes(); |
36 | 0 | } |
37 | | |
38 | 0 | void IndexSplitVectors::sync_with_sub_indexes() { |
39 | 0 | if (sub_indexes.empty()) |
40 | 0 | return; |
41 | 0 | Index* index0 = sub_indexes[0]; |
42 | 0 | sum_d = index0->d; |
43 | 0 | metric_type = index0->metric_type; |
44 | 0 | is_trained = index0->is_trained; |
45 | 0 | ntotal = index0->ntotal; |
46 | 0 | for (int i = 1; i < sub_indexes.size(); i++) { |
47 | 0 | Index* index = sub_indexes[i]; |
48 | 0 | FAISS_THROW_IF_NOT(metric_type == index->metric_type); |
49 | 0 | FAISS_THROW_IF_NOT(ntotal == index->ntotal); |
50 | 0 | sum_d += index->d; |
51 | 0 | } |
52 | 0 | } |
53 | | |
54 | 0 | void IndexSplitVectors::add(idx_t /*n*/, const float* /*x*/) { |
55 | 0 | FAISS_THROW_MSG("not implemented"); |
56 | 0 | } |
57 | | |
58 | | void IndexSplitVectors::search( |
59 | | idx_t n, |
60 | | const float* x, |
61 | | idx_t k, |
62 | | float* distances, |
63 | | idx_t* labels, |
64 | 0 | const SearchParameters* params) const { |
65 | 0 | FAISS_THROW_IF_NOT_MSG( |
66 | 0 | !params, "search params not supported for this index"); |
67 | 0 | FAISS_THROW_IF_NOT_MSG(k == 1, "search implemented only for k=1"); |
68 | 0 | FAISS_THROW_IF_NOT_MSG( |
69 | 0 | sum_d == d, "not enough indexes compared to # dimensions"); |
70 | | |
71 | 0 | int64_t nshard = sub_indexes.size(); |
72 | |
|
73 | 0 | std::unique_ptr<float[]> all_distances(new float[nshard * k * n]); |
74 | 0 | std::unique_ptr<idx_t[]> all_labels(new idx_t[nshard * k * n]); |
75 | |
|
76 | 0 | auto query_func = |
77 | 0 | [n, x, k, distances, labels, &all_distances, &all_labels, this]( |
78 | 0 | int no) { |
79 | 0 | const IndexSplitVectors* index = this; |
80 | 0 | float* distances1 = |
81 | 0 | no == 0 ? distances : all_distances.get() + no * k * n; |
82 | 0 | idx_t* labels1 = |
83 | 0 | no == 0 ? labels : all_labels.get() + no * k * n; |
84 | 0 | if (index->verbose) |
85 | 0 | printf("begin query shard %d on %" PRId64 " points\n", |
86 | 0 | no, |
87 | 0 | n); |
88 | 0 | const Index* sub_index = index->sub_indexes[no]; |
89 | 0 | int64_t sub_d = sub_index->d, d = index->d; |
90 | 0 | idx_t ofs = 0; |
91 | 0 | for (int i = 0; i < no; i++) |
92 | 0 | ofs += index->sub_indexes[i]->d; |
93 | |
|
94 | 0 | std::unique_ptr<float[]> sub_x(new float[sub_d * n]); |
95 | 0 | for (idx_t i = 0; i < n; i++) |
96 | 0 | memcpy(sub_x.get() + i * sub_d, |
97 | 0 | x + ofs + i * d, |
98 | 0 | sub_d * sizeof(float)); |
99 | 0 | sub_index->search(n, sub_x.get(), k, distances1, labels1); |
100 | 0 | if (index->verbose) |
101 | 0 | printf("end query shard %d\n", no); |
102 | 0 | }; |
103 | |
|
104 | 0 | if (!threaded) { |
105 | 0 | for (int i = 0; i < nshard; i++) { |
106 | 0 | query_func(i); |
107 | 0 | } |
108 | 0 | } else { |
109 | 0 | std::vector<std::unique_ptr<WorkerThread>> threads; |
110 | 0 | std::vector<std::future<bool>> v; |
111 | |
|
112 | 0 | for (int i = 0; i < nshard; i++) { |
113 | 0 | threads.emplace_back(new WorkerThread()); |
114 | 0 | WorkerThread* wt = threads.back().get(); |
115 | 0 | v.emplace_back(wt->add([i, query_func]() { query_func(i); })); |
116 | 0 | } |
117 | | |
118 | | // Blocking wait for completion |
119 | 0 | for (auto& func : v) { |
120 | 0 | func.get(); |
121 | 0 | } |
122 | 0 | } |
123 | |
|
124 | 0 | int64_t factor = 1; |
125 | 0 | for (int i = 0; i < nshard; i++) { |
126 | 0 | if (i > 0) { // results of 0 are already in the table |
127 | 0 | const float* distances_i = all_distances.get() + i * k * n; |
128 | 0 | const idx_t* labels_i = all_labels.get() + i * k * n; |
129 | 0 | for (int64_t j = 0; j < n; j++) { |
130 | 0 | if (labels[j] >= 0 && labels_i[j] >= 0) { |
131 | 0 | labels[j] += labels_i[j] * factor; |
132 | 0 | distances[j] += distances_i[j]; |
133 | 0 | } else { |
134 | 0 | labels[j] = -1; |
135 | 0 | distances[j] = std::numeric_limits<float>::quiet_NaN(); |
136 | 0 | } |
137 | 0 | } |
138 | 0 | } |
139 | 0 | factor *= sub_indexes[i]->ntotal; |
140 | 0 | } |
141 | 0 | } |
142 | | |
143 | 0 | void IndexSplitVectors::train(idx_t /*n*/, const float* /*x*/) { |
144 | 0 | FAISS_THROW_MSG("not implemented"); |
145 | 0 | } |
146 | | |
147 | 0 | void IndexSplitVectors::reset() { |
148 | 0 | FAISS_THROW_MSG("not implemented"); |
149 | 0 | } |
150 | | |
151 | 0 | IndexSplitVectors::~IndexSplitVectors() { |
152 | 0 | if (own_fields) { |
153 | 0 | for (int s = 0; s < sub_indexes.size(); s++) |
154 | 0 | delete sub_indexes[s]; |
155 | 0 | } |
156 | 0 | } |
157 | | |
158 | | /******************************************************** |
159 | | * IndexRandom implementation |
160 | | */ |
161 | | |
162 | | IndexRandom::IndexRandom( |
163 | | idx_t d, |
164 | | idx_t ntotal, |
165 | | int64_t seed, |
166 | | MetricType metric_type) |
167 | 0 | : Index(d, metric_type), seed(seed) { |
168 | 0 | this->ntotal = ntotal; |
169 | 0 | is_trained = true; |
170 | 0 | } |
171 | | |
172 | 0 | void IndexRandom::add(idx_t n, const float*) { |
173 | 0 | ntotal += n; |
174 | 0 | } |
175 | | |
176 | | void IndexRandom::search( |
177 | | idx_t n, |
178 | | const float* x, |
179 | | idx_t k, |
180 | | float* distances, |
181 | | idx_t* labels, |
182 | 0 | const SearchParameters* params) const { |
183 | 0 | FAISS_THROW_IF_NOT_MSG( |
184 | 0 | !params, "search params not supported for this index"); |
185 | 0 | FAISS_THROW_IF_NOT(k <= ntotal); |
186 | 0 | #pragma omp parallel for if (n > 1000) |
187 | 0 | for (idx_t i = 0; i < n; i++) { |
188 | 0 | RandomGenerator rng( |
189 | 0 | seed + ivec_checksum(d, (const int32_t*)(x + i * d))); |
190 | 0 | idx_t* I = labels + i * k; |
191 | 0 | float* D = distances + i * k; |
192 | | // assumes k << ntotal |
193 | 0 | if (k < 100 * ntotal) { |
194 | 0 | std::unordered_set<idx_t> map; |
195 | 0 | for (int j = 0; j < k; j++) { |
196 | 0 | idx_t ii; |
197 | 0 | for (;;) { |
198 | | // yes I know it's not strictly uniform... |
199 | 0 | ii = rng.rand_int64() % ntotal; |
200 | 0 | if (map.count(ii) == 0) { |
201 | 0 | break; |
202 | 0 | } |
203 | 0 | } |
204 | 0 | I[j] = ii; |
205 | 0 | map.insert(ii); |
206 | 0 | } |
207 | 0 | } else { |
208 | 0 | std::vector<idx_t> perm(ntotal); |
209 | 0 | for (idx_t j = 0; j < ntotal; j++) { |
210 | 0 | perm[j] = j; |
211 | 0 | } |
212 | 0 | for (int j = 0; j < k; j++) { |
213 | 0 | std::swap(perm[j], perm[rng.rand_int(ntotal)]); |
214 | 0 | I[j] = perm[j]; |
215 | 0 | } |
216 | 0 | } |
217 | 0 | float dprev = 0; |
218 | 0 | for (int j = 0; j < k; j++) { |
219 | 0 | float step = rng.rand_float(); |
220 | 0 | if (is_similarity_metric(metric_type)) { |
221 | 0 | step = -step; |
222 | 0 | } |
223 | 0 | dprev += step; |
224 | 0 | D[j] = dprev; |
225 | 0 | } |
226 | 0 | } |
227 | 0 | } |
228 | | |
229 | 0 | void IndexRandom::reconstruct(idx_t key, float* recons) const { |
230 | 0 | RandomGenerator rng(seed + 123332 + key); |
231 | 0 | for (size_t i = 0; i < d; i++) { |
232 | 0 | recons[i] = rng.rand_float(); |
233 | 0 | } |
234 | 0 | } |
235 | | |
236 | 0 | void IndexRandom::reset() { |
237 | 0 | ntotal = 0; |
238 | 0 | } |
239 | | |
240 | 0 | IndexRandom::~IndexRandom() = default; |
241 | | |
242 | | } // namespace faiss |