contrib/faiss/faiss/utils/transpose/transpose-avx2-inl.h
Line | Count | Source |
1 | | /* |
2 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | | * |
4 | | * This source code is licensed under the MIT license found in the |
5 | | * LICENSE file in the root directory of this source tree. |
6 | | */ |
7 | | |
8 | | #pragma once |
9 | | |
10 | | // This file contains transposing kernels for AVX2 for |
11 | | // tiny float/int32 matrices, such as 8x2. |
12 | | |
13 | | #ifdef __AVX2__ |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | namespace faiss { |
18 | | |
19 | | // 8x2 -> 2x8 |
20 | | inline void transpose_8x2( |
21 | | const __m256 i0, |
22 | | const __m256 i1, |
23 | | __m256& o0, |
24 | 0 | __m256& o1) { |
25 | | // say, we have the following as in input: |
26 | | // i0: 00 01 10 11 20 21 30 31 |
27 | | // i1: 40 41 50 51 60 61 70 71 |
28 | | |
29 | | // 00 01 10 11 40 41 50 51 |
30 | 0 | const __m256 r0 = _mm256_permute2f128_ps(i0, i1, _MM_SHUFFLE(0, 2, 0, 0)); |
31 | | // 20 21 30 31 60 61 70 71 |
32 | 0 | const __m256 r1 = _mm256_permute2f128_ps(i0, i1, _MM_SHUFFLE(0, 3, 0, 1)); |
33 | | |
34 | | // 00 10 20 30 40 50 60 70 |
35 | 0 | o0 = _mm256_shuffle_ps(r0, r1, _MM_SHUFFLE(2, 0, 2, 0)); |
36 | | // 01 11 21 31 41 51 61 71 |
37 | 0 | o1 = _mm256_shuffle_ps(r0, r1, _MM_SHUFFLE(3, 1, 3, 1)); |
38 | 0 | } |
39 | | |
40 | | // 8x4 -> 4x8 |
41 | | inline void transpose_8x4( |
42 | | const __m256 i0, |
43 | | const __m256 i1, |
44 | | const __m256 i2, |
45 | | const __m256 i3, |
46 | | __m256& o0, |
47 | | __m256& o1, |
48 | | __m256& o2, |
49 | 0 | __m256& o3) { |
50 | | // say, we have the following as an input: |
51 | | // i0: 00 01 02 03 10 11 12 13 |
52 | | // i1: 20 21 22 23 30 31 32 33 |
53 | | // i2: 40 41 42 43 50 51 52 53 |
54 | | // i3: 60 61 62 63 70 71 72 73 |
55 | | |
56 | | // 00 01 02 03 40 41 42 43 |
57 | 0 | const __m256 r0 = _mm256_permute2f128_ps(i0, i2, _MM_SHUFFLE(0, 2, 0, 0)); |
58 | | // 20 21 22 23 60 61 62 63 |
59 | 0 | const __m256 r1 = _mm256_permute2f128_ps(i1, i3, _MM_SHUFFLE(0, 2, 0, 0)); |
60 | | // 10 11 12 13 50 51 52 53 |
61 | 0 | const __m256 r2 = _mm256_permute2f128_ps(i0, i2, _MM_SHUFFLE(0, 3, 0, 1)); |
62 | | // 30 31 32 33 70 71 72 73 |
63 | 0 | const __m256 r3 = _mm256_permute2f128_ps(i1, i3, _MM_SHUFFLE(0, 3, 0, 1)); |
64 | | |
65 | | // 00 02 10 12 40 42 50 52 |
66 | 0 | const __m256 t0 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(2, 0, 2, 0)); |
67 | | // 01 03 11 13 41 43 51 53 |
68 | 0 | const __m256 t1 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 1, 3, 1)); |
69 | | // 20 22 30 32 60 62 70 72 |
70 | 0 | const __m256 t2 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(2, 0, 2, 0)); |
71 | | // 21 23 31 33 61 63 71 73 |
72 | 0 | const __m256 t3 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 1, 3, 1)); |
73 | | |
74 | | // 00 10 20 30 40 50 60 70 |
75 | 0 | o0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(2, 0, 2, 0)); |
76 | | // 01 11 21 31 41 51 61 71 |
77 | 0 | o1 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(2, 0, 2, 0)); |
78 | | // 02 12 22 32 42 52 62 72 |
79 | 0 | o2 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 1, 3, 1)); |
80 | | // 03 13 23 33 43 53 63 73 |
81 | 0 | o3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 1, 3, 1)); |
82 | 0 | } |
83 | | |
84 | | inline void transpose_8x8( |
85 | | const __m256 i0, |
86 | | const __m256 i1, |
87 | | const __m256 i2, |
88 | | const __m256 i3, |
89 | | const __m256 i4, |
90 | | const __m256 i5, |
91 | | const __m256 i6, |
92 | | const __m256 i7, |
93 | | __m256& o0, |
94 | | __m256& o1, |
95 | | __m256& o2, |
96 | | __m256& o3, |
97 | | __m256& o4, |
98 | | __m256& o5, |
99 | | __m256& o6, |
100 | 0 | __m256& o7) { |
101 | | // say, we have the following as an input: |
102 | | // i0: 00 01 02 03 04 05 06 07 |
103 | | // i1: 10 11 12 13 14 15 16 17 |
104 | | // i2: 20 21 22 23 24 25 26 27 |
105 | | // i3: 30 31 32 33 34 35 36 37 |
106 | | // i4: 40 41 42 43 44 45 46 47 |
107 | | // i5: 50 51 52 53 54 55 56 57 |
108 | | // i6: 60 61 62 63 64 65 66 67 |
109 | | // i7: 70 71 72 73 74 75 76 77 |
110 | | |
111 | | // 00 10 01 11 04 14 05 15 |
112 | 0 | const __m256 r0 = _mm256_unpacklo_ps(i0, i1); |
113 | | // 02 12 03 13 06 16 07 17 |
114 | 0 | const __m256 r1 = _mm256_unpackhi_ps(i0, i1); |
115 | | // 20 30 21 31 24 34 25 35 |
116 | 0 | const __m256 r2 = _mm256_unpacklo_ps(i2, i3); |
117 | | // 22 32 23 33 26 36 27 37 |
118 | 0 | const __m256 r3 = _mm256_unpackhi_ps(i2, i3); |
119 | | // 40 50 41 51 44 54 45 55 |
120 | 0 | const __m256 r4 = _mm256_unpacklo_ps(i4, i5); |
121 | | // 42 52 43 53 46 56 47 57 |
122 | 0 | const __m256 r5 = _mm256_unpackhi_ps(i4, i5); |
123 | | // 60 70 61 71 64 74 65 75 |
124 | 0 | const __m256 r6 = _mm256_unpacklo_ps(i6, i7); |
125 | | // 62 72 63 73 66 76 67 77 |
126 | 0 | const __m256 r7 = _mm256_unpackhi_ps(i6, i7); |
127 | | |
128 | | // 00 10 20 30 04 14 24 34 |
129 | 0 | const __m256 rr0 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); |
130 | | // 01 11 21 31 05 15 25 35 |
131 | 0 | const __m256 rr1 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); |
132 | | // 02 12 22 32 06 16 26 36 |
133 | 0 | const __m256 rr2 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); |
134 | | // 03 13 23 33 07 17 27 37 |
135 | 0 | const __m256 rr3 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); |
136 | | // 40 50 60 70 44 54 64 74 |
137 | 0 | const __m256 rr4 = _mm256_shuffle_ps(r4, r6, _MM_SHUFFLE(1, 0, 1, 0)); |
138 | | // 41 51 61 71 45 55 65 75 |
139 | 0 | const __m256 rr5 = _mm256_shuffle_ps(r4, r6, _MM_SHUFFLE(3, 2, 3, 2)); |
140 | | // 42 52 62 72 46 56 66 76 |
141 | 0 | const __m256 rr6 = _mm256_shuffle_ps(r5, r7, _MM_SHUFFLE(1, 0, 1, 0)); |
142 | | // 43 53 63 73 47 57 67 77 |
143 | 0 | const __m256 rr7 = _mm256_shuffle_ps(r5, r7, _MM_SHUFFLE(3, 2, 3, 2)); |
144 | | |
145 | | // 00 10 20 30 40 50 60 70 |
146 | 0 | o0 = _mm256_permute2f128_ps(rr0, rr4, 0x20); |
147 | | // 01 11 21 31 41 51 61 71 |
148 | 0 | o1 = _mm256_permute2f128_ps(rr1, rr5, 0x20); |
149 | | // 02 12 22 32 42 52 62 72 |
150 | 0 | o2 = _mm256_permute2f128_ps(rr2, rr6, 0x20); |
151 | | // 03 13 23 33 43 53 63 73 |
152 | 0 | o3 = _mm256_permute2f128_ps(rr3, rr7, 0x20); |
153 | | // 04 14 24 34 44 54 64 74 |
154 | 0 | o4 = _mm256_permute2f128_ps(rr0, rr4, 0x31); |
155 | | // 05 15 25 35 45 55 65 75 |
156 | 0 | o5 = _mm256_permute2f128_ps(rr1, rr5, 0x31); |
157 | | // 06 16 26 36 46 56 66 76 |
158 | 0 | o6 = _mm256_permute2f128_ps(rr2, rr6, 0x31); |
159 | | // 07 17 27 37 47 57 67 77 |
160 | | o7 = _mm256_permute2f128_ps(rr3, rr7, 0x31); |
161 | 0 | } |
162 | | |
163 | | } // namespace faiss |
164 | | |
165 | | #endif |