/root/doris/contrib/faiss/faiss/utils/fp16-inl.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cstdint> |
12 | | |
13 | | namespace faiss { |
14 | | |
15 | | // non-intrinsic FP16 <-> FP32 code adapted from |
16 | | // https://github.com/ispc/ispc/blob/master/stdlib.ispc |
17 | | |
18 | | namespace { |
19 | | |
20 | 0 | inline float floatbits(uint32_t x) { |
21 | 0 | void* xptr = &x; |
22 | 0 | return *(float*)xptr; |
23 | 0 | } Unexecuted instantiation: ScalarQuantizer.cpp:_ZN5faiss12_GLOBAL__N_19floatbitsEj Unexecuted instantiation: IndexRowwiseMinMax.cpp:_ZN5faiss12_GLOBAL__N_19floatbitsEj |
24 | | |
25 | 0 | inline uint32_t intbits(float f) { |
26 | 0 | void* fptr = &f; |
27 | 0 | return *(uint32_t*)fptr; |
28 | 0 | } Unexecuted instantiation: ScalarQuantizer.cpp:_ZN5faiss12_GLOBAL__N_17intbitsEf Unexecuted instantiation: IndexRowwiseMinMax.cpp:_ZN5faiss12_GLOBAL__N_17intbitsEf |
29 | | |
30 | | } // namespace |
31 | | |
32 | 0 | inline uint16_t encode_fp16(float f) { |
33 | | // via Fabian "ryg" Giesen. |
34 | | // https://gist.github.com/2156668 |
35 | 0 | uint32_t sign_mask = 0x80000000u; |
36 | 0 | int32_t o; |
37 | |
|
38 | 0 | uint32_t fint = intbits(f); |
39 | 0 | uint32_t sign = fint & sign_mask; |
40 | 0 | fint ^= sign; |
41 | | |
42 | | // NOTE all the integer compares in this function can be safely |
43 | | // compiled into signed compares since all operands are below |
44 | | // 0x80000000. Important if you want fast straight SSE2 code (since |
45 | | // there's no unsigned PCMPGTD). |
46 | | |
47 | | // Inf or NaN (all exponent bits set) |
48 | | // NaN->qNaN and Inf->Inf |
49 | | // unconditional assignment here, will override with right value for |
50 | | // the regular case below. |
51 | 0 | uint32_t f32infty = 255u << 23; |
52 | 0 | o = (fint > f32infty) ? 0x7e00u : 0x7c00u; |
53 | | |
54 | | // (De)normalized number or zero |
55 | | // update fint unconditionally to save the blending; we don't need it |
56 | | // anymore for the Inf/NaN case anyway. |
57 | |
|
58 | 0 | const uint32_t round_mask = ~0xfffu; |
59 | 0 | const uint32_t magic = 15u << 23; |
60 | | |
61 | | // Shift exponent down, denormalize if necessary. |
62 | | // NOTE This represents half-float denormals using single |
63 | | // precision denormals. The main reason to do this is that |
64 | | // there's no shift with per-lane variable shifts in SSE*, which |
65 | | // we'd otherwise need. It has some funky side effects though: |
66 | | // - This conversion will actually respect the FTZ (Flush To Zero) |
67 | | // flag in MXCSR - if it's set, no half-float denormals will be |
68 | | // generated. I'm honestly not sure whether this is good or |
69 | | // bad. It's definitely interesting. |
70 | | // - If the underlying HW doesn't support denormals (not an issue |
71 | | // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs), |
72 | | // you will always get flush-to-zero behavior. This is bad, |
73 | | // unless you're on a CPU where you don't care. |
74 | | // - Denormals tend to be slow. FP32 denormals are rare in |
75 | | // practice outside of things like recursive filters in DSP - |
76 | | // not a typical half-float application. Whether FP16 denormals |
77 | | // are rare in practice, I don't know. Whatever slow path your |
78 | | // HW may or may not have for denormals, this may well hit it. |
79 | 0 | float fscale = floatbits(fint & round_mask) * floatbits(magic); |
80 | 0 | fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u)); |
81 | 0 | int32_t fint2 = intbits(fscale) - round_mask; |
82 | |
|
83 | 0 | if (fint < f32infty) |
84 | 0 | o = fint2 >> 13; // Take the bits! |
85 | |
|
86 | 0 | return (o | (sign >> 16)); |
87 | 0 | } |
88 | | |
89 | 0 | inline float decode_fp16(uint16_t h) { |
90 | | // https://gist.github.com/2144712 |
91 | | // Fabian "ryg" Giesen. |
92 | |
|
93 | 0 | const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift |
94 | |
|
95 | 0 | int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits |
96 | 0 | int32_t exp = shifted_exp & o; // just the exponent |
97 | 0 | o += (int32_t)(127 - 15) << 23; // exponent adjust |
98 | |
|
99 | 0 | int32_t infnan_val = o + ((int32_t)(128 - 16) << 23); |
100 | 0 | int32_t zerodenorm_val = |
101 | 0 | intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23)); |
102 | 0 | int32_t reg_val = (exp == 0) ? zerodenorm_val : o; |
103 | |
|
104 | 0 | int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16; |
105 | 0 | return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit); |
106 | 0 | } |
107 | | |
108 | | } // namespace faiss |