Browse Source

mpi: The MPI collective callback function is only called when all communications are completed

Nathalie Furmento 12 years ago
parent
commit
a2d532afd9
1 changed files with 79 additions and 4 deletions
  1. 79 4
      mpi/src/starpu_mpi_collective.c

+ 79 - 4
mpi/src/starpu_mpi_collective.c

@@ -17,14 +17,61 @@
 #include <mpi.h>
 #include <starpu.h>
 #include <starpu_mpi.h>
+#include <starpu_mpi_private.h>
+
+struct _callback_arg
+{
+	void (*callback)(void *);
+	void *arg;
+	int nb;
+	int count;
+};
+
+void _callback_collective(void *arg)
+{
+	struct _callback_arg *callback_arg = arg;
+	callback_arg->nb ++;
+	if (callback_arg->nb == callback_arg->count)
+	{
+		callback_arg->callback(callback_arg->arg);
+	}
+}
 
 int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg)
 {
 	int rank;
 	int x;
+	struct _callback_arg *callback_arg;
 
 	MPI_Comm_rank(comm, &rank);
 
+#ifdef STARPU_DEVEL
+#warning callback_arg needs to be free-ed
+#endif
+	callback_arg = malloc(sizeof(struct _callback_arg));
+	callback_arg->count = 0;
+	callback_arg->nb = 0;
+	callback_arg->callback = (rank == root) ? scallback : rcallback;
+	callback_arg->arg = (rank == root) ? sarg : rarg;
+
+	for(x = 0; x < count ;  x++)
+	{
+		if (data_handles[x])
+		{
+			int owner = starpu_data_get_rank(data_handles[x]);
+			int mpi_tag = starpu_data_get_tag(data_handles[x]);
+			STARPU_ASSERT(mpi_tag >= 0);
+			if ((rank == root) && (owner != root))
+			{
+				callback_arg->count ++;
+			}
+			if ((rank != root) && (owner == rank))
+			{
+				callback_arg->count ++;
+			}
+		}
+	}
+
 	for(x = 0; x < count ;  x++)
 	{
 		if (data_handles[x])
@@ -35,12 +82,12 @@ int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, i
 			if ((rank == root) && (owner != root))
 			{
 				//fprintf(stderr, "[%d] Sending data[%d] to %d\n", rank, x, owner);
-				starpu_mpi_isend_detached(data_handles[x], owner, mpi_tag, comm, scallback, sarg);
+				starpu_mpi_isend_detached(data_handles[x], owner, mpi_tag, comm, _callback_collective, callback_arg);
 			}
 			if ((rank != root) && (owner == rank))
 			{
 				//fprintf(stderr, "[%d] Receiving data[%d] from %d\n", rank, x, root);
-				starpu_mpi_irecv_detached(data_handles[x], root, mpi_tag, comm, rcallback, rarg);
+				starpu_mpi_irecv_detached(data_handles[x], root, mpi_tag, comm, _callback_collective, callback_arg);
 			}
 		}
 	}
@@ -51,9 +98,37 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
 {
 	int rank;
 	int x;
+	struct _callback_arg *callback_arg;
 
 	MPI_Comm_rank(comm, &rank);
 
+#ifdef STARPU_DEVEL
+#warning callback_arg needs to be free-ed
+#endif
+	callback_arg = malloc(sizeof(struct _callback_arg));
+	callback_arg->count = 0;
+	callback_arg->nb = 0;
+	callback_arg->callback = (rank == root) ? scallback : rcallback;
+	callback_arg->arg = (rank == root) ? sarg : rarg;
+
+	for(x = 0; x < count ;  x++)
+	{
+		if (data_handles[x])
+		{
+			int owner = starpu_data_get_rank(data_handles[x]);
+			int mpi_tag = starpu_data_get_tag(data_handles[x]);
+			STARPU_ASSERT(mpi_tag >= 0);
+			if ((rank == root) && (owner != root))
+			{
+				callback_arg->count ++;
+			}
+			if ((rank != root) && (owner == rank))
+			{
+				callback_arg->count ++;
+			}
+		}
+	}
+
 	for(x = 0; x < count ;  x++)
 	{
 		if (data_handles[x])
@@ -64,12 +139,12 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
 			if ((rank == root) && (owner != root))
 			{
 				//fprintf(stderr, "[%d] Receiving data[%d] from %d\n", rank, x, owner);
-				starpu_mpi_irecv_detached(data_handles[x], owner, mpi_tag, comm, scallback, sarg);
+				starpu_mpi_irecv_detached(data_handles[x], owner, mpi_tag, comm, _callback_collective, callback_arg);
 			}
 			if ((rank != root) && (owner == rank))
 			{
 				//fprintf(stderr, "[%d] Sending data[%d] to %d\n", rank, x, root);
-				starpu_mpi_isend_detached(data_handles[x], root, mpi_tag, comm, rcallback, rarg);
+				starpu_mpi_isend_detached(data_handles[x], root, mpi_tag, comm, _callback_collective, callback_arg);
 			}
 		}
 	}