/root/doris/be/src/olap/collection_similarity.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "olap/collection_similarity.h" |
19 | | |
20 | | #include "vec/columns/column_nullable.h" |
21 | | #include "vec/columns/column_vector.h" |
22 | | |
23 | | namespace doris { |
24 | | #include "common/compile_check_begin.h" |
25 | | |
26 | 671k | void CollectionSimilarity::collect(segment_v2::rowid_t row_id, float score) { |
27 | 671k | _bm25_scores[row_id] += score; |
28 | 671k | } |
29 | | |
30 | | void CollectionSimilarity::get_bm25_scores(roaring::Roaring* row_bitmap, |
31 | | vectorized::IColumn::MutablePtr& scores, |
32 | | std::unique_ptr<std::vector<uint64_t>>& row_ids, |
33 | 12 | const ScoreRangeFilterPtr& filter) const { |
34 | 12 | std::vector<float> filtered_scores; |
35 | 12 | filtered_scores.reserve(row_bitmap->cardinality()); |
36 | | |
37 | 12 | roaring::Roaring new_bitmap; |
38 | | |
39 | 330k | for (uint32_t row_id : *row_bitmap) { |
40 | 330k | auto it = _bm25_scores.find(row_id); |
41 | 330k | float score = (it != _bm25_scores.end()) ? it->second : 0.0F; |
42 | 330k | if (filter && !filter->pass(score)) { |
43 | 110k | continue; |
44 | 110k | } |
45 | 220k | row_ids->push_back(row_id); |
46 | 220k | filtered_scores.push_back(score); |
47 | 220k | new_bitmap.add(row_id); |
48 | 220k | } |
49 | | |
50 | 12 | size_t num_results = row_ids->size(); |
51 | 12 | auto score_column = vectorized::ColumnFloat32::create(num_results); |
52 | 12 | if (num_results > 0) { |
53 | 11 | memcpy(score_column->get_data().data(), filtered_scores.data(), |
54 | 11 | num_results * sizeof(float)); |
55 | 11 | } |
56 | | |
57 | 12 | *row_bitmap = std::move(new_bitmap); |
58 | 12 | auto null_map = vectorized::ColumnUInt8::create(num_results, 0); |
59 | 12 | scores = vectorized::ColumnNullable::create(std::move(score_column), std::move(null_map)); |
60 | 12 | } |
61 | | |
62 | | void CollectionSimilarity::get_topn_bm25_scores(roaring::Roaring* row_bitmap, |
63 | | vectorized::IColumn::MutablePtr& scores, |
64 | | std::unique_ptr<std::vector<uint64_t>>& row_ids, |
65 | | OrderType order_type, size_t top_k, |
66 | 16 | const ScoreRangeFilterPtr& filter) const { |
67 | 16 | std::vector<std::pair<uint32_t, float>> top_k_results; |
68 | | |
69 | 16 | if (order_type == OrderType::DESC) { |
70 | 12 | find_top_k_scores<OrderType::DESC>(row_bitmap, _bm25_scores, top_k, top_k_results, filter); |
71 | 12 | } else { |
72 | 4 | find_top_k_scores<OrderType::ASC>(row_bitmap, _bm25_scores, top_k, top_k_results, filter); |
73 | 4 | } |
74 | | |
75 | 16 | size_t num_results = top_k_results.size(); |
76 | 16 | auto score_column = vectorized::ColumnFloat32::create(num_results); |
77 | 16 | auto& score_data = score_column->get_data(); |
78 | | |
79 | 16 | row_ids->resize(num_results); |
80 | 16 | roaring::Roaring new_bitmap; |
81 | | |
82 | 1.39k | for (size_t i = 0; i < num_results; ++i) { |
83 | 1.37k | (*row_ids)[i] = top_k_results[i].first; |
84 | 1.37k | score_data[i] = top_k_results[i].second; |
85 | 1.37k | new_bitmap.add(top_k_results[i].first); |
86 | 1.37k | } |
87 | | |
88 | 16 | *row_bitmap = std::move(new_bitmap); |
89 | 16 | auto null_map = vectorized::ColumnUInt8::create(num_results, 0); |
90 | 16 | scores = vectorized::ColumnNullable::create(std::move(score_column), std::move(null_map)); |
91 | 16 | } |
92 | | |
93 | | template <OrderType order> |
94 | | void CollectionSimilarity::find_top_k_scores(const roaring::Roaring* row_bitmap, |
95 | | const ScoreMap& all_scores, size_t top_k, |
96 | | std::vector<std::pair<uint32_t, float>>& top_k_results, |
97 | 16 | const ScoreRangeFilterPtr& filter) const { |
98 | 16 | if (top_k <= 0) { |
99 | 2 | return; |
100 | 2 | } |
101 | | |
102 | 1.14M | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { |
103 | 1.14M | if constexpr (order == OrderType::DESC) { |
104 | 1.04M | return a.second > b.second; |
105 | 1.04M | } else { |
106 | 100k | return a.second < b.second; |
107 | 100k | } |
108 | 1.14M | }; _ZZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE1EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEEENKUlRKSL_SV_E_clESV_SV_ Line | Count | Source | 102 | 1.04M | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 103 | 1.04M | if constexpr (order == OrderType::DESC) { | 104 | 1.04M | return a.second > b.second; | 105 | | } else { | 106 | | return a.second < b.second; | 107 | | } | 108 | 1.04M | }; |
_ZZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE0EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEEENKUlRKSL_SV_E_clESV_SV_ Line | Count | Source | 102 | 100k | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 103 | | if constexpr (order == OrderType::DESC) { | 104 | | return a.second > b.second; | 105 | 100k | } else { | 106 | 100k | return a.second < b.second; | 107 | 100k | } | 108 | 100k | }; |
|
109 | | |
110 | 14 | std::priority_queue<std::pair<uint32_t, float>, std::vector<std::pair<uint32_t, float>>, |
111 | 14 | decltype(pair_comp)> |
112 | 14 | top_k_heap(pair_comp); |
113 | | |
114 | 14 | std::vector<uint32_t> zero_score_ids; |
115 | | |
116 | 410k | for (uint32_t row_id : *row_bitmap) { |
117 | 410k | auto it = all_scores.find(row_id); |
118 | 410k | float score = (it != all_scores.end()) ? it->second : 0.0F; |
119 | | |
120 | 410k | if (filter && !filter->pass(score)) { |
121 | 188k | continue; |
122 | 188k | } |
123 | | |
124 | 222k | if (score == 0.0F) { |
125 | 6 | zero_score_ids.push_back(row_id); |
126 | 6 | continue; |
127 | 6 | } |
128 | | |
129 | 222k | if (top_k_heap.size() < top_k) { |
130 | 1.37k | top_k_heap.emplace(row_id, score); |
131 | 220k | } else if (pair_comp({row_id, score}, top_k_heap.top())) { |
132 | 120k | top_k_heap.pop(); |
133 | 120k | top_k_heap.emplace(row_id, score); |
134 | 120k | } |
135 | 222k | } |
136 | | |
137 | 14 | top_k_results.reserve(top_k); |
138 | 1.38k | while (!top_k_heap.empty()) { |
139 | 1.37k | top_k_results.push_back(top_k_heap.top()); |
140 | 1.37k | top_k_heap.pop(); |
141 | 1.37k | } |
142 | 14 | std::ranges::reverse(top_k_results); |
143 | | |
144 | 14 | if constexpr (order == OrderType::DESC) { |
145 | | // DESC: high scores first, then zeros at the end |
146 | 10 | size_t remaining = top_k - top_k_results.size(); |
147 | 13 | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { |
148 | 3 | top_k_results.emplace_back(zero_score_ids[i], 0.0F); |
149 | 3 | } |
150 | 10 | } else { |
151 | | // ASC: zeros first, then low scores |
152 | 4 | std::vector<std::pair<uint32_t, float>> final_results; |
153 | 4 | final_results.reserve(top_k); |
154 | | |
155 | 4 | size_t zero_count = std::min(top_k, zero_score_ids.size()); |
156 | 6 | for (size_t i = 0; i < zero_count; ++i) { |
157 | 2 | final_results.emplace_back(zero_score_ids[i], 0.0F); |
158 | 2 | } |
159 | | |
160 | 4 | size_t remaining = top_k - final_results.size(); |
161 | 111 | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { |
162 | 107 | final_results.push_back(top_k_results[i]); |
163 | 107 | } |
164 | | |
165 | 4 | top_k_results = std::move(final_results); |
166 | 4 | } |
167 | 14 | } _ZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE1EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEE Line | Count | Source | 97 | 12 | const ScoreRangeFilterPtr& filter) const { | 98 | 12 | if (top_k <= 0) { | 99 | 2 | return; | 100 | 2 | } | 101 | | | 102 | 10 | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 103 | 10 | if constexpr (order == OrderType::DESC) { | 104 | 10 | return a.second > b.second; | 105 | 10 | } else { | 106 | 10 | return a.second < b.second; | 107 | 10 | } | 108 | 10 | }; | 109 | | | 110 | 10 | std::priority_queue<std::pair<uint32_t, float>, std::vector<std::pair<uint32_t, float>>, | 111 | 10 | decltype(pair_comp)> | 112 | 10 | top_k_heap(pair_comp); | 113 | | | 114 | 10 | std::vector<uint32_t> zero_score_ids; | 115 | | | 116 | 310k | for (uint32_t row_id : *row_bitmap) { | 117 | 310k | auto it = all_scores.find(row_id); | 118 | 310k | float score = (it != all_scores.end()) ? it->second : 0.0F; | 119 | | | 120 | 310k | if (filter && !filter->pass(score)) { | 121 | 188k | continue; | 122 | 188k | } | 123 | | | 124 | 122k | if (score == 0.0F) { | 125 | 4 | zero_score_ids.push_back(row_id); | 126 | 4 | continue; | 127 | 4 | } | 128 | | | 129 | 122k | if (top_k_heap.size() < top_k) { | 130 | 1.26k | top_k_heap.emplace(row_id, score); | 131 | 120k | } else if (pair_comp({row_id, score}, top_k_heap.top())) { | 132 | 120k | top_k_heap.pop(); | 133 | 120k | top_k_heap.emplace(row_id, score); | 134 | 120k | } | 135 | 122k | } | 136 | | | 137 | 10 | top_k_results.reserve(top_k); | 138 | 1.27k | while (!top_k_heap.empty()) { | 139 | 1.26k | top_k_results.push_back(top_k_heap.top()); | 140 | 1.26k | top_k_heap.pop(); | 141 | 1.26k | } | 142 | 10 | std::ranges::reverse(top_k_results); | 143 | | | 144 | 10 | if constexpr (order == OrderType::DESC) { | 145 | | // DESC: high scores first, then zeros at the end | 146 | 10 | size_t remaining = top_k - top_k_results.size(); | 147 | 13 | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { | 148 | 3 | top_k_results.emplace_back(zero_score_ids[i], 0.0F); | 149 | 3 | } | 150 | | } else { | 151 | | // ASC: zeros first, then low scores | 152 | | std::vector<std::pair<uint32_t, float>> final_results; | 153 | | final_results.reserve(top_k); | 154 | | | 155 | | size_t zero_count = std::min(top_k, zero_score_ids.size()); | 156 | | for (size_t i = 0; i < zero_count; ++i) { | 157 | | final_results.emplace_back(zero_score_ids[i], 0.0F); | 158 | | } | 159 | | | 160 | | size_t remaining = top_k - final_results.size(); | 161 | | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { | 162 | | final_results.push_back(top_k_results[i]); | 163 | | } | 164 | | | 165 | | top_k_results = std::move(final_results); | 166 | | } | 167 | 10 | } |
_ZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE0EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEE Line | Count | Source | 97 | 4 | const ScoreRangeFilterPtr& filter) const { | 98 | 4 | if (top_k <= 0) { | 99 | 0 | return; | 100 | 0 | } | 101 | | | 102 | 4 | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 103 | 4 | if constexpr (order == OrderType::DESC) { | 104 | 4 | return a.second > b.second; | 105 | 4 | } else { | 106 | 4 | return a.second < b.second; | 107 | 4 | } | 108 | 4 | }; | 109 | | | 110 | 4 | std::priority_queue<std::pair<uint32_t, float>, std::vector<std::pair<uint32_t, float>>, | 111 | 4 | decltype(pair_comp)> | 112 | 4 | top_k_heap(pair_comp); | 113 | | | 114 | 4 | std::vector<uint32_t> zero_score_ids; | 115 | | | 116 | 100k | for (uint32_t row_id : *row_bitmap) { | 117 | 100k | auto it = all_scores.find(row_id); | 118 | 100k | float score = (it != all_scores.end()) ? it->second : 0.0F; | 119 | | | 120 | 100k | if (filter && !filter->pass(score)) { | 121 | 2 | continue; | 122 | 2 | } | 123 | | | 124 | 100k | if (score == 0.0F) { | 125 | 2 | zero_score_ids.push_back(row_id); | 126 | 2 | continue; | 127 | 2 | } | 128 | | | 129 | 100k | if (top_k_heap.size() < top_k) { | 130 | 109 | top_k_heap.emplace(row_id, score); | 131 | 99.9k | } else if (pair_comp({row_id, score}, top_k_heap.top())) { | 132 | 1 | top_k_heap.pop(); | 133 | 1 | top_k_heap.emplace(row_id, score); | 134 | 1 | } | 135 | 100k | } | 136 | | | 137 | 4 | top_k_results.reserve(top_k); | 138 | 113 | while (!top_k_heap.empty()) { | 139 | 109 | top_k_results.push_back(top_k_heap.top()); | 140 | 109 | top_k_heap.pop(); | 141 | 109 | } | 142 | 4 | std::ranges::reverse(top_k_results); | 143 | | | 144 | | if constexpr (order == OrderType::DESC) { | 145 | | // DESC: high scores first, then zeros at the end | 146 | | size_t remaining = top_k - top_k_results.size(); | 147 | | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { | 148 | | top_k_results.emplace_back(zero_score_ids[i], 0.0F); | 149 | | } | 150 | 4 | } else { | 151 | | // ASC: zeros first, then low scores | 152 | 4 | std::vector<std::pair<uint32_t, float>> final_results; | 153 | 4 | final_results.reserve(top_k); | 154 | | | 155 | 4 | size_t zero_count = std::min(top_k, zero_score_ids.size()); | 156 | 6 | for (size_t i = 0; i < zero_count; ++i) { | 157 | 2 | final_results.emplace_back(zero_score_ids[i], 0.0F); | 158 | 2 | } | 159 | | | 160 | 4 | size_t remaining = top_k - final_results.size(); | 161 | 111 | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { | 162 | 107 | final_results.push_back(top_k_results[i]); | 163 | 107 | } | 164 | | | 165 | 4 | top_k_results = std::move(final_results); | 166 | 4 | } | 167 | 4 | } |
|
168 | | |
169 | | #include "common/compile_check_end.h" |
170 | | } // namespace doris |