Coverage Report

Created: 2026-01-05 15:17

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/faiss/faiss/impl/ScalarQuantizer.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
// -*- c++ -*-
9
10
#pragma once
11
12
#include <faiss/impl/AuxIndexStructures.h>
13
#include <faiss/impl/DistanceComputer.h>
14
#include <faiss/impl/Quantizer.h>
15
16
namespace faiss {
17
18
struct InvertedListScanner;
19
20
/**
21
 * The uniform quantizer has a range [vmin, vmax]. The range can be
22
 * the same for all dimensions (uniform) or specific per dimension
23
 * (default).
24
 */
25
26
struct ScalarQuantizer : Quantizer {
27
    enum QuantizerType {
28
        QT_8bit,         ///< 8 bits per component
29
        QT_4bit,         ///< 4 bits per component
30
        QT_8bit_uniform, ///< same, shared range for all dimensions
31
        QT_4bit_uniform,
32
        QT_fp16,
33
        QT_8bit_direct, ///< fast indexing of uint8s
34
        QT_6bit,        ///< 6 bits per component
35
        QT_bf16,
36
        QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
37
                               ///< [-128 to 127]
38
    };
39
40
    QuantizerType qtype = QT_8bit;
41
42
    /** The uniform encoder can estimate the range of representable
43
     * values of the unform encoder using different statistics. Here
44
     * rs = rangestat_arg */
45
46
    // rangestat_arg.
47
    enum RangeStat {
48
        RS_minmax,    ///< [min - rs*(max-min), max + rs*(max-min)]
49
        RS_meanstd,   ///< [mean - std * rs, mean + std * rs]
50
        RS_quantiles, ///< [Q(rs), Q(1-rs)]
51
        RS_optim,     ///< alternate optimization of reconstruction error
52
    };
53
54
    RangeStat rangestat = RS_minmax;
55
    float rangestat_arg = 0;
56
57
    /// bits per scalar code
58
    size_t bits = 0;
59
60
    /// trained values (including the range)
61
    std::vector<float> trained;
62
63
    ScalarQuantizer(size_t d, QuantizerType qtype);
64
    ScalarQuantizer();
65
66
    /// updates internal values based on qtype and d
67
    void set_derived_sizes();
68
69
    void train(size_t n, const float* x) override;
70
71
    /** Encode a set of vectors
72
     *
73
     * @param x      vectors to encode, size n * d
74
     * @param codes  output codes, size n * code_size
75
     */
76
    void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
77
78
    /** Decode a set of vectors
79
     *
80
     * @param codes  codes to decode, size n * code_size
81
     * @param x      output vectors, size n * d
82
     */
83
    void decode(const uint8_t* code, float* x, size_t n) const override;
84
85
    /*****************************************************
86
     * Objects that provide methods for encoding/decoding, distance
87
     * computation and inverted list scanning
88
     *****************************************************/
89
90
    struct SQuantizer {
91
        // encodes one vector. Assumes code is filled with 0s on input!
92
        virtual void encode_vector(const float* x, uint8_t* code) const = 0;
93
        virtual void decode_vector(const uint8_t* code, float* x) const = 0;
94
95
24
        virtual ~SQuantizer() {}
96
    };
97
98
    SQuantizer* select_quantizer() const;
99
100
    struct SQDistanceComputer : FlatCodesDistanceComputer {
101
        const float* q;
102
103
20
        SQDistanceComputer() : q(nullptr) {}
104
105
        virtual float query_to_code(const uint8_t* code) const = 0;
106
107
12.2k
        float distance_to_code(const uint8_t* code) final {
108
12.2k
            return query_to_code(code);
109
12.2k
        }
110
    };
111
112
    SQDistanceComputer* get_distance_computer(
113
            MetricType metric = METRIC_L2) const;
114
115
    InvertedListScanner* select_InvertedListScanner(
116
            MetricType mt,
117
            const Index* quantizer,
118
            bool store_pairs,
119
            const IDSelector* sel,
120
            bool by_residual = false) const;
121
};
122
123
} // namespace faiss