mpi_like_async.c 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. /*
  2. * StarPU
  3. * Copyright (C) Université Bordeaux 1, CNRS 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu.h>
  17. #include <pthread.h>
  18. #define NTHREADS 16
  19. #define NITER 128
  20. //#define DEBUG_MESSAGES 1
  21. static pthread_cond_t cond;
  22. static pthread_mutex_t mutex;
  23. struct thread_data {
  24. unsigned index;
  25. unsigned val;
  26. starpu_data_handle handle;
  27. pthread_t thread;
  28. pthread_mutex_t recv_mutex;
  29. unsigned recv_flag; // set when a message is received
  30. unsigned recv_buf;
  31. struct thread_data *neighbour;
  32. };
  33. struct data_req {
  34. int (*test_func)(void *);
  35. void *test_arg;
  36. struct data_req *next;
  37. };
  38. static pthread_mutex_t data_req_mutex;
  39. static pthread_cond_t data_req_cond;
  40. struct data_req *data_req_list;
  41. unsigned progress_thread_running;
  42. static struct thread_data problem_data[NTHREADS];
  43. /* We implement some ring transfer, every thread will try to receive a piece of
  44. * data from its neighbour and increment it before transmitting it to its
  45. * successor. */
  46. #ifdef STARPU_USE_CUDA
  47. void cuda_codelet_unsigned_inc(void *descr[], __attribute__ ((unused)) void *cl_arg);
  48. #endif
  49. static void increment_handle_cpu_kernel(void *descr[], void *cl_arg __attribute__((unused)))
  50. {
  51. unsigned *val = (unsigned *)STARPU_VARIABLE_GET_PTR(descr[0]);
  52. *val += 1;
  53. // fprintf(stderr, "VAL %d (&val = %p)\n", *val, val);
  54. }
  55. static starpu_codelet increment_handle_cl = {
  56. .where = STARPU_CPU|STARPU_CUDA,
  57. .cpu_func = increment_handle_cpu_kernel,
  58. #ifdef STARPU_USE_CUDA
  59. .cuda_func = cuda_codelet_unsigned_inc,
  60. #endif
  61. .nbuffers = 1
  62. };
  63. static void increment_handle_async(struct thread_data *thread_data)
  64. {
  65. struct starpu_task *task = starpu_task_create();
  66. task->cl = &increment_handle_cl;
  67. task->buffers[0].handle = thread_data->handle;
  68. task->buffers[0].mode = STARPU_RW;
  69. task->detach = 1;
  70. task->destroy = 1;
  71. int ret = starpu_task_submit(task);
  72. STARPU_ASSERT(!ret);
  73. }
  74. static int test_recv_handle_async(void *arg)
  75. {
  76. // fprintf(stderr, "test_recv_handle_async\n");
  77. int ret;
  78. struct thread_data *thread_data = arg;
  79. pthread_mutex_lock(&thread_data->recv_mutex);
  80. ret = (thread_data->recv_flag == 1);
  81. if (ret)
  82. {
  83. thread_data->recv_flag = 0;
  84. thread_data->val = thread_data->recv_buf;
  85. }
  86. pthread_mutex_unlock(&thread_data->recv_mutex);
  87. if (ret)
  88. {
  89. #ifdef DEBUG_MESSAGES
  90. fprintf(stderr, "Thread %d received value %d from thread %d\n",
  91. thread_data->index, thread_data->val, (thread_data->index - 1)%NTHREADS);
  92. #endif
  93. starpu_data_release(thread_data->handle);
  94. }
  95. return ret;
  96. }
  97. static void recv_handle_async(void *_thread_data)
  98. {
  99. struct thread_data *thread_data = _thread_data;
  100. struct data_req *req = malloc(sizeof(struct data_req));
  101. req->test_func = test_recv_handle_async;
  102. req->test_arg = thread_data;
  103. req->next = NULL;
  104. pthread_mutex_lock(&data_req_mutex);
  105. req->next = data_req_list;
  106. data_req_list = req;
  107. pthread_cond_signal(&data_req_cond);
  108. pthread_mutex_unlock(&data_req_mutex);
  109. }
  110. static int test_send_handle_async(void *arg)
  111. {
  112. int ret;
  113. struct thread_data *thread_data = arg;
  114. struct thread_data *neighbour_data = thread_data->neighbour;
  115. pthread_mutex_lock(&neighbour_data->recv_mutex);
  116. ret = (neighbour_data->recv_flag == 0);
  117. pthread_mutex_unlock(&neighbour_data->recv_mutex);
  118. if (ret)
  119. {
  120. #ifdef DEBUG_MESSAGES
  121. fprintf(stderr, "Thread %d sends value %d to thread %d\n", thread_data->index, thread_data->val, neighbour_data->index);
  122. #endif
  123. starpu_data_release(thread_data->handle);
  124. }
  125. return ret;
  126. }
  127. static void send_handle_async(void *_thread_data)
  128. {
  129. struct thread_data *thread_data = _thread_data;
  130. struct thread_data *neighbour_data = thread_data->neighbour;
  131. // fprintf(stderr, "send_handle_async\n");
  132. /* send the message */
  133. pthread_mutex_lock(&neighbour_data->recv_mutex);
  134. neighbour_data->recv_buf = thread_data->val;
  135. neighbour_data->recv_flag = 1;
  136. pthread_mutex_unlock(&neighbour_data->recv_mutex);
  137. struct data_req *req = malloc(sizeof(struct data_req));
  138. req->test_func = test_send_handle_async;
  139. req->test_arg = thread_data;
  140. req->next = NULL;
  141. pthread_mutex_lock(&data_req_mutex);
  142. req->next = data_req_list;
  143. data_req_list = req;
  144. pthread_cond_signal(&data_req_cond);
  145. pthread_mutex_unlock(&data_req_mutex);
  146. }
  147. static void *progress_func(void *arg)
  148. {
  149. pthread_mutex_lock(&data_req_mutex);
  150. progress_thread_running = 1;
  151. pthread_cond_signal(&data_req_cond);
  152. while (progress_thread_running) {
  153. struct data_req *req;
  154. if (data_req_list == NULL)
  155. pthread_cond_wait(&data_req_cond, &data_req_mutex);
  156. req = data_req_list;
  157. if (req)
  158. {
  159. data_req_list = req->next;
  160. req->next = NULL;
  161. pthread_mutex_unlock(&data_req_mutex);
  162. int ret = req->test_func(req->test_arg);
  163. if (ret)
  164. {
  165. free(req);
  166. pthread_mutex_lock(&data_req_mutex);
  167. }
  168. else {
  169. /* ret = 0 : the request is not finished, we put it back at the end of the list */
  170. pthread_mutex_lock(&data_req_mutex);
  171. struct data_req *req_aux = data_req_list;
  172. if (!req_aux)
  173. {
  174. /* The list is empty */
  175. data_req_list = req;
  176. }
  177. else {
  178. while (req_aux)
  179. {
  180. if (req_aux->next == NULL)
  181. {
  182. req_aux->next = req;
  183. break;
  184. }
  185. req_aux = req_aux->next;
  186. }
  187. }
  188. }
  189. }
  190. }
  191. pthread_mutex_unlock(&data_req_mutex);
  192. return NULL;
  193. }
  194. static void *thread_func(void *arg)
  195. {
  196. unsigned iter;
  197. struct thread_data *thread_data = arg;
  198. unsigned index = thread_data->index;
  199. starpu_variable_data_register(&thread_data->handle, 0, (uintptr_t)&thread_data->val, sizeof(unsigned));
  200. for (iter = 0; iter < NITER; iter++)
  201. {
  202. /* The first thread initiates the first transfer */
  203. if (!((index == 0) && (iter == 0)))
  204. {
  205. starpu_data_acquire_cb(
  206. thread_data->handle, STARPU_W,
  207. recv_handle_async, thread_data
  208. );
  209. }
  210. increment_handle_async(thread_data);
  211. if (!((index == (NTHREADS - 1)) && (iter == (NITER - 1))))
  212. {
  213. starpu_data_acquire_cb(
  214. thread_data->handle, STARPU_R,
  215. send_handle_async, thread_data
  216. );
  217. }
  218. }
  219. starpu_task_wait_for_all();
  220. return NULL;
  221. }
  222. int main(int argc, char **argv)
  223. {
  224. int ret;
  225. void *retval;
  226. starpu_init(NULL);
  227. /* Create a thread to perform blocking calls */
  228. pthread_t progress_thread;
  229. pthread_mutex_init(&data_req_mutex, NULL);
  230. pthread_cond_init(&data_req_cond, NULL);
  231. data_req_list = NULL;
  232. progress_thread_running = 0;
  233. unsigned t;
  234. for (t = 0; t < NTHREADS; t++)
  235. {
  236. problem_data[t].index = t;
  237. problem_data[t].val = 0;
  238. pthread_mutex_init(&problem_data[t].recv_mutex, NULL);
  239. problem_data[t].recv_flag = 0;
  240. problem_data[t].neighbour = &problem_data[(t+1)%NTHREADS];
  241. }
  242. pthread_create(&progress_thread, NULL, progress_func, NULL);
  243. pthread_mutex_lock(&data_req_mutex);
  244. while (!progress_thread_running)
  245. pthread_cond_wait(&data_req_cond, &data_req_mutex);
  246. pthread_mutex_unlock(&data_req_mutex);
  247. for (t = 0; t < NTHREADS; t++)
  248. {
  249. ret = pthread_create(&problem_data[t].thread, NULL, thread_func, &problem_data[t]);
  250. STARPU_ASSERT(!ret);
  251. }
  252. for (t = 0; t < NTHREADS; t++)
  253. {
  254. ret = pthread_join(problem_data[t].thread, &retval);
  255. STARPU_ASSERT(retval == NULL);
  256. }
  257. pthread_mutex_lock(&data_req_mutex);
  258. progress_thread_running = 0;
  259. pthread_cond_signal(&data_req_cond);
  260. pthread_mutex_unlock(&data_req_mutex);
  261. ret = pthread_join(progress_thread, &retval);
  262. STARPU_ASSERT(retval == NULL);
  263. /* We check that the value in the "last" thread is valid */
  264. starpu_data_handle last_handle = problem_data[NTHREADS - 1].handle;
  265. starpu_data_acquire(last_handle, STARPU_R);
  266. if (problem_data[NTHREADS - 1].val != (NTHREADS * NITER))
  267. {
  268. fprintf(stderr, "Final value : %d should be %d\n", problem_data[NTHREADS - 1].val, (NTHREADS * NITER));
  269. STARPU_ABORT();
  270. }
  271. starpu_data_release(last_handle);
  272. starpu_shutdown();
  273. return 0;
  274. }