load_heat_propagation.c 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2016 Inria
  4. * Copyright (C) 2017 CNRS
  5. *
  6. * StarPU is free software; you can redistribute it and/or modify
  7. * it under the terms of the GNU Lesser General Public License as published by
  8. * the Free Software Foundation; either version 2.1 of the License, or (at
  9. * your option) any later version.
  10. *
  11. * StarPU is distributed in the hope that it will be useful, but
  12. * WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  14. *
  15. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  16. */
  17. #include <starpu_mpi.h>
  18. #include <starpu_mpi_tag.h>
  19. #include <common/uthash.h>
  20. #include <common/utils.h>
  21. #include <math.h>
  22. #include "load_balancer_policy.h"
  23. #include "data_movements_interface.h"
  24. #include "load_data_interface.h"
  25. static int TAG_LOAD(int n)
  26. {
  27. return ((n+1) << 24);
  28. }
  29. static int TAG_MOV(int n)
  30. {
  31. return ((n+1) << 20);
  32. }
  33. /* Hash table of local pieces of data that has been moved out of the local MPI
  34. * node by the load balancer. All of these pieces of data must be migrated back
  35. * to the local node at the end of the execution. */
  36. struct moved_data_entry
  37. {
  38. UT_hash_handle hh;
  39. starpu_data_handle_t handle;
  40. };
  41. static struct moved_data_entry *mdh = NULL;
  42. static starpu_pthread_mutex_t load_data_mutex;
  43. static starpu_pthread_cond_t load_data_cond;
  44. /* MPI infos */
  45. static int my_rank;
  46. static int world_size;
  47. /* Number of neighbours of the local MPI node and their IDs. These are given by
  48. * the get_neighbors() method, and thus can be easily changed. */
  49. static int *neighbor_ids = NULL;
  50. static int nneighbors = 0;
  51. /* Local load data */
  52. static starpu_data_handle_t *load_data_handle = NULL;
  53. static starpu_data_handle_t *load_data_handle_cpy = NULL;
  54. /* Load data of neighbours */
  55. static starpu_data_handle_t *neighbor_load_data_handles = NULL;
  56. /* Table which contains a data_movements_handle for each MPI node of
  57. * MPI_COMM_WORLD. Since all the MPI nodes must be advised of any data
  58. * movement, this table will be used to perform communications of data
  59. * movements handles following an all-to-all model. */
  60. static starpu_data_handle_t *data_movements_handles = NULL;
  61. /* Load balancer interface which contains the application-specific methods for
  62. * the load balancer to use. */
  63. static struct starpu_mpi_lb_conf *user_itf = NULL;
  64. static double time_threshold = 20000;
  65. /******************************************************************************
  66. * Balancing *
  67. *****************************************************************************/
  68. /* Decides which data has to move where, and fills the
  69. * data_movements_handles[my_rank] data handle from that.
  70. * In data :
  71. * - local load_data_handle
  72. * - nneighbors
  73. * - neighbor_ids[nneighbors]
  74. * - neighbor_load_data_handles[nneighbors]
  75. * Out data :
  76. * - data_movements_handles[my_rank]
  77. */
  78. static void balance(starpu_data_handle_t load_data_cpy)
  79. {
  80. int less_loaded = -1;
  81. int n;
  82. double elapsed_time, ref_elapsed_time;
  83. double my_elapsed_time = load_data_get_elapsed_time(load_data_cpy);
  84. /* Search for the less loaded neighbor */
  85. ref_elapsed_time = my_elapsed_time;
  86. for (n = 0; n < nneighbors; n++)
  87. {
  88. elapsed_time = load_data_get_elapsed_time(neighbor_load_data_handles[n]);
  89. if (ref_elapsed_time > elapsed_time)
  90. {
  91. //fprintf(stderr,"Node%d: ref local time %lf vs neighbour%d time %lf\n", my_rank, ref_elapsed_time, neighbor_ids[n], elapsed_time);
  92. less_loaded = neighbor_ids[n];
  93. ref_elapsed_time = elapsed_time;
  94. }
  95. }
  96. /* We found it */
  97. if (less_loaded >= 0)
  98. {
  99. _STARPU_DEBUG("Less loaded found on node %d : %d\n", my_rank, less_loaded);
  100. double diff_time = my_elapsed_time - ref_elapsed_time;
  101. /* If the difference is higher than a time threshold, we move
  102. * one data to the less loaded neighbour. */
  103. /* TODO: How to decide the time threshold ? */
  104. if ((time_threshold > 0) && (diff_time >= time_threshold))
  105. {
  106. starpu_data_handle_t *handles = NULL;
  107. int nhandles = 0;
  108. user_itf->get_data_unit_to_migrate(&handles, &nhandles, less_loaded);
  109. data_movements_reallocate_tables(data_movements_handles[my_rank], nhandles);
  110. if (nhandles)
  111. {
  112. int *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
  113. int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
  114. for (n = 0; n < nhandles; n++)
  115. {
  116. tags[n] = starpu_mpi_data_get_tag(handles[n]);
  117. ranks[n] = less_loaded;
  118. }
  119. free(handles);
  120. }
  121. }
  122. else
  123. data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
  124. }
  125. else
  126. data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
  127. }
  128. static void exchange_load_data_infos(starpu_data_handle_t load_data_cpy)
  129. {
  130. int i;
  131. /* Allocate all requests and status for point-to-point communications */
  132. starpu_mpi_req load_send_req[nneighbors];
  133. starpu_mpi_req load_recv_req[nneighbors];
  134. MPI_Status load_send_status[nneighbors];
  135. MPI_Status load_recv_status[nneighbors];
  136. int flag;
  137. /* Send the local load data to neighbour nodes, and receive the remote load
  138. * data from neighbour nodes */
  139. for (i = 0; i < nneighbors; i++)
  140. {
  141. //_STARPU_DEBUG("[node %d] sending and receiving with %i-th neighbor %i\n", my_rank, i, neighbor_ids[i]);
  142. starpu_mpi_isend(load_data_cpy, &load_send_req[i], neighbor_ids[i], TAG_LOAD(my_rank), MPI_COMM_WORLD);
  143. starpu_mpi_irecv(neighbor_load_data_handles[i], &load_recv_req[i], neighbor_ids[i], TAG_LOAD(neighbor_ids[i]), MPI_COMM_WORLD);
  144. }
  145. /* Wait for completion of all send requests */
  146. for (i = 0; i < nneighbors; i++)
  147. {
  148. flag = 0;
  149. while (!flag)
  150. starpu_mpi_test(&load_send_req[i], &flag, &load_send_status[i]);
  151. }
  152. /* Wait for completion of all receive requests */
  153. for (i = 0; i < nneighbors; i++)
  154. {
  155. flag = 0;
  156. while (!flag)
  157. starpu_mpi_test(&load_recv_req[i], &flag, &load_recv_status[i]);
  158. }
  159. }
  160. static void exchange_data_movements_infos()
  161. {
  162. int i;
  163. /* Allocate all requests and status for point-to-point communications */
  164. starpu_mpi_req data_movements_send_req[world_size];
  165. starpu_mpi_req data_movements_recv_req[world_size];
  166. MPI_Status data_movements_send_status[world_size];
  167. MPI_Status data_movements_recv_status[world_size];
  168. int flag;
  169. /* Send the new ranks of local data to all other nodes, and receive the new
  170. * ranks of all remote data from all other nodes */
  171. for (i = 0; i < world_size; i++)
  172. {
  173. if (i != my_rank)
  174. {
  175. //_STARPU_DEBUG("[node %d] Send and receive data movement with %d\n", my_rank, i);
  176. starpu_mpi_isend(data_movements_handles[my_rank], &data_movements_send_req[i], i, TAG_MOV(my_rank), MPI_COMM_WORLD);
  177. starpu_mpi_irecv(data_movements_handles[i], &data_movements_recv_req[i], i, TAG_MOV(i), MPI_COMM_WORLD);
  178. }
  179. }
  180. /* Wait for completion of all send requests */
  181. for (i = 0; i < world_size; i++)
  182. {
  183. if (i != my_rank)
  184. {
  185. //fprintf(stderr,"Wait for sending data movement of %d to %d\n", my_rank, i);
  186. flag = 0;
  187. while (!flag)
  188. starpu_mpi_test(&data_movements_send_req[i], &flag, &data_movements_send_status[i]);
  189. }
  190. }
  191. /* Wait for completion of all receive requests */
  192. for (i = 0; i < world_size; i++)
  193. {
  194. if (i != my_rank)
  195. {
  196. //fprintf(stderr,"Wait for recieving data movement from %d on %d\n", i, my_rank);
  197. flag = 0;
  198. while (!flag)
  199. starpu_mpi_test(&data_movements_recv_req[i], &flag, &data_movements_recv_status[i]);
  200. }
  201. }
  202. }
  203. static void update_data_ranks()
  204. {
  205. int i,j;
  206. /* Update the new ranks for all concerned data */
  207. for (i = 0; i < world_size; i++)
  208. {
  209. int ndata_to_update = data_movements_get_size_tables(data_movements_handles[i]);
  210. if (ndata_to_update)
  211. {
  212. //fprintf(stderr,"Update %d data from table %d on node %d\n", ndata_to_update, i, my_rank);
  213. for (j = 0; j < ndata_to_update; j++)
  214. {
  215. starpu_data_handle_t handle = _starpu_mpi_data_get_data_handle_from_tag((data_movements_get_tags_table(data_movements_handles[i]))[j]);
  216. STARPU_ASSERT(handle);
  217. int dst_rank = (data_movements_get_ranks_table(data_movements_handles[i]))[j];
  218. /* Save the fact that the data has been moved out of this node */
  219. if (i == my_rank)
  220. {
  221. struct moved_data_entry *md = (struct moved_data_entry *)malloc(sizeof(struct moved_data_entry));
  222. md->handle = handle;
  223. HASH_ADD_PTR(mdh, handle, md);
  224. }
  225. else if (dst_rank == my_rank)
  226. {
  227. /* The data has been moved out, and now is moved back, so
  228. * update the state of the moved_data hash table to reflect
  229. * this change */
  230. struct moved_data_entry *md = NULL;
  231. HASH_FIND_PTR(mdh, &handle, md);
  232. if (md)
  233. {
  234. HASH_DEL(mdh, md);
  235. free(md);
  236. }
  237. }
  238. //if (i == my_rank)
  239. //{
  240. // if (dst_rank != my_rank)
  241. // fprintf(stderr,"Move data %p (tag %d) from node %d to node %d\n", handle, (data_movements_get_tags_table(data_movements_handles[i]))[j], my_rank, dst_rank);
  242. // else
  243. // fprintf(stderr,"Bring back data %p (tag %d) from node %d on node %d\n", handle, (data_movements_get_tags_table(data_movements_handles[i]))[j], starpu_mpi_data_get_rank(handle), my_rank);
  244. //}
  245. _STARPU_DEBUG("Call of starpu_mpi_get_data_on_node(%d,%d) on node %d\n", starpu_mpi_data_get_tag(handle), dst_rank, my_rank);
  246. /* Migrate the data handle */
  247. starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, handle, dst_rank, NULL, NULL);
  248. _STARPU_DEBUG("New rank (%d) of data %d upgraded on node %d\n", dst_rank, starpu_mpi_data_get_tag(handle), my_rank);
  249. starpu_mpi_data_set_rank_comm(handle, dst_rank, MPI_COMM_WORLD);
  250. }
  251. }
  252. }
  253. }
  254. static void clean_balance()
  255. {
  256. int i;
  257. starpu_mpi_cache_flush(MPI_COMM_WORLD, *load_data_handle_cpy);
  258. for (i = 0; i < nneighbors; i++)
  259. starpu_mpi_cache_flush(MPI_COMM_WORLD, neighbor_load_data_handles[i]);
  260. for (i = 0; i < world_size; i++)
  261. starpu_mpi_cache_flush(MPI_COMM_WORLD, data_movements_handles[i]);
  262. }
  263. /* Core function of the load balancer. Computes from the load_data_cpy handle a
  264. * load balancing of the work to come (if needed), perform the necessary data
  265. * communications and negociate with the other nodes the rebalancing. */
  266. static void heat_balance(starpu_data_handle_t load_data_cpy)
  267. {
  268. /* Exchange load data handles with neighboring nodes */
  269. exchange_load_data_infos(load_data_cpy);
  270. /* Determine if this node should sent data to other nodes :
  271. * which ones, how much data */
  272. balance(load_data_cpy);
  273. /* Exchange data movements with neighboring nodes */
  274. exchange_data_movements_infos();
  275. /* Perform data movements */
  276. update_data_ranks();
  277. /* Clean the data handles to properly launch the next balance phase */
  278. clean_balance();
  279. }
  280. /******************************************************************************
  281. * Heat Load Balancer Entry Points *
  282. *****************************************************************************/
  283. static void submitted_task_heat(struct starpu_task *task)
  284. {
  285. load_data_inc_nsubmitted_tasks(*load_data_handle);
  286. //if (load_data_get_nsubmitted_tasks(*load_data_handle) > task->tag_id)
  287. //{
  288. // fprintf(stderr,"Error : nsubmitted_tasks (%d) > tag_id (%lld) ! \n", load_data_get_nsubmitted_tasks(*load_data_handle), (long long int)task->tag_id);
  289. // STARPU_ASSERT(0);
  290. //}
  291. int phase = load_data_get_current_phase(*load_data_handle);
  292. /* Numbering of tasks in StarPU-MPI should be given by the application with
  293. * the STARPU_TAG_ONLY insert task option for now. */
  294. /* TODO: Properly implement a solution for numbering tasks in StarPU-MPI */
  295. if ((task->tag_id / load_data_get_sleep_threshold(*load_data_handle)) > phase)
  296. {
  297. STARPU_PTHREAD_MUTEX_LOCK(&load_data_mutex);
  298. load_data_update_wakeup_cond(*load_data_handle);
  299. //fprintf(stderr,"Node %d sleep on tag %lld\n", my_rank, (long long int)task->tag_id);
  300. //if (load_data_get_nsubmitted_tasks(*load_data_handle) < load_data_get_wakeup_threshold(*load_data_handle))
  301. //{
  302. // fprintf(stderr,"Error : nsubmitted_tasks (%d) lower than wakeup_threshold (%d) !\n", load_data_get_nsubmitted_tasks(*load_data_handle), load_data_get_wakeup_threshold(*load_data_handle));
  303. // STARPU_ASSERT(0);
  304. //}
  305. if (load_data_get_wakeup_threshold(*load_data_handle) > load_data_get_nfinished_tasks(*load_data_handle))
  306. STARPU_PTHREAD_COND_WAIT(&load_data_cond, &load_data_mutex);
  307. load_data_next_phase(*load_data_handle);
  308. /* Register a copy of the load data at this moment, to allow to compute
  309. * the heat balance while not locking the load data during the whole
  310. * balance step, which could cause all the workers to wait on the lock
  311. * to update the data. */
  312. struct starpu_data_interface_ops *itf_load_data = starpu_data_get_interface_ops(*load_data_handle);
  313. void* itf_src = starpu_data_get_interface_on_node(*load_data_handle, STARPU_MAIN_RAM);
  314. void* itf_dst = starpu_data_get_interface_on_node(*load_data_handle_cpy, STARPU_MAIN_RAM);
  315. memcpy(itf_dst, itf_src, itf_load_data->interface_size);
  316. _STARPU_DEBUG("[node %d] Balance phase %d\n", my_rank, load_data_get_current_phase(*load_data_handle));
  317. STARPU_PTHREAD_MUTEX_UNLOCK(&load_data_mutex);
  318. heat_balance(*load_data_handle_cpy);
  319. }
  320. }
  321. static void finished_task_heat()
  322. {
  323. //fprintf(stderr,"Try to decrement nsubmitted_tasks...");
  324. STARPU_PTHREAD_MUTEX_LOCK(&load_data_mutex);
  325. load_data_inc_nfinished_tasks(*load_data_handle);
  326. //fprintf(stderr,"Decrement nsubmitted_tasks, now %d\n", load_data_get_nsubmitted_tasks(*load_data_handle));
  327. if (load_data_wakeup_cond(*load_data_handle))
  328. {
  329. //fprintf(stderr,"Wakeup ! nfinished_tasks = %d, wakeup_threshold = %d\n", load_data_get_nfinished_tasks(*load_data_handle), load_data_get_wakeup_threshold(*load_data_handle));
  330. load_data_update_elapsed_time(*load_data_handle);
  331. STARPU_PTHREAD_COND_SIGNAL(&load_data_cond);
  332. STARPU_PTHREAD_MUTEX_UNLOCK(&load_data_mutex);
  333. }
  334. else
  335. STARPU_PTHREAD_MUTEX_UNLOCK(&load_data_mutex);
  336. }
  337. /******************************************************************************
  338. * Initialization / Deinitialization *
  339. *****************************************************************************/
  340. static int init_heat(struct starpu_mpi_lb_conf *itf)
  341. {
  342. int i;
  343. int sleep_task_threshold;
  344. double wakeup_ratio;
  345. starpu_mpi_comm_size(MPI_COMM_WORLD, &world_size);
  346. starpu_mpi_comm_rank(MPI_COMM_WORLD, &my_rank);
  347. /* Immediately return if the starpu_mpi_lb_conf is invalid. */
  348. if (!(itf && itf->get_neighbors && itf->get_data_unit_to_migrate))
  349. {
  350. _STARPU_MSG("Error: struct starpu_mpi_lb_conf %p invalid\n", itf);
  351. return 1;
  352. }
  353. user_itf = malloc(sizeof(struct starpu_mpi_lb_conf));
  354. memcpy(user_itf, itf, sizeof(struct starpu_mpi_lb_conf));;
  355. /* Get the neighbors of the local MPI node */
  356. user_itf->get_neighbors(&neighbor_ids, &nneighbors);
  357. if (nneighbors == 0)
  358. {
  359. _STARPU_MSG("Error: Function get_neighbors returning 0 neighbor\n");
  360. free(user_itf);
  361. user_itf = NULL;
  362. return 2;
  363. }
  364. /* The sleep threshold is deducted from the numbering of tasks by the
  365. * application. For example, with this threshold, the submission thread
  366. * will stop when a task for which the numbering is 2000 or above will be
  367. * submitted to StarPU-MPI. However, much less tasks can be really
  368. * submitted to the local MPI node: the sleeping of the submission threads
  369. * checks the numbering of the tasks, not how many tasks have been
  370. * submitted to the local MPI node, which are two different things. */
  371. char *sleep_env = starpu_getenv("LB_HEAT_SLEEP_THRESHOLD");
  372. if (sleep_env)
  373. sleep_task_threshold = atoi(sleep_env);
  374. else
  375. sleep_task_threshold = 2000;
  376. char *wakeup_env = starpu_getenv("LB_HEAT_WAKEUP_RATIO");
  377. if (wakeup_env)
  378. wakeup_ratio = atof(wakeup_env);
  379. else
  380. wakeup_ratio = 0.5;
  381. char *time_env = starpu_getenv("LB_HEAT_TIME_THRESHOLD");
  382. if (time_env)
  383. time_threshold = atoi(time_env);
  384. else
  385. time_threshold = 2000;
  386. STARPU_PTHREAD_MUTEX_INIT(&load_data_mutex, NULL);
  387. STARPU_PTHREAD_COND_INIT(&load_data_cond, NULL);
  388. /* Allocate, initialize and register all the data handles that will be
  389. * needed for the load balancer, to not reallocate them at each balance
  390. * step. */
  391. /* Local load data */
  392. load_data_handle = malloc(sizeof(starpu_data_handle_t));
  393. memset(load_data_handle, 0, sizeof(starpu_data_handle_t));
  394. load_data_data_register(load_data_handle, STARPU_MAIN_RAM, sleep_task_threshold, wakeup_ratio);
  395. /* Copy of the local load data to enable parallel update of the load data
  396. * with communications to neighbor nodes */
  397. load_data_handle_cpy = malloc(sizeof(starpu_data_handle_t));
  398. memset(load_data_handle_cpy, 0, sizeof(starpu_data_handle_t));
  399. void *local_interface = starpu_data_get_interface_on_node(*load_data_handle, STARPU_MAIN_RAM);
  400. struct starpu_data_interface_ops *itf_load_data = starpu_data_get_interface_ops(*load_data_handle);
  401. starpu_data_register(load_data_handle_cpy, STARPU_MAIN_RAM, local_interface, itf_load_data);
  402. starpu_mpi_data_register(*load_data_handle_cpy, TAG_LOAD(my_rank), my_rank);
  403. /* Remote load data */
  404. neighbor_load_data_handles = malloc(nneighbors*sizeof(starpu_data_handle_t));
  405. memset(neighbor_load_data_handles, 0, nneighbors*sizeof(starpu_data_handle_t));
  406. for (i = 0; i < nneighbors; i++)
  407. {
  408. load_data_data_register(&neighbor_load_data_handles[i], STARPU_MAIN_RAM, sleep_task_threshold, wakeup_ratio);
  409. starpu_mpi_data_register(neighbor_load_data_handles[i], TAG_LOAD(neighbor_ids[i]), neighbor_ids[i]);
  410. }
  411. /* Data movements handles */
  412. data_movements_handles = malloc(world_size*sizeof(starpu_data_handle_t));
  413. for (i = 0; i < world_size; i++)
  414. {
  415. data_movements_data_register(&data_movements_handles[i], STARPU_MAIN_RAM, NULL, NULL, 0);
  416. starpu_mpi_data_register(data_movements_handles[i], TAG_MOV(i), i);
  417. }
  418. /* Hash table of moved data that will be brought back on the node at
  419. * termination time */
  420. mdh = NULL;
  421. return 0;
  422. }
  423. /* Move back all the data that has been migrated out of this node at
  424. * denitialization time of the load balancer, to ensure the consistency with
  425. * the ranks of data originally registered by the application. */
  426. static void move_back_data()
  427. {
  428. int i,j;
  429. /* Update the new ranks for all concerned data */
  430. for (i = 0; i < world_size; i++)
  431. {
  432. /* In this case, each data_movements_handles contains the handles to move back on the specific node */
  433. int ndata_to_update = data_movements_get_size_tables(data_movements_handles[i]);
  434. if (ndata_to_update)
  435. {
  436. _STARPU_DEBUG("Move back %d data from table %d on node %d\n", ndata_to_update, i, my_rank);
  437. for (j = 0; j < ndata_to_update; j++)
  438. {
  439. starpu_data_handle_t handle = _starpu_mpi_data_get_data_handle_from_tag((data_movements_get_tags_table(data_movements_handles[i]))[j]);
  440. STARPU_ASSERT(handle);
  441. int dst_rank = (data_movements_get_ranks_table(data_movements_handles[i]))[j];
  442. STARPU_ASSERT(i == dst_rank);
  443. if (i == my_rank)
  444. {
  445. /* The data is moved back, so update the state of the
  446. * moved_data hash table to reflect this change */
  447. struct moved_data_entry *md = NULL;
  448. HASH_FIND_PTR(mdh, &handle, md);
  449. if (md)
  450. {
  451. HASH_DEL(mdh, md);
  452. free(md);
  453. }
  454. }
  455. //fprintf(stderr,"Call of starpu_mpi_get_data_on_node(%d,%d) on node %d\n", starpu_mpi_data_get_tag(handle), dst_rank, my_rank);
  456. /* Migrate the data handle */
  457. starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, handle, dst_rank, NULL, NULL);
  458. //fprintf(stderr,"New rank (%d) of data %d upgraded on node %d\n", dst_rank, starpu_mpi_data_get_tag(handle), my_rank);
  459. starpu_mpi_data_set_rank_comm(handle, dst_rank, MPI_COMM_WORLD);
  460. }
  461. }
  462. }
  463. }
  464. static int deinit_heat()
  465. {
  466. int i;
  467. if ((!user_itf) || (nneighbors == 0))
  468. return 1;
  469. _STARPU_DEBUG("Shutting down heat lb policy\n");
  470. unsigned int ndata_to_move_back = HASH_COUNT(mdh);
  471. if (ndata_to_move_back)
  472. {
  473. _STARPU_DEBUG("Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
  474. data_movements_reallocate_tables(data_movements_handles[my_rank], ndata_to_move_back);
  475. int *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
  476. int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
  477. int n = 0;
  478. struct moved_data_entry *md, *tmp;
  479. HASH_ITER(hh, mdh, md, tmp)
  480. {
  481. tags[n] = starpu_mpi_data_get_tag(md->handle);
  482. ranks[n] = my_rank;
  483. n++;
  484. }
  485. }
  486. else
  487. data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
  488. exchange_data_movements_infos();
  489. move_back_data();
  490. /* This assert ensures that all nodes have properly gotten back all the
  491. * data that has been moven out of the node. */
  492. STARPU_ASSERT(HASH_COUNT(mdh) == 0);
  493. free(mdh);
  494. mdh = NULL;
  495. starpu_data_unregister(*load_data_handle);
  496. free(load_data_handle);
  497. load_data_handle = NULL;
  498. starpu_mpi_cache_flush(MPI_COMM_WORLD, *load_data_handle_cpy);
  499. starpu_data_unregister(*load_data_handle_cpy);
  500. free(load_data_handle_cpy);
  501. load_data_handle_cpy = NULL;
  502. for (i = 0; i < nneighbors; i++)
  503. {
  504. starpu_mpi_cache_flush(MPI_COMM_WORLD, neighbor_load_data_handles[i]);
  505. starpu_data_unregister(neighbor_load_data_handles[i]);
  506. }
  507. free(neighbor_load_data_handles);
  508. neighbor_load_data_handles = NULL;
  509. nneighbors = 0;
  510. free(neighbor_ids);
  511. neighbor_ids = NULL;
  512. for (i = 0; i < world_size; i++)
  513. {
  514. starpu_mpi_cache_flush(MPI_COMM_WORLD, data_movements_handles[i]);
  515. data_movements_reallocate_tables(data_movements_handles[i], 0);
  516. starpu_data_unregister(data_movements_handles[i]);
  517. }
  518. free(data_movements_handles);
  519. data_movements_handles = NULL;
  520. STARPU_PTHREAD_MUTEX_DESTROY(&load_data_mutex);
  521. STARPU_PTHREAD_COND_DESTROY(&load_data_cond);
  522. free(user_itf);
  523. user_itf = NULL;
  524. return 0;
  525. }
  526. /******************************************************************************
  527. * Policy *
  528. *****************************************************************************/
  529. struct load_balancer_policy load_heat_propagation_policy =
  530. {
  531. .init = init_heat,
  532. .deinit = deinit_heat,
  533. .submitted_task_entry_point = submitted_task_heat,
  534. .finished_task_entry_point = finished_task_heat,
  535. .policy_name = "heat"
  536. };