/root/doris/contrib/openblas/interface/gemm.c
Line | Count | Source |
1 | | /*********************************************************************/ |
2 | | /* Copyright 2024, 2025 The OpenBLAS Project */ |
3 | | /* Copyright 2009, 2010 The University of Texas at Austin. */ |
4 | | /* All rights reserved. */ |
5 | | /* */ |
6 | | /* Redistribution and use in source and binary forms, with or */ |
7 | | /* without modification, are permitted provided that the following */ |
8 | | /* conditions are met: */ |
9 | | /* */ |
10 | | /* 1. Redistributions of source code must retain the above */ |
11 | | /* copyright notice, this list of conditions and the following */ |
12 | | /* disclaimer. */ |
13 | | /* */ |
14 | | /* 2. Redistributions in binary form must reproduce the above */ |
15 | | /* copyright notice, this list of conditions and the following */ |
16 | | /* disclaimer in the documentation and/or other materials */ |
17 | | /* provided with the distribution. */ |
18 | | /* */ |
19 | | /* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ |
20 | | /* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ |
21 | | /* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ |
22 | | /* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ |
23 | | /* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ |
24 | | /* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ |
25 | | /* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ |
26 | | /* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ |
27 | | /* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ |
28 | | /* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ |
29 | | /* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ |
30 | | /* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ |
31 | | /* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ |
32 | | /* POSSIBILITY OF SUCH DAMAGE. */ |
33 | | /* */ |
34 | | /* The views and conclusions contained in the software and */ |
35 | | /* documentation are those of the authors and should not be */ |
36 | | /* interpreted as representing official policies, either expressed */ |
37 | | /* or implied, of The University of Texas at Austin. */ |
38 | | /*********************************************************************/ |
39 | | |
40 | | #include <stdio.h> |
41 | | #include <stdlib.h> |
42 | | #include <stdbool.h> |
43 | | #include "common.h" |
44 | | #ifdef FUNCTION_PROFILE |
45 | | #include "functable.h" |
46 | | #endif |
47 | | |
48 | | #ifndef COMPLEX |
49 | 0 | #define SMP_THRESHOLD_MIN 65536.0 |
50 | | #ifdef XDOUBLE |
51 | | #define ERROR_NAME "QGEMM " |
52 | | #define GEMV BLASFUNC(qgemv) |
53 | | #elif defined(DOUBLE) |
54 | 0 | #define ERROR_NAME "DGEMM " |
55 | | #define GEMV BLASFUNC(dgemv) |
56 | | #elif defined(BFLOAT16) |
57 | | #define ERROR_NAME "SBGEMM " |
58 | | #define GEMV BLASFUNC(sbgemv) |
59 | | #else |
60 | 0 | #define ERROR_NAME "SGEMM " |
61 | | #define GEMV BLASFUNC(sgemv) |
62 | | #endif |
63 | | #else |
64 | | #define SMP_THRESHOLD_MIN 8192.0 |
65 | | #ifndef GEMM3M |
66 | | #ifdef XDOUBLE |
67 | | #define ERROR_NAME "XGEMM " |
68 | | #elif defined(DOUBLE) |
69 | | #define ERROR_NAME "ZGEMM " |
70 | | #else |
71 | | #define ERROR_NAME "CGEMM " |
72 | | #endif |
73 | | #else |
74 | | #ifdef XDOUBLE |
75 | | #define ERROR_NAME "XGEMM3M " |
76 | | #elif defined(DOUBLE) |
77 | | #define ERROR_NAME "ZGEMM3M " |
78 | | #else |
79 | | #define ERROR_NAME "CGEMM3M " |
80 | | #endif |
81 | | #endif |
82 | | #endif |
83 | | |
84 | | #ifndef GEMM_MULTITHREAD_THRESHOLD |
85 | | #define GEMM_MULTITHREAD_THRESHOLD 4 |
86 | | #endif |
87 | | |
88 | | static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { |
89 | | #if !defined(GEMM3M) || defined(GENERIC) |
90 | | GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, |
91 | | GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, |
92 | | GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, |
93 | | GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC, |
94 | | #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3) |
95 | | GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN, |
96 | | GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT, |
97 | | GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR, |
98 | | GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC, |
99 | | #endif |
100 | | #else |
101 | | GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN, |
102 | | GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT, |
103 | | GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR, |
104 | | GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC, |
105 | | #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3) |
106 | | GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN, |
107 | | GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT, |
108 | | GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR, |
109 | | GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC, |
110 | | #endif |
111 | | #endif |
112 | | }; |
113 | | |
114 | | #if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) |
115 | | #define USE_SMALL_MATRIX_OPT 1 |
116 | | #else |
117 | | #define USE_SMALL_MATRIX_OPT 0 |
118 | | #endif |
119 | | |
120 | | #if USE_SMALL_MATRIX_OPT |
121 | | #ifndef DYNAMIC_ARCH |
122 | 0 | #define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) |
123 | | #else |
124 | | #define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) |
125 | | #endif |
126 | | |
127 | | |
128 | | #ifndef COMPLEX |
129 | | static size_t gemm_small_kernel[] = { |
130 | | GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, |
131 | | GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, |
132 | | }; |
133 | | |
134 | | |
135 | | static size_t gemm_small_kernel_b0[] = { |
136 | | GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, |
137 | | GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, |
138 | | }; |
139 | | |
140 | 0 | #define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) |
141 | 0 | #define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) |
142 | | #else |
143 | | |
144 | | static size_t zgemm_small_kernel[] = { |
145 | | GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, |
146 | | GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, |
147 | | GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, |
148 | | GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, |
149 | | }; |
150 | | |
151 | | static size_t zgemm_small_kernel_b0[] = { |
152 | | GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, |
153 | | GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, |
154 | | GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, |
155 | | GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, |
156 | | }; |
157 | | |
158 | | #define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) |
159 | | #define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) |
160 | | #endif |
161 | | #endif |
162 | | |
163 | | #if defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16) |
164 | | #define XFEATURE_XTILEDATA 18 |
165 | | #define ARCH_REQ_XCOMP_PERM 0x1023 |
166 | | static int openblas_amxtile_permission = 0; |
167 | | static int init_amxtile_permission() { |
168 | | long status = |
169 | | syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); |
170 | | if (status != 0) { |
171 | | fprintf(stderr, "XTILEDATA permission not granted in your device(Linux, " |
172 | | "Intel Sapphier Rapids), skip sbgemm calculation\n"); |
173 | | return -1; |
174 | | } |
175 | | openblas_amxtile_permission = 1; |
176 | | return 0; |
177 | | } |
178 | | #endif |
179 | | |
180 | | #ifdef SMP |
181 | | #ifdef DYNAMIC_ARCH |
182 | | extern char* gotoblas_corename(void); |
183 | | #endif |
184 | | |
185 | | #if defined(DYNAMIC_ARCH) || defined(NEOVERSEV1) |
186 | | static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) { |
187 | | return |
188 | | MNK < 262144L ? 1 |
189 | | : MNK < 1124864L ? MIN(ncpu, 6) |
190 | | : MNK < 7880599L ? MIN(ncpu, 12) |
191 | | : MNK < 17173512L ? MIN(ncpu, 16) |
192 | | : MNK < 33386248L ? MIN(ncpu, 20) |
193 | | : MNK < 57066625L ? MIN(ncpu, 24) |
194 | | : MNK < 91733851L ? MIN(ncpu, 32) |
195 | | : MNK < 265847707L ? MIN(ncpu, 40) |
196 | | : MNK < 458314011L ? MIN(ncpu, 48) |
197 | | : MNK < 729000000L ? MIN(ncpu, 56) |
198 | | : ncpu; |
199 | | } |
200 | | #endif |
201 | | |
202 | | #if defined(DYNAMIC_ARCH) || defined(NEOVERSEV2) |
203 | | static inline int get_gemm_optimal_nthreads_neoversev2(double MNK, int ncpu) { |
204 | | return |
205 | | MNK < 125000L ? 1 |
206 | | : MNK < 1092727L ? MIN(ncpu, 6) |
207 | | : MNK < 2628072L ? MIN(ncpu, 8) |
208 | | : MNK < 8000000L ? MIN(ncpu, 12) |
209 | | : MNK < 20346417L ? MIN(ncpu, 16) |
210 | | : MNK < 57066625L ? MIN(ncpu, 24) |
211 | | : MNK < 91125000L ? MIN(ncpu, 28) |
212 | | : MNK < 238328000L ? MIN(ncpu, 40) |
213 | | : MNK < 454756609L ? MIN(ncpu, 48) |
214 | | : MNK < 857375000L ? MIN(ncpu, 56) |
215 | | : MNK < 1073741824L ? MIN(ncpu, 64) |
216 | | : ncpu; |
217 | | } |
218 | | #endif |
219 | | |
220 | 0 | static inline int get_gemm_optimal_nthreads(double MNK) { |
221 | 0 | int ncpu = num_cpu_avail(3); |
222 | | #if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) |
223 | | return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); |
224 | | #elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) |
225 | | return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu); |
226 | | #elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) |
227 | | if (strcmp(gotoblas_corename(), "neoversev1") == 0) { |
228 | | return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); |
229 | | } |
230 | | if (strcmp(gotoblas_corename(), "neoversev2") == 0) { |
231 | | return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu); |
232 | | } |
233 | | #endif |
234 | 0 | if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) { |
235 | 0 | return 1; |
236 | 0 | } |
237 | 0 | else { |
238 | 0 | if (MNK/ncpu < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD) { |
239 | 0 | return MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD); |
240 | 0 | } |
241 | 0 | else { |
242 | 0 | return ncpu; |
243 | 0 | } |
244 | 0 | } |
245 | 0 | } Unexecuted instantiation: sgemm.c:get_gemm_optimal_nthreads Unexecuted instantiation: dgemm.c:get_gemm_optimal_nthreads |
246 | | #endif |
247 | | |
248 | | #ifndef CBLAS |
249 | | |
250 | | void NAME(char *TRANSA, char *TRANSB, |
251 | | blasint *M, blasint *N, blasint *K, |
252 | | FLOAT *alpha, |
253 | | IFLOAT *a, blasint *ldA, |
254 | | IFLOAT *b, blasint *ldB, |
255 | | FLOAT *beta, |
256 | 0 | FLOAT *c, blasint *ldC){ |
257 | |
|
258 | 0 | blas_arg_t args; |
259 | |
|
260 | 0 | int transa, transb, nrowa, nrowb; |
261 | 0 | blasint info; |
262 | |
|
263 | 0 | char transA, transB; |
264 | 0 | IFLOAT *buffer; |
265 | 0 | IFLOAT *sa, *sb; |
266 | |
|
267 | 0 | #ifdef SMP |
268 | 0 | double MNK; |
269 | | #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) |
270 | | #ifndef COMPLEX |
271 | | #ifdef XDOUBLE |
272 | | int mode = BLAS_XDOUBLE | BLAS_REAL; |
273 | | #elif defined(DOUBLE) |
274 | | int mode = BLAS_DOUBLE | BLAS_REAL; |
275 | | #else |
276 | | int mode = BLAS_SINGLE | BLAS_REAL; |
277 | | #endif |
278 | | #else |
279 | | #ifdef XDOUBLE |
280 | | int mode = BLAS_XDOUBLE | BLAS_COMPLEX; |
281 | | #elif defined(DOUBLE) |
282 | | int mode = BLAS_DOUBLE | BLAS_COMPLEX; |
283 | | #else |
284 | | int mode = BLAS_SINGLE | BLAS_COMPLEX; |
285 | | #endif |
286 | | #endif |
287 | | #endif |
288 | 0 | #endif |
289 | |
|
290 | | #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3) |
291 | | int nodes; |
292 | | #endif |
293 | |
|
294 | 0 | PRINT_DEBUG_NAME; |
295 | |
|
296 | 0 | args.m = *M; |
297 | 0 | args.n = *N; |
298 | 0 | args.k = *K; |
299 | |
|
300 | 0 | args.a = (void *)a; |
301 | 0 | args.b = (void *)b; |
302 | 0 | args.c = (void *)c; |
303 | |
|
304 | 0 | args.lda = *ldA; |
305 | 0 | args.ldb = *ldB; |
306 | 0 | args.ldc = *ldC; |
307 | |
|
308 | 0 | args.alpha = (void *)alpha; |
309 | 0 | args.beta = (void *)beta; |
310 | |
|
311 | 0 | transA = *TRANSA; |
312 | 0 | transB = *TRANSB; |
313 | |
|
314 | 0 | TOUPPER(transA); |
315 | 0 | TOUPPER(transB); |
316 | |
|
317 | 0 | transa = -1; |
318 | 0 | transb = -1; |
319 | |
|
320 | 0 | if (transA == 'N') transa = 0; |
321 | 0 | if (transA == 'T') transa = 1; |
322 | 0 | #ifndef COMPLEX |
323 | 0 | if (transA == 'R') transa = 0; |
324 | 0 | if (transA == 'C') transa = 1; |
325 | | #else |
326 | | if (transA == 'R') transa = 2; |
327 | | if (transA == 'C') transa = 3; |
328 | | #endif |
329 | |
|
330 | 0 | if (transB == 'N') transb = 0; |
331 | 0 | if (transB == 'T') transb = 1; |
332 | 0 | #ifndef COMPLEX |
333 | 0 | if (transB == 'R') transb = 0; |
334 | 0 | if (transB == 'C') transb = 1; |
335 | | #else |
336 | | if (transB == 'R') transb = 2; |
337 | | if (transB == 'C') transb = 3; |
338 | | #endif |
339 | |
|
340 | 0 | nrowa = args.m; |
341 | 0 | if (transa & 1) nrowa = args.k; |
342 | 0 | nrowb = args.k; |
343 | 0 | if (transb & 1) nrowb = args.n; |
344 | |
|
345 | 0 | info = 0; |
346 | |
|
347 | 0 | if (args.ldc < args.m) info = 13; |
348 | 0 | if (args.ldb < nrowb) info = 10; |
349 | 0 | if (args.lda < nrowa) info = 8; |
350 | 0 | if (args.k < 0) info = 5; |
351 | 0 | if (args.n < 0) info = 4; |
352 | 0 | if (args.m < 0) info = 3; |
353 | 0 | if (transb < 0) info = 2; |
354 | 0 | if (transa < 0) info = 1; |
355 | |
|
356 | 0 | if (info){ |
357 | 0 | BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); |
358 | 0 | return; |
359 | 0 | } |
360 | | |
361 | | #else |
362 | | |
363 | | void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, |
364 | | blasint m, blasint n, blasint k, |
365 | | #ifndef COMPLEX |
366 | | FLOAT alpha, |
367 | | IFLOAT *a, blasint lda, |
368 | | IFLOAT *b, blasint ldb, |
369 | | FLOAT beta, |
370 | | FLOAT *c, blasint ldc) { |
371 | | #else |
372 | | void *valpha, |
373 | | void *va, blasint lda, |
374 | | void *vb, blasint ldb, |
375 | | void *vbeta, |
376 | | void *vc, blasint ldc) { |
377 | | FLOAT *alpha = (FLOAT*) valpha; |
378 | | FLOAT *beta = (FLOAT*) vbeta; |
379 | | FLOAT *a = (FLOAT*) va; |
380 | | FLOAT *b = (FLOAT*) vb; |
381 | | FLOAT *c = (FLOAT*) vc; |
382 | | #endif |
383 | | |
384 | | blas_arg_t args; |
385 | | int transa, transb; |
386 | | blasint nrowa, nrowb, info; |
387 | | |
388 | | XFLOAT *buffer; |
389 | | XFLOAT *sa, *sb; |
390 | | |
391 | | #ifdef SMP |
392 | | double MNK; |
393 | | #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) |
394 | | #ifndef COMPLEX |
395 | | #ifdef XDOUBLE |
396 | | int mode = BLAS_XDOUBLE | BLAS_REAL; |
397 | | #elif defined(DOUBLE) |
398 | | int mode = BLAS_DOUBLE | BLAS_REAL; |
399 | | #else |
400 | | int mode = BLAS_SINGLE | BLAS_REAL; |
401 | | #endif |
402 | | #else |
403 | | #ifdef XDOUBLE |
404 | | int mode = BLAS_XDOUBLE | BLAS_COMPLEX; |
405 | | #elif defined(DOUBLE) |
406 | | int mode = BLAS_DOUBLE | BLAS_COMPLEX; |
407 | | #else |
408 | | int mode = BLAS_SINGLE | BLAS_COMPLEX; |
409 | | #endif |
410 | | #endif |
411 | | #endif |
412 | | #endif |
413 | | |
414 | | #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3) |
415 | | int nodes; |
416 | | #endif |
417 | | |
418 | | PRINT_DEBUG_CNAME; |
419 | | |
420 | | #if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) |
421 | | #if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) |
422 | | #if defined(DYNAMIC_ARCH) |
423 | | if (support_avx512() ) |
424 | | #endif |
425 | | if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) { |
426 | | SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); |
427 | | return; |
428 | | } |
429 | | #endif |
430 | | #if defined(ARCH_ARM64) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) |
431 | | #if defined(DYNAMIC_ARCH) |
432 | | if (support_sme1()) |
433 | | #endif |
434 | | if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { |
435 | | SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); |
436 | | return; |
437 | | } |
438 | | #endif |
439 | | #endif |
440 | | |
441 | | #ifndef COMPLEX |
442 | | args.alpha = (void *)α |
443 | | args.beta = (void *)β |
444 | | #else |
445 | | args.alpha = (void *)alpha; |
446 | | args.beta = (void *)beta; |
447 | | #endif |
448 | | |
449 | | transa = -1; |
450 | | transb = -1; |
451 | | info = 0; |
452 | | |
453 | | if (order == CblasColMajor) { |
454 | | args.m = m; |
455 | | args.n = n; |
456 | | args.k = k; |
457 | | |
458 | | args.a = (void *)a; |
459 | | args.b = (void *)b; |
460 | | args.c = (void *)c; |
461 | | |
462 | | args.lda = lda; |
463 | | args.ldb = ldb; |
464 | | args.ldc = ldc; |
465 | | |
466 | | if (TransA == CblasNoTrans) transa = 0; |
467 | | if (TransA == CblasTrans) transa = 1; |
468 | | #ifndef COMPLEX |
469 | | if (TransA == CblasConjNoTrans) transa = 0; |
470 | | if (TransA == CblasConjTrans) transa = 1; |
471 | | #else |
472 | | if (TransA == CblasConjNoTrans) transa = 2; |
473 | | if (TransA == CblasConjTrans) transa = 3; |
474 | | #endif |
475 | | if (TransB == CblasNoTrans) transb = 0; |
476 | | if (TransB == CblasTrans) transb = 1; |
477 | | #ifndef COMPLEX |
478 | | if (TransB == CblasConjNoTrans) transb = 0; |
479 | | if (TransB == CblasConjTrans) transb = 1; |
480 | | #else |
481 | | if (TransB == CblasConjNoTrans) transb = 2; |
482 | | if (TransB == CblasConjTrans) transb = 3; |
483 | | #endif |
484 | | |
485 | | nrowa = args.m; |
486 | | if (transa & 1) nrowa = args.k; |
487 | | nrowb = args.k; |
488 | | if (transb & 1) nrowb = args.n; |
489 | | |
490 | | info = -1; |
491 | | |
492 | | if (args.ldc < args.m) info = 13; |
493 | | if (args.ldb < nrowb) info = 10; |
494 | | if (args.lda < nrowa) info = 8; |
495 | | if (args.k < 0) info = 5; |
496 | | if (args.n < 0) info = 4; |
497 | | if (args.m < 0) info = 3; |
498 | | if (transb < 0) info = 2; |
499 | | if (transa < 0) info = 1; |
500 | | } |
501 | | |
502 | | if (order == CblasRowMajor) { |
503 | | args.m = n; |
504 | | args.n = m; |
505 | | args.k = k; |
506 | | |
507 | | args.a = (void *)b; |
508 | | args.b = (void *)a; |
509 | | args.c = (void *)c; |
510 | | |
511 | | args.lda = ldb; |
512 | | args.ldb = lda; |
513 | | args.ldc = ldc; |
514 | | |
515 | | if (TransB == CblasNoTrans) transa = 0; |
516 | | if (TransB == CblasTrans) transa = 1; |
517 | | #ifndef COMPLEX |
518 | | if (TransB == CblasConjNoTrans) transa = 0; |
519 | | if (TransB == CblasConjTrans) transa = 1; |
520 | | #else |
521 | | if (TransB == CblasConjNoTrans) transa = 2; |
522 | | if (TransB == CblasConjTrans) transa = 3; |
523 | | #endif |
524 | | if (TransA == CblasNoTrans) transb = 0; |
525 | | if (TransA == CblasTrans) transb = 1; |
526 | | #ifndef COMPLEX |
527 | | if (TransA == CblasConjNoTrans) transb = 0; |
528 | | if (TransA == CblasConjTrans) transb = 1; |
529 | | #else |
530 | | if (TransA == CblasConjNoTrans) transb = 2; |
531 | | if (TransA == CblasConjTrans) transb = 3; |
532 | | #endif |
533 | | |
534 | | nrowa = args.m; |
535 | | if (transa & 1) nrowa = args.k; |
536 | | nrowb = args.k; |
537 | | if (transb & 1) nrowb = args.n; |
538 | | |
539 | | info = -1; |
540 | | |
541 | | if (args.ldc < args.m) info = 13; |
542 | | if (args.ldb < nrowb) info = 10; |
543 | | if (args.lda < nrowa) info = 8; |
544 | | if (args.k < 0) info = 5; |
545 | | if (args.n < 0) info = 4; |
546 | | if (args.m < 0) info = 3; |
547 | | if (transb < 0) info = 2; |
548 | | if (transa < 0) info = 1; |
549 | | |
550 | | } |
551 | | |
552 | | if (info >= 0) { |
553 | | BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); |
554 | | return; |
555 | | } |
556 | | |
557 | | #endif |
558 | | |
559 | | #if defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16) |
560 | | #if defined(DYNAMIC_ARCH) |
561 | | if (gotoblas->need_amxtile_permission && |
562 | | openblas_amxtile_permission == 0 && init_amxtile_permission() == -1) { |
563 | | return; |
564 | | } |
565 | | #endif |
566 | | #if !defined(DYNAMIC_ARCH) && defined(SAPPHIRERAPIDS) |
567 | | if (openblas_amxtile_permission == 0 && init_amxtile_permission() == -1) { |
568 | | return; |
569 | | } |
570 | | #endif |
571 | | #endif // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16) |
572 | | |
573 | 0 | if ((args.m == 0) || (args.n == 0)) return; |
574 | | |
575 | | #if 0 |
576 | | fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n", |
577 | | args.m, args.n, args.k, args.lda, args.ldb, args.ldc); |
578 | | #endif |
579 | | |
580 | | #if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) |
581 | | #if defined(ARCH_ARM64) |
582 | | // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c} |
583 | | // perform poorly in certain circumstances. We use the following boolean |
584 | | // variable along with the gemv argument values to avoid these inefficient |
585 | | // gemv cases, see github issue#4951. |
586 | | bool have_tuned_gemv = false; |
587 | | #else |
588 | | bool have_tuned_gemv = true; |
589 | | #endif |
590 | | // Check if we can convert GEMM -> GEMV |
591 | | if (args.k != 0) { |
592 | | if (args.n == 1) { |
593 | | blasint inc_x = 1; |
594 | | blasint inc_y = 1; |
595 | | // These were passed in as blasint, but the struct translates them to blaslong |
596 | | blasint m = args.m; |
597 | | blasint n = args.k; |
598 | | blasint lda = args.lda; |
599 | | // Create new transpose parameters |
600 | | char NT = 'N'; |
601 | | if (transa & 1) { |
602 | | NT = 'T'; |
603 | | m = args.k; |
604 | | n = args.m; |
605 | | } |
606 | | if (transb & 1) { |
607 | | inc_x = args.ldb; |
608 | | } |
609 | | bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1)); |
610 | | if (is_efficient_gemv) { |
611 | | GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); |
612 | | return; |
613 | | } |
614 | | } |
615 | | if (args.m == 1) { |
616 | | blasint inc_x = args.lda; |
617 | | blasint inc_y = args.ldc; |
618 | | // These were passed in as blasint, but the struct translates them to blaslong |
619 | | blasint m = args.k; |
620 | | blasint n = args.n; |
621 | | blasint ldb = args.ldb; |
622 | | // Create new transpose parameters |
623 | | char NT = 'T'; |
624 | | if (transa & 1) { |
625 | | inc_x = 1; |
626 | | } |
627 | | if (transb & 1) { |
628 | | NT = 'N'; |
629 | | m = args.n; |
630 | | n = args.k; |
631 | | } |
632 | | bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1)); |
633 | | if (is_efficient_gemv) { |
634 | | GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); |
635 | | return; |
636 | | } |
637 | | } |
638 | | } |
639 | | #endif |
640 | | |
641 | 0 | IDEBUG_START; |
642 | |
|
643 | 0 | FUNCTION_PROFILE_START(); |
644 | |
|
645 | 0 | #if USE_SMALL_MATRIX_OPT |
646 | 0 | #if !defined(COMPLEX) |
647 | 0 | if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ |
648 | 0 | if(*(FLOAT *)(args.beta) == 0.0){ |
649 | 0 | (GEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); |
650 | 0 | }else{ |
651 | 0 | (GEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); |
652 | 0 | } |
653 | 0 | return; |
654 | 0 | } |
655 | | #else |
656 | | if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, alpha[0], alpha[1], beta[0], beta[1])){ |
657 | | if(beta[0] == 0.0 && beta[1] == 0.0){ |
658 | | (ZGEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); |
659 | | }else{ |
660 | | (ZGEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); |
661 | | } |
662 | | return; |
663 | | } |
664 | | #endif |
665 | 0 | #endif |
666 | | |
667 | 0 | buffer = (XFLOAT *)blas_memory_alloc(0); |
668 | | |
669 | | //For LOONGARCH64, applying an offset to the buffer is essential |
670 | | //for minimizing cache conflicts and optimizing performance. |
671 | | #if defined(ARCH_LOONGARCH64) && !defined(NO_AFFINITY) |
672 | | sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A); |
673 | | #else |
674 | 0 | sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); |
675 | 0 | #endif |
676 | 0 | sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); |
677 | |
|
678 | 0 | #ifdef SMP |
679 | | #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY) |
680 | | mode |= (transa << BLAS_TRANSA_SHIFT); |
681 | | mode |= (transb << BLAS_TRANSB_SHIFT); |
682 | | #endif |
683 | |
|
684 | 0 | MNK = (double) args.m * (double) args.n * (double) args.k; |
685 | 0 | args.nthreads = get_gemm_optimal_nthreads(MNK); |
686 | |
|
687 | 0 | args.common = NULL; |
688 | |
|
689 | 0 | if (args.nthreads == 1) { |
690 | 0 | #endif |
691 | |
|
692 | 0 | (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0); |
693 | |
|
694 | 0 | #ifdef SMP |
695 | |
|
696 | 0 | } else { |
697 | |
|
698 | 0 | #ifndef USE_SIMPLE_THREADED_LEVEL3 |
699 | |
|
700 | | #ifndef NO_AFFINITY |
701 | | nodes = get_num_nodes(); |
702 | | |
703 | | if ((nodes > 1) && get_node_equal()) { |
704 | | |
705 | | args.nthreads /= nodes; |
706 | | |
707 | | gemm_thread_mn(mode, &args, NULL, NULL, gemm[16 | (transb << 2) | transa], sa, sb, nodes); |
708 | | |
709 | | } else { |
710 | | #endif |
711 | |
|
712 | 0 | (gemm[16 | (transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0); |
713 | |
|
714 | | #else |
715 | | |
716 | | GEMM_THREAD(mode, &args, NULL, NULL, gemm[(transb << 2) | transa], sa, sb, args.nthreads); |
717 | | |
718 | | #endif |
719 | |
|
720 | 0 | #ifndef USE_SIMPLE_THREADED_LEVEL3 |
721 | | #ifndef NO_AFFINITY |
722 | | } |
723 | | #endif |
724 | 0 | #endif |
725 | |
|
726 | 0 | #endif |
727 | |
|
728 | 0 | #ifdef SMP |
729 | 0 | } |
730 | 0 | #endif |
731 | |
|
732 | 0 | blas_memory_free(buffer); |
733 | |
|
734 | 0 | FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, args.m * args.k + args.k * args.n + args.m * args.n, 2 * args.m * args.n * args.k); |
735 | |
|
736 | 0 | IDEBUG_END; |
737 | |
|
738 | 0 | return; |
739 | 0 | } Unexecuted instantiation: sgemm_ Unexecuted instantiation: dgemm_ |