strassen2.c 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <stdio.h>
  17. #include <stdint.h>
  18. #include <math.h>
  19. #include <sys/types.h>
  20. #include <sys/time.h>
  21. #include <pthread.h>
  22. #include <signal.h>
  23. #include <starpu.h>
  24. #define MAXDEPS 4
  25. uint64_t current_tag = 1024;
  26. uint64_t used_mem = 0;
  27. /*
  28. Strassen:
  29. M1 = (A11 + A22)(B11 + B22)
  30. M2 = (A21 + A22)B11
  31. M3 = A11(B12 - B22)
  32. M4 = A22(B21 - B11)
  33. M5 = (A11 + A12)B22
  34. M6 = (A21 - A11)(B11 + B12)
  35. M7 = (A12 - A22)(B21 + B22)
  36. C11 = M1 + M4 - M5 + M7
  37. C12 = M3 + M5
  38. C21 = M2 + M4
  39. C22 = M1 - M2 + M3 + M6
  40. 7 recursive calls to the Strassen algorithm (in each Mi computation)
  41. 10+7 temporary buffers (to compute the terms of Mi = Mia x Mib, and to store Mi)
  42. complexity:
  43. M(n) multiplication complexity
  44. A(n) add/sub complexity
  45. M(n) = (10 + 8) A(n/2) + 7 M(n/2)
  46. NB: we consider fortran ordering (hence we compute M3t = (B12t - B22t)A11t for instance)
  47. */
  48. static unsigned size = 2048;
  49. static unsigned reclevel = 3;
  50. static unsigned norandom = 0;
  51. static unsigned pin = 0;
  52. extern void mult_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  53. extern void sub_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  54. extern void add_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  55. extern void self_add_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  56. extern void self_sub_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  57. #ifdef USE_CUDA
  58. extern void mult_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  59. extern void sub_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  60. extern void add_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  61. extern void self_add_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  62. extern void self_sub_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg);
  63. #endif
  64. extern void null_codelet(__attribute__((unused)) starpu_data_interface_t *descr,
  65. __attribute__((unused)) void *arg);
  66. extern void display_perf(double timing, unsigned size);
  67. struct starpu_perfmodel_t strassen_model_mult = {
  68. .type = HISTORY_BASED,
  69. .symbol = "strassen_model_mult"
  70. };
  71. struct starpu_perfmodel_t strassen_model_add = {
  72. .type = HISTORY_BASED,
  73. .symbol = "strassen_model_add"
  74. };
  75. struct starpu_perfmodel_t strassen_model_sub = {
  76. .type = HISTORY_BASED,
  77. .symbol = "strassen_model_sub"
  78. };
  79. struct starpu_perfmodel_t strassen_model_self_add = {
  80. .type = HISTORY_BASED,
  81. .symbol = "strassen_model_self_add"
  82. };
  83. struct starpu_perfmodel_t strassen_model_self_sub = {
  84. .type = HISTORY_BASED,
  85. .symbol = "strassen_model_self_sub"
  86. };
  87. struct data_deps_t {
  88. unsigned ndeps;
  89. starpu_tag_t deps[MAXDEPS];
  90. };
  91. struct strassen_iter {
  92. unsigned reclevel;
  93. struct strassen_iter *children[7];
  94. starpu_data_handle A, B, C;
  95. /* temporary buffers */
  96. /* Mi = Mia * Mib*/
  97. starpu_data_handle Mia_data[7];
  98. starpu_data_handle Mib_data[7];
  99. starpu_data_handle Mi_data[7];
  100. /* input deps */
  101. struct data_deps_t A_deps;
  102. struct data_deps_t B_deps;
  103. /* output deps */
  104. struct data_deps_t C_deps;
  105. };
  106. static starpu_filter f =
  107. {
  108. .filter_func = starpu_block_filter_func,
  109. .filter_arg = 2
  110. };
  111. static starpu_filter f2 =
  112. {
  113. .filter_func = starpu_vertical_block_filter_func,
  114. .filter_arg = 2
  115. };
  116. starpu_data_handle allocate_tmp_matrix(unsigned size, unsigned reclevel)
  117. {
  118. starpu_data_handle *data = malloc(sizeof(starpu_data_handle));
  119. float *buffer;
  120. #ifdef USE_CUDA
  121. if (pin) {
  122. starpu_malloc_pinned_if_possible(&buffer, size*size*sizeof(float));
  123. } else
  124. #endif
  125. {
  126. #ifdef HAVE_POSIX_MEMALIGN
  127. posix_memalign((void **)&buffer, 4096, size*size*sizeof(float));
  128. #else
  129. buffer = malloc(size*size*sizeof(float));
  130. #endif
  131. }
  132. assert(buffer);
  133. used_mem += size*size*sizeof(float);
  134. memset(buffer, 0, size*size*sizeof(float));
  135. starpu_monitor_blas_data(data, 0, (uintptr_t)buffer, size, size, size, sizeof(float));
  136. /* we construct a starpu_filter tree of depth reclevel */
  137. unsigned rec;
  138. for (rec = 0; rec < reclevel; rec++)
  139. starpu_map_filters(*data, 2, &f, &f2);
  140. return *data;
  141. }
  142. enum operation {
  143. ADD,
  144. SUB,
  145. MULT
  146. };
  147. static starpu_codelet cl_add = {
  148. .where = ANY,
  149. .model = &strassen_model_add,
  150. .core_func = add_core_codelet,
  151. #ifdef USE_CUDA
  152. .cublas_func = add_cublas_codelet,
  153. #endif
  154. .nbuffers = 3
  155. };
  156. static starpu_codelet cl_sub = {
  157. .where = ANY,
  158. .model = &strassen_model_sub,
  159. .core_func = sub_core_codelet,
  160. #ifdef USE_CUDA
  161. .cublas_func = sub_cublas_codelet,
  162. #endif
  163. .nbuffers = 3
  164. };
  165. static starpu_codelet cl_mult = {
  166. .where = ANY,
  167. .model = &strassen_model_mult,
  168. .core_func = mult_core_codelet,
  169. #ifdef USE_CUDA
  170. .cublas_func = mult_cublas_codelet,
  171. #endif
  172. .nbuffers = 3
  173. };
  174. /* C = A op B */
  175. struct starpu_task *compute_add_sub_op(starpu_data_handle C, enum operation op, starpu_data_handle A, starpu_data_handle B)
  176. {
  177. struct starpu_task *task = starpu_task_create();
  178. uint64_t j_tag = current_tag++;
  179. task->buffers[0].state = C;
  180. task->buffers[0].mode = W;
  181. task->buffers[1].state = A;
  182. task->buffers[1].mode = R;
  183. task->buffers[2].state = B;
  184. task->buffers[2].mode = R;
  185. task->callback_func = NULL;
  186. switch (op) {
  187. case ADD:
  188. task->cl = &cl_add;
  189. break;
  190. case SUB:
  191. task->cl = &cl_sub;
  192. break;
  193. case MULT:
  194. task->cl = &cl_mult;
  195. break;
  196. default:
  197. assert(0);
  198. };
  199. task->use_tag = 1;
  200. task->tag_id = (starpu_tag_t)j_tag;
  201. return task;
  202. }
  203. static starpu_codelet cl_self_add = {
  204. .where = ANY,
  205. .model = &strassen_model_self_add,
  206. .core_func = self_add_core_codelet,
  207. #ifdef USE_CUDA
  208. .cublas_func = self_add_cublas_codelet,
  209. #endif
  210. .nbuffers = 2
  211. };
  212. static starpu_codelet cl_self_sub = {
  213. .where = ANY,
  214. .model = &strassen_model_self_sub,
  215. .core_func = self_sub_core_codelet,
  216. #ifdef USE_CUDA
  217. .cublas_func = self_sub_cublas_codelet,
  218. #endif
  219. .nbuffers = 2
  220. };
  221. /* C = C op A */
  222. struct starpu_task *compute_self_add_sub_op(starpu_data_handle C, enum operation op, starpu_data_handle A)
  223. {
  224. struct starpu_task *task = starpu_task_create();
  225. uint64_t j_tag = current_tag++;
  226. task->buffers[0].state = C;
  227. task->buffers[0].mode = RW;
  228. task->buffers[1].state = A;
  229. task->buffers[1].mode = R;
  230. task->callback_func = NULL;
  231. switch (op) {
  232. case ADD:
  233. task->cl = &cl_self_add;
  234. break;
  235. case SUB:
  236. task->cl = &cl_self_sub;
  237. break;
  238. default:
  239. assert(0);
  240. };
  241. task->use_tag = 1;
  242. task->tag_id = (starpu_tag_t)j_tag;
  243. return task;
  244. }
  245. struct cleanup_arg {
  246. unsigned ndeps;
  247. starpu_tag_t tags[8];
  248. unsigned ndata;
  249. starpu_data_handle data[32];
  250. };
  251. void cleanup_callback(void *_arg)
  252. {
  253. //fprintf(stderr, "cleanup callback\n");
  254. struct cleanup_arg *arg = _arg;
  255. unsigned i;
  256. for (i = 0; i < arg->ndata; i++)
  257. starpu_advise_if_data_is_important(arg->data[i], 0);
  258. free(arg);
  259. }
  260. static starpu_codelet cleanup_codelet = {
  261. .where = ANY,
  262. .model = NULL,
  263. .core_func = null_codelet,
  264. #ifdef USE_CUDA
  265. .cublas_func = null_codelet,
  266. #endif
  267. .nbuffers = 0
  268. };
  269. /* this creates a codelet that will tell StarPU that all specified data are not
  270. essential once the tasks corresponding to the task will be performed */
  271. void create_cleanup_task(struct cleanup_arg *cleanup_arg)
  272. {
  273. struct starpu_task *task = starpu_task_create();
  274. uint64_t j_tag = current_tag++;
  275. task->cl = &cleanup_codelet;
  276. task->callback_func = cleanup_callback;
  277. task->callback_arg = cleanup_arg;
  278. task->use_tag = 1;
  279. task->tag_id = j_tag;
  280. starpu_tag_declare_deps_array(j_tag, cleanup_arg->ndeps, cleanup_arg->tags);
  281. starpu_submit_task(task);
  282. }
  283. void strassen_mult(struct strassen_iter *iter)
  284. {
  285. if (iter->reclevel == 0)
  286. {
  287. struct starpu_task *task_mult =
  288. compute_add_sub_op(iter->C, MULT, iter->A, iter->B);
  289. starpu_tag_t tag_mult = task_mult->tag_id;
  290. starpu_tag_t deps_array[10];
  291. unsigned indexA, indexB;
  292. for (indexA = 0; indexA < iter->A_deps.ndeps; indexA++)
  293. {
  294. deps_array[indexA] = iter->A_deps.deps[indexA];
  295. }
  296. for (indexB = 0; indexB < iter->B_deps.ndeps; indexB++)
  297. {
  298. deps_array[indexB+indexA] = iter->B_deps.deps[indexB];
  299. }
  300. starpu_tag_declare_deps_array(tag_mult, indexA+indexB, deps_array);
  301. iter->C_deps.ndeps = 1;
  302. iter->C_deps.deps[0] = tag_mult;
  303. starpu_submit_task(task_mult);
  304. return;
  305. }
  306. starpu_data_handle A11 = get_sub_data(iter->A, 2, 0, 0);
  307. starpu_data_handle A12 = get_sub_data(iter->A, 2, 1, 0);
  308. starpu_data_handle A21 = get_sub_data(iter->A, 2, 0, 1);
  309. starpu_data_handle A22 = get_sub_data(iter->A, 2, 1, 1);
  310. starpu_data_handle B11 = get_sub_data(iter->B, 2, 0, 0);
  311. starpu_data_handle B12 = get_sub_data(iter->B, 2, 1, 0);
  312. starpu_data_handle B21 = get_sub_data(iter->B, 2, 0, 1);
  313. starpu_data_handle B22 = get_sub_data(iter->B, 2, 1, 1);
  314. starpu_data_handle C11 = get_sub_data(iter->C, 2, 0, 0);
  315. starpu_data_handle C12 = get_sub_data(iter->C, 2, 1, 0);
  316. starpu_data_handle C21 = get_sub_data(iter->C, 2, 0, 1);
  317. starpu_data_handle C22 = get_sub_data(iter->C, 2, 1, 1);
  318. unsigned size = starpu_get_blas_nx(A11);
  319. /* M1a = (A11 + A22) */
  320. iter->Mia_data[0] = allocate_tmp_matrix(size, iter->reclevel);
  321. struct starpu_task *task_1a = compute_add_sub_op(iter->Mia_data[0], ADD, A11, A22);
  322. starpu_tag_t tag_1a = task_1a->tag_id;
  323. starpu_tag_declare_deps_array(tag_1a, iter->A_deps.ndeps, iter->A_deps.deps);
  324. starpu_submit_task(task_1a);
  325. /* M1b = (B11 + B22) */
  326. iter->Mib_data[0] = allocate_tmp_matrix(size, iter->reclevel);
  327. struct starpu_task *task_1b = compute_add_sub_op(iter->Mib_data[0], ADD, B11, B22);
  328. starpu_tag_t tag_1b = task_1b->tag_id;
  329. starpu_tag_declare_deps_array(tag_1b, iter->B_deps.ndeps, iter->B_deps.deps);
  330. starpu_submit_task(task_1b);
  331. /* M2a = (A21 + A22) */
  332. iter->Mia_data[1] = allocate_tmp_matrix(size, iter->reclevel);
  333. struct starpu_task *task_2a = compute_add_sub_op(iter->Mia_data[1], ADD, A21, A22);
  334. starpu_tag_t tag_2a = task_2a->tag_id;
  335. starpu_tag_declare_deps_array(tag_2a, iter->A_deps.ndeps, iter->A_deps.deps);
  336. starpu_submit_task(task_2a);
  337. /* M3b = (B12 - B22) */
  338. iter->Mib_data[2] = allocate_tmp_matrix(size, iter->reclevel);
  339. struct starpu_task *task_3b = compute_add_sub_op(iter->Mib_data[2], SUB, B12, B22);
  340. starpu_tag_t tag_3b = task_3b->tag_id;
  341. starpu_tag_declare_deps_array(tag_3b, iter->B_deps.ndeps, iter->B_deps.deps);
  342. starpu_submit_task(task_3b);
  343. /* M4b = (B21 - B11) */
  344. iter->Mib_data[3] = allocate_tmp_matrix(size, iter->reclevel);
  345. struct starpu_task *task_4b = compute_add_sub_op(iter->Mib_data[3], SUB, B21, B11);
  346. starpu_tag_t tag_4b = task_4b->tag_id;
  347. starpu_tag_declare_deps_array(tag_4b, iter->B_deps.ndeps, iter->B_deps.deps);
  348. starpu_submit_task(task_4b);
  349. /* M5a = (A11 + A12) */
  350. iter->Mia_data[4] = allocate_tmp_matrix(size, iter->reclevel);
  351. struct starpu_task *task_5a = compute_add_sub_op(iter->Mia_data[4], ADD, A11, A12);
  352. starpu_tag_t tag_5a = task_5a->tag_id;
  353. starpu_tag_declare_deps_array(tag_5a, iter->A_deps.ndeps, iter->A_deps.deps);
  354. starpu_submit_task(task_5a);
  355. /* M6a = (A21 - A11) */
  356. iter->Mia_data[5] = allocate_tmp_matrix(size, iter->reclevel);
  357. struct starpu_task *task_6a = compute_add_sub_op(iter->Mia_data[5], SUB, A21, A11);
  358. starpu_tag_t tag_6a = task_6a->tag_id;
  359. starpu_tag_declare_deps_array(tag_6a, iter->A_deps.ndeps, iter->A_deps.deps);
  360. starpu_submit_task(task_6a);
  361. /* M6b = (B11 + B12) */
  362. iter->Mib_data[5] = allocate_tmp_matrix(size, iter->reclevel);
  363. struct starpu_task *task_6b = compute_add_sub_op(iter->Mib_data[5], SUB, B11, B12);
  364. starpu_tag_t tag_6b = task_6b->tag_id;
  365. starpu_tag_declare_deps_array(tag_6b, iter->B_deps.ndeps, iter->B_deps.deps);
  366. starpu_submit_task(task_6b);
  367. /* M7a = (A12 - A22) */
  368. iter->Mia_data[6] = allocate_tmp_matrix(size, iter->reclevel);
  369. struct starpu_task *task_7a = compute_add_sub_op(iter->Mia_data[6], SUB, A12, A22);
  370. starpu_tag_t tag_7a = task_7a->tag_id;
  371. starpu_tag_declare_deps_array(tag_7a, iter->A_deps.ndeps, iter->A_deps.deps);
  372. starpu_submit_task(task_7a);
  373. /* M7b = (B21 + B22) */
  374. iter->Mib_data[6] = allocate_tmp_matrix(size, iter->reclevel);
  375. struct starpu_task *task_7b = compute_add_sub_op(iter->Mib_data[6], ADD, B21, B22);
  376. starpu_tag_t tag_7b = task_7b->tag_id;
  377. starpu_tag_declare_deps_array(tag_7b, iter->B_deps.ndeps, iter->B_deps.deps);
  378. starpu_submit_task(task_7b);
  379. iter->Mi_data[0] = allocate_tmp_matrix(size, iter->reclevel);
  380. iter->Mi_data[1] = allocate_tmp_matrix(size, iter->reclevel);
  381. iter->Mi_data[2] = allocate_tmp_matrix(size, iter->reclevel);
  382. iter->Mi_data[3] = allocate_tmp_matrix(size, iter->reclevel);
  383. iter->Mi_data[4] = allocate_tmp_matrix(size, iter->reclevel);
  384. iter->Mi_data[5] = allocate_tmp_matrix(size, iter->reclevel);
  385. iter->Mi_data[6] = allocate_tmp_matrix(size, iter->reclevel);
  386. /* M1 = M1a * M1b */
  387. iter->children[0] = malloc(sizeof(struct strassen_iter));
  388. iter->children[0]->reclevel = iter->reclevel - 1;
  389. iter->children[0]->A_deps.ndeps = 1;
  390. iter->children[0]->A_deps.deps[0] = tag_1a;
  391. iter->children[0]->B_deps.ndeps = 1;
  392. iter->children[0]->B_deps.deps[0] = tag_1b;
  393. iter->children[0]->A = iter->Mia_data[0];
  394. iter->children[0]->B = iter->Mib_data[0];
  395. iter->children[0]->C = iter->Mi_data[0];
  396. strassen_mult(iter->children[0]);
  397. /* M2 = M2a * B11 */
  398. iter->children[1] = malloc(sizeof(struct strassen_iter));
  399. iter->children[1]->reclevel = iter->reclevel - 1;
  400. iter->children[1]->A_deps.ndeps = 1;
  401. iter->children[1]->A_deps.deps[0] = tag_2a;
  402. iter->children[1]->B_deps.ndeps = iter->B_deps.ndeps;
  403. memcpy(iter->children[1]->B_deps.deps, iter->B_deps.deps, iter->B_deps.ndeps*sizeof(starpu_tag_t));
  404. iter->children[1]->A = iter->Mia_data[1];
  405. iter->children[1]->B = B11;
  406. iter->children[1]->C = iter->Mi_data[1];
  407. strassen_mult(iter->children[1]);
  408. /* M3 = A11 * M3b */
  409. iter->children[2] = malloc(sizeof(struct strassen_iter));
  410. iter->children[2]->reclevel = iter->reclevel - 1;
  411. iter->children[2]->A_deps.ndeps = iter->B_deps.ndeps;
  412. memcpy(iter->children[2]->A_deps.deps, iter->A_deps.deps, iter->A_deps.ndeps*sizeof(starpu_tag_t));
  413. iter->children[2]->B_deps.ndeps = 1;
  414. iter->children[2]->B_deps.deps[0] = tag_3b;
  415. iter->children[2]->A = A11;
  416. iter->children[2]->B = iter->Mib_data[2];
  417. iter->children[2]->C = iter->Mi_data[2];
  418. strassen_mult(iter->children[2]);
  419. /* M4 = A22 * M4b */
  420. iter->children[3] = malloc(sizeof(struct strassen_iter));
  421. iter->children[3]->reclevel = iter->reclevel - 1;
  422. iter->children[3]->A_deps.ndeps = iter->B_deps.ndeps;
  423. memcpy(iter->children[3]->A_deps.deps, iter->A_deps.deps, iter->A_deps.ndeps*sizeof(starpu_tag_t));
  424. iter->children[3]->B_deps.ndeps = 1;
  425. iter->children[3]->B_deps.deps[0] = tag_4b;
  426. iter->children[3]->A = A22;
  427. iter->children[3]->B = iter->Mib_data[3];
  428. iter->children[3]->C = iter->Mi_data[3];
  429. strassen_mult(iter->children[3]);
  430. /* M5 = M5a * B22 */
  431. iter->children[4] = malloc(sizeof(struct strassen_iter));
  432. iter->children[4]->reclevel = iter->reclevel - 1;
  433. iter->children[4]->A_deps.ndeps = 1;
  434. iter->children[4]->A_deps.deps[0] = tag_5a;
  435. iter->children[4]->B_deps.ndeps = iter->B_deps.ndeps;
  436. memcpy(iter->children[4]->B_deps.deps, iter->B_deps.deps, iter->B_deps.ndeps*sizeof(starpu_tag_t));
  437. iter->children[4]->A = iter->Mia_data[4];
  438. iter->children[4]->B = B22;
  439. iter->children[4]->C = iter->Mi_data[4];
  440. strassen_mult(iter->children[4]);
  441. /* M6 = M6a * M6b */
  442. iter->children[5] = malloc(sizeof(struct strassen_iter));
  443. iter->children[5]->reclevel = iter->reclevel - 1;
  444. iter->children[5]->A_deps.ndeps = 1;
  445. iter->children[5]->A_deps.deps[0] = tag_6a;
  446. iter->children[5]->B_deps.ndeps = 1;
  447. iter->children[5]->B_deps.deps[0] = tag_6b;
  448. iter->children[5]->A = iter->Mia_data[5];
  449. iter->children[5]->B = iter->Mib_data[5];
  450. iter->children[5]->C = iter->Mi_data[5];
  451. strassen_mult(iter->children[5]);
  452. /* M7 = M7a * M7b */
  453. iter->children[6] = malloc(sizeof(struct strassen_iter));
  454. iter->children[6]->reclevel = iter->reclevel - 1;
  455. iter->children[6]->A_deps.ndeps = 1;
  456. iter->children[6]->A_deps.deps[0] = tag_7a;
  457. iter->children[6]->B_deps.ndeps = 1;
  458. iter->children[6]->B_deps.deps[0] = tag_7b;
  459. iter->children[6]->A = iter->Mia_data[6];
  460. iter->children[6]->B = iter->Mib_data[6];
  461. iter->children[6]->C = iter->Mi_data[6];
  462. strassen_mult(iter->children[6]);
  463. starpu_tag_t *tag_m1 = iter->children[0]->C_deps.deps;
  464. starpu_tag_t *tag_m2 = iter->children[1]->C_deps.deps;
  465. starpu_tag_t *tag_m3 = iter->children[2]->C_deps.deps;
  466. starpu_tag_t *tag_m4 = iter->children[3]->C_deps.deps;
  467. starpu_tag_t *tag_m5 = iter->children[4]->C_deps.deps;
  468. starpu_tag_t *tag_m6 = iter->children[5]->C_deps.deps;
  469. starpu_tag_t *tag_m7 = iter->children[6]->C_deps.deps;
  470. /* C11 = M1 + M4 - M5 + M7 */
  471. struct starpu_task *task_c11_a = compute_self_add_sub_op(C11, ADD, iter->Mi_data[0]);
  472. struct starpu_task *task_c11_b = compute_self_add_sub_op(C11, ADD, iter->Mi_data[3]);
  473. struct starpu_task *task_c11_c = compute_self_add_sub_op(C11, SUB, iter->Mi_data[4]);
  474. struct starpu_task *task_c11_d = compute_self_add_sub_op(C11, ADD, iter->Mi_data[6]);
  475. starpu_tag_t tag_c11_a = task_c11_a->tag_id;
  476. starpu_tag_t tag_c11_b = task_c11_b->tag_id;
  477. starpu_tag_t tag_c11_c = task_c11_c->tag_id;
  478. starpu_tag_t tag_c11_d = task_c11_d->tag_id;
  479. /* C12 = M3 + M5 */
  480. struct starpu_task *task_c12_a = compute_self_add_sub_op(C12, ADD, iter->Mi_data[2]);
  481. struct starpu_task *task_c12_b = compute_self_add_sub_op(C12, ADD, iter->Mi_data[4]);
  482. starpu_tag_t tag_c12_a = task_c12_a->tag_id;
  483. starpu_tag_t tag_c12_b = task_c12_b->tag_id;
  484. /* C21 = M2 + M4 */
  485. struct starpu_task *task_c21_a = compute_self_add_sub_op(C21, ADD, iter->Mi_data[1]);
  486. struct starpu_task *task_c21_b = compute_self_add_sub_op(C21, ADD, iter->Mi_data[3]);
  487. starpu_tag_t tag_c21_a = task_c21_a->tag_id;
  488. starpu_tag_t tag_c21_b = task_c21_b->tag_id;
  489. /* C22 = M1 - M2 + M3 + M6 */
  490. struct starpu_task *task_c22_a = compute_self_add_sub_op(C22, ADD, iter->Mi_data[0]);
  491. struct starpu_task *task_c22_b = compute_self_add_sub_op(C22, SUB, iter->Mi_data[1]);
  492. struct starpu_task *task_c22_c = compute_self_add_sub_op(C22, ADD, iter->Mi_data[3]);
  493. struct starpu_task *task_c22_d = compute_self_add_sub_op(C22, ADD, iter->Mi_data[5]);
  494. starpu_tag_t tag_c22_a = task_c22_a->tag_id;
  495. starpu_tag_t tag_c22_b = task_c22_b->tag_id;
  496. starpu_tag_t tag_c22_c = task_c22_c->tag_id;
  497. starpu_tag_t tag_c22_d = task_c22_d->tag_id;
  498. if (iter->reclevel == 1)
  499. {
  500. starpu_tag_declare_deps(tag_c11_a, 1, tag_m1[0]);
  501. starpu_tag_declare_deps(tag_c11_b, 2, tag_m4[0], tag_c11_a);
  502. starpu_tag_declare_deps(tag_c11_c, 2, tag_m5[0], tag_c11_b);
  503. starpu_tag_declare_deps(tag_c11_d, 2, tag_m7[0], tag_c11_c);
  504. starpu_tag_declare_deps(tag_c12_a, 1, tag_m3[0]);
  505. starpu_tag_declare_deps(tag_c12_b, 2, tag_m5[0], tag_c12_a);
  506. starpu_tag_declare_deps(tag_c21_a, 1, tag_m2[0]);
  507. starpu_tag_declare_deps(tag_c21_b, 2, tag_m4[0], tag_c21_a);
  508. starpu_tag_declare_deps(tag_c22_a, 1, tag_m1[0]);
  509. starpu_tag_declare_deps(tag_c22_b, 2, tag_m2[0], tag_c22_a);
  510. starpu_tag_declare_deps(tag_c22_c, 2, tag_m3[0], tag_c22_b);
  511. starpu_tag_declare_deps(tag_c22_d, 2, tag_m6[0], tag_c22_c);
  512. }
  513. else
  514. {
  515. starpu_tag_declare_deps(tag_c11_a, 4, tag_m1[0], tag_m1[1], tag_m1[2], tag_m1[3]);
  516. starpu_tag_declare_deps(tag_c11_b, 5, tag_m4[0], tag_m4[1], tag_m4[2], tag_m4[3], tag_c11_a);
  517. starpu_tag_declare_deps(tag_c11_c, 5, tag_m5[0], tag_m5[1], tag_m5[2], tag_m5[3], tag_c11_b);
  518. starpu_tag_declare_deps(tag_c11_d, 5, tag_m7[0], tag_m7[1], tag_m7[2], tag_m7[3], tag_c11_c);
  519. starpu_tag_declare_deps(tag_c12_a, 4, tag_m3[0], tag_m3[1], tag_m3[2], tag_m3[3]);
  520. starpu_tag_declare_deps(tag_c12_b, 5, tag_m5[0], tag_m5[1], tag_m5[2], tag_m5[3], tag_c12_a);
  521. starpu_tag_declare_deps(tag_c21_a, 4, tag_m2[0], tag_m2[1], tag_m2[2], tag_m2[3]);
  522. starpu_tag_declare_deps(tag_c21_b, 5, tag_m4[0], tag_m4[1], tag_m4[2], tag_m4[3], tag_c21_a);
  523. starpu_tag_declare_deps(tag_c22_a, 4, tag_m1[0], tag_m1[1], tag_m1[2], tag_m1[3]);
  524. starpu_tag_declare_deps(tag_c22_b, 5, tag_m2[0], tag_m2[1], tag_m2[2], tag_m2[3], tag_c22_a);
  525. starpu_tag_declare_deps(tag_c22_c, 5, tag_m3[0], tag_m3[1], tag_m3[2], tag_m3[3], tag_c22_b);
  526. starpu_tag_declare_deps(tag_c22_d, 5, tag_m6[0], tag_m6[1], tag_m6[2], tag_m6[3], tag_c22_c);
  527. }
  528. starpu_submit_task(task_c11_a);
  529. starpu_submit_task(task_c11_b);
  530. starpu_submit_task(task_c11_c);
  531. starpu_submit_task(task_c11_d);
  532. starpu_submit_task(task_c12_a);
  533. starpu_submit_task(task_c12_b);
  534. starpu_submit_task(task_c21_a);
  535. starpu_submit_task(task_c21_b);
  536. starpu_submit_task(task_c22_a);
  537. starpu_submit_task(task_c22_b);
  538. starpu_submit_task(task_c22_c);
  539. starpu_submit_task(task_c22_d);
  540. iter->C_deps.ndeps = 4;
  541. iter->C_deps.deps[0] = tag_c11_d;
  542. iter->C_deps.deps[1] = tag_c12_b;
  543. iter->C_deps.deps[2] = tag_c21_b;
  544. iter->C_deps.deps[3] = tag_c22_d;
  545. struct cleanup_arg *clean_struct = malloc(sizeof(struct cleanup_arg));
  546. clean_struct->ndeps = 4;
  547. clean_struct->tags[0] = tag_c11_d;
  548. clean_struct->tags[1] = tag_c12_b;
  549. clean_struct->tags[2] = tag_c21_b;
  550. clean_struct->tags[3] = tag_c22_d;
  551. clean_struct->ndata = 17;
  552. clean_struct->data[0] = iter->Mia_data[0];
  553. clean_struct->data[1] = iter->Mib_data[0];
  554. clean_struct->data[2] = iter->Mia_data[1];
  555. clean_struct->data[3] = iter->Mib_data[2];
  556. clean_struct->data[4] = iter->Mib_data[3];
  557. clean_struct->data[5] = iter->Mia_data[4];
  558. clean_struct->data[6] = iter->Mia_data[5];
  559. clean_struct->data[7] = iter->Mib_data[5];
  560. clean_struct->data[8] = iter->Mia_data[6];
  561. clean_struct->data[9] = iter->Mib_data[6];
  562. clean_struct->data[10] = iter->Mi_data[0];
  563. clean_struct->data[11] = iter->Mi_data[1];
  564. clean_struct->data[12] = iter->Mi_data[2];
  565. clean_struct->data[13] = iter->Mi_data[3];
  566. clean_struct->data[14] = iter->Mi_data[4];
  567. clean_struct->data[15] = iter->Mi_data[5];
  568. clean_struct->data[16] = iter->Mi_data[6];
  569. create_cleanup_task(clean_struct);
  570. }
  571. static void dummy_codelet_func(__attribute__((unused))starpu_data_interface_t *descr,
  572. __attribute__((unused)) void *arg)
  573. {
  574. }
  575. static starpu_codelet dummy_codelet = {
  576. .where = ANY,
  577. .model = NULL,
  578. .core_func = dummy_codelet_func,
  579. #ifdef USE_CUDA
  580. .cublas_func = dummy_codelet_func,
  581. #endif
  582. .nbuffers = 0
  583. };
  584. static struct starpu_task *dummy_task(starpu_tag_t tag)
  585. {
  586. struct starpu_task *task =starpu_task_create();
  587. task->callback_func = NULL;
  588. task->cl = &dummy_codelet;
  589. task->cl_arg = NULL;
  590. task->use_tag = 1;
  591. task->tag_id = tag;
  592. return task;
  593. }
  594. void parse_args(int argc, char **argv)
  595. {
  596. int i;
  597. for (i = 1; i < argc; i++) {
  598. if (strcmp(argv[i], "-size") == 0) {
  599. char *argptr;
  600. size = strtol(argv[++i], &argptr, 10);
  601. }
  602. if (strcmp(argv[i], "-rec") == 0) {
  603. char *argptr;
  604. reclevel = strtol(argv[++i], &argptr, 10);
  605. }
  606. if (strcmp(argv[i], "-no-random") == 0) {
  607. norandom = 1;
  608. }
  609. if (strcmp(argv[i], "-pin") == 0) {
  610. pin = 1;
  611. }
  612. }
  613. }
  614. int main(int argc, char **argv)
  615. {
  616. starpu_data_handle data_A, data_B, data_C;
  617. float *A, *B, *C;
  618. struct timeval start;
  619. struct timeval end;
  620. parse_args(argc, argv);
  621. starpu_init(NULL);
  622. #ifdef USE_CUDA
  623. if (pin) {
  624. starpu_malloc_pinned_if_possible(&A, size*size*sizeof(float));
  625. starpu_malloc_pinned_if_possible(&B, size*size*sizeof(float));
  626. starpu_malloc_pinned_if_possible(&C, size*size*sizeof(float));
  627. } else
  628. #endif
  629. {
  630. #ifdef HAVE_POSIX_MEMALIGN
  631. posix_memalign((void **)&A, 4096, size*size*sizeof(float));
  632. posix_memalign((void **)&B, 4096, size*size*sizeof(float));
  633. posix_memalign((void **)&C, 4096, size*size*sizeof(float));
  634. #else
  635. A = malloc(size*size*sizeof(float));
  636. B = malloc(size*size*sizeof(float));
  637. C = malloc(size*size*sizeof(float));
  638. #endif
  639. }
  640. assert(A);
  641. assert(B);
  642. assert(C);
  643. used_mem += 3*size*size*sizeof(float);
  644. memset(A, 0, size*size*sizeof(float));
  645. memset(B, 0, size*size*sizeof(float));
  646. memset(C, 0, size*size*sizeof(float));
  647. starpu_monitor_blas_data(&data_A, 0, (uintptr_t)A, size, size, size, sizeof(float));
  648. starpu_monitor_blas_data(&data_B, 0, (uintptr_t)B, size, size, size, sizeof(float));
  649. starpu_monitor_blas_data(&data_C, 0, (uintptr_t)C, size, size, size, sizeof(float));
  650. unsigned rec;
  651. for (rec = 0; rec < reclevel; rec++)
  652. {
  653. starpu_map_filters(data_A, 2, &f, &f2);
  654. starpu_map_filters(data_B, 2, &f, &f2);
  655. starpu_map_filters(data_C, 2, &f, &f2);
  656. }
  657. struct strassen_iter iter;
  658. iter.reclevel = reclevel;
  659. iter.A = data_A;
  660. iter.B = data_B;
  661. iter.C = data_C;
  662. iter.A_deps.ndeps = 1;
  663. iter.A_deps.deps[0] = 42;
  664. iter.B_deps.ndeps = 1;
  665. iter.B_deps.deps[0] = 42;
  666. strassen_mult(&iter);
  667. starpu_tag_declare_deps_array(10, iter.C_deps.ndeps, iter.C_deps.deps);
  668. fprintf(stderr, "Using %ld MB of memory\n", used_mem/(1024*1024));
  669. struct starpu_task *task_start = dummy_task(42);
  670. gettimeofday(&start, NULL);
  671. starpu_submit_task(task_start);
  672. struct starpu_task *task_end = dummy_task(10);
  673. task_end->synchronous = 1;
  674. starpu_submit_task(task_end);
  675. gettimeofday(&end, NULL);
  676. starpu_shutdown();
  677. double timing = (double)((end.tv_sec - start.tv_sec)*1000000 + (end.tv_usec - start.tv_usec));
  678. display_perf(timing, size);
  679. return 0;
  680. }