starpu-blas-wrapper.c 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <semaphore.h>
  17. #include <core/jobs.h>
  18. #include <core/workers.h>
  19. #include <core/dependencies/tags.h>
  20. #include <string.h>
  21. #include <math.h>
  22. #include <sys/types.h>
  23. #include <ctype.h>
  24. #include <pthread.h>
  25. #include <signal.h>
  26. #include <cblas.h>
  27. #include <datawizard/datawizard.h>
  28. #include <task-models/blas_model.h>
  29. #include <common/fxt.h>
  30. #include <starpu.h>
  31. #ifdef USE_CUDA
  32. #include <cuda.h>
  33. #endif
  34. #define BLOCK 75
  35. #include "starpu-blas-wrapper.h"
  36. extern struct data_interface_ops_t interface_blas_ops;
  37. static int core_sgemm = 0;
  38. static int cublas_sgemm = 0;
  39. static int core_strsm = 0;
  40. static int cublas_strsm = 0;
  41. static int inited = 0;
  42. void STARPU_INIT(void)
  43. {
  44. if (!inited) {
  45. inited = 1;
  46. starpu_init(NULL);
  47. }
  48. }
  49. void STARPU_TERMINATE(void)
  50. {
  51. starpu_shutdown();
  52. fprintf(stderr, "sgemm : core %d cublas %d\n", core_sgemm, cublas_sgemm);
  53. fprintf(stderr, "strsm : core %d cublas %d\n", core_strsm, cublas_strsm);
  54. }
  55. /*
  56. *
  57. * Specific to PaStiX !
  58. *
  59. */
  60. /*
  61. *
  62. * We "need" some custom filters
  63. *
  64. * VECTOR
  65. * (n)
  66. * / | \
  67. * VECTOR BLAS VECTOR
  68. * (n1) (n2)
  69. *
  70. * if n1 = 0 :
  71. * VECTOR
  72. * / \
  73. * BLAS VECTOR
  74. */
  75. struct divide_vector_in_blas_filter_args {
  76. uint32_t n1, n2; /* (total size of the first portion (vector length) n < root's n ! */
  77. uint32_t stride; /* stride of the first portion (need to be a multiple of n */
  78. };
  79. unsigned divide_vector_in_blas_filter(starpu_filter *f, starpu_data_handle root_data)
  80. {
  81. starpu_vector_interface_t *vector_root = &root_data->interface[0].vector;
  82. uint32_t nx = vector_root->nx;
  83. size_t elemsize = vector_root->elemsize;
  84. struct divide_vector_in_blas_filter_args *args = f->filter_arg_ptr;
  85. unsigned n1 = args->n1;
  86. unsigned n2 = args->n2;
  87. unsigned stride = args->stride;
  88. STARPU_ASSERT(n1 + n2 < nx);
  89. unsigned n3 = nx - n1 - n2;
  90. /* first allocate the children starpu_data_handle */
  91. root_data->children = calloc((n1==0)?2:3, sizeof(starpu_data_handle));
  92. STARPU_ASSERT(root_data->children);
  93. STARPU_ASSERT((n2 % args->stride) == 0);
  94. unsigned child = 0;
  95. unsigned node;
  96. if (n1 > 0)
  97. {
  98. for (node = 0; node < MAXNODES; node++)
  99. {
  100. starpu_vector_interface_t *local = &root_data->children[child].interface[node].vector;
  101. local->nx = n1;
  102. local->elemsize = elemsize;
  103. if (root_data->per_node[node].allocated) {
  104. local->ptr = root_data->interface[node].vector.ptr;
  105. }
  106. }
  107. child++;
  108. }
  109. for (node = 0; node < MAXNODES; node++)
  110. {
  111. starpu_blas_interface_t *local = &root_data->children[child].interface[node].blas;
  112. local->nx = stride;
  113. local->ny = n2/stride;
  114. local->ld = stride;
  115. local->elemsize = elemsize;
  116. if (root_data->per_node[node].allocated) {
  117. local->ptr = root_data->interface[node].vector.ptr + n1*elemsize;
  118. }
  119. struct starpu_data_state_t *state = &root_data->children[child];
  120. state->ops = &interface_blas_ops;
  121. }
  122. child++;
  123. for (node = 0; node < MAXNODES; node++)
  124. {
  125. starpu_vector_interface_t *local = &root_data->children[child].interface[node].vector;
  126. local->nx = n3;
  127. local->elemsize = elemsize;
  128. if (root_data->per_node[node].allocated) {
  129. local->ptr = root_data->interface[node].vector.ptr + (n1+n2)*elemsize;
  130. }
  131. }
  132. return (n1==0)?2:3;
  133. }
  134. static data_state *cblktab;
  135. static void _cublas_cblk_strsm_callback(void *sem)
  136. {
  137. sem_t *semptr = sem;
  138. sem_post(semptr);
  139. }
  140. void STARPU_MONITOR_DATA(unsigned ncols)
  141. {
  142. cblktab = calloc(ncols, sizeof(data_state));
  143. }
  144. void STARPU_MONITOR_CBLK(unsigned col, float *data, unsigned stride, unsigned width)
  145. {
  146. //void starpu_monitor_blas_data(struct starpu_data_state_t *state, uint32_t home_node,
  147. // uintptr_t ptr, uint32_t ld, uint32_t nx,
  148. // uint32_t ny, size_t elemsize);
  149. //fprintf(stderr, "col %d data %p stride %d width %d\n", col, data, stride, width);
  150. starpu_monitor_blas_data(&cblktab[col], 0 /* home */,
  151. (uintptr_t) data, stride, stride, width, sizeof(float));
  152. }
  153. static data_state work_block_1;
  154. static data_state work_block_2;
  155. void allocate_maxbloktab_on_cublas(starpu_data_interface_t *descr __attribute__((unused)), void *arg __attribute__((unused)))
  156. {
  157. request_data_allocation(&work_block_1, 1);
  158. request_data_allocation(&work_block_2, 1);
  159. starpu_filter f1, f2;
  160. struct divide_vector_in_blas_filter_args args1, args2;
  161. f1.filter_func = divide_vector_in_blas_filter;
  162. args1.n1 = 1; /* XXX random ... */
  163. args1.n2 = 2;
  164. args1.stride = 1;
  165. f1.filter_arg_ptr = &args1;
  166. starpu_partition_data(&work_block_1, &f1);
  167. f2.filter_func = divide_vector_in_blas_filter;
  168. args2.n1 = 0;
  169. args2.n2 = 2;
  170. args2.stride = 1;
  171. f2.filter_arg_ptr = &args2;
  172. starpu_partition_data(&work_block_2, &f2);
  173. }
  174. void STARPU_DECLARE_WORK_BLOCKS(float *maxbloktab1, float *maxbloktab2, unsigned solv_coefmax)
  175. {
  176. starpu_monitor_vector_data(&work_block_1, 0 /* home */, (uintptr_t)maxbloktab1, solv_coefmax, sizeof(float));
  177. starpu_monitor_vector_data(&work_block_2, 0 /* home */, (uintptr_t)maxbloktab2, solv_coefmax, sizeof(float));
  178. starpu_codelet cl;
  179. job_t j;
  180. sem_t sem;
  181. /* initialize codelet */
  182. cl.where = CUBLAS;
  183. cl.cublas_func = allocate_maxbloktab_on_cublas;
  184. j = job_create();
  185. j->cb = _cublas_cblk_strsm_callback;
  186. j->argcb = &sem;
  187. j->cl = &cl;
  188. j->cl_arg = NULL;
  189. j->nbuffers = 0;
  190. j->cl->model = NULL;
  191. sem_init(&sem, 0, 0U);
  192. /* submit the codelet */
  193. submit_job(j);
  194. /* wait for its completion */
  195. sem_wait(&sem);
  196. sem_destroy(&sem);
  197. }
  198. void _core_cblk_strsm(starpu_data_interface_t *descr, void *arg __attribute__((unused)))
  199. {
  200. uint32_t nx, ny, ld;
  201. nx = descr[0].blas.nx;
  202. ny = descr[0].blas.ny;
  203. ld = descr[0].blas.ld;
  204. float *diag_cblkdata, *extra_cblkdata;
  205. diag_cblkdata = (float *)descr[0].blas.ptr;
  206. extra_cblkdata = diag_cblkdata + ny;
  207. unsigned m = nx - ny;
  208. unsigned n = ny;
  209. // SOPALIN_TRSM("R","L","T","U",dimb,dima,fun,ga,stride,gb,stride);
  210. core_strsm++;
  211. cblas_strsm(CblasColMajor, CblasRight, CblasLower, CblasTrans, CblasUnit, m, n, 1.0f,
  212. diag_cblkdata, ld, extra_cblkdata, ld);
  213. }
  214. void _cublas_cblk_strsm(starpu_data_interface_t *descr, void *arg __attribute__((unused)))
  215. {
  216. uint32_t nx, ny, ld;
  217. nx = descr[0].blas.nx;
  218. ny = descr[0].blas.ny;
  219. ld = descr[0].blas.ld;
  220. float *diag_cblkdata, *extra_cblkdata;
  221. diag_cblkdata = (float *)descr[0].blas.ptr;
  222. extra_cblkdata = diag_cblkdata + ny;
  223. unsigned m = nx - ny;
  224. unsigned n = ny;
  225. cublas_strsm++;
  226. cublasStrsm ('R', 'L', 'T', 'U', m, n, 1.0,
  227. diag_cblkdata, ld,
  228. extra_cblkdata, ld);
  229. cublasStatus st = cublasGetError();
  230. if (st) fprintf(stderr, "ERROR %d\n", st);
  231. STARPU_ASSERT(st == CUBLAS_STATUS_SUCCESS);
  232. }
  233. static struct starpu_perfmodel_t starpu_cblk_strsm = {
  234. .per_arch = {
  235. [STARPU_CORE_DEFAULT] = { .cost_model = starpu_cblk_strsm_core_cost },
  236. [STARPU_CUDA_DEFAULT] = { .cost_model = starpu_cblk_strsm_cuda_cost }
  237. },
  238. // .type = REGRESSION_BASED,
  239. .type = PER_ARCH,
  240. .symbol = "starpu_cblk_strsm"
  241. };
  242. void STARPU_CBLK_STRSM(unsigned col)
  243. {
  244. /* perform a strsm on the block column */
  245. starpu_codelet cl;
  246. job_t j;
  247. sem_t sem;
  248. /* initialize codelet */
  249. cl.where = CORE|CUBLAS;
  250. cl.core_func = _core_cblk_strsm;
  251. cl.cublas_func = _cublas_cblk_strsm;
  252. j = job_create();
  253. // j->where = (starpu_get_blas_nx(&cblktab[col]) > BLOCK && starpu_get_blas_ny(&cblktab[col]) > BLOCK)? CUBLAS:CORE;
  254. j->cb = _cublas_cblk_strsm_callback;
  255. j->argcb = &sem;
  256. j->cl = &cl;
  257. j->cl_arg = NULL;
  258. j->nbuffers = 1;
  259. /* we could be a little more precise actually */
  260. j->buffers[0].state = &cblktab[col];
  261. j->buffers[0].mode = RW;
  262. j->cl->model = &starpu_cblk_strsm;
  263. sem_init(&sem, 0, 0U);
  264. /* submit the codelet */
  265. submit_job(j);
  266. /* wait for its completion */
  267. sem_wait(&sem);
  268. sem_destroy(&sem);
  269. }
  270. struct starpu_compute_contrib_compact_args {
  271. unsigned stride;
  272. int dimi;
  273. int dimj;
  274. int dima;
  275. };
  276. void _core_compute_contrib_compact(starpu_data_interface_t *descr, void *arg)
  277. {
  278. struct starpu_compute_contrib_compact_args *args = arg;
  279. float *gaik = (float *)descr[0].blas.ptr + args->dima;
  280. float *gb = (float *)descr[1].blas.ptr;
  281. unsigned strideb = (unsigned)descr[1].blas.ld;
  282. float *gc = (float *)descr[2].blas.ptr;
  283. unsigned stridec = (unsigned)descr[2].blas.ld;
  284. core_sgemm++;
  285. cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans,
  286. args->dimi, args->dimj, args->dima,
  287. 1.0f, gaik, args->stride,
  288. gb, strideb,
  289. 0.0 , gc, stridec);
  290. }
  291. void _cublas_compute_contrib_compact(starpu_data_interface_t *descr, void *arg)
  292. {
  293. struct starpu_compute_contrib_compact_args *args = arg;
  294. float *gaik = (float *)descr[0].blas.ptr + args->dima;
  295. float *gb = (float *)descr[1].blas.ptr;
  296. unsigned strideb = (unsigned)descr[1].blas.ld;
  297. float *gc = (float *)descr[2].blas.ptr;
  298. unsigned stridec = (unsigned)descr[2].blas.ld;
  299. cublas_sgemm++;
  300. cublasSgemm('N','T', args->dimi, args->dimj, args->dima,
  301. 1.0, gaik, args->stride,
  302. gb, strideb,
  303. 0.0, gc, stridec);
  304. cublasStatus st = cublasGetError();
  305. if (st) fprintf(stderr, "ERROR %d\n", st);
  306. STARPU_ASSERT(st == CUBLAS_STATUS_SUCCESS);
  307. }
  308. static struct starpu_perfmodel_t starpu_compute_contrib_compact = {
  309. .per_arch = {
  310. [STARPU_CORE_DEFAULT] = { .cost_model = starpu_compute_contrib_compact_core_cost },
  311. [STARPU_CUDA_DEFAULT] = { .cost_model = starpu_compute_contrib_compact_cuda_cost }
  312. },
  313. // .type = REGRESSION_BASED,
  314. .type = PER_ARCH,
  315. .symbol = "starpu_compute_contrib_compact"
  316. };
  317. int update_work_blocks(unsigned col, int dimi, int dimj, int dima, int stride)
  318. {
  319. /* be paranoid XXX */
  320. notify_data_modification(get_sub_data(&work_block_1, 1, 0), 0);
  321. notify_data_modification(get_sub_data(&work_block_1, 1, 1), 0);
  322. //notify_data_modification(get_sub_data(&work_block_1, 1, 2), 0);
  323. notify_data_modification(get_sub_data(&work_block_2, 1, 0), 0);
  324. notify_data_modification(get_sub_data(&work_block_2, 1, 1), 0);
  325. notify_data_modification(&cblktab[col], 0);
  326. starpu_unpartition_data(&work_block_1, 0);
  327. starpu_unpartition_data(&work_block_2, 0);
  328. starpu_filter f1, f2;
  329. struct divide_vector_in_blas_filter_args args1, args2;
  330. f1.filter_func = divide_vector_in_blas_filter;
  331. args1.n1 = stride - dima - dimi; //STARPU_ASSERT(args1.n1 != 0);
  332. args1.n2 = (stride - dima)*dima;
  333. args1.stride = (stride - dima);
  334. f1.filter_arg_ptr = &args1;
  335. starpu_partition_data(&work_block_1, &f1);
  336. f2.filter_func = divide_vector_in_blas_filter;
  337. args2.n1 = 0;
  338. args2.n2 = dimi*dimj;
  339. args2.stride = dimi;
  340. f2.filter_arg_ptr = &args2;
  341. starpu_partition_data(&work_block_2, &f2);
  342. return (args1.n1!=0)?3:2;
  343. }
  344. void STARPU_COMPUTE_CONTRIB_COMPACT(unsigned col, int dimi, int dimj, int dima, int stride)
  345. {
  346. // CUBLAS_SGEMM("N","T",dimi,dimj,dima, 1.0,gaik,stride,gb,stride-dima,
  347. // 0.0 ,gc,dimi);
  348. struct starpu_compute_contrib_compact_args args;
  349. args.stride = stride;
  350. args.dimi = dimi;
  351. args.dimj = dimj;
  352. args.dima = dima;
  353. starpu_codelet cl;
  354. job_t j;
  355. sem_t sem;
  356. /* initialize codelet */
  357. cl.where = CUBLAS|CORE;
  358. cl.core_func = _core_compute_contrib_compact;
  359. cl.cublas_func = _cublas_compute_contrib_compact;
  360. j = job_create();
  361. j->cb = _cublas_cblk_strsm_callback;
  362. j->argcb = &sem;
  363. j->cl = &cl;
  364. j->cl_arg = &args;
  365. j->cl->model = &starpu_compute_contrib_compact;
  366. int ret;
  367. ret = update_work_blocks(col, dimi, dimj, dima, stride);
  368. j->nbuffers = 3;
  369. /* we could be a little more precise actually */
  370. j->buffers[0].state = &cblktab[col]; // gaik
  371. j->buffers[0].mode = R;
  372. j->buffers[1].state = get_sub_data(&work_block_1, 1, (ret==2)?0:1);
  373. j->buffers[1].mode = R;
  374. j->buffers[2].state = get_sub_data(&work_block_2, 1, 0);;
  375. j->buffers[2].mode = RW; // XXX W
  376. sem_init(&sem, 0, 0U);
  377. /* submit the codelet */
  378. submit_job(j);
  379. /* wait for its completion */
  380. sem_wait(&sem);
  381. sem_destroy(&sem);
  382. }
  383. /*
  384. *
  385. * SGEMM
  386. *
  387. */
  388. struct sgemm_args {
  389. char transa;
  390. char transb;
  391. int m, n, k;
  392. float alpha;
  393. float beta;
  394. };
  395. void _cublas_sgemm(starpu_data_interface_t *descr, void *arg)
  396. {
  397. float *A, *B, *C;
  398. uint32_t nxA, nyA, ldA;
  399. uint32_t nxB, nyB, ldB;
  400. uint32_t nxC, nyC, ldC;
  401. A = (float *)descr[0].blas.ptr;
  402. nxA = descr[0].blas.nx;
  403. nyA = descr[0].blas.ny;
  404. ldA = descr[0].blas.ld;
  405. B = (float *)descr[1].blas.ptr;
  406. nxB = descr[1].blas.nx;
  407. nyB = descr[1].blas.ny;
  408. ldB = descr[1].blas.ld;
  409. C = (float *)descr[2].blas.ptr;
  410. nxC = descr[2].blas.nx;
  411. nyC = descr[2].blas.ny;
  412. ldC = descr[2].blas.ld;
  413. struct sgemm_args *args = arg;
  414. // fprintf(stderr, "CUBLAS SGEMM nxA %d nyA %d nxB %d nyB %d nxC %d nyC %d lda %d ldb %d ldc %d\n", nxA, nyA, nxB, nyB, nxC, nyC, ldA, ldB, ldC);
  415. // STARPU_ASSERT(nxA == nxC);
  416. // STARPU_ASSERT(nyA == nxB);
  417. // STARPU_ASSERT(nyB == nyC);
  418. //
  419. // STARPU_ASSERT(nxA <= ldA);
  420. // STARPU_ASSERT(nxB <= ldB);
  421. // STARPU_ASSERT(nxC <= ldC);
  422. cublasSgemm (args->transa, args->transb, args->m, args->n, args->k, args->alpha, A, (int)ldA,
  423. B, (int)ldB, args->beta, C, (int)ldC);
  424. cublasStatus st = cublasGetError();
  425. if (st) fprintf(stderr, "ERROR %d\n", st);
  426. STARPU_ASSERT(st == CUBLAS_STATUS_SUCCESS);
  427. }
  428. static void _cublas_sgemm_callback(void *sem)
  429. {
  430. sem_t *semptr = sem;
  431. sem_post(semptr);
  432. }
  433. void STARPU_SGEMM (const char *transa, const char *transb, const int m,
  434. const int n, const int k, const float alpha,
  435. const float *A, const int lda, const float *B,
  436. const int ldb, const float beta, float *C, const int ldc)
  437. {
  438. struct sgemm_args args;
  439. args.transa = *transa;
  440. args.transb = *transb;
  441. args.alpha = alpha;
  442. args.beta = beta;
  443. args.m = m;
  444. args.n = n;
  445. args.k = k;
  446. data_state A_state;
  447. data_state B_state;
  448. data_state C_state;
  449. starpu_codelet cl;
  450. job_t j;
  451. sem_t sem;
  452. // fprintf(stderr, "STARPU - SGEMM - TRANSA %c TRANSB %c m %d n %d k %d lda %d ldb %d ldc %d \n", *transa, *transb, m, n, k, lda, ldb, ldc);
  453. if (toupper(*transa) == 'N')
  454. {
  455. starpu_monitor_blas_data(&A_state, 0, (uintptr_t)A, lda, m, k, sizeof(float));
  456. }
  457. else
  458. {
  459. starpu_monitor_blas_data(&A_state, 0, (uintptr_t)A, lda, k, m, sizeof(float));
  460. }
  461. if (toupper(*transb) == 'N')
  462. {
  463. starpu_monitor_blas_data(&B_state, 0, (uintptr_t)B, ldb, k, n, sizeof(float));
  464. }
  465. else
  466. {
  467. starpu_monitor_blas_data(&B_state, 0, (uintptr_t)B, ldb, n, k, sizeof(float));
  468. }
  469. starpu_monitor_blas_data(&C_state, 0, (uintptr_t)C, ldc, m, n, sizeof(float));
  470. /* initialize codelet */
  471. cl.where = CUBLAS;
  472. //cl.core_func = _core_strsm;
  473. cl.cublas_func = _cublas_sgemm;
  474. j = job_create();
  475. j->cb = _cublas_sgemm_callback;
  476. j->argcb = &sem;
  477. j->cl = &cl;
  478. j->cl_arg = &args;
  479. j->nbuffers = 3;
  480. j->buffers[0].state = &A_state;
  481. j->buffers[0].mode = R;
  482. j->buffers[1].state = &B_state;
  483. j->buffers[1].mode = R;
  484. j->buffers[2].state = &C_state;
  485. j->buffers[2].mode = RW;
  486. j->cl->model = NULL;
  487. sem_init(&sem, 0, 0U);
  488. /* submit the codelet */
  489. submit_job(j);
  490. /* wait for its completion */
  491. sem_wait(&sem);
  492. sem_destroy(&sem);
  493. /* make sure data are in memory again */
  494. starpu_unpartition_data(&A_state, 0);
  495. starpu_unpartition_data(&B_state, 0);
  496. starpu_unpartition_data(&C_state, 0);
  497. //starpu_delete_data(&A_state);
  498. //starpu_delete_data(&B_state);
  499. //starpu_delete_data(&C_state);
  500. // fprintf(stderr, "SGEMM done\n");
  501. }
  502. /*
  503. *
  504. * STRSM
  505. *
  506. */
  507. struct strsm_args {
  508. char side;
  509. char uplo;
  510. char transa;
  511. char diag;
  512. float alpha;
  513. int m,n;
  514. };
  515. //
  516. //void _core_strsm(starpu_data_interface_t *descr, void *arg)
  517. //{
  518. // float *A, *B;
  519. // uint32_t nxA, nyA, ldA;
  520. // uint32_t nxB, nyB, ldB;
  521. //
  522. // A = (float *)descr[0].blas.ptr;
  523. // nxA = descr[0].blas.nx;
  524. // nyA = descr[0].blas.ny;
  525. // ldA = descr[0].blas.ld;
  526. //
  527. // B = (float *)descr[1].blas.ptr;
  528. // nxB = descr[1].blas.nx;
  529. // nyB = descr[1].blas.ny;
  530. // ldB = descr[1].blas.ld;
  531. //
  532. // struct strsm_args *args = arg;
  533. //
  534. // fprintf(stderr, "CORE STRSM nxA %d nyA %d nxB %d nyB %d lda %d ldb %d\n", nxA, nyA, nxB, nyB, ldA, ldB);
  535. //
  536. // SOPALIN_TRSM("R","L","T","U",dimb,dima,fun,ga,stride,gb,stride);
  537. //
  538. //}
  539. /*
  540. *
  541. *
  542. *
  543. */
  544. void CUBLAS_SGEMM (const char *transa, const char *transb, const int m,
  545. const int n, const int k, const float alpha,
  546. const float *A, const int lda, const float *B,
  547. const int ldb, const float beta, float *C, const int ldc)
  548. {
  549. int ka, kb;
  550. float *devPtrA, *devPtrB, *devPtrC;
  551. // printf("CUBLAS SGEMM : m %d n %d k %d lda %d ldb %d ldc %d\n", m, n, k, lda, ldb, ldc);
  552. /* A - REAL array of DIMENSION ( LDA, ka ), where ka is
  553. * k when TRANSA = 'N' or 'n', and is m otherwise.
  554. * Before entry with TRANSA = 'N' or 'n', the leading m by k
  555. * part of the array A must contain the matrix A, otherwise
  556. * the leading k by m part of the array A must contain the
  557. * matrix A.
  558. */
  559. ka = (toupper(transa[0]) == 'N') ? k : m;
  560. cublasAlloc (lda * ka, sizeof(devPtrA[0]), (void**)&devPtrA);
  561. if (toupper(transa[0]) == 'N') {
  562. cublasSetMatrix (STARPU_MIN(m,lda), k, sizeof(A[0]), A, lda, devPtrA,
  563. lda);
  564. } else {
  565. cublasSetMatrix (STARPU_MIN(k,lda), m, sizeof(A[0]), A, lda, devPtrA,
  566. lda);
  567. }
  568. /* B - REAL array of DIMENSION ( LDB, kb ), where kb is
  569. * n when TRANSB = 'N' or 'n', and is k otherwise.
  570. * Before entry with TRANSB = 'N' or 'n', the leading k by n
  571. * part of the array B must contain the matrix B, otherwise
  572. * the leading n by k part of the array B must contain the
  573. * matrix B.
  574. */
  575. kb = (toupper(transb[0]) == 'N') ? n : k;
  576. cublasAlloc (ldb * kb, sizeof(devPtrB[0]), (void**)&devPtrB);
  577. if (toupper(transb[0]) == 'N') {
  578. cublasSetMatrix (STARPU_MIN(k,ldb), n, sizeof(B[0]), B, ldb, devPtrB,
  579. ldb);
  580. } else {
  581. cublasSetMatrix (STARPU_MIN(n,ldb), k, sizeof(B[0]), B, ldb, devPtrB,
  582. ldb);
  583. }
  584. /* C - REAL array of DIMENSION ( LDC, n ).
  585. * Before entry, the leading m by n part of the array C must
  586. * contain the matrix C, except when beta is zero, in which
  587. * case C need not be set on entry.
  588. * On exit, the array C is overwritten by the m by n matrix
  589. */
  590. cublasAlloc ((ldc) * (n), sizeof(devPtrC[0]), (void**)&devPtrC);
  591. cublasSetMatrix (STARPU_MIN(m,ldc), n, sizeof(C[0]), C, ldc, devPtrC, ldc);
  592. cublasSgemm (transa[0], transb[0], m, n, k, alpha, devPtrA, lda,
  593. devPtrB, ldb, beta, devPtrC, ldc);
  594. cublasGetMatrix (STARPU_MIN(m,ldc), n, sizeof(C[0]), devPtrC, ldc, C, ldc);
  595. cublasFree (devPtrA);
  596. cublasFree (devPtrB);
  597. cublasFree (devPtrC);
  598. }