Coverage Report

Created: 2025-10-16 17:20

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/IndexRaBitQ.cpp
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#include <faiss/IndexRaBitQ.h>
9
10
#include <faiss/impl/FaissAssert.h>
11
#include <faiss/impl/ResultHandler.h>
12
13
namespace faiss {
14
15
0
IndexRaBitQ::IndexRaBitQ() = default;
16
17
IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric)
18
0
        : IndexFlatCodes(0, d, metric), rabitq(d, metric) {
19
0
    code_size = rabitq.code_size;
20
21
0
    is_trained = false;
22
0
}
23
24
0
void IndexRaBitQ::train(idx_t n, const float* x) {
25
    // compute a centroid
26
0
    std::vector<float> centroid(d, 0);
27
0
    for (size_t i = 0; i < n; i++) {
28
0
        for (size_t j = 0; j < d; j++) {
29
0
            centroid[j] += x[i * d + j];
30
0
        }
31
0
    }
32
33
0
    if (n != 0) {
34
0
        for (size_t j = 0; j < d; j++) {
35
0
            centroid[j] /= (float)n;
36
0
        }
37
0
    }
38
39
0
    center = std::move(centroid);
40
41
    //
42
0
    rabitq.train(n, x);
43
0
    is_trained = true;
44
0
}
45
46
0
void IndexRaBitQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
47
0
    FAISS_THROW_IF_NOT(is_trained);
48
0
    rabitq.compute_codes_core(x, bytes, n, center.data());
49
0
}
50
51
0
void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
52
0
    FAISS_THROW_IF_NOT(is_trained);
53
0
    rabitq.decode_core(bytes, x, n, center.data());
54
0
}
55
56
0
FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
57
0
    FlatCodesDistanceComputer* dc =
58
0
            rabitq.get_distance_computer(qb, center.data());
59
0
    dc->code_size = rabitq.code_size;
60
0
    dc->codes = codes.data();
61
0
    return dc;
62
0
}
63
64
FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
65
0
        const uint8_t qb) const {
66
0
    FlatCodesDistanceComputer* dc =
67
0
            rabitq.get_distance_computer(qb, center.data());
68
0
    dc->code_size = rabitq.code_size;
69
0
    dc->codes = codes.data();
70
0
    return dc;
71
0
}
72
73
namespace {
74
75
struct Run_search_with_dc_res {
76
    using T = void;
77
78
    uint8_t qb = 0;
79
80
    template <class BlockResultHandler>
81
0
    void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
82
0
        size_t ntotal = index->ntotal;
83
0
        using SingleResultHandler =
84
0
                typename BlockResultHandler::SingleResultHandler;
85
0
        const int d = index->d;
86
87
0
#pragma omp parallel // if (res.nq > 100)
88
0
        {
89
0
            std::unique_ptr<FlatCodesDistanceComputer> dc(
90
0
                    index->get_quantized_distance_computer(qb));
91
0
            SingleResultHandler resi(res);
92
0
#pragma omp for
93
0
            for (int64_t q = 0; q < res.nq; q++) {
94
0
                resi.begin(q);
95
0
                dc->set_query(xq + d * q);
96
0
                for (size_t i = 0; i < ntotal; i++) {
97
0
                    if (res.is_in_selection(i)) {
98
0
                        float dis = (*dc)(i);
99
0
                        resi.add_result(dis, i);
100
0
                    }
101
0
                }
102
0
                resi.end();
103
0
            }
104
0
        }
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf.omp_outlined_debug__
105
0
    }
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22Top1BlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_22HeapBlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_27ReservoirBlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMinIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMinIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb1EEEEEvRT_PKNS_11IndexRaBitQEPKf
Unexecuted instantiation: IndexRaBitQ.cpp:_ZN5faiss12_GLOBAL__N_122Run_search_with_dc_res1fINS_29RangeSearchBlockResultHandlerINS_4CMaxIflEELb0EEEEEvRT_PKNS_11IndexRaBitQEPKf
106
};
107
108
} // namespace
109
110
void IndexRaBitQ::search(
111
        idx_t n,
112
        const float* x,
113
        idx_t k,
114
        float* distances,
115
        idx_t* labels,
116
0
        const SearchParameters* params_in) const {
117
0
    uint8_t used_qb = qb;
118
0
    if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
119
0
        used_qb = params->qb;
120
0
    }
121
122
0
    const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
123
0
    Run_search_with_dc_res r;
124
0
    r.qb = used_qb;
125
126
0
    dispatch_knn_ResultHandler(
127
0
            n, distances, labels, k, metric_type, sel, r, this, x);
128
0
}
129
130
void IndexRaBitQ::range_search(
131
        idx_t n,
132
        const float* x,
133
        float radius,
134
        RangeSearchResult* result,
135
0
        const SearchParameters* params_in) const {
136
0
    uint8_t used_qb = qb;
137
0
    if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
138
0
        used_qb = params->qb;
139
0
    }
140
141
0
    const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
142
0
    Run_search_with_dc_res r;
143
0
    r.qb = used_qb;
144
145
0
    dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
146
0
}
147
148
} // namespace faiss