/root/doris/contrib/faiss/faiss/Index.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #include <faiss/Index.h> |
11 | | |
12 | | #include <faiss/impl/AuxIndexStructures.h> |
13 | | #include <faiss/impl/DistanceComputer.h> |
14 | | #include <faiss/impl/FaissAssert.h> |
15 | | #include <faiss/utils/distances.h> |
16 | | |
17 | | #include <cstring> |
18 | | |
19 | | namespace faiss { |
20 | | |
21 | 257 | Index::~Index() = default; |
22 | | |
23 | 15 | void Index::train(idx_t /*n*/, const float* /*x*/) { |
24 | | // does nothing by default |
25 | 15 | } |
26 | | |
27 | | void Index::range_search( |
28 | | idx_t, |
29 | | const float*, |
30 | | float, |
31 | | RangeSearchResult*, |
32 | 0 | const SearchParameters* params) const { |
33 | 0 | FAISS_THROW_MSG("range search not implemented"); |
34 | 0 | } |
35 | | |
36 | 0 | void Index::assign(idx_t n, const float* x, idx_t* labels, idx_t k) const { |
37 | 0 | std::vector<float> distances(n * k); |
38 | 0 | search(n, x, k, distances.data(), labels); |
39 | 0 | } |
40 | | |
41 | | void Index::add_with_ids( |
42 | | idx_t /*n*/, |
43 | | const float* /*x*/, |
44 | 0 | const idx_t* /*xids*/) { |
45 | 0 | FAISS_THROW_MSG("add_with_ids not implemented for this type of index"); |
46 | 0 | } |
47 | | |
48 | 0 | size_t Index::remove_ids(const IDSelector& /*sel*/) { |
49 | 0 | FAISS_THROW_MSG("remove_ids not implemented for this type of index"); |
50 | 0 | return -1; |
51 | 0 | } |
52 | | |
53 | 0 | void Index::reconstruct(idx_t, float*) const { |
54 | 0 | FAISS_THROW_MSG("reconstruct not implemented for this type of index"); |
55 | 0 | } |
56 | | |
57 | 0 | void Index::reconstruct_batch(idx_t n, const idx_t* keys, float* recons) const { |
58 | 0 | std::mutex exception_mutex; |
59 | 0 | std::string exception_string; |
60 | 0 | #pragma omp parallel for if (n > 1000) |
61 | 0 | for (idx_t i = 0; i < n; i++) { |
62 | 0 | try { |
63 | 0 | reconstruct(keys[i], &recons[i * d]); |
64 | 0 | } catch (const std::exception& e) { |
65 | 0 | std::lock_guard<std::mutex> lock(exception_mutex); |
66 | 0 | exception_string = e.what(); |
67 | 0 | } |
68 | 0 | } |
69 | 0 | if (!exception_string.empty()) { |
70 | 0 | FAISS_THROW_MSG(exception_string.c_str()); |
71 | 0 | } |
72 | 0 | } |
73 | | |
74 | 0 | void Index::reconstruct_n(idx_t i0, idx_t ni, float* recons) const { |
75 | 0 | #pragma omp parallel for if (ni > 1000) |
76 | 0 | for (idx_t i = 0; i < ni; i++) { |
77 | 0 | reconstruct(i0 + i, recons + i * d); |
78 | 0 | } |
79 | 0 | } |
80 | | |
81 | | void Index::search_and_reconstruct( |
82 | | idx_t n, |
83 | | const float* x, |
84 | | idx_t k, |
85 | | float* distances, |
86 | | idx_t* labels, |
87 | | float* recons, |
88 | 0 | const SearchParameters* params) const { |
89 | 0 | FAISS_THROW_IF_NOT(k > 0); |
90 | | |
91 | 0 | search(n, x, k, distances, labels, params); |
92 | 0 | for (idx_t i = 0; i < n; ++i) { |
93 | 0 | for (idx_t j = 0; j < k; ++j) { |
94 | 0 | idx_t ij = i * k + j; |
95 | 0 | idx_t key = labels[ij]; |
96 | 0 | float* reconstructed = recons + ij * d; |
97 | 0 | if (key < 0) { |
98 | | // Fill with NaNs |
99 | 0 | memset(reconstructed, -1, sizeof(*reconstructed) * d); |
100 | 0 | } else { |
101 | 0 | reconstruct(key, reconstructed); |
102 | 0 | } |
103 | 0 | } |
104 | 0 | } |
105 | 0 | } |
106 | | |
107 | 0 | void Index::compute_residual(const float* x, float* residual, idx_t key) const { |
108 | 0 | reconstruct(key, residual); |
109 | 0 | for (size_t i = 0; i < d; i++) { |
110 | 0 | residual[i] = x[i] - residual[i]; |
111 | 0 | } |
112 | 0 | } |
113 | | |
114 | | void Index::compute_residual_n( |
115 | | idx_t n, |
116 | | const float* xs, |
117 | | float* residuals, |
118 | 0 | const idx_t* keys) const { |
119 | 0 | #pragma omp parallel for |
120 | 0 | for (idx_t i = 0; i < n; ++i) { |
121 | 0 | compute_residual(&xs[i * d], &residuals[i * d], keys[i]); |
122 | 0 | } |
123 | 0 | } |
124 | | |
125 | 0 | size_t Index::sa_code_size() const { |
126 | 0 | FAISS_THROW_MSG("standalone codec not implemented for this type of index"); |
127 | 0 | } |
128 | | |
129 | 0 | void Index::sa_encode(idx_t, const float*, uint8_t*) const { |
130 | 0 | FAISS_THROW_MSG("standalone codec not implemented for this type of index"); |
131 | 0 | } |
132 | | |
133 | 0 | void Index::sa_decode(idx_t, const uint8_t*, float*) const { |
134 | 0 | FAISS_THROW_MSG("standalone codec not implemented for this type of index"); |
135 | 0 | } |
136 | | |
137 | 0 | void Index::add_sa_codes(idx_t, const uint8_t*, const idx_t*) { |
138 | 0 | FAISS_THROW_MSG("add_sa_codes not implemented for this type of index"); |
139 | 0 | } |
140 | | |
141 | | namespace { |
142 | | |
143 | | // storage that explicitly reconstructs vectors before computing distances |
144 | | struct GenericDistanceComputer : DistanceComputer { |
145 | | size_t d; |
146 | | const Index& storage; |
147 | | std::vector<float> buf; |
148 | | const float* q; |
149 | | |
150 | 0 | explicit GenericDistanceComputer(const Index& storage) : storage(storage) { |
151 | 0 | d = storage.d; |
152 | 0 | buf.resize(d * 2); |
153 | 0 | } |
154 | | |
155 | 0 | float operator()(idx_t i) override { |
156 | 0 | storage.reconstruct(i, buf.data()); |
157 | 0 | return fvec_L2sqr(q, buf.data(), d); |
158 | 0 | } |
159 | | |
160 | 0 | float symmetric_dis(idx_t i, idx_t j) override { |
161 | 0 | storage.reconstruct(i, buf.data()); |
162 | 0 | storage.reconstruct(j, buf.data() + d); |
163 | 0 | return fvec_L2sqr(buf.data() + d, buf.data(), d); |
164 | 0 | } |
165 | | |
166 | 0 | void set_query(const float* x) override { |
167 | 0 | q = x; |
168 | 0 | } |
169 | | }; |
170 | | |
171 | | } // namespace |
172 | | |
173 | 0 | DistanceComputer* Index::get_distance_computer() const { |
174 | 0 | if (metric_type == METRIC_L2) { |
175 | 0 | return new GenericDistanceComputer(*this); |
176 | 0 | } else { |
177 | 0 | FAISS_THROW_MSG("get_distance_computer() not implemented"); |
178 | 0 | } |
179 | 0 | } |
180 | | |
181 | 0 | void Index::merge_from(Index& /* otherIndex */, idx_t /* add_id */) { |
182 | 0 | FAISS_THROW_MSG("merge_from() not implemented"); |
183 | 0 | } |
184 | | |
185 | 0 | void Index::check_compatible_for_merge(const Index& /* otherIndex */) const { |
186 | 0 | FAISS_THROW_MSG("check_compatible_for_merge() not implemented"); |
187 | 0 | } |
188 | | |
189 | | } // namespace faiss |