be/src/storage/compaction/collection_similarity.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "storage/compaction/collection_similarity.h" |
19 | | |
20 | | #include "core/column/column_nullable.h" |
21 | | #include "core/column/column_vector.h" |
22 | | |
23 | | namespace doris { |
24 | | |
25 | 673k | void CollectionSimilarity::collect(segment_v2::rowid_t row_id, float score) { |
26 | 673k | _bm25_scores[row_id] += score; |
27 | 673k | } |
28 | | |
29 | | void CollectionSimilarity::get_bm25_scores(roaring::Roaring* row_bitmap, |
30 | | IColumn::MutablePtr& scores, |
31 | | std::unique_ptr<std::vector<uint64_t>>& row_ids, |
32 | 12 | const ScoreRangeFilterPtr& filter) const { |
33 | 12 | std::vector<float> filtered_scores; |
34 | 12 | filtered_scores.reserve(row_bitmap->cardinality()); |
35 | | |
36 | 12 | roaring::Roaring new_bitmap; |
37 | | |
38 | 330k | for (uint32_t row_id : *row_bitmap) { |
39 | 330k | auto it = _bm25_scores.find(row_id); |
40 | 330k | float score = (it != _bm25_scores.end()) ? it->second : 0.0F; |
41 | 330k | if (filter && !filter->pass(score)) { |
42 | 110k | continue; |
43 | 110k | } |
44 | 220k | row_ids->push_back(row_id); |
45 | 220k | filtered_scores.push_back(score); |
46 | 220k | new_bitmap.add(row_id); |
47 | 220k | } |
48 | | |
49 | 12 | size_t num_results = row_ids->size(); |
50 | 12 | auto score_column = ColumnFloat32::create(num_results); |
51 | 12 | if (num_results > 0) { |
52 | 11 | memcpy(score_column->get_data().data(), filtered_scores.data(), |
53 | 11 | num_results * sizeof(float)); |
54 | 11 | } |
55 | | |
56 | 12 | *row_bitmap = std::move(new_bitmap); |
57 | 12 | auto null_map = ColumnUInt8::create(num_results, 0); |
58 | 12 | scores = ColumnNullable::create(std::move(score_column), std::move(null_map)); |
59 | 12 | } |
60 | | |
61 | | void CollectionSimilarity::get_topn_bm25_scores(roaring::Roaring* row_bitmap, |
62 | | IColumn::MutablePtr& scores, |
63 | | std::unique_ptr<std::vector<uint64_t>>& row_ids, |
64 | | OrderType order_type, size_t top_k, |
65 | 51 | const ScoreRangeFilterPtr& filter) const { |
66 | 51 | std::vector<std::pair<uint32_t, float>> top_k_results; |
67 | | |
68 | 51 | if (order_type == OrderType::DESC) { |
69 | 37 | find_top_k_scores<OrderType::DESC>(row_bitmap, _bm25_scores, top_k, top_k_results, filter); |
70 | 37 | } else { |
71 | 14 | find_top_k_scores<OrderType::ASC>(row_bitmap, _bm25_scores, top_k, top_k_results, filter); |
72 | 14 | } |
73 | | |
74 | 51 | size_t num_results = top_k_results.size(); |
75 | 51 | auto score_column = ColumnFloat32::create(num_results); |
76 | 51 | auto& score_data = score_column->get_data(); |
77 | | |
78 | 51 | row_ids->resize(num_results); |
79 | 51 | roaring::Roaring new_bitmap; |
80 | | |
81 | 1.47k | for (size_t i = 0; i < num_results; ++i) { |
82 | 1.42k | (*row_ids)[i] = top_k_results[i].first; |
83 | 1.42k | score_data[i] = top_k_results[i].second; |
84 | 1.42k | new_bitmap.add(top_k_results[i].first); |
85 | 1.42k | } |
86 | | |
87 | 51 | *row_bitmap = std::move(new_bitmap); |
88 | 51 | auto null_map = ColumnUInt8::create(num_results, 0); |
89 | 51 | scores = ColumnNullable::create(std::move(score_column), std::move(null_map)); |
90 | 51 | } |
91 | | |
92 | | template <OrderType order> |
93 | | void CollectionSimilarity::find_top_k_scores(const roaring::Roaring* row_bitmap, |
94 | | const ScoreMap& all_scores, size_t top_k, |
95 | | std::vector<std::pair<uint32_t, float>>& top_k_results, |
96 | 51 | const ScoreRangeFilterPtr& filter) const { |
97 | 51 | if (top_k <= 0) { |
98 | 2 | return; |
99 | 2 | } |
100 | | |
101 | 1.14M | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { |
102 | 1.14M | if constexpr (order == OrderType::DESC) { |
103 | 1.04M | return a.second > b.second; |
104 | 1.04M | } else { |
105 | 102k | return a.second < b.second; |
106 | 102k | } |
107 | 1.14M | }; _ZZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE1EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEEENKUlRKSL_SV_E_clESV_SV_ Line | Count | Source | 101 | 1.04M | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 102 | 1.04M | if constexpr (order == OrderType::DESC) { | 103 | 1.04M | return a.second > b.second; | 104 | | } else { | 105 | | return a.second < b.second; | 106 | | } | 107 | 1.04M | }; |
_ZZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE0EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEEENKUlRKSL_SV_E_clESV_SV_ Line | Count | Source | 101 | 102k | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 102 | | if constexpr (order == OrderType::DESC) { | 103 | | return a.second > b.second; | 104 | 102k | } else { | 105 | 102k | return a.second < b.second; | 106 | 102k | } | 107 | 102k | }; |
|
108 | | |
109 | 49 | std::priority_queue<std::pair<uint32_t, float>, std::vector<std::pair<uint32_t, float>>, |
110 | 49 | decltype(pair_comp)> |
111 | 49 | top_k_heap(pair_comp); |
112 | | |
113 | 49 | std::vector<uint32_t> zero_score_ids; |
114 | | |
115 | 412k | for (uint32_t row_id : *row_bitmap) { |
116 | 412k | auto it = all_scores.find(row_id); |
117 | 412k | float score = (it != all_scores.end()) ? it->second : 0.0F; |
118 | | |
119 | 412k | if (filter && !filter->pass(score)) { |
120 | 188k | continue; |
121 | 188k | } |
122 | | |
123 | 223k | if (score == 0.0F) { |
124 | 6 | zero_score_ids.push_back(row_id); |
125 | 6 | continue; |
126 | 6 | } |
127 | | |
128 | 223k | if (top_k_heap.size() < top_k) { |
129 | 1.42k | top_k_heap.emplace(row_id, score); |
130 | 222k | } else if (pair_comp({row_id, score}, top_k_heap.top())) { |
131 | 120k | top_k_heap.pop(); |
132 | 120k | top_k_heap.emplace(row_id, score); |
133 | 120k | } |
134 | 223k | } |
135 | | |
136 | 49 | top_k_results.reserve(top_k); |
137 | 1.46k | while (!top_k_heap.empty()) { |
138 | 1.42k | top_k_results.push_back(top_k_heap.top()); |
139 | 1.42k | top_k_heap.pop(); |
140 | 1.42k | } |
141 | 49 | std::ranges::reverse(top_k_results); |
142 | | |
143 | 49 | if constexpr (order == OrderType::DESC) { |
144 | | // DESC: high scores first, then zeros at the end |
145 | 35 | size_t remaining = top_k - top_k_results.size(); |
146 | 38 | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { |
147 | 3 | top_k_results.emplace_back(zero_score_ids[i], 0.0F); |
148 | 3 | } |
149 | 35 | } else { |
150 | | // ASC: zeros first, then low scores |
151 | 14 | std::vector<std::pair<uint32_t, float>> final_results; |
152 | 14 | final_results.reserve(top_k); |
153 | | |
154 | 14 | size_t zero_count = std::min(top_k, zero_score_ids.size()); |
155 | 16 | for (size_t i = 0; i < zero_count; ++i) { |
156 | 2 | final_results.emplace_back(zero_score_ids[i], 0.0F); |
157 | 2 | } |
158 | | |
159 | 14 | size_t remaining = top_k - final_results.size(); |
160 | 155 | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { |
161 | 141 | final_results.push_back(top_k_results[i]); |
162 | 141 | } |
163 | | |
164 | 14 | top_k_results = std::move(final_results); |
165 | 14 | } |
166 | 49 | } _ZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE1EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEE Line | Count | Source | 96 | 37 | const ScoreRangeFilterPtr& filter) const { | 97 | 37 | if (top_k <= 0) { | 98 | 2 | return; | 99 | 2 | } | 100 | | | 101 | 35 | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 102 | 35 | if constexpr (order == OrderType::DESC) { | 103 | 35 | return a.second > b.second; | 104 | 35 | } else { | 105 | 35 | return a.second < b.second; | 106 | 35 | } | 107 | 35 | }; | 108 | | | 109 | 35 | std::priority_queue<std::pair<uint32_t, float>, std::vector<std::pair<uint32_t, float>>, | 110 | 35 | decltype(pair_comp)> | 111 | 35 | top_k_heap(pair_comp); | 112 | | | 113 | 35 | std::vector<uint32_t> zero_score_ids; | 114 | | | 115 | 310k | for (uint32_t row_id : *row_bitmap) { | 116 | 310k | auto it = all_scores.find(row_id); | 117 | 310k | float score = (it != all_scores.end()) ? it->second : 0.0F; | 118 | | | 119 | 310k | if (filter && !filter->pass(score)) { | 120 | 188k | continue; | 121 | 188k | } | 122 | | | 123 | 122k | if (score == 0.0F) { | 124 | 4 | zero_score_ids.push_back(row_id); | 125 | 4 | continue; | 126 | 4 | } | 127 | | | 128 | 122k | if (top_k_heap.size() < top_k) { | 129 | 1.27k | top_k_heap.emplace(row_id, score); | 130 | 120k | } else if (pair_comp({row_id, score}, top_k_heap.top())) { | 131 | 120k | top_k_heap.pop(); | 132 | 120k | top_k_heap.emplace(row_id, score); | 133 | 120k | } | 134 | 122k | } | 135 | | | 136 | 35 | top_k_results.reserve(top_k); | 137 | 1.31k | while (!top_k_heap.empty()) { | 138 | 1.27k | top_k_results.push_back(top_k_heap.top()); | 139 | 1.27k | top_k_heap.pop(); | 140 | 1.27k | } | 141 | 35 | std::ranges::reverse(top_k_results); | 142 | | | 143 | 35 | if constexpr (order == OrderType::DESC) { | 144 | | // DESC: high scores first, then zeros at the end | 145 | 35 | size_t remaining = top_k - top_k_results.size(); | 146 | 38 | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { | 147 | 3 | top_k_results.emplace_back(zero_score_ids[i], 0.0F); | 148 | 3 | } | 149 | | } else { | 150 | | // ASC: zeros first, then low scores | 151 | | std::vector<std::pair<uint32_t, float>> final_results; | 152 | | final_results.reserve(top_k); | 153 | | | 154 | | size_t zero_count = std::min(top_k, zero_score_ids.size()); | 155 | | for (size_t i = 0; i < zero_count; ++i) { | 156 | | final_results.emplace_back(zero_score_ids[i], 0.0F); | 157 | | } | 158 | | | 159 | | size_t remaining = top_k - final_results.size(); | 160 | | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { | 161 | | final_results.push_back(top_k_results[i]); | 162 | | } | 163 | | | 164 | | top_k_results = std::move(final_results); | 165 | | } | 166 | 35 | } |
_ZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE0EEEvPKN7roaring7RoaringERKN5phmap13flat_hash_mapIjfNS7_4HashIjEENS7_7EqualToIjEESaISt4pairIKjfEEEEmRSt6vectorISD_IjfESaISL_EERKSt10shared_ptrINS_16ScoreRangeFilterEE Line | Count | Source | 96 | 14 | const ScoreRangeFilterPtr& filter) const { | 97 | 14 | if (top_k <= 0) { | 98 | 0 | return; | 99 | 0 | } | 100 | | | 101 | 14 | auto pair_comp = [](const std::pair<uint32_t, float>& a, const std::pair<uint32_t, float>& b) { | 102 | 14 | if constexpr (order == OrderType::DESC) { | 103 | 14 | return a.second > b.second; | 104 | 14 | } else { | 105 | 14 | return a.second < b.second; | 106 | 14 | } | 107 | 14 | }; | 108 | | | 109 | 14 | std::priority_queue<std::pair<uint32_t, float>, std::vector<std::pair<uint32_t, float>>, | 110 | 14 | decltype(pair_comp)> | 111 | 14 | top_k_heap(pair_comp); | 112 | | | 113 | 14 | std::vector<uint32_t> zero_score_ids; | 114 | | | 115 | 102k | for (uint32_t row_id : *row_bitmap) { | 116 | 102k | auto it = all_scores.find(row_id); | 117 | 102k | float score = (it != all_scores.end()) ? it->second : 0.0F; | 118 | | | 119 | 102k | if (filter && !filter->pass(score)) { | 120 | 690 | continue; | 121 | 690 | } | 122 | | | 123 | 101k | if (score == 0.0F) { | 124 | 2 | zero_score_ids.push_back(row_id); | 125 | 2 | continue; | 126 | 2 | } | 127 | | | 128 | 101k | if (top_k_heap.size() < top_k) { | 129 | 143 | top_k_heap.emplace(row_id, score); | 130 | 101k | } else if (pair_comp({row_id, score}, top_k_heap.top())) { | 131 | 32 | top_k_heap.pop(); | 132 | 32 | top_k_heap.emplace(row_id, score); | 133 | 32 | } | 134 | 101k | } | 135 | | | 136 | 14 | top_k_results.reserve(top_k); | 137 | 157 | while (!top_k_heap.empty()) { | 138 | 143 | top_k_results.push_back(top_k_heap.top()); | 139 | 143 | top_k_heap.pop(); | 140 | 143 | } | 141 | 14 | std::ranges::reverse(top_k_results); | 142 | | | 143 | | if constexpr (order == OrderType::DESC) { | 144 | | // DESC: high scores first, then zeros at the end | 145 | | size_t remaining = top_k - top_k_results.size(); | 146 | | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { | 147 | | top_k_results.emplace_back(zero_score_ids[i], 0.0F); | 148 | | } | 149 | 14 | } else { | 150 | | // ASC: zeros first, then low scores | 151 | 14 | std::vector<std::pair<uint32_t, float>> final_results; | 152 | 14 | final_results.reserve(top_k); | 153 | | | 154 | 14 | size_t zero_count = std::min(top_k, zero_score_ids.size()); | 155 | 16 | for (size_t i = 0; i < zero_count; ++i) { | 156 | 2 | final_results.emplace_back(zero_score_ids[i], 0.0F); | 157 | 2 | } | 158 | | | 159 | 14 | size_t remaining = top_k - final_results.size(); | 160 | 155 | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { | 161 | 141 | final_results.push_back(top_k_results[i]); | 162 | 141 | } | 163 | | | 164 | 14 | top_k_results = std::move(final_results); | 165 | 14 | } | 166 | 14 | } |
|
167 | | |
168 | | } // namespace doris |