driver_mpi_source.c 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2016-2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <mpi.h>
  17. #include <errno.h>
  18. #include <starpu.h>
  19. #include <drivers/mpi/driver_mpi_source.h>
  20. #include <drivers/mpi/driver_mpi_common.h>
  21. #include <datawizard/memory_nodes.h>
  22. #include <drivers/driver_common/driver_common.h>
  23. #include <drivers/mp_common/source_common.h>
  24. /* Mutex for concurrent access to the table.
  25. */
  26. starpu_pthread_mutex_t htbl_mutex = STARPU_PTHREAD_MUTEX_INITIALIZER;
  27. /* Structure used by host to store informations about a kernel executable on
  28. * a MPI MS device : its name, and its address on each device.
  29. * If a kernel has been initialized, then a lookup has already been achieved and the
  30. * device knows how to call it, else the host still needs to do a lookup.
  31. */
  32. struct _starpu_mpi_ms_kernel
  33. {
  34. UT_hash_handle hh;
  35. char *name;
  36. starpu_mpi_ms_kernel_t func[STARPU_MAXMPIDEVS];
  37. } *kernels;
  38. /* Array of structures containing all the informations useful to send
  39. * and receive informations with devices */
  40. struct _starpu_mp_node *_starpu_mpi_ms_nodes[STARPU_MAXMPIDEVS];
  41. struct _starpu_mp_node *_starpu_mpi_ms_src_get_actual_thread_mp_node()
  42. {
  43. struct _starpu_worker *actual_worker = _starpu_get_local_worker_key();
  44. STARPU_ASSERT(actual_worker);
  45. int devid = actual_worker->devid;
  46. STARPU_ASSERT(devid >= 0 && devid < STARPU_MAXMPIDEVS);
  47. return _starpu_mpi_ms_nodes[devid];
  48. }
  49. void _starpu_mpi_source_init(struct _starpu_mp_node *node)
  50. {
  51. _starpu_mpi_common_mp_initialize_src_sink(node);
  52. //TODO
  53. }
  54. void _starpu_mpi_source_deinit(struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED)
  55. {
  56. }
  57. struct _starpu_mp_node *_starpu_mpi_src_get_mp_node_from_memory_node(int memory_node)
  58. {
  59. int devid = starpu_memory_node_get_devid(memory_node);
  60. STARPU_ASSERT_MSG(devid >= 0 && devid < STARPU_MAXMPIDEVS, "bogus devid %d for memory node %d\n", devid, memory_node);
  61. return _starpu_mpi_ms_nodes[devid];
  62. }
  63. int _starpu_mpi_src_allocate_memory(void ** addr, size_t size, unsigned memory_node)
  64. {
  65. const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
  66. return _starpu_src_common_allocate(mp_node, addr, size);
  67. }
  68. void _starpu_mpi_source_free_memory(void *addr, unsigned memory_node)
  69. {
  70. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
  71. _starpu_src_common_free(mp_node, addr);
  72. }
  73. /* Transfert SIZE bytes from the address pointed by SRC in the SRC_NODE memory
  74. * node to the address pointed by DST in the DST_NODE memory node
  75. */
  76. int _starpu_mpi_copy_ram_to_mpi_sync(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size)
  77. {
  78. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(dst_node);
  79. return _starpu_src_common_copy_host_to_sink_sync(mp_node, src, dst, size);
  80. }
  81. /* Transfert SIZE bytes from the address pointed by SRC in the SRC_NODE memory
  82. * node to the address pointed by DST in the DST_NODE memory node
  83. */
  84. int _starpu_mpi_copy_mpi_to_ram_sync(void *src, unsigned src_node, void *dst, unsigned dst_node STARPU_ATTRIBUTE_UNUSED, size_t size)
  85. {
  86. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(src_node);
  87. return _starpu_src_common_copy_sink_to_host_sync(mp_node, src, dst, size);
  88. }
  89. int _starpu_mpi_copy_sink_to_sink_sync(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size)
  90. {
  91. return _starpu_src_common_copy_sink_to_sink_sync(_starpu_mpi_src_get_mp_node_from_memory_node(src_node),
  92. _starpu_mpi_src_get_mp_node_from_memory_node(dst_node),
  93. src, dst, size);
  94. }
  95. int _starpu_mpi_copy_mpi_to_ram_async(void *src, unsigned src_node, void *dst, unsigned dst_node STARPU_ATTRIBUTE_UNUSED, size_t size, void * event)
  96. {
  97. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(src_node);
  98. return _starpu_src_common_copy_sink_to_host_async(mp_node, src, dst, size, event);
  99. }
  100. int _starpu_mpi_copy_ram_to_mpi_async(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size, void * event)
  101. {
  102. struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(dst_node);
  103. return _starpu_src_common_copy_host_to_sink_async(mp_node, src, dst, size, event);
  104. }
  105. int _starpu_mpi_copy_sink_to_sink_async(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size, void * event)
  106. {
  107. return _starpu_src_common_copy_sink_to_sink_async(_starpu_mpi_src_get_mp_node_from_memory_node(src_node),
  108. _starpu_mpi_src_get_mp_node_from_memory_node(dst_node),
  109. src, dst, size, event);
  110. }
  111. int starpu_mpi_ms_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name)
  112. {
  113. unsigned int func_name_size = (strlen(func_name) + 1) * sizeof(char);
  114. STARPU_PTHREAD_MUTEX_LOCK(&htbl_mutex);
  115. struct _starpu_mpi_ms_kernel *kernel;
  116. HASH_FIND_STR(kernels, func_name, kernel);
  117. if (kernel != NULL)
  118. {
  119. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  120. // Function already in the table.
  121. *symbol = kernel;
  122. return 0;
  123. }
  124. kernel = malloc(sizeof(*kernel));
  125. if (kernel == NULL)
  126. {
  127. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  128. return -ENOMEM;
  129. }
  130. kernel->name = malloc(func_name_size);
  131. if (kernel->name == NULL)
  132. {
  133. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  134. free(kernel);
  135. return -ENOMEM;
  136. }
  137. memcpy(kernel->name, func_name, func_name_size);
  138. HASH_ADD_STR(kernels, name, kernel);
  139. unsigned int nb_mpi_devices = _starpu_mpi_src_get_device_count();
  140. unsigned int i;
  141. for (i = 0; i < nb_mpi_devices; ++i)
  142. kernel->func[i] = NULL;
  143. STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
  144. *symbol = kernel;
  145. return 0;
  146. }
  147. starpu_mpi_ms_kernel_t starpu_mpi_ms_get_kernel(starpu_mpi_ms_func_symbol_t symbol)
  148. {
  149. int workerid = starpu_worker_get_id();
  150. /* This function has to be called in the codelet only, by the thread
  151. * which will handle the task */
  152. if (workerid < 0)
  153. return NULL;
  154. int devid = starpu_worker_get_devid(workerid);
  155. struct _starpu_mpi_ms_kernel *kernel = symbol;
  156. if (kernel->func[devid] == NULL)
  157. {
  158. struct _starpu_mp_node *node = _starpu_mpi_ms_nodes[devid];
  159. int ret = _starpu_src_common_lookup(node, (void (**)(void))&kernel->func[devid], kernel->name);
  160. if (ret)
  161. return NULL;
  162. }
  163. return kernel->func[devid];
  164. }
  165. starpu_mpi_ms_kernel_t _starpu_mpi_ms_src_get_kernel_from_codelet(struct starpu_codelet *cl, unsigned nimpl)
  166. {
  167. starpu_mpi_ms_kernel_t kernel = NULL;
  168. starpu_mpi_ms_func_t func = _starpu_task_get_mpi_ms_nth_implementation(cl, nimpl);
  169. if (func)
  170. {
  171. /* We execute the function contained in the codelet, it must return a
  172. * pointer to the function to execute on the device, either specified
  173. * directly by the user or by a call to starpu_mic_get_func().
  174. */
  175. kernel = func();
  176. }
  177. else
  178. {
  179. /* If user dont define any starpu_mic_fun_t in cl->mic_func we try to use
  180. * cpu_func_name.
  181. */
  182. const char *func_name = _starpu_task_get_cpu_name_nth_implementation(cl, nimpl);
  183. if (func_name)
  184. {
  185. starpu_mpi_ms_func_symbol_t symbol;
  186. starpu_mpi_ms_register_kernel(&symbol, func_name);
  187. kernel = starpu_mpi_ms_get_kernel(symbol);
  188. }
  189. }
  190. STARPU_ASSERT_MSG(kernel, "when STARPU_MPI_MS is defined in 'where', mpi_ms_funcs or cpu_funcs_name has to be defined and the function be non-static");
  191. return kernel;
  192. }
  193. void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void)
  194. {
  195. starpu_mpi_ms_kernel_t kernel = NULL;
  196. starpu_mpi_ms_func_t func = _starpu_task_get_mpi_ms_nth_implementation(j->task->cl, j->nimpl);
  197. if (func)
  198. {
  199. /* We execute the function contained in the codelet, it must return a
  200. * pointer to the function to execute on the device, either specified
  201. * directly by the user or by a call to starpu_mpi_ms_get_func().
  202. */
  203. kernel = func();
  204. }
  205. else
  206. {
  207. /* If user dont define any starpu_mpi_ms_fun_t in cl->mpi_ms_func we try to use
  208. * cpu_func_name.
  209. */
  210. const char *func_name = _starpu_task_get_cpu_name_nth_implementation(j->task->cl, j->nimpl);
  211. if (func_name)
  212. {
  213. starpu_mpi_ms_func_symbol_t symbol;
  214. starpu_mpi_ms_register_kernel(&symbol, func_name);
  215. kernel = starpu_mpi_ms_get_kernel(symbol);
  216. }
  217. }
  218. STARPU_ASSERT(kernel);
  219. return (void (*)(void))kernel;
  220. }
  221. unsigned _starpu_mpi_src_get_device_count()
  222. {
  223. int nb_mpi_devices;
  224. if (!_starpu_mpi_common_is_mp_initialized())
  225. return 0;
  226. MPI_Comm_size(MPI_COMM_WORLD, &nb_mpi_devices);
  227. //Remove one for master
  228. nb_mpi_devices = nb_mpi_devices - 1;
  229. return nb_mpi_devices;
  230. }
  231. void *_starpu_mpi_src_worker(void *arg)
  232. {
  233. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  234. struct _starpu_worker_set *worker_set_mpi = (struct _starpu_worker_set *) arg;
  235. int nbsinknodes = _starpu_mpi_src_get_device_count();
  236. int workersetnum;
  237. for (workersetnum = 0; workersetnum < nbsinknodes; workersetnum++)
  238. {
  239. struct _starpu_worker_set * worker_set = &worker_set_mpi[workersetnum];
  240. #else
  241. struct _starpu_worker_set *worker_set = arg;
  242. #endif
  243. /* As all workers of a set share common data, we just use the first
  244. * one for intializing the following stuffs. */
  245. struct _starpu_worker *baseworker = &worker_set->workers[0];
  246. struct _starpu_machine_config *config = baseworker->config;
  247. unsigned baseworkerid = baseworker - config->workers;
  248. unsigned devid = baseworker->devid;
  249. unsigned i;
  250. /* unsigned memnode = baseworker->memory_node; */
  251. _starpu_driver_start(baseworker, STARPU_CPU_WORKER, 0);
  252. #ifdef STARPU_USE_FXT
  253. for (i = 1; i < worker_set->nworkers; i++)
  254. _starpu_worker_start(&worker_set->workers[i], STARPU_MPI_WORKER, 0);
  255. #endif
  256. // Current task for a thread managing a worker set has no sense.
  257. _starpu_set_current_task(NULL);
  258. for (i = 0; i < config->topology.nworker[STARPU_MPI_MS_WORKER][devid]; i++)
  259. {
  260. struct _starpu_worker *worker = &config->workers[baseworkerid+i];
  261. snprintf(worker->name, sizeof(worker->name), "MPI_MS %u core %u", devid, i);
  262. snprintf(worker->short_name, sizeof(worker->short_name), "MPI_MS %u.%u", devid, i);
  263. }
  264. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  265. {
  266. char thread_name[16];
  267. snprintf(thread_name, sizeof(thread_name), "MPI_MS");
  268. starpu_pthread_setname(thread_name);
  269. }
  270. #else
  271. {
  272. char thread_name[16];
  273. snprintf(thread_name, sizeof(thread_name), "MPI_MS %u", devid);
  274. starpu_pthread_setname(thread_name);
  275. }
  276. #endif
  277. for (i = 0; i < worker_set->nworkers; i++)
  278. {
  279. struct _starpu_worker *worker = &worker_set->workers[i];
  280. _STARPU_TRACE_WORKER_INIT_END(worker->workerid);
  281. }
  282. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  283. _starpu_src_common_init_switch_env(workersetnum);
  284. } /* for */
  285. /* set the worker zero for the main thread */
  286. for (workersetnum = 0; workersetnum < nbsinknodes; workersetnum++)
  287. {
  288. struct _starpu_worker_set * worker_set = &worker_set_mpi[workersetnum];
  289. struct _starpu_worker *baseworker = &worker_set->workers[0];
  290. #endif
  291. /* tell the main thread that this one is ready */
  292. STARPU_PTHREAD_MUTEX_LOCK(&worker_set->mutex);
  293. baseworker->status = STATUS_UNKNOWN;
  294. worker_set->set_is_initialized = 1;
  295. STARPU_PTHREAD_COND_SIGNAL(&worker_set->ready_cond);
  296. STARPU_PTHREAD_MUTEX_UNLOCK(&worker_set->mutex);
  297. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  298. }
  299. #endif
  300. #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
  301. _starpu_src_common_workers_set(worker_set_mpi, nbsinknodes, _starpu_mpi_ms_nodes);
  302. #else
  303. _starpu_src_common_worker(worker_set, baseworkerid, _starpu_mpi_ms_nodes[devid]);
  304. #endif
  305. return NULL;
  306. }
  307. int _starpu_mpi_copy_interface_from_mpi_to_cpu(starpu_data_handle_t handle, void *src_interface, unsigned src_node, void *dst_interface, unsigned dst_node, struct _starpu_data_request *req)
  308. {
  309. int src_kind = starpu_node_get_kind(src_node);
  310. int dst_kind = starpu_node_get_kind(dst_node);
  311. STARPU_ASSERT(src_kind == STARPU_MPI_MS_RAM && dst_kind == STARPU_CPU_RAM);
  312. int ret = 0;
  313. const struct starpu_data_copy_methods *copy_methods = handle->ops->copy_methods;
  314. if (!req || starpu_asynchronous_copy_disabled() || starpu_asynchronous_mpi_ms_copy_disabled() || !(copy_methods->mpi_ms_to_ram_async || copy_methods->any_to_any))
  315. {
  316. /* this is not associated to a request so it's synchronous */
  317. STARPU_ASSERT(copy_methods->mpi_ms_to_ram || copy_methods->any_to_any);
  318. if (copy_methods->mpi_ms_to_ram)
  319. copy_methods->mpi_ms_to_ram(src_interface, src_node, dst_interface, dst_node);
  320. else
  321. copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, NULL);
  322. }
  323. else
  324. {
  325. req->async_channel.node_ops = &_starpu_driver_mpi_node_ops;
  326. if(copy_methods->mpi_ms_to_ram_async)
  327. ret = copy_methods->mpi_ms_to_ram_async(src_interface, src_node, dst_interface, dst_node, &req->async_channel);
  328. else
  329. {
  330. STARPU_ASSERT(copy_methods->any_to_any);
  331. ret = copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, &req->async_channel);
  332. }
  333. }
  334. return ret;
  335. }
  336. int _starpu_mpi_copy_interface_from_mpi_to_mpi(starpu_data_handle_t handle, void *src_interface, unsigned src_node, void *dst_interface, unsigned dst_node, struct _starpu_data_request *req)
  337. {
  338. int src_kind = starpu_node_get_kind(src_node);
  339. int dst_kind = starpu_node_get_kind(dst_node);
  340. STARPU_ASSERT(src_kind == STARPU_MPI_MS_RAM && dst_kind == STARPU_MPI_MS_RAM);
  341. int ret = 0;
  342. const struct starpu_data_copy_methods *copy_methods = handle->ops->copy_methods;
  343. if (!req || starpu_asynchronous_copy_disabled() || starpu_asynchronous_mpi_ms_copy_disabled() || !(copy_methods->mpi_ms_to_mpi_ms_async || copy_methods->any_to_any))
  344. {
  345. /* this is not associated to a request so it's synchronous */
  346. STARPU_ASSERT(copy_methods->mpi_ms_to_mpi_ms || copy_methods->any_to_any);
  347. if (copy_methods->mpi_ms_to_mpi_ms)
  348. copy_methods->mpi_ms_to_mpi_ms(src_interface, src_node, dst_interface, dst_node);
  349. else
  350. copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, NULL);
  351. }
  352. else
  353. {
  354. req->async_channel.node_ops = &_starpu_driver_mpi_node_ops;
  355. if(copy_methods->mpi_ms_to_mpi_ms_async)
  356. ret = copy_methods->mpi_ms_to_mpi_ms_async(src_interface, src_node, dst_interface, dst_node, &req->async_channel);
  357. else
  358. {
  359. STARPU_ASSERT(copy_methods->any_to_any);
  360. ret = copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, &req->async_channel);
  361. }
  362. }
  363. return ret;
  364. }
  365. int _starpu_mpi_copy_interface_from_cpu_to_mpi(starpu_data_handle_t handle, void *src_interface, unsigned src_node, void *dst_interface, unsigned dst_node, struct _starpu_data_request *req)
  366. {
  367. int src_kind = starpu_node_get_kind(src_node);
  368. int dst_kind = starpu_node_get_kind(dst_node);
  369. STARPU_ASSERT(src_kind == STARPU_CPU_RAM && dst_kind == STARPU_MPI_MS_RAM);
  370. int ret = 0;
  371. const struct starpu_data_copy_methods *copy_methods = handle->ops->copy_methods;
  372. if (!req || starpu_asynchronous_copy_disabled() || starpu_asynchronous_mpi_ms_copy_disabled() || !(copy_methods->ram_to_mpi_ms_async || copy_methods->any_to_any))
  373. {
  374. /* this is not associated to a request so it's synchronous */
  375. STARPU_ASSERT(copy_methods->ram_to_mpi_ms || copy_methods->any_to_any);
  376. if (copy_methods->ram_to_mpi_ms)
  377. copy_methods->ram_to_mpi_ms(src_interface, src_node, dst_interface, dst_node);
  378. else
  379. copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, NULL);
  380. }
  381. else
  382. {
  383. req->async_channel.node_ops = &_starpu_driver_mpi_node_ops;
  384. if(copy_methods->ram_to_mpi_ms_async)
  385. ret = copy_methods->ram_to_mpi_ms_async(src_interface, src_node, dst_interface, dst_node, &req->async_channel);
  386. else
  387. {
  388. STARPU_ASSERT(copy_methods->any_to_any);
  389. ret = copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, &req->async_channel);
  390. }
  391. }
  392. return ret;
  393. }
  394. int _starpu_mpi_copy_data_from_mpi_to_cpu(uintptr_t src, size_t src_offset, unsigned src_node, uintptr_t dst, size_t dst_offset, unsigned dst_node, size_t size, struct _starpu_async_channel *async_channel)
  395. {
  396. int src_kind = starpu_node_get_kind(src_node);
  397. int dst_kind = starpu_node_get_kind(dst_node);
  398. STARPU_ASSERT(src_kind == STARPU_MPI_MS_RAM && dst_kind == STARPU_CPU_RAM);
  399. if (async_channel)
  400. return _starpu_mpi_copy_mpi_to_ram_async((void*) (src + src_offset), src_node,
  401. (void*) (dst + dst_offset), dst_node,
  402. size, async_channel);
  403. else
  404. return _starpu_mpi_copy_mpi_to_ram_sync((void*) (src + src_offset), src_node,
  405. (void*) (dst + dst_offset), dst_node,
  406. size);
  407. }
  408. int _starpu_mpi_copy_data_from_mpi_to_mpi(uintptr_t src, size_t src_offset, unsigned src_node, uintptr_t dst, size_t dst_offset, unsigned dst_node, size_t size, struct _starpu_async_channel *async_channel)
  409. {
  410. int src_kind = starpu_node_get_kind(src_node);
  411. int dst_kind = starpu_node_get_kind(dst_node);
  412. STARPU_ASSERT(src_kind == STARPU_MPI_MS_RAM && dst_kind == STARPU_MPI_MS_RAM);
  413. if (async_channel)
  414. return _starpu_mpi_copy_sink_to_sink_async((void*) (src + src_offset), src_node,
  415. (void*) (dst + dst_offset), dst_node,
  416. size, async_channel);
  417. else
  418. return _starpu_mpi_copy_sink_to_sink_sync((void*) (src + src_offset), src_node,
  419. (void*) (dst + dst_offset), dst_node,
  420. size);
  421. }
  422. int _starpu_mpi_copy_data_from_cpu_to_mpi(uintptr_t src, size_t src_offset, unsigned src_node, uintptr_t dst, size_t dst_offset, unsigned dst_node, size_t size, struct _starpu_async_channel *async_channel)
  423. {
  424. int src_kind = starpu_node_get_kind(src_node);
  425. int dst_kind = starpu_node_get_kind(dst_node);
  426. STARPU_ASSERT(src_kind == STARPU_CPU_RAM && dst_kind == STARPU_MPI_MS_RAM);
  427. if (async_channel)
  428. return _starpu_mpi_copy_ram_to_mpi_async((void*) (src + src_offset), src_node,
  429. (void*) (dst + dst_offset), dst_node,
  430. size, async_channel);
  431. else
  432. return _starpu_mpi_copy_ram_to_mpi_sync((void*) (src + src_offset), src_node,
  433. (void*) (dst + dst_offset), dst_node,
  434. size);
  435. }
  436. int _starpu_mpi_is_direct_access_supported(unsigned node, unsigned handling_node)
  437. {
  438. (void) node;
  439. enum starpu_node_kind kind = starpu_node_get_kind(handling_node);
  440. return (kind == STARPU_MPI_MS_RAM);
  441. }
  442. uintptr_t _starpu_mpi_malloc_on_node(unsigned dst_node, size_t size, int flags)
  443. {
  444. (void) flags;
  445. uintptr_t addr = 0;
  446. if (_starpu_mpi_src_allocate_memory((void **)(&addr), size, dst_node))
  447. addr = 0;
  448. return addr;
  449. }
  450. void _starpu_mpi_free_on_node(unsigned dst_node, uintptr_t addr, size_t size, int flags)
  451. {
  452. (void) flags;
  453. (void) size;
  454. _starpu_mpi_source_free_memory((void*) addr, dst_node);
  455. }
  456. struct _starpu_node_ops _starpu_driver_mpi_node_ops =
  457. {
  458. .copy_interface_to[STARPU_CPU_RAM] = _starpu_mpi_copy_interface_from_mpi_to_cpu,
  459. .copy_interface_to[STARPU_MPI_MS_RAM] = _starpu_mpi_copy_interface_from_mpi_to_mpi,
  460. .copy_data_to[STARPU_CPU_RAM] = _starpu_mpi_copy_data_from_mpi_to_cpu,
  461. .copy_data_to[STARPU_MPI_MS_RAM] = _starpu_mpi_copy_data_from_mpi_to_mpi,
  462. /* TODO: copy2D/3D? */
  463. .wait_request_completion = _starpu_mpi_common_wait_request_completion,
  464. .test_request_completion = _starpu_mpi_common_test_event,
  465. .is_direct_access_supported = _starpu_mpi_is_direct_access_supported,
  466. .malloc_on_node = _starpu_mpi_malloc_on_node,
  467. .free_on_node = _starpu_mpi_free_on_node,
  468. .name = "mpi driver"
  469. };