/root/doris/contrib/faiss/faiss/utils/Heap.cpp
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | /* Function for soft heap */ |
11 | | |
12 | | #include <faiss/impl/FaissAssert.h> |
13 | | #include <faiss/utils/Heap.h> |
14 | | |
15 | | namespace faiss { |
16 | | |
17 | | template <typename C> |
18 | 0 | void HeapArray<C>::heapify() { |
19 | 0 | #pragma omp parallel for |
20 | 0 | for (int64_t j = 0; j < nh; j++) |
21 | 0 | heap_heapify<C>(k, val + j * k, ids + j * k); Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIflEEE7heapifyEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIflEEE7heapifyEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIfiEEE7heapifyEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIfiEEE7heapifyEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIilEEE7heapifyEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIilEEE7heapifyEv.omp_outlined_debug__ |
22 | 0 | } Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIflEEE7heapifyEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIflEEE7heapifyEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIfiEEE7heapifyEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIfiEEE7heapifyEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIilEEE7heapifyEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIilEEE7heapifyEv |
23 | | |
24 | | template <typename C> |
25 | 0 | void HeapArray<C>::reorder() { |
26 | 0 | #pragma omp parallel for |
27 | 0 | for (int64_t j = 0; j < nh; j++) |
28 | 0 | heap_reorder<C>(k, val + j * k, ids + j * k); Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIflEEE7reorderEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIflEEE7reorderEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIfiEEE7reorderEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIfiEEE7reorderEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIilEEE7reorderEv.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIilEEE7reorderEv.omp_outlined_debug__ |
29 | 0 | } Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIflEEE7reorderEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIflEEE7reorderEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIfiEEE7reorderEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIfiEEE7reorderEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIilEEE7reorderEv Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIilEEE7reorderEv |
30 | | |
31 | | template <typename C> |
32 | 0 | void HeapArray<C>::addn(size_t nj, const T* vin, TI j0, size_t i0, int64_t ni) { |
33 | 0 | if (ni == -1) |
34 | 0 | ni = nh; |
35 | 0 | assert(i0 >= 0 && i0 + ni <= nh); |
36 | 0 | #pragma omp parallel for if (ni * nj > 100000) |
37 | 0 | for (int64_t i = i0; i < i0 + ni; i++) { |
38 | 0 | T* __restrict simi = get_val(i); |
39 | 0 | TI* __restrict idxi = get_ids(i); |
40 | 0 | const T* ip_line = vin + (i - i0) * nj; |
41 | |
|
42 | 0 | for (size_t j = 0; j < nj; j++) { |
43 | 0 | T ip = ip_line[j]; |
44 | 0 | if (C::cmp(simi[0], ip)) { |
45 | 0 | heap_replace_top<C>(k, simi, idxi, ip, j + j0); |
46 | 0 | } |
47 | 0 | } |
48 | 0 | } Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIflEEE4addnEmPKflml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIflEEE4addnEmPKflml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIfiEEE4addnEmPKfiml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIfiEEE4addnEmPKfiml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIilEEE4addnEmPKilml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIilEEE4addnEmPKilml.omp_outlined_debug__ |
49 | 0 | } Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIflEEE4addnEmPKflml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIflEEE4addnEmPKflml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIfiEEE4addnEmPKfiml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIfiEEE4addnEmPKfiml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIilEEE4addnEmPKilml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIilEEE4addnEmPKilml |
50 | | |
51 | | template <typename C> |
52 | | void HeapArray<C>::addn_with_ids( |
53 | | size_t nj, |
54 | | const T* vin, |
55 | | const TI* id_in, |
56 | | int64_t id_stride, |
57 | | size_t i0, |
58 | 0 | int64_t ni) { |
59 | 0 | if (id_in == nullptr) { |
60 | 0 | addn(nj, vin, 0, i0, ni); |
61 | 0 | return; |
62 | 0 | } |
63 | 0 | if (ni == -1) |
64 | 0 | ni = nh; |
65 | 0 | assert(i0 >= 0 && i0 + ni <= nh); |
66 | 0 | #pragma omp parallel for if (ni * nj > 100000) |
67 | 0 | for (int64_t i = i0; i < i0 + ni; i++) { |
68 | 0 | T* __restrict simi = get_val(i); |
69 | 0 | TI* __restrict idxi = get_ids(i); |
70 | 0 | const T* ip_line = vin + (i - i0) * nj; |
71 | 0 | const TI* id_line = id_in + (i - i0) * id_stride; |
72 | |
|
73 | 0 | for (size_t j = 0; j < nj; j++) { |
74 | 0 | T ip = ip_line[j]; |
75 | 0 | if (C::cmp(simi[0], ip)) { |
76 | 0 | heap_replace_top<C>(k, simi, idxi, ip, id_line[j]); |
77 | 0 | } |
78 | 0 | } |
79 | 0 | } Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIflEEE13addn_with_idsEmPKfPKllml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIflEEE13addn_with_idsEmPKfPKllml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIfiEEE13addn_with_idsEmPKfPKilml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIfiEEE13addn_with_idsEmPKfPKilml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIilEEE13addn_with_idsEmPKiPKllml.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIilEEE13addn_with_idsEmPKiPKllml.omp_outlined_debug__ |
80 | 0 | } Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIflEEE13addn_with_idsEmPKfPKllml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIflEEE13addn_with_idsEmPKfPKllml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIfiEEE13addn_with_idsEmPKfPKilml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIfiEEE13addn_with_idsEmPKfPKilml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIilEEE13addn_with_idsEmPKiPKllml Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIilEEE13addn_with_idsEmPKiPKllml |
81 | | |
82 | | template <typename C> |
83 | | void HeapArray<C>::addn_query_subset_with_ids( |
84 | | size_t nsubset, |
85 | | const TI* subset, |
86 | | size_t nj, |
87 | | const T* vin, |
88 | | const TI* id_in, |
89 | 0 | int64_t id_stride) { |
90 | 0 | FAISS_THROW_IF_NOT_MSG(id_in, "anonymous ids not supported"); |
91 | 0 | if (id_stride < 0) { |
92 | 0 | id_stride = nj; |
93 | 0 | } |
94 | 0 | #pragma omp parallel for if (nsubset * nj > 100000) |
95 | 0 | for (int64_t si = 0; si < nsubset; si++) { |
96 | 0 | TI i = subset[si]; |
97 | 0 | T* __restrict simi = get_val(i); |
98 | 0 | TI* __restrict idxi = get_ids(i); |
99 | 0 | const T* ip_line = vin + si * nj; |
100 | 0 | const TI* id_line = id_in + si * id_stride; |
101 | |
|
102 | 0 | for (size_t j = 0; j < nj; j++) { |
103 | 0 | T ip = ip_line[j]; |
104 | 0 | if (C::cmp(simi[0], ip)) { |
105 | 0 | heap_replace_top<C>(k, simi, idxi, ip, id_line[j]); |
106 | 0 | } |
107 | 0 | } |
108 | 0 | } Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIflEEE26addn_query_subset_with_idsEmPKlmPKfS5_l.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIflEEE26addn_query_subset_with_idsEmPKlmPKfS5_l.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIfiEEE26addn_query_subset_with_idsEmPKimPKfS5_l.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIfiEEE26addn_query_subset_with_idsEmPKimPKfS5_l.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMinIilEEE26addn_query_subset_with_idsEmPKlmPKiS5_l.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss9HeapArrayINS_4CMaxIilEEE26addn_query_subset_with_idsEmPKlmPKiS5_l.omp_outlined_debug__ |
109 | 0 | } Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIflEEE26addn_query_subset_with_idsEmPKlmPKfS5_l Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIflEEE26addn_query_subset_with_idsEmPKlmPKfS5_l Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIfiEEE26addn_query_subset_with_idsEmPKimPKfS5_l Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIfiEEE26addn_query_subset_with_idsEmPKimPKfS5_l Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMinIilEEE26addn_query_subset_with_idsEmPKlmPKiS5_l Unexecuted instantiation: _ZN5faiss9HeapArrayINS_4CMaxIilEEE26addn_query_subset_with_idsEmPKlmPKiS5_l |
110 | | |
111 | | template <typename C> |
112 | 0 | void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const { |
113 | 0 | #pragma omp parallel for if (nh * k > 100000) |
114 | 0 | for (int64_t j = 0; j < nh; j++) { |
115 | 0 | int64_t imin = -1; |
116 | 0 | typename C::T xval = C::Crev::neutral(); |
117 | 0 | const typename C::T* x_ = val + j * k; |
118 | 0 | for (size_t i = 0; i < k; i++) |
119 | 0 | if (C::cmp(x_[i], xval)) { |
120 | 0 | xval = x_[i]; |
121 | 0 | imin = i; |
122 | 0 | } |
123 | 0 | if (out_val) |
124 | 0 | out_val[j] = xval; |
125 | |
|
126 | 0 | if (out_ids) { |
127 | 0 | if (ids && imin != -1) |
128 | 0 | out_ids[j] = ids[j * k + imin]; |
129 | 0 | else |
130 | 0 | out_ids[j] = imin; |
131 | 0 | } |
132 | 0 | } Unexecuted instantiation: Heap.cpp:_ZNK5faiss9HeapArrayINS_4CMinIflEEE16per_line_extremaEPfPl.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZNK5faiss9HeapArrayINS_4CMaxIflEEE16per_line_extremaEPfPl.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZNK5faiss9HeapArrayINS_4CMinIfiEEE16per_line_extremaEPfPi.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZNK5faiss9HeapArrayINS_4CMaxIfiEEE16per_line_extremaEPfPi.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZNK5faiss9HeapArrayINS_4CMinIilEEE16per_line_extremaEPiPl.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZNK5faiss9HeapArrayINS_4CMaxIilEEE16per_line_extremaEPiPl.omp_outlined_debug__ |
133 | 0 | } Unexecuted instantiation: _ZNK5faiss9HeapArrayINS_4CMinIflEEE16per_line_extremaEPfPl Unexecuted instantiation: _ZNK5faiss9HeapArrayINS_4CMaxIflEEE16per_line_extremaEPfPl Unexecuted instantiation: _ZNK5faiss9HeapArrayINS_4CMinIfiEEE16per_line_extremaEPfPi Unexecuted instantiation: _ZNK5faiss9HeapArrayINS_4CMaxIfiEEE16per_line_extremaEPfPi Unexecuted instantiation: _ZNK5faiss9HeapArrayINS_4CMinIilEEE16per_line_extremaEPiPl Unexecuted instantiation: _ZNK5faiss9HeapArrayINS_4CMaxIilEEE16per_line_extremaEPiPl |
134 | | |
135 | | // explicit instanciations |
136 | | |
137 | | template struct HeapArray<CMin<float, int64_t>>; |
138 | | template struct HeapArray<CMax<float, int64_t>>; |
139 | | template struct HeapArray<CMin<float, int32_t>>; |
140 | | template struct HeapArray<CMax<float, int32_t>>; |
141 | | template struct HeapArray<CMin<int, int64_t>>; |
142 | | template struct HeapArray<CMax<int, int64_t>>; |
143 | | |
144 | | /********************************************************** |
145 | | * merge knn search results |
146 | | **********************************************************/ |
147 | | |
148 | | /** Merge result tables from several shards. The per-shard results are assumed |
149 | | * to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k |
150 | | * element heap because we want the best (ie. lowest for L2) result to be on |
151 | | * top, not the worst. |
152 | | * |
153 | | * @param all_distances size (nshard, n, k) |
154 | | * @param all_labels size (nshard, n, k) |
155 | | * @param distances output distances, size (n, k) |
156 | | * @param labels output labels, size (n, k) |
157 | | */ |
158 | | template <class idx_t, class C> |
159 | | void merge_knn_results( |
160 | | size_t n, |
161 | | size_t k, |
162 | | typename C::TI nshard, |
163 | | const typename C::T* all_distances, |
164 | | const idx_t* all_labels, |
165 | | typename C::T* distances, |
166 | 0 | idx_t* labels) { |
167 | 0 | using distance_t = typename C::T; |
168 | 0 | if (k == 0) { |
169 | 0 | return; |
170 | 0 | } |
171 | 0 | long stride = n * k; |
172 | 0 | #pragma omp parallel if (n * nshard * k > 100000) |
173 | 0 | { |
174 | 0 | std::vector<int> buf(2 * nshard); |
175 | | // index in each shard's result list |
176 | 0 | int* pointer = buf.data(); |
177 | | // (shard_ids, heap_vals): heap that indexes |
178 | | // shard -> current distance for this shard |
179 | 0 | int* shard_ids = pointer + nshard; |
180 | 0 | std::vector<distance_t> buf2(nshard); |
181 | 0 | distance_t* heap_vals = buf2.data(); |
182 | 0 | #pragma omp for |
183 | 0 | for (long i = 0; i < n; i++) { |
184 | | // the heap maps values to the shard where they are |
185 | | // produced. |
186 | 0 | const distance_t* D_in = all_distances + i * k; |
187 | 0 | const idx_t* I_in = all_labels + i * k; |
188 | 0 | int heap_size = 0; |
189 | | |
190 | | // push the first element of each shard (if not -1) |
191 | 0 | for (long s = 0; s < nshard; s++) { |
192 | 0 | pointer[s] = 0; |
193 | 0 | if (I_in[stride * s] >= 0) { |
194 | 0 | heap_push<C>( |
195 | 0 | ++heap_size, |
196 | 0 | heap_vals, |
197 | 0 | shard_ids, |
198 | 0 | D_in[stride * s], |
199 | 0 | s); |
200 | 0 | } |
201 | 0 | } |
202 | |
|
203 | 0 | distance_t* D = distances + i * k; |
204 | 0 | idx_t* I = labels + i * k; |
205 | |
|
206 | 0 | int j; |
207 | 0 | for (j = 0; j < k && heap_size > 0; j++) { |
208 | | // pop element from best shard |
209 | 0 | int s = shard_ids[0]; // top of heap |
210 | 0 | int& p = pointer[s]; |
211 | 0 | D[j] = heap_vals[0]; |
212 | 0 | I[j] = I_in[stride * s + p]; |
213 | | |
214 | | // pop from shard, advance pointer for this shard |
215 | 0 | heap_pop<C>(heap_size--, heap_vals, shard_ids); |
216 | 0 | p++; |
217 | 0 | if (p < k && I_in[stride * s + p] >= 0) { |
218 | 0 | heap_push<C>( |
219 | 0 | ++heap_size, |
220 | 0 | heap_vals, |
221 | 0 | shard_ids, |
222 | 0 | D_in[stride * s + p], |
223 | 0 | s); |
224 | 0 | } |
225 | 0 | } |
226 | 0 | for (; j < k; j++) { |
227 | 0 | I[j] = -1; |
228 | 0 | D[j] = C::Crev::neutral(); |
229 | 0 | } |
230 | 0 | } |
231 | 0 | } Unexecuted instantiation: Heap.cpp:_ZN5faiss17merge_knn_resultsIlNS_4CMinIfiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss17merge_knn_resultsIlNS_4CMaxIfiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss17merge_knn_resultsIlNS_4CMinIiiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_.omp_outlined_debug__ Unexecuted instantiation: Heap.cpp:_ZN5faiss17merge_knn_resultsIlNS_4CMaxIiiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_.omp_outlined_debug__ |
232 | 0 | } Unexecuted instantiation: _ZN5faiss17merge_knn_resultsIlNS_4CMinIfiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_ Unexecuted instantiation: _ZN5faiss17merge_knn_resultsIlNS_4CMaxIfiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_ Unexecuted instantiation: _ZN5faiss17merge_knn_resultsIlNS_4CMinIiiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_ Unexecuted instantiation: _ZN5faiss17merge_knn_resultsIlNS_4CMaxIiiEEEEvmmNT0_2TIEPKNS3_1TEPKT_PS5_PS8_ |
233 | | |
234 | | // explicit instanciations |
235 | | #define INSTANTIATE(C, distance_t) \ |
236 | | template void merge_knn_results<int64_t, C<distance_t, int>>( \ |
237 | | size_t, \ |
238 | | size_t, \ |
239 | | int, \ |
240 | | const distance_t*, \ |
241 | | const int64_t*, \ |
242 | | distance_t*, \ |
243 | | int64_t*); |
244 | | |
245 | | INSTANTIATE(CMin, float); |
246 | | INSTANTIATE(CMax, float); |
247 | | INSTANTIATE(CMin, int32_t); |
248 | | INSTANTIATE(CMax, int32_t); |
249 | | |
250 | | } // namespace faiss |