load_heat_propagation.c 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2016-2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu_mpi.h>
  17. #include <mpi/starpu_mpi_tag.h>
  18. #include <common/uthash.h>
  19. #include <common/utils.h>
  20. #include <math.h>
  21. #include <starpu_mpi_private.h>
  22. #include "load_balancer_policy.h"
  23. #include "data_movements_interface.h"
  24. #include "load_data_interface.h"
  25. #include <common/config.h>
  26. #if defined(STARPU_USE_MPI_MPI)
  27. static starpu_mpi_tag_t TAG_LOAD(int n)
  28. {
  29. return ((starpu_mpi_tag_t) n+1) << 24;
  30. }
  31. static starpu_mpi_tag_t TAG_MOV(int n)
  32. {
  33. return ((starpu_mpi_tag_t) n+1) << 20;
  34. }
  35. /* Hash table of local pieces of data that has been moved out of the local MPI
  36. * node by the load balancer. All of these pieces of data must be migrated back
  37. * to the local node at the end of the execution. */
  38. struct moved_data_entry
  39. {
  40. UT_hash_handle hh;
  41. starpu_data_handle_t handle;
  42. };
  43. static struct moved_data_entry *mdh = NULL;
  44. static starpu_pthread_mutex_t load_data_mutex;
  45. static starpu_pthread_cond_t load_data_cond;
  46. /* MPI infos */
  47. static int my_rank;
  48. static int world_size;
  49. /* Number of neighbours of the local MPI node and their IDs. These are given by
  50. * the get_neighbors() method, and thus can be easily changed. */
  51. static int *neighbor_ids = NULL;
  52. static int nneighbors = 0;
  53. /* Local load data */
  54. static starpu_data_handle_t *load_data_handle = NULL;
  55. static starpu_data_handle_t *load_data_handle_cpy = NULL;
  56. /* Load data of neighbours */
  57. static starpu_data_handle_t *neighbor_load_data_handles = NULL;
  58. /* Table which contains a data_movements_handle for each MPI node of
  59. * MPI_COMM_WORLD. Since all the MPI nodes must be advised of any data
  60. * movement, this table will be used to perform communications of data
  61. * movements handles following an all-to-all model. */
  62. static starpu_data_handle_t *data_movements_handles = NULL;
  63. /* Load balancer interface which contains the application-specific methods for
  64. * the load balancer to use. */
  65. static struct starpu_mpi_lb_conf *user_itf = NULL;
  66. static double time_threshold = 20000;
  67. /******************************************************************************
  68. * Balancing *
  69. *****************************************************************************/
  70. /* Decides which data has to move where, and fills the
  71. * data_movements_handles[my_rank] data handle from that.
  72. * In data :
  73. * - local load_data_handle
  74. * - nneighbors
  75. * - neighbor_ids[nneighbors]
  76. * - neighbor_load_data_handles[nneighbors]
  77. * Out data :
  78. * - data_movements_handles[my_rank]
  79. */
  80. static void balance(starpu_data_handle_t load_data_cpy)
  81. {
  82. int less_loaded = -1;
  83. int n;
  84. double ref_elapsed_time;
  85. double my_elapsed_time = load_data_get_elapsed_time(load_data_cpy);
  86. /* Search for the less loaded neighbor */
  87. ref_elapsed_time = my_elapsed_time;
  88. for (n = 0; n < nneighbors; n++)
  89. {
  90. double elapsed_time = load_data_get_elapsed_time(neighbor_load_data_handles[n]);
  91. if (ref_elapsed_time > elapsed_time)
  92. {
  93. //fprintf(stderr,"Node%d: ref local time %lf vs neighbour%d time %lf\n", my_rank, ref_elapsed_time, neighbor_ids[n], elapsed_time);
  94. less_loaded = neighbor_ids[n];
  95. ref_elapsed_time = elapsed_time;
  96. }
  97. }
  98. /* We found it */
  99. if (less_loaded >= 0)
  100. {
  101. _STARPU_DEBUG("Less loaded found on node %d : %d\n", my_rank, less_loaded);
  102. double diff_time = my_elapsed_time - ref_elapsed_time;
  103. /* If the difference is higher than a time threshold, we move
  104. * one data to the less loaded neighbour. */
  105. /* TODO: How to decide the time threshold ? */
  106. if ((time_threshold > 0) && (diff_time >= time_threshold))
  107. {
  108. starpu_data_handle_t *handles = NULL;
  109. int nhandles = 0;
  110. user_itf->get_data_unit_to_migrate(&handles, &nhandles, less_loaded);
  111. data_movements_reallocate_tables(data_movements_handles[my_rank], nhandles);
  112. if (nhandles)
  113. {
  114. starpu_mpi_tag_t *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
  115. int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
  116. for (n = 0; n < nhandles; n++)
  117. {
  118. tags[n] = starpu_mpi_data_get_tag(handles[n]);
  119. ranks[n] = less_loaded;
  120. }
  121. free(handles);
  122. }
  123. }
  124. else
  125. data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
  126. }
  127. else
  128. data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
  129. }
  130. static void exchange_load_data_infos(starpu_data_handle_t load_data_cpy)
  131. {
  132. int i;
  133. /* Allocate all requests and status for point-to-point communications */
  134. starpu_mpi_req load_send_req[nneighbors];
  135. starpu_mpi_req load_recv_req[nneighbors];
  136. MPI_Status load_send_status[nneighbors];
  137. MPI_Status load_recv_status[nneighbors];
  138. int flag;
  139. /* Send the local load data to neighbour nodes, and receive the remote load
  140. * data from neighbour nodes */
  141. for (i = 0; i < nneighbors; i++)
  142. {
  143. //_STARPU_DEBUG("[node %d] sending and receiving with %i-th neighbor %i\n", my_rank, i, neighbor_ids[i]);
  144. starpu_mpi_isend(load_data_cpy, &load_send_req[i], neighbor_ids[i], TAG_LOAD(my_rank), MPI_COMM_WORLD);
  145. starpu_mpi_irecv(neighbor_load_data_handles[i], &load_recv_req[i], neighbor_ids[i], TAG_LOAD(neighbor_ids[i]), MPI_COMM_WORLD);
  146. }
  147. /* Wait for completion of all send requests */
  148. for (i = 0; i < nneighbors; i++)
  149. {
  150. flag = 0;
  151. while (!flag)
  152. starpu_mpi_test(&load_send_req[i], &flag, &load_send_status[i]);
  153. }
  154. /* Wait for completion of all receive requests */
  155. for (i = 0; i < nneighbors; i++)
  156. {
  157. flag = 0;
  158. while (!flag)
  159. starpu_mpi_test(&load_recv_req[i], &flag, &load_recv_status[i]);
  160. }
  161. }
  162. static void exchange_data_movements_infos()
  163. {
  164. int i;
  165. /* Allocate all requests and status for point-to-point communications */
  166. starpu_mpi_req data_movements_send_req[world_size];
  167. starpu_mpi_req data_movements_recv_req[world_size];
  168. MPI_Status data_movements_send_status[world_size];
  169. MPI_Status data_movements_recv_status[world_size];
  170. int flag;
  171. /* Send the new ranks of local data to all other nodes, and receive the new
  172. * ranks of all remote data from all other nodes */
  173. for (i = 0; i < world_size; i++)
  174. {
  175. if (i != my_rank)
  176. {
  177. //_STARPU_DEBUG("[node %d] Send and receive data movement with %d\n", my_rank, i);
  178. starpu_mpi_isend(data_movements_handles[my_rank], &data_movements_send_req[i], i, TAG_MOV(my_rank), MPI_COMM_WORLD);
  179. starpu_mpi_irecv(data_movements_handles[i], &data_movements_recv_req[i], i, TAG_MOV(i), MPI_COMM_WORLD);
  180. }
  181. }
  182. /* Wait for completion of all send requests */
  183. for (i = 0; i < world_size; i++)
  184. {
  185. if (i != my_rank)
  186. {
  187. //fprintf(stderr,"Wait for sending data movement of %d to %d\n", my_rank, i);
  188. flag = 0;
  189. while (!flag)
  190. starpu_mpi_test(&data_movements_send_req[i], &flag, &data_movements_send_status[i]);
  191. }
  192. }
  193. /* Wait for completion of all receive requests */
  194. for (i = 0; i < world_size; i++)
  195. {
  196. if (i != my_rank)
  197. {
  198. //fprintf(stderr,"Wait for recieving data movement from %d on %d\n", i, my_rank);
  199. flag = 0;
  200. while (!flag)
  201. starpu_mpi_test(&data_movements_recv_req[i], &flag, &data_movements_recv_status[i]);
  202. }
  203. }
  204. }
  205. static void update_data_ranks()
  206. {
  207. int i,j;
  208. /* Update the new ranks for all concerned data */
  209. for (i = 0; i < world_size; i++)
  210. {
  211. int ndata_to_update = data_movements_get_size_tables(data_movements_handles[i]);
  212. if (ndata_to_update)
  213. {
  214. //fprintf(stderr,"Update %d data from table %d on node %d\n", ndata_to_update, i, my_rank);
  215. for (j = 0; j < ndata_to_update; j++)
  216. {
  217. starpu_data_handle_t handle = _starpu_mpi_tag_get_data_handle_from_tag((data_movements_get_tags_table(data_movements_handles[i]))[j]);
  218. STARPU_ASSERT(handle);
  219. int dst_rank = (data_movements_get_ranks_table(data_movements_handles[i]))[j];
  220. /* Save the fact that the data has been moved out of this node */
  221. if (i == my_rank)
  222. {
  223. struct moved_data_entry *md;
  224. _STARPU_MPI_MALLOC(md, sizeof(struct moved_data_entry));
  225. md->handle = handle;
  226. HASH_ADD_PTR(mdh, handle, md);
  227. }
  228. else if (dst_rank == my_rank)
  229. {
  230. /* The data has been moved out, and now is moved back, so
  231. * update the state of the moved_data hash table to reflect
  232. * this change */
  233. struct moved_data_entry *md = NULL;
  234. HASH_FIND_PTR(mdh, &handle, md);
  235. if (md)
  236. {
  237. HASH_DEL(mdh, md);
  238. free(md);
  239. }
  240. }
  241. //if (i == my_rank)
  242. //{
  243. // if (dst_rank != my_rank)
  244. // fprintf(stderr,"Move data %p (tag %d) from node %d to node %d\n", handle, (data_movements_get_tags_table(data_movements_handles[i]))[j], my_rank, dst_rank);
  245. // else
  246. // fprintf(stderr,"Bring back data %p (tag %d) from node %d on node %d\n", handle, (data_movements_get_tags_table(data_movements_handles[i]))[j], starpu_mpi_data_get_rank(handle), my_rank);
  247. //}
  248. _STARPU_DEBUG("Call of starpu_mpi_get_data_on_node(%"PRIi64",%d) on node %d\n", starpu_mpi_data_get_tag(handle), dst_rank, my_rank);
  249. /* Migrate the data handle */
  250. starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, handle, dst_rank, NULL, NULL);
  251. _STARPU_DEBUG("New rank (%d) of data %"PRIi64" upgraded on node %d\n", dst_rank, starpu_mpi_data_get_tag(handle), my_rank);
  252. starpu_mpi_data_set_rank_comm(handle, dst_rank, MPI_COMM_WORLD);
  253. }
  254. }
  255. }
  256. }
  257. static void clean_balance()
  258. {
  259. int i;
  260. starpu_mpi_cache_flush(MPI_COMM_WORLD, *load_data_handle_cpy);
  261. for (i = 0; i < nneighbors; i++)
  262. starpu_mpi_cache_flush(MPI_COMM_WORLD, neighbor_load_data_handles[i]);
  263. for (i = 0; i < world_size; i++)
  264. starpu_mpi_cache_flush(MPI_COMM_WORLD, data_movements_handles[i]);
  265. }
  266. /* Core function of the load balancer. Computes from the load_data_cpy handle a
  267. * load balancing of the work to come (if needed), perform the necessary data
  268. * communications and negociate with the other nodes the rebalancing. */
  269. static void heat_balance(starpu_data_handle_t load_data_cpy)
  270. {
  271. /* Exchange load data handles with neighboring nodes */
  272. exchange_load_data_infos(load_data_cpy);
  273. /* Determine if this node should sent data to other nodes :
  274. * which ones, how much data */
  275. balance(load_data_cpy);
  276. /* Exchange data movements with neighboring nodes */
  277. exchange_data_movements_infos();
  278. /* Perform data movements */
  279. update_data_ranks();
  280. /* Clean the data handles to properly launch the next balance phase */
  281. clean_balance();
  282. }
  283. /******************************************************************************
  284. * Heat Load Balancer Entry Points *
  285. *****************************************************************************/
  286. static void submitted_task_heat(struct starpu_task *task)
  287. {
  288. load_data_inc_nsubmitted_tasks(*load_data_handle);
  289. //if (load_data_get_nsubmitted_tasks(*load_data_handle) > task->tag_id)
  290. //{
  291. // fprintf(stderr,"Error : nsubmitted_tasks (%d) > tag_id (%lld) ! \n", load_data_get_nsubmitted_tasks(*load_data_handle), (long long int)task->tag_id);
  292. // STARPU_ASSERT(0);
  293. //}
  294. int phase = load_data_get_current_phase(*load_data_handle);
  295. /* Numbering of tasks in StarPU-MPI should be given by the application with
  296. * the STARPU_TAG_ONLY insert task option for now. */
  297. /* TODO: Properly implement a solution for numbering tasks in StarPU-MPI */
  298. if (((int)task->tag_id / load_data_get_sleep_threshold(*load_data_handle)) > phase)
  299. {
  300. STARPU_PTHREAD_MUTEX_LOCK(&load_data_mutex);
  301. load_data_update_wakeup_cond(*load_data_handle);
  302. //fprintf(stderr,"Node %d sleep on tag %lld\n", my_rank, (long long int)task->tag_id);
  303. //if (load_data_get_nsubmitted_tasks(*load_data_handle) < load_data_get_wakeup_threshold(*load_data_handle))
  304. //{
  305. // fprintf(stderr,"Error : nsubmitted_tasks (%d) lower than wakeup_threshold (%d) !\n", load_data_get_nsubmitted_tasks(*load_data_handle), load_data_get_wakeup_threshold(*load_data_handle));
  306. // STARPU_ASSERT(0);
  307. //}
  308. if (load_data_get_wakeup_threshold(*load_data_handle) > load_data_get_nfinished_tasks(*load_data_handle))
  309. STARPU_PTHREAD_COND_WAIT(&load_data_cond, &load_data_mutex);
  310. load_data_next_phase(*load_data_handle);
  311. /* Register a copy of the load data at this moment, to allow to compute
  312. * the heat balance while not locking the load data during the whole
  313. * balance step, which could cause all the workers to wait on the lock
  314. * to update the data. */
  315. struct starpu_data_interface_ops *itf_load_data = starpu_data_get_interface_ops(*load_data_handle);
  316. void* itf_src = starpu_data_get_interface_on_node(*load_data_handle, STARPU_MAIN_RAM);
  317. void* itf_dst = starpu_data_get_interface_on_node(*load_data_handle_cpy, STARPU_MAIN_RAM);
  318. memcpy(itf_dst, itf_src, itf_load_data->interface_size);
  319. _STARPU_DEBUG("[node %d] Balance phase %d\n", my_rank, load_data_get_current_phase(*load_data_handle));
  320. STARPU_PTHREAD_MUTEX_UNLOCK(&load_data_mutex);
  321. heat_balance(*load_data_handle_cpy);
  322. }
  323. }
  324. static void finished_task_heat()
  325. {
  326. //fprintf(stderr,"Try to decrement nsubmitted_tasks...");
  327. STARPU_PTHREAD_MUTEX_LOCK(&load_data_mutex);
  328. load_data_inc_nfinished_tasks(*load_data_handle);
  329. //fprintf(stderr,"Decrement nsubmitted_tasks, now %d\n", load_data_get_nsubmitted_tasks(*load_data_handle));
  330. if (load_data_wakeup_cond(*load_data_handle))
  331. {
  332. //fprintf(stderr,"Wakeup ! nfinished_tasks = %d, wakeup_threshold = %d\n", load_data_get_nfinished_tasks(*load_data_handle), load_data_get_wakeup_threshold(*load_data_handle));
  333. load_data_update_elapsed_time(*load_data_handle);
  334. STARPU_PTHREAD_COND_SIGNAL(&load_data_cond);
  335. STARPU_PTHREAD_MUTEX_UNLOCK(&load_data_mutex);
  336. }
  337. else
  338. STARPU_PTHREAD_MUTEX_UNLOCK(&load_data_mutex);
  339. }
  340. /******************************************************************************
  341. * Initialization / Deinitialization *
  342. *****************************************************************************/
  343. static int init_heat(struct starpu_mpi_lb_conf *itf)
  344. {
  345. int i;
  346. int sleep_task_threshold;
  347. double wakeup_ratio;
  348. starpu_mpi_comm_size(MPI_COMM_WORLD, &world_size);
  349. starpu_mpi_comm_rank(MPI_COMM_WORLD, &my_rank);
  350. /* Immediately return if the starpu_mpi_lb_conf is invalid. */
  351. if (!(itf && itf->get_neighbors && itf->get_data_unit_to_migrate))
  352. {
  353. _STARPU_MSG("Error: struct starpu_mpi_lb_conf %p invalid\n", itf);
  354. return 1;
  355. }
  356. _STARPU_MPI_MALLOC(user_itf, sizeof(struct starpu_mpi_lb_conf));
  357. memcpy(user_itf, itf, sizeof(struct starpu_mpi_lb_conf));
  358. /* Get the neighbors of the local MPI node */
  359. user_itf->get_neighbors(&neighbor_ids, &nneighbors);
  360. if (nneighbors == 0)
  361. {
  362. _STARPU_MSG("Error: Function get_neighbors returning 0 neighbor\n");
  363. free(user_itf);
  364. user_itf = NULL;
  365. return 2;
  366. }
  367. /* The sleep threshold is deducted from the numbering of tasks by the
  368. * application. For example, with this threshold, the submission thread
  369. * will stop when a task for which the numbering is 2000 or above will be
  370. * submitted to StarPU-MPI. However, much less tasks can be really
  371. * submitted to the local MPI node: the sleeping of the submission threads
  372. * checks the numbering of the tasks, not how many tasks have been
  373. * submitted to the local MPI node, which are two different things. */
  374. char *sleep_env = starpu_getenv("LB_HEAT_SLEEP_THRESHOLD");
  375. if (sleep_env)
  376. sleep_task_threshold = atoi(sleep_env);
  377. else
  378. sleep_task_threshold = 2000;
  379. char *wakeup_env = starpu_getenv("LB_HEAT_WAKEUP_RATIO");
  380. if (wakeup_env)
  381. wakeup_ratio = atof(wakeup_env);
  382. else
  383. wakeup_ratio = 0.5;
  384. char *time_env = starpu_getenv("LB_HEAT_TIME_THRESHOLD");
  385. if (time_env)
  386. time_threshold = atoi(time_env);
  387. else
  388. time_threshold = 2000;
  389. STARPU_PTHREAD_MUTEX_INIT(&load_data_mutex, NULL);
  390. STARPU_PTHREAD_COND_INIT(&load_data_cond, NULL);
  391. /* Allocate, initialize and register all the data handles that will be
  392. * needed for the load balancer, to not reallocate them at each balance
  393. * step. */
  394. /* Local load data */
  395. _STARPU_MPI_CALLOC(load_data_handle, 1, sizeof(starpu_data_handle_t));
  396. load_data_data_register(load_data_handle, STARPU_MAIN_RAM, sleep_task_threshold, wakeup_ratio);
  397. /* Copy of the local load data to enable parallel update of the load data
  398. * with communications to neighbor nodes */
  399. _STARPU_MPI_CALLOC(load_data_handle_cpy, 1, sizeof(starpu_data_handle_t));
  400. void *local_interface = starpu_data_get_interface_on_node(*load_data_handle, STARPU_MAIN_RAM);
  401. struct starpu_data_interface_ops *itf_load_data = starpu_data_get_interface_ops(*load_data_handle);
  402. starpu_data_register(load_data_handle_cpy, STARPU_MAIN_RAM, local_interface, itf_load_data);
  403. starpu_mpi_data_register(*load_data_handle_cpy, TAG_LOAD(my_rank), my_rank);
  404. /* Remote load data */
  405. _STARPU_MPI_CALLOC(neighbor_load_data_handles, nneighbors, sizeof(starpu_data_handle_t));
  406. for (i = 0; i < nneighbors; i++)
  407. {
  408. load_data_data_register(&neighbor_load_data_handles[i], STARPU_MAIN_RAM, sleep_task_threshold, wakeup_ratio);
  409. starpu_mpi_data_register(neighbor_load_data_handles[i], TAG_LOAD(neighbor_ids[i]), neighbor_ids[i]);
  410. }
  411. /* Data movements handles */
  412. _STARPU_MPI_MALLOC(data_movements_handles, world_size*sizeof(starpu_data_handle_t));
  413. for (i = 0; i < world_size; i++)
  414. {
  415. data_movements_data_register(&data_movements_handles[i], STARPU_MAIN_RAM, NULL, NULL, 0);
  416. starpu_mpi_data_register(data_movements_handles[i], TAG_MOV(i), i);
  417. }
  418. /* Hash table of moved data that will be brought back on the node at
  419. * termination time */
  420. mdh = NULL;
  421. return 0;
  422. }
  423. /* Move back all the data that has been migrated out of this node at
  424. * denitialization time of the load balancer, to ensure the consistency with
  425. * the ranks of data originally registered by the application. */
  426. static void move_back_data()
  427. {
  428. int i,j;
  429. /* Update the new ranks for all concerned data */
  430. for (i = 0; i < world_size; i++)
  431. {
  432. /* In this case, each data_movements_handles contains the handles to move back on the specific node */
  433. int ndata_to_update = data_movements_get_size_tables(data_movements_handles[i]);
  434. if (ndata_to_update)
  435. {
  436. _STARPU_DEBUG("Move back %d data from table %d on node %d\n", ndata_to_update, i, my_rank);
  437. for (j = 0; j < ndata_to_update; j++)
  438. {
  439. starpu_data_handle_t handle = _starpu_mpi_tag_get_data_handle_from_tag((data_movements_get_tags_table(data_movements_handles[i]))[j]);
  440. STARPU_ASSERT(handle);
  441. int dst_rank = (data_movements_get_ranks_table(data_movements_handles[i]))[j];
  442. STARPU_ASSERT(i == dst_rank);
  443. if (i == my_rank)
  444. {
  445. /* The data is moved back, so update the state of the
  446. * moved_data hash table to reflect this change */
  447. struct moved_data_entry *md = NULL;
  448. HASH_FIND_PTR(mdh, &handle, md);
  449. if (md)
  450. {
  451. HASH_DEL(mdh, md);
  452. free(md);
  453. }
  454. }
  455. //fprintf(stderr,"Call of starpu_mpi_get_data_on_node(%d,%d) on node %d\n", starpu_mpi_data_get_tag(handle), dst_rank, my_rank);
  456. /* Migrate the data handle */
  457. starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, handle, dst_rank, NULL, NULL);
  458. //fprintf(stderr,"New rank (%d) of data %d upgraded on node %d\n", dst_rank, starpu_mpi_data_get_tag(handle), my_rank);
  459. starpu_mpi_data_set_rank_comm(handle, dst_rank, MPI_COMM_WORLD);
  460. }
  461. }
  462. }
  463. }
  464. static int deinit_heat()
  465. {
  466. int i;
  467. if ((!user_itf) || (nneighbors == 0))
  468. return 1;
  469. _STARPU_DEBUG("Shutting down heat lb policy\n");
  470. unsigned int ndata_to_move_back = HASH_COUNT(mdh);
  471. if (ndata_to_move_back)
  472. {
  473. _STARPU_DEBUG("Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
  474. data_movements_reallocate_tables(data_movements_handles[my_rank], ndata_to_move_back);
  475. starpu_mpi_tag_t *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
  476. int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
  477. int n = 0;
  478. struct moved_data_entry *md=NULL, *tmp=NULL;
  479. HASH_ITER(hh, mdh, md, tmp)
  480. {
  481. tags[n] = starpu_mpi_data_get_tag(md->handle);
  482. ranks[n] = my_rank;
  483. n++;
  484. }
  485. }
  486. else
  487. data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
  488. exchange_data_movements_infos();
  489. move_back_data();
  490. /* This assert ensures that all nodes have properly gotten back all the
  491. * data that has been moven out of the node. */
  492. STARPU_ASSERT(HASH_COUNT(mdh) == 0);
  493. free(mdh);
  494. mdh = NULL;
  495. starpu_data_unregister(*load_data_handle);
  496. free(load_data_handle);
  497. load_data_handle = NULL;
  498. starpu_mpi_cache_flush(MPI_COMM_WORLD, *load_data_handle_cpy);
  499. starpu_data_unregister(*load_data_handle_cpy);
  500. free(load_data_handle_cpy);
  501. load_data_handle_cpy = NULL;
  502. for (i = 0; i < nneighbors; i++)
  503. {
  504. starpu_mpi_cache_flush(MPI_COMM_WORLD, neighbor_load_data_handles[i]);
  505. starpu_data_unregister(neighbor_load_data_handles[i]);
  506. }
  507. free(neighbor_load_data_handles);
  508. neighbor_load_data_handles = NULL;
  509. nneighbors = 0;
  510. free(neighbor_ids);
  511. neighbor_ids = NULL;
  512. for (i = 0; i < world_size; i++)
  513. {
  514. starpu_mpi_cache_flush(MPI_COMM_WORLD, data_movements_handles[i]);
  515. data_movements_reallocate_tables(data_movements_handles[i], 0);
  516. starpu_data_unregister(data_movements_handles[i]);
  517. }
  518. free(data_movements_handles);
  519. data_movements_handles = NULL;
  520. STARPU_PTHREAD_MUTEX_DESTROY(&load_data_mutex);
  521. STARPU_PTHREAD_COND_DESTROY(&load_data_cond);
  522. free(user_itf);
  523. user_itf = NULL;
  524. return 0;
  525. }
  526. /******************************************************************************
  527. * Policy *
  528. *****************************************************************************/
  529. struct load_balancer_policy load_heat_propagation_policy =
  530. {
  531. .init = init_heat,
  532. .deinit = deinit_heat,
  533. .submitted_task_entry_point = submitted_task_heat,
  534. .finished_task_entry_point = finished_task_heat,
  535. .policy_name = "heat"
  536. };
  537. #endif