cg.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2021 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <math.h>
  17. #include <assert.h>
  18. #include <starpu.h>
  19. #include <starpu_mpi.h>
  20. #include <common/blas.h>
  21. /*
  22. * Distributed version of Conjugate Gradient implemented in examples/cg/cg.c
  23. *
  24. * Use -display-result option and compare with the non-distributed version: the
  25. * x vector should be the same.
  26. */
  27. #include "../../../examples/cg/cg.h"
  28. static int copy_handle(starpu_data_handle_t* dst, starpu_data_handle_t* src, unsigned nblocks);
  29. #define HANDLE_TYPE_VECTOR starpu_data_handle_t*
  30. #define HANDLE_TYPE_MATRIX starpu_data_handle_t**
  31. #define TASK_INSERT(cl, ...) starpu_mpi_task_insert(MPI_COMM_WORLD, cl, ##__VA_ARGS__)
  32. #define GET_VECTOR_BLOCK(v, i) v[i]
  33. #define GET_MATRIX_BLOCK(m, i, j) m[i][j]
  34. #define BARRIER() starpu_mpi_barrier(MPI_COMM_WORLD);
  35. #define GET_DATA_HANDLE(handle) starpu_mpi_get_data_on_all_nodes_detached(MPI_COMM_WORLD, handle)
  36. static unsigned block_size;
  37. static int rank;
  38. static int nodes_p = 2;
  39. static int nodes_q;
  40. static TYPE ***A;
  41. static TYPE **x;
  42. static TYPE **b;
  43. static TYPE **r;
  44. static TYPE **d;
  45. static TYPE **q;
  46. #define FPRINTF_SERVER(ofile, fmt, ...) do { if (!getenv("STARPU_SSILENT") && rank == 0) {fprintf(ofile, fmt, ## __VA_ARGS__); }} while(0)
  47. #include "../../../examples/cg/cg_kernels.c"
  48. static int my_distrib(const int yy, const int xx)
  49. {
  50. return (yy%nodes_q)*nodes_p + (xx%nodes_p);
  51. }
  52. static int copy_handle(starpu_data_handle_t* dst, starpu_data_handle_t* src, unsigned nb)
  53. {
  54. unsigned block;
  55. for (block = 0; block < nb; block++)
  56. {
  57. if (rank == my_distrib(block, 0))
  58. {
  59. starpu_data_cpy(dst[block], src[block], /* asynchronous */ 1, /* without callback */ NULL, NULL);
  60. }
  61. }
  62. return 0;
  63. }
  64. /*
  65. * Generate Input data
  66. */
  67. static void generate_random_problem(void)
  68. {
  69. unsigned ii, jj, j, i;
  70. int mpi_rank;
  71. A = malloc(nblocks * sizeof(TYPE **));
  72. x = malloc(nblocks * sizeof(TYPE *));
  73. b = malloc(nblocks * sizeof(TYPE *));
  74. r = malloc(nblocks * sizeof(TYPE *));
  75. d = malloc(nblocks * sizeof(TYPE *));
  76. q = malloc(nblocks * sizeof(TYPE *));
  77. for (j = 0; j < nblocks; j++)
  78. {
  79. A[j] = malloc(nblocks * sizeof(TYPE*));
  80. mpi_rank = my_distrib(j, 0);
  81. if (mpi_rank == rank || display_result)
  82. {
  83. starpu_malloc((void**) &x[j], block_size*sizeof(TYPE));
  84. }
  85. if (mpi_rank == rank)
  86. {
  87. starpu_malloc((void**) &b[j], block_size*sizeof(TYPE));
  88. starpu_malloc((void**) &r[j], block_size*sizeof(TYPE));
  89. starpu_malloc((void**) &d[j], block_size*sizeof(TYPE));
  90. starpu_malloc((void**) &q[j], block_size*sizeof(TYPE));
  91. for (jj = 0; jj < block_size; jj++)
  92. {
  93. x[j][jj] = (TYPE) 0.0;
  94. b[j][jj] = (TYPE) 1.0;
  95. r[j][jj] = (TYPE) 0.0;
  96. d[j][jj] = (TYPE) 0.0;
  97. q[j][jj] = (TYPE) 0.0;
  98. }
  99. }
  100. for (i = 0; i < nblocks; i++)
  101. {
  102. mpi_rank = my_distrib(j, i);
  103. if (mpi_rank == rank)
  104. {
  105. starpu_malloc((void**) &A[j][i], block_size*block_size*sizeof(TYPE));
  106. for (ii = 0; ii < block_size; ii++)
  107. {
  108. for (jj = 0; jj < block_size; jj++)
  109. {
  110. /* We take Hilbert matrix that is not well conditionned but definite positive: H(i,j) = 1/(1+i+j) */
  111. A[j][i][jj + ii*block_size] = (TYPE) (1.0/(1.0+(ii+(j*block_size)+jj+(i*block_size))));
  112. }
  113. }
  114. }
  115. }
  116. }
  117. }
  118. static void free_data(void)
  119. {
  120. unsigned j, i;
  121. int mpi_rank;
  122. for (j = 0; j < nblocks; j++)
  123. {
  124. mpi_rank = my_distrib(j, 0);
  125. if (mpi_rank == rank || display_result)
  126. {
  127. starpu_free_noflag((void*) x[j], block_size*sizeof(TYPE));
  128. }
  129. if (mpi_rank == rank)
  130. {
  131. starpu_free_noflag((void*) b[j], block_size*sizeof(TYPE));
  132. starpu_free_noflag((void*) r[j], block_size*sizeof(TYPE));
  133. starpu_free_noflag((void*) d[j], block_size*sizeof(TYPE));
  134. starpu_free_noflag((void*) q[j], block_size*sizeof(TYPE));
  135. }
  136. for (i = 0; i < nblocks; i++)
  137. {
  138. mpi_rank = my_distrib(j, i);
  139. if (mpi_rank == rank)
  140. {
  141. starpu_free_noflag((void*) A[j][i], block_size*block_size*sizeof(TYPE));
  142. }
  143. }
  144. free(A[j]);
  145. }
  146. free(A);
  147. free(x);
  148. free(b);
  149. free(r);
  150. free(d);
  151. free(q);
  152. }
  153. static void register_data(void)
  154. {
  155. unsigned j, i;
  156. int mpi_rank;
  157. starpu_mpi_tag_t mpi_tag = 0;
  158. A_handle = malloc(nblocks*sizeof(starpu_data_handle_t*));
  159. x_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  160. b_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  161. r_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  162. d_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  163. q_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  164. for (j = 0; j < nblocks; j++)
  165. {
  166. mpi_rank = my_distrib(j, 0);
  167. A_handle[j] = malloc(nblocks*sizeof(starpu_data_handle_t));
  168. if (mpi_rank == rank || display_result)
  169. {
  170. starpu_vector_data_register(&x_handle[j], STARPU_MAIN_RAM, (uintptr_t) x[j], block_size, sizeof(TYPE));
  171. }
  172. else if (!display_result)
  173. {
  174. assert(mpi_rank != rank);
  175. starpu_vector_data_register(&x_handle[j], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  176. }
  177. if (mpi_rank == rank)
  178. {
  179. starpu_vector_data_register(&b_handle[j], STARPU_MAIN_RAM, (uintptr_t) b[j], block_size, sizeof(TYPE));
  180. starpu_vector_data_register(&r_handle[j], STARPU_MAIN_RAM, (uintptr_t) r[j], block_size, sizeof(TYPE));
  181. starpu_vector_data_register(&d_handle[j], STARPU_MAIN_RAM, (uintptr_t) d[j], block_size, sizeof(TYPE));
  182. starpu_vector_data_register(&q_handle[j], STARPU_MAIN_RAM, (uintptr_t) q[j], block_size, sizeof(TYPE));
  183. }
  184. else
  185. {
  186. starpu_vector_data_register(&b_handle[j], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  187. starpu_vector_data_register(&r_handle[j], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  188. starpu_vector_data_register(&d_handle[j], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  189. starpu_vector_data_register(&q_handle[j], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  190. }
  191. starpu_data_set_coordinates(x_handle[j], 1, j);
  192. starpu_mpi_data_register(x_handle[j], ++mpi_tag, mpi_rank);
  193. starpu_data_set_coordinates(b_handle[j], 1, j);
  194. starpu_mpi_data_register(b_handle[j], ++mpi_tag, mpi_rank);
  195. starpu_data_set_coordinates(r_handle[j], 1, j);
  196. starpu_mpi_data_register(r_handle[j], ++mpi_tag, mpi_rank);
  197. starpu_data_set_coordinates(d_handle[j], 1, j);
  198. starpu_mpi_data_register(d_handle[j], ++mpi_tag, mpi_rank);
  199. starpu_data_set_coordinates(q_handle[j], 1, j);
  200. starpu_mpi_data_register(q_handle[j], ++mpi_tag, mpi_rank);
  201. if (use_reduction)
  202. {
  203. starpu_data_set_reduction_methods(q_handle[j], &accumulate_vector_cl, &bzero_vector_cl);
  204. starpu_data_set_reduction_methods(r_handle[j], &accumulate_vector_cl, &bzero_vector_cl);
  205. }
  206. for (i = 0; i < nblocks; i++)
  207. {
  208. mpi_rank = my_distrib(j, i);
  209. if (mpi_rank == rank)
  210. {
  211. starpu_matrix_data_register(&A_handle[j][i], STARPU_MAIN_RAM, (uintptr_t) A[j][i], block_size, block_size, block_size, sizeof(TYPE));
  212. }
  213. else
  214. {
  215. starpu_matrix_data_register(&A_handle[j][i], -1, (uintptr_t) NULL, block_size, block_size, block_size, sizeof(TYPE));
  216. }
  217. starpu_data_set_coordinates(A_handle[j][i], 2, i, j);
  218. starpu_mpi_data_register(A_handle[j][i], ++mpi_tag, mpi_rank);
  219. }
  220. }
  221. starpu_variable_data_register(&dtq_handle, STARPU_MAIN_RAM, (uintptr_t)&dtq, sizeof(TYPE));
  222. starpu_variable_data_register(&rtr_handle, STARPU_MAIN_RAM, (uintptr_t)&rtr, sizeof(TYPE));
  223. starpu_mpi_data_register(rtr_handle, ++mpi_tag, 0);
  224. starpu_mpi_data_register(dtq_handle, ++mpi_tag, 0);
  225. if (use_reduction)
  226. {
  227. starpu_data_set_reduction_methods(dtq_handle, &accumulate_variable_cl, &bzero_variable_cl);
  228. starpu_data_set_reduction_methods(rtr_handle, &accumulate_variable_cl, &bzero_variable_cl);
  229. }
  230. }
  231. static void unregister_data(void)
  232. {
  233. unsigned j, i;
  234. for (j = 0; j < nblocks; j++)
  235. {
  236. starpu_data_unregister(x_handle[j]);
  237. starpu_data_unregister(b_handle[j]);
  238. starpu_data_unregister(r_handle[j]);
  239. starpu_data_unregister(d_handle[j]);
  240. starpu_data_unregister(q_handle[j]);
  241. for (i = 0; i < nblocks; i++)
  242. {
  243. starpu_data_unregister(A_handle[j][i]);
  244. }
  245. free(A_handle[j]);
  246. }
  247. starpu_data_unregister(dtq_handle);
  248. starpu_data_unregister(rtr_handle);
  249. free(A_handle);
  250. free(x_handle);
  251. free(b_handle);
  252. free(r_handle);
  253. free(d_handle);
  254. free(q_handle);
  255. }
  256. static void display_x_result(void)
  257. {
  258. unsigned j, i;
  259. for (j = 0; j < nblocks; j++)
  260. {
  261. starpu_mpi_get_data_on_node(MPI_COMM_WORLD, x_handle[j], 0);
  262. }
  263. if (rank == 0)
  264. {
  265. FPRINTF_SERVER(stderr, "Computed X vector:\n");
  266. for (j = 0; j < nblocks; j++)
  267. {
  268. starpu_data_acquire(x_handle[j], STARPU_R);
  269. for (i = 0; i < block_size; i++)
  270. {
  271. FPRINTF(stderr, "% 02.2e\n", x[j][i]);
  272. }
  273. starpu_data_release(x_handle[j]);
  274. }
  275. }
  276. }
  277. static void parse_args(int argc, char **argv)
  278. {
  279. int i;
  280. for (i = 1; i < argc; i++)
  281. {
  282. if (strcmp(argv[i], "-p") == 0)
  283. {
  284. nodes_p = atoi(argv[++i]);
  285. continue;
  286. }
  287. if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-help") == 0)
  288. {
  289. FPRINTF_SERVER(stderr, "usage: %s [-h] [-nblocks #blocks] [-display-result] [-p node_grid_width] [-n problem_size] [-no-reduction] [-maxiter i]\n", argv[0]);
  290. exit(-1);
  291. }
  292. }
  293. parse_common_args(argc, argv);
  294. }
  295. int main(int argc, char **argv)
  296. {
  297. int worldsize, ret;
  298. double start, end;
  299. /* Not supported yet */
  300. if (starpu_get_env_number_default("STARPU_GLOBAL_ARBITER", 0) > 0)
  301. return 77;
  302. ret = starpu_mpi_init_conf(&argc, &argv, 1, MPI_COMM_WORLD, NULL);
  303. if (ret == -ENODEV)
  304. return 77;
  305. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init_conf");
  306. starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
  307. starpu_mpi_comm_size(MPI_COMM_WORLD, &worldsize);
  308. parse_args(argc, argv);
  309. if (worldsize % nodes_p != 0)
  310. {
  311. FPRINTF_SERVER(stderr, "Node grid (%d) width must divide the number of nodes (%d).\n", nodes_p, worldsize);
  312. starpu_mpi_shutdown();
  313. return 1;
  314. }
  315. nodes_q = worldsize / nodes_p;
  316. if (n % nblocks != 0)
  317. {
  318. FPRINTF_SERVER(stderr, "The number of blocks (%u) must divide the matrix size (%lld).\n", nblocks, n);
  319. starpu_mpi_shutdown();
  320. return 1;
  321. }
  322. block_size = n / nblocks;
  323. starpu_cublas_init();
  324. FPRINTF_SERVER(stderr, "************** PARAMETERS ***************\n");
  325. FPRINTF_SERVER(stderr, "%d nodes (%dx%d)\n", worldsize, nodes_p, nodes_q);
  326. FPRINTF_SERVER(stderr, "Problem size (-n): %lld\n", n);
  327. FPRINTF_SERVER(stderr, "Maximum number of iterations (-maxiter): %d\n", i_max);
  328. FPRINTF_SERVER(stderr, "Number of blocks (-nblocks): %u\n", nblocks);
  329. FPRINTF_SERVER(stderr, "Reduction (-no-reduction): %s\n", use_reduction ? "enabled" : "disabled");
  330. starpu_mpi_barrier(MPI_COMM_WORLD);
  331. start = starpu_timing_now();
  332. generate_random_problem();
  333. register_data();
  334. starpu_mpi_barrier(MPI_COMM_WORLD);
  335. end = starpu_timing_now();
  336. FPRINTF_SERVER(stderr, "Problem initialization timing : %2.2f seconds\n", (end-start)/1e6);
  337. ret = cg();
  338. if (ret == -ENODEV)
  339. {
  340. ret = 77;
  341. goto enodev;
  342. }
  343. starpu_task_wait_for_all();
  344. if (display_result)
  345. {
  346. display_x_result();
  347. }
  348. enodev:
  349. unregister_data();
  350. free_data();
  351. starpu_cublas_shutdown();
  352. starpu_mpi_shutdown();
  353. return ret;
  354. }