blas.c 18 KB


  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2009-2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <ctype.h>
  17. #include <stdio.h>
  18. #include <starpu.h>
  19. #include "blas.h"
  20. /*
  21. This files contains BLAS wrappers for the different BLAS implementations
  22. (eg. REFBLAS, ATLAS, GOTOBLAS ...). We assume a Fortran orientation as most
  23. libraries do not supply C-based ordering.
  24. */
  25. #ifdef STARPU_ATLAS
  26. inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K,
  27. float alpha, const float *A, int lda, const float *B, int ldb,
  28. float beta, float *C, int ldc)
  29. {
  30. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  31. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  32. cblas_sgemm(CblasColMajor, ta, tb,
  33. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  34. }
  35. inline void STARPU_DGEMM(char *transa, char *transb, int M, int N, int K,
  36. double alpha, double *A, int lda, double *B, int ldb,
  37. double beta, double *C, int ldc)
  38. {
  39. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  40. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  41. cblas_dgemm(CblasColMajor, ta, tb,
  42. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  43. }
  44. inline void STARPU_SGEMV(char *transa, int M, int N, float alpha, float *A, int lda, float *X, int incX, float beta, float *Y, int incY)
  45. {
  46. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  47. cblas_sgemv(CblasColMajor, ta, M, N, alpha, A, lda,
  48. X, incX, beta, Y, incY);
  49. }
  50. inline void STARPU_DGEMV(char *transa, int M, int N, double alpha, double *A, int lda, double *X, int incX, double beta, double *Y, int incY)
  51. {
  52. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  53. cblas_dgemv(CblasColMajor, ta, M, N, alpha, A, lda,
  54. X, incX, beta, Y, incY);
  55. }
  56. inline float STARPU_SASUM(int N, float *X, int incX)
  57. {
  58. return cblas_sasum(N, X, incX);
  59. }
  60. inline double STARPU_DASUM(int N, double *X, int incX)
  61. {
  62. return cblas_dasum(N, X, incX);
  63. }
  64. void STARPU_SSCAL(int N, float alpha, float *X, int incX)
  65. {
  66. cblas_sscal(N, alpha, X, incX);
  67. }
  68. void STARPU_DSCAL(int N, double alpha, double *X, int incX)
  69. {
  70. cblas_dscal(N, alpha, X, incX);
  71. }
  72. void STARPU_STRSM (const char *side, const char *uplo, const char *transa,
  73. const char *diag, const int m, const int n,
  74. const float alpha, const float *A, const int lda,
  75. float *B, const int ldb)
  76. {
  77. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  78. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  79. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  80. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  81. cblas_strsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  82. }
  83. void STARPU_DTRSM (const char *side, const char *uplo, const char *transa,
  84. const char *diag, const int m, const int n,
  85. const double alpha, const double *A, const int lda,
  86. double *B, const int ldb)
  87. {
  88. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  89. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  90. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  91. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  92. cblas_dtrsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  93. }
  94. void STARPU_SSYR (const char *uplo, const int n, const float alpha,
  95. const float *x, const int incx, float *A, const int lda)
  96. {
  97. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  98. cblas_ssyr(CblasColMajor, uplo_, n, alpha, x, incx, A, lda);
  99. }
  100. void STARPU_SSYRK (const char *uplo, const char *trans, const int n,
  101. const int k, const float alpha, const float *A,
  102. const int lda, const float beta, float *C,
  103. const int ldc)
  104. {
  105. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  106. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  107. cblas_ssyrk(CblasColMajor, uplo_, trans_, n, k, alpha, A, lda, beta, C, ldc);
  108. }
  109. void STARPU_SGER(const int m, const int n, const float alpha,
  110. const float *x, const int incx, const float *y,
  111. const int incy, float *A, const int lda)
  112. {
  113. cblas_sger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  114. }
  115. void STARPU_DGER(const int m, const int n, const double alpha,
  116. const double *x, const int incx, const double *y,
  117. const int incy, double *A, const int lda)
  118. {
  119. cblas_dger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  120. }
  121. void STARPU_STRSV (const char *uplo, const char *trans, const char *diag,
  122. const int n, const float *A, const int lda, float *x,
  123. const int incx)
  124. {
  125. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  126. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  127. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  128. cblas_strsv(CblasColMajor, uplo_, trans_, diag_, n, A, lda, x, incx);
  129. }
  130. void STARPU_STRMM(const char *side, const char *uplo, const char *transA,
  131. const char *diag, const int m, const int n,
  132. const float alpha, const float *A, const int lda,
  133. float *B, const int ldb)
  134. {
  135. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  136. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  137. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  138. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  139. cblas_strmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  140. }
  141. void STARPU_DTRMM(const char *side, const char *uplo, const char *transA,
  142. const char *diag, const int m, const int n,
  143. const double alpha, const double *A, const int lda,
  144. double *B, const int ldb)
  145. {
  146. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  147. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  148. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  149. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  150. cblas_dtrmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  151. }
  152. void STARPU_STRMV(const char *uplo, const char *transA, const char *diag,
  153. const int n, const float *A, const int lda, float *X,
  154. const int incX)
  155. {
  156. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  157. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  158. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  159. cblas_strmv(CblasColMajor, uplo_, transA_, diag_, n, A, lda, X, incX);
  160. }
  161. void STARPU_SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  162. {
  163. cblas_saxpy(n, alpha, X, incX, Y, incY);
  164. }
  165. void STARPU_DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  166. {
  167. cblas_daxpy(n, alpha, X, incX, Y, incY);
  168. }
  169. int STARPU_ISAMAX (const int n, float *X, const int incX)
  170. {
  171. int retVal;
  172. retVal = cblas_isamax(n, X, incX);
  173. return retVal;
  174. }
  175. int STARPU_IDAMAX (const int n, double *X, const int incX)
  176. {
  177. int retVal;
  178. retVal = cblas_idamax(n, X, incX);
  179. return retVal;
  180. }
  181. float STARPU_SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  182. {
  183. return cblas_sdot(n, x, incx, y, incy);
  184. }
  185. double STARPU_DDOT(const int n, const double *x, const int incx, const double *y, const int incy)
  186. {
  187. return cblas_ddot(n, x, incx, y, incy);
  188. }
  189. void STARPU_SSWAP(const int n, float *x, const int incx, float *y, const int incy)
  190. {
  191. cblas_sswap(n, x, incx, y, incy);
  192. }
  193. void STARPU_DSWAP(const int n, double *x, const int incx, double *y, const int incy)
  194. {
  195. cblas_dswap(n, x, incx, y, incy);
  196. }
  197. #elif defined(STARPU_GOTO) || defined(STARPU_OPENBLAS) || defined(STARPU_SYSTEM_BLAS) || defined(STARPU_MKL) || defined(STARPU_ARMPL)
  198. inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K,
  199. float alpha, const float *A, int lda, const float *B, int ldb,
  200. float beta, float *C, int ldc)
  201. {
  202. sgemm_(transa, transb, &M, &N, &K, &alpha,
  203. A, &lda, B, &ldb,
  204. &beta, C, &ldc);
  205. }
  206. inline void STARPU_DGEMM(char *transa, char *transb, int M, int N, int K,
  207. double alpha, double *A, int lda, double *B, int ldb,
  208. double beta, double *C, int ldc)
  209. {
  210. dgemm_(transa, transb, &M, &N, &K, &alpha,
  211. A, &lda, B, &ldb,
  212. &beta, C, &ldc);
  213. }
  214. inline void STARPU_SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
  215. float *X, int incX, float beta, float *Y, int incY)
  216. {
  217. sgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  218. }
  219. inline void STARPU_DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
  220. double *X, int incX, double beta, double *Y, int incY)
  221. {
  222. dgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  223. }
  224. inline float STARPU_SASUM(int N, float *X, int incX)
  225. {
  226. return sasum_(&N, X, &incX);
  227. }
  228. inline double STARPU_DASUM(int N, double *X, int incX)
  229. {
  230. return dasum_(&N, X, &incX);
  231. }
  232. void STARPU_SSCAL(int N, float alpha, float *X, int incX)
  233. {
  234. sscal_(&N, &alpha, X, &incX);
  235. }
  236. void STARPU_DSCAL(int N, double alpha, double *X, int incX)
  237. {
  238. dscal_(&N, &alpha, X, &incX);
  239. }
  240. void STARPU_STRSM (const char *side, const char *uplo, const char *transa,
  241. const char *diag, const int m, const int n,
  242. const float alpha, const float *A, const int lda,
  243. float *B, const int ldb)
  244. {
  245. strsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  246. }
  247. void STARPU_DTRSM (const char *side, const char *uplo, const char *transa,
  248. const char *diag, const int m, const int n,
  249. const double alpha, const double *A, const int lda,
  250. double *B, const int ldb)
  251. {
  252. dtrsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  253. }
  254. void STARPU_SSYR (const char *uplo, const int n, const float alpha,
  255. const float *x, const int incx, float *A, const int lda)
  256. {
  257. ssyr_(uplo, &n, &alpha, x, &incx, A, &lda);
  258. }
  259. void STARPU_SSYRK (const char *uplo, const char *trans, const int n,
  260. const int k, const float alpha, const float *A,
  261. const int lda, const float beta, float *C,
  262. const int ldc)
  263. {
  264. ssyrk_(uplo, trans, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
  265. }
  266. void STARPU_SGER(const int m, const int n, const float alpha,
  267. const float *x, const int incx, const float *y,
  268. const int incy, float *A, const int lda)
  269. {
  270. sger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  271. }
  272. void STARPU_DGER(const int m, const int n, const double alpha,
  273. const double *x, const int incx, const double *y,
  274. const int incy, double *A, const int lda)
  275. {
  276. dger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  277. }
  278. void STARPU_STRSV (const char *uplo, const char *trans, const char *diag,
  279. const int n, const float *A, const int lda, float *x,
  280. const int incx)
  281. {
  282. strsv_(uplo, trans, diag, &n, A, &lda, x, &incx);
  283. }
  284. void STARPU_STRMM(const char *side, const char *uplo, const char *transA,
  285. const char *diag, const int m, const int n,
  286. const float alpha, const float *A, const int lda,
  287. float *B, const int ldb)
  288. {
  289. strmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  290. }
  291. void STARPU_DTRMM(const char *side, const char *uplo, const char *transA,
  292. const char *diag, const int m, const int n,
  293. const double alpha, const double *A, const int lda,
  294. double *B, const int ldb)
  295. {
  296. dtrmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  297. }
  298. void STARPU_STRMV(const char *uplo, const char *transA, const char *diag,
  299. const int n, const float *A, const int lda, float *X,
  300. const int incX)
  301. {
  302. strmv_(uplo, transA, diag, &n, A, &lda, X, &incX);
  303. }
  304. void STARPU_SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  305. {
  306. saxpy_(&n, &alpha, X, &incX, Y, &incY);
  307. }
  308. void STARPU_DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  309. {
  310. daxpy_(&n, &alpha, X, &incX, Y, &incY);
  311. }
  312. int STARPU_ISAMAX (const int n, float *X, const int incX)
  313. {
  314. int retVal;
  315. retVal = isamax_ (&n, X, &incX);
  316. return retVal;
  317. }
  318. int STARPU_IDAMAX (const int n, double *X, const int incX)
  319. {
  320. int retVal;
  321. retVal = idamax_ (&n, X, &incX);
  322. return retVal;
  323. }
  324. float STARPU_SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  325. {
  326. float retVal = 0;
  327. /* GOTOBLAS will return a FLOATRET which is a double, not a float */
  328. retVal = (float)sdot_(&n, x, &incx, y, &incy);
  329. return retVal;
  330. }
  331. double STARPU_DDOT(const int n, const double *x, const int incx, const double *y, const int incy)
  332. {
  333. return ddot_(&n, x, &incx, y, &incy);
  334. }
  335. void STARPU_SSWAP(const int n, float *X, const int incX, float *Y, const int incY)
  336. {
  337. sswap_(&n, X, &incX, Y, &incY);
  338. }
  339. void STARPU_DSWAP(const int n, double *X, const int incX, double *Y, const int incY)
  340. {
  341. dswap_(&n, X, &incX, Y, &incY);
  342. }
  343. #if defined(STARPU_MKL) || defined(STARPU_ARMPL)
  344. void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda)
  345. {
  346. int info = 0;
  347. spotrf_(uplo, &n, a, &lda, &info);
  348. }
  349. void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda)
  350. {
  351. int info = 0;
  352. dpotrf_(uplo, &n, a, &lda, &info);
  353. }
  354. #endif
  355. #elif defined(STARPU_SIMGRID)
  356. inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K,
  357. float alpha, const float *A, int lda, const float *B, int ldb,
  358. float beta, float *C, int ldc) { }
  359. inline void STARPU_DGEMM(char *transa, char *transb, int M, int N, int K,
  360. double alpha, double *A, int lda, double *B, int ldb,
  361. double beta, double *C, int ldc) { }
  362. inline void STARPU_SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
  363. float *X, int incX, float beta, float *Y, int incY) { }
  364. inline void STARPU_DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
  365. double *X, int incX, double beta, double *Y, int incY) { }
  366. inline float STARPU_SASUM(int N, float *X, int incX) { return 0.; }
  367. inline double STARPU_DASUM(int N, double *X, int incX) { return 0.; }
  368. void STARPU_SSCAL(int N, float alpha, float *X, int incX) { }
  369. void STARPU_DSCAL(int N, double alpha, double *X, int incX) { }
  370. void STARPU_STRSM (const char *side, const char *uplo, const char *transa,
  371. const char *diag, const int m, const int n,
  372. const float alpha, const float *A, const int lda,
  373. float *B, const int ldb) { }
  374. void STARPU_DTRSM (const char *side, const char *uplo, const char *transa,
  375. const char *diag, const int m, const int n,
  376. const double alpha, const double *A, const int lda,
  377. double *B, const int ldb) { }
  378. void STARPU_SSYR (const char *uplo, const int n, const float alpha,
  379. const float *x, const int incx, float *A, const int lda) { }
  380. void STARPU_SSYRK (const char *uplo, const char *trans, const int n,
  381. const int k, const float alpha, const float *A,
  382. const int lda, const float beta, float *C,
  383. const int ldc) { }
  384. void STARPU_SGER(const int m, const int n, const float alpha,
  385. const float *x, const int incx, const float *y,
  386. const int incy, float *A, const int lda) { }
  387. void STARPU_DGER(const int m, const int n, const double alpha,
  388. const double *x, const int incx, const double *y,
  389. const int incy, double *A, const int lda) { }
  390. void STARPU_STRSV (const char *uplo, const char *trans, const char *diag,
  391. const int n, const float *A, const int lda, float *x,
  392. const int incx) { }
  393. void STARPU_STRMM(const char *side, const char *uplo, const char *transA,
  394. const char *diag, const int m, const int n,
  395. const float alpha, const float *A, const int lda,
  396. float *B, const int ldb) { }
  397. void STARPU_DTRMM(const char *side, const char *uplo, const char *transA,
  398. const char *diag, const int m, const int n,
  399. const double alpha, const double *A, const int lda,
  400. double *B, const int ldb) { }
  401. void STARPU_STRMV(const char *uplo, const char *transA, const char *diag,
  402. const int n, const float *A, const int lda, float *X,
  403. const int incX) { }
  404. void STARPU_SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY) { }
  405. void STARPU_DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY) { }
  406. int STARPU_ISAMAX (const int n, float *X, const int incX) { return 0; }
  407. int STARPU_IDAMAX (const int n, double *X, const int incX) { return 0; }
  408. float STARPU_SDOT(const int n, const float *x, const int incx, const float *y, const int incy) { return 0.; }
  409. double STARPU_DDOT(const int n, const double *x, const int incx, const double *y, const int incy) { return 0.; }
  410. void STARPU_SSWAP(const int n, float *X, const int incX, float *Y, const int incY) { }
  411. void STARPU_DSWAP(const int n, double *X, const int incX, double *Y, const int incY) { }
  412. void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda) { }
  413. void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda) { }
  414. #else
  415. #error "no BLAS lib available..."
  416. #endif