Browse Source

mpi/src/starpu_mpi.c: fix mechanism to send data with user-defined datatype.

 The previous method did not wait for the completion of the receive for the size before receiving the data itself, and it lead to incorrect size.
 Waiting for the completion of this request would have meant to call a blocking function from a StarPU callback.
 We now send or receive the size as soon as the starpu-mpi request is posted.
Nathalie Furmento 12 years ago
parent
commit
e390ca2a94
1 changed files with 41 additions and 41 deletions
  1. 41 41
      mpi/src/starpu_mpi.c

+ 41 - 41
mpi/src/starpu_mpi.c

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2009, 2010-2012  Université de Bordeaux 1
- * Copyright (C) 2010, 2011, 2012  Centre National de la Recherche Scientifique
+ * Copyright (C) 2010, 2011, 2012, 2013  Centre National de la Recherche Scientifique
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -69,6 +69,7 @@ static int posted_requests = 0, newer_requests, barrier_running = 0;
 /********************************************************/
 
 static struct _starpu_mpi_req *_starpu_mpi_isend_irecv_common(starpu_data_handle_t data_handle,
+							      size_t size,
 							      int srcdst, int mpi_tag, MPI_Comm comm,
 							      unsigned detached, void (*callback)(void *), void *arg,
 							      enum _starpu_mpi_request_type request_type, void (*func)(struct _starpu_mpi_req *),
@@ -90,6 +91,7 @@ static struct _starpu_mpi_req *_starpu_mpi_isend_irecv_common(starpu_data_handle
 	req->request_type = request_type;
 
 	req->data_handle = data_handle;
+	req->count = size;
 	req->srcdst = srcdst;
 	req->mpi_tag = mpi_tag;
 	req->comm = comm;
@@ -141,13 +143,7 @@ static void _starpu_mpi_isend_data_func(struct _starpu_mpi_req *req)
 	_STARPU_MPI_LOG_OUT();
 }
 
-static void _starpu_mpi_isend_size_callback(void *arg)
-{
-	struct _starpu_mpi_req *req = (struct _starpu_mpi_req *) arg;
-	_starpu_mpi_isend_data_func(req);
-}
-
-static void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
+static void _starpu_mpi_isend_pack_func(struct _starpu_mpi_req *req)
 {
 	_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->user_datatype);
 	if (req->user_datatype == 0)
@@ -158,12 +154,8 @@ static void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
 	}
 	else
 	{
-		starpu_data_handle_t count_handle;
-
 		starpu_handle_pack_data(req->data_handle, &req->ptr, &req->count);
-		starpu_variable_data_register(&count_handle, 0, (uintptr_t)&req->count, sizeof(req->count));
-		_starpu_mpi_isend_common(count_handle, req->srcdst, req->mpi_tag, req->comm, 1, _starpu_mpi_isend_size_callback, req);
-		starpu_data_unregister_submit(count_handle);
+		_starpu_mpi_isend_data_func(req);
 	}
 }
 
@@ -171,7 +163,19 @@ static struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t dat
 							int dest, int mpi_tag, MPI_Comm comm,
 							unsigned detached, void (*callback)(void *), void *arg)
 {
-	return _starpu_mpi_isend_irecv_common(data_handle, dest, mpi_tag, comm, detached, callback, arg, SEND_REQ, _starpu_mpi_isend_size_func, STARPU_R);
+	enum starpu_data_interface_id id = starpu_handle_get_interface_id(data_handle);
+	size_t size;
+
+	size = starpu_handle_get_size(data_handle);
+
+	if (id >= STARPU_MAX_INTERFACE_ID)
+	{
+		starpu_data_handle_t size_handle;
+		starpu_variable_data_register(&size_handle, 0, (uintptr_t)&(size), sizeof(size));
+		starpu_mpi_send(size_handle, dest, mpi_tag, comm);
+	}
+
+	return _starpu_mpi_isend_irecv_common(data_handle, size, dest, mpi_tag, comm, detached, callback, arg, SEND_REQ, _starpu_mpi_isend_pack_func, STARPU_R);
 }
 
 int starpu_mpi_isend(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int mpi_tag, MPI_Comm comm)
@@ -242,27 +246,7 @@ static void _starpu_mpi_irecv_data_func(struct _starpu_mpi_req *req)
 	_STARPU_MPI_LOG_OUT();
 }
 
-struct _starpu_mpi_irecv_size_callback
-{
-	starpu_data_handle_t handle;
-	struct _starpu_mpi_req *req;
-};
-
-static void _starpu_mpi_irecv_size_callback(void *arg)
-{
-	struct _starpu_mpi_irecv_size_callback *callback = (struct _starpu_mpi_irecv_size_callback *)arg;
-
-	starpu_data_unregister(callback->handle);
-	callback->req->ptr = malloc(callback->req->count);
-#ifdef STARPU_DEVEL
-#warning TODO: in some cases, callback->req->count is incorrect, we need to fix that
-#endif
-	STARPU_ASSERT_MSG(callback->req->ptr, "cannot allocate message of size %ld\n", callback->req->count);
-	_starpu_mpi_irecv_data_func(callback->req);
-	free(callback);
-}
-
-static void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req)
+static void _starpu_mpi_irecv_pack_func(struct _starpu_mpi_req *req)
 {
 	_STARPU_MPI_LOG_IN();
 
@@ -275,16 +259,27 @@ static void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req)
 	}
 	else
 	{
-		struct _starpu_mpi_irecv_size_callback *callback = malloc(sizeof(struct _starpu_mpi_irecv_size_callback));
-		callback->req = req;
-		starpu_variable_data_register(&callback->handle, 0, (uintptr_t)&(callback->req->count), sizeof(callback->req->count));
-		_starpu_mpi_irecv_common(callback->handle, req->srcdst, req->mpi_tag, req->comm, 1, _starpu_mpi_irecv_size_callback, callback);
+		req->ptr = malloc(req->count);
+		_starpu_mpi_irecv_data_func(req);
 	}
 }
 
 static struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, unsigned detached, void (*callback)(void *), void *arg)
 {
-	return _starpu_mpi_isend_irecv_common(data_handle, source, mpi_tag, comm, detached, callback, arg, RECV_REQ, _starpu_mpi_irecv_size_func, STARPU_W);
+	enum starpu_data_interface_id id = starpu_handle_get_interface_id(data_handle);
+	size_t size=0;
+
+	if (id >= STARPU_MAX_INTERFACE_ID)
+	{
+		starpu_data_handle_t size_handle;
+		MPI_Status status;
+
+		starpu_variable_data_register(&size_handle, 0, (uintptr_t)&(size), sizeof(size));
+		starpu_mpi_recv(size_handle, source, mpi_tag, comm, &status);
+		starpu_data_unregister(size_handle);
+	}
+
+	return _starpu_mpi_isend_irecv_common(data_handle, size, source, mpi_tag, comm, detached, callback, arg, RECV_REQ, _starpu_mpi_irecv_pack_func, STARPU_W);
 }
 
 int starpu_mpi_irecv(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int source, int mpi_tag, MPI_Comm comm)
@@ -349,8 +344,13 @@ static void _starpu_mpi_probe_func(struct _starpu_mpi_req *req)
 
 int starpu_mpi_irecv_probe_detached(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
 {
+	size_t size;
+
 	_STARPU_MPI_LOG_IN();
-	_starpu_mpi_isend_irecv_common(data_handle, source, mpi_tag, comm, 1, callback, arg, PROBE_REQ, _starpu_mpi_probe_func, STARPU_W);
+
+	size = starpu_handle_get_size(data_handle);
+	_starpu_mpi_isend_irecv_common(data_handle, size, source, mpi_tag, comm, 1, callback, arg, PROBE_REQ, _starpu_mpi_probe_func, STARPU_W);
+
 	_STARPU_MPI_LOG_OUT();
 	return 0;
 }