Browse Source

nmad: split src/starpu_mpi_nmad.c into src/starpu_mpi.c and src/nmad/starpu_mpi_nmad.c for nmad specific functions

Nathalie Furmento 7 years ago
parent
commit
dc839f4a0b

+ 1 - 0
mpi/src/starpu_mpi.c

@@ -159,6 +159,7 @@ int starpu_mpi_irecv(starpu_data_handle_t data_handle, starpu_mpi_req *public_re
 	_STARPU_MPI_TRACE_IRECV_COMPLETE_BEGIN(source, data_tag);
 	req = _starpu_mpi_irecv_common(data_handle, source, data_tag, comm, 0, 0, NULL, NULL, 1, 0, 0);
 	_STARPU_MPI_TRACE_IRECV_COMPLETE_END(source, data_tag);
+
 	STARPU_MPI_ASSERT_MSG(req, "Invalid return for _starpu_mpi_irecv_common");
 	*public_req = req;
 

+ 2 - 1
nmad/src/Makefile.am

@@ -68,7 +68,8 @@ noinst_HEADERS =					\
 	starpu_mpi_init.h
 
 libstarpumpi_@STARPU_EFFECTIVE_VERSION@_la_SOURCES =	\
-	starpu_mpi_nmad.c				\
+	nmad/starpu_mpi_nmad.c				\
+	starpu_mpi.c					\
 	starpu_mpi_helper.c				\
 	starpu_mpi_datatype.c				\
 	starpu_mpi_task_insert.c			\

+ 17 - 398
nmad/src/starpu_mpi_nmad.c

@@ -25,8 +25,6 @@
 #include <starpu_mpi_stats.h>
 #include <starpu_mpi_cache.h>
 #include <starpu_mpi_select_node.h>
-#include <starpu_mpi_tag.h>
-#include <starpu_mpi_comm.h>
 #include <starpu_mpi_init.h>
 #include <common/config.h>
 #include <common/thread.h>
@@ -40,12 +38,6 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req,n
 #ifdef STARPU_VERBOSE
 static char *_starpu_mpi_request_type(enum _starpu_mpi_request_type request_type);
 #endif
-static struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t data_handle,
-							int dest, int data_tag, MPI_Comm comm,
-							unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg);
-static struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle,
-							int source, int data_tag, MPI_Comm comm,
-							unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency);
 static void _starpu_mpi_handle_new_request(void *arg);
 
 static void _starpu_mpi_handle_pending_request(struct _starpu_mpi_req *req);
@@ -61,6 +53,8 @@ static starpu_pthread_cond_t progress_cond;
 static starpu_pthread_mutex_t progress_mutex;
 static volatile int running = 0;
 
+extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
+
 /* Count requests posted by the application and not yet submitted to MPI, i.e pushed into the new_requests list */
 
 static volatile int pending_request = 0;
@@ -72,7 +66,7 @@ static callback_lfstack_t callback_stack = NULL;
 
 static starpu_sem_t callback_sem;
 
-static void _starpu_mpi_request_init(struct _starpu_mpi_req **req)
+void _starpu_mpi_request_init(struct _starpu_mpi_req **req)
 {
 	_STARPU_MPI_CALLOC(*req, 1, sizeof(struct _starpu_mpi_req));
 
@@ -134,7 +128,7 @@ static void _starpu_mpi_request_init(struct _starpu_mpi_req **req)
 #endif
 }
 
-static void _starpu_mpi_request_destroy(struct _starpu_mpi_req *req)
+void _starpu_mpi_request_destroy(struct _starpu_mpi_req *req)
 {
 	piom_cond_destroy(&(req->req_cond));
 	free(req);
@@ -151,12 +145,14 @@ static void nop_acquire_cb(void *arg)
 	starpu_data_release(arg);
 }
 
-static struct _starpu_mpi_req *_starpu_mpi_isend_irecv_common(starpu_data_handle_t data_handle,
-							      int srcdst, int data_tag, MPI_Comm comm,
-							      unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg,
-							      enum _starpu_mpi_request_type request_type, void (*func)(struct _starpu_mpi_req *),
-							      enum starpu_data_access_mode mode,
-							      int sequential_consistency)
+struct _starpu_mpi_req *_starpu_mpi_isend_irecv_common(starpu_data_handle_t data_handle,
+						       int srcdst, int data_tag, MPI_Comm comm,
+						       unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg,
+						       enum _starpu_mpi_request_type request_type, void (*func)(struct _starpu_mpi_req *),
+						       enum starpu_data_access_mode mode,
+						       int sequential_consistency,
+						       int is_internal_req,
+						       starpu_ssize_t count)
 {
 
 	struct _starpu_mpi_req *req;
@@ -237,7 +233,7 @@ static void _starpu_mpi_isend_data_func(struct _starpu_mpi_req *req)
 	_STARPU_MPI_LOG_OUT();
 }
 
-static void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
+void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
 {
 	_starpu_mpi_datatype_allocate(req->data_handle, req);
 
@@ -288,103 +284,6 @@ static void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
 	_starpu_mpi_isend_data_func(req);
 }
 
-static struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t data_handle,
-							int dest, int data_tag, MPI_Comm comm,
-							 unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg)
-{
-	return _starpu_mpi_isend_irecv_common(data_handle, dest, data_tag, comm, detached, sync, prio, callback, arg, SEND_REQ, _starpu_mpi_isend_size_func, STARPU_R,1);
-}
-
-int starpu_mpi_isend_prio(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, int prio, MPI_Comm comm)
-{
-	_STARPU_MPI_LOG_IN();
-	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_isend needs a valid starpu_mpi_req");
-
-	struct _starpu_mpi_req *req;
-	_STARPU_MPI_TRACE_ISEND_COMPLETE_BEGIN(dest, data_tag, 0);
-	req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 0, 0, prio, NULL, NULL);
-	_STARPU_MPI_TRACE_ISEND_COMPLETE_END(dest, data_tag, 0);
-
-	STARPU_MPI_ASSERT_MSG(req, "Invalid return for _starpu_mpi_isend_common");
-	*public_req = req;
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_isend(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, MPI_Comm comm)
-{
-	return starpu_mpi_isend_prio(data_handle, public_req, dest, data_tag, 0, comm);
-}
-
-int starpu_mpi_isend_detached_prio(starpu_data_handle_t data_handle, int dest, int data_tag, int prio, MPI_Comm comm, void (*callback)(void *), void *arg)
-{
-	_STARPU_MPI_LOG_IN();
-	_starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 1, 0, prio, callback, arg);
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_isend_detached(starpu_data_handle_t data_handle, int dest, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
-{
-	return starpu_mpi_isend_detached_prio(data_handle, dest, data_tag, 0, comm, callback, arg);
-}
-
-int starpu_mpi_send_prio(starpu_data_handle_t data_handle, int dest, int data_tag, int prio, MPI_Comm comm)
-{
-	starpu_mpi_req req;
-	MPI_Status status;
-
-	_STARPU_MPI_LOG_IN();
-	memset(&status, 0, sizeof(MPI_Status));
-
-	starpu_mpi_isend_prio(data_handle, &req, dest, data_tag, prio, comm);
-	starpu_mpi_wait(&req, &status);
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_send(starpu_data_handle_t data_handle, int dest, int data_tag, MPI_Comm comm)
-{
-	return starpu_mpi_send_prio(data_handle, dest, data_tag, 0, comm);
-}
-
-int starpu_mpi_issend_prio(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, int prio, MPI_Comm comm)
-{
-	_STARPU_MPI_LOG_IN();
-	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_issend needs a valid starpu_mpi_req");
-
-	struct _starpu_mpi_req *req;
-	req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 0, 1, prio, NULL, NULL);
-
-	STARPU_MPI_ASSERT_MSG(req, "Invalid return for _starpu_mpi_isend_common");
-	*public_req = req;
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_issend(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, MPI_Comm comm)
-{
-	return starpu_mpi_issend_prio(data_handle, public_req, dest, data_tag, 0, comm);
-}
-
-int starpu_mpi_issend_detached_prio(starpu_data_handle_t data_handle, int dest, int data_tag, int prio, MPI_Comm comm, void (*callback)(void *), void *arg)
-{
-	_STARPU_MPI_LOG_IN();
-
-	_starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 1, 1, prio, callback, arg);
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_issend_detached(starpu_data_handle_t data_handle, int dest, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
-{
-	return starpu_mpi_issend_detached_prio(data_handle, dest, data_tag, 0, comm, callback, arg);
-}
-
 /********************************************************/
 /*                                                      */
 /*  Receive functionalities                             */
@@ -430,7 +329,7 @@ static void _starpu_mpi_irecv_size_callback(void *arg)
 	free(callback);
 }
 
-static void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req)
+void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req)
 {
 	_STARPU_MPI_LOG_IN();
 
@@ -447,77 +346,18 @@ static void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req)
 		callback->req = req;
 		starpu_variable_data_register(&callback->handle, 0, (uintptr_t)&(callback->req->count), sizeof(callback->req->count));
 		_STARPU_MPI_DEBUG(4, "Receiving size with tag %d from node %d\n", req->node_tag.data_tag, req->node_tag.rank);
-		_starpu_mpi_irecv_common(callback->handle, req->node_tag.rank, req->node_tag.data_tag, req->node_tag.comm, 1, 0, _starpu_mpi_irecv_size_callback, callback,1);
+		_starpu_mpi_irecv_common(callback->handle, req->node_tag.rank, req->node_tag.data_tag, req->node_tag.comm, 1, 0, _starpu_mpi_irecv_size_callback, callback,1,0,0);
 	}
 
 }
 
-static struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency)
-{
-	return _starpu_mpi_isend_irecv_common(data_handle, source, mpi_tag, comm, detached, sync, 0, callback, arg, RECV_REQ, _starpu_mpi_irecv_size_func, STARPU_W,sequential_consistency);
-}
-
-int starpu_mpi_irecv(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int source, int mpi_tag, MPI_Comm comm)
-{
-	_STARPU_MPI_LOG_IN();
-	STARPU_ASSERT_MSG(public_req, "starpu_mpi_irecv needs a valid starpu_mpi_req");
-
-	struct _starpu_mpi_req *req;
-	_STARPU_MPI_TRACE_IRECV_COMPLETE_BEGIN(source, mpi_tag);
-	req = _starpu_mpi_irecv_common(data_handle, source, mpi_tag, comm, 0, 0, NULL, NULL,1);
-	_STARPU_MPI_TRACE_IRECV_COMPLETE_END(source, mpi_tag);
-
-	STARPU_ASSERT_MSG(req, "Invalid return for _starpu_mpi_irecv_common");
-	*public_req = req;
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_irecv_detached(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
-{
-	_STARPU_MPI_LOG_IN();
-	_starpu_mpi_irecv_common(data_handle, source, mpi_tag, comm, 1, 0, callback, arg,1);
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_irecv_detached_sequential_consistency(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg, int sequential_consistency)
-{
-	_STARPU_MPI_LOG_IN();
-
-//	// We check if a tag is defined for the data handle, if not,
-//	// we define the one given for the communication.
-//	// A tag is necessary for the internal mpi engine.
-//	int tag = starpu_data_get_tag(data_handle);
-//	if (tag == -1)
-//		starpu_data_set_tag(data_handle, data_tag);
-
-	_starpu_mpi_irecv_common(data_handle, source, data_tag, comm, 1, 0, callback, arg, sequential_consistency);
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
-int starpu_mpi_recv(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, MPI_Status *status)
-{
-	starpu_mpi_req req;
-
-	_STARPU_MPI_LOG_IN();
-	starpu_mpi_irecv(data_handle, &req, source, mpi_tag, comm);
-	starpu_mpi_wait(&req, status);
-
-	_STARPU_MPI_LOG_OUT();
-	return 0;
-}
-
 /********************************************************/
 /*                                                      */
 /*  Wait functionalities                                */
 /*                                                      */
 /********************************************************/
 
-int starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status)
+int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status)
 {
 	_STARPU_MPI_LOG_IN();
 	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_wait needs a valid starpu_mpi_req");
@@ -546,7 +386,7 @@ int starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status)
 /*                                                      */
 /********************************************************/
 
-int starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
+int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
 {
 	_STARPU_MPI_LOG_IN();
 	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_test needs a valid starpu_mpi_req");
@@ -580,19 +420,6 @@ int starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
 /*                                                      */
 /********************************************************/
 
-int starpu_mpi_barrier(MPI_Comm comm)
-{
-	_STARPU_MPI_LOG_IN();
-	int ret;
-	//	STARPU_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
-	ret = MPI_Barrier(comm);
-
-	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Barrier returning %d", ret);
-
-	_STARPU_MPI_LOG_OUT();
-	return ret;
-}
-
 /********************************************************/
 /*                                                      */
 /*  Progression                                         */
@@ -716,28 +543,6 @@ static void _starpu_mpi_handle_new_request(void *arg)
 	_STARPU_MPI_LOG_OUT();
 }
 
-static void _starpu_mpi_print_thread_level_support(int thread_level, char *msg)
-{
-	switch (thread_level)
-	{
-	case MPI_THREAD_SERIALIZED:
-	{
-		_STARPU_DISP("MPI%s MPI_THREAD_SERIALIZED; Multiple threads may make MPI calls, but only one at a time.\n", msg);
-		break;
-	}
-	case MPI_THREAD_FUNNELED:
-	{
-		_STARPU_DISP("MPI%s MPI_THREAD_FUNNELED; The application can safely make calls to StarPU-MPI functions, but should not call directly MPI communication functions.\n", msg);
-		break;
-	}
-	case MPI_THREAD_SINGLE:
-	{
-		_STARPU_DISP("MPI%s MPI_THREAD_SINGLE; MPI does not have multi-thread support, this might cause problems. The application can make calls to StarPU-MPI functions, but not call directly MPI Communication functions.\n", msg);
-		break;
-	}
-	}
-}
-
 static void *_starpu_mpi_progress_thread_func(void *arg)
 {
 	struct _starpu_mpi_argc_argv *argc_argv = (struct _starpu_mpi_argc_argv *) arg;
@@ -935,189 +740,3 @@ void _starpu_mpi_progress_shutdown(int *value)
         STARPU_PTHREAD_COND_DESTROY(&progress_cond);
 }
 
-void _starpu_mpi_data_clear(starpu_data_handle_t data_handle)
-{
-	_starpu_mpi_cache_data_clear(data_handle);
-	free(data_handle->mpi_data);
-}
-
-void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, int tag, int rank, MPI_Comm comm)
-{
-	struct _starpu_mpi_data *mpi_data;
-	if (data_handle->mpi_data)
-	{
-		mpi_data = data_handle->mpi_data;
-	}
-	else
-	{
-		_STARPU_CALLOC(mpi_data, 1, sizeof(struct _starpu_mpi_data));
-		mpi_data->magic = 42;
-		mpi_data->node_tag.data_tag = -1;
-		mpi_data->node_tag.rank = -1;
-		mpi_data->node_tag.comm = MPI_COMM_WORLD;
-		data_handle->mpi_data = mpi_data;
-		_starpu_mpi_cache_data_init(data_handle);
-		_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_data_clear);
-	}
-
-	if (tag != -1)
-	{
-		mpi_data->node_tag.data_tag = tag;
-	}
-	if (rank != -1)
-	{
-		_STARPU_MPI_TRACE_DATA_SET_RANK(data_handle, rank);
-		mpi_data->node_tag.rank = rank;
-		mpi_data->node_tag.comm = comm;
-	}
-}
-
-void starpu_mpi_data_set_rank_comm(starpu_data_handle_t handle, int rank, MPI_Comm comm)
-{
-	starpu_mpi_data_register_comm(handle, -1, rank, comm);
-}
-
-void starpu_mpi_data_set_tag(starpu_data_handle_t handle, int tag)
-{
-	starpu_mpi_data_register_comm(handle, tag, -1, MPI_COMM_WORLD);
-}
-
-int starpu_mpi_data_get_rank(starpu_data_handle_t data)
-{
-	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
-	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.rank;
-}
-
-int starpu_mpi_data_get_tag(starpu_data_handle_t data)
-{
-	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
-	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.data_tag;
-}
-
-void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
-{
-	int me, rank, tag;
-
-	rank = starpu_mpi_data_get_rank(data_handle);
-	if (rank == -1)
-	{
-		_STARPU_ERROR("StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register() or starpu_mpi_data_register()\n");
-	}
-
-	starpu_mpi_comm_rank(comm, &me);
-	if (node == rank)
-		return;
-
-	tag = starpu_mpi_data_get_tag(data_handle);
-	if (tag == -1)
-	{
-		_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register() or starpu_mpi_data_register()\n");
-	}
-
-	if (me == node)
-	{
-		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		int already_received = _starpu_mpi_cache_received_data_set(data_handle);
-		if (already_received == 0)
-		{
-			_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
-			starpu_mpi_irecv_detached(data_handle, rank, tag, comm, callback, arg);
-		}
-	}
-	else if (me == rank)
-	{
-		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
-		if (already_sent == 0)
-		{
-			_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
-			starpu_mpi_isend_detached(data_handle, node, tag, comm, NULL, NULL);
-		}
-	}
-}
-
-void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node)
-{
-	int me, rank, tag;
-
-	rank = starpu_mpi_data_get_rank(data_handle);
-	if (rank == -1)
-	{
-		_STARPU_ERROR("StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register\n");
-	}
-
-	starpu_mpi_comm_rank(comm, &me);
-	if (node == rank)
-		return;
-
-	tag = starpu_mpi_data_get_tag(data_handle);
-	if (tag == -1)
-	{
-		_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
-	}
-
-	if (me == node)
-	{
-		MPI_Status status;
-		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		int already_received = _starpu_mpi_cache_received_data_set(data_handle);
-		if (already_received == 0)
-		{
-			_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
-			starpu_mpi_recv(data_handle, rank, tag, comm, &status);
-		}
-	}
-	else if (me == rank)
-	{
-		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
-		if (already_sent == 0)
-		{
-			_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
-			starpu_mpi_send(data_handle, node, tag, comm);
-		}
-	}
-}
-
-void starpu_mpi_get_data_on_all_nodes_detached(MPI_Comm comm, starpu_data_handle_t data_handle)
-{
-	int size, i;
-	starpu_mpi_comm_size(comm, &size);
-#ifdef STARPU_DEVEL
-#warning TODO: use binary communication tree to optimize broadcast
-#endif
-	for (i = 0; i < size; i++)
-		starpu_mpi_get_data_on_node_detached(comm, data_handle, i, NULL, NULL);
-}
-
-void starpu_mpi_data_migrate(MPI_Comm comm, starpu_data_handle_t data, int new_rank)
-{
-	int old_rank = starpu_mpi_data_get_rank(data);
-	if (new_rank == old_rank)
-		/* Already there */
-		return;
-
-	/* First submit data migration if it's not already on destination */
-	starpu_mpi_get_data_on_node_detached(comm, data, new_rank, NULL, NULL);
-
-	/* And note new owner */
-	starpu_mpi_data_set_rank_comm(data, new_rank, comm);
-
-	/* Flush cache in all other nodes */
-	/* TODO: Ideally we'd transmit the knowledge of who owns it */
-	starpu_mpi_cache_flush(comm, data);
-	return;
-}
-
-int starpu_mpi_wait_for_all(MPI_Comm comm)
-{
-	int mpi = 1;
-	int task = 1;
-	while (task || mpi)
-	{
-		task = _starpu_task_wait_for_all_and_return_nb_waited_tasks();
-		mpi = starpu_mpi_barrier(comm);
-	}
-	return 0;
-}
-

+ 403 - 0
nmad/src/starpu_mpi.c

@@ -0,0 +1,403 @@
+/* StarPU --- Runtime system for heterogeneous multicore architectures.
+ *
+ * Copyright (C) 2009, 2010-2014, 2017  Université de Bordeaux
+ * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015  Centre National de la Recherche Scientifique
+ *
+ * StarPU is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * StarPU is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#include <stdlib.h>
+#include <limits.h>
+#include <starpu_mpi.h>
+#include <starpu_mpi_datatype.h>
+#include <starpu_mpi_private.h>
+#include <starpu_mpi_cache.h>
+#include <starpu_profiling.h>
+#include <starpu_mpi_stats.h>
+#include <starpu_mpi_cache.h>
+#include <starpu_mpi_select_node.h>
+#include <starpu_mpi_init.h>
+#include <common/config.h>
+#include <common/thread.h>
+#include <datawizard/coherency.h>
+#include <nm_sendrecv_interface.h>
+#include <nm_mpi_nmad.h>
+#include <core/task.h>
+#include <core/topology.h>
+
+static struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t data_handle,
+							int dest, int data_tag, MPI_Comm comm,
+							unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg,
+							int sequential_consistency)
+{
+	return _starpu_mpi_isend_irecv_common(data_handle, dest, data_tag, comm, detached, sync, prio, callback, arg, SEND_REQ, _starpu_mpi_isend_size_func,
+#ifdef STARPU_MPI_PEDANTIC_ISEND
+					      STARPU_RW,
+#else
+					      STARPU_R,
+#endif
+					      sequential_consistency, 0, 0);
+}
+
+int starpu_mpi_isend_prio(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, int prio, MPI_Comm comm)
+{
+	_STARPU_MPI_LOG_IN();
+	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_isend needs a valid starpu_mpi_req");
+
+	struct _starpu_mpi_req *req;
+	_STARPU_MPI_TRACE_ISEND_COMPLETE_BEGIN(dest, data_tag, 0);
+	req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 0, 0, prio, NULL, NULL, 1);
+	_STARPU_MPI_TRACE_ISEND_COMPLETE_END(dest, data_tag, 0);
+
+	STARPU_MPI_ASSERT_MSG(req, "Invalid return for _starpu_mpi_isend_common");
+	*public_req = req;
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_isend(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, MPI_Comm comm)
+{
+	return starpu_mpi_isend_prio(data_handle, public_req, dest, data_tag, 0, comm);
+}
+
+int starpu_mpi_isend_detached_prio(starpu_data_handle_t data_handle, int dest, int data_tag, int prio, MPI_Comm comm, void (*callback)(void *), void *arg)
+{
+	_STARPU_MPI_LOG_IN();
+	_starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 1, 0, prio, callback, arg, 1);
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_isend_detached(starpu_data_handle_t data_handle, int dest, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
+{
+	return starpu_mpi_isend_detached_prio(data_handle, dest, data_tag, 0, comm, callback, arg);
+}
+
+int starpu_mpi_send_prio(starpu_data_handle_t data_handle, int dest, int data_tag, int prio, MPI_Comm comm)
+{
+	starpu_mpi_req req;
+	MPI_Status status;
+
+	_STARPU_MPI_LOG_IN();
+	memset(&status, 0, sizeof(MPI_Status));
+
+	starpu_mpi_isend_prio(data_handle, &req, dest, data_tag, prio, comm);
+	starpu_mpi_wait(&req, &status);
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_send(starpu_data_handle_t data_handle, int dest, int data_tag, MPI_Comm comm)
+{
+	return starpu_mpi_send_prio(data_handle, dest, data_tag, 0, comm);
+}
+
+int starpu_mpi_issend_prio(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, int prio, MPI_Comm comm)
+{
+	_STARPU_MPI_LOG_IN();
+	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_issend needs a valid starpu_mpi_req");
+
+	struct _starpu_mpi_req *req;
+	req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 0, 1, prio, NULL, NULL, 1);
+
+	STARPU_MPI_ASSERT_MSG(req, "Invalid return for _starpu_mpi_isend_common");
+	*public_req = req;
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_issend(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int dest, int data_tag, MPI_Comm comm)
+{
+	return starpu_mpi_issend_prio(data_handle, public_req, dest, data_tag, 0, comm);
+}
+
+int starpu_mpi_issend_detached_prio(starpu_data_handle_t data_handle, int dest, int data_tag, int prio, MPI_Comm comm, void (*callback)(void *), void *arg)
+{
+	_STARPU_MPI_LOG_IN();
+
+	_starpu_mpi_isend_common(data_handle, dest, data_tag, comm, 1, 1, prio, callback, arg, 1);
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_issend_detached(starpu_data_handle_t data_handle, int dest, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
+{
+	return starpu_mpi_issend_detached_prio(data_handle, dest, data_tag, 0, comm, callback, arg);
+}
+
+struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count)
+{
+	return _starpu_mpi_isend_irecv_common(data_handle, source, data_tag, comm, detached, sync, 0, callback, arg, RECV_REQ, _starpu_mpi_irecv_size_func, STARPU_W, sequential_consistency, is_internal_req, count);
+}
+
+int starpu_mpi_irecv(starpu_data_handle_t data_handle, starpu_mpi_req *public_req, int source, int data_tag, MPI_Comm comm)
+{
+	_STARPU_MPI_LOG_IN();
+	STARPU_MPI_ASSERT_MSG(public_req, "starpu_mpi_irecv needs a valid starpu_mpi_req");
+
+	struct _starpu_mpi_req *req;
+	_STARPU_MPI_TRACE_IRECV_COMPLETE_BEGIN(source, data_tag);
+	req = _starpu_mpi_irecv_common(data_handle, source, data_tag, comm, 0, 0, NULL, NULL, 1, 0, 0);
+	_STARPU_MPI_TRACE_IRECV_COMPLETE_END(source, data_tag);
+
+	STARPU_MPI_ASSERT_MSG(req, "Invalid return for _starpu_mpi_irecv_common");
+	*public_req = req;
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_irecv_detached(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
+{
+	_STARPU_MPI_LOG_IN();
+
+	_starpu_mpi_irecv_common(data_handle, source, data_tag, comm, 1, 0, callback, arg, 1, 0, 0);
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_irecv_detached_sequential_consistency(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, void (*callback)(void *), void *arg, int sequential_consistency)
+{
+	_STARPU_MPI_LOG_IN();
+
+	_starpu_mpi_irecv_common(data_handle, source, data_tag, comm, 1, 0, callback, arg, sequential_consistency, 0, 0);
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+int starpu_mpi_recv(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, MPI_Status *status)
+{
+	starpu_mpi_req req;
+
+	_STARPU_MPI_LOG_IN();
+	starpu_mpi_irecv(data_handle, &req, source, data_tag, comm);
+	starpu_mpi_wait(&req, status);
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
+extern int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status);
+int starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status)
+{
+	return _starpu_mpi_wait(public_req, status);
+}
+
+extern int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status);
+int starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
+{
+	return _starpu_mpi_test(public_req, flag, status);
+}
+
+int starpu_mpi_barrier(MPI_Comm comm)
+{
+	_STARPU_MPI_LOG_IN();
+	int ret;
+	//	STARPU_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
+	ret = MPI_Barrier(comm);
+
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Barrier returning %d", ret);
+
+	_STARPU_MPI_LOG_OUT();
+	return ret;
+}
+
+void _starpu_mpi_data_clear(starpu_data_handle_t data_handle)
+{
+	_starpu_mpi_cache_data_clear(data_handle);
+	free(data_handle->mpi_data);
+}
+
+void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, int tag, int rank, MPI_Comm comm)
+{
+	struct _starpu_mpi_data *mpi_data;
+	if (data_handle->mpi_data)
+	{
+		mpi_data = data_handle->mpi_data;
+	}
+	else
+	{
+		_STARPU_CALLOC(mpi_data, 1, sizeof(struct _starpu_mpi_data));
+		mpi_data->magic = 42;
+		mpi_data->node_tag.data_tag = -1;
+		mpi_data->node_tag.rank = -1;
+		mpi_data->node_tag.comm = MPI_COMM_WORLD;
+		data_handle->mpi_data = mpi_data;
+		_starpu_mpi_cache_data_init(data_handle);
+		_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_data_clear);
+	}
+
+	if (tag != -1)
+	{
+		mpi_data->node_tag.data_tag = tag;
+	}
+	if (rank != -1)
+	{
+		_STARPU_MPI_TRACE_DATA_SET_RANK(data_handle, rank);
+		mpi_data->node_tag.rank = rank;
+		mpi_data->node_tag.comm = comm;
+	}
+}
+
+void starpu_mpi_data_set_rank_comm(starpu_data_handle_t handle, int rank, MPI_Comm comm)
+{
+	starpu_mpi_data_register_comm(handle, -1, rank, comm);
+}
+
+void starpu_mpi_data_set_tag(starpu_data_handle_t handle, int tag)
+{
+	starpu_mpi_data_register_comm(handle, tag, -1, MPI_COMM_WORLD);
+}
+
+int starpu_mpi_data_get_rank(starpu_data_handle_t data)
+{
+	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
+	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.rank;
+}
+
+int starpu_mpi_data_get_tag(starpu_data_handle_t data)
+{
+	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
+	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.data_tag;
+}
+
+void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
+{
+	int me, rank, tag;
+
+	rank = starpu_mpi_data_get_rank(data_handle);
+	if (rank == -1)
+	{
+		_STARPU_ERROR("StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register() or starpu_mpi_data_register()\n");
+	}
+
+	starpu_mpi_comm_rank(comm, &me);
+	if (node == rank)
+		return;
+
+	tag = starpu_mpi_data_get_tag(data_handle);
+	if (tag == -1)
+	{
+		_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register() or starpu_mpi_data_register()\n");
+	}
+
+	if (me == node)
+	{
+		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
+		int already_received = _starpu_mpi_cache_received_data_set(data_handle);
+		if (already_received == 0)
+		{
+			_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
+			starpu_mpi_irecv_detached(data_handle, rank, tag, comm, callback, arg);
+		}
+	}
+	else if (me == rank)
+	{
+		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
+		int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
+		if (already_sent == 0)
+		{
+			_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
+			starpu_mpi_isend_detached(data_handle, node, tag, comm, NULL, NULL);
+		}
+	}
+}
+
+void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node)
+{
+	int me, rank, tag;
+
+	rank = starpu_mpi_data_get_rank(data_handle);
+	if (rank == -1)
+	{
+		_STARPU_ERROR("StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register\n");
+	}
+
+	starpu_mpi_comm_rank(comm, &me);
+	if (node == rank)
+		return;
+
+	tag = starpu_mpi_data_get_tag(data_handle);
+	if (tag == -1)
+	{
+		_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
+	}
+
+	if (me == node)
+	{
+		MPI_Status status;
+		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
+		int already_received = _starpu_mpi_cache_received_data_set(data_handle);
+		if (already_received == 0)
+		{
+			_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
+			starpu_mpi_recv(data_handle, rank, tag, comm, &status);
+		}
+	}
+	else if (me == rank)
+	{
+		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
+		int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
+		if (already_sent == 0)
+		{
+			_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
+			starpu_mpi_send(data_handle, node, tag, comm);
+		}
+	}
+}
+
+void starpu_mpi_get_data_on_all_nodes_detached(MPI_Comm comm, starpu_data_handle_t data_handle)
+{
+	int size, i;
+	starpu_mpi_comm_size(comm, &size);
+#ifdef STARPU_DEVEL
+#warning TODO: use binary communication tree to optimize broadcast
+#endif
+	for (i = 0; i < size; i++)
+		starpu_mpi_get_data_on_node_detached(comm, data_handle, i, NULL, NULL);
+}
+
+void starpu_mpi_data_migrate(MPI_Comm comm, starpu_data_handle_t data, int new_rank)
+{
+	int old_rank = starpu_mpi_data_get_rank(data);
+	if (new_rank == old_rank)
+		/* Already there */
+		return;
+
+	/* First submit data migration if it's not already on destination */
+	starpu_mpi_get_data_on_node_detached(comm, data, new_rank, NULL, NULL);
+
+	/* And note new owner */
+	starpu_mpi_data_set_rank_comm(data, new_rank, comm);
+
+	/* Flush cache in all other nodes */
+	/* TODO: Ideally we'd transmit the knowledge of who owns it */
+	starpu_mpi_cache_flush(comm, data);
+	return;
+}
+
+int starpu_mpi_wait_for_all(MPI_Comm comm)
+{
+	int mpi = 1;
+	int task = 1;
+	while (task || mpi)
+	{
+		task = _starpu_task_wait_for_all_and_return_nb_waited_tasks();
+		mpi = starpu_mpi_barrier(comm);
+	}
+	return 0;
+}

+ 0 - 4
nmad/src/starpu_mpi_init.c

@@ -24,11 +24,7 @@
 #include <starpu_profiling.h>
 #include <starpu_mpi_stats.h>
 #include <starpu_mpi_cache.h>
-#include <starpu_mpi_sync_data.h>
-#include <starpu_mpi_early_data.h>
-#include <starpu_mpi_early_request.h>
 #include <starpu_mpi_select_node.h>
-#include <starpu_mpi_tag.h>
 #include <common/config.h>
 #include <common/thread.h>
 #include <datawizard/interfaces/data_interface.h>

+ 18 - 0
nmad/src/starpu_mpi_private.h

@@ -238,6 +238,24 @@ LIST_TYPE(_starpu_mpi_req,
 );
 PRIO_LIST_TYPE(_starpu_mpi_req, prio)
 
+struct _starpu_mpi_req *_starpu_mpi_isend_irecv_common(starpu_data_handle_t data_handle,
+						       int srcdst, int data_tag, MPI_Comm comm,
+						       unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg,
+						       enum _starpu_mpi_request_type request_type, void (*func)(struct _starpu_mpi_req *),
+						       enum starpu_data_access_mode mode,
+						       int sequential_consistency,
+						       int is_internal_req,
+						       starpu_ssize_t count);
+
+void _starpu_mpi_submit_ready_request_inc(struct _starpu_mpi_req *req);
+void _starpu_mpi_request_init(struct _starpu_mpi_req **req);
+void _starpu_mpi_request_destroy(struct _starpu_mpi_req *req);
+void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req);
+void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req);
+void _starpu_mpi_wait_func(struct _starpu_mpi_req *waiting_req);
+void _starpu_mpi_test_func(struct _starpu_mpi_req *testing_req);
+int _starpu_mpi_barrier(MPI_Comm comm);
+
 struct _starpu_mpi_argc_argv
 {
 	int initialize_mpi;