starpu_mpi_datatype.c 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2009, 2010 Université de Bordeaux 1
  4. * Copyright (C) 2010, 2011 Centre National de la Recherche Scientifique
  5. *
  6. * StarPU is free software; you can redistribute it and/or modify
  7. * it under the terms of the GNU Lesser General Public License as published by
  8. * the Free Software Foundation; either version 2.1 of the License, or (at
  9. * your option) any later version.
  10. *
  11. * StarPU is distributed in the hope that it will be useful, but
  12. * WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  14. *
  15. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  16. */
  17. #include <starpu_mpi_datatype.h>
  18. /*
  19. * MPI_* functions usually requires both a pointer to the first element of
  20. * a datatype and the datatype itself, so we need to provide both.
  21. */
  22. typedef int (*handle_to_datatype_func)(starpu_data_handle, MPI_Datatype *);
  23. typedef void *(*handle_to_ptr_func)(starpu_data_handle);
  24. /*
  25. * Matrix
  26. */
  27. static int handle_to_datatype_matrix(starpu_data_handle data_handle, MPI_Datatype *datatype)
  28. {
  29. int ret;
  30. unsigned nx = starpu_matrix_get_nx(data_handle);
  31. unsigned ny = starpu_matrix_get_ny(data_handle);
  32. unsigned ld = starpu_matrix_get_local_ld(data_handle);
  33. size_t elemsize = starpu_matrix_get_elemsize(data_handle);
  34. ret = MPI_Type_vector(ny, nx*elemsize, ld*elemsize, MPI_BYTE, datatype);
  35. STARPU_ASSERT(ret == MPI_SUCCESS);
  36. ret = MPI_Type_commit(datatype);
  37. STARPU_ASSERT(ret == MPI_SUCCESS);
  38. return 0;
  39. }
  40. static void *handle_to_ptr_matrix(starpu_data_handle data_handle)
  41. {
  42. return (void *)starpu_matrix_get_local_ptr(data_handle);
  43. }
  44. /*
  45. * Block
  46. */
  47. static int handle_to_datatype_block(starpu_data_handle data_handle, MPI_Datatype *datatype)
  48. {
  49. int ret;
  50. unsigned nx = starpu_block_get_nx(data_handle);
  51. unsigned ny = starpu_block_get_ny(data_handle);
  52. unsigned nz = starpu_block_get_nz(data_handle);
  53. unsigned ldy = starpu_block_get_local_ldy(data_handle);
  54. unsigned ldz = starpu_block_get_local_ldz(data_handle);
  55. size_t elemsize = starpu_block_get_elemsize(data_handle);
  56. MPI_Datatype datatype_2dlayer;
  57. ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_2dlayer);
  58. STARPU_ASSERT(ret == MPI_SUCCESS);
  59. ret = MPI_Type_commit(&datatype_2dlayer);
  60. STARPU_ASSERT(ret == MPI_SUCCESS);
  61. ret = MPI_Type_hvector(nz, 1, ldz*elemsize, datatype_2dlayer, datatype);
  62. STARPU_ASSERT(ret == MPI_SUCCESS);
  63. ret = MPI_Type_commit(datatype);
  64. STARPU_ASSERT(ret == MPI_SUCCESS);
  65. return 0;
  66. }
  67. static void *handle_to_ptr_block(starpu_data_handle data_handle)
  68. {
  69. return (void *)starpu_block_get_local_ptr(data_handle);
  70. }
  71. /*
  72. * Vector
  73. */
  74. static int handle_to_datatype_vector(starpu_data_handle data_handle, MPI_Datatype *datatype)
  75. {
  76. int ret;
  77. unsigned nx = starpu_vector_get_nx(data_handle);
  78. size_t elemsize = starpu_vector_get_elemsize(data_handle);
  79. ret = MPI_Type_contiguous(nx*elemsize, MPI_BYTE, datatype);
  80. STARPU_ASSERT(ret == MPI_SUCCESS);
  81. ret = MPI_Type_commit(datatype);
  82. STARPU_ASSERT(ret == MPI_SUCCESS);
  83. return 0;
  84. }
  85. static void *handle_to_ptr_vector(starpu_data_handle data_handle)
  86. {
  87. return (void *)starpu_vector_get_local_ptr(data_handle);
  88. }
  89. /*
  90. * Variable
  91. */
  92. static int handle_to_datatype_variable(starpu_data_handle data_handle, MPI_Datatype *datatype)
  93. {
  94. int ret;
  95. size_t elemsize = starpu_variable_get_elemsize(data_handle);
  96. ret = MPI_Type_contiguous(elemsize, MPI_BYTE, datatype);
  97. STARPU_ASSERT(ret == MPI_SUCCESS);
  98. ret = MPI_Type_commit(datatype);
  99. STARPU_ASSERT(ret == MPI_SUCCESS);
  100. return 0;
  101. }
  102. static void *handle_to_ptr_variable(starpu_data_handle data_handle)
  103. {
  104. return (void *)starpu_variable_get_local_ptr(data_handle);
  105. }
  106. /*
  107. * Generic
  108. */
  109. static handle_to_datatype_func handle_to_datatype_funcs[STARPU_NINTERFACES_ID] = {
  110. [STARPU_MATRIX_INTERFACE_ID] = handle_to_datatype_matrix,
  111. [STARPU_BLOCK_INTERFACE_ID] = handle_to_datatype_block,
  112. [STARPU_VECTOR_INTERFACE_ID] = handle_to_datatype_vector,
  113. [STARPU_CSR_INTERFACE_ID] = NULL,
  114. [STARPU_BCSR_INTERFACE_ID] = NULL,
  115. [STARPU_VARIABLE_INTERFACE_ID] = handle_to_datatype_variable,
  116. };
  117. static handle_to_ptr_func handle_to_ptr_funcs[STARPU_NINTERFACES_ID] = {
  118. [STARPU_MATRIX_INTERFACE_ID] = handle_to_ptr_matrix,
  119. [STARPU_BLOCK_INTERFACE_ID] = handle_to_ptr_block,
  120. [STARPU_VECTOR_INTERFACE_ID] = handle_to_ptr_vector,
  121. [STARPU_CSR_INTERFACE_ID] = NULL,
  122. [STARPU_BCSR_INTERFACE_ID] = NULL,
  123. [STARPU_VARIABLE_INTERFACE_ID] = handle_to_ptr_variable,
  124. };
  125. int starpu_mpi_handle_to_datatype(starpu_data_handle data_handle, MPI_Datatype *datatype)
  126. {
  127. unsigned id = starpu_get_handle_interface_id(data_handle);
  128. handle_to_datatype_func func = handle_to_datatype_funcs[id];
  129. STARPU_ASSERT(func);
  130. return func(data_handle, datatype);
  131. }
  132. void *starpu_mpi_handle_to_ptr(starpu_data_handle data_handle)
  133. {
  134. unsigned id = starpu_get_handle_interface_id(data_handle);
  135. handle_to_ptr_func func = handle_to_ptr_funcs[id];
  136. STARPU_ASSERT(func);
  137. return func(data_handle);
  138. }