starpu-blas-wrapper.c 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <semaphore.h>
  17. #include <core/jobs.h>
  18. #include <core/workers.h>
  19. #include <core/dependencies/tags.h>
  20. #include <string.h>
  21. #include <math.h>
  22. #include <sys/types.h>
  23. #include <ctype.h>
  24. #include <pthread.h>
  25. #include <signal.h>
  26. #include <cblas.h>
  27. #include <datawizard/datawizard.h>
  28. #include <task-models/blas_model.h>
  29. #include <common/fxt.h>
  30. #include <starpu.h>
  31. #ifdef USE_CUDA
  32. #include <cuda.h>
  33. #endif
  34. #define BLOCK 75
  35. #include "starpu-blas-wrapper.h"
  36. extern struct data_interface_ops_t interface_blas_ops;
  37. static int core_sgemm = 0;
  38. static int cublas_sgemm = 0;
  39. static int core_strsm = 0;
  40. static int cublas_strsm = 0;
  41. static int inited = 0;
  42. void STARPU_INIT(void)
  43. {
  44. if (!inited) {
  45. inited = 1;
  46. starpu_init(NULL);
  47. }
  48. }
  49. void STARPU_TERMINATE(void)
  50. {
  51. starpu_shutdown();
  52. fprintf(stderr, "sgemm : core %d cublas %d\n", core_sgemm, cublas_sgemm);
  53. fprintf(stderr, "strsm : core %d cublas %d\n", core_strsm, cublas_strsm);
  54. }
  55. /*
  56. *
  57. * Specific to PaStiX !
  58. *
  59. */
  60. /*
  61. *
  62. * We "need" some custom filters
  63. *
  64. * VECTOR
  65. * (n)
  66. * / | \
  67. * VECTOR BLAS VECTOR
  68. * (n1) (n2)
  69. *
  70. * if n1 = 0 :
  71. * VECTOR
  72. * / \
  73. * BLAS VECTOR
  74. */
  75. struct divide_vector_in_blas_filter_args {
  76. uint32_t n1, n2; /* (total size of the first portion (vector length) n < root's n ! */
  77. uint32_t stride; /* stride of the first portion (need to be a multiple of n */
  78. };
  79. void divide_vector_in_blas_filter(starpu_filter *f, starpu_data_handle root_data)
  80. {
  81. starpu_vector_interface_t *vector_root = &root_data->interface[0].vector;
  82. uint32_t nx = vector_root->nx;
  83. size_t elemsize = vector_root->elemsize;
  84. struct divide_vector_in_blas_filter_args *args = f->filter_arg_ptr;
  85. unsigned n1 = args->n1;
  86. unsigned n2 = args->n2;
  87. unsigned stride = args->stride;
  88. STARPU_ASSERT(n1 + n2 < nx);
  89. unsigned n3 = nx - n1 - n2;
  90. /* first allocate the children starpu_data_handle */
  91. starpu_data_create_children(root_data, (n1==0)?2:3, root_data->ops);
  92. STARPU_ASSERT((n2 % args->stride) == 0);
  93. unsigned child = 0;
  94. unsigned node;
  95. if (n1 > 0)
  96. {
  97. for (node = 0; node < STARPU_MAXNODES; node++)
  98. {
  99. starpu_vector_interface_t *local = &root_data->children[child].interface[node].vector;
  100. local->nx = n1;
  101. local->elemsize = elemsize;
  102. if (root_data->per_node[node].allocated) {
  103. local->ptr = root_data->interface[node].vector.ptr;
  104. }
  105. }
  106. child++;
  107. }
  108. for (node = 0; node < STARPU_MAXNODES; node++)
  109. {
  110. starpu_blas_interface_t *local = &root_data->children[child].interface[node].blas;
  111. local->nx = stride;
  112. local->ny = n2/stride;
  113. local->ld = stride;
  114. local->elemsize = elemsize;
  115. if (root_data->per_node[node].allocated) {
  116. local->ptr = root_data->interface[node].vector.ptr + n1*elemsize;
  117. }
  118. struct starpu_data_state_t *state = &root_data->children[child];
  119. state->ops = &interface_blas_ops;
  120. }
  121. child++;
  122. for (node = 0; node < STARPU_MAXNODES; node++)
  123. {
  124. starpu_vector_interface_t *local = &root_data->children[child].interface[node].vector;
  125. local->nx = n3;
  126. local->elemsize = elemsize;
  127. if (root_data->per_node[node].allocated) {
  128. local->ptr = root_data->interface[node].vector.ptr + (n1+n2)*elemsize;
  129. }
  130. }
  131. }
  132. static data_state *cblktab;
  133. static void _cublas_cblk_strsm_callback(void *sem)
  134. {
  135. sem_t *semptr = sem;
  136. sem_post(semptr);
  137. }
  138. void STARPU_MONITOR_DATA(unsigned ncols)
  139. {
  140. cblktab = calloc(ncols, sizeof(data_state));
  141. }
  142. void STARPU_MONITOR_CBLK(unsigned col, float *data, unsigned stride, unsigned width)
  143. {
  144. //void starpu_register_blas_data(struct starpu_data_state_t *state, uint32_t home_node,
  145. // uintptr_t ptr, uint32_t ld, uint32_t nx,
  146. // uint32_t ny, size_t elemsize);
  147. //fprintf(stderr, "col %d data %p stride %d width %d\n", col, data, stride, width);
  148. starpu_register_blas_data(&cblktab[col], 0 /* home */,
  149. (uintptr_t) data, stride, stride, width, sizeof(float));
  150. }
  151. static data_state work_block_1;
  152. static data_state work_block_2;
  153. void allocate_maxbloktab_on_cublas(void *descr[] __attribute__((unused)), void *arg __attribute__((unused)))
  154. {
  155. starpu_request_data_allocation(&work_block_1, 1);
  156. starpu_request_data_allocation(&work_block_2, 1);
  157. starpu_filter f1, f2;
  158. struct divide_vector_in_blas_filter_args args1, args2;
  159. f1.filter_func = divide_vector_in_blas_filter;
  160. args1.n1 = 1; /* XXX random ... */
  161. args1.n2 = 2;
  162. args1.stride = 1;
  163. f1.filter_arg_ptr = &args1;
  164. starpu_partition_data(&work_block_1, &f1);
  165. f2.filter_func = divide_vector_in_blas_filter;
  166. args2.n1 = 0;
  167. args2.n2 = 2;
  168. args2.stride = 1;
  169. f2.filter_arg_ptr = &args2;
  170. starpu_partition_data(&work_block_2, &f2);
  171. }
  172. void STARPU_DECLARE_WORK_BLOCKS(float *maxbloktab1, float *maxbloktab2, unsigned solv_coefmax)
  173. {
  174. starpu_register_vector_data(&work_block_1, 0 /* home */, (uintptr_t)maxbloktab1, solv_coefmax, sizeof(float));
  175. starpu_register_vector_data(&work_block_2, 0 /* home */, (uintptr_t)maxbloktab2, solv_coefmax, sizeof(float));
  176. starpu_codelet cl;
  177. job_t j;
  178. sem_t sem;
  179. /* initialize codelet */
  180. cl.where = CUDA;
  181. cl.cuda_func = allocate_maxbloktab_on_cublas;
  182. j = _starpu_job_create();
  183. j->cb = _cublas_cblk_strsm_callback;
  184. j->argcb = &sem;
  185. j->cl = &cl;
  186. j->cl_arg = NULL;
  187. j->nbuffers = 0;
  188. j->cl->model = NULL;
  189. sem_init(&sem, 0, 0U);
  190. /* submit the codelet */
  191. submit_job(j);
  192. /* wait for its completion */
  193. sem_wait(&sem);
  194. sem_destroy(&sem);
  195. }
  196. void _core_cblk_strsm(void *descr[], void *arg __attribute__((unused)))
  197. {
  198. uint32_t nx, ny, ld;
  199. nx = GET_BLAS_NX(descr[0]);
  200. ny = GET_BLAS_NY(descr[0]);
  201. ld = GET_BLAS_LD(descr[0]);
  202. float *diag_cblkdata, *extra_cblkdata;
  203. diag_cblkdata = (float *)GET_BLAS_PTR(descr[0]);
  204. extra_cblkdata = diag_cblkdata + ny;
  205. unsigned m = nx - ny;
  206. unsigned n = ny;
  207. // SOPALIN_TRSM("R","L","T","U",dimb,dima,fun,ga,stride,gb,stride);
  208. core_strsm++;
  209. cblas_strsm(CblasColMajor, CblasRight, CblasLower, CblasTrans, CblasUnit, m, n, 1.0f,
  210. diag_cblkdata, ld, extra_cblkdata, ld);
  211. }
  212. void _cublas_cblk_strsm(void *descr[], void *arg __attribute__((unused)))
  213. {
  214. uint32_t nx, ny, ld;
  215. nx = GET_BLAS_NX(descr[0]);
  216. ny = GET_BLAS_NY(descr[0]);
  217. ld = GET_BLAS_LD(descr[0]);
  218. float *diag_cblkdata, *extra_cblkdata;
  219. diag_cblkdata = (float *)GET_BLAS_PTR(descr[0]);
  220. extra_cblkdata = diag_cblkdata + ny;
  221. unsigned m = nx - ny;
  222. unsigned n = ny;
  223. cublas_strsm++;
  224. cublasStrsm ('R', 'L', 'T', 'U', m, n, 1.0,
  225. diag_cblkdata, ld,
  226. extra_cblkdata, ld);
  227. cublasStatus st = cublasGetError();
  228. if (st) fprintf(stderr, "ERROR %d\n", st);
  229. STARPU_ASSERT(st == CUBLAS_STATUS_SUCCESS);
  230. }
  231. static struct starpu_perfmodel_t starpu_cblk_strsm = {
  232. .per_arch = {
  233. [STARPU_CORE_DEFAULT] = { .cost_model = starpu_cblk_strsm_core_cost },
  234. [STARPU_CUDA_DEFAULT] = { .cost_model = starpu_cblk_strsm_cuda_cost }
  235. },
  236. // .type = REGRESSION_BASED,
  237. .type = PER_ARCH,
  238. .symbol = "starpu_cblk_strsm"
  239. };
  240. void STARPU_CBLK_STRSM(unsigned col)
  241. {
  242. /* perform a strsm on the block column */
  243. starpu_codelet cl;
  244. job_t j;
  245. sem_t sem;
  246. /* initialize codelet */
  247. cl.where = CORE|CUDA;
  248. cl.core_func = _core_cblk_strsm;
  249. cl.cuda_func = _cublas_cblk_strsm;
  250. j = _starpu_job_create();
  251. // j->where = (starpu_get_blas_nx(&cblktab[col]) > BLOCK && starpu_get_blas_ny(&cblktab[col]) > BLOCK)? CUBLAS:CORE;
  252. j->cb = _cublas_cblk_strsm_callback;
  253. j->argcb = &sem;
  254. j->cl = &cl;
  255. j->cl_arg = NULL;
  256. j->nbuffers = 1;
  257. /* we could be a little more precise actually */
  258. j->buffers[0].handle = &cblktab[col];
  259. j->buffers[0].mode = STARPU_RW;
  260. j->cl->model = &starpu_cblk_strsm;
  261. sem_init(&sem, 0, 0U);
  262. /* submit the codelet */
  263. submit_job(j);
  264. /* wait for its completion */
  265. sem_wait(&sem);
  266. sem_destroy(&sem);
  267. }
  268. struct starpu_compute_contrib_compact_args {
  269. unsigned stride;
  270. int dimi;
  271. int dimj;
  272. int dima;
  273. };
  274. void _core_compute_contrib_compact(void *descr[], void *arg)
  275. {
  276. struct starpu_compute_contrib_compact_args *args = arg;
  277. float *gaik = (float *)GET_BLAS_PTR(descr[0]) + args->dima;
  278. float *gb = (float *)GET_BLAS_PTR(descr[1]);
  279. unsigned strideb = (unsigned)GET_BLAS_LD(descr[1]);
  280. float *gc = (float *)GET_BLAS_PTR(descr[2]);
  281. unsigned stridec = (unsigned)GET_BLAS_LD(descr[2]);
  282. core_sgemm++;
  283. cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans,
  284. args->dimi, args->dimj, args->dima,
  285. 1.0f, gaik, args->stride,
  286. gb, strideb,
  287. 0.0 , gc, stridec);
  288. }
  289. void _cublas_compute_contrib_compact(void *descr[], void *arg)
  290. {
  291. struct starpu_compute_contrib_compact_args *args = arg;
  292. float *gaik = (float *)GET_BLAS_PTR(descr[0]) + args->dima;
  293. float *gb = (float *)GET_BLAS_PTR(descr[1]);
  294. unsigned strideb = (unsigned)GET_BLAS_LD(descr[1]);
  295. float *gc = (float *)GET_BLAS_PTR(descr[2]);
  296. unsigned stridec = (unsigned)GET_BLAS_LD(descr[2]);
  297. cublas_sgemm++;
  298. cublasSgemm('N','T', args->dimi, args->dimj, args->dima,
  299. 1.0, gaik, args->stride,
  300. gb, strideb,
  301. 0.0, gc, stridec);
  302. cublasStatus st = cublasGetError();
  303. if (st) fprintf(stderr, "ERROR %d\n", st);
  304. STARPU_ASSERT(st == CUBLAS_STATUS_SUCCESS);
  305. }
  306. static struct starpu_perfmodel_t starpu_compute_contrib_compact = {
  307. .per_arch = {
  308. [STARPU_CORE_DEFAULT] = { .cost_model = starpu_compute_contrib_compact_core_cost },
  309. [STARPU_CUDA_DEFAULT] = { .cost_model = starpu_compute_contrib_compact_cuda_cost }
  310. },
  311. // .type = REGRESSION_BASED,
  312. .type = PER_ARCH,
  313. .symbol = "starpu_compute_contrib_compact"
  314. };
  315. int update_work_blocks(unsigned col, int dimi, int dimj, int dima, int stride)
  316. {
  317. /* be paranoid XXX */
  318. notify_data_modification(get_sub_data(&work_block_1, 1, 0), 0);
  319. notify_data_modification(get_sub_data(&work_block_1, 1, 1), 0);
  320. //notify_data_modification(get_sub_data(&work_block_1, 1, 2), 0);
  321. notify_data_modification(get_sub_data(&work_block_2, 1, 0), 0);
  322. notify_data_modification(get_sub_data(&work_block_2, 1, 1), 0);
  323. notify_data_modification(&cblktab[col], 0);
  324. starpu_unpartition_data(&work_block_1, 0);
  325. starpu_unpartition_data(&work_block_2, 0);
  326. starpu_filter f1, f2;
  327. struct divide_vector_in_blas_filter_args args1, args2;
  328. f1.filter_func = divide_vector_in_blas_filter;
  329. args1.n1 = stride - dima - dimi; //STARPU_ASSERT(args1.n1 != 0);
  330. args1.n2 = (stride - dima)*dima;
  331. args1.stride = (stride - dima);
  332. f1.filter_arg_ptr = &args1;
  333. starpu_partition_data(&work_block_1, &f1);
  334. f2.filter_func = divide_vector_in_blas_filter;
  335. args2.n1 = 0;
  336. args2.n2 = dimi*dimj;
  337. args2.stride = dimi;
  338. f2.filter_arg_ptr = &args2;
  339. starpu_partition_data(&work_block_2, &f2);
  340. return (args1.n1!=0)?3:2;
  341. }
  342. void STARPU_COMPUTE_CONTRIB_COMPACT(unsigned col, int dimi, int dimj, int dima, int stride)
  343. {
  344. // CUBLAS_SGEMM("N","T",dimi,dimj,dima, 1.0,gaik,stride,gb,stride-dima,
  345. // 0.0 ,gc,dimi);
  346. struct starpu_compute_contrib_compact_args args;
  347. args.stride = stride;
  348. args.dimi = dimi;
  349. args.dimj = dimj;
  350. args.dima = dima;
  351. starpu_codelet cl;
  352. job_t j;
  353. sem_t sem;
  354. /* initialize codelet */
  355. cl.where = CUDA|CORE;
  356. cl.core_func = _core_compute_contrib_compact;
  357. cl.cuda_func = _cublas_compute_contrib_compact;
  358. j = _starpu_job_create();
  359. j->cb = _cublas_cblk_strsm_callback;
  360. j->argcb = &sem;
  361. j->cl = &cl;
  362. j->cl_arg = &args;
  363. j->cl->model = &starpu_compute_contrib_compact;
  364. int ret;
  365. ret = update_work_blocks(col, dimi, dimj, dima, stride);
  366. j->nbuffers = 3;
  367. /* we could be a little more precise actually */
  368. j->buffers[0].handle = &cblktab[col]; // gaik
  369. j->buffers[0].mode = STARPU_R;
  370. j->buffers[1].handle = get_sub_data(&work_block_1, 1, (ret==2)?0:1);
  371. j->buffers[1].mode = STARPU_R;
  372. j->buffers[2].handle = get_sub_data(&work_block_2, 1, 0);;
  373. j->buffers[2].mode = STARPU_RW; // XXX STARPU_W
  374. sem_init(&sem, 0, 0U);
  375. /* submit the codelet */
  376. submit_job(j);
  377. /* wait for its completion */
  378. sem_wait(&sem);
  379. sem_destroy(&sem);
  380. }
  381. /*
  382. *
  383. * SGEMM
  384. *
  385. */
  386. struct sgemm_args {
  387. char transa;
  388. char transb;
  389. int m, n, k;
  390. float alpha;
  391. float beta;
  392. };
  393. void _cublas_sgemm(void *descr[], void *arg)
  394. {
  395. float *A, *B, *C;
  396. uint32_t nxA, nyA, ldA;
  397. uint32_t nxB, nyB, ldB;
  398. uint32_t nxC, nyC, ldC;
  399. A = (float *)GET_BLAS_PTR(descr[0]);
  400. nxA = GET_BLAS_NX(descr[0]);
  401. nyA = GET_BLAS_NY(descr[0]);
  402. ldA = GET_BLAS_LD(descr[0]);
  403. B = (float *)GET_BLAS_PTR(descr[1]);
  404. nxB = GET_BLAS_NX(descr[1]);
  405. nyB = GET_BLAS_NY(descr[1]);
  406. ldB = GET_BLAS_LD(descr[1]);
  407. C = (float *)GET_BLAS_PTR(descr[2]);
  408. nxC = GET_BLAS_NX(descr[2]);
  409. nyC = GET_BLAS_NY(descr[2]);
  410. ldC = GET_BLAS_LD(descr[2]);
  411. struct sgemm_args *args = arg;
  412. // fprintf(stderr, "CUBLAS SGEMM nxA %d nyA %d nxB %d nyB %d nxC %d nyC %d lda %d ldb %d ldc %d\n", nxA, nyA, nxB, nyB, nxC, nyC, ldA, ldB, ldC);
  413. // STARPU_ASSERT(nxA == nxC);
  414. // STARPU_ASSERT(nyA == nxB);
  415. // STARPU_ASSERT(nyB == nyC);
  416. //
  417. // STARPU_ASSERT(nxA <= ldA);
  418. // STARPU_ASSERT(nxB <= ldB);
  419. // STARPU_ASSERT(nxC <= ldC);
  420. cublasSgemm (args->transa, args->transb, args->m, args->n, args->k, args->alpha, A, (int)ldA,
  421. B, (int)ldB, args->beta, C, (int)ldC);
  422. cublasStatus st = cublasGetError();
  423. if (st) fprintf(stderr, "ERROR %d\n", st);
  424. STARPU_ASSERT(st == CUBLAS_STATUS_SUCCESS);
  425. }
  426. static void _cublas_sgemm_callback(void *sem)
  427. {
  428. sem_t *semptr = sem;
  429. sem_post(semptr);
  430. }
  431. void STARPU_SGEMM (const char *transa, const char *transb, const int m,
  432. const int n, const int k, const float alpha,
  433. const float *A, const int lda, const float *B,
  434. const int ldb, const float beta, float *C, const int ldc)
  435. {
  436. struct sgemm_args args;
  437. args.transa = *transa;
  438. args.transb = *transb;
  439. args.alpha = alpha;
  440. args.beta = beta;
  441. args.m = m;
  442. args.n = n;
  443. args.k = k;
  444. data_state A_state;
  445. data_state B_state;
  446. data_state C_state;
  447. starpu_codelet cl;
  448. job_t j;
  449. sem_t sem;
  450. // fprintf(stderr, "STARPU - SGEMM - TRANSA %c TRANSB %c m %d n %d k %d lda %d ldb %d ldc %d \n", *transa, *transb, m, n, k, lda, ldb, ldc);
  451. if (toupper(*transa) == 'N')
  452. {
  453. starpu_register_blas_data(&A_state, 0, (uintptr_t)A, lda, m, k, sizeof(float));
  454. }
  455. else
  456. {
  457. starpu_register_blas_data(&A_state, 0, (uintptr_t)A, lda, k, m, sizeof(float));
  458. }
  459. if (toupper(*transb) == 'N')
  460. {
  461. starpu_register_blas_data(&B_state, 0, (uintptr_t)B, ldb, k, n, sizeof(float));
  462. }
  463. else
  464. {
  465. starpu_register_blas_data(&B_state, 0, (uintptr_t)B, ldb, n, k, sizeof(float));
  466. }
  467. starpu_register_blas_data(&C_state, 0, (uintptr_t)C, ldc, m, n, sizeof(float));
  468. /* initialize codelet */
  469. cl.where = CUDA;
  470. //cl.core_func = _core_strsm;
  471. cl.cuda_func = _cublas_sgemm;
  472. j = _starpu_job_create();
  473. j->cb = _cublas_sgemm_callback;
  474. j->argcb = &sem;
  475. j->cl = &cl;
  476. j->cl_arg = &args;
  477. j->nbuffers = 3;
  478. j->buffers[0].handle = &A_state;
  479. j->buffers[0].mode = STARPU_R;
  480. j->buffers[1].handle = &B_state;
  481. j->buffers[1].mode = STARPU_R;
  482. j->buffers[2].handle = &C_state;
  483. j->buffers[2].mode = STARPU_RW;
  484. j->cl->model = NULL;
  485. sem_init(&sem, 0, 0U);
  486. /* submit the codelet */
  487. submit_job(j);
  488. /* wait for its completion */
  489. sem_wait(&sem);
  490. sem_destroy(&sem);
  491. /* make sure data are in memory again */
  492. starpu_unpartition_data(&A_state, 0);
  493. starpu_unpartition_data(&B_state, 0);
  494. starpu_unpartition_data(&C_state, 0);
  495. //starpu_delete_data(&A_state);
  496. //starpu_delete_data(&B_state);
  497. //starpu_delete_data(&C_state);
  498. // fprintf(stderr, "SGEMM done\n");
  499. }
  500. /*
  501. *
  502. * STRSM
  503. *
  504. */
  505. struct strsm_args {
  506. char side;
  507. char uplo;
  508. char transa;
  509. char diag;
  510. float alpha;
  511. int m,n;
  512. };
  513. //
  514. //void _core_strsm(void *descr[], void *arg)
  515. //{
  516. // float *A, *B;
  517. // uint32_t nxA, nyA, ldA;
  518. // uint32_t nxB, nyB, ldB;
  519. //
  520. // A = (float *)GET_BLAS_PTR(descr[0]);
  521. // nxA = GET_BLAS_NX(descr[0]);
  522. // nyA = GET_BLAS_NY(descr[0]);
  523. // ldA = GET_BLAS_LD(descr[0]);
  524. //
  525. // B = (float *)GET_BLAS_PTR(descr[1]);
  526. // nxB = GET_BLAS_NX(descr[1]);
  527. // nyB = GET_BLAS_NY(descr[1]);
  528. // ldB = GET_BLAS_LD(descr[1]);
  529. //
  530. // struct strsm_args *args = arg;
  531. //
  532. // fprintf(stderr, "CORE STRSM nxA %d nyA %d nxB %d nyB %d lda %d ldb %d\n", nxA, nyA, nxB, nyB, ldA, ldB);
  533. //
  534. // SOPALIN_TRSM("R","L","T","U",dimb,dima,fun,ga,stride,gb,stride);
  535. //
  536. //}
  537. /*
  538. *
  539. *
  540. *
  541. */
  542. void CUBLAS_SGEMM (const char *transa, const char *transb, const int m,
  543. const int n, const int k, const float alpha,
  544. const float *A, const int lda, const float *B,
  545. const int ldb, const float beta, float *C, const int ldc)
  546. {
  547. int ka, kb;
  548. float *devPtrA, *devPtrB, *devPtrC;
  549. // printf("CUBLAS SGEMM : m %d n %d k %d lda %d ldb %d ldc %d\n", m, n, k, lda, ldb, ldc);
  550. /* A - REAL array of DIMENSION ( LDA, ka ), where ka is
  551. * k when TRANSA = 'N' or 'n', and is m otherwise.
  552. * Before entry with TRANSA = 'N' or 'n', the leading m by k
  553. * part of the array A must contain the matrix A, otherwise
  554. * the leading k by m part of the array A must contain the
  555. * matrix A.
  556. */
  557. ka = (toupper(transa[0]) == 'N') ? k : m;
  558. cublasAlloc (lda * ka, sizeof(devPtrA[0]), (void**)&devPtrA);
  559. if (toupper(transa[0]) == 'N') {
  560. cublasSetMatrix (STARPU_MIN(m,lda), k, sizeof(A[0]), A, lda, devPtrA,
  561. lda);
  562. } else {
  563. cublasSetMatrix (STARPU_MIN(k,lda), m, sizeof(A[0]), A, lda, devPtrA,
  564. lda);
  565. }
  566. /* B - REAL array of DIMENSION ( LDB, kb ), where kb is
  567. * n when TRANSB = 'N' or 'n', and is k otherwise.
  568. * Before entry with TRANSB = 'N' or 'n', the leading k by n
  569. * part of the array B must contain the matrix B, otherwise
  570. * the leading n by k part of the array B must contain the
  571. * matrix B.
  572. */
  573. kb = (toupper(transb[0]) == 'N') ? n : k;
  574. cublasAlloc (ldb * kb, sizeof(devPtrB[0]), (void**)&devPtrB);
  575. if (toupper(transb[0]) == 'N') {
  576. cublasSetMatrix (STARPU_MIN(k,ldb), n, sizeof(B[0]), B, ldb, devPtrB,
  577. ldb);
  578. } else {
  579. cublasSetMatrix (STARPU_MIN(n,ldb), k, sizeof(B[0]), B, ldb, devPtrB,
  580. ldb);
  581. }
  582. /* C - REAL array of DIMENSION ( LDC, n ).
  583. * Before entry, the leading m by n part of the array C must
  584. * contain the matrix C, except when beta is zero, in which
  585. * case C need not be set on entry.
  586. * On exit, the array C is overwritten by the m by n matrix
  587. */
  588. cublasAlloc ((ldc) * (n), sizeof(devPtrC[0]), (void**)&devPtrC);
  589. cublasSetMatrix (STARPU_MIN(m,ldc), n, sizeof(C[0]), C, ldc, devPtrC, ldc);
  590. cublasSgemm (transa[0], transb[0], m, n, k, alpha, devPtrA, lda,
  591. devPtrB, ldb, beta, devPtrC, ldc);
  592. cublasGetMatrix (STARPU_MIN(m,ldc), n, sizeof(C[0]), devPtrC, ldc, C, ldc);
  593. cublasFree (devPtrA);
  594. cublasFree (devPtrB);
  595. cublasFree (devPtrC);
  596. }