|
|
@@ -1545,6 +1545,9 @@ void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, int tag, in
|
|
|
else
|
|
|
{
|
|
|
mpi_data = calloc(1, sizeof(struct _starpu_mpi_node_tag));
|
|
|
+ mpi_data->data_tag = -1;
|
|
|
+ mpi_data->rank = -1;
|
|
|
+ mpi_data->comm = MPI_COMM_WORLD;
|
|
|
data_handle->mpi_data = mpi_data;
|
|
|
_starpu_mpi_data_register_tag(data_handle, tag);
|
|
|
_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_clear_cache);
|
|
|
@@ -1584,6 +1587,108 @@ int starpu_mpi_data_get_tag(starpu_data_handle_t data)
|
|
|
return ((struct _starpu_mpi_node_tag *)(data->mpi_data))->data_tag;
|
|
|
}
|
|
|
|
|
|
+void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
|
|
|
+{
|
|
|
+ int me, rank, tag;
|
|
|
+
|
|
|
+ rank = starpu_mpi_data_get_rank(data_handle);
|
|
|
+ tag = starpu_mpi_data_get_tag(data_handle);
|
|
|
+ if (rank == -1)
|
|
|
+ {
|
|
|
+ _STARPU_ERROR("StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register() or starpu_mpi_data_register()\n");
|
|
|
+ }
|
|
|
+ if (tag == -1)
|
|
|
+ {
|
|
|
+ _STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register() or starpu_mpi_data_register()\n");
|
|
|
+ }
|
|
|
+ starpu_mpi_comm_rank(comm, &me);
|
|
|
+
|
|
|
+ if (node == rank) return;
|
|
|
+
|
|
|
+ if (me == node)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
|
|
|
+ void *already_received = _starpu_mpi_cache_received_data_set(data_handle, rank);
|
|
|
+ if (already_received == NULL)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
|
|
|
+ starpu_mpi_irecv_detached(data_handle, rank, tag, comm, callback, arg);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (me == rank)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
|
|
|
+ void *already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
|
|
|
+ if (already_sent == NULL)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
|
|
|
+ starpu_mpi_isend_detached(data_handle, node, tag, comm, NULL, NULL);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node)
|
|
|
+{
|
|
|
+ int me, rank, tag;
|
|
|
+
|
|
|
+ rank = starpu_mpi_data_get_rank(data_handle);
|
|
|
+ tag = starpu_mpi_data_get_tag(data_handle);
|
|
|
+ if (rank == -1)
|
|
|
+ {
|
|
|
+ fprintf(stderr,"StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register\n");
|
|
|
+ STARPU_ABORT();
|
|
|
+ }
|
|
|
+ if (tag == -1)
|
|
|
+ {
|
|
|
+ fprintf(stderr,"StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
|
|
|
+ STARPU_ABORT();
|
|
|
+ }
|
|
|
+ starpu_mpi_comm_rank(comm, &me);
|
|
|
+
|
|
|
+ if (node == rank) return;
|
|
|
+
|
|
|
+ if (me == node)
|
|
|
+ {
|
|
|
+ MPI_Status status;
|
|
|
+ _STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
|
|
|
+ void *already_received = _starpu_mpi_cache_received_data_set(data_handle, rank);
|
|
|
+ if (already_received == NULL)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
|
|
|
+ starpu_mpi_recv(data_handle, rank, tag, comm, &status);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (me == rank)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
|
|
|
+ void *already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
|
|
|
+ if (already_sent == NULL)
|
|
|
+ {
|
|
|
+ _STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
|
|
|
+ starpu_mpi_send(data_handle, node, tag, comm);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void starpu_mpi_data_migrate(MPI_Comm comm, starpu_data_handle_t data, int new_rank)
|
|
|
+{
|
|
|
+ int old_rank = starpu_mpi_data_get_rank(data);
|
|
|
+ if (new_rank == old_rank)
|
|
|
+ /* Already there */
|
|
|
+ return;
|
|
|
+
|
|
|
+ /* First submit data migration if it's not already on destination */
|
|
|
+ starpu_mpi_get_data_on_node_detached(comm, data, new_rank, NULL, NULL);
|
|
|
+
|
|
|
+ /* And note new owner */
|
|
|
+ starpu_mpi_data_set_rank_comm(data, new_rank, comm);
|
|
|
+
|
|
|
+ /* Flush cache in all other nodes */
|
|
|
+ /* TODO: Ideally we'd transmit the knowledge of who owns it */
|
|
|
+ starpu_mpi_cache_flush(comm, data);
|
|
|
+ return;
|
|
|
+}
|
|
|
+
|
|
|
int starpu_mpi_wait_for_all(MPI_Comm comm)
|
|
|
{
|
|
|
int mpi = 1;
|