strassen.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "strassen.h"
  17. #include "strassen_models.h"
  18. static starpu_data_handle create_tmp_matrix(starpu_data_handle M)
  19. {
  20. float *data;
  21. starpu_data_handle state = malloc(sizeof(starpu_data_handle));
  22. /* create a matrix with the same dimensions as M */
  23. uint32_t nx = starpu_matrix_get_nx(M);
  24. uint32_t ny = starpu_matrix_get_nx(M);
  25. STARPU_ASSERT(state);
  26. data = malloc(nx*ny*sizeof(float));
  27. STARPU_ASSERT(data);
  28. starpu_matrix_data_register(&state, 0, (uintptr_t)data, nx, nx, ny, sizeof(float));
  29. return state;
  30. }
  31. static void free_tmp_matrix(starpu_data_handle matrix)
  32. {
  33. starpu_data_unregister(matrix);
  34. free(matrix);
  35. }
  36. static void partition_matrices(strassen_iter_state_t *iter)
  37. {
  38. starpu_data_handle A = iter->A;
  39. starpu_data_handle B = iter->B;
  40. starpu_data_handle C = iter->C;
  41. starpu_filter f;
  42. f.filter_func = starpu_block_filter_func;
  43. f.filter_arg = 2;
  44. starpu_filter f2;
  45. f2.filter_func = starpu_vertical_block_filter_func;
  46. f2.filter_arg = 2;
  47. starpu_map_filters(A, 2, &f, &f2);
  48. starpu_map_filters(B, 2, &f, &f2);
  49. starpu_map_filters(C, 2, &f, &f2);
  50. iter->A11 = starpu_data_get_sub_data(A, 2, 0, 0);
  51. iter->A12 = starpu_data_get_sub_data(A, 2, 1, 0);
  52. iter->A21 = starpu_data_get_sub_data(A, 2, 0, 1);
  53. iter->A22 = starpu_data_get_sub_data(A, 2, 1, 1);
  54. iter->B11 = starpu_data_get_sub_data(B, 2, 0, 0);
  55. iter->B12 = starpu_data_get_sub_data(B, 2, 1, 0);
  56. iter->B21 = starpu_data_get_sub_data(B, 2, 0, 1);
  57. iter->B22 = starpu_data_get_sub_data(B, 2, 1, 1);
  58. iter->C11 = starpu_data_get_sub_data(C, 2, 0, 0);
  59. iter->C12 = starpu_data_get_sub_data(C, 2, 1, 0);
  60. iter->C21 = starpu_data_get_sub_data(C, 2, 0, 1);
  61. iter->C22 = starpu_data_get_sub_data(C, 2, 1, 1);
  62. /* TODO check that all sub-matrices have the same size */
  63. }
  64. static void unpartition_matrices(strassen_iter_state_t *iter)
  65. {
  66. /* TODO there is no need to actually gather those results ... */
  67. starpu_data_unpartition(iter->A, 0);
  68. starpu_data_unpartition(iter->B, 0);
  69. starpu_data_unpartition(iter->C, 0);
  70. }
  71. static starpu_codelet cl_add = {
  72. .where = STARPU_CPU|STARPU_CUDA,
  73. .model = &strassen_model_add_sub,
  74. .cpu_func = add_cpu_codelet,
  75. #ifdef STARPU_USE_CUDA
  76. .cuda_func = add_cublas_codelet,
  77. #endif
  78. .nbuffers = 3
  79. };
  80. static starpu_codelet cl_sub = {
  81. .where = STARPU_CPU|STARPU_CUDA,
  82. .model = &strassen_model_add_sub,
  83. .cpu_func = sub_cpu_codelet,
  84. #ifdef STARPU_USE_CUDA
  85. .cuda_func = sub_cublas_codelet,
  86. #endif
  87. .nbuffers = 3
  88. };
  89. static starpu_codelet cl_mult = {
  90. .where = STARPU_CPU|STARPU_CUDA,
  91. .model = &strassen_model_mult,
  92. .cpu_func = mult_cpu_codelet,
  93. #ifdef STARPU_USE_CUDA
  94. .cuda_func = mult_cublas_codelet,
  95. #endif
  96. .nbuffers = 3
  97. };
  98. static starpu_codelet cl_self_add = {
  99. .where = STARPU_CPU|STARPU_CUDA,
  100. .model = &strassen_model_self_add_sub,
  101. .cpu_func = self_add_cpu_codelet,
  102. #ifdef STARPU_USE_CUDA
  103. .cuda_func = self_add_cublas_codelet,
  104. #endif
  105. .nbuffers = 2
  106. };
  107. static starpu_codelet cl_self_sub = {
  108. .where = STARPU_CPU|STARPU_CUDA,
  109. .model = &strassen_model_self_add_sub,
  110. .cpu_func = self_sub_cpu_codelet,
  111. #ifdef STARPU_USE_CUDA
  112. .cuda_func = self_sub_cublas_codelet,
  113. #endif
  114. .nbuffers = 2
  115. };
  116. static void compute_add_sub_op(starpu_data_handle A1, operation op,
  117. starpu_data_handle A2, starpu_data_handle C,
  118. void (*callback)(void *), void *argcallback)
  119. {
  120. /* performs C = (A op B) */
  121. struct starpu_task *task = starpu_task_create();
  122. task->cl_arg = NULL;
  123. task->use_tag = 0;
  124. task->buffers[0].handle = C;
  125. task->buffers[0].mode = STARPU_W;
  126. task->buffers[1].handle = A1;
  127. task->buffers[1].mode = STARPU_R;
  128. task->buffers[2].handle = A2;
  129. task->buffers[2].mode = STARPU_R;
  130. task->callback_func = callback;
  131. task->callback_arg = argcallback;
  132. switch (op) {
  133. case ADD:
  134. STARPU_ASSERT(A1);
  135. STARPU_ASSERT(A2);
  136. STARPU_ASSERT(C);
  137. task->cl = &cl_add;
  138. break;
  139. case SUB:
  140. STARPU_ASSERT(A1);
  141. STARPU_ASSERT(A2);
  142. STARPU_ASSERT(C);
  143. task->cl = &cl_sub;
  144. break;
  145. case MULT:
  146. STARPU_ASSERT(A1);
  147. STARPU_ASSERT(A2);
  148. STARPU_ASSERT(C);
  149. task->cl = &cl_mult;
  150. break;
  151. case SELFADD:
  152. task->buffers[0].mode = STARPU_RW;
  153. task->cl = &cl_self_add;
  154. break;
  155. case SELFSUB:
  156. task->buffers[0].mode = STARPU_RW;
  157. task->cl = &cl_self_sub;
  158. break;
  159. default:
  160. STARPU_ABORT();
  161. }
  162. starpu_task_submit(task);
  163. }
  164. /* Cij +=/-= Ek is done */
  165. void phase_3_callback_function(void *_arg)
  166. {
  167. unsigned cnt, use_cnt;
  168. phase3_t *arg = _arg;
  169. unsigned i = arg->i;
  170. strassen_iter_state_t *iter = arg->iter;
  171. free(arg);
  172. use_cnt = STARPU_ATOMIC_ADD(&iter->Ei_remaining_use[i], -1);
  173. if (use_cnt == 0)
  174. {
  175. /* no one needs Ei anymore : free it */
  176. switch (i) {
  177. case 0:
  178. free_tmp_matrix(iter->E1);
  179. break;
  180. case 1:
  181. free_tmp_matrix(iter->E2);
  182. break;
  183. case 2:
  184. free_tmp_matrix(iter->E3);
  185. break;
  186. case 3:
  187. free_tmp_matrix(iter->E4);
  188. break;
  189. case 4:
  190. free_tmp_matrix(iter->E5);
  191. break;
  192. case 5:
  193. free_tmp_matrix(iter->E6);
  194. break;
  195. case 6:
  196. free_tmp_matrix(iter->E7);
  197. break;
  198. default:
  199. STARPU_ABORT();
  200. }
  201. }
  202. cnt = STARPU_ATOMIC_ADD(&iter->counter, -1);
  203. if (cnt == 0)
  204. {
  205. /* the entire strassen iteration is done ! */
  206. unpartition_matrices(iter);
  207. // XXX free the Ei
  208. STARPU_ASSERT(iter->strassen_iter_callback);
  209. iter->strassen_iter_callback(iter->argcb);
  210. free(iter);
  211. }
  212. }
  213. /* Ei is computed */
  214. void phase_2_callback_function(void *_arg)
  215. {
  216. phase2_t *arg = _arg;
  217. strassen_iter_state_t *iter = arg->iter;
  218. unsigned i = arg->i;
  219. free(arg);
  220. phase3_t *arg1, *arg2;
  221. arg1 = malloc(sizeof(phase3_t));
  222. arg2 = malloc(sizeof(phase3_t));
  223. arg1->iter = iter;
  224. arg2->iter = iter;
  225. arg1->i = i;
  226. arg2->i = i;
  227. switch (i) {
  228. case 0:
  229. free(arg2); // will not be needed ..
  230. free_tmp_matrix(iter->E11);
  231. free_tmp_matrix(iter->E12);
  232. /* C11 += E1 */
  233. compute_add_sub_op(iter->E1, SELFADD, NULL, iter->C11, phase_3_callback_function, arg1);
  234. break;
  235. case 1:
  236. free_tmp_matrix(iter->E21);
  237. free_tmp_matrix(iter->E22);
  238. /* C11 += E2 */
  239. compute_add_sub_op(iter->E2, SELFADD, NULL, iter->C11, phase_3_callback_function, arg1);
  240. /* C22 += E2 */
  241. compute_add_sub_op(iter->E2, SELFADD, NULL, iter->C22, phase_3_callback_function, arg2);
  242. break;
  243. case 2:
  244. free(arg2); // will not be needed ..
  245. free_tmp_matrix(iter->E31);
  246. free_tmp_matrix(iter->E32);
  247. /* C22 -= E3 */
  248. compute_add_sub_op(iter->E3, SELFSUB, NULL, iter->C22, phase_3_callback_function, arg1);
  249. break;
  250. case 3:
  251. free_tmp_matrix(iter->E41);
  252. /* C11 -= E4 */
  253. compute_add_sub_op(iter->E4, SELFSUB, NULL, iter->C11, phase_3_callback_function, arg1);
  254. /* C12 += E4 */
  255. compute_add_sub_op(iter->E4, SELFADD, NULL, iter->C12, phase_3_callback_function, arg2);
  256. break;
  257. case 4:
  258. free_tmp_matrix(iter->E52);
  259. /* C12 += E5 */
  260. compute_add_sub_op(iter->E5, SELFADD, NULL, iter->C12, phase_3_callback_function, arg1);
  261. /* C22 += E5 */
  262. compute_add_sub_op(iter->E5, SELFADD, NULL, iter->C22, phase_3_callback_function, arg2);
  263. break;
  264. case 5:
  265. free_tmp_matrix(iter->E62);
  266. /* C11 += E6 */
  267. compute_add_sub_op(iter->E6, SELFADD, NULL, iter->C11, phase_3_callback_function, arg1);
  268. /* C21 += E6 */
  269. compute_add_sub_op(iter->E6, SELFADD, NULL, iter->C21, phase_3_callback_function, arg2);
  270. break;
  271. case 6:
  272. free_tmp_matrix(iter->E71);
  273. /* C21 += E7 */
  274. compute_add_sub_op(iter->E7, SELFADD, NULL, iter->C21, phase_3_callback_function, arg1);
  275. /* C22 -= E7 */
  276. compute_add_sub_op(iter->E7, SELFSUB, NULL, iter->C22, phase_3_callback_function, arg2);
  277. break;
  278. default:
  279. STARPU_ABORT();
  280. }
  281. }
  282. /* computes Ei */
  283. static void _strassen_phase_2(strassen_iter_state_t *iter, unsigned i)
  284. {
  285. phase2_t *phase_2_arg = malloc(sizeof(phase2_t));
  286. phase_2_arg->iter = iter;
  287. phase_2_arg->i = i;
  288. /* XXX */
  289. starpu_data_handle A;
  290. starpu_data_handle B;
  291. starpu_data_handle C;
  292. switch (i) {
  293. case 0:
  294. A = iter->E11; B = iter->E12;
  295. iter->E1 = create_tmp_matrix(A);
  296. C = iter->E1;
  297. break;
  298. case 1:
  299. A = iter->E21; B = iter->E22;
  300. iter->E2 = create_tmp_matrix(A);
  301. C = iter->E2;
  302. break;
  303. case 2:
  304. A = iter->E31; B = iter->E32;
  305. iter->E3 = create_tmp_matrix(A);
  306. C = iter->E3;
  307. break;
  308. case 3:
  309. A = iter->E41; B = iter->E42;
  310. iter->E4 = create_tmp_matrix(A);
  311. C = iter->E4;
  312. break;
  313. case 4:
  314. A = iter->E51; B = iter->E52;
  315. iter->E5 = create_tmp_matrix(A);
  316. C = iter->E5;
  317. break;
  318. case 5:
  319. A = iter->E61; B = iter->E62;
  320. iter->E6 = create_tmp_matrix(A);
  321. C = iter->E6;
  322. break;
  323. case 6:
  324. A = iter->E71; B = iter->E72;
  325. iter->E7 = create_tmp_matrix(A);
  326. C = iter->E7;
  327. break;
  328. default:
  329. STARPU_ABORT();
  330. }
  331. STARPU_ASSERT(A);
  332. STARPU_ASSERT(B);
  333. STARPU_ASSERT(C);
  334. // DEBUG XXX
  335. //compute_add_sub_op(A, MULT, B, C, phase_2_callback_function, phase_2_arg);
  336. strassen(A, B, C, phase_2_callback_function, phase_2_arg, iter->reclevel-1);
  337. }
  338. #define THRESHHOLD 128
  339. static void phase_1_callback_function(void *_arg)
  340. {
  341. phase1_t *arg = _arg;
  342. strassen_iter_state_t *iter = arg->iter;
  343. unsigned i = arg->i;
  344. free(arg);
  345. unsigned cnt = STARPU_ATOMIC_ADD(&iter->Ei12[i], +1);
  346. if (cnt == 2) {
  347. /* Ei1 and Ei2 are ready, compute Ei */
  348. _strassen_phase_2(iter, i);
  349. }
  350. }
  351. /* computes Ei1 or Ei2 with i in 0-6 */
  352. static void _strassen_phase_1(starpu_data_handle A1, operation opA, starpu_data_handle A2,
  353. starpu_data_handle C, strassen_iter_state_t *iter, unsigned i)
  354. {
  355. phase1_t *phase_1_arg = malloc(sizeof(phase1_t));
  356. phase_1_arg->iter = iter;
  357. phase_1_arg->i = i;
  358. compute_add_sub_op(A1, opA, A2, C, phase_1_callback_function, phase_1_arg);
  359. }
  360. strassen_iter_state_t *init_strassen_iter_state(starpu_data_handle A, starpu_data_handle B, starpu_data_handle C, void (*strassen_iter_callback)(void *), void *argcb)
  361. {
  362. strassen_iter_state_t *iter_state = malloc(sizeof(strassen_iter_state_t));
  363. iter_state->Ei12[0] = 0;
  364. iter_state->Ei12[1] = 0;
  365. iter_state->Ei12[2] = 0;
  366. iter_state->Ei12[3] = 1; // E42 = B22
  367. iter_state->Ei12[4] = 1; // E51 = A11
  368. iter_state->Ei12[5] = 1; // E61 = A22
  369. iter_state->Ei12[6] = 1; // E72 = B11
  370. iter_state->Ei_remaining_use[0] = 1;
  371. iter_state->Ei_remaining_use[1] = 2;
  372. iter_state->Ei_remaining_use[2] = 1;
  373. iter_state->Ei_remaining_use[3] = 2;
  374. iter_state->Ei_remaining_use[4] = 2;
  375. iter_state->Ei_remaining_use[5] = 2;
  376. iter_state->Ei_remaining_use[6] = 2;
  377. unsigned i;
  378. for (i = 0; i < 6; i++)
  379. {
  380. iter_state->Ei[i] = 0;
  381. }
  382. for (i = 0; i < 4; i++)
  383. {
  384. iter_state->Cij[i] = 0;
  385. }
  386. iter_state->strassen_iter_callback = strassen_iter_callback;
  387. iter_state->argcb = argcb;
  388. iter_state->A = A;
  389. iter_state->B = B;
  390. iter_state->C = C;
  391. iter_state->counter = 12;
  392. return iter_state;
  393. }
  394. static void _do_strassen(starpu_data_handle A, starpu_data_handle B, starpu_data_handle C, void (*strassen_iter_callback)(void *), void *argcb, unsigned reclevel)
  395. {
  396. /* do one level of recursion in the strassen algorithm */
  397. strassen_iter_state_t *iter = init_strassen_iter_state(A, B, C, strassen_iter_callback, argcb);
  398. partition_matrices(iter);
  399. iter->reclevel = reclevel;
  400. /* some Eij are already known */
  401. iter->E11 = create_tmp_matrix(iter->A11);
  402. iter->E12 = create_tmp_matrix(iter->B21);
  403. iter->E21 = create_tmp_matrix(iter->A11);
  404. iter->E22 = create_tmp_matrix(iter->B11);
  405. iter->E31 = create_tmp_matrix(iter->A11);
  406. iter->E32 = create_tmp_matrix(iter->B11);
  407. iter->E41 = create_tmp_matrix(iter->A11);
  408. iter->E42 = iter->B22;
  409. iter->E51 = iter->A11;
  410. iter->E52 = create_tmp_matrix(iter->B12);
  411. iter->E61 = iter->A22;
  412. iter->E62 = create_tmp_matrix(iter->B21);
  413. iter->E71 = create_tmp_matrix(iter->A21);
  414. iter->E72 = iter->B11;
  415. /* compute all Eij */
  416. _strassen_phase_1(iter->A11, SUB, iter->A22, iter->E11, iter, 0);
  417. _strassen_phase_1(iter->B21, ADD, iter->B22, iter->E12, iter, 0);
  418. _strassen_phase_1(iter->A11, ADD, iter->A22, iter->E21, iter, 1);
  419. _strassen_phase_1(iter->B11, ADD, iter->B22, iter->E22, iter, 1);
  420. _strassen_phase_1(iter->A11, SUB, iter->A21, iter->E31, iter, 2);
  421. _strassen_phase_1(iter->B11, ADD, iter->B12, iter->E32, iter, 2);
  422. _strassen_phase_1(iter->A11, ADD, iter->A12, iter->E41, iter, 3);
  423. _strassen_phase_1(iter->B12, SUB, iter->B22, iter->E52, iter, 4);
  424. _strassen_phase_1(iter->B21, SUB, iter->B11, iter->E62, iter, 5);
  425. _strassen_phase_1(iter->A21, ADD, iter->A22, iter->E71, iter, 6);
  426. }
  427. void strassen(starpu_data_handle A, starpu_data_handle B, starpu_data_handle C, void (*callback)(void *), void *argcb, unsigned reclevel)
  428. {
  429. /* C = A * B */
  430. if ( reclevel == 0 )
  431. {
  432. /* don't use Strassen but a simple sequential multiplication
  433. * provided this is small enough */
  434. compute_add_sub_op(A, MULT, B, C, callback, argcb);
  435. }
  436. else {
  437. _do_strassen(A, B, C, callback, argcb, reclevel);
  438. }
  439. }