cg.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2021 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <math.h>
  17. #include <assert.h>
  18. #include <starpu.h>
  19. #include <starpu_mpi.h>
  20. #include <common/blas.h>
  21. /*
  22. * Distributed version of Conjugate Gradient implemented in examples/cg/cg.c
  23. *
  24. * Use -display-result option and compare with the non-distributed version: the
  25. * x vector should be the same.
  26. */
  27. #include "../../../examples/cg/cg.h"
  28. static int copy_handle(starpu_data_handle_t* dst, starpu_data_handle_t* src, unsigned nblocks);
  29. #define HANDLE_TYPE_VECTOR starpu_data_handle_t*
  30. #define HANDLE_TYPE_MATRIX starpu_data_handle_t**
  31. #define TASK_INSERT(cl, ...) starpu_mpi_task_insert(MPI_COMM_WORLD, cl, ##__VA_ARGS__)
  32. #define GET_VECTOR_BLOCK(v, i) v[i]
  33. #define GET_MATRIX_BLOCK(m, i, j) m[i][j]
  34. #define BARRIER() starpu_mpi_barrier(MPI_COMM_WORLD);
  35. #define GET_DATA_HANDLE(handle) starpu_mpi_get_data_on_all_nodes_detached(MPI_COMM_WORLD, handle)
  36. static int block_size;
  37. static int rank;
  38. static int nodes_p = 2;
  39. static int nodes_q;
  40. static TYPE ***A;
  41. static TYPE **x;
  42. static TYPE **b;
  43. static TYPE **r;
  44. static TYPE **d;
  45. static TYPE **q;
  46. #define FPRINTF_SERVER(ofile, fmt, ...) do { if (!getenv("STARPU_SSILENT") && rank == 0) {fprintf(ofile, fmt, ## __VA_ARGS__); }} while(0)
  47. #include "../../../examples/cg/cg_kernels.c"
  48. static int my_distrib(const int y, const int x)
  49. {
  50. return (y%nodes_q)*nodes_p + (x%nodes_p);
  51. }
  52. static int copy_handle(starpu_data_handle_t* dst, starpu_data_handle_t* src, unsigned nblocks)
  53. {
  54. unsigned b;
  55. for (b = 0; b < nblocks; b++)
  56. {
  57. if (rank == my_distrib(b, 0))
  58. {
  59. starpu_data_cpy(dst[b], src[b], /* asynchronous */ 1, /* without callback */ NULL, NULL);
  60. }
  61. }
  62. return 0;
  63. }
  64. /*
  65. * Generate Input data
  66. */
  67. static void generate_random_problem(void)
  68. {
  69. unsigned nn, mm, m, n, mpi_rank;
  70. A = malloc(nblocks * sizeof(TYPE **));
  71. x = malloc(nblocks * sizeof(TYPE *));
  72. b = malloc(nblocks * sizeof(TYPE *));
  73. r = malloc(nblocks * sizeof(TYPE *));
  74. d = malloc(nblocks * sizeof(TYPE *));
  75. q = malloc(nblocks * sizeof(TYPE *));
  76. for (m = 0; m < nblocks; m++)
  77. {
  78. A[m] = malloc(nblocks * sizeof(TYPE*));
  79. mpi_rank = my_distrib(m, 0);
  80. if (mpi_rank == rank || display_result)
  81. {
  82. starpu_malloc((void**) &x[m], block_size*sizeof(TYPE));
  83. }
  84. if (mpi_rank == rank)
  85. {
  86. starpu_malloc((void**) &b[m], block_size*sizeof(TYPE));
  87. starpu_malloc((void**) &r[m], block_size*sizeof(TYPE));
  88. starpu_malloc((void**) &d[m], block_size*sizeof(TYPE));
  89. starpu_malloc((void**) &q[m], block_size*sizeof(TYPE));
  90. for (mm = 0; mm < block_size; mm++)
  91. {
  92. x[m][mm] = (TYPE) 0.0;
  93. b[m][mm] = (TYPE) 1.0;
  94. r[m][mm] = (TYPE) 0.0;
  95. d[m][mm] = (TYPE) 0.0;
  96. q[m][mm] = (TYPE) 0.0;
  97. }
  98. }
  99. for (n = 0; n < nblocks; n++)
  100. {
  101. mpi_rank = my_distrib(m, n);
  102. if (mpi_rank == rank)
  103. {
  104. starpu_malloc((void**) &A[m][n], block_size*block_size*sizeof(TYPE));
  105. for (nn = 0; nn < block_size; nn++)
  106. {
  107. for (mm = 0; mm < block_size; mm++)
  108. {
  109. /* We take Hilbert matrix that is not well conditionned but definite positive: H(i,j) = 1/(1+i+j) */
  110. A[m][n][mm + nn*block_size] = (TYPE) (1.0/(1.0+(nn+(m*block_size)+mm+(n*block_size))));
  111. }
  112. }
  113. }
  114. }
  115. }
  116. }
  117. static void free_data(void)
  118. {
  119. unsigned nn, mm, m, n, mpi_rank;
  120. for (m = 0; m < nblocks; m++)
  121. {
  122. mpi_rank = my_distrib(m, 0);
  123. if (mpi_rank == rank || display_result)
  124. {
  125. starpu_free((void*) x[m]);
  126. }
  127. if (mpi_rank == rank)
  128. {
  129. starpu_free((void*) b[m]);
  130. starpu_free((void*) r[m]);
  131. starpu_free((void*) d[m]);
  132. starpu_free((void*) q[m]);
  133. }
  134. for (n = 0; n < nblocks; n++)
  135. {
  136. mpi_rank = my_distrib(m, n);
  137. if (mpi_rank == rank)
  138. {
  139. starpu_free((void*) A[m][n]);
  140. }
  141. }
  142. free(A[m]);
  143. }
  144. free(A);
  145. free(x);
  146. free(b);
  147. free(r);
  148. free(d);
  149. free(q);
  150. }
  151. static void register_data(void)
  152. {
  153. unsigned m, n;
  154. int mpi_rank;
  155. starpu_mpi_tag_t mpi_tag = 0;
  156. A_handle = malloc(nblocks*sizeof(starpu_data_handle_t*));
  157. x_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  158. b_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  159. r_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  160. d_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  161. q_handle = malloc(nblocks*sizeof(starpu_data_handle_t));
  162. for (m = 0; m < nblocks; m++)
  163. {
  164. mpi_rank = my_distrib(m, 0);
  165. A_handle[m] = malloc(nblocks*sizeof(starpu_data_handle_t));
  166. if (mpi_rank == rank || display_result)
  167. {
  168. starpu_vector_data_register(&x_handle[m], STARPU_MAIN_RAM, (uintptr_t) x[m], block_size, sizeof(TYPE));
  169. }
  170. else if (!display_result)
  171. {
  172. assert(mpi_rank != rank);
  173. starpu_vector_data_register(&x_handle[m], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  174. }
  175. if (mpi_rank == rank)
  176. {
  177. starpu_vector_data_register(&b_handle[m], STARPU_MAIN_RAM, (uintptr_t) b[m], block_size, sizeof(TYPE));
  178. starpu_vector_data_register(&r_handle[m], STARPU_MAIN_RAM, (uintptr_t) r[m], block_size, sizeof(TYPE));
  179. starpu_vector_data_register(&d_handle[m], STARPU_MAIN_RAM, (uintptr_t) d[m], block_size, sizeof(TYPE));
  180. starpu_vector_data_register(&q_handle[m], STARPU_MAIN_RAM, (uintptr_t) q[m], block_size, sizeof(TYPE));
  181. }
  182. else
  183. {
  184. starpu_vector_data_register(&b_handle[m], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  185. starpu_vector_data_register(&r_handle[m], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  186. starpu_vector_data_register(&d_handle[m], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  187. starpu_vector_data_register(&q_handle[m], -1, (uintptr_t) NULL, block_size, sizeof(TYPE));
  188. }
  189. starpu_data_set_coordinates(x_handle[m], 1, m);
  190. starpu_mpi_data_register(x_handle[m], ++mpi_tag, mpi_rank);
  191. starpu_data_set_coordinates(b_handle[m], 1, m);
  192. starpu_mpi_data_register(b_handle[m], ++mpi_tag, mpi_rank);
  193. starpu_data_set_coordinates(r_handle[m], 1, m);
  194. starpu_mpi_data_register(r_handle[m], ++mpi_tag, mpi_rank);
  195. starpu_data_set_coordinates(d_handle[m], 1, m);
  196. starpu_mpi_data_register(d_handle[m], ++mpi_tag, mpi_rank);
  197. starpu_data_set_coordinates(q_handle[m], 1, m);
  198. starpu_mpi_data_register(q_handle[m], ++mpi_tag, mpi_rank);
  199. if (use_reduction)
  200. {
  201. starpu_data_set_reduction_methods(q_handle[m], &accumulate_vector_cl, &bzero_vector_cl);
  202. starpu_data_set_reduction_methods(r_handle[m], &accumulate_vector_cl, &bzero_vector_cl);
  203. }
  204. for (n = 0; n < nblocks; n++)
  205. {
  206. mpi_rank = my_distrib(m, n);
  207. if (mpi_rank == rank)
  208. {
  209. starpu_matrix_data_register(&A_handle[m][n], STARPU_MAIN_RAM, (uintptr_t) A[m][n], block_size, block_size, block_size, sizeof(TYPE));
  210. }
  211. else
  212. {
  213. starpu_matrix_data_register(&A_handle[m][n], -1, (uintptr_t) NULL, block_size, block_size, block_size, sizeof(TYPE));
  214. }
  215. starpu_data_set_coordinates(A_handle[m][n], 2, n, m);
  216. starpu_mpi_data_register(A_handle[m][n], ++mpi_tag, mpi_rank);
  217. }
  218. }
  219. starpu_variable_data_register(&dtq_handle, STARPU_MAIN_RAM, (uintptr_t)&dtq, sizeof(TYPE));
  220. starpu_variable_data_register(&rtr_handle, STARPU_MAIN_RAM, (uintptr_t)&rtr, sizeof(TYPE));
  221. starpu_mpi_data_register(rtr_handle, ++mpi_tag, 0);
  222. starpu_mpi_data_register(dtq_handle, ++mpi_tag, 0);
  223. if (use_reduction)
  224. {
  225. starpu_data_set_reduction_methods(dtq_handle, &accumulate_variable_cl, &bzero_variable_cl);
  226. starpu_data_set_reduction_methods(rtr_handle, &accumulate_variable_cl, &bzero_variable_cl);
  227. }
  228. }
  229. static void unregister_data(void)
  230. {
  231. unsigned m, n;
  232. for (m = 0; m < nblocks; m++)
  233. {
  234. starpu_data_unregister(x_handle[m]);
  235. starpu_data_unregister(b_handle[m]);
  236. starpu_data_unregister(r_handle[m]);
  237. starpu_data_unregister(d_handle[m]);
  238. starpu_data_unregister(q_handle[m]);
  239. for (n = 0; n < nblocks; n++)
  240. {
  241. starpu_data_unregister(A_handle[m][n]);
  242. }
  243. free(A_handle[m]);
  244. }
  245. starpu_data_unregister(dtq_handle);
  246. starpu_data_unregister(rtr_handle);
  247. free(A_handle);
  248. free(x_handle);
  249. free(b_handle);
  250. free(r_handle);
  251. free(d_handle);
  252. free(q_handle);
  253. }
  254. static void display_x_result(void)
  255. {
  256. int j, i;
  257. for (j = 0; j < nblocks; j++)
  258. {
  259. starpu_mpi_get_data_on_node(MPI_COMM_WORLD, x_handle[j], 0);
  260. }
  261. if (rank == 0)
  262. {
  263. FPRINTF_SERVER(stderr, "Computed X vector:\n");
  264. for (j = 0; j < nblocks; j++)
  265. {
  266. starpu_data_acquire(x_handle[j], STARPU_R);
  267. for (i = 0; i < block_size; i++)
  268. {
  269. FPRINTF(stderr, "% 02.2e\n", x[j][i]);
  270. }
  271. starpu_data_release(x_handle[j]);
  272. }
  273. }
  274. }
  275. static void parse_args(int argc, char **argv)
  276. {
  277. int i;
  278. for (i = 1; i < argc; i++)
  279. {
  280. if (strcmp(argv[i], "-p") == 0)
  281. {
  282. nodes_p = atoi(argv[++i]);
  283. continue;
  284. }
  285. if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-help") == 0)
  286. {
  287. FPRINTF_SERVER(stderr, "usage: %s [-h] [-nblocks #blocks] [-display-result] [-p node_grid_width] [-n problem_size] [-no-reduction] [-maxiter i]\n", argv[0]);
  288. exit(-1);
  289. }
  290. }
  291. parse_common_args(argc, argv);
  292. }
  293. int main(int argc, char **argv)
  294. {
  295. int worldsize, ret;
  296. double start, end;
  297. /* Not supported yet */
  298. if (starpu_get_env_number_default("STARPU_GLOBAL_ARBITER", 0) > 0)
  299. return 77;
  300. ret = starpu_mpi_init_conf(&argc, &argv, 1, MPI_COMM_WORLD, NULL);
  301. if (ret == -ENODEV)
  302. return 77;
  303. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init_conf");
  304. starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
  305. starpu_mpi_comm_size(MPI_COMM_WORLD, &worldsize);
  306. parse_args(argc, argv);
  307. if (worldsize % nodes_p != 0)
  308. {
  309. FPRINTF_SERVER(stderr, "Node grid width must divide the number of nodes.\n");
  310. starpu_mpi_shutdown();
  311. return 1;
  312. }
  313. nodes_q = worldsize / nodes_p;
  314. if (n % nblocks != 0)
  315. {
  316. FPRINTF_SERVER(stderr, "The number of blocks must divide the matrix size.\n");
  317. starpu_mpi_shutdown();
  318. return 1;
  319. }
  320. block_size = n / nblocks;
  321. starpu_cublas_init();
  322. FPRINTF_SERVER(stderr, "************** PARAMETERS ***************\n");
  323. FPRINTF_SERVER(stderr, "%d nodes (%dx%d)\n", worldsize, nodes_p, nodes_q);
  324. FPRINTF_SERVER(stderr, "Problem size (-n): %lld\n", n);
  325. FPRINTF_SERVER(stderr, "Maximum number of iterations (-maxiter): %d\n", i_max);
  326. FPRINTF_SERVER(stderr, "Number of blocks (-nblocks): %d\n", nblocks);
  327. FPRINTF_SERVER(stderr, "Reduction (-no-reduction): %s\n", use_reduction ? "enabled" : "disabled");
  328. starpu_mpi_barrier(MPI_COMM_WORLD);
  329. start = starpu_timing_now();
  330. generate_random_problem();
  331. register_data();
  332. starpu_mpi_barrier(MPI_COMM_WORLD);
  333. end = starpu_timing_now();
  334. FPRINTF_SERVER(stderr, "Problem intialization timing : %2.2f seconds\n", (end-start)/10e6);
  335. ret = cg();
  336. if (ret == -ENODEV)
  337. {
  338. ret = 77;
  339. goto enodev;
  340. }
  341. starpu_task_wait_for_all();
  342. if (display_result)
  343. {
  344. display_x_result();
  345. }
  346. enodev:
  347. unregister_data();
  348. free_data();
  349. starpu_cublas_shutdown();
  350. starpu_mpi_shutdown();
  351. return ret;
  352. }