Kaynağa Gözat

- Provide the status returned by the MPI_Irecv function.
- In addition to the datatype, MPI calls also need the address of the first
element of the message: this is now given by starpu_mpi_handle_to_ptr.

Cédric Augonnet 15 yıl önce
ebeveyn
işleme
1470ba3508

+ 12 - 11
mpi/starpu_mpi.c

@@ -36,18 +36,18 @@ int starpu_mpi_irecv(starpu_data_handle data_handle, starpu_mpi_req_t *req,
 }
 
 int starpu_mpi_recv(starpu_data_handle data_handle,
-		int source, int mpi_tag, MPI_Comm comm)
+		int source, int mpi_tag, MPI_Comm comm, MPI_Status *status)
 {
-	/* TODO test if we are blocking in a callback .. */
+	/* test if we are blocking in a callback .. */
+	int ret = starpu_sync_data_with_mem(data_handle, STARPU_W);
+	if (ret)
+		return ret;
 
-	starpu_sync_data_with_mem(data_handle, STARPU_W);
-
-	void *ptr = (void *)starpu_get_vector_local_ptr(data_handle);
+	void *ptr = starpu_mpi_handle_to_ptr(data_handle);
 	
-	MPI_Status status;
 	MPI_Datatype datatype;
 	starpu_mpi_handle_to_datatype(data_handle, &datatype);
-	MPI_Recv(ptr, 1, datatype, source, mpi_tag, comm, &status);
+	MPI_Recv(ptr, 1, datatype, source, mpi_tag, comm, status);
 
 	starpu_release_data_from_mem(data_handle);
 
@@ -57,11 +57,12 @@ int starpu_mpi_recv(starpu_data_handle data_handle,
 int starpu_mpi_send(starpu_data_handle data_handle,
 		int dest, int mpi_tag, MPI_Comm comm)
 {
-	/* TODO test if we are blocking in a callback .. */
-
-	starpu_sync_data_with_mem(data_handle, STARPU_R);
+	/* test if we are blocking in a callback .. */
+	int ret = starpu_sync_data_with_mem(data_handle, STARPU_R);
+	if (ret)
+		return ret;
 
-	void *ptr = (void *)starpu_get_vector_local_ptr(data_handle);
+	void *ptr = starpu_mpi_handle_to_ptr(data_handle);
 	
 	MPI_Status status;
 	MPI_Datatype datatype;

+ 1 - 1
mpi/starpu_mpi.h

@@ -37,7 +37,7 @@ int starpu_mpi_irecv(starpu_data_handle data_handle, starpu_mpi_req_t *req,
 int starpu_mpi_send(starpu_data_handle data_handle,
 		int dest, int mpi_tag, MPI_Comm comm);
 int starpu_mpi_recv(starpu_data_handle data_handle,
-		int source, int mpi_tag, MPI_Comm comm);
+		int source, int mpi_tag, MPI_Comm comm, MPI_Status *status);
 int starpu_mpi_wait(starpu_mpi_req_t *req);
 int starpu_mpi_test(starpu_mpi_req_t *req, int *flag);
 int starpu_mpi_initialize(void);

+ 39 - 0
mpi/starpu_mpi_datatype.c

@@ -16,7 +16,17 @@
 
 #include <starpu_mpi_datatype.h>
 
+/*
+ *	MPI_* functions usually requires both a pointer to the first element of
+ *	a datatype and the datatype itself, so we need to provide both.
+ */
+
 typedef int (*handle_to_datatype_func)(starpu_data_handle, MPI_Datatype *);
+typedef void *(*handle_to_ptr_func)(starpu_data_handle);
+
+/*
+ * 	Vector
+ */
 
 static int handle_to_datatype_vector(starpu_data_handle data_handle, MPI_Datatype *datatype)
 {
@@ -29,6 +39,15 @@ static int handle_to_datatype_vector(starpu_data_handle data_handle, MPI_Datatyp
 	return 0;
 }
 
+static void *handle_to_ptr_vector(starpu_data_handle data_handle)
+{
+	return (void *)starpu_get_vector_local_ptr(data_handle);
+}
+
+/*
+ *	Generic
+ */
+
 static handle_to_datatype_func handle_to_datatype_funcs[STARPU_NINTERFACES_ID] = {
 	[STARPU_BLAS_INTERFACE_ID]	= NULL,
 	[STARPU_BLOCK_INTERFACE_ID]	= NULL,
@@ -38,6 +57,15 @@ static handle_to_datatype_func handle_to_datatype_funcs[STARPU_NINTERFACES_ID] =
 	[STARPU_BCSCR_INTERFACE_ID]	= NULL
 };
 
+static handle_to_ptr_func handle_to_ptr_funcs[STARPU_NINTERFACES_ID] = {
+	[STARPU_BLAS_INTERFACE_ID]	= NULL,
+	[STARPU_BLOCK_INTERFACE_ID]	= NULL,
+	[STARPU_VECTOR_INTERFACE_ID]	= handle_to_ptr_vector,
+	[STARPU_CSR_INTERFACE_ID]	= NULL,
+	[STARPU_CSC_INTERFACE_ID]	= NULL,
+	[STARPU_BCSCR_INTERFACE_ID]	= NULL
+};
+
 int starpu_mpi_handle_to_datatype(starpu_data_handle data_handle, MPI_Datatype *datatype)
 {
 	unsigned id = starpu_get_handle_interface_id(data_handle);
@@ -48,3 +76,14 @@ int starpu_mpi_handle_to_datatype(starpu_data_handle data_handle, MPI_Datatype *
 
 	return func(data_handle, datatype);
 }
+
+void *starpu_mpi_handle_to_ptr(starpu_data_handle data_handle)
+{
+	unsigned id = starpu_get_handle_interface_id(data_handle);
+
+	handle_to_ptr_func func = handle_to_ptr_funcs[id];
+	
+	STARPU_ASSERT(func);
+
+	return func(data_handle);
+}

+ 1 - 0
mpi/starpu_mpi_datatype.h

@@ -20,5 +20,6 @@
 #include <starpu_mpi.h>
 
 int starpu_mpi_handle_to_datatype(starpu_data_handle data_handle, MPI_Datatype *datatype);
+void *starpu_mpi_handle_to_ptr(starpu_data_handle data_handle);
 
 #endif // __STARPU_MPI_DATATYPE_H__

+ 1 - 1
mpi/tests/pingpong.c

@@ -59,7 +59,7 @@ int main(int argc, char **argv)
 			starpu_mpi_send(tab_handle, other_rank, loop, MPI_COMM_WORLD);
 		}
 		else {
-			starpu_mpi_recv(tab_handle, other_rank, loop, MPI_COMM_WORLD);
+			starpu_mpi_recv(tab_handle, other_rank, loop, MPI_COMM_WORLD, NULL);
 		}
 	}
 	

+ 1 - 1
mpi/tests/ring.c

@@ -90,7 +90,7 @@ int main(int argc, char **argv)
 		if (!((loop == 0) && (rank == 0)))
 		{
 			token = 0;
-			starpu_mpi_recv(token_handle, (rank+size-1)%size, tag, MPI_COMM_WORLD);
+			starpu_mpi_recv(token_handle, (rank+size-1)%size, tag, MPI_COMM_WORLD, NULL);
 		}
 		else {
 			token = 0;