blas.c 12 KB


  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <ctype.h>
  17. #include <stdio.h>
  18. #include <starpu.h>
  19. #include "blas.h"
  20. /*
  21. This files contains BLAS wrappers for the different BLAS implementations
  22. (eg. REFBLAS, ATLAS, GOTOBLAS ...). We assume a Fortran orientation as most
  23. libraries do not supply C-based ordering.
  24. */
  25. #ifdef ATLAS
  26. inline void SGEMM(char *transa, char *transb, int M, int N, int K,
  27. float alpha, float *A, int lda, float *B, int ldb,
  28. float beta, float *C, int ldc)
  29. {
  30. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  31. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  32. cblas_sgemm(CblasColMajor, ta, tb,
  33. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  34. }
  35. inline void DGEMM(char *transa, char *transb, int M, int N, int K,
  36. double alpha, double *A, int lda, double *B, int ldb,
  37. double beta, double *C, int ldc)
  38. {
  39. enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  40. enum CBLAS_TRANSPOSE tb = (toupper(transb[0]) == 'N')?CblasNoTrans:CblasTrans;
  41. cblas_dgemm(CblasColMajor, ta, tb,
  42. M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
  43. }
  44. inline float SASUM(int N, float *X, int incX)
  45. {
  46. return cblas_sasum(N, X, incX);
  47. }
  48. inline double DASUM(int N, double *X, int incX)
  49. {
  50. return cblas_dasum(N, X, incX);
  51. }
  52. void SSCAL(int N, float alpha, float *X, int incX)
  53. {
  54. cblas_sscal(N, alpha, X, incX);
  55. }
  56. void DSCAL(int N, double alpha, double *X, int incX)
  57. {
  58. cblas_dscal(N, alpha, X, incX);
  59. }
  60. void STRSM (const char *side, const char *uplo, const char *transa,
  61. const char *diag, const int m, const int n,
  62. const float alpha, const float *A, const int lda,
  63. float *B, const int ldb)
  64. {
  65. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  66. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  67. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  68. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  69. cblas_strsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  70. }
  71. void DTRSM (const char *side, const char *uplo, const char *transa,
  72. const char *diag, const int m, const int n,
  73. const double alpha, const double *A, const int lda,
  74. double *B, const int ldb)
  75. {
  76. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  77. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  78. enum CBLAS_TRANSPOSE transa_ = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
  79. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  80. cblas_dtrsm(CblasColMajor, side_, uplo_, transa_, diag_, m, n, alpha, A, lda, B, ldb);
  81. }
  82. void SSYR (const char *uplo, const int n, const float alpha,
  83. const float *x, const int incx, float *A, const int lda)
  84. {
  85. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  86. cblas_ssyr(CblasColMajor, uplo_, n, alpha, x, incx, A, lda);
  87. }
  88. void SSYRK (const char *uplo, const char *trans, const int n,
  89. const int k, const float alpha, const float *A,
  90. const int lda, const float beta, float *C,
  91. const int ldc)
  92. {
  93. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  94. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  95. cblas_ssyrk(CblasColMajor, uplo_, trans_, n, k, alpha, A, lda, beta, C, ldc);
  96. }
  97. void SGER(const int m, const int n, const float alpha,
  98. const float *x, const int incx, const float *y,
  99. const int incy, float *A, const int lda)
  100. {
  101. cblas_sger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  102. }
  103. void DGER(const int m, const int n, const double alpha,
  104. const double *x, const int incx, const double *y,
  105. const int incy, double *A, const int lda)
  106. {
  107. cblas_dger(CblasColMajor, m, n, alpha, x, incx, y, incy, A, lda);
  108. }
  109. void STRSV (const char *uplo, const char *trans, const char *diag,
  110. const int n, const float *A, const int lda, float *x,
  111. const int incx)
  112. {
  113. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  114. enum CBLAS_TRANSPOSE trans_ = (toupper(trans[0]) == 'N')?CblasNoTrans:CblasTrans;
  115. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  116. cblas_strsv(CblasColMajor, uplo_, trans_, diag_, n, A, lda, x, incx);
  117. }
  118. void STRMM(const char *side, const char *uplo, const char *transA,
  119. const char *diag, const int m, const int n,
  120. const float alpha, const float *A, const int lda,
  121. float *B, const int ldb)
  122. {
  123. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  124. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  125. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  126. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  127. cblas_strmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  128. }
  129. void DTRMM(const char *side, const char *uplo, const char *transA,
  130. const char *diag, const int m, const int n,
  131. const double alpha, const double *A, const int lda,
  132. double *B, const int ldb)
  133. {
  134. enum CBLAS_SIDE side_ = (toupper(side[0]) == 'L')?CblasLeft:CblasRight;
  135. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  136. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  137. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  138. cblas_dtrmm(CblasColMajor, side_, uplo_, transA_, diag_, m, n, alpha, A, lda, B, ldb);
  139. }
  140. void STRMV(const char *uplo, const char *transA, const char *diag,
  141. const int n, const float *A, const int lda, float *X,
  142. const int incX)
  143. {
  144. enum CBLAS_UPLO uplo_ = (toupper(uplo[0]) == 'U')?CblasUpper:CblasLower;
  145. enum CBLAS_TRANSPOSE transA_ = (toupper(transA[0]) == 'N')?CblasNoTrans:CblasTrans;
  146. enum CBLAS_DIAG diag_ = (toupper(diag[0]) == 'N')?CblasNonUnit:CblasUnit;
  147. cblas_strmv(CblasColMajor, uplo_, transA_, diag_, n, A, lda, X, incX);
  148. }
  149. void SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  150. {
  151. cblas_saxpy(n, alpha, X, incX, Y, incY);
  152. }
  153. void DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  154. {
  155. cblas_daxpy(n, alpha, X, incX, Y, incY);
  156. }
  157. int ISAMAX (const int n, float *X, const int incX)
  158. {
  159. int retVal;
  160. retVal = cblas_isamax(n, X, incX);
  161. return retVal;
  162. }
  163. int IDAMAX (const int n, double *X, const int incX)
  164. {
  165. int retVal;
  166. retVal = cblas_idamax(n, X, incX);
  167. return retVal;
  168. }
  169. float SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  170. {
  171. return cblas_sdot(n, x, incx, y, incy);
  172. }
  173. void SSWAP(const int n, float *x, const int incx, float *y, const int incy)
  174. {
  175. cblas_sswap(n, x, incx, y, incy);
  176. }
  177. void DSWAP(const int n, double *x, const int incx, double *y, const int incy)
  178. {
  179. cblas_dswap(n, x, incx, y, incy);
  180. }
  181. #elif defined(GOTO) || defined(SYSTEM_BLAS)
  182. inline void SGEMM(char *transa, char *transb, int M, int N, int K,
  183. float alpha, float *A, int lda, float *B, int ldb,
  184. float beta, float *C, int ldc)
  185. {
  186. sgemm_(transa, transb, &M, &N, &K, &alpha,
  187. A, &lda, B, &ldb,
  188. &beta, C, &ldc);
  189. }
  190. inline void DGEMM(char *transa, char *transb, int M, int N, int K,
  191. double alpha, double *A, int lda, double *B, int ldb,
  192. double beta, double *C, int ldc)
  193. {
  194. dgemm_(transa, transb, &M, &N, &K, &alpha,
  195. A, &lda, B, &ldb,
  196. &beta, C, &ldc);
  197. }
  198. inline float SASUM(int N, float *X, int incX)
  199. {
  200. return sasum_(&N, X, &incX);
  201. }
  202. inline double DASUM(int N, double *X, int incX)
  203. {
  204. return dasum_(&N, X, &incX);
  205. }
  206. void SSCAL(int N, float alpha, float *X, int incX)
  207. {
  208. sscal_(&N, &alpha, X, &incX);
  209. }
  210. void DSCAL(int N, double alpha, double *X, int incX)
  211. {
  212. dscal_(&N, &alpha, X, &incX);
  213. }
  214. void STRSM (const char *side, const char *uplo, const char *transa,
  215. const char *diag, const int m, const int n,
  216. const float alpha, const float *A, const int lda,
  217. float *B, const int ldb)
  218. {
  219. strsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  220. }
  221. void DTRSM (const char *side, const char *uplo, const char *transa,
  222. const char *diag, const int m, const int n,
  223. const double alpha, const double *A, const int lda,
  224. double *B, const int ldb)
  225. {
  226. dtrsm_(side, uplo, transa, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  227. }
  228. void SSYR (const char *uplo, const int n, const float alpha,
  229. const float *x, const int incx, float *A, const int lda)
  230. {
  231. ssyr_(uplo, &n, &alpha, x, &incx, A, &lda);
  232. }
  233. void SSYRK (const char *uplo, const char *trans, const int n,
  234. const int k, const float alpha, const float *A,
  235. const int lda, const float beta, float *C,
  236. const int ldc)
  237. {
  238. ssyrk_(uplo, trans, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
  239. }
  240. void SGER(const int m, const int n, const float alpha,
  241. const float *x, const int incx, const float *y,
  242. const int incy, float *A, const int lda)
  243. {
  244. sger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  245. }
  246. void DGER(const int m, const int n, const double alpha,
  247. const double *x, const int incx, const double *y,
  248. const int incy, double *A, const int lda)
  249. {
  250. dger_(&m, &n, &alpha, x, &incx, y, &incy, A, &lda);
  251. }
  252. void STRSV (const char *uplo, const char *trans, const char *diag,
  253. const int n, const float *A, const int lda, float *x,
  254. const int incx)
  255. {
  256. strsv_(uplo, trans, diag, &n, A, &lda, x, &incx);
  257. }
  258. void STRMM(const char *side, const char *uplo, const char *transA,
  259. const char *diag, const int m, const int n,
  260. const float alpha, const float *A, const int lda,
  261. float *B, const int ldb)
  262. {
  263. strmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  264. }
  265. void DTRMM(const char *side, const char *uplo, const char *transA,
  266. const char *diag, const int m, const int n,
  267. const double alpha, const double *A, const int lda,
  268. double *B, const int ldb)
  269. {
  270. dtrmm_(side, uplo, transA, diag, &m, &n, &alpha, A, &lda, B, &ldb);
  271. }
  272. void STRMV(const char *uplo, const char *transA, const char *diag,
  273. const int n, const float *A, const int lda, float *X,
  274. const int incX)
  275. {
  276. strmv_(uplo, transA, diag, &n, A, &lda, X, &incX);
  277. }
  278. void SAXPY(const int n, const float alpha, float *X, const int incX, float *Y, const int incY)
  279. {
  280. saxpy_(&n, &alpha, X, &incX, Y, &incY);
  281. }
  282. void DAXPY(const int n, const double alpha, double *X, const int incX, double *Y, const int incY)
  283. {
  284. daxpy_(&n, &alpha, X, &incX, Y, &incY);
  285. }
  286. int ISAMAX (const int n, float *X, const int incX)
  287. {
  288. int retVal;
  289. retVal = isamax_ (&n, X, &incX);
  290. return retVal;
  291. }
  292. int IDAMAX (const int n, double *X, const int incX)
  293. {
  294. int retVal;
  295. retVal = idamax_ (&n, X, &incX);
  296. return retVal;
  297. }
  298. float SDOT(const int n, const float *x, const int incx, const float *y, const int incy)
  299. {
  300. float retVal = 0;
  301. /* GOTOBLAS will return a FLOATRET which is a double, not a float */
  302. retVal = (float)sdot_(&n, x, &incx, y, &incy);
  303. return retVal;
  304. }
  305. void SSWAP(const int n, float *X, const int incX, float *Y, const int incY)
  306. {
  307. sswap_(&n, X, &incX, Y, &incY);
  308. }
  309. void DSWAP(const int n, double *X, const int incX, double *Y, const int incY)
  310. {
  311. dswap_(&n, X, &incX, Y, &incY);
  312. }
  313. #else
  314. #error "no BLAS lib available..."
  315. #endif