Coverage Report

Created: 2026-03-17 19:28

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
contrib/faiss/faiss/utils/transpose/transpose-avx2-inl.h
Line
Count
Source
1
/*
2
 * Copyright (c) Meta Platforms, Inc. and affiliates.
3
 *
4
 * This source code is licensed under the MIT license found in the
5
 * LICENSE file in the root directory of this source tree.
6
 */
7
8
#pragma once
9
10
// This file contains transposing kernels for AVX2 for
11
// tiny float/int32 matrices, such as 8x2.
12
13
#ifdef __AVX2__
14
15
#include <immintrin.h>
16
17
namespace faiss {
18
19
// 8x2 -> 2x8
20
inline void transpose_8x2(
21
        const __m256 i0,
22
        const __m256 i1,
23
        __m256& o0,
24
0
        __m256& o1) {
25
    // say, we have the following as in input:
26
    // i0:  00 01 10 11 20 21 30 31
27
    // i1:  40 41 50 51 60 61 70 71
28
29
    // 00 01 10 11 40 41 50 51
30
0
    const __m256 r0 = _mm256_permute2f128_ps(i0, i1, _MM_SHUFFLE(0, 2, 0, 0));
31
    // 20 21 30 31 60 61 70 71
32
0
    const __m256 r1 = _mm256_permute2f128_ps(i0, i1, _MM_SHUFFLE(0, 3, 0, 1));
33
34
    // 00 10 20 30 40 50 60 70
35
0
    o0 = _mm256_shuffle_ps(r0, r1, _MM_SHUFFLE(2, 0, 2, 0));
36
    // 01 11 21 31 41 51 61 71
37
0
    o1 = _mm256_shuffle_ps(r0, r1, _MM_SHUFFLE(3, 1, 3, 1));
38
0
}
39
40
// 8x4 -> 4x8
41
inline void transpose_8x4(
42
        const __m256 i0,
43
        const __m256 i1,
44
        const __m256 i2,
45
        const __m256 i3,
46
        __m256& o0,
47
        __m256& o1,
48
        __m256& o2,
49
0
        __m256& o3) {
50
    // say, we have the following as an input:
51
    // i0:  00 01 02 03 10 11 12 13
52
    // i1:  20 21 22 23 30 31 32 33
53
    // i2:  40 41 42 43 50 51 52 53
54
    // i3:  60 61 62 63 70 71 72 73
55
56
    // 00 01 02 03 40 41 42 43
57
0
    const __m256 r0 = _mm256_permute2f128_ps(i0, i2, _MM_SHUFFLE(0, 2, 0, 0));
58
    // 20 21 22 23 60 61 62 63
59
0
    const __m256 r1 = _mm256_permute2f128_ps(i1, i3, _MM_SHUFFLE(0, 2, 0, 0));
60
    // 10 11 12 13 50 51 52 53
61
0
    const __m256 r2 = _mm256_permute2f128_ps(i0, i2, _MM_SHUFFLE(0, 3, 0, 1));
62
    // 30 31 32 33 70 71 72 73
63
0
    const __m256 r3 = _mm256_permute2f128_ps(i1, i3, _MM_SHUFFLE(0, 3, 0, 1));
64
65
    // 00 02 10 12 40 42 50 52
66
0
    const __m256 t0 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(2, 0, 2, 0));
67
    // 01 03 11 13 41 43 51 53
68
0
    const __m256 t1 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 1, 3, 1));
69
    // 20 22 30 32 60 62 70 72
70
0
    const __m256 t2 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(2, 0, 2, 0));
71
    // 21 23 31 33 61 63 71 73
72
0
    const __m256 t3 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 1, 3, 1));
73
74
    // 00 10 20 30 40 50 60 70
75
0
    o0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(2, 0, 2, 0));
76
    // 01 11 21 31 41 51 61 71
77
0
    o1 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(2, 0, 2, 0));
78
    // 02 12 22 32 42 52 62 72
79
0
    o2 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 1, 3, 1));
80
    // 03 13 23 33 43 53 63 73
81
0
    o3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 1, 3, 1));
82
0
}
83
84
inline void transpose_8x8(
85
        const __m256 i0,
86
        const __m256 i1,
87
        const __m256 i2,
88
        const __m256 i3,
89
        const __m256 i4,
90
        const __m256 i5,
91
        const __m256 i6,
92
        const __m256 i7,
93
        __m256& o0,
94
        __m256& o1,
95
        __m256& o2,
96
        __m256& o3,
97
        __m256& o4,
98
        __m256& o5,
99
        __m256& o6,
100
0
        __m256& o7) {
101
    // say, we have the following as an input:
102
    // i0:  00 01 02 03 04 05 06 07
103
    // i1:  10 11 12 13 14 15 16 17
104
    // i2:  20 21 22 23 24 25 26 27
105
    // i3:  30 31 32 33 34 35 36 37
106
    // i4:  40 41 42 43 44 45 46 47
107
    // i5:  50 51 52 53 54 55 56 57
108
    // i6:  60 61 62 63 64 65 66 67
109
    // i7:  70 71 72 73 74 75 76 77
110
111
    // 00 10 01 11 04 14 05 15
112
0
    const __m256 r0 = _mm256_unpacklo_ps(i0, i1);
113
    // 02 12 03 13 06 16 07 17
114
0
    const __m256 r1 = _mm256_unpackhi_ps(i0, i1);
115
    // 20 30 21 31 24 34 25 35
116
0
    const __m256 r2 = _mm256_unpacklo_ps(i2, i3);
117
    // 22 32 23 33 26 36 27 37
118
0
    const __m256 r3 = _mm256_unpackhi_ps(i2, i3);
119
    // 40 50 41 51 44 54 45 55
120
0
    const __m256 r4 = _mm256_unpacklo_ps(i4, i5);
121
    // 42 52 43 53 46 56 47 57
122
0
    const __m256 r5 = _mm256_unpackhi_ps(i4, i5);
123
    // 60 70 61 71 64 74 65 75
124
0
    const __m256 r6 = _mm256_unpacklo_ps(i6, i7);
125
    // 62 72 63 73 66 76 67 77
126
0
    const __m256 r7 = _mm256_unpackhi_ps(i6, i7);
127
128
    // 00 10 20 30 04 14 24 34
129
0
    const __m256 rr0 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0));
130
    // 01 11 21 31 05 15 25 35
131
0
    const __m256 rr1 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2));
132
    // 02 12 22 32 06 16 26 36
133
0
    const __m256 rr2 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0));
134
    // 03 13 23 33 07 17 27 37
135
0
    const __m256 rr3 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2));
136
    // 40 50 60 70 44 54 64 74
137
0
    const __m256 rr4 = _mm256_shuffle_ps(r4, r6, _MM_SHUFFLE(1, 0, 1, 0));
138
    // 41 51 61 71 45 55 65 75
139
0
    const __m256 rr5 = _mm256_shuffle_ps(r4, r6, _MM_SHUFFLE(3, 2, 3, 2));
140
    // 42 52 62 72 46 56 66 76
141
0
    const __m256 rr6 = _mm256_shuffle_ps(r5, r7, _MM_SHUFFLE(1, 0, 1, 0));
142
    // 43 53 63 73 47 57 67 77
143
0
    const __m256 rr7 = _mm256_shuffle_ps(r5, r7, _MM_SHUFFLE(3, 2, 3, 2));
144
145
    // 00 10 20 30 40 50 60 70
146
0
    o0 = _mm256_permute2f128_ps(rr0, rr4, 0x20);
147
    // 01 11 21 31 41 51 61 71
148
0
    o1 = _mm256_permute2f128_ps(rr1, rr5, 0x20);
149
    // 02 12 22 32 42 52 62 72
150
0
    o2 = _mm256_permute2f128_ps(rr2, rr6, 0x20);
151
    // 03 13 23 33 43 53 63 73
152
0
    o3 = _mm256_permute2f128_ps(rr3, rr7, 0x20);
153
    // 04 14 24 34 44 54 64 74
154
0
    o4 = _mm256_permute2f128_ps(rr0, rr4, 0x31);
155
    // 05 15 25 35 45 55 65 75
156
0
    o5 = _mm256_permute2f128_ps(rr1, rr5, 0x31);
157
    // 06 16 26 36 46 56 66 76
158
0
    o6 = _mm256_permute2f128_ps(rr2, rr6, 0x31);
159
    // 07 17 27 37 47 57 67 77
160
    o7 = _mm256_permute2f128_ps(rr3, rr7, 0x31);
161
0
}
162
163
} // namespace faiss
164
165
#endif