/root/doris/be/src/olap/collection_similarity.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "olap/collection_similarity.h" |
19 | | |
20 | | #include "vec/columns/column_nullable.h" |
21 | | #include "vec/columns/column_vector.h" |
22 | | |
23 | | namespace doris { |
24 | | #include "common/compile_check_begin.h" |
25 | | |
26 | 49 | void CollectionSimilarity::collect(segment_v2::rowid_t row_id, float score) { |
27 | 49 | _bm25_scores[row_id] += score; |
28 | 49 | } |
29 | | |
30 | | void CollectionSimilarity::get_bm25_scores(roaring::Roaring* row_bitmap, |
31 | | vectorized::IColumn::MutablePtr& scores, |
32 | 2 | std::unique_ptr<std::vector<uint64_t>>& row_ids) const { |
33 | 2 | size_t num_results = row_bitmap->cardinality(); |
34 | 2 | auto score_column = vectorized::ColumnFloat32::create(num_results); |
35 | 2 | auto& score_data = score_column->get_data(); |
36 | | |
37 | 2 | row_ids->resize(num_results); |
38 | | |
39 | 2 | int32_t i = 0; |
40 | 6 | for (uint32_t row_id : *row_bitmap) { |
41 | 6 | (*row_ids)[i] = row_id; |
42 | 6 | auto it = _bm25_scores.find(row_id); |
43 | 6 | if (it != _bm25_scores.end()) { |
44 | 4 | score_data[i] = it->second; |
45 | 4 | } else { |
46 | 2 | score_data[i] = 0.0; |
47 | 2 | } |
48 | 6 | i++; |
49 | 6 | } |
50 | | |
51 | 2 | auto null_map = vectorized::ColumnUInt8::create(num_results, 0); |
52 | 2 | scores = vectorized::ColumnNullable::create(std::move(score_column), std::move(null_map)); |
53 | 2 | } |
54 | | |
55 | | void CollectionSimilarity::get_topn_bm25_scores(roaring::Roaring* row_bitmap, |
56 | | vectorized::IColumn::MutablePtr& scores, |
57 | | std::unique_ptr<std::vector<uint64_t>>& row_ids, |
58 | 7 | OrderType order_type, size_t top_k) const { |
59 | 7 | std::vector<std::pair<uint32_t, float>> top_k_results; |
60 | | |
61 | 7 | if (order_type == OrderType::DESC) { |
62 | 6 | find_top_k_scores<OrderType::DESC>( |
63 | 6 | row_bitmap, _bm25_scores, top_k, |
64 | 19 | [](const ScoreMapIterator& a, const ScoreMapIterator& b) { |
65 | 19 | return a->second > b->second; |
66 | 19 | }, |
67 | 6 | top_k_results); |
68 | 6 | } else { |
69 | 1 | find_top_k_scores<OrderType::ASC>( |
70 | 1 | row_bitmap, _bm25_scores, top_k, |
71 | 7 | [](const ScoreMapIterator& a, const ScoreMapIterator& b) { |
72 | 7 | return a->second < b->second; |
73 | 7 | }, |
74 | 1 | top_k_results); |
75 | 1 | } |
76 | | |
77 | 7 | size_t num_results = top_k_results.size(); |
78 | 7 | auto score_column = vectorized::ColumnFloat32::create(num_results); |
79 | 7 | auto& score_data = score_column->get_data(); |
80 | | |
81 | 7 | row_ids->resize(num_results); |
82 | 7 | roaring::Roaring new_bitmap; |
83 | | |
84 | 23 | for (size_t i = 0; i < num_results; ++i) { |
85 | 16 | (*row_ids)[i] = top_k_results[i].first; |
86 | 16 | score_data[i] = top_k_results[i].second; |
87 | 16 | new_bitmap.add(top_k_results[i].first); |
88 | 16 | } |
89 | | |
90 | 7 | *row_bitmap = std::move(new_bitmap); |
91 | 7 | auto null_map = vectorized::ColumnUInt8::create(num_results, 0); |
92 | 7 | scores = vectorized::ColumnNullable::create(std::move(score_column), std::move(null_map)); |
93 | 7 | } |
94 | | |
95 | | template <OrderType order, typename Compare> |
96 | | void CollectionSimilarity::find_top_k_scores( |
97 | | const roaring::Roaring* row_bitmap, const ScoreMap& all_scores, size_t top_k, Compare comp, |
98 | 7 | std::vector<std::pair<uint32_t, float>>& top_k_results) const { |
99 | 7 | if (top_k <= 0) { |
100 | 2 | return; |
101 | 2 | } |
102 | | |
103 | 5 | std::priority_queue<ScoreMapIterator, std::vector<ScoreMapIterator>, Compare> top_k_heap(comp); |
104 | | |
105 | 5 | std::vector<uint32_t> zero_score_ids; |
106 | 23 | for (uint32_t row_id : *row_bitmap) { |
107 | 23 | auto it = all_scores.find(row_id); |
108 | 23 | if (it == all_scores.end()) { |
109 | 3 | zero_score_ids.push_back(row_id); |
110 | 3 | continue; |
111 | 3 | } |
112 | 20 | if (top_k_heap.size() < top_k) { |
113 | 13 | top_k_heap.push(it); |
114 | 13 | } else if (comp(it, top_k_heap.top())) { |
115 | 4 | top_k_heap.pop(); |
116 | 4 | top_k_heap.push(it); |
117 | 4 | } |
118 | 20 | } |
119 | | |
120 | 5 | top_k_results.reserve(top_k_heap.size()); |
121 | 18 | while (!top_k_heap.empty()) { |
122 | 13 | auto top = top_k_heap.top(); |
123 | 13 | top_k_results.push_back({top->first, top->second}); |
124 | 13 | top_k_heap.pop(); |
125 | 13 | } |
126 | | |
127 | 5 | if constexpr (order == OrderType::DESC) { |
128 | 4 | std::ranges::reverse(top_k_results); |
129 | | |
130 | 4 | size_t remaining = top_k - top_k_results.size(); |
131 | 7 | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { |
132 | 3 | top_k_results.emplace_back(zero_score_ids[i], 0.0F); |
133 | 3 | } |
134 | 4 | } else { |
135 | 1 | std::vector<std::pair<uint32_t, float>> final_results; |
136 | 1 | final_results.reserve(top_k); |
137 | | |
138 | 1 | size_t zero_count = std::min(top_k, zero_score_ids.size()); |
139 | 1 | for (size_t i = 0; i < zero_count; ++i) { |
140 | 0 | final_results.emplace_back(zero_score_ids[i], 0.0F); |
141 | 0 | } |
142 | | |
143 | 1 | std::ranges::reverse(top_k_results); |
144 | 1 | size_t remaining = top_k - final_results.size(); |
145 | 4 | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { |
146 | 3 | final_results.push_back(top_k_results[i]); |
147 | 3 | } |
148 | | |
149 | 1 | top_k_results = std::move(final_results); |
150 | 1 | } |
151 | 5 | } collection_similarity.cpp:_ZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE1EZNKS0_20get_topn_bm25_scoresEPN7roaring7RoaringERNS_3COWINS_10vectorized7IColumnEE11mutable_ptrIS8_EERSt10unique_ptrISt6vectorImSaImEESt14default_deleteISG_EES2_mE3$_0EEvPKS4_RKN5phmap13flat_hash_mapIjfNSO_4HashIjEENSO_7EqualToIjEESaISt4pairIKjfEEEEmT0_RSE_ISU_IjfESaIS12_EE Line | Count | Source | 98 | 6 | std::vector<std::pair<uint32_t, float>>& top_k_results) const { | 99 | 6 | if (top_k <= 0) { | 100 | 2 | return; | 101 | 2 | } | 102 | | | 103 | 4 | std::priority_queue<ScoreMapIterator, std::vector<ScoreMapIterator>, Compare> top_k_heap(comp); | 104 | | | 105 | 4 | std::vector<uint32_t> zero_score_ids; | 106 | 18 | for (uint32_t row_id : *row_bitmap) { | 107 | 18 | auto it = all_scores.find(row_id); | 108 | 18 | if (it == all_scores.end()) { | 109 | 3 | zero_score_ids.push_back(row_id); | 110 | 3 | continue; | 111 | 3 | } | 112 | 15 | if (top_k_heap.size() < top_k) { | 113 | 10 | top_k_heap.push(it); | 114 | 10 | } else if (comp(it, top_k_heap.top())) { | 115 | 3 | top_k_heap.pop(); | 116 | 3 | top_k_heap.push(it); | 117 | 3 | } | 118 | 15 | } | 119 | | | 120 | 4 | top_k_results.reserve(top_k_heap.size()); | 121 | 14 | while (!top_k_heap.empty()) { | 122 | 10 | auto top = top_k_heap.top(); | 123 | 10 | top_k_results.push_back({top->first, top->second}); | 124 | 10 | top_k_heap.pop(); | 125 | 10 | } | 126 | | | 127 | 4 | if constexpr (order == OrderType::DESC) { | 128 | 4 | std::ranges::reverse(top_k_results); | 129 | | | 130 | 4 | size_t remaining = top_k - top_k_results.size(); | 131 | 7 | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { | 132 | 3 | top_k_results.emplace_back(zero_score_ids[i], 0.0F); | 133 | 3 | } | 134 | | } else { | 135 | | std::vector<std::pair<uint32_t, float>> final_results; | 136 | | final_results.reserve(top_k); | 137 | | | 138 | | size_t zero_count = std::min(top_k, zero_score_ids.size()); | 139 | | for (size_t i = 0; i < zero_count; ++i) { | 140 | | final_results.emplace_back(zero_score_ids[i], 0.0F); | 141 | | } | 142 | | | 143 | | std::ranges::reverse(top_k_results); | 144 | | size_t remaining = top_k - final_results.size(); | 145 | | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { | 146 | | final_results.push_back(top_k_results[i]); | 147 | | } | 148 | | | 149 | | top_k_results = std::move(final_results); | 150 | | } | 151 | 4 | } |
collection_similarity.cpp:_ZNK5doris20CollectionSimilarity17find_top_k_scoresILNS_9OrderTypeE0EZNKS0_20get_topn_bm25_scoresEPN7roaring7RoaringERNS_3COWINS_10vectorized7IColumnEE11mutable_ptrIS8_EERSt10unique_ptrISt6vectorImSaImEESt14default_deleteISG_EES2_mE3$_1EEvPKS4_RKN5phmap13flat_hash_mapIjfNSO_4HashIjEENSO_7EqualToIjEESaISt4pairIKjfEEEEmT0_RSE_ISU_IjfESaIS12_EE Line | Count | Source | 98 | 1 | std::vector<std::pair<uint32_t, float>>& top_k_results) const { | 99 | 1 | if (top_k <= 0) { | 100 | 0 | return; | 101 | 0 | } | 102 | | | 103 | 1 | std::priority_queue<ScoreMapIterator, std::vector<ScoreMapIterator>, Compare> top_k_heap(comp); | 104 | | | 105 | 1 | std::vector<uint32_t> zero_score_ids; | 106 | 5 | for (uint32_t row_id : *row_bitmap) { | 107 | 5 | auto it = all_scores.find(row_id); | 108 | 5 | if (it == all_scores.end()) { | 109 | 0 | zero_score_ids.push_back(row_id); | 110 | 0 | continue; | 111 | 0 | } | 112 | 5 | if (top_k_heap.size() < top_k) { | 113 | 3 | top_k_heap.push(it); | 114 | 3 | } else if (comp(it, top_k_heap.top())) { | 115 | 1 | top_k_heap.pop(); | 116 | 1 | top_k_heap.push(it); | 117 | 1 | } | 118 | 5 | } | 119 | | | 120 | 1 | top_k_results.reserve(top_k_heap.size()); | 121 | 4 | while (!top_k_heap.empty()) { | 122 | 3 | auto top = top_k_heap.top(); | 123 | 3 | top_k_results.push_back({top->first, top->second}); | 124 | 3 | top_k_heap.pop(); | 125 | 3 | } | 126 | | | 127 | | if constexpr (order == OrderType::DESC) { | 128 | | std::ranges::reverse(top_k_results); | 129 | | | 130 | | size_t remaining = top_k - top_k_results.size(); | 131 | | for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) { | 132 | | top_k_results.emplace_back(zero_score_ids[i], 0.0F); | 133 | | } | 134 | 1 | } else { | 135 | 1 | std::vector<std::pair<uint32_t, float>> final_results; | 136 | 1 | final_results.reserve(top_k); | 137 | | | 138 | 1 | size_t zero_count = std::min(top_k, zero_score_ids.size()); | 139 | 1 | for (size_t i = 0; i < zero_count; ++i) { | 140 | 0 | final_results.emplace_back(zero_score_ids[i], 0.0F); | 141 | 0 | } | 142 | | | 143 | 1 | std::ranges::reverse(top_k_results); | 144 | 1 | size_t remaining = top_k - final_results.size(); | 145 | 4 | for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) { | 146 | 3 | final_results.push_back(top_k_results[i]); | 147 | 3 | } | 148 | | | 149 | 1 | top_k_results = std::move(final_results); | 150 | 1 | } | 151 | 1 | } |
|
152 | | |
153 | | #include "common/compile_check_end.h" |
154 | | } // namespace doris |