Coverage Report

Created: 2025-08-27 05:46

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/contrib/openblas/interface/gemm.c
Line
Count
Source
1
/*********************************************************************/
2
/* Copyright 2024, 2025 The OpenBLAS Project                         */
3
/* Copyright 2009, 2010 The University of Texas at Austin.           */
4
/* All rights reserved.                                              */
5
/*                                                                   */
6
/* Redistribution and use in source and binary forms, with or        */
7
/* without modification, are permitted provided that the following   */
8
/* conditions are met:                                               */
9
/*                                                                   */
10
/*   1. Redistributions of source code must retain the above         */
11
/*      copyright notice, this list of conditions and the following  */
12
/*      disclaimer.                                                  */
13
/*                                                                   */
14
/*   2. Redistributions in binary form must reproduce the above      */
15
/*      copyright notice, this list of conditions and the following  */
16
/*      disclaimer in the documentation and/or other materials       */
17
/*      provided with the distribution.                              */
18
/*                                                                   */
19
/*    THIS  SOFTWARE IS PROVIDED  BY THE  UNIVERSITY OF  TEXAS AT    */
20
/*    AUSTIN  ``AS IS''  AND ANY  EXPRESS OR  IMPLIED WARRANTIES,    */
21
/*    INCLUDING, BUT  NOT LIMITED  TO, THE IMPLIED  WARRANTIES OF    */
22
/*    MERCHANTABILITY  AND FITNESS FOR  A PARTICULAR  PURPOSE ARE    */
23
/*    DISCLAIMED.  IN  NO EVENT SHALL THE UNIVERSITY  OF TEXAS AT    */
24
/*    AUSTIN OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT,    */
25
/*    INCIDENTAL,  SPECIAL, EXEMPLARY,  OR  CONSEQUENTIAL DAMAGES    */
26
/*    (INCLUDING, BUT  NOT LIMITED TO,  PROCUREMENT OF SUBSTITUTE    */
27
/*    GOODS  OR  SERVICES; LOSS  OF  USE,  DATA,  OR PROFITS;  OR    */
28
/*    BUSINESS INTERRUPTION) HOWEVER CAUSED  AND ON ANY THEORY OF    */
29
/*    LIABILITY, WHETHER  IN CONTRACT, STRICT  LIABILITY, OR TORT    */
30
/*    (INCLUDING NEGLIGENCE OR OTHERWISE)  ARISING IN ANY WAY OUT    */
31
/*    OF  THE  USE OF  THIS  SOFTWARE,  EVEN  IF ADVISED  OF  THE    */
32
/*    POSSIBILITY OF SUCH DAMAGE.                                    */
33
/*                                                                   */
34
/* The views and conclusions contained in the software and           */
35
/* documentation are those of the authors and should not be          */
36
/* interpreted as representing official policies, either expressed   */
37
/* or implied, of The University of Texas at Austin.                 */
38
/*********************************************************************/
39
40
#include <stdio.h>
41
#include <stdlib.h>
42
#include <stdbool.h>
43
#include "common.h"
44
#ifdef FUNCTION_PROFILE
45
#include "functable.h"
46
#endif
47
48
#ifndef COMPLEX
49
0
#define SMP_THRESHOLD_MIN 65536.0
50
#ifdef XDOUBLE
51
#define ERROR_NAME "QGEMM "
52
#define GEMV BLASFUNC(qgemv)
53
#elif defined(DOUBLE)
54
0
#define ERROR_NAME "DGEMM "
55
#define GEMV BLASFUNC(dgemv)
56
#elif defined(BFLOAT16)
57
#define ERROR_NAME "SBGEMM "
58
#define GEMV BLASFUNC(sbgemv)
59
#else
60
0
#define ERROR_NAME "SGEMM "
61
#define GEMV BLASFUNC(sgemv)
62
#endif
63
#else
64
#define SMP_THRESHOLD_MIN 8192.0
65
#ifndef GEMM3M
66
#ifdef XDOUBLE
67
#define ERROR_NAME "XGEMM "
68
#elif defined(DOUBLE)
69
#define ERROR_NAME "ZGEMM "
70
#else
71
#define ERROR_NAME "CGEMM "
72
#endif
73
#else
74
#ifdef XDOUBLE
75
#define ERROR_NAME "XGEMM3M "
76
#elif defined(DOUBLE)
77
#define ERROR_NAME "ZGEMM3M "
78
#else
79
#define ERROR_NAME "CGEMM3M "
80
#endif
81
#endif
82
#endif
83
84
#ifndef GEMM_MULTITHREAD_THRESHOLD
85
#define GEMM_MULTITHREAD_THRESHOLD 4
86
#endif
87
88
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
89
#if !defined(GEMM3M) || defined(GENERIC)
90
  GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
91
  GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
92
  GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
93
  GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
94
#if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
95
  GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
96
  GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
97
  GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
98
  GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
99
#endif
100
#else
101
  GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
102
  GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
103
  GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
104
  GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
105
#if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
106
  GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
107
  GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
108
  GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
109
  GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
110
#endif
111
#endif
112
};
113
114
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE)
115
#define USE_SMALL_MATRIX_OPT 1
116
#else
117
#define USE_SMALL_MATRIX_OPT 0
118
#endif
119
120
#if USE_SMALL_MATRIX_OPT
121
#ifndef DYNAMIC_ARCH
122
0
#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx]))
123
#else
124
#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx]))))
125
#endif
126
127
128
#ifndef COMPLEX
129
static size_t gemm_small_kernel[] = {
130
  GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0,
131
  GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0,
132
};
133
134
135
static size_t gemm_small_kernel_b0[] = {
136
  GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0,
137
  GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0,
138
};
139
140
0
#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx))
141
0
#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx))
142
#else
143
144
static size_t zgemm_small_kernel[] = {
145
  GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN,
146
  GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT,
147
  GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR,
148
  GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC,
149
};
150
151
static size_t zgemm_small_kernel_b0[] = {
152
  GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN,
153
  GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT,
154
  GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR,
155
  GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC,
156
};
157
158
#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx))
159
#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx))
160
#endif
161
#endif
162
163
#if defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
164
#define XFEATURE_XTILEDATA 18
165
#define ARCH_REQ_XCOMP_PERM 0x1023
166
static int openblas_amxtile_permission = 0;
167
static int init_amxtile_permission() {
168
  long status =
169
      syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
170
  if (status != 0) {
171
    fprintf(stderr, "XTILEDATA permission not granted in your device(Linux, "
172
                    "Intel Sapphier Rapids), skip sbgemm calculation\n");
173
    return -1;
174
  }
175
  openblas_amxtile_permission = 1;
176
  return 0;
177
}
178
#endif
179
180
#ifdef SMP
181
#ifdef DYNAMIC_ARCH
182
extern char* gotoblas_corename(void);
183
#endif
184
185
#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV1)
186
static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) {
187
  return
188
      MNK < 262144L    ? 1
189
    : MNK < 1124864L   ? MIN(ncpu, 6)
190
    : MNK < 7880599L   ? MIN(ncpu, 12)
191
    : MNK < 17173512L  ? MIN(ncpu, 16)
192
    : MNK < 33386248L  ? MIN(ncpu, 20)
193
    : MNK < 57066625L  ? MIN(ncpu, 24)
194
    : MNK < 91733851L  ? MIN(ncpu, 32)
195
    : MNK < 265847707L ? MIN(ncpu, 40)
196
    : MNK < 458314011L ? MIN(ncpu, 48)
197
    : MNK < 729000000L ? MIN(ncpu, 56)
198
    : ncpu;
199
}
200
#endif
201
202
#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV2)
203
static inline int get_gemm_optimal_nthreads_neoversev2(double MNK, int ncpu) {
204
  return
205
      MNK < 125000L     ? 1
206
    : MNK < 1092727L    ? MIN(ncpu, 6)
207
    : MNK < 2628072L    ? MIN(ncpu, 8)
208
    : MNK < 8000000L    ? MIN(ncpu, 12)
209
    : MNK < 20346417L   ? MIN(ncpu, 16)
210
    : MNK < 57066625L   ? MIN(ncpu, 24)
211
    : MNK < 91125000L   ? MIN(ncpu, 28)
212
    : MNK < 238328000L  ? MIN(ncpu, 40)
213
    : MNK < 454756609L  ? MIN(ncpu, 48)
214
    : MNK < 857375000L  ? MIN(ncpu, 56)
215
    : MNK < 1073741824L ? MIN(ncpu, 64)
216
    : ncpu;
217
}
218
#endif
219
220
0
static inline int get_gemm_optimal_nthreads(double MNK) {
221
0
  int ncpu = num_cpu_avail(3);
222
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
223
  return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
224
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
225
  return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu);
226
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
227
  if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
228
    return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
229
  }
230
  if (strcmp(gotoblas_corename(), "neoversev2") == 0) {
231
    return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu);
232
  }
233
#endif
234
0
  if ( MNK <= (SMP_THRESHOLD_MIN  * (double) GEMM_MULTITHREAD_THRESHOLD) ) {
235
0
    return 1;
236
0
  }
237
0
  else {
238
0
    if (MNK/ncpu < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD) {
239
0
      return MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD);
240
0
    }
241
0
    else {
242
0
      return ncpu;
243
0
    }
244
0
  }
245
0
}
Unexecuted instantiation: sgemm.c:get_gemm_optimal_nthreads
Unexecuted instantiation: dgemm.c:get_gemm_optimal_nthreads
246
#endif
247
248
#ifndef CBLAS
249
250
void NAME(char *TRANSA, char *TRANSB,
251
    blasint *M, blasint *N, blasint *K,
252
    FLOAT *alpha,
253
    IFLOAT *a, blasint *ldA,
254
    IFLOAT *b, blasint *ldB,
255
    FLOAT *beta,
256
0
    FLOAT *c, blasint *ldC){
257
258
0
  blas_arg_t args;
259
260
0
  int transa, transb, nrowa, nrowb;
261
0
  blasint info;
262
263
0
  char transA, transB;
264
0
  IFLOAT *buffer;
265
0
  IFLOAT *sa, *sb;
266
267
0
#ifdef SMP
268
0
  double MNK;
269
#if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
270
#ifndef COMPLEX
271
#ifdef XDOUBLE
272
  int mode  =  BLAS_XDOUBLE | BLAS_REAL;
273
#elif defined(DOUBLE)
274
  int mode  =  BLAS_DOUBLE  | BLAS_REAL;
275
#else
276
  int mode  =  BLAS_SINGLE  | BLAS_REAL;
277
#endif
278
#else
279
#ifdef XDOUBLE
280
  int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
281
#elif defined(DOUBLE)
282
  int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
283
#else
284
  int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
285
#endif
286
#endif
287
#endif
288
0
#endif
289
290
#if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
291
  int nodes;
292
#endif
293
294
0
  PRINT_DEBUG_NAME;
295
296
0
  args.m = *M;
297
0
  args.n = *N;
298
0
  args.k = *K;
299
300
0
  args.a = (void *)a;
301
0
  args.b = (void *)b;
302
0
  args.c = (void *)c;
303
304
0
  args.lda = *ldA;
305
0
  args.ldb = *ldB;
306
0
  args.ldc = *ldC;
307
308
0
  args.alpha = (void *)alpha;
309
0
  args.beta  = (void *)beta;
310
311
0
  transA = *TRANSA;
312
0
  transB = *TRANSB;
313
314
0
  TOUPPER(transA);
315
0
  TOUPPER(transB);
316
317
0
  transa = -1;
318
0
  transb = -1;
319
320
0
  if (transA == 'N') transa = 0;
321
0
  if (transA == 'T') transa = 1;
322
0
#ifndef COMPLEX
323
0
  if (transA == 'R') transa = 0;
324
0
  if (transA == 'C') transa = 1;
325
#else
326
  if (transA == 'R') transa = 2;
327
  if (transA == 'C') transa = 3;
328
#endif
329
330
0
  if (transB == 'N') transb = 0;
331
0
  if (transB == 'T') transb = 1;
332
0
#ifndef COMPLEX
333
0
  if (transB == 'R') transb = 0;
334
0
  if (transB == 'C') transb = 1;
335
#else
336
  if (transB == 'R') transb = 2;
337
  if (transB == 'C') transb = 3;
338
#endif
339
340
0
  nrowa = args.m;
341
0
  if (transa & 1) nrowa = args.k;
342
0
  nrowb = args.k;
343
0
  if (transb & 1) nrowb = args.n;
344
345
0
  info = 0;
346
347
0
  if (args.ldc < args.m) info = 13;
348
0
  if (args.ldb < nrowb)  info = 10;
349
0
  if (args.lda < nrowa)  info =  8;
350
0
  if (args.k < 0)        info =  5;
351
0
  if (args.n < 0)        info =  4;
352
0
  if (args.m < 0)        info =  3;
353
0
  if (transb < 0)        info =  2;
354
0
  if (transa < 0)        info =  1;
355
356
0
  if (info){
357
0
    BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
358
0
    return;
359
0
  }
360
361
#else
362
363
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
364
     blasint m, blasint n, blasint k,
365
#ifndef COMPLEX
366
     FLOAT alpha,
367
     IFLOAT *a, blasint lda,
368
     IFLOAT *b, blasint ldb,
369
     FLOAT beta,
370
     FLOAT *c, blasint ldc) {
371
#else
372
     void *valpha,
373
     void *va, blasint lda,
374
     void *vb, blasint ldb,
375
     void *vbeta,
376
     void *vc, blasint ldc) {
377
  FLOAT *alpha = (FLOAT*) valpha;
378
  FLOAT *beta  = (FLOAT*) vbeta;
379
  FLOAT *a = (FLOAT*) va;
380
  FLOAT *b = (FLOAT*) vb;
381
  FLOAT *c = (FLOAT*) vc;
382
#endif
383
384
  blas_arg_t args;
385
  int transa, transb;
386
  blasint nrowa, nrowb, info;
387
388
  XFLOAT *buffer;
389
  XFLOAT *sa, *sb;
390
391
#ifdef SMP
392
  double MNK;
393
#if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
394
#ifndef COMPLEX
395
#ifdef XDOUBLE
396
  int mode  =  BLAS_XDOUBLE | BLAS_REAL;
397
#elif defined(DOUBLE)
398
  int mode  =  BLAS_DOUBLE  | BLAS_REAL;
399
#else
400
  int mode  =  BLAS_SINGLE  | BLAS_REAL;
401
#endif
402
#else
403
#ifdef XDOUBLE
404
  int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
405
#elif defined(DOUBLE)
406
  int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
407
#else
408
  int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
409
#endif
410
#endif
411
#endif
412
#endif
413
414
#if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
415
  int nodes;
416
#endif
417
418
  PRINT_DEBUG_CNAME;
419
420
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) 
421
#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
422
#if defined(DYNAMIC_ARCH)
423
  if (support_avx512() )
424
#endif
425
  if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
426
  SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
427
  return;
428
  }
429
#endif
430
#if defined(ARCH_ARM64) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
431
#if defined(DYNAMIC_ARCH)
432
 if (support_sme1())
433
#endif
434
  if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
435
  SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
436
  return;
437
  }
438
#endif
439
#endif
440
441
#ifndef COMPLEX
442
  args.alpha = (void *)&alpha;
443
  args.beta  = (void *)&beta;
444
#else
445
  args.alpha = (void *)alpha;
446
  args.beta  = (void *)beta;
447
#endif
448
449
  transa = -1;
450
  transb = -1;
451
  info   =  0;
452
453
  if (order == CblasColMajor) {
454
    args.m = m;
455
    args.n = n;
456
    args.k = k;
457
458
    args.a = (void *)a;
459
    args.b = (void *)b;
460
    args.c = (void *)c;
461
462
    args.lda = lda;
463
    args.ldb = ldb;
464
    args.ldc = ldc;
465
466
    if (TransA == CblasNoTrans)     transa = 0;
467
    if (TransA == CblasTrans)       transa = 1;
468
#ifndef COMPLEX
469
    if (TransA == CblasConjNoTrans) transa = 0;
470
    if (TransA == CblasConjTrans)   transa = 1;
471
#else
472
    if (TransA == CblasConjNoTrans) transa = 2;
473
    if (TransA == CblasConjTrans)   transa = 3;
474
#endif
475
    if (TransB == CblasNoTrans)     transb = 0;
476
    if (TransB == CblasTrans)       transb = 1;
477
#ifndef COMPLEX
478
    if (TransB == CblasConjNoTrans) transb = 0;
479
    if (TransB == CblasConjTrans)   transb = 1;
480
#else
481
    if (TransB == CblasConjNoTrans) transb = 2;
482
    if (TransB == CblasConjTrans)   transb = 3;
483
#endif
484
485
    nrowa = args.m;
486
    if (transa & 1) nrowa = args.k;
487
    nrowb = args.k;
488
    if (transb & 1) nrowb = args.n;
489
490
    info = -1;
491
492
    if (args.ldc < args.m) info = 13;
493
    if (args.ldb < nrowb)  info = 10;
494
    if (args.lda < nrowa)  info =  8;
495
    if (args.k < 0)        info =  5;
496
    if (args.n < 0)        info =  4;
497
    if (args.m < 0)        info =  3;
498
    if (transb < 0)        info =  2;
499
    if (transa < 0)        info =  1;
500
  }
501
502
  if (order == CblasRowMajor) {
503
    args.m = n;
504
    args.n = m;
505
    args.k = k;
506
507
    args.a = (void *)b;
508
    args.b = (void *)a;
509
    args.c = (void *)c;
510
511
    args.lda = ldb;
512
    args.ldb = lda;
513
    args.ldc = ldc;
514
515
    if (TransB == CblasNoTrans)     transa = 0;
516
    if (TransB == CblasTrans)       transa = 1;
517
#ifndef COMPLEX
518
    if (TransB == CblasConjNoTrans) transa = 0;
519
    if (TransB == CblasConjTrans)   transa = 1;
520
#else
521
    if (TransB == CblasConjNoTrans) transa = 2;
522
    if (TransB == CblasConjTrans)   transa = 3;
523
#endif
524
    if (TransA == CblasNoTrans)     transb = 0;
525
    if (TransA == CblasTrans)       transb = 1;
526
#ifndef COMPLEX
527
    if (TransA == CblasConjNoTrans) transb = 0;
528
    if (TransA == CblasConjTrans)   transb = 1;
529
#else
530
    if (TransA == CblasConjNoTrans) transb = 2;
531
    if (TransA == CblasConjTrans)   transb = 3;
532
#endif
533
534
    nrowa = args.m;
535
    if (transa & 1) nrowa = args.k;
536
    nrowb = args.k;
537
    if (transb & 1) nrowb = args.n;
538
539
    info = -1;
540
541
    if (args.ldc < args.m) info = 13;
542
    if (args.ldb < nrowb)  info = 10;
543
    if (args.lda < nrowa)  info =  8;
544
    if (args.k < 0)        info =  5;
545
    if (args.n < 0)        info =  4;
546
    if (args.m < 0)        info =  3;
547
    if (transb < 0)        info =  2;
548
    if (transa < 0)        info =  1;
549
550
  }
551
552
  if (info >= 0) {
553
    BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
554
    return;
555
  }
556
557
#endif
558
559
#if defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
560
#if defined(DYNAMIC_ARCH)
561
  if (gotoblas->need_amxtile_permission &&
562
      openblas_amxtile_permission == 0 && init_amxtile_permission() == -1) {
563
    return;
564
  }
565
#endif
566
#if !defined(DYNAMIC_ARCH) && defined(SAPPHIRERAPIDS)
567
  if (openblas_amxtile_permission == 0 && init_amxtile_permission() == -1) {
568
    return;
569
  }
570
#endif
571
#endif  // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
572
573
0
  if ((args.m == 0) || (args.n == 0)) return;
574
575
#if 0
576
  fprintf(stderr, "m = %4d  n = %d  k = %d  lda = %4d  ldb = %4d  ldc = %4d\n",
577
   args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
578
#endif
579
580
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
581
#if defined(ARCH_ARM64)
582
  // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
583
  // perform poorly in certain circumstances. We use the following boolean
584
  // variable along with the gemv argument values to avoid these inefficient
585
  // gemv cases, see github issue#4951.
586
  bool have_tuned_gemv = false;
587
#else
588
  bool have_tuned_gemv = true;
589
#endif
590
  // Check if we can convert GEMM -> GEMV
591
  if (args.k != 0) {
592
    if (args.n == 1) {
593
      blasint inc_x = 1;
594
      blasint inc_y = 1;
595
      // These were passed in as blasint, but the struct translates them to blaslong
596
      blasint m = args.m;
597
      blasint n = args.k;
598
      blasint lda = args.lda;
599
      // Create new transpose parameters
600
      char NT = 'N';
601
      if (transa & 1) {
602
        NT = 'T';
603
        m = args.k;
604
        n = args.m;
605
      }
606
      if (transb & 1) {
607
        inc_x = args.ldb;
608
      }
609
      bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1));
610
      if (is_efficient_gemv) {
611
        GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
612
        return;
613
      }
614
    }
615
    if (args.m == 1) {
616
      blasint inc_x = args.lda;
617
      blasint inc_y = args.ldc;
618
      // These were passed in as blasint, but the struct translates them to blaslong
619
      blasint m = args.k;
620
      blasint n = args.n;
621
      blasint ldb = args.ldb;
622
      // Create new transpose parameters
623
      char NT = 'T';
624
      if (transa & 1) {
625
        inc_x = 1;
626
      }
627
      if (transb & 1) {
628
        NT = 'N';
629
        m = args.n;
630
        n = args.k;
631
      }
632
      bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1));
633
      if (is_efficient_gemv) {
634
        GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
635
        return;
636
      }
637
    }
638
  }
639
#endif
640
641
0
  IDEBUG_START;
642
643
0
  FUNCTION_PROFILE_START();
644
645
0
#if USE_SMALL_MATRIX_OPT
646
0
#if !defined(COMPLEX)
647
0
  if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){
648
0
    if(*(FLOAT *)(args.beta) == 0.0){
649
0
    (GEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc);
650
0
    }else{
651
0
    (GEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc);
652
0
    }
653
0
    return;
654
0
  }
655
#else
656
  if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, alpha[0], alpha[1], beta[0], beta[1])){
657
    if(beta[0] == 0.0 && beta[1] == 0.0){
658
    (ZGEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc);
659
    }else{
660
    (ZGEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc);
661
    }
662
    return;
663
  }
664
#endif
665
0
#endif
666
667
0
  buffer = (XFLOAT *)blas_memory_alloc(0);
668
669
//For LOONGARCH64, applying an offset to the buffer is essential
670
//for minimizing cache conflicts and optimizing performance.
671
#if defined(ARCH_LOONGARCH64) && !defined(NO_AFFINITY)
672
  sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A);
673
#else
674
0
  sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
675
0
#endif
676
0
  sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
677
678
0
#ifdef SMP
679
#if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
680
  mode |= (transa << BLAS_TRANSA_SHIFT);
681
  mode |= (transb << BLAS_TRANSB_SHIFT);
682
#endif
683
684
0
  MNK = (double) args.m * (double) args.n * (double) args.k;
685
0
  args.nthreads = get_gemm_optimal_nthreads(MNK);
686
687
0
  args.common = NULL;
688
689
0
 if (args.nthreads == 1) {
690
0
#endif
691
692
0
    (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
693
694
0
#ifdef SMP
695
696
0
  } else {
697
698
0
#ifndef USE_SIMPLE_THREADED_LEVEL3
699
700
#ifndef NO_AFFINITY
701
      nodes = get_num_nodes();
702
703
      if ((nodes > 1) && get_node_equal()) {
704
705
  args.nthreads /= nodes;
706
707
  gemm_thread_mn(mode, &args, NULL, NULL, gemm[16 | (transb << 2) | transa], sa, sb, nodes);
708
709
      } else {
710
#endif
711
712
0
  (gemm[16 | (transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
713
714
#else
715
716
  GEMM_THREAD(mode, &args, NULL, NULL, gemm[(transb << 2) | transa], sa, sb, args.nthreads);
717
718
#endif
719
720
0
#ifndef USE_SIMPLE_THREADED_LEVEL3
721
#ifndef NO_AFFINITY
722
      }
723
#endif
724
0
#endif
725
726
0
#endif
727
728
0
#ifdef SMP
729
0
  }
730
0
#endif
731
732
0
 blas_memory_free(buffer);
733
734
0
  FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, args.m * args.k + args.k * args.n + args.m * args.n, 2 * args.m * args.n * args.k);
735
736
0
  IDEBUG_END;
737
738
0
  return;
739
0
}
Unexecuted instantiation: sgemm_
Unexecuted instantiation: dgemm_