be/src/exprs/function/array/function_array_distance.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #pragma once |
19 | | |
20 | | #include <faiss/impl/platform_macros.h> |
21 | | #include <faiss/utils/distances.h> |
22 | | #include <gen_cpp/Types_types.h> |
23 | | |
24 | | #include "common/exception.h" |
25 | | #include "common/status.h" |
26 | | #include "core/assert_cast.h" |
27 | | #include "core/column/column.h" |
28 | | #include "core/column/column_array.h" |
29 | | #include "core/column/column_array_view.h" |
30 | | #include "core/column/column_const.h" |
31 | | #include "core/column/column_nullable.h" |
32 | | #include "core/data_type/data_type.h" |
33 | | #include "core/data_type/data_type_array.h" |
34 | | #include "core/data_type/data_type_nullable.h" |
35 | | #include "core/data_type/data_type_number.h" |
36 | | #include "core/data_type/primitive_type.h" |
37 | | #include "core/types.h" |
38 | | #include "exec/common/util.hpp" |
39 | | #include "exprs/function/function.h" |
40 | | |
41 | | namespace doris { |
42 | | |
43 | | class L1Distance { |
44 | | public: |
45 | | static constexpr auto name = "l1_distance"; |
46 | 0 | static float distance(const float* x, const float* y, size_t d) { |
47 | 0 | return faiss::fvec_L1(x, y, d); |
48 | 0 | } |
49 | | }; |
50 | | |
51 | | class L2Distance { |
52 | | public: |
53 | | static constexpr auto name = "l2_distance"; |
54 | 0 | static float distance(const float* x, const float* y, size_t d) { |
55 | 0 | return std::sqrt(faiss::fvec_L2sqr(x, y, d)); |
56 | 0 | } |
57 | | }; |
58 | | |
59 | | class InnerProduct { |
60 | | public: |
61 | | static constexpr auto name = "inner_product"; |
62 | 0 | static float distance(const float* x, const float* y, size_t d) { |
63 | 0 | return faiss::fvec_inner_product(x, y, d); |
64 | 0 | } |
65 | | }; |
66 | | |
67 | | class CosineDistance { |
68 | | public: |
69 | | static constexpr auto name = "cosine_distance"; |
70 | | static float distance(const float* x, const float* y, size_t d); |
71 | | }; |
72 | | |
73 | | class CosineSimilarity { |
74 | | public: |
75 | | static constexpr auto name = "cosine_similarity"; |
76 | | static float distance(const float* x, const float* y, size_t d); |
77 | | }; |
78 | | |
79 | | class L2DistanceApproximate : public L2Distance { |
80 | | public: |
81 | | static constexpr auto name = "l2_distance_approximate"; |
82 | | }; |
83 | | |
84 | | class InnerProductApproximate : public InnerProduct { |
85 | | public: |
86 | | static constexpr auto name = "inner_product_approximate"; |
87 | | }; |
88 | | |
89 | | template <typename DistanceImpl> |
90 | | class FunctionArrayDistance : public IFunction { |
91 | | public: |
92 | | using DataType = PrimitiveTypeTraits<TYPE_FLOAT>::DataType; |
93 | | using ColumnType = PrimitiveTypeTraits<TYPE_FLOAT>::ColumnType; |
94 | | |
95 | | static constexpr auto name = DistanceImpl::name; |
96 | 27 | String get_name() const override { return name; }_ZNK5doris21FunctionArrayDistanceINS_10L1DistanceEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 1 | String get_name() const override { return name; } |
_ZNK5doris21FunctionArrayDistanceINS_10L2DistanceEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 1 | String get_name() const override { return name; } |
_ZNK5doris21FunctionArrayDistanceINS_14CosineDistanceEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 1 | String get_name() const override { return name; } |
_ZNK5doris21FunctionArrayDistanceINS_16CosineSimilarityEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 21 | String get_name() const override { return name; } |
_ZNK5doris21FunctionArrayDistanceINS_12InnerProductEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 1 | String get_name() const override { return name; } |
_ZNK5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 1 | String get_name() const override { return name; } |
_ZNK5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE8get_nameB5cxx11Ev Line | Count | Source | 96 | 1 | String get_name() const override { return name; } |
|
97 | 33 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); }_ZN5doris21FunctionArrayDistanceINS_10L1DistanceEE6createEv Line | Count | Source | 97 | 2 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
_ZN5doris21FunctionArrayDistanceINS_10L2DistanceEE6createEv Line | Count | Source | 97 | 2 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
_ZN5doris21FunctionArrayDistanceINS_14CosineDistanceEE6createEv Line | Count | Source | 97 | 2 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
_ZN5doris21FunctionArrayDistanceINS_16CosineSimilarityEE6createEv Line | Count | Source | 97 | 12 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
_ZN5doris21FunctionArrayDistanceINS_12InnerProductEE6createEv Line | Count | Source | 97 | 2 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
_ZN5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE6createEv Line | Count | Source | 97 | 11 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
_ZN5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE6createEv Line | Count | Source | 97 | 2 | static FunctionPtr create() { return std::make_shared<FunctionArrayDistance<DistanceImpl>>(); } |
|
98 | 19 | size_t get_number_of_arguments() const override { return 2; }Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L1DistanceEE23get_number_of_argumentsEv Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L2DistanceEE23get_number_of_argumentsEv Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_14CosineDistanceEE23get_number_of_argumentsEv _ZNK5doris21FunctionArrayDistanceINS_16CosineSimilarityEE23get_number_of_argumentsEv Line | Count | Source | 98 | 10 | size_t get_number_of_arguments() const override { return 2; } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_12InnerProductEE23get_number_of_argumentsEv _ZNK5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE23get_number_of_argumentsEv Line | Count | Source | 98 | 9 | size_t get_number_of_arguments() const override { return 2; } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE23get_number_of_argumentsEv |
99 | | |
100 | 19 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
101 | 19 | if (arguments.size() != 2) { |
102 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, "Invalid number of arguments"); |
103 | 0 | } |
104 | | |
105 | | // primitive_type of Nullable is its nested type. |
106 | 19 | if (arguments[0]->get_primitive_type() != TYPE_ARRAY || |
107 | 19 | arguments[1]->get_primitive_type() != TYPE_ARRAY) { |
108 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
109 | 0 | "Arguments for function {} must be arrays", get_name()); |
110 | 0 | } |
111 | | |
112 | 19 | return std::make_shared<DataTypeFloat32>(); |
113 | 19 | } Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L1DistanceEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L2DistanceEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_14CosineDistanceEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE _ZNK5doris21FunctionArrayDistanceINS_16CosineSimilarityEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE Line | Count | Source | 100 | 10 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { | 101 | 10 | if (arguments.size() != 2) { | 102 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, "Invalid number of arguments"); | 103 | 0 | } | 104 | | | 105 | | // primitive_type of Nullable is its nested type. | 106 | 10 | if (arguments[0]->get_primitive_type() != TYPE_ARRAY || | 107 | 10 | arguments[1]->get_primitive_type() != TYPE_ARRAY) { | 108 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, | 109 | 0 | "Arguments for function {} must be arrays", get_name()); | 110 | 0 | } | 111 | | | 112 | 10 | return std::make_shared<DataTypeFloat32>(); | 113 | 10 | } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_12InnerProductEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE _ZNK5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE Line | Count | Source | 100 | 9 | DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { | 101 | 9 | if (arguments.size() != 2) { | 102 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, "Invalid number of arguments"); | 103 | 0 | } | 104 | | | 105 | | // primitive_type of Nullable is its nested type. | 106 | 9 | if (arguments[0]->get_primitive_type() != TYPE_ARRAY || | 107 | 9 | arguments[1]->get_primitive_type() != TYPE_ARRAY) { | 108 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, | 109 | 0 | "Arguments for function {} must be arrays", get_name()); | 110 | 0 | } | 111 | | | 112 | 9 | return std::make_shared<DataTypeFloat32>(); | 113 | 9 | } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE20get_return_type_implERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS7_EE |
114 | | |
115 | | // All array distance functions has always not nullable return type. |
116 | | // We want to make sure throw exception if input columns contain NULL. |
117 | 29 | bool use_default_implementation_for_nulls() const override { return false; }Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L1DistanceEE36use_default_implementation_for_nullsEv Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L2DistanceEE36use_default_implementation_for_nullsEv Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_14CosineDistanceEE36use_default_implementation_for_nullsEv _ZNK5doris21FunctionArrayDistanceINS_16CosineSimilarityEE36use_default_implementation_for_nullsEv Line | Count | Source | 117 | 20 | bool use_default_implementation_for_nulls() const override { return false; } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_12InnerProductEE36use_default_implementation_for_nullsEv _ZNK5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE36use_default_implementation_for_nullsEv Line | Count | Source | 117 | 9 | bool use_default_implementation_for_nulls() const override { return false; } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE36use_default_implementation_for_nullsEv |
118 | | |
119 | | // Validate that neither outer column nor inner array elements contain NULL. |
120 | | // Distance functions always throw on NULL input. |
121 | | static void _validate_no_nulls(const ColumnPtr& col, const char* arg_name, |
122 | 20 | const String& func_name) { |
123 | 20 | const IColumn* raw = col.get(); |
124 | | |
125 | | // Unwrap const |
126 | 20 | if (is_column_const(*raw)) { |
127 | 0 | raw = assert_cast<const ColumnConst*>(raw)->get_data_column_ptr().get(); |
128 | 0 | } |
129 | | |
130 | | // Check outer nullable |
131 | 20 | if (raw->is_nullable()) { |
132 | 20 | if (raw->has_null()) { |
133 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
134 | 0 | "{} for function {} cannot be null", arg_name, func_name); |
135 | 0 | } |
136 | 20 | raw = assert_cast<const ColumnNullable*>(raw)->get_nested_column_ptr().get(); |
137 | 20 | } |
138 | | |
139 | | // Check inner nullable (array elements) |
140 | 20 | const auto& array_col = assert_cast<const ColumnArray&>(*raw); |
141 | 20 | if (array_col.get_data_ptr()->is_nullable() && array_col.get_data_ptr()->has_null()) { |
142 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
143 | 0 | "{} for function {} cannot have null", arg_name, func_name); |
144 | 0 | } |
145 | 20 | } Unexecuted instantiation: _ZN5doris21FunctionArrayDistanceINS_10L1DistanceEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE Unexecuted instantiation: _ZN5doris21FunctionArrayDistanceINS_10L2DistanceEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE Unexecuted instantiation: _ZN5doris21FunctionArrayDistanceINS_14CosineDistanceEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE _ZN5doris21FunctionArrayDistanceINS_16CosineSimilarityEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE Line | Count | Source | 122 | 20 | const String& func_name) { | 123 | 20 | const IColumn* raw = col.get(); | 124 | | | 125 | | // Unwrap const | 126 | 20 | if (is_column_const(*raw)) { | 127 | 0 | raw = assert_cast<const ColumnConst*>(raw)->get_data_column_ptr().get(); | 128 | 0 | } | 129 | | | 130 | | // Check outer nullable | 131 | 20 | if (raw->is_nullable()) { | 132 | 20 | if (raw->has_null()) { | 133 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, | 134 | 0 | "{} for function {} cannot be null", arg_name, func_name); | 135 | 0 | } | 136 | 20 | raw = assert_cast<const ColumnNullable*>(raw)->get_nested_column_ptr().get(); | 137 | 20 | } | 138 | | | 139 | | // Check inner nullable (array elements) | 140 | 20 | const auto& array_col = assert_cast<const ColumnArray&>(*raw); | 141 | 20 | if (array_col.get_data_ptr()->is_nullable() && array_col.get_data_ptr()->has_null()) { | 142 | 0 | throw doris::Exception(ErrorCode::INVALID_ARGUMENT, | 143 | 0 | "{} for function {} cannot have null", arg_name, func_name); | 144 | 0 | } | 145 | 20 | } |
Unexecuted instantiation: _ZN5doris21FunctionArrayDistanceINS_12InnerProductEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE Unexecuted instantiation: _ZN5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE Unexecuted instantiation: _ZN5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE18_validate_no_nullsERKNS_3COWINS_7IColumnEE13immutable_ptrIS4_EEPKcRKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE |
146 | | |
147 | | Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, |
148 | 10 | uint32_t result, size_t input_rows_count) const override { |
149 | 10 | const auto& col1 = block.get_by_position(arguments[0]).column; |
150 | 10 | const auto& col2 = block.get_by_position(arguments[1]).column; |
151 | | |
152 | | // Validate no NULLs (distance functions always throw on NULL input) |
153 | 10 | _validate_no_nulls(col1, "First argument", get_name()); |
154 | 10 | _validate_no_nulls(col2, "Second argument", get_name()); |
155 | | |
156 | | // Create views — handles Const/Nullable unwrapping automatically |
157 | 10 | auto view1 = ColumnArrayView<TYPE_FLOAT>::create(col1); |
158 | 10 | auto view2 = ColumnArrayView<TYPE_FLOAT>::create(col2); |
159 | | |
160 | 10 | auto dst = ColumnType::create(input_rows_count); |
161 | 10 | auto& dst_data = dst->get_data(); |
162 | | |
163 | 22 | for (size_t row = 0; row < input_rows_count; ++row) { |
164 | 12 | auto a1 = view1[row]; |
165 | 12 | auto a2 = view2[row]; |
166 | 12 | const float* p1 = a1.get_data(); |
167 | 12 | const float* p2 = a2.get_data(); |
168 | 12 | auto dim1 = a1.size(); |
169 | 12 | auto dim2 = a2.size(); |
170 | | |
171 | 12 | if (dim1 != dim2) [[unlikely]] { |
172 | 0 | return Status::InvalidArgument( |
173 | 0 | "function {} have different input element sizes of array: {} and {}", |
174 | 0 | get_name(), dim1, dim2); |
175 | 0 | } |
176 | 12 | dst_data[row] = DistanceImpl::distance(p1, p2, dim1); |
177 | 12 | } |
178 | | |
179 | 10 | block.replace_by_position(result, std::move(dst)); |
180 | 10 | return Status::OK(); |
181 | 10 | } Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L1DistanceEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_10L2DistanceEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_14CosineDistanceEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm _ZNK5doris21FunctionArrayDistanceINS_16CosineSimilarityEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm Line | Count | Source | 148 | 10 | uint32_t result, size_t input_rows_count) const override { | 149 | 10 | const auto& col1 = block.get_by_position(arguments[0]).column; | 150 | 10 | const auto& col2 = block.get_by_position(arguments[1]).column; | 151 | | | 152 | | // Validate no NULLs (distance functions always throw on NULL input) | 153 | 10 | _validate_no_nulls(col1, "First argument", get_name()); | 154 | 10 | _validate_no_nulls(col2, "Second argument", get_name()); | 155 | | | 156 | | // Create views — handles Const/Nullable unwrapping automatically | 157 | 10 | auto view1 = ColumnArrayView<TYPE_FLOAT>::create(col1); | 158 | 10 | auto view2 = ColumnArrayView<TYPE_FLOAT>::create(col2); | 159 | | | 160 | 10 | auto dst = ColumnType::create(input_rows_count); | 161 | 10 | auto& dst_data = dst->get_data(); | 162 | | | 163 | 22 | for (size_t row = 0; row < input_rows_count; ++row) { | 164 | 12 | auto a1 = view1[row]; | 165 | 12 | auto a2 = view2[row]; | 166 | 12 | const float* p1 = a1.get_data(); | 167 | 12 | const float* p2 = a2.get_data(); | 168 | 12 | auto dim1 = a1.size(); | 169 | 12 | auto dim2 = a2.size(); | 170 | | | 171 | 12 | if (dim1 != dim2) [[unlikely]] { | 172 | 0 | return Status::InvalidArgument( | 173 | 0 | "function {} have different input element sizes of array: {} and {}", | 174 | 0 | get_name(), dim1, dim2); | 175 | 0 | } | 176 | 12 | dst_data[row] = DistanceImpl::distance(p1, p2, dim1); | 177 | 12 | } | 178 | | | 179 | 10 | block.replace_by_position(result, std::move(dst)); | 180 | 10 | return Status::OK(); | 181 | 10 | } |
Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_12InnerProductEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_21L2DistanceApproximateEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm Unexecuted instantiation: _ZNK5doris21FunctionArrayDistanceINS_23InnerProductApproximateEE12execute_implEPNS_15FunctionContextERNS_5BlockERKSt6vectorIjSaIjEEjm |
182 | | }; |
183 | | |
184 | | } // namespace doris |