starpu_mpi.c 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu_mpi.h>
  17. #include <starpu_mpi_datatype.h>
  18. static void submit_mpi_req(struct starpu_mpi_req_s *req);
  19. void handle_request_termination(struct starpu_mpi_req_s *req);
  20. static starpu_mpi_req_list_t new_requests;
  21. static starpu_mpi_req_list_t pending_requests;
  22. static pthread_cond_t cond;
  23. static pthread_mutex_t mutex;
  24. static pthread_t progress_thread;
  25. static int running = 0;
  26. static void _handle_new_mpi_isend(struct starpu_mpi_req_s *req)
  27. {
  28. void *ptr = starpu_mpi_handle_to_ptr(req->data_handle);
  29. starpu_mpi_handle_to_datatype(req->data_handle, &req->datatype);
  30. MPI_Isend(ptr, 1, req->datatype, req->dst, req->mpi_tag, req->comm, &req->request);
  31. }
  32. int starpu_mpi_isend(starpu_data_handle data_handle, struct starpu_mpi_req_s *req,
  33. int dest, int mpi_tag, MPI_Comm comm,
  34. void (*callback)(void *))
  35. {
  36. req->submitted = 0;
  37. pthread_mutex_init(&req->req_mutex, NULL);
  38. pthread_cond_init(&req->req_cond, NULL);
  39. req->data_handle = data_handle;
  40. req->dst = dest;
  41. req->mpi_tag = mpi_tag;
  42. req->comm = comm;
  43. req->mode = STARPU_R;
  44. req->handle_new = _handle_new_mpi_isend;
  45. submit_mpi_req(req);
  46. return 0;
  47. }
  48. static void _handle_new_mpi_irecv(struct starpu_mpi_req_s *req)
  49. {
  50. void *ptr = starpu_mpi_handle_to_ptr(req->data_handle);
  51. starpu_mpi_handle_to_datatype(req->data_handle, &req->datatype);
  52. MPI_Irecv(ptr, 1, req->datatype, req->src, req->mpi_tag, req->comm, &req->request);
  53. }
  54. int starpu_mpi_irecv(starpu_data_handle data_handle, struct starpu_mpi_req_s *req,
  55. int source, int mpi_tag, MPI_Comm comm,
  56. void (*callback)(void *))
  57. {
  58. req->submitted = 0;
  59. pthread_mutex_init(&req->req_mutex, NULL);
  60. pthread_cond_init(&req->req_cond, NULL);
  61. req->data_handle = data_handle;
  62. req->mode = STARPU_W;
  63. req->src = source;
  64. req->mpi_tag = mpi_tag;
  65. req->comm = comm;
  66. req->handle_new = _handle_new_mpi_irecv;
  67. submit_mpi_req(req);
  68. return 0;
  69. }
  70. int starpu_mpi_recv(starpu_data_handle data_handle,
  71. int source, int mpi_tag, MPI_Comm comm, MPI_Status *status)
  72. {
  73. /* test if we are blocking in a callback .. */
  74. int ret = starpu_sync_data_with_mem(data_handle, STARPU_W);
  75. if (ret)
  76. return ret;
  77. void *ptr = starpu_mpi_handle_to_ptr(data_handle);
  78. MPI_Datatype datatype;
  79. starpu_mpi_handle_to_datatype(data_handle, &datatype);
  80. MPI_Recv(ptr, 1, datatype, source, mpi_tag, comm, status);
  81. starpu_release_data_from_mem(data_handle);
  82. return 0;
  83. }
  84. int starpu_mpi_send(starpu_data_handle data_handle,
  85. int dest, int mpi_tag, MPI_Comm comm)
  86. {
  87. /* test if we are blocking in a callback .. */
  88. int ret = starpu_sync_data_with_mem(data_handle, STARPU_R);
  89. if (ret)
  90. return ret;
  91. void *ptr = starpu_mpi_handle_to_ptr(data_handle);
  92. MPI_Status status;
  93. MPI_Datatype datatype;
  94. starpu_mpi_handle_to_datatype(data_handle, &datatype);
  95. MPI_Send(ptr, 1, datatype, dest, mpi_tag, comm);
  96. starpu_release_data_from_mem(data_handle);
  97. return 0;
  98. }
  99. static void _handle_new_mpi_wait(struct starpu_mpi_req_s *req)
  100. {
  101. req->ret = MPI_Wait(&req->request, req->status);
  102. handle_request_termination(req);
  103. }
  104. int starpu_mpi_wait(struct starpu_mpi_req_s *req, MPI_Status *status)
  105. {
  106. int ret;
  107. pthread_mutex_lock(&req->req_mutex);
  108. req->status = status;
  109. /* we don't submit a wait request until the previous mpi request was
  110. * actually submitted */
  111. while (!req->submitted)
  112. pthread_cond_wait(&req->req_cond, &req->req_mutex);
  113. req->submitted = 0;
  114. req->handle_new = _handle_new_mpi_wait;
  115. req->status = status;
  116. submit_mpi_req(req);
  117. while (!req->submitted)
  118. pthread_cond_wait(&req->req_cond, &req->req_mutex);
  119. ret = req->ret;
  120. pthread_mutex_unlock(&req->req_mutex);
  121. return ret;
  122. }
  123. int starpu_mpi_test(struct starpu_mpi_req_s *req, int *flag, MPI_Status *status)
  124. {
  125. int ret = 0;
  126. pthread_mutex_lock(&req->req_mutex);
  127. if (req->submitted)
  128. {
  129. ret = MPI_Test(&req->request, flag, status);
  130. if (*flag)
  131. handle_request_termination(req);
  132. }
  133. else {
  134. *flag = 0;
  135. }
  136. pthread_mutex_unlock(&req->req_mutex);
  137. return ret;
  138. }
  139. /*
  140. * Requests
  141. */
  142. void handle_request_termination(struct starpu_mpi_req_s *req)
  143. {
  144. MPI_Type_free(&req->datatype);
  145. starpu_release_data_from_mem(req->data_handle);
  146. }
  147. void handle_request(struct starpu_mpi_req_s *req)
  148. {
  149. STARPU_ASSERT(req);
  150. pthread_mutex_lock(&req->req_mutex);
  151. starpu_sync_data_with_mem(req->data_handle, req->mode);
  152. /* submit the request to MPI */
  153. req->handle_new(req);
  154. /* perhaps somebody is waiting or trying to test */
  155. req->submitted = 1;
  156. pthread_cond_broadcast(&req->req_cond);
  157. pthread_mutex_unlock(&req->req_mutex);
  158. }
  159. static void submit_mpi_req(struct starpu_mpi_req_s *req)
  160. {
  161. pthread_mutex_lock(&mutex);
  162. pthread_mutex_lock(&req->req_mutex);
  163. starpu_mpi_req_list_push_front(new_requests, req);
  164. pthread_cond_broadcast(&cond);
  165. pthread_cond_broadcast(&req->req_cond);
  166. pthread_mutex_unlock(&req->req_mutex);
  167. pthread_mutex_unlock(&mutex);
  168. }
  169. /*
  170. * Progression loop
  171. */
  172. void *progress_thread_func(void *arg __attribute__((unused)))
  173. {
  174. /* notify the main thread that the progression thread is ready */
  175. pthread_mutex_lock(&mutex);
  176. running = 1;
  177. pthread_cond_signal(&cond);
  178. pthread_mutex_unlock(&mutex);
  179. pthread_mutex_lock(&mutex);
  180. while (running) {
  181. pthread_cond_wait(&cond, &mutex);
  182. if (!running)
  183. break;
  184. while (!starpu_mpi_req_list_empty(new_requests))
  185. {
  186. /* get one request */
  187. struct starpu_mpi_req_s *req;
  188. req = starpu_mpi_req_list_pop_back(new_requests);
  189. /* handling a request is likely to block for a while
  190. * (on a sync_data_with_mem call), we want to let the
  191. * application submit requests in the meantime, so we
  192. * release the lock. */
  193. pthread_mutex_unlock(&mutex);
  194. handle_request(req);
  195. pthread_mutex_lock(&mutex);
  196. }
  197. pthread_mutex_unlock(&mutex);
  198. }
  199. pthread_mutex_unlock(&mutex);
  200. return NULL;
  201. }
  202. /*
  203. * (De)Initialization methods
  204. */
  205. int starpu_mpi_initialize(void)
  206. {
  207. pthread_mutex_init(&mutex, NULL);
  208. pthread_cond_init(&cond, NULL);
  209. /* requests that have not be submitted to MPI yet */
  210. new_requests = starpu_mpi_req_list_new();
  211. /* requests that are already submitted and which are not completed yet */
  212. pending_requests = starpu_mpi_req_list_new();
  213. int ret = pthread_create(&progress_thread, NULL, progress_thread_func, NULL);
  214. pthread_mutex_lock(&mutex);
  215. if (!running)
  216. pthread_cond_wait(&cond, &mutex);
  217. pthread_mutex_unlock(&mutex);
  218. return 0;
  219. }
  220. int starpu_mpi_shutdown(void)
  221. {
  222. /* kill the progression thread */
  223. pthread_mutex_lock(&mutex);
  224. running = 0;
  225. pthread_cond_signal(&cond);
  226. pthread_mutex_unlock(&mutex);
  227. void *value;
  228. pthread_join(progress_thread, &value);
  229. return 0;
  230. }