contrib/faiss/faiss/utils/extra_distances-inl.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | /** In this file are the implementations of extra metrics beyond L2 |
9 | | * and inner product */ |
10 | | |
11 | | #include <faiss/MetricType.h> |
12 | | #include <faiss/impl/FaissAssert.h> |
13 | | #include <faiss/utils/distances.h> |
14 | | #include <cmath> |
15 | | #include <type_traits> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | template <MetricType mt> |
20 | | struct VectorDistance { |
21 | | size_t d; |
22 | | float metric_arg; |
23 | | static constexpr bool is_similarity = is_similarity_metric(mt); |
24 | | |
25 | | inline float operator()(const float* x, const float* y) const; |
26 | | |
27 | | // heap template to use for this type of metric |
28 | | using C = typename std::conditional< |
29 | | is_similarity_metric(mt), |
30 | | CMin<float, int64_t>, |
31 | | CMax<float, int64_t>>::type; |
32 | | }; |
33 | | |
34 | | template <> |
35 | | inline float VectorDistance<METRIC_L2>::operator()( |
36 | | const float* x, |
37 | 0 | const float* y) const { |
38 | 0 | return fvec_L2sqr(x, y, d); |
39 | 0 | } |
40 | | |
41 | | template <> |
42 | | inline float VectorDistance<METRIC_INNER_PRODUCT>::operator()( |
43 | | const float* x, |
44 | 0 | const float* y) const { |
45 | 0 | return fvec_inner_product(x, y, d); |
46 | 0 | } |
47 | | |
48 | | template <> |
49 | | inline float VectorDistance<METRIC_L1>::operator()( |
50 | | const float* x, |
51 | 0 | const float* y) const { |
52 | 0 | return fvec_L1(x, y, d); |
53 | 0 | } |
54 | | |
55 | | template <> |
56 | | inline float VectorDistance<METRIC_Linf>::operator()( |
57 | | const float* x, |
58 | 0 | const float* y) const { |
59 | 0 | return fvec_Linf(x, y, d); |
60 | | /* |
61 | | float vmax = 0; |
62 | | for (size_t i = 0; i < d; i++) { |
63 | | float diff = fabs (x[i] - y[i]); |
64 | | if (diff > vmax) vmax = diff; |
65 | | } |
66 | | return vmax;*/ |
67 | 0 | } |
68 | | |
69 | | template <> |
70 | | inline float VectorDistance<METRIC_Lp>::operator()( |
71 | | const float* x, |
72 | 0 | const float* y) const { |
73 | 0 | float accu = 0; |
74 | 0 | for (size_t i = 0; i < d; i++) { |
75 | 0 | float diff = fabs(x[i] - y[i]); |
76 | 0 | accu += powf(diff, metric_arg); |
77 | 0 | } |
78 | 0 | return accu; |
79 | 0 | } |
80 | | |
81 | | template <> |
82 | | inline float VectorDistance<METRIC_Canberra>::operator()( |
83 | | const float* x, |
84 | 0 | const float* y) const { |
85 | 0 | float accu = 0; |
86 | 0 | for (size_t i = 0; i < d; i++) { |
87 | 0 | float xi = x[i], yi = y[i]; |
88 | 0 | accu += fabs(xi - yi) / (fabs(xi) + fabs(yi)); |
89 | 0 | } |
90 | 0 | return accu; |
91 | 0 | } |
92 | | |
93 | | template <> |
94 | | inline float VectorDistance<METRIC_BrayCurtis>::operator()( |
95 | | const float* x, |
96 | 0 | const float* y) const { |
97 | 0 | float accu_num = 0, accu_den = 0; |
98 | 0 | for (size_t i = 0; i < d; i++) { |
99 | 0 | float xi = x[i], yi = y[i]; |
100 | 0 | accu_num += fabs(xi - yi); |
101 | 0 | accu_den += fabs(xi + yi); |
102 | 0 | } |
103 | 0 | return accu_num / accu_den; |
104 | 0 | } |
105 | | |
106 | | template <> |
107 | | inline float VectorDistance<METRIC_JensenShannon>::operator()( |
108 | | const float* x, |
109 | 0 | const float* y) const { |
110 | 0 | float accu = 0; |
111 | 0 | for (size_t i = 0; i < d; i++) { |
112 | 0 | float xi = x[i], yi = y[i]; |
113 | 0 | float mi = 0.5 * (xi + yi); |
114 | 0 | float kl1 = -xi * log(mi / xi); |
115 | 0 | float kl2 = -yi * log(mi / yi); |
116 | 0 | accu += kl1 + kl2; |
117 | 0 | } |
118 | 0 | return 0.5 * accu; |
119 | 0 | } |
120 | | |
121 | | template <> |
122 | | inline float VectorDistance<METRIC_Jaccard>::operator()( |
123 | | const float* x, |
124 | 0 | const float* y) const { |
125 | | // WARNING: this distance is defined only for positive input vectors. |
126 | | // Providing vectors with negative values would lead to incorrect results. |
127 | 0 | float accu_num = 0, accu_den = 0; |
128 | 0 | for (size_t i = 0; i < d; i++) { |
129 | 0 | accu_num += fmin(x[i], y[i]); |
130 | 0 | accu_den += fmax(x[i], y[i]); |
131 | 0 | } |
132 | 0 | return accu_num / accu_den; |
133 | 0 | } |
134 | | |
135 | | template <> |
136 | | inline float VectorDistance<METRIC_NaNEuclidean>::operator()( |
137 | | const float* x, |
138 | 0 | const float* y) const { |
139 | | // https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html |
140 | 0 | float accu = 0; |
141 | 0 | size_t present = 0; |
142 | 0 | for (size_t i = 0; i < d; i++) { |
143 | 0 | if (!std::isnan(x[i]) && !std::isnan(y[i])) { |
144 | 0 | float diff = x[i] - y[i]; |
145 | 0 | accu += diff * diff; |
146 | 0 | present++; |
147 | 0 | } |
148 | 0 | } |
149 | 0 | if (present == 0) { |
150 | 0 | return NAN; |
151 | 0 | } |
152 | 0 | return float(d) / float(present) * accu; |
153 | 0 | } |
154 | | |
155 | | template <> |
156 | | inline float VectorDistance<METRIC_ABS_INNER_PRODUCT>::operator()( |
157 | | const float* x, |
158 | 0 | const float* y) const { |
159 | 0 | float accu = 0; |
160 | 0 | for (size_t i = 0; i < d; i++) { |
161 | 0 | accu += fabs(x[i] * y[i]); |
162 | 0 | } |
163 | 0 | return accu; |
164 | 0 | } |
165 | | |
166 | | /*************************************************************************** |
167 | | * Dispatching function that takes a metric type and a consumer object |
168 | | * the consumer object should contain a retun type T and a operation template |
169 | | * function f() that is called to perform the operation. The first argument |
170 | | * of the function is the VectorDistance object. The rest are passed in as is. |
171 | | **************************************************************************/ |
172 | | |
173 | | template <class Consumer, class... Types> |
174 | | typename Consumer::T dispatch_VectorDistance( |
175 | | size_t d, |
176 | | MetricType metric, |
177 | | float metric_arg, |
178 | | Consumer& consumer, |
179 | 0 | Types... args) { |
180 | 0 | switch (metric) { |
181 | 0 | #define DISPATCH_VD(mt) \ |
182 | 0 | case mt: { \ |
183 | 0 | VectorDistance<mt> vd = {d, metric_arg}; \ |
184 | 0 | return consumer.template f<VectorDistance<mt>>(vd, args...); \ |
185 | 0 | } |
186 | 0 | DISPATCH_VD(METRIC_INNER_PRODUCT); |
187 | 0 | DISPATCH_VD(METRIC_L2); |
188 | 0 | DISPATCH_VD(METRIC_L1); |
189 | 0 | DISPATCH_VD(METRIC_Linf); |
190 | 0 | DISPATCH_VD(METRIC_Lp); |
191 | 0 | DISPATCH_VD(METRIC_Canberra); |
192 | 0 | DISPATCH_VD(METRIC_BrayCurtis); |
193 | 0 | DISPATCH_VD(METRIC_JensenShannon); |
194 | 0 | DISPATCH_VD(METRIC_Jaccard); |
195 | 0 | DISPATCH_VD(METRIC_NaNEuclidean); |
196 | 0 | DISPATCH_VD(METRIC_ABS_INNER_PRODUCT); |
197 | 0 | default: |
198 | 0 | FAISS_THROW_FMT("Invalid metric %d", metric); |
199 | 0 | } |
200 | 0 | #undef DISPATCH_VD |
201 | 0 | } Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_125Run_get_distance_computerEJPKNS_14IndexFlatCodesEEEENT_1TEmNS_10MetricTypeEfRS6_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22Top1BlockResultHandlerINS_4CMinIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22HeapBlockResultHandlerINS_4CMinIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_27ReservoirBlockResultHandlerINS_4CMinIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22Top1BlockResultHandlerINS_4CMinIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22HeapBlockResultHandlerINS_4CMinIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_27ReservoirBlockResultHandlerINS_4CMinIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22Top1BlockResultHandlerINS_4CMaxIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22HeapBlockResultHandlerINS_4CMaxIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_27ReservoirBlockResultHandlerINS_4CMaxIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22Top1BlockResultHandlerINS_4CMaxIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_27ReservoirBlockResultHandlerINS_4CMaxIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_29RangeSearchBlockResultHandlerINS_4CMinIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_29RangeSearchBlockResultHandlerINS_4CMinIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb1EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: IndexFlatCodes.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_126Run_search_with_decompressINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEJPKNS_14IndexFlatCodesEPKfS6_EEENT_1TEmNS_10MetricTypeEfRSD_DpT0_ Unexecuted instantiation: extra_distances.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_128Run_pairwise_extra_distancesEJlPKflS4_PflllEEENT_1TEmNS_10MetricTypeEfRS6_DpT0_ Unexecuted instantiation: extra_distances.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_121Run_knn_extra_metricsEJPKfS4_mmmPfPlEEENT_1TEmNS_10MetricTypeEfRS7_DpT0_ Unexecuted instantiation: extra_distances.cpp:_ZN5faiss23dispatch_VectorDistanceINS_12_GLOBAL__N_125Run_get_distance_computerEJPKfmEEENT_1TEmNS_10MetricTypeEfRS5_DpT0_ |
202 | | |
203 | | } // namespace faiss |