Browse Source

mpi/src/: when exchanging user-defined datatypes, the size of the data is sent as a first message

Nathalie Furmento 12 years ago
parent
commit
b952994b31
1 changed files with 68 additions and 28 deletions
  1. 68 28
      mpi/src/starpu_mpi.c

+ 68 - 28
mpi/src/starpu_mpi.c

@@ -18,7 +18,7 @@
 #include <stdlib.h>
 #include <starpu_mpi.h>
 #include <starpu_mpi_datatype.h>
-//#define STARPU_MPI_VERBOSE	1
+#define STARPU_MPI_VERBOSE	1
 #include <starpu_mpi_private.h>
 #include <starpu_profiling.h>
 #include <starpu_mpi_stats.h>
@@ -33,6 +33,10 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req);
 #ifdef STARPU_MPI_VERBOSE
 static char *_starpu_mpi_request_type(enum _starpu_mpi_request_type request_type);
 #endif
+static struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t data_handle,
+							int dest, int mpi_tag, MPI_Comm comm,
+							unsigned detached, void (*callback)(void *), void *arg);
+static struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, unsigned detached, void (*callback)(void *), void *arg);
 
 /* The list of requests that have been newly submitted by the application */
 static struct _starpu_mpi_req_list *new_requests;
@@ -109,23 +113,13 @@ static struct _starpu_mpi_req *_starpu_mpi_isend_irecv_common(starpu_data_handle
 /*                                                      */
 /********************************************************/
 
-static void _starpu_mpi_isend_func(struct _starpu_mpi_req *req)
+static void _starpu_mpi_isend_data_func(struct _starpu_mpi_req *req)
 {
         _STARPU_MPI_LOG_IN();
 
-	req->needs_unpacking = starpu_mpi_handle_to_datatype(req->data_handle, &req->datatype);
-	if (req->needs_unpacking)
-	{
-		starpu_handle_pack_data(req->data_handle, &req->ptr, &req->count);
-	}
-	else
-	{
-		req->count = 1;
-		req->ptr = starpu_handle_get_local_ptr(req->data_handle);
-	}
 	STARPU_ASSERT(req->ptr);
 
-        _STARPU_MPI_DEBUG("post MPI isend tag %d dst %d ptr %p datatype %p count %d req %p\n", req->mpi_tag, req->srcdst, req->ptr, req->datatype, req->count, &req->request);
+        _STARPU_MPI_DEBUG("post MPI isend tag %d dst %d ptr %p datatype %p count %d req %p\n", req->mpi_tag, req->srcdst, req->ptr, req->datatype, (int)req->count, &req->request);
 
 	_starpu_mpi_comm_amounts_inc(req->comm, req->srcdst, req->datatype, req->count);
 
@@ -142,11 +136,37 @@ static void _starpu_mpi_isend_func(struct _starpu_mpi_req *req)
         _STARPU_MPI_LOG_OUT();
 }
 
+static void _starpu_mpi_isend_size_callback(void *arg)
+{
+	struct _starpu_mpi_req *req = (struct _starpu_mpi_req *) arg;
+	_starpu_mpi_isend_data_func(req);
+}
+
+static void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
+{
+	req->needs_unpacking = starpu_mpi_handle_to_datatype(req->data_handle, &req->datatype);
+	if (!req->needs_unpacking)
+	{
+		req->count = 1;
+		req->ptr = starpu_handle_get_local_ptr(req->data_handle);
+		_starpu_mpi_isend_data_func(req);
+	}
+	else
+	{
+		starpu_data_handle_t count_handle;
+
+		starpu_handle_pack_data(req->data_handle, &req->ptr, &req->count);
+		starpu_variable_data_register(&count_handle, 0, (uintptr_t)&req->count, sizeof(req->count));
+		_starpu_mpi_isend_common(count_handle, req->srcdst, req->mpi_tag, req->comm, 1, _starpu_mpi_isend_size_callback, req);
+		starpu_data_unregister_submit(count_handle);
+	}
+}
+
 static struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t data_handle,
 							int dest, int mpi_tag, MPI_Comm comm,
 							unsigned detached, void (*callback)(void *), void *arg)
 {
-	return _starpu_mpi_isend_irecv_common(data_handle, dest, mpi_tag, comm, detached, callback, arg, SEND_REQ, _starpu_mpi_isend_func, STARPU_R);
+	return _starpu_mpi_isend_irecv_common(data_handle, dest, mpi_tag, comm, detached, callback, arg, SEND_REQ, _starpu_mpi_isend_size_func, STARPU_R);
 }
 
 int starpu_mpi_isend(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int mpi_tag, MPI_Comm comm)
@@ -195,24 +215,13 @@ int starpu_mpi_send(starpu_data_handle_t data_handle, int dest, int mpi_tag, MPI
 /*                                                      */
 /********************************************************/
 
-static void _starpu_mpi_irecv_func(struct _starpu_mpi_req *req)
+static void _starpu_mpi_irecv_data_func(struct _starpu_mpi_req *req)
 {
         _STARPU_MPI_LOG_IN();
 
-	req->needs_unpacking = starpu_mpi_handle_to_datatype(req->data_handle, &req->datatype);
-	if (req->needs_unpacking == 1)
-	{
-		req->count = starpu_handle_get_size(req->data_handle);
-		req->ptr = malloc(req->count);
-	}
-	else
-	{
-		req->count = 1;
-		req->ptr = starpu_handle_get_local_ptr(req->data_handle);
-	}
 	STARPU_ASSERT(req->ptr);
 
-	_STARPU_MPI_DEBUG("post MPI irecv tag %d src %d data %p ptr %p req %p datatype %p\n", req->mpi_tag, req->srcdst, req->data_handle, req->ptr, &req->request, req->datatype);
+	_STARPU_MPI_DEBUG("post MPI irecv tag %d src %d data %p ptr %p datatype %p count %d req %p \n", req->mpi_tag, req->srcdst, req->data_handle, req->ptr, req->datatype, (int)req->count, &req->request);
 
         req->ret = MPI_Irecv(req->ptr, req->count, req->datatype, req->srcdst, req->mpi_tag, req->comm, &req->request);
         STARPU_ASSERT(req->ret == MPI_SUCCESS);
@@ -225,9 +234,40 @@ static void _starpu_mpi_irecv_func(struct _starpu_mpi_req *req)
         _STARPU_MPI_LOG_OUT();
 }
 
+static void _starpu_mpi_irecv_size_callback(void *arg)
+{
+	struct _starpu_mpi_req *req = (struct _starpu_mpi_req *) arg;
+#ifdef STARPU_DEVEL
+#  warning are we sure that req->count can be used as we have not released count_handle?
+#endif
+	req->ptr = malloc(req->count);
+	_starpu_mpi_irecv_data_func(req);
+}
+
+static void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req)
+{
+        _STARPU_MPI_LOG_IN();
+
+	req->needs_unpacking = starpu_mpi_handle_to_datatype(req->data_handle, &req->datatype);
+	if (!req->needs_unpacking)
+	{
+		req->count = 1;
+		req->ptr = starpu_handle_get_local_ptr(req->data_handle);
+		_starpu_mpi_irecv_data_func(req);
+	}
+	else
+	{
+		starpu_data_handle_t count_handle;
+
+		starpu_variable_data_register(&count_handle, 0, (uintptr_t)&req->count, sizeof(req->count));
+		_starpu_mpi_irecv_common(count_handle, req->srcdst, req->mpi_tag, req->comm, 1, _starpu_mpi_irecv_size_callback, req);
+		starpu_data_unregister_submit(count_handle);
+	}
+}
+
 static struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, unsigned detached, void (*callback)(void *), void *arg)
 {
-	return _starpu_mpi_isend_irecv_common(data_handle, source, mpi_tag, comm, detached, callback, arg, RECV_REQ, _starpu_mpi_irecv_func, STARPU_W);
+	return _starpu_mpi_isend_irecv_common(data_handle, source, mpi_tag, comm, detached, callback, arg, RECV_REQ, _starpu_mpi_irecv_size_func, STARPU_W);
 }
 
 int starpu_mpi_irecv(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int source, int mpi_tag, MPI_Comm comm)