/root/doris/contrib/faiss/faiss/impl/ScalarQuantizer.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | // -*- c++ -*- |
9 | | |
10 | | #pragma once |
11 | | |
12 | | #include <faiss/impl/AuxIndexStructures.h> |
13 | | #include <faiss/impl/DistanceComputer.h> |
14 | | #include <faiss/impl/Quantizer.h> |
15 | | |
16 | | namespace faiss { |
17 | | |
18 | | struct InvertedListScanner; |
19 | | |
20 | | /** |
21 | | * The uniform quantizer has a range [vmin, vmax]. The range can be |
22 | | * the same for all dimensions (uniform) or specific per dimension |
23 | | * (default). |
24 | | */ |
25 | | |
26 | | struct ScalarQuantizer : Quantizer { |
27 | | enum QuantizerType { |
28 | | QT_8bit, ///< 8 bits per component |
29 | | QT_4bit, ///< 4 bits per component |
30 | | QT_8bit_uniform, ///< same, shared range for all dimensions |
31 | | QT_4bit_uniform, |
32 | | QT_fp16, |
33 | | QT_8bit_direct, ///< fast indexing of uint8s |
34 | | QT_6bit, ///< 6 bits per component |
35 | | QT_bf16, |
36 | | QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from |
37 | | ///< [-128 to 127] |
38 | | }; |
39 | | |
40 | | QuantizerType qtype = QT_8bit; |
41 | | |
42 | | /** The uniform encoder can estimate the range of representable |
43 | | * values of the unform encoder using different statistics. Here |
44 | | * rs = rangestat_arg */ |
45 | | |
46 | | // rangestat_arg. |
47 | | enum RangeStat { |
48 | | RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)] |
49 | | RS_meanstd, ///< [mean - std * rs, mean + std * rs] |
50 | | RS_quantiles, ///< [Q(rs), Q(1-rs)] |
51 | | RS_optim, ///< alternate optimization of reconstruction error |
52 | | }; |
53 | | |
54 | | RangeStat rangestat = RS_minmax; |
55 | | float rangestat_arg = 0; |
56 | | |
57 | | /// bits per scalar code |
58 | | size_t bits = 0; |
59 | | |
60 | | /// trained values (including the range) |
61 | | std::vector<float> trained; |
62 | | |
63 | | ScalarQuantizer(size_t d, QuantizerType qtype); |
64 | | ScalarQuantizer(); |
65 | | |
66 | | /// updates internal values based on qtype and d |
67 | | void set_derived_sizes(); |
68 | | |
69 | | void train(size_t n, const float* x) override; |
70 | | |
71 | | /** Encode a set of vectors |
72 | | * |
73 | | * @param x vectors to encode, size n * d |
74 | | * @param codes output codes, size n * code_size |
75 | | */ |
76 | | void compute_codes(const float* x, uint8_t* codes, size_t n) const override; |
77 | | |
78 | | /** Decode a set of vectors |
79 | | * |
80 | | * @param codes codes to decode, size n * code_size |
81 | | * @param x output vectors, size n * d |
82 | | */ |
83 | | void decode(const uint8_t* code, float* x, size_t n) const override; |
84 | | |
85 | | /***************************************************** |
86 | | * Objects that provide methods for encoding/decoding, distance |
87 | | * computation and inverted list scanning |
88 | | *****************************************************/ |
89 | | |
90 | | struct SQuantizer { |
91 | | // encodes one vector. Assumes code is filled with 0s on input! |
92 | | virtual void encode_vector(const float* x, uint8_t* code) const = 0; |
93 | | virtual void decode_vector(const uint8_t* code, float* x) const = 0; |
94 | | |
95 | 24 | virtual ~SQuantizer() {} |
96 | | }; |
97 | | |
98 | | SQuantizer* select_quantizer() const; |
99 | | |
100 | | struct SQDistanceComputer : FlatCodesDistanceComputer { |
101 | | const float* q; |
102 | | |
103 | 20 | SQDistanceComputer() : q(nullptr) {} |
104 | | |
105 | | virtual float query_to_code(const uint8_t* code) const = 0; |
106 | | |
107 | 12.2k | float distance_to_code(const uint8_t* code) final { |
108 | 12.2k | return query_to_code(code); |
109 | 12.2k | } |
110 | | }; |
111 | | |
112 | | SQDistanceComputer* get_distance_computer( |
113 | | MetricType metric = METRIC_L2) const; |
114 | | |
115 | | InvertedListScanner* select_InvertedListScanner( |
116 | | MetricType mt, |
117 | | const Index* quantizer, |
118 | | bool store_pairs, |
119 | | const IDSelector* sel, |
120 | | bool by_residual = false) const; |
121 | | }; |
122 | | |
123 | | } // namespace faiss |