datatypes.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2013, 2014, 2015, 2016, 2017 CNRS
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu_mpi.h>
  17. #include <stdlib.h>
  18. #include "helper.h"
  19. typedef void (*check_func)(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error);
  20. void check_void(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  21. {
  22. FPRINTF_MPI(stderr, "Success with void value\n");
  23. }
  24. void check_variable(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  25. {
  26. float *v_s, *v_r;
  27. STARPU_ASSERT(starpu_variable_get_elemsize(handle_s) == starpu_variable_get_elemsize(handle_r));
  28. v_s = (float *)starpu_variable_get_local_ptr(handle_s);
  29. v_r = (float *)starpu_variable_get_local_ptr(handle_r);
  30. if (*v_s == *v_r)
  31. {
  32. FPRINTF_MPI(stderr, "Success with variable value: %f == %f\n", *v_s, *v_r);
  33. }
  34. else
  35. {
  36. *error = 1;
  37. FPRINTF_MPI(stderr, "Error with variable value: %f != %f\n", *v_s, *v_r);
  38. }
  39. }
  40. void check_vector(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  41. {
  42. int i;
  43. int nx;
  44. int *v_r, *v_s;
  45. STARPU_ASSERT(starpu_vector_get_elemsize(handle_s) == starpu_vector_get_elemsize(handle_r));
  46. STARPU_ASSERT(starpu_vector_get_nx(handle_s) == starpu_vector_get_nx(handle_r));
  47. nx = starpu_vector_get_nx(handle_r);
  48. v_r = (int *)starpu_vector_get_local_ptr(handle_r);
  49. v_s = (int *)starpu_vector_get_local_ptr(handle_s);
  50. for(i=0 ; i<nx ; i++)
  51. {
  52. if (v_s[i] == v_r[i])
  53. {
  54. FPRINTF_MPI(stderr, "Success with vector[%d] value: %d == %d\n", i, v_s[i], v_r[i]);
  55. }
  56. else
  57. {
  58. *error = 1;
  59. FPRINTF_MPI(stderr, "Error with vector[%d] value: %d != %d\n", i, v_s[i], v_r[i]);
  60. }
  61. }
  62. }
  63. void check_matrix(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  64. {
  65. STARPU_ASSERT(starpu_matrix_get_elemsize(handle_s) == starpu_matrix_get_elemsize(handle_r));
  66. STARPU_ASSERT(starpu_matrix_get_nx(handle_s) == starpu_matrix_get_nx(handle_r));
  67. STARPU_ASSERT(starpu_matrix_get_ny(handle_s) == starpu_matrix_get_ny(handle_r));
  68. STARPU_ASSERT(starpu_matrix_get_local_ld(handle_s) == starpu_matrix_get_local_ld(handle_r));
  69. char *matrix_s = (char *)starpu_matrix_get_local_ptr(handle_s);
  70. char *matrix_r = (char *)starpu_matrix_get_local_ptr(handle_r);
  71. int nx = starpu_matrix_get_nx(handle_s);
  72. int ny = starpu_matrix_get_ny(handle_s);
  73. int ldy = starpu_matrix_get_local_ld(handle_s);
  74. int x, y;
  75. for(y=0 ; y<ny ; y++)
  76. {
  77. for(x=0 ; x<nx ; x++)
  78. {
  79. int index=(y*ldy)+x;
  80. if (matrix_s[index] == matrix_r[index])
  81. {
  82. FPRINTF_MPI(stderr, "Success with matrix[%d,%d --> %d] value: %c == %c\n", x, y, index, matrix_s[index], matrix_r[index]);
  83. }
  84. else
  85. {
  86. *error = 1;
  87. FPRINTF_MPI(stderr, "Error with matrix[%d,%d --> %d] value: %c != %c\n", x, y, index, matrix_s[index], matrix_r[index]);
  88. }
  89. }
  90. }
  91. }
  92. void check_block(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  93. {
  94. STARPU_ASSERT(starpu_block_get_elemsize(handle_s) == starpu_block_get_elemsize(handle_r));
  95. STARPU_ASSERT(starpu_block_get_nx(handle_s) == starpu_block_get_nx(handle_r));
  96. STARPU_ASSERT(starpu_block_get_ny(handle_s) == starpu_block_get_ny(handle_r));
  97. STARPU_ASSERT(starpu_block_get_nz(handle_s) == starpu_block_get_nz(handle_r));
  98. STARPU_ASSERT(starpu_block_get_local_ldy(handle_s) == starpu_block_get_local_ldy(handle_r));
  99. STARPU_ASSERT(starpu_block_get_local_ldz(handle_s) == starpu_block_get_local_ldz(handle_r));
  100. starpu_data_acquire(handle_s, STARPU_R);
  101. starpu_data_acquire(handle_r, STARPU_R);
  102. float *block_s = (float *)starpu_block_get_local_ptr(handle_s);
  103. float *block_r = (float *)starpu_block_get_local_ptr(handle_r);
  104. int nx = starpu_block_get_nx(handle_s);
  105. int ny = starpu_block_get_ny(handle_s);
  106. int nz = starpu_block_get_nz(handle_s);
  107. int ldy = starpu_block_get_local_ldy(handle_s);
  108. int ldz = starpu_block_get_local_ldz(handle_s);
  109. int x, y, z;
  110. for(z=0 ; z<nz ; z++)
  111. {
  112. for(y=0 ; y<ny ; y++)
  113. for(x=0 ; x<nx ; x++)
  114. {
  115. int index=(z*ldz)+(y*ldy)+x;
  116. if (block_s[index] == block_r[index])
  117. {
  118. FPRINTF_MPI(stderr, "Success with block[%d,%d,%d --> %d] value: %f == %f\n", x, y, z, index, block_s[index], block_r[index]);
  119. }
  120. else
  121. {
  122. *error = 1;
  123. FPRINTF_MPI(stderr, "Error with block[%d,%d,%d --> %d] value: %f != %f\n", x, y, z, index, block_s[index], block_r[index]);
  124. }
  125. }
  126. }
  127. starpu_data_release(handle_s);
  128. starpu_data_release(handle_r);
  129. }
  130. void check_bcsr(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  131. {
  132. STARPU_ASSERT(starpu_bcsr_get_elemsize(handle_s) == starpu_bcsr_get_elemsize(handle_r));
  133. STARPU_ASSERT(starpu_bcsr_get_nnz(handle_s) == starpu_bcsr_get_nnz(handle_r));
  134. STARPU_ASSERT(starpu_bcsr_get_nrow(handle_s) == starpu_bcsr_get_nrow(handle_r));
  135. STARPU_ASSERT(starpu_bcsr_get_firstentry(handle_s) == starpu_bcsr_get_firstentry(handle_r));
  136. STARPU_ASSERT(starpu_bcsr_get_r(handle_s) == starpu_bcsr_get_r(handle_r));
  137. STARPU_ASSERT(starpu_bcsr_get_c(handle_s) == starpu_bcsr_get_c(handle_r));
  138. starpu_data_acquire(handle_s, STARPU_R);
  139. starpu_data_acquire(handle_r, STARPU_R);
  140. uint32_t *colind_s = starpu_bcsr_get_local_colind(handle_s);
  141. uint32_t *colind_r = starpu_bcsr_get_local_colind(handle_r);
  142. uint32_t *rowptr_s = starpu_bcsr_get_local_rowptr(handle_s);
  143. uint32_t *rowptr_r = starpu_bcsr_get_local_rowptr(handle_r);
  144. int *bcsr_s = (int *)starpu_bcsr_get_local_nzval(handle_s);
  145. int *bcsr_r = (int *)starpu_bcsr_get_local_nzval(handle_r);
  146. int r = starpu_bcsr_get_r(handle_s);
  147. int c = starpu_bcsr_get_c(handle_s);
  148. int nnz = starpu_bcsr_get_nnz(handle_s);
  149. int nrows = starpu_bcsr_get_nrow(handle_s);
  150. int x;
  151. for(x=0 ; x<nnz ; x++)
  152. {
  153. if (colind_s[x] == colind_r[x])
  154. {
  155. FPRINTF_MPI(stderr, "Success with colind[%d] value: %u == %u\n", x, colind_s[x], colind_r[x]);
  156. }
  157. else
  158. {
  159. *error = 1;
  160. FPRINTF_MPI(stderr, "Error with colind[%d] value: %u != %u\n", x, colind_s[x], colind_r[x]);
  161. }
  162. }
  163. for(x=0 ; x<nrows+1 ; x++)
  164. {
  165. if (rowptr_s[x] == rowptr_r[x])
  166. {
  167. FPRINTF_MPI(stderr, "Success with rowptr[%d] value: %u == %u\n", x, rowptr_s[x], rowptr_r[x]);
  168. }
  169. else
  170. {
  171. *error = 1;
  172. FPRINTF_MPI(stderr, "Error with rowptr[%d] value: %u != %u\n", x, rowptr_s[x], rowptr_r[x]);
  173. }
  174. }
  175. for(x=0 ; x<r*c*nnz ; x++)
  176. {
  177. if (bcsr_s[x] == bcsr_r[x])
  178. {
  179. FPRINTF_MPI(stderr, "Success with bcsr[%d] value: %d == %d\n", x, bcsr_s[x], bcsr_r[x]);
  180. }
  181. else
  182. {
  183. *error = 1;
  184. FPRINTF_MPI(stderr, "Error with bcsr[%d] value: %d != %d\n", x, bcsr_s[x], bcsr_r[x]);
  185. }
  186. }
  187. starpu_data_release(handle_s);
  188. starpu_data_release(handle_r);
  189. }
  190. void send_recv_and_check(int rank, int node, starpu_data_handle_t handle_s, int tag_s, starpu_data_handle_t handle_r, int tag_r, int *error, check_func func)
  191. {
  192. int ret;
  193. MPI_Status status;
  194. if (rank == 0)
  195. {
  196. ret = starpu_mpi_send(handle_s, node, tag_s, MPI_COMM_WORLD);
  197. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_send");
  198. ret = starpu_mpi_recv(handle_r, node, tag_r, MPI_COMM_WORLD, &status);
  199. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_recv");
  200. assert(func);
  201. func(handle_s, handle_r, error);
  202. }
  203. else if (rank == 1)
  204. {
  205. ret = starpu_mpi_recv(handle_s, node, tag_s, MPI_COMM_WORLD, &status);
  206. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_recv");
  207. ret = starpu_mpi_send(handle_s, node, tag_r, MPI_COMM_WORLD);
  208. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_send");
  209. }
  210. }
  211. void exchange_void(int rank, int *error)
  212. {
  213. STARPU_SKIP_IF_VALGRIND;
  214. if (rank == 0)
  215. {
  216. starpu_data_handle_t void_handle[2];
  217. starpu_void_data_register(&void_handle[0]);
  218. starpu_void_data_register(&void_handle[1]);
  219. send_recv_and_check(rank, 1, void_handle[0], 0x42, void_handle[1], 0x1337, error, check_void);
  220. starpu_data_unregister(void_handle[0]);
  221. starpu_data_unregister(void_handle[1]);
  222. }
  223. else if (rank == 1)
  224. {
  225. starpu_data_handle_t void_handle;
  226. starpu_void_data_register(&void_handle);
  227. send_recv_and_check(rank, 0, void_handle, 0x42, NULL, 0x1337, NULL, NULL);
  228. starpu_data_unregister(void_handle);
  229. }
  230. }
  231. void exchange_variable(int rank, int *error)
  232. {
  233. if (rank == 0)
  234. {
  235. float v = 42.12;
  236. starpu_data_handle_t variable_handle[2];
  237. starpu_variable_data_register(&variable_handle[0], STARPU_MAIN_RAM, (uintptr_t)&v, sizeof(v));
  238. starpu_variable_data_register(&variable_handle[1], -1, (uintptr_t)NULL, sizeof(v));
  239. send_recv_and_check(rank, 1, variable_handle[0], 0x42, variable_handle[1], 0x1337, error, check_variable);
  240. starpu_data_unregister(variable_handle[0]);
  241. starpu_data_unregister(variable_handle[1]);
  242. }
  243. else if (rank == 1)
  244. {
  245. starpu_data_handle_t variable_handle;
  246. starpu_variable_data_register(&variable_handle, -1, (uintptr_t)NULL, sizeof(float));
  247. send_recv_and_check(rank, 0, variable_handle, 0x42, NULL, 0x1337, NULL, NULL);
  248. starpu_data_unregister(variable_handle);
  249. }
  250. }
  251. void exchange_vector(int rank, int *error)
  252. {
  253. if (rank == 0)
  254. {
  255. int vector[4] = {1, 2, 3, 4};
  256. starpu_data_handle_t vector_handle[2];
  257. starpu_vector_data_register(&vector_handle[0], STARPU_MAIN_RAM, (uintptr_t)vector, 4, sizeof(vector[0]));
  258. starpu_vector_data_register(&vector_handle[1], -1, (uintptr_t)NULL, 4, sizeof(vector[0]));
  259. send_recv_and_check(rank, 1, vector_handle[0], 0x43, vector_handle[1], 0x2337, error, check_vector);
  260. starpu_data_unregister(vector_handle[0]);
  261. starpu_data_unregister(vector_handle[1]);
  262. }
  263. else if (rank == 1)
  264. {
  265. starpu_data_handle_t vector_handle;
  266. starpu_vector_data_register(&vector_handle, -1, (uintptr_t)NULL, 4, sizeof(int));
  267. send_recv_and_check(rank, 0, vector_handle, 0x43, NULL, 0x2337, NULL, NULL);
  268. starpu_data_unregister(vector_handle);
  269. }
  270. }
  271. void exchange_matrix(int rank, int *error)
  272. {
  273. int nx=3;
  274. int ny=2;
  275. if (rank == 0)
  276. {
  277. char *matrix, n='a';
  278. int x, y;
  279. starpu_data_handle_t matrix_handle[2];
  280. matrix = (char*)malloc(nx*ny*sizeof(char));
  281. assert(matrix);
  282. for(y=0 ; y<ny ; y++)
  283. {
  284. for(x=0 ; x<nx ; x++)
  285. {
  286. matrix[(y*nx)+x] = n++;
  287. }
  288. }
  289. starpu_matrix_data_register(&matrix_handle[0], STARPU_MAIN_RAM, (uintptr_t)matrix, nx, nx, ny, sizeof(char));
  290. starpu_matrix_data_register(&matrix_handle[1], -1, (uintptr_t)NULL, nx, nx, ny, sizeof(char));
  291. send_recv_and_check(rank, 1, matrix_handle[0], 0x75, matrix_handle[1], 0x8555, error, check_matrix);
  292. starpu_data_unregister(matrix_handle[0]);
  293. starpu_data_unregister(matrix_handle[1]);
  294. free(matrix);
  295. }
  296. else if (rank == 1)
  297. {
  298. starpu_data_handle_t matrix_handle;
  299. starpu_matrix_data_register(&matrix_handle, -1, (uintptr_t)NULL, nx, nx, ny, sizeof(char));
  300. send_recv_and_check(rank, 0, matrix_handle, 0x75, NULL, 0x8555, NULL, NULL);
  301. starpu_data_unregister(matrix_handle);
  302. }
  303. }
  304. void exchange_block(int rank, int *error)
  305. {
  306. int nx=3;
  307. int ny=2;
  308. int nz=4;
  309. if (rank == 0)
  310. {
  311. float *block, n=1.0;
  312. int x, y, z;
  313. starpu_data_handle_t block_handle[2];
  314. block = (float*)malloc(nx*ny*nz*sizeof(float));
  315. assert(block);
  316. for(z=0 ; z<nz ; z++)
  317. {
  318. for(y=0 ; y<ny ; y++)
  319. {
  320. for(x=0 ; x<nx ; x++)
  321. {
  322. block[(z*nx*ny)+(y*nx)+x] = n++;
  323. }
  324. }
  325. }
  326. starpu_block_data_register(&block_handle[0], STARPU_MAIN_RAM, (uintptr_t)block, nx, nx*ny, nx, ny, nz, sizeof(float));
  327. starpu_block_data_register(&block_handle[1], -1, (uintptr_t)NULL, nx, nx*ny, nx, ny, nz, sizeof(float));
  328. send_recv_and_check(rank, 1, block_handle[0], 0x73, block_handle[1], 0x8337, error, check_block);
  329. starpu_data_unregister(block_handle[0]);
  330. starpu_data_unregister(block_handle[1]);
  331. free(block);
  332. }
  333. else if (rank == 1)
  334. {
  335. starpu_data_handle_t block_handle;
  336. starpu_block_data_register(&block_handle, -1, (uintptr_t)NULL, nx, nx*ny, nx, ny, nz, sizeof(float));
  337. send_recv_and_check(rank, 0, block_handle, 0x73, NULL, 0x8337, NULL, NULL);
  338. starpu_data_unregister(block_handle);
  339. }
  340. }
  341. void exchange_bcsr(int rank, int *error)
  342. {
  343. /*
  344. * We use the following matrix:
  345. *
  346. * +----------------+
  347. * | 0 1 0 0 |
  348. * | 2 3 0 0 |
  349. * | 4 5 8 9 |
  350. * | 6 7 10 11 |
  351. * +----------------+
  352. *
  353. * nzval = [0, 1, 2, 3] ++ [4, 5, 6, 7] ++ [8, 9, 10, 11]
  354. * colind = [0, 0, 1]
  355. * rowptr = [0, 1, 3]
  356. * r = c = 2
  357. */
  358. /* Size of the blocks */
  359. #define BCSR_R 2
  360. #define BCSR_C 2
  361. #define BCSR_NROWS 2
  362. #define BCSR_NNZ_BLOCKS 3 /* out of 4 */
  363. #define BCSR_NZVAL_SIZE (BCSR_R*BCSR_C*BCSR_NNZ_BLOCKS)
  364. if (rank == 0)
  365. {
  366. starpu_data_handle_t bcsr_handle[2];
  367. uint32_t colind[BCSR_NNZ_BLOCKS] = {0, 0, 1};
  368. uint32_t rowptr[BCSR_NROWS+1] = {0, 1, BCSR_NNZ_BLOCKS};
  369. int nzval[BCSR_NZVAL_SIZE] =
  370. {
  371. 0, 1, 2, 3, /* First block */
  372. 4, 5, 6, 7, /* Second block */
  373. 8, 9, 10, 11 /* Third block */
  374. };
  375. starpu_bcsr_data_register(&bcsr_handle[0], STARPU_MAIN_RAM, BCSR_NNZ_BLOCKS, BCSR_NROWS, (uintptr_t) nzval, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
  376. starpu_bcsr_data_register(&bcsr_handle[1], -1, BCSR_NNZ_BLOCKS, BCSR_NROWS, (uintptr_t) NULL, (uint32_t *) NULL, (uint32_t *) NULL, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
  377. send_recv_and_check(rank, 1, bcsr_handle[0], 0x73, bcsr_handle[1], 0x8337, error, check_bcsr);
  378. starpu_data_unregister(bcsr_handle[0]);
  379. starpu_data_unregister(bcsr_handle[1]);
  380. }
  381. else if (rank == 1)
  382. {
  383. starpu_data_handle_t bcsr_handle;
  384. starpu_bcsr_data_register(&bcsr_handle, -1, BCSR_NNZ_BLOCKS, BCSR_NROWS, (uintptr_t) NULL, (uint32_t *) NULL, (uint32_t *) NULL, 0, BCSR_R, BCSR_C, sizeof(int));
  385. send_recv_and_check(rank, 0, bcsr_handle, 0x73, NULL, 0x8337, NULL, NULL);
  386. starpu_data_unregister(bcsr_handle);
  387. }
  388. }
  389. int main(int argc, char **argv)
  390. {
  391. int ret, rank, size;
  392. int error=0;
  393. int mpi_init;
  394. MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
  395. ret = starpu_init(NULL);
  396. STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
  397. ret = starpu_mpi_init(NULL, NULL, mpi_init);
  398. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
  399. starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
  400. starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
  401. if (size < 2)
  402. {
  403. if (rank == 0)
  404. FPRINTF(stderr, "We need at least 2 processes.\n");
  405. starpu_mpi_shutdown();
  406. starpu_shutdown();
  407. if (!mpi_init)
  408. MPI_Finalize();
  409. return STARPU_TEST_SKIPPED;
  410. }
  411. exchange_void(rank, &error);
  412. exchange_variable(rank, &error);
  413. exchange_vector(rank, &error);
  414. exchange_matrix(rank, &error);
  415. exchange_block(rank, &error);
  416. exchange_bcsr(rank, &error);
  417. starpu_mpi_shutdown();
  418. starpu_shutdown();
  419. if (!mpi_init)
  420. MPI_Finalize();
  421. return rank == 0 ? error : 0;
  422. }