datatypes.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2013, 2014, 2015, 2016 CNRS
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu_mpi.h>
  17. #include <stdlib.h>
  18. #include "helper.h"
  19. typedef void (*check_func)(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error);
  20. void check_void(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  21. {
  22. FPRINTF_MPI(stderr, "Success with void value\n");
  23. }
  24. void check_variable(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  25. {
  26. int ret;
  27. float *v_s, *v_r;
  28. STARPU_ASSERT(starpu_variable_get_elemsize(handle_s) == starpu_variable_get_elemsize(handle_r));
  29. v_s = (float *)starpu_variable_get_local_ptr(handle_s);
  30. v_r = (float *)starpu_variable_get_local_ptr(handle_r);
  31. if (*v_s == *v_r)
  32. {
  33. FPRINTF_MPI(stderr, "Success with variable value: %f == %f\n", *v_s, *v_r);
  34. }
  35. else
  36. {
  37. *error = 1;
  38. FPRINTF_MPI(stderr, "Error with variable value: %f != %f\n", *v_s, *v_r);
  39. }
  40. }
  41. void check_vector(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  42. {
  43. int ret, i;
  44. int nx;
  45. int *v_r, *v_s;
  46. STARPU_ASSERT(starpu_vector_get_elemsize(handle_s) == starpu_vector_get_elemsize(handle_r));
  47. STARPU_ASSERT(starpu_vector_get_nx(handle_s) == starpu_vector_get_nx(handle_r));
  48. nx = starpu_vector_get_nx(handle_r);
  49. v_r = (int *)starpu_vector_get_local_ptr(handle_r);
  50. v_s = (int *)starpu_vector_get_local_ptr(handle_s);
  51. for(i=0 ; i<nx ; i++)
  52. {
  53. if (v_s[i] == v_r[i])
  54. {
  55. FPRINTF_MPI(stderr, "Success with vector[%d] value: %d == %d\n", i, v_s[i], v_r[i]);
  56. }
  57. else
  58. {
  59. *error = 1;
  60. FPRINTF_MPI(stderr, "Error with vector[%d] value: %d != %d\n", i, v_s[i], v_r[i]);
  61. }
  62. }
  63. }
  64. void check_matrix(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  65. {
  66. STARPU_ASSERT(starpu_matrix_get_elemsize(handle_s) == starpu_matrix_get_elemsize(handle_r));
  67. STARPU_ASSERT(starpu_matrix_get_nx(handle_s) == starpu_matrix_get_nx(handle_r));
  68. STARPU_ASSERT(starpu_matrix_get_ny(handle_s) == starpu_matrix_get_ny(handle_r));
  69. STARPU_ASSERT(starpu_matrix_get_local_ld(handle_s) == starpu_matrix_get_local_ld(handle_r));
  70. char *matrix_s = (char *)starpu_matrix_get_local_ptr(handle_s);
  71. char *matrix_r = (char *)starpu_matrix_get_local_ptr(handle_r);
  72. int nx = starpu_matrix_get_nx(handle_s);
  73. int ny = starpu_matrix_get_ny(handle_s);
  74. int ldy = starpu_matrix_get_local_ld(handle_s);
  75. int x, y;
  76. for(y=0 ; y<ny ; y++)
  77. {
  78. for(x=0 ; x<nx ; x++)
  79. {
  80. int index=(y*ldy)+x;
  81. if (matrix_s[index] == matrix_r[index])
  82. {
  83. FPRINTF_MPI(stderr, "Success with matrix[%d,%d --> %d] value: %c == %c\n", x, y, index, matrix_s[index], matrix_r[index]);
  84. }
  85. else
  86. {
  87. *error = 1;
  88. FPRINTF_MPI(stderr, "Error with matrix[%d,%d --> %d] value: %c != %c\n", x, y, index, matrix_s[index], matrix_r[index]);
  89. }
  90. }
  91. }
  92. }
  93. void check_block(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
  94. {
  95. STARPU_ASSERT(starpu_block_get_elemsize(handle_s) == starpu_block_get_elemsize(handle_r));
  96. STARPU_ASSERT(starpu_block_get_nx(handle_s) == starpu_block_get_nx(handle_r));
  97. STARPU_ASSERT(starpu_block_get_ny(handle_s) == starpu_block_get_ny(handle_r));
  98. STARPU_ASSERT(starpu_block_get_nz(handle_s) == starpu_block_get_nz(handle_r));
  99. STARPU_ASSERT(starpu_block_get_local_ldy(handle_s) == starpu_block_get_local_ldy(handle_r));
  100. STARPU_ASSERT(starpu_block_get_local_ldz(handle_s) == starpu_block_get_local_ldz(handle_r));
  101. starpu_data_acquire(handle_s, STARPU_R);
  102. starpu_data_acquire(handle_r, STARPU_R);
  103. float *block_s = (float *)starpu_block_get_local_ptr(handle_s);
  104. float *block_r = (float *)starpu_block_get_local_ptr(handle_r);
  105. int nx = starpu_block_get_nx(handle_s);
  106. int ny = starpu_block_get_ny(handle_s);
  107. int nz = starpu_block_get_nz(handle_s);
  108. int ldy = starpu_block_get_local_ldy(handle_s);
  109. int ldz = starpu_block_get_local_ldz(handle_s);
  110. int x, y, z;
  111. for(z=0 ; z<nz ; z++)
  112. {
  113. for(y=0 ; y<ny ; y++)
  114. for(x=0 ; x<nx ; x++)
  115. {
  116. int index=(z*ldz)+(y*ldy)+x;
  117. if (block_s[index] == block_r[index])
  118. {
  119. FPRINTF_MPI(stderr, "Success with block[%d,%d,%d --> %d] value: %f == %f\n", x, y, z, index, block_s[index], block_r[index]);
  120. }
  121. else
  122. {
  123. *error = 1;
  124. FPRINTF_MPI(stderr, "Error with block[%d,%d,%d --> %d] value: %f != %f\n", x, y, z, index, block_s[index], block_r[index]);
  125. }
  126. }
  127. }
  128. starpu_data_release(handle_s);
  129. starpu_data_release(handle_r);
  130. }
  131. void send_recv_and_check(int rank, int node, starpu_data_handle_t handle_s, int tag_s, starpu_data_handle_t handle_r, int tag_r, int *error, check_func func)
  132. {
  133. int ret;
  134. MPI_Status status;
  135. if (rank == 0)
  136. {
  137. ret = starpu_mpi_send(handle_s, node, tag_s, MPI_COMM_WORLD);
  138. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_send");
  139. ret = starpu_mpi_recv(handle_r, node, tag_r, MPI_COMM_WORLD, &status);
  140. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_recv");
  141. func(handle_s, handle_r, error);
  142. }
  143. else if (rank == 1)
  144. {
  145. ret = starpu_mpi_recv(handle_s, node, tag_s, MPI_COMM_WORLD, &status);
  146. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_recv");
  147. ret = starpu_mpi_send(handle_s, node, tag_r, MPI_COMM_WORLD);
  148. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_send");
  149. }
  150. }
  151. void exchange_void(int rank, int *error)
  152. {
  153. STARPU_SKIP_IF_VALGRIND;
  154. if (rank == 0)
  155. {
  156. starpu_data_handle_t void_handle[2];
  157. starpu_void_data_register(&void_handle[0]);
  158. starpu_void_data_register(&void_handle[1]);
  159. send_recv_and_check(rank, 1, void_handle[0], 0x42, void_handle[1], 0x1337, error, check_void);
  160. starpu_data_unregister(void_handle[0]);
  161. starpu_data_unregister(void_handle[1]);
  162. }
  163. else if (rank == 1)
  164. {
  165. starpu_data_handle_t void_handle;
  166. starpu_void_data_register(&void_handle);
  167. send_recv_and_check(rank, 0, void_handle, 0x42, NULL, 0x1337, NULL, NULL);
  168. starpu_data_unregister(void_handle);
  169. }
  170. }
  171. void exchange_variable(int rank, int *error)
  172. {
  173. if (rank == 0)
  174. {
  175. float v = 42.12;
  176. starpu_data_handle_t variable_handle[2];
  177. starpu_variable_data_register(&variable_handle[0], STARPU_MAIN_RAM, (uintptr_t)&v, sizeof(v));
  178. starpu_variable_data_register(&variable_handle[1], -1, (uintptr_t)NULL, sizeof(v));
  179. send_recv_and_check(rank, 1, variable_handle[0], 0x42, variable_handle[1], 0x1337, error, check_variable);
  180. starpu_data_unregister(variable_handle[0]);
  181. starpu_data_unregister(variable_handle[1]);
  182. }
  183. else if (rank == 1)
  184. {
  185. starpu_data_handle_t variable_handle;
  186. starpu_variable_data_register(&variable_handle, -1, (uintptr_t)NULL, sizeof(float));
  187. send_recv_and_check(rank, 0, variable_handle, 0x42, NULL, 0x1337, NULL, NULL);
  188. starpu_data_unregister(variable_handle);
  189. }
  190. }
  191. void exchange_vector(int rank, int *error)
  192. {
  193. if (rank == 0)
  194. {
  195. int vector[4] = {1, 2, 3, 4};
  196. starpu_data_handle_t vector_handle[2];
  197. starpu_vector_data_register(&vector_handle[0], STARPU_MAIN_RAM, (uintptr_t)vector, 4, sizeof(vector[0]));
  198. starpu_vector_data_register(&vector_handle[1], -1, (uintptr_t)NULL, 4, sizeof(vector[0]));
  199. send_recv_and_check(rank, 1, vector_handle[0], 0x43, vector_handle[1], 0x2337, error, check_vector);
  200. starpu_data_unregister(vector_handle[0]);
  201. starpu_data_unregister(vector_handle[1]);
  202. }
  203. else if (rank == 1)
  204. {
  205. starpu_data_handle_t vector_handle;
  206. starpu_vector_data_register(&vector_handle, -1, (uintptr_t)NULL, 4, sizeof(int));
  207. send_recv_and_check(rank, 0, vector_handle, 0x43, NULL, 0x2337, NULL, NULL);
  208. starpu_data_unregister(vector_handle);
  209. }
  210. }
  211. void exchange_matrix(int rank, int *error)
  212. {
  213. int nx=3;
  214. int ny=2;
  215. if (rank == 0)
  216. {
  217. char *matrix, n='a';
  218. int x, y;
  219. starpu_data_handle_t matrix_handle[2];
  220. matrix = (char*)malloc(nx*ny*sizeof(char));
  221. assert(matrix);
  222. for(y=0 ; y<ny ; y++)
  223. {
  224. for(x=0 ; x<nx ; x++)
  225. {
  226. matrix[(y*nx)+x] = n++;
  227. }
  228. }
  229. starpu_matrix_data_register(&matrix_handle[0], STARPU_MAIN_RAM, (uintptr_t)matrix, nx, nx, ny, sizeof(char));
  230. starpu_matrix_data_register(&matrix_handle[1], -1, (uintptr_t)NULL, nx, nx, ny, sizeof(char));
  231. send_recv_and_check(rank, 1, matrix_handle[0], 0x75, matrix_handle[1], 0x8555, error, check_matrix);
  232. starpu_data_unregister(matrix_handle[0]);
  233. starpu_data_unregister(matrix_handle[1]);
  234. free(matrix);
  235. }
  236. else if (rank == 1)
  237. {
  238. starpu_data_handle_t matrix_handle;
  239. starpu_matrix_data_register(&matrix_handle, -1, (uintptr_t)NULL, nx, nx, ny, sizeof(char));
  240. send_recv_and_check(rank, 0, matrix_handle, 0x75, NULL, 0x8555, NULL, NULL);
  241. starpu_data_unregister(matrix_handle);
  242. }
  243. }
  244. void exchange_block(int rank, int *error)
  245. {
  246. int nx=3;
  247. int ny=2;
  248. int nz=4;
  249. if (rank == 0)
  250. {
  251. float *block, n=1.0;
  252. int x, y, z;
  253. starpu_data_handle_t block_handle[2];
  254. block = (float*)malloc(nx*ny*nz*sizeof(float));
  255. assert(block);
  256. for(z=0 ; z<nz ; z++)
  257. {
  258. for(y=0 ; y<ny ; y++)
  259. {
  260. for(x=0 ; x<nx ; x++)
  261. {
  262. block[(z*nx*ny)+(y*nx)+x] = n++;
  263. }
  264. }
  265. }
  266. starpu_block_data_register(&block_handle[0], STARPU_MAIN_RAM, (uintptr_t)block, nx, nx*ny, nx, ny, nz, sizeof(float));
  267. starpu_block_data_register(&block_handle[1], -1, (uintptr_t)NULL, nx, nx*ny, nx, ny, nz, sizeof(float));
  268. send_recv_and_check(rank, 1, block_handle[0], 0x73, block_handle[1], 0x8337, error, check_block);
  269. starpu_data_unregister(block_handle[0]);
  270. starpu_data_unregister(block_handle[1]);
  271. free(block);
  272. }
  273. else if (rank == 1)
  274. {
  275. starpu_data_handle_t block_handle;
  276. starpu_block_data_register(&block_handle, -1, (uintptr_t)NULL, nx, nx*ny, nx, ny, nz, sizeof(float));
  277. send_recv_and_check(rank, 0, block_handle, 0x73, NULL, 0x8337, NULL, NULL);
  278. starpu_data_unregister(block_handle);
  279. }
  280. }
  281. int main(int argc, char **argv)
  282. {
  283. int ret, rank, size;
  284. int error=0;
  285. MPI_Init(&argc, &argv);
  286. starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
  287. starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
  288. ret = starpu_init(NULL);
  289. STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
  290. ret = starpu_mpi_init(NULL, NULL, 0);
  291. STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
  292. if (size < 2)
  293. {
  294. if (rank == 0)
  295. FPRINTF(stderr, "We need at least 2 processes.\n");
  296. starpu_mpi_shutdown();
  297. starpu_shutdown();
  298. MPI_Finalize();
  299. return STARPU_TEST_SKIPPED;
  300. }
  301. exchange_void(rank, &error);
  302. exchange_variable(rank, &error);
  303. exchange_vector(rank, &error);
  304. exchange_matrix(rank, &error);
  305. exchange_block(rank, &error);
  306. starpu_mpi_shutdown();
  307. starpu_shutdown();
  308. MPI_Finalize();
  309. return rank == 0 ? error : 0;
  310. }