driver_mpi_source.c 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2016,2017 Inria
  4. * Copyright (C) 2017, 2019 CNRS
  5. * Copyright (C) 2017 Université de Bordeaux
  6. * Copyright (C) 2015 Mathieu Lirzin
  7. *
  8. * StarPU is free software; you can redistribute it and/or modify
  9. * it under the terms of the GNU Lesser General Public License as published by
  10. * the Free Software Foundation; either version 2.1 of the License, or (at
  11. * your option) any later version.
  12. *
  13. * StarPU is distributed in the hope that it will be useful, but
  14. * WITHOUT ANY WARRANTY; without even the implied warranty of
  15. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  16. *
  17. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  18. */
  19. #include <mpi.h>
  20. #include <errno.h>
  21. #include <starpu.h>
  22. #include <drivers/mpi/driver_mpi_source.h>
  23. #include <drivers/mpi/driver_mpi_common.h>
  24. #include <datawizard/memory_nodes.h>
  25. #include <drivers/driver_common/driver_common.h>
  26. #include <drivers/mp_common/source_common.h>
  27. /* Mutex for concurrent access to the table.
  28. */
  29. starpu_pthread_mutex_t htbl_mutex = STARPU_PTHREAD_MUTEX_INITIALIZER;
  30. /* Structure used by host to store informations about a kernel executable on
  31. * a MPI MS device : its name, and its address on each device.
  32. * If a kernel has been initialized, then a lookup has already been achieved and the
  33. * device knows how to call it, else the host still needs to do a lookup.
  34. */
  35. struct _starpu_mpi_ms_kernel
  36. {
  37. UT_hash_handle hh;
  38. char *name;
  39. starpu_mpi_ms_kernel_t func[STARPU_MAXMPIDEVS];
  40. } *kernels;
  41. /* Array of structures containing all the informations useful to send
  42. * and receive informations with devices */
  43. struct _starpu_mp_node *_starpu_mpi_ms_nodes[STARPU_MAXMPIDEVS];
  44. struct _starpu_mp_node *_starpu_mpi_ms_src_get_actual_thread_mp_node()
  45. {
  46. struct _starpu_worker *actual_worker = _starpu_get_local_worker_key();
  47. STARPU_ASSERT(actual_worker);
  48. int devid = actual_worker->devid;
  49. STARPU_ASSERT(devid >= 0 && devid < STARPU_MAXMPIDEVS);
  50. return _starpu_mpi_ms_nodes[devid];
  51. }
  52. void _starpu_mpi_source_init(struct _starpu_mp_node *node)
  53. {
  54. _starpu_mpi_common_mp_initialize_src_sink(node);
  55. //TODO
  56. }
  57. void _starpu_mpi_source_deinit(struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED)
  58. {
  59. }
  60. struct _starpu_mp_node *_starpu_mpi_src_get_mp_node_from_memory_node(int memory_node)
  61. {
  62. int devid = _starpu_memory_node_get_devid(memory_node);
  63. STARPU_ASSERT_MSG(devid >= 0 && devid < STARPU_MAXMPIDEVS, "bogus devid %d for memory node %d\n", devid, memory_node);
  64. return _starpu_mpi_ms_nodes[devid];
  65. }
  66. int _starpu_mpi_src_allocate_memory(void ** addr, size_t size, unsigned memory_node)
  67. {
  68. const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
  69. return _starpu_src_common_allocate(mp_node, addr, size);
  70. }
  71. void _starpu_mpi_source_free_memory(void *addr, unsigned memory_node)
  72. {
  73. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
  74. _starpu_src_common_free(mp_node, addr);
  75. }
  76. /* Transfert SIZE bytes from the address pointed by SRC in the SRC_NODE memory
  77. * node to the address pointed by DST in the DST_NODE memory node
  78. */
  79. int _starpu_mpi_copy_ram_to_mpi_sync(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size)
  80. {
  81. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(dst_node);
  82. return _starpu_src_common_copy_host_to_sink_sync(mp_node, src, dst, size);
  83. }
  84. /* Transfert SIZE bytes from the address pointed by SRC in the SRC_NODE memory
  85. * node to the address pointed by DST in the DST_NODE memory node
  86. */
  87. int _starpu_mpi_copy_mpi_to_ram_sync(void *src, unsigned src_node, void *dst, unsigned dst_node STARPU_ATTRIBUTE_UNUSED, size_t size)
  88. {
  89. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(src_node);
  90. return _starpu_src_common_copy_sink_to_host_sync(mp_node, src, dst, size);
  91. }
  92. int _starpu_mpi_copy_sink_to_sink_sync(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size)
  93. {
  94. return _starpu_src_common_copy_sink_to_sink_sync(_starpu_mpi_src_get_mp_node_from_memory_node(src_node),
  95. _starpu_mpi_src_get_mp_node_from_memory_node(dst_node),
  96. src, dst, size);
  97. }
  98. int _starpu_mpi_copy_mpi_to_ram_async(void *src, unsigned src_node, void *dst, unsigned dst_node STARPU_ATTRIBUTE_UNUSED, size_t size, void * event)
  99. {
  100. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(src_node);
  101. return _starpu_src_common_copy_sink_to_host_async(mp_node, src, dst, size, event);
  102. }
  103. int _starpu_mpi_copy_ram_to_mpi_async(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size, void * event)
  104. {
  105. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(dst_node);
  106. return _starpu_src_common_copy_host_to_sink_async(mp_node, src, dst, size, event);
  107. }
  108. int _starpu_mpi_copy_sink_to_sink_async(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size, void * event)
  109. {
  110. return _starpu_src_common_copy_sink_to_sink_async(_starpu_mpi_src_get_mp_node_from_memory_node(src_node),
  111. _starpu_mpi_src_get_mp_node_from_memory_node(dst_node),
  112. src, dst, size, event);
  113. }
  114. int starpu_mpi_ms_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name)
  115. {
  116. unsigned int func_name_size = (strlen(func_name) + 1) * sizeof(char);
  117. STARPU_PTHREAD_MUTEX_LOCK(&htbl_mutex);
  118. struct _starpu_mpi_ms_kernel *kernel;
  119. HASH_FIND_STR(kernels, func_name, kernel);
  120. if (kernel != NULL)
  121. {
  122. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  123. // Function already in the table.
  124. *symbol = kernel;
  125. return 0;
  126. }
  127. kernel = malloc(sizeof(*kernel));
  128. if (kernel == NULL)
  129. {
  130. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  131. return -ENOMEM;
  132. }
  133. kernel->name = malloc(func_name_size);
  134. if (kernel->name == NULL)
  135. {
  136. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  137. free(kernel);
  138. return -ENOMEM;
  139. }
  140. memcpy(kernel->name, func_name, func_name_size);
  141. HASH_ADD_STR(kernels, name, kernel);
  142. unsigned int nb_mpi_devices = _starpu_mpi_src_get_device_count();
  143. unsigned int i;
  144. for (i = 0; i < nb_mpi_devices; ++i)
  145. kernel->func[i] = NULL;
  146. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  147. *symbol = kernel;
  148. return 0;
  149. }
  150. starpu_mpi_ms_kernel_t starpu_mpi_ms_get_kernel(starpu_mpi_ms_func_symbol_t symbol)
  151. {
  152. int workerid = starpu_worker_get_id();
  153. /* This function has to be called in the codelet only, by the thread
  154. * which will handle the task */
  155. if (workerid < 0)
  156. return NULL;
  157. int devid = starpu_worker_get_devid(workerid);
  158. struct _starpu_mpi_ms_kernel *kernel = symbol;
  159. if (kernel->func[devid] == NULL)
  160. {
  161. struct _starpu_mp_node *node = _starpu_mpi_ms_nodes[devid];
  162. int ret = _starpu_src_common_lookup(node, (void (**)(void))&kernel->func[devid], kernel->name);
  163. if (ret)
  164. return NULL;
  165. }
  166. return kernel->func[devid];
  167. }
  168. starpu_mpi_ms_kernel_t _starpu_mpi_ms_src_get_kernel_from_codelet(struct starpu_codelet *cl, unsigned nimpl)
  169. {
  170. starpu_mpi_ms_kernel_t kernel = NULL;
  171. starpu_mpi_ms_func_t func = _starpu_task_get_mpi_ms_nth_implementation(cl, nimpl);
  172. if (func)
  173. {
  174. /* We execute the function contained in the codelet, it must return a
  175. * pointer to the function to execute on the device, either specified
  176. * directly by the user or by a call to starpu_mic_get_func().
  177. */
  178. kernel = func();
  179. }
  180. else
  181. {
  182. /* If user dont define any starpu_mic_fun_t in cl->mic_func we try to use
  183. * cpu_func_name.
  184. */
  185. const char *func_name = _starpu_task_get_cpu_name_nth_implementation(cl, nimpl);
  186. if (func_name)
  187. {
  188. starpu_mpi_ms_func_symbol_t symbol;
  189. starpu_mpi_ms_register_kernel(&symbol, func_name);
  190. kernel = starpu_mpi_ms_get_kernel(symbol);
  191. }
  192. }
  193. STARPU_ASSERT_MSG(kernel, "when STARPU_MPI_MS is defined in 'where', mpi_ms_funcs or cpu_funcs_name has to be defined and the function be non-static");
  194. return kernel;
  195. }
  196. void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void)
  197. {
  198. starpu_mpi_ms_kernel_t kernel = NULL;
  199. starpu_mpi_ms_func_t func = _starpu_task_get_mpi_ms_nth_implementation(j->task->cl, j->nimpl);
  200. if (func)
  201. {
  202. /* We execute the function contained in the codelet, it must return a
  203. * pointer to the function to execute on the device, either specified
  204. * directly by the user or by a call to starpu_mpi_ms_get_func().
  205. */
  206. kernel = func();
  207. }
  208. else
  209. {
  210. /* If user dont define any starpu_mpi_ms_fun_t in cl->mpi_ms_func we try to use
  211. * cpu_func_name.
  212. */
  213. const char *func_name = _starpu_task_get_cpu_name_nth_implementation(j->task->cl, j->nimpl);
  214. if (func_name)
  215. {
  216. starpu_mpi_ms_func_symbol_t symbol;
  217. starpu_mpi_ms_register_kernel(&symbol, func_name);
  218. kernel = starpu_mpi_ms_get_kernel(symbol);
  219. }
  220. }
  221. STARPU_ASSERT(kernel);
  222. return (void (*)(void))kernel;
  223. }
  224. unsigned _starpu_mpi_src_get_device_count()
  225. {
  226. int nb_mpi_devices;
  227. if (!_starpu_mpi_common_is_mp_initialized())
  228. return 0;
  229. MPI_Comm_size(MPI_COMM_WORLD, &nb_mpi_devices);
  230. //Remove one for master
  231. nb_mpi_devices = nb_mpi_devices - 1;
  232. return nb_mpi_devices;
  233. }
  234. void *_starpu_mpi_src_worker(void *arg)
  235. {
  236. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  237. struct _starpu_worker_set *worker_set_mpi = (struct _starpu_worker_set *) arg;
  238. int nbsinknodes = _starpu_mpi_src_get_device_count();
  239. int workersetnum;
  240. for (workersetnum = 0; workersetnum < nbsinknodes; workersetnum++)
  241. {
  242. struct _starpu_worker_set * worker_set = &worker_set_mpi[workersetnum];
  243. #else
  244. struct _starpu_worker_set *worker_set = arg;
  245. #endif
  246. /* As all workers of a set share common data, we just use the first
  247. * one for intializing the following stuffs. */
  248. struct _starpu_worker *baseworker = &worker_set->workers[0];
  249. struct _starpu_machine_config *config = baseworker->config;
  250. unsigned baseworkerid = baseworker - config->workers;
  251. unsigned devid = baseworker->devid;
  252. unsigned i;
  253. /* unsigned memnode = baseworker->memory_node; */
  254. _starpu_driver_start(baseworker, _STARPU_FUT_MPI_KEY, 0);
  255. #ifdef STARPU_USE_FXT
  256. for (i = 1; i < worker_set->nworkers; i++)
  257. _starpu_worker_start(&worker_set->workers[i], _STARPU_FUT_MPI_KEY, 0);
  258. #endif
  259. // Current task for a thread managing a worker set has no sense.
  260. _starpu_set_current_task(NULL);
  261. for (i = 0; i < config->topology.nmpicores[devid]; i++)
  262. {
  263. struct _starpu_worker *worker = &config->workers[baseworkerid+i];
  264. snprintf(worker->name, sizeof(worker->name), "MPI_MS %u core %u", devid, i);
  265. snprintf(worker->short_name, sizeof(worker->short_name), "MPI_MS %u.%u", devid, i);
  266. }
  267. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  268. {
  269. char thread_name[16];
  270. snprintf(thread_name, sizeof(thread_name), "MPI_MS");
  271. starpu_pthread_setname(thread_name);
  272. }
  273. #else
  274. {
  275. char thread_name[16];
  276. snprintf(thread_name, sizeof(thread_name), "MPI_MS %u", devid);
  277. starpu_pthread_setname(thread_name);
  278. }
  279. #endif
  280. for (i = 0; i < worker_set->nworkers; i++)
  281. {
  282. struct _starpu_worker *worker = &worker_set->workers[i];
  283. _STARPU_TRACE_WORKER_INIT_END(worker->workerid);
  284. }
  285. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  286. _starpu_src_common_init_switch_env(workersetnum);
  287. } /* for */
  288. /* set the worker zero for the main thread */
  289. for (workersetnum = 0; workersetnum < nbsinknodes; workersetnum++)
  290. {
  291. struct _starpu_worker_set * worker_set = &worker_set_mpi[workersetnum];
  292. struct _starpu_worker *baseworker = &worker_set->workers[0];
  293. #endif
  294. /* tell the main thread that this one is ready */
  295. STARPU_PTHREAD_MUTEX_LOCK(&worker_set->mutex);
  296. baseworker->status = STATUS_UNKNOWN;
  297. worker_set->set_is_initialized = 1;
  298. STARPU_PTHREAD_COND_SIGNAL(&worker_set->ready_cond);
  299. STARPU_PTHREAD_MUTEX_UNLOCK(&worker_set->mutex);
  300. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  301. }
  302. #endif
  303. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  304. _starpu_src_common_workers_set(worker_set_mpi, nbsinknodes, _starpu_mpi_ms_nodes);
  305. #else
  306. _starpu_src_common_worker(worker_set, baseworkerid, _starpu_mpi_ms_nodes[devid]);
  307. #endif
  308. return NULL;
  309. }