pxlu.c 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "pxlu.h"
  17. #include "pxlu_kernels.h"
  18. #define MPI_TAG11(k) ((1U << 16) | (k))
  19. #define MPI_TAG12(k, j) ((2U << 16) | (j)*nblocks | (k))
  20. #define MPI_TAG21(k, i) ((3U << 16) | (i)*nblocks | (k))
  21. #define TAG11(k) ((starpu_tag_t)( (1ULL<<50) | (unsigned long long)(k)))
  22. #define TAG12(k,i) ((starpu_tag_t)(((2ULL<<50) | (((unsigned long long)(k))<<32) \
  23. | (unsigned long long)(i))))
  24. #define TAG21(k,j) ((starpu_tag_t)(((3ULL<<50) | (((unsigned long long)(k))<<32) \
  25. | (unsigned long long)(j))))
  26. #define TAG22(k,i,j) ((starpu_tag_t)(((4ULL<<50) | ((unsigned long long)(k)<<32) \
  27. | ((unsigned long long)(i)<<16) \
  28. | (unsigned long long)(j))))
  29. #define TAG11_SAVE(k) ((starpu_tag_t)( (5ULL<<50) | (unsigned long long)(k)))
  30. #define TAG12_SAVE(k,i) ((starpu_tag_t)(((6ULL<<50) | (((unsigned long long)(k))<<32) \
  31. | (unsigned long long)(i))))
  32. #define TAG21_SAVE(k,j) ((starpu_tag_t)(((7ULL<<50) | (((unsigned long long)(k))<<32) \
  33. | (unsigned long long)(j))))
  34. #define TAG11_SAVE_PARTIAL(k) ((starpu_tag_t)( (8ULL<<50) | (unsigned long long)(k)))
  35. #define TAG12_SAVE_PARTIAL(k,i) ((starpu_tag_t)(((9ULL<<50) | (((unsigned long long)(k))<<32) \
  36. | (unsigned long long)(i))))
  37. #define TAG21_SAVE_PARTIAL(k,j) ((starpu_tag_t)(((10ULL<<50) | (((unsigned long long)(k))<<32) \
  38. | (unsigned long long)(j))))
  39. #define STARPU_TAG_INIT ((starpu_tag_t)(11ULL<<50))
  40. static unsigned no_prio = 0;
  41. static unsigned nblocks = 0;
  42. static int rank = -1;
  43. static int world_size = -1;
  44. struct callback_arg {
  45. unsigned i, j, k;
  46. };
  47. /*
  48. * Various
  49. */
  50. static struct starpu_task *create_task(starpu_tag_t id)
  51. {
  52. struct starpu_task *task = starpu_task_create();
  53. task->cl_arg = NULL;
  54. task->use_tag = 1;
  55. task->tag_id = id;
  56. return task;
  57. }
  58. /* Send handle to every node appearing in the mask, and unlock tag once the
  59. * transfers are done. */
  60. static void send_data_to_mask(starpu_data_handle handle, int *rank_mask, int mpi_tag, starpu_tag_t tag)
  61. {
  62. unsigned cnt = 0;
  63. STARPU_ASSERT(handle != STARPU_POISON_PTR);
  64. int rank_array[world_size];
  65. int comm_array[world_size];
  66. int mpi_tag_array[world_size];
  67. starpu_data_handle handle_array[world_size];
  68. unsigned r;
  69. for (r = 0; r < world_size; r++)
  70. {
  71. if (rank_mask[r])
  72. rank_array[cnt++] = r;
  73. comm_array[r] = MPI_COMM_WORLD;
  74. mpi_tag_array[r] = mpi_tag;
  75. handle_array[r] = handle;
  76. }
  77. if (cnt == 0)
  78. {
  79. /* In case there is no message to send, we release the tag at
  80. * once */
  81. starpu_tag_notify_from_apps(tag);
  82. }
  83. else {
  84. starpu_mpi_isend_array_detached_unlock_tag(cnt, handle_array,
  85. rank_array, mpi_tag_array, comm_array, tag);
  86. }
  87. }
  88. /* Initiate a receive request once all dependencies are fulfilled and unlock
  89. * tag 'unlocked_tag' once it's done. */
  90. struct recv_when_done_callback_arg {
  91. int source;
  92. int mpi_tag;
  93. starpu_data_handle handle;
  94. starpu_tag_t unlocked_tag;
  95. };
  96. static void callback_receive_when_done(void *_arg)
  97. {
  98. struct recv_when_done_callback_arg *arg = _arg;
  99. starpu_mpi_irecv_detached_unlock_tag(arg->handle, arg->source,
  100. arg->mpi_tag, MPI_COMM_WORLD, arg->unlocked_tag);
  101. free(arg);
  102. }
  103. static void receive_when_deps_are_done(unsigned ndeps, starpu_tag_t *deps_tags,
  104. int source, int mpi_tag,
  105. starpu_data_handle handle,
  106. starpu_tag_t partial_tag,
  107. starpu_tag_t unlocked_tag)
  108. {
  109. STARPU_ASSERT(handle != STARPU_POISON_PTR);
  110. struct recv_when_done_callback_arg *arg =
  111. malloc(sizeof(struct recv_when_done_callback_arg));
  112. arg->source = source;
  113. arg->mpi_tag = mpi_tag;
  114. arg->handle = handle;
  115. arg->unlocked_tag = unlocked_tag;
  116. if (ndeps == 0)
  117. {
  118. callback_receive_when_done(arg);
  119. return;
  120. }
  121. starpu_create_sync_task(partial_tag, ndeps, deps_tags,
  122. callback_receive_when_done, arg);
  123. }
  124. /*
  125. * Task 11 (diagonal factorization)
  126. */
  127. static void create_task_11_recv(unsigned k)
  128. {
  129. unsigned i, j;
  130. /* The current node is not computing that task, so we receive the block
  131. * with MPI */
  132. /* We don't issue a MPI receive request until everyone using the
  133. * temporary buffer is done : 11_(k-1) can be used by 12_(k-1)j and
  134. * 21(k-1)i with i,j >= k */
  135. unsigned ndeps = 0;
  136. starpu_tag_t tag_array[2*nblocks];
  137. if (k > 0)
  138. for (i = k; i < nblocks; i++)
  139. {
  140. if (rank == get_block_rank(i, k))
  141. tag_array[ndeps++] = TAG21(k-1, i);
  142. }
  143. if (k > 0)
  144. for (j = k; j < nblocks; j++)
  145. {
  146. if (rank == get_block_rank(k, j))
  147. tag_array[ndeps++] = TAG12(k-1, j);
  148. }
  149. int source = get_block_rank(k, k);
  150. starpu_data_handle block_handle = STARPU_PLU(get_tmp_11_block_handle)();
  151. int mpi_tag = MPI_TAG11(k);
  152. starpu_tag_t partial_tag = TAG11_SAVE_PARTIAL(k);
  153. starpu_tag_t unlocked_tag = TAG11_SAVE(k);
  154. // fprintf(stderr, "NODE %d - 11 (%d) - recv when done ndeps %d - tag array %lx\n", rank, k, ndeps, tag_array[0]);
  155. receive_when_deps_are_done(ndeps, tag_array, source, mpi_tag, block_handle, partial_tag, unlocked_tag);
  156. }
  157. static void find_nodes_using_11(unsigned k, int *rank_mask)
  158. {
  159. memset(rank_mask, 0, world_size*sizeof(int));
  160. /* Block 11_k is used to compute 12_kj + 12ki with i,j > k */
  161. unsigned i;
  162. for (i = k; i < nblocks; i++)
  163. {
  164. int r = get_block_rank(i, k);
  165. rank_mask[r] = 1;
  166. }
  167. unsigned j;
  168. for (j = k; j < nblocks; j++)
  169. {
  170. int r = get_block_rank(k, j);
  171. rank_mask[r] = 1;
  172. }
  173. }
  174. static void callback_task_11_real(void *_arg)
  175. {
  176. struct callback_arg *arg = _arg;
  177. unsigned k = arg->k;
  178. /* Find all the nodes potentially requiring this block */
  179. int rank_mask[world_size];
  180. find_nodes_using_11(k, rank_mask);
  181. rank_mask[rank] = 0;
  182. /* Send the block to those nodes */
  183. starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(k, k);
  184. starpu_tag_t tag = TAG11_SAVE(k);
  185. int mpi_tag = MPI_TAG11(k);
  186. send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
  187. free(arg);
  188. }
  189. static void create_task_11_real(unsigned k)
  190. {
  191. struct starpu_task *task = create_task(TAG11(k));
  192. task->cl = &STARPU_PLU(cl11);
  193. /* which sub-data is manipulated ? */
  194. task->buffers[0].handle = STARPU_PLU(get_block_handle)(k, k);
  195. task->buffers[0].mode = STARPU_RW;
  196. struct callback_arg *arg = malloc(sizeof(struct callback_arg));
  197. arg->k = k;
  198. task->callback_func = callback_task_11_real;
  199. task->callback_arg = arg;
  200. /* this is an important task */
  201. if (!no_prio)
  202. task->priority = MAX_PRIO;
  203. /* enforce dependencies ... */
  204. if (k > 0) {
  205. starpu_tag_declare_deps(TAG11(k), 1, TAG22(k-1, k, k));
  206. }
  207. else {
  208. starpu_tag_declare_deps(TAG11(k), 1, STARPU_TAG_INIT);
  209. }
  210. starpu_submit_task(task);
  211. }
  212. static void create_task_11(unsigned k)
  213. {
  214. if (get_block_rank(k, k) == rank)
  215. {
  216. // fprintf(stderr, "CREATE real task 11(%d) (TAG11_SAVE(%d) = %lx) on node %d\n", k, k, TAG11_SAVE(k), rank);
  217. create_task_11_real(k);
  218. }
  219. else {
  220. /* We don't handle the task, but perhaps we have to generate MPI transfers. */
  221. int rank_mask[world_size];
  222. find_nodes_using_11(k, rank_mask);
  223. if (rank_mask[rank])
  224. {
  225. // fprintf(stderr, "create RECV task 11(%d) on node %d\n", k, rank);
  226. create_task_11_recv(k);
  227. }
  228. else {
  229. // fprintf(stderr, "Node %d needs not 11(%d)\n", rank, k);
  230. }
  231. }
  232. }
  233. /*
  234. * Task 12 (Update lower left (TRSM))
  235. */
  236. static void create_task_12_recv(unsigned k, unsigned j)
  237. {
  238. unsigned i;
  239. /* The current node is not computing that task, so we receive the block
  240. * with MPI */
  241. /* We don't issue a MPI receive request until everyone using the
  242. * temporary buffer is done : 12_(k-1)j can be used by 22_(k-1)ij with
  243. * i >= k */
  244. unsigned ndeps = 0;
  245. starpu_tag_t tag_array[nblocks];
  246. if (k > 0)
  247. for (i = k; i < nblocks; i++)
  248. {
  249. if (rank == get_block_rank(i, j))
  250. tag_array[ndeps++] = TAG22(k-1, i, j);
  251. }
  252. int source = get_block_rank(k, j);
  253. starpu_data_handle block_handle = STARPU_PLU(get_tmp_12_block_handle)(j);
  254. int mpi_tag = MPI_TAG12(k, j);
  255. starpu_tag_t partial_tag = TAG12_SAVE_PARTIAL(k, j);
  256. starpu_tag_t unlocked_tag = TAG12_SAVE(k, j);
  257. receive_when_deps_are_done(ndeps, tag_array, source, mpi_tag, block_handle, partial_tag, unlocked_tag);
  258. }
  259. static void find_nodes_using_12(unsigned k, unsigned j, int *rank_mask)
  260. {
  261. memset(rank_mask, 0, world_size*sizeof(int));
  262. /* Block 12_kj is used to compute 22_kij with i > k */
  263. unsigned i;
  264. for (i = k; i < nblocks; i++)
  265. {
  266. int r = get_block_rank(i, j);
  267. rank_mask[r] = 1;
  268. }
  269. }
  270. static void callback_task_12_real(void *_arg)
  271. {
  272. struct callback_arg *arg = _arg;
  273. unsigned k = arg->k;
  274. unsigned j = arg->j;
  275. /* Find all the nodes potentially requiring this block */
  276. int rank_mask[world_size];
  277. find_nodes_using_12(k, j, rank_mask);
  278. rank_mask[rank] = 0;
  279. /* Send the block to those nodes */
  280. // starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(j, k);
  281. starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(k, j);
  282. starpu_tag_t tag = TAG12_SAVE(k, j);
  283. int mpi_tag = MPI_TAG12(k, j);
  284. send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
  285. free(arg);
  286. }
  287. static void create_task_12_real(unsigned k, unsigned j)
  288. {
  289. struct starpu_task *task = create_task(TAG12(k, j));
  290. task->cl = &STARPU_PLU(cl12);
  291. // task->cl = &STARPU_PLU(cl21);
  292. int myrank;
  293. MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
  294. STARPU_ASSERT(myrank == rank);
  295. /* which sub-data is manipulated ? */
  296. starpu_data_handle diag_block;
  297. if (get_block_rank(k, k) == rank)
  298. diag_block = STARPU_PLU(get_block_handle)(k, k);
  299. else
  300. diag_block = STARPU_PLU(get_tmp_11_block_handle)();
  301. task->buffers[0].handle = diag_block;
  302. task->buffers[0].mode = STARPU_R;
  303. task->buffers[1].handle = STARPU_PLU(get_block_handle)(k, j);
  304. task->buffers[1].mode = STARPU_RW;
  305. STARPU_ASSERT(get_block_rank(k, j) == rank);
  306. STARPU_ASSERT(STARPU_PLU(get_tmp_11_block_handle)() != STARPU_POISON_PTR);
  307. STARPU_ASSERT(task->buffers[0].handle != STARPU_POISON_PTR);
  308. STARPU_ASSERT(task->buffers[1].handle != STARPU_POISON_PTR);
  309. struct callback_arg *arg = malloc(sizeof(struct callback_arg));
  310. arg->j = j;
  311. arg->k = k;
  312. task->callback_func = callback_task_12_real;
  313. task->callback_arg = arg;
  314. if (!no_prio && (j == k+1)) {
  315. task->priority = MAX_PRIO;
  316. }
  317. /* enforce dependencies ... */
  318. if (k > 0) {
  319. starpu_tag_declare_deps(TAG12(k, j), 2, TAG11_SAVE(k), TAG22(k-1, k, j));
  320. }
  321. else {
  322. starpu_tag_declare_deps(TAG12(k, j), 1, TAG11_SAVE(k));
  323. }
  324. starpu_submit_task(task);
  325. }
  326. static void create_task_12(unsigned k, unsigned j)
  327. {
  328. if (get_block_rank(k, j) == rank)
  329. {
  330. // fprintf(stderr, "CREATE real task 12(k = %d, j = %d) on node %d\n", k, j, rank);
  331. create_task_12_real(k, j);
  332. }
  333. else {
  334. /* We don't handle the task, but perhaps we have to generate MPI transfers. */
  335. int rank_mask[world_size];
  336. find_nodes_using_12(k, j, rank_mask);
  337. if (rank_mask[rank])
  338. {
  339. // fprintf(stderr, "create RECV task 12(k = %d, j = %d) on node %d\n", k, j, rank);
  340. create_task_12_recv(k, j);
  341. }
  342. // else {
  343. // fprintf(stderr, "Node %d needs not 12(k=%d, i=%d)\n", rank, k, j);
  344. // }
  345. }
  346. }
  347. /*
  348. * Task 21 (Update upper right (TRSM))
  349. */
  350. static void create_task_21_recv(unsigned k, unsigned i)
  351. {
  352. unsigned j;
  353. /* The current node is not computing that task, so we receive the block
  354. * with MPI */
  355. /* We don't issue a MPI receive request until everyone using the
  356. * temporary buffer is done : 21_(k-1)i can be used by 22_(k-1)ij with
  357. * j >= k */
  358. unsigned ndeps = 0;
  359. starpu_tag_t tag_array[nblocks];
  360. if (k > 0)
  361. for (j = k; j < nblocks; j++)
  362. {
  363. if (rank == get_block_rank(i, j))
  364. tag_array[ndeps++] = TAG22(k-1, i, j);
  365. }
  366. int source = get_block_rank(i, k);
  367. starpu_data_handle block_handle = STARPU_PLU(get_tmp_21_block_handle)(i);
  368. int mpi_tag = MPI_TAG21(k, i);
  369. starpu_tag_t partial_tag = TAG21_SAVE_PARTIAL(k, i);
  370. starpu_tag_t unlocked_tag = TAG21_SAVE(k, i);
  371. // fprintf(stderr, "NODE %d - 21 (%d, %d) - recv when done ndeps %d - tag array %lx\n", rank, k, i, ndeps, tag_array[0]);
  372. receive_when_deps_are_done(ndeps, tag_array, source, mpi_tag, block_handle, partial_tag, unlocked_tag);
  373. }
  374. static void find_nodes_using_21(unsigned k, unsigned i, int *rank_mask)
  375. {
  376. memset(rank_mask, 0, world_size*sizeof(int));
  377. /* Block 21_ki is used to compute 22_kij with j > k */
  378. unsigned j;
  379. for (j = k; j < nblocks; j++)
  380. {
  381. int r = get_block_rank(i, j);
  382. rank_mask[r] = 1;
  383. }
  384. }
  385. static void callback_task_21_real(void *_arg)
  386. {
  387. struct callback_arg *arg = _arg;
  388. unsigned k = arg->k;
  389. unsigned i = arg->i;
  390. /* Find all the nodes potentially requiring this block */
  391. int rank_mask[world_size];
  392. find_nodes_using_21(k, i, rank_mask);
  393. rank_mask[rank] = 0;
  394. /* Send the block to those nodes */
  395. starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(i, k);
  396. starpu_tag_t tag = TAG21_SAVE(k, i);
  397. int mpi_tag = MPI_TAG21(k, i);
  398. send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
  399. free(arg);
  400. }
  401. static void create_task_21_real(unsigned k, unsigned i)
  402. {
  403. struct starpu_task *task = create_task(TAG21(k, i));
  404. task->cl = &STARPU_PLU(cl21);
  405. // task->cl = &STARPU_PLU(cl12);
  406. /* which sub-data is manipulated ? */
  407. starpu_data_handle diag_block;
  408. if (get_block_rank(k, k) == rank)
  409. diag_block = STARPU_PLU(get_block_handle)(k, k);
  410. else
  411. diag_block = STARPU_PLU(get_tmp_11_block_handle)();
  412. task->buffers[0].handle = diag_block;
  413. task->buffers[0].mode = STARPU_R;
  414. task->buffers[1].handle = STARPU_PLU(get_block_handle)(i, k);
  415. task->buffers[1].mode = STARPU_RW;
  416. STARPU_ASSERT(STARPU_PLU(get_tmp_11_block_handle)() != STARPU_POISON_PTR);
  417. STARPU_ASSERT(task->buffers[0].handle != STARPU_POISON_PTR);
  418. STARPU_ASSERT(task->buffers[1].handle != STARPU_POISON_PTR);
  419. struct callback_arg *arg = malloc(sizeof(struct callback_arg));
  420. arg->i = i;
  421. arg->k = k;
  422. task->callback_func = callback_task_21_real;
  423. task->callback_arg = arg;
  424. if (!no_prio && (i == k+1)) {
  425. task->priority = MAX_PRIO;
  426. }
  427. /* enforce dependencies ... */
  428. if (k > 0) {
  429. starpu_tag_declare_deps(TAG21(k, i), 2, TAG11_SAVE(k), TAG22(k-1, i, k));
  430. }
  431. else {
  432. starpu_tag_declare_deps(TAG21(k, i), 1, TAG11_SAVE(k));
  433. }
  434. starpu_submit_task(task);
  435. }
  436. static void create_task_21(unsigned k, unsigned i)
  437. {
  438. if (get_block_rank(i, k) == rank)
  439. {
  440. // fprintf(stderr, "CREATE real task 21(k = %d, i = %d) on node %d\n", k, i, rank);
  441. create_task_21_real(k, i);
  442. }
  443. else {
  444. /* We don't handle the task, but perhaps we have to generate MPI transfers. */
  445. int rank_mask[world_size];
  446. find_nodes_using_21(k, i, rank_mask);
  447. if (rank_mask[rank])
  448. {
  449. // fprintf(stderr, "create RECV task 21(k = %d, i = %d) on node %d\n", k, i, rank);
  450. create_task_21_recv(k, i);
  451. }
  452. // else {
  453. // fprintf(stderr, "Node %d needs not 21(k=%d, i=%d)\n", rank, k,i);
  454. // }
  455. }
  456. }
  457. /*
  458. * Task 22 (GEMM)
  459. */
  460. static void create_task_22_real(unsigned k, unsigned i, unsigned j)
  461. {
  462. // printf("task 22 k,i,j = %d,%d,%d TAG = %llx\n", k,i,j, TAG22(k,i,j));
  463. struct starpu_task *task = create_task(TAG22(k, i, j));
  464. task->cl = &STARPU_PLU(cl22);
  465. /* which sub-data is manipulated ? */
  466. /* produced by TAG21_SAVE(k, i) */
  467. starpu_data_handle block21;
  468. if (get_block_rank(i, k) == rank)
  469. {
  470. block21 = STARPU_PLU(get_block_handle)(i, k);
  471. }
  472. else
  473. block21 = STARPU_PLU(get_tmp_21_block_handle)(i);
  474. /* produced by TAG12_SAVE(k, j) */
  475. starpu_data_handle block12;
  476. if (get_block_rank(k, j) == rank)
  477. {
  478. // block12 = STARPU_PLU(get_block_handle)(j, k);
  479. block12 = STARPU_PLU(get_block_handle)(k, j);
  480. }
  481. else
  482. block12 = STARPU_PLU(get_tmp_12_block_handle)(j);
  483. task->buffers[0].handle = block21;
  484. task->buffers[0].mode = STARPU_R;
  485. task->buffers[1].handle = block12;
  486. task->buffers[1].mode = STARPU_R;
  487. /* produced by TAG22(k-1, i, j) */
  488. task->buffers[2].handle = STARPU_PLU(get_block_handle)(i, j);
  489. task->buffers[2].mode = STARPU_RW;
  490. STARPU_ASSERT(task->buffers[0].handle != STARPU_POISON_PTR);
  491. STARPU_ASSERT(task->buffers[1].handle != STARPU_POISON_PTR);
  492. STARPU_ASSERT(task->buffers[2].handle != STARPU_POISON_PTR);
  493. if (!no_prio && (i == k + 1) && (j == k +1) ) {
  494. task->priority = MAX_PRIO;
  495. }
  496. /* enforce dependencies ... */
  497. if (k > 0) {
  498. starpu_tag_declare_deps(TAG22(k, i, j), 3, TAG22(k-1, i, j), TAG12_SAVE(k, j), TAG21_SAVE(k, i));
  499. }
  500. else {
  501. starpu_tag_declare_deps(TAG22(k, i, j), 2, TAG12_SAVE(k, j), TAG21_SAVE(k, i));
  502. }
  503. starpu_submit_task(task);
  504. }
  505. static void create_task_22(unsigned k, unsigned i, unsigned j)
  506. {
  507. if (get_block_rank(i, j) == rank)
  508. {
  509. // fprintf(stderr, "CREATE real task 22(k = %d, i = %d, j = %d) on node %d\n", k, i, j, rank);
  510. create_task_22_real(k, i, j);
  511. }
  512. // else {
  513. // fprintf(stderr, "Node %d needs not 22(k=%d, i=%d, j = %d)\n", rank, k,i,j);
  514. // }
  515. }
  516. static void wait_tag_and_fetch_handle(starpu_tag_t tag, starpu_data_handle handle)
  517. {
  518. STARPU_ASSERT(handle != STARPU_POISON_PTR);
  519. starpu_tag_wait(tag);
  520. // fprintf(stderr, "Rank %d : tag %lx is done\n", rank, tag);
  521. starpu_sync_data_with_mem(handle, STARPU_R);
  522. // starpu_delete_data(handle);
  523. }
  524. static void wait_termination(void)
  525. {
  526. starpu_wait_all_tasks();
  527. unsigned k, i, j;
  528. for (k = 0; k < nblocks; k++)
  529. {
  530. /* Wait task 11k if needed */
  531. if (get_block_rank(k, k) == rank)
  532. {
  533. starpu_data_handle diag_block = STARPU_PLU(get_block_handle)(k, k);
  534. wait_tag_and_fetch_handle(TAG11_SAVE(k), diag_block);
  535. }
  536. for (i = k + 1; i < nblocks; i++)
  537. {
  538. /* Wait task 21ki is needed */
  539. if (get_block_rank(i, k) == rank)
  540. {
  541. starpu_data_handle block21 = STARPU_PLU(get_block_handle)(i, k);
  542. //starpu_data_handle block21 = STARPU_PLU(get_block_handle)(k, i);
  543. fprintf(stderr, "BLOCK21 i %d k %d -> handle %p\n", i, k, block21);
  544. wait_tag_and_fetch_handle(TAG21_SAVE(k, i), block21);
  545. }
  546. }
  547. for (j = k + 1; j < nblocks; j++)
  548. {
  549. /* Wait task 12kj is needed */
  550. if (get_block_rank(k, j) == rank)
  551. {
  552. //starpu_data_handle block12 = STARPU_PLU(get_block_handle)(j, k);
  553. starpu_data_handle block12 = STARPU_PLU(get_block_handle)(k, j);
  554. fprintf(stderr, "BLOCK12 j %d k %d -> handle %p\n", j, k, block12);
  555. wait_tag_and_fetch_handle(TAG12_SAVE(k, j), block12);
  556. }
  557. }
  558. }
  559. }
  560. /*
  561. * code to bootstrap the factorization
  562. */
  563. double STARPU_PLU(plu_main)(unsigned _nblocks, int _rank, int _world_size)
  564. {
  565. struct timeval start;
  566. struct timeval end;
  567. nblocks = _nblocks;
  568. rank = _rank;
  569. world_size = _world_size;
  570. struct starpu_task *entry_task = NULL;
  571. /* create all the DAG nodes */
  572. unsigned i,j,k;
  573. for (k = 0; k < nblocks; k++)
  574. {
  575. create_task_11(k);
  576. for (i = k+1; i<nblocks; i++)
  577. {
  578. create_task_12(k, i);
  579. create_task_21(k, i);
  580. }
  581. for (i = k+1; i<nblocks; i++)
  582. {
  583. for (j = k+1; j<nblocks; j++)
  584. {
  585. create_task_22(k, i, j);
  586. }
  587. }
  588. }
  589. int barrier_ret = MPI_Barrier(MPI_COMM_WORLD);
  590. STARPU_ASSERT(barrier_ret == MPI_SUCCESS);
  591. /* schedule the codelet */
  592. gettimeofday(&start, NULL);
  593. starpu_tag_notify_from_apps(STARPU_TAG_INIT);
  594. wait_termination();
  595. gettimeofday(&end, NULL);
  596. double timing = (double)((end.tv_sec - start.tv_sec)*1000000 + (end.tv_usec - start.tv_usec));
  597. // fprintf(stderr, "RANK %d -> took %lf ms\n", rank, timing/1000);
  598. return timing;
  599. }