blas.c 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2012 Inria
  4. * Copyright (C) 2009-2011,2014-2015, 2018 Université de Bordeaux
  5. * Copyright (C) 2010,2015,2017 CNRS
  6. *
  7. * StarPU is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU Lesser General Public License as published by
  9. * the Free Software Foundation; either version 2.1 of the License, or (at
  10. * your option) any later version.
  11. *
  12. * StarPU is distributed in the hope that it will be useful, but
  13. * WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  15. *
  16. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  17. */
  18. #include <ctype.h>
  19. #include <stdio.h>
  20. #include <starpu.h>
  21. #include "blas.h"
  22. /*
  23. This files contains BLAS wrappers for the different BLAS implementations
  24. (eg. REFBLAS, ATLAS, GOTOBLAS ...). We assume a Fortran orientation as most
  25. libraries do not supply C-based ordering.
  26. */
  27. #ifdef STARPU_ATLAS
  28. inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K,
  29. float alpha, const float *A, int lda, const float *B, int ldb,
  30. float beta, float *C, int ldc)
  31. {
  32. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  33. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  34. cblas_sgemm(CblasColMajor, ta, tb,
  35. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  36. }
  37. inline void STARPU_DGEMM(char *transa, char *transb, int M, int N, int K,
  38. double alpha, double *A, int lda, double *B, int ldb,
  39. double beta, double *C, int ldc)
  40. {
  41. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  42. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  43. cblas_dgemm(CblasColMajor, ta, tb,
  44. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  45. }
  46. inline void STARPU_SGEMV(char *transa, int M, int N, float alpha, float *A, int lda, float *X, int incX, float beta, float *Y, int incY)
  47. {
  48. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  49. cblas_sgemv(CblasColMajor, ta, M, N, alpha, A, lda,
  50. X, incX, beta, Y, incY);
  51. }
  52. inline void STARPU_DGEMV(char *transa, int M, int N, double alpha, double *A, int lda, double *X, int incX, double beta, double *Y, int incY)
  53. {
  54. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  55. cblas_dgemv(CblasColMajor, ta, M, N, alpha, A, lda,
  56. X, incX, beta, Y, incY);
  57. }
  58. inline float STARPU_SASUM(int N, float *X, int incX)
  59. {
  60. return cblas_sasum(N, X, incX);
  61. }
  62. inline double STARPU_DASUM(int N, double *X, int incX)
  63. {
  64. return cblas_dasum(N, X, incX);
  65. }
  66. void STARPU_SSCAL(int N, float alpha, float *X, int incX)
  67. {
  68. cblas_sscal(N, alpha, X, incX);
  69. }
  70. void STARPU_DSCAL(int N, double alpha, double *X, int incX)
  71. {
  72. cblas_dscal(N, alpha, X, incX);
  73. }
  74. void STARPU_STRSM (const char *side, const char *uplo, const char *transa,
  75. const char *diag, const int m, const int n,
  76. const float alpha, const float *A, const int lda,
  77. float *B, const int ldb)
  78. {
  79. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  80. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  81. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  82. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  83. cblas_strsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  84. }
  85. void STARPU_DTRSM (const char *side, const char *uplo, const char *transa,
  86. const char *diag, const int m, const int n,
  87. const double alpha, const double *A, const int lda,
  88. double *B, const int ldb)
  89. {
  90. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  91. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  92. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  93. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  94. cblas_dtrsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  95. }
  96. void STARPU_SSYR (const char *uplo, const int n, const float alpha,
  97. const float *x, const int incx, float *A, const int lda)
  98. {
  99. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  100. cblas_ssyr(CblasColMajor, uplo_, n, alpha, x, incx, A, lda);
  101. }
  102. void STARPU_SSYRK (const char *uplo, const char *trans, const int n,
  103. const int k, const float alpha, const float *A,
  104. const int lda, const float beta, float *C,
  105. const int ldc)
  106. {
  107. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  108. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  109. cblas_ssyrk(CblasColMajor, uplo_, trans_, n, k, alpha, A, lda, beta, C, ldc);
  110. }
  111. void STARPU_SGER(const int m, const int n, const float alpha,
  112. const float *x, const int incx, const float *y,
  113. const int incy, float *A, const int lda)
  114. {
  115. cblas_sger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  116. }
  117. void STARPU_DGER(const int m, const int n, const double alpha,
  118. const double *x, const int incx, const double *y,
  119. const int incy, double *A, const int lda)
  120. {
  121. cblas_dger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  122. }
  123. void STARPU_STRSV (const char *uplo, const char *trans, const char *diag,
  124. const int n, const float *A, const int lda, float *x,
  125. const int incx)
  126. {
  127. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  128. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  129. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  130. cblas_strsv(CblasColMajor, uplo_, trans_, diag_, n, A, lda, x, incx);
  131. }
  132. void STARPU_STRMM(const char *side, const char *uplo, const char *transA,
  133. const char *diag, const int m, const int n,
  134. const float alpha, const float *A, const int lda,
  135. float *B, const int ldb)
  136. {
  137. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  138. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  139. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  140. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  141. cblas_strmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  142. }
  143. void STARPU_DTRMM(const char *side, const char *uplo, const char *transA,
  144. const char *diag, const int m, const int n,
  145. const double alpha, const double *A, const int lda,
  146. double *B, const int ldb)
  147. {
  148. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  149. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  150. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  151. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  152. cblas_dtrmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  153. }
  154. void STARPU_STRMV(const char *uplo, const char *transA, const char *diag,
  155. const int n, const float *A, const int lda, float *X,
  156. const int incX)
  157. {
  158. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  159. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  160. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  161. cblas_strmv(CblasColMajor, uplo_, transA_, diag_, n, A, lda, X, incX);
  162. }
  163. void STARPU_SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  164. {
  165. cblas_saxpy(n, alpha, X, incX, Y, incY);
  166. }
  167. void STARPU_DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  168. {
  169. cblas_daxpy(n, alpha, X, incX, Y, incY);
  170. }
  171. int STARPU_ISAMAX (const int n, float *X, const int incX)
  172. {
  173. int retVal;
  174. retVal = cblas_isamax(n, X, incX);
  175. return retVal;
  176. }
  177. int STARPU_IDAMAX (const int n, double *X, const int incX)
  178. {
  179. int retVal;
  180. retVal = cblas_idamax(n, X, incX);
  181. return retVal;
  182. }
  183. float STARPU_SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  184. {
  185. return cblas_sdot(n, x, incx, y, incy);
  186. }
  187. double STARPU_DDOT(const int n, const double *x, const int incx, const double *y, const int incy)
  188. {
  189. return cblas_ddot(n, x, incx, y, incy);
  190. }
  191. void STARPU_SSWAP(const int n, float *x, const int incx, float *y, const int incy)
  192. {
  193. cblas_sswap(n, x, incx, y, incy);
  194. }
  195. void STARPU_DSWAP(const int n, double *x, const int incx, double *y, const int incy)
  196. {
  197. cblas_dswap(n, x, incx, y, incy);
  198. }
  199. #elif defined(STARPU_GOTO) || defined(STARPU_OPENBLAS) || defined(STARPU_SYSTEM_BLAS) || defined(STARPU_MKL)
  200. inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K,
  201. float alpha, const float *A, int lda, const float *B, int ldb,
  202. float beta, float *C, int ldc)
  203. {
  204. sgemm_(transa, transb, &M, &N, &K, &alpha,
  205. A, &lda, B, &ldb,
  206. &beta, C, &ldc);
  207. }
  208. inline void STARPU_DGEMM(char *transa, char *transb, int M, int N, int K,
  209. double alpha, double *A, int lda, double *B, int ldb,
  210. double beta, double *C, int ldc)
  211. {
  212. dgemm_(transa, transb, &M, &N, &K, &alpha,
  213. A, &lda, B, &ldb,
  214. &beta, C, &ldc);
  215. }
  216. inline void STARPU_SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
  217. float *X, int incX, float beta, float *Y, int incY)
  218. {
  219. sgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  220. }
  221. inline void STARPU_DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
  222. double *X, int incX, double beta, double *Y, int incY)
  223. {
  224. dgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  225. }
  226. inline float STARPU_SASUM(int N, float *X, int incX)
  227. {
  228. return sasum_(&N, X, &incX);
  229. }
  230. inline double STARPU_DASUM(int N, double *X, int incX)
  231. {
  232. return dasum_(&N, X, &incX);
  233. }
  234. void STARPU_SSCAL(int N, float alpha, float *X, int incX)
  235. {
  236. sscal_(&N, &alpha, X, &incX);
  237. }
  238. void STARPU_DSCAL(int N, double alpha, double *X, int incX)
  239. {
  240. dscal_(&N, &alpha, X, &incX);
  241. }
  242. void STARPU_STRSM (const char *side, const char *uplo, const char *transa,
  243. const char *diag, const int m, const int n,
  244. const float alpha, const float *A, const int lda,
  245. float *B, const int ldb)
  246. {
  247. strsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  248. }
  249. void STARPU_DTRSM (const char *side, const char *uplo, const char *transa,
  250. const char *diag, const int m, const int n,
  251. const double alpha, const double *A, const int lda,
  252. double *B, const int ldb)
  253. {
  254. dtrsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  255. }
  256. void STARPU_SSYR (const char *uplo, const int n, const float alpha,
  257. const float *x, const int incx, float *A, const int lda)
  258. {
  259. ssyr_(uplo, &n, &alpha, x, &incx, A, &lda);
  260. }
  261. void STARPU_SSYRK (const char *uplo, const char *trans, const int n,
  262. const int k, const float alpha, const float *A,
  263. const int lda, const float beta, float *C,
  264. const int ldc)
  265. {
  266. ssyrk_(uplo, trans, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
  267. }
  268. void STARPU_SGER(const int m, const int n, const float alpha,
  269. const float *x, const int incx, const float *y,
  270. const int incy, float *A, const int lda)
  271. {
  272. sger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  273. }
  274. void STARPU_DGER(const int m, const int n, const double alpha,
  275. const double *x, const int incx, const double *y,
  276. const int incy, double *A, const int lda)
  277. {
  278. dger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  279. }
  280. void STARPU_STRSV (const char *uplo, const char *trans, const char *diag,
  281. const int n, const float *A, const int lda, float *x,
  282. const int incx)
  283. {
  284. strsv_(uplo, trans, diag, &n, A, &lda, x, &incx);
  285. }
  286. void STARPU_STRMM(const char *side, const char *uplo, const char *transA,
  287. const char *diag, const int m, const int n,
  288. const float alpha, const float *A, const int lda,
  289. float *B, const int ldb)
  290. {
  291. strmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  292. }
  293. void STARPU_DTRMM(const char *side, const char *uplo, const char *transA,
  294. const char *diag, const int m, const int n,
  295. const double alpha, const double *A, const int lda,
  296. double *B, const int ldb)
  297. {
  298. dtrmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  299. }
  300. void STARPU_STRMV(const char *uplo, const char *transA, const char *diag,
  301. const int n, const float *A, const int lda, float *X,
  302. const int incX)
  303. {
  304. strmv_(uplo, transA, diag, &n, A, &lda, X, &incX);
  305. }
  306. void STARPU_SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  307. {
  308. saxpy_(&n, &alpha, X, &incX, Y, &incY);
  309. }
  310. void STARPU_DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  311. {
  312. daxpy_(&n, &alpha, X, &incX, Y, &incY);
  313. }
  314. int STARPU_ISAMAX (const int n, float *X, const int incX)
  315. {
  316. int retVal;
  317. retVal = isamax_ (&n, X, &incX);
  318. return retVal;
  319. }
  320. int STARPU_IDAMAX (const int n, double *X, const int incX)
  321. {
  322. int retVal;
  323. retVal = idamax_ (&n, X, &incX);
  324. return retVal;
  325. }
  326. float STARPU_SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  327. {
  328. float retVal = 0;
  329. /* GOTOBLAS will return a FLOATRET which is a double, not a float */
  330. retVal = (float)sdot_(&n, x, &incx, y, &incy);
  331. return retVal;
  332. }
  333. double STARPU_DDOT(const int n, const double *x, const int incx, const double *y, const int incy)
  334. {
  335. return ddot_(&n, x, &incx, y, &incy);
  336. }
  337. void STARPU_SSWAP(const int n, float *X, const int incX, float *Y, const int incY)
  338. {
  339. sswap_(&n, X, &incX, Y, &incY);
  340. }
  341. void STARPU_DSWAP(const int n, double *X, const int incX, double *Y, const int incY)
  342. {
  343. dswap_(&n, X, &incX, Y, &incY);
  344. }
  345. #ifdef STARPU_MKL
  346. void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda)
  347. {
  348. int info = 0;
  349. spotrf_(uplo, &n, a, &lda, &info);
  350. }
  351. void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda)
  352. {
  353. int info = 0;
  354. dpotrf_(uplo, &n, a, &lda, &info);
  355. }
  356. #endif
  357. #elif defined(STARPU_SIMGRID)
  358. inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K,
  359. float alpha, const float *A, int lda, const float *B, int ldb,
  360. float beta, float *C, int ldc) { }
  361. inline void STARPU_DGEMM(char *transa, char *transb, int M, int N, int K,
  362. double alpha, double *A, int lda, double *B, int ldb,
  363. double beta, double *C, int ldc) { }
  364. inline void STARPU_SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
  365. float *X, int incX, float beta, float *Y, int incY) { }
  366. inline void STARPU_DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
  367. double *X, int incX, double beta, double *Y, int incY) { }
  368. inline float STARPU_SASUM(int N, float *X, int incX) { return 0.; }
  369. inline double STARPU_DASUM(int N, double *X, int incX) { return 0.; }
  370. void STARPU_SSCAL(int N, float alpha, float *X, int incX) { }
  371. void STARPU_DSCAL(int N, double alpha, double *X, int incX) { }
  372. void STARPU_STRSM (const char *side, const char *uplo, const char *transa,
  373. const char *diag, const int m, const int n,
  374. const float alpha, const float *A, const int lda,
  375. float *B, const int ldb) { }
  376. void STARPU_DTRSM (const char *side, const char *uplo, const char *transa,
  377. const char *diag, const int m, const int n,
  378. const double alpha, const double *A, const int lda,
  379. double *B, const int ldb) { }
  380. void STARPU_SSYR (const char *uplo, const int n, const float alpha,
  381. const float *x, const int incx, float *A, const int lda) { }
  382. void STARPU_SSYRK (const char *uplo, const char *trans, const int n,
  383. const int k, const float alpha, const float *A,
  384. const int lda, const float beta, float *C,
  385. const int ldc) { }
  386. void STARPU_SGER(const int m, const int n, const float alpha,
  387. const float *x, const int incx, const float *y,
  388. const int incy, float *A, const int lda) { }
  389. void STARPU_DGER(const int m, const int n, const double alpha,
  390. const double *x, const int incx, const double *y,
  391. const int incy, double *A, const int lda) { }
  392. void STARPU_STRSV (const char *uplo, const char *trans, const char *diag,
  393. const int n, const float *A, const int lda, float *x,
  394. const int incx) { }
  395. void STARPU_STRMM(const char *side, const char *uplo, const char *transA,
  396. const char *diag, const int m, const int n,
  397. const float alpha, const float *A, const int lda,
  398. float *B, const int ldb) { }
  399. void STARPU_DTRMM(const char *side, const char *uplo, const char *transA,
  400. const char *diag, const int m, const int n,
  401. const double alpha, const double *A, const int lda,
  402. double *B, const int ldb) { }
  403. void STARPU_STRMV(const char *uplo, const char *transA, const char *diag,
  404. const int n, const float *A, const int lda, float *X,
  405. const int incX) { }
  406. void STARPU_SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY) { }
  407. void STARPU_DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY) { }
  408. int STARPU_ISAMAX (const int n, float *X, const int incX) { return 0; }
  409. int STARPU_IDAMAX (const int n, double *X, const int incX) { return 0; }
  410. float STARPU_SDOT(const int n, const float *x, const int incx, const float *y, const int incy) { return 0.; }
  411. double STARPU_DDOT(const int n, const double *x, const int incx, const double *y, const int incy) { return 0.; }
  412. void STARPU_SSWAP(const int n, float *X, const int incX, float *Y, const int incY) { }
  413. void STARPU_DSWAP(const int n, double *X, const int incX, double *Y, const int incY) { }
  414. void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda) { }
  415. void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda) { }
  416. #else
  417. #error "no BLAS lib available..."
  418. #endif