blas.c 14 KB


  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2009, 2010 Université de Bordeaux 1
  4. * Copyright (C) 2010 Centre National de la Recherche Scientifique
  5. *
  6. * StarPU is free software; you can redistribute it and/or modify
  7. * it under the terms of the GNU Lesser General Public License as published by
  8. * the Free Software Foundation; either version 2.1 of the License, or (at
  9. * your option) any later version.
  10. *
  11. * StarPU is distributed in the hope that it will be useful, but
  12. * WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  14. *
  15. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  16. */
  17. #include <ctype.h>
  18. #include <stdio.h>
  19. #include <starpu.h>
  20. #include "blas.h"
  21. /*
  22. This files contains BLAS wrappers for the different BLAS implementations
  23. (eg. REFBLAS, STARPU_ATLAS, GOTOBLAS ...). We assume a Fortran orientation as most
  24. libraries do not supply C-based ordering.
  25. */
  26. #ifdef STARPU_ATLAS
  27. inline void SGEMM(char *transa, char *transb, int M, int N, int K,
  28. float alpha, const float *A, int lda, const float *B, int ldb,
  29. float beta, float *C, int ldc)
  30. {
  31. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  32. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  33. cblas_sgemm(CblasColMajor, ta, tb,
  34. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  35. }
  36. inline void DGEMM(char *transa, char *transb, int M, int N, int K,
  37. double alpha, double *A, int lda, double *B, int ldb,
  38. double beta, double *C, int ldc)
  39. {
  40. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  41. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  42. cblas_dgemm(CblasColMajor, ta, tb,
  43. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  44. }
  45. inline void SGEMV(char *transa, int M, int N, float alpha, float *A, int lda, float *X, int incX, float beta, float *Y, int incY)
  46. {
  47. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  48. cblas_sgemv(CblasColMajor, ta, M, N, alpha, A, lda,
  49. X, incX, beta, Y, incY);
  50. }
  51. inline void DGEMV(char *transa, int M, int N, double alpha, double *A, int lda, double *X, int incX, double beta, double *Y, int incY)
  52. {
  53. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  54. cblas_dgemv(CblasColMajor, ta, M, N, alpha, A, lda,
  55. X, incX, beta, Y, incY);
  56. }
  57. inline float SASUM(int N, float *X, int incX)
  58. {
  59. return cblas_sasum(N, X, incX);
  60. }
  61. inline double DASUM(int N, double *X, int incX)
  62. {
  63. return cblas_dasum(N, X, incX);
  64. }
  65. void SSCAL(int N, float alpha, float *X, int incX)
  66. {
  67. cblas_sscal(N, alpha, X, incX);
  68. }
  69. void DSCAL(int N, double alpha, double *X, int incX)
  70. {
  71. cblas_dscal(N, alpha, X, incX);
  72. }
  73. void STRSM (const char *side, const char *uplo, const char *transa,
  74. const char *diag, const int m, const int n,
  75. const float alpha, const float *A, const int lda,
  76. float *B, const int ldb)
  77. {
  78. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  79. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  80. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  81. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  82. cblas_strsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  83. }
  84. void DTRSM (const char *side, const char *uplo, const char *transa,
  85. const char *diag, const int m, const int n,
  86. const double alpha, const double *A, const int lda,
  87. double *B, const int ldb)
  88. {
  89. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  90. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  91. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  92. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  93. cblas_dtrsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  94. }
  95. void SSYR (const char *uplo, const int n, const float alpha,
  96. const float *x, const int incx, float *A, const int lda)
  97. {
  98. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  99. cblas_ssyr(CblasColMajor, uplo_, n, alpha, x, incx, A, lda);
  100. }
  101. void SSYRK (const char *uplo, const char *trans, const int n,
  102. const int k, const float alpha, const float *A,
  103. const int lda, const float beta, float *C,
  104. const int ldc)
  105. {
  106. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  107. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  108. cblas_ssyrk(CblasColMajor, uplo_, trans_, n, k, alpha, A, lda, beta, C, ldc);
  109. }
  110. void SGER(const int m, const int n, const float alpha,
  111. const float *x, const int incx, const float *y,
  112. const int incy, float *A, const int lda)
  113. {
  114. cblas_sger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  115. }
  116. void DGER(const int m, const int n, const double alpha,
  117. const double *x, const int incx, const double *y,
  118. const int incy, double *A, const int lda)
  119. {
  120. cblas_dger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  121. }
  122. void STRSV (const char *uplo, const char *trans, const char *diag,
  123. const int n, const float *A, const int lda, float *x,
  124. const int incx)
  125. {
  126. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  127. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  128. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  129. cblas_strsv(CblasColMajor, uplo_, trans_, diag_, n, A, lda, x, incx);
  130. }
  131. void STRMM(const char *side, const char *uplo, const char *transA,
  132. const char *diag, const int m, const int n,
  133. const float alpha, const float *A, const int lda,
  134. float *B, const int ldb)
  135. {
  136. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  137. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  138. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  139. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  140. cblas_strmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  141. }
  142. void DTRMM(const char *side, const char *uplo, const char *transA,
  143. const char *diag, const int m, const int n,
  144. const double alpha, const double *A, const int lda,
  145. double *B, const int ldb)
  146. {
  147. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  148. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  149. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  150. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  151. cblas_dtrmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  152. }
  153. void STRMV(const char *uplo, const char *transA, const char *diag,
  154. const int n, const float *A, const int lda, float *X,
  155. const int incX)
  156. {
  157. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  158. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  159. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  160. cblas_strmv(CblasColMajor, uplo_, transA_, diag_, n, A, lda, X, incX);
  161. }
  162. void SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  163. {
  164. cblas_saxpy(n, alpha, X, incX, Y, incY);
  165. }
  166. void DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  167. {
  168. cblas_daxpy(n, alpha, X, incX, Y, incY);
  169. }
  170. int ISAMAX (const int n, float *X, const int incX)
  171. {
  172. int retVal;
  173. retVal = cblas_isamax(n, X, incX);
  174. return retVal;
  175. }
  176. int IDAMAX (const int n, double *X, const int incX)
  177. {
  178. int retVal;
  179. retVal = cblas_idamax(n, X, incX);
  180. return retVal;
  181. }
  182. float SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  183. {
  184. return cblas_sdot(n, x, incx, y, incy);
  185. }
  186. double DDOT(const int n, const double *x, const int incx, const double *y, const int incy)
  187. {
  188. return cblas_ddot(n, x, incx, y, incy);
  189. }
  190. void SSWAP(const int n, float *x, const int incx, float *y, const int incy)
  191. {
  192. cblas_sswap(n, x, incx, y, incy);
  193. }
  194. void DSWAP(const int n, double *x, const int incx, double *y, const int incy)
  195. {
  196. cblas_dswap(n, x, incx, y, incy);
  197. }
  198. #elif defined(STARPU_GOTO) || defined(STARPU_SYSTEM_BLAS) || defined(STARPU_MKL)
  199. inline void SGEMM(char *transa, char *transb, int M, int N, int K,
  200. float alpha, const float *A, int lda, const float *B, int ldb,
  201. float beta, float *C, int ldc)
  202. {
  203. sgemm_(transa, transb, &M, &N, &K, &alpha,
  204. A, &lda, B, &ldb,
  205. &beta, C, &ldc);
  206. }
  207. inline void DGEMM(char *transa, char *transb, int M, int N, int K,
  208. double alpha, double *A, int lda, double *B, int ldb,
  209. double beta, double *C, int ldc)
  210. {
  211. dgemm_(transa, transb, &M, &N, &K, &alpha,
  212. A, &lda, B, &ldb,
  213. &beta, C, &ldc);
  214. }
  215. inline void SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
  216. float *X, int incX, float beta, float *Y, int incY)
  217. {
  218. sgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  219. }
  220. inline void DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
  221. double *X, int incX, double beta, double *Y, int incY)
  222. {
  223. dgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  224. }
  225. inline float SASUM(int N, float *X, int incX)
  226. {
  227. return sasum_(&N, X, &incX);
  228. }
  229. inline double DASUM(int N, double *X, int incX)
  230. {
  231. return dasum_(&N, X, &incX);
  232. }
  233. void SSCAL(int N, float alpha, float *X, int incX)
  234. {
  235. sscal_(&N, &alpha, X, &incX);
  236. }
  237. void DSCAL(int N, double alpha, double *X, int incX)
  238. {
  239. dscal_(&N, &alpha, X, &incX);
  240. }
  241. void STRSM (const char *side, const char *uplo, const char *transa,
  242. const char *diag, const int m, const int n,
  243. const float alpha, const float *A, const int lda,
  244. float *B, const int ldb)
  245. {
  246. strsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  247. }
  248. void DTRSM (const char *side, const char *uplo, const char *transa,
  249. const char *diag, const int m, const int n,
  250. const double alpha, const double *A, const int lda,
  251. double *B, const int ldb)
  252. {
  253. dtrsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  254. }
  255. void SSYR (const char *uplo, const int n, const float alpha,
  256. const float *x, const int incx, float *A, const int lda)
  257. {
  258. ssyr_(uplo, &n, &alpha, x, &incx, A, &lda);
  259. }
  260. void SSYRK (const char *uplo, const char *trans, const int n,
  261. const int k, const float alpha, const float *A,
  262. const int lda, const float beta, float *C,
  263. const int ldc)
  264. {
  265. ssyrk_(uplo, trans, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
  266. }
  267. void SGER(const int m, const int n, const float alpha,
  268. const float *x, const int incx, const float *y,
  269. const int incy, float *A, const int lda)
  270. {
  271. sger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  272. }
  273. void DGER(const int m, const int n, const double alpha,
  274. const double *x, const int incx, const double *y,
  275. const int incy, double *A, const int lda)
  276. {
  277. dger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  278. }
  279. void STRSV (const char *uplo, const char *trans, const char *diag,
  280. const int n, const float *A, const int lda, float *x,
  281. const int incx)
  282. {
  283. strsv_(uplo, trans, diag, &n, A, &lda, x, &incx);
  284. }
  285. void STRMM(const char *side, const char *uplo, const char *transA,
  286. const char *diag, const int m, const int n,
  287. const float alpha, const float *A, const int lda,
  288. float *B, const int ldb)
  289. {
  290. strmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  291. }
  292. void DTRMM(const char *side, const char *uplo, const char *transA,
  293. const char *diag, const int m, const int n,
  294. const double alpha, const double *A, const int lda,
  295. double *B, const int ldb)
  296. {
  297. dtrmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  298. }
  299. void STRMV(const char *uplo, const char *transA, const char *diag,
  300. const int n, const float *A, const int lda, float *X,
  301. const int incX)
  302. {
  303. strmv_(uplo, transA, diag, &n, A, &lda, X, &incX);
  304. }
  305. void SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  306. {
  307. saxpy_(&n, &alpha, X, &incX, Y, &incY);
  308. }
  309. void DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  310. {
  311. daxpy_(&n, &alpha, X, &incX, Y, &incY);
  312. }
  313. int ISAMAX (const int n, float *X, const int incX)
  314. {
  315. int retVal;
  316. retVal = isamax_ (&n, X, &incX);
  317. return retVal;
  318. }
  319. int IDAMAX (const int n, double *X, const int incX)
  320. {
  321. int retVal;
  322. retVal = idamax_ (&n, X, &incX);
  323. return retVal;
  324. }
  325. float SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  326. {
  327. float retVal = 0;
  328. /* GOTOBLAS will return a FLOATRET which is a double, not a float */
  329. retVal = (float)sdot_(&n, x, &incx, y, &incy);
  330. return retVal;
  331. }
  332. double DDOT(const int n, const double *x, const int incx, const double *y, const int incy)
  333. {
  334. return ddot_(&n, x, &incx, y, &incy);
  335. }
  336. void SSWAP(const int n, float *X, const int incX, float *Y, const int incY)
  337. {
  338. sswap_(&n, X, &incX, Y, &incY);
  339. }
  340. void DSWAP(const int n, double *X, const int incX, double *Y, const int incY)
  341. {
  342. dswap_(&n, X, &incX, Y, &incY);
  343. }
  344. #else
  345. #error "no BLAS lib available..."
  346. #endif