浏览代码

add MPI slave to MPI slave sync transfers

Corentin Salingue 8 年之前
父节点
当前提交
087b676559

+ 4 - 1
src/datawizard/coherency.c

@@ -261,8 +261,11 @@ static int worker_supports_direct_access(unsigned node, unsigned handling_node)
 			/* TODO: We don't handle direct MIC-MIC transfers yet */
 			return 0;
         case STARPU_MPI_MS_RAM:
+        {
             /* Don't support MPI-MPI transfers yet */
-            return 0;
+            enum starpu_node_kind kind = starpu_node_get_kind(handling_node);
+            return kind == STARPU_MPI_MS_RAM;
+        }
 		case STARPU_SCC_RAM:
 			return 1;
 		default:

+ 1 - 3
src/datawizard/copy_driver.c

@@ -762,8 +762,6 @@ int starpu_interface_copy(uintptr_t src, size_t src_offset, unsigned src_node, u
                     size);
 
     case _STARPU_MEMORY_NODE_TUPLE(STARPU_MPI_MS_RAM, STARPU_MPI_MS_RAM):
-        /* TODO : not used now */
-        STARPU_ABORT();
         if (async_data)
             return _starpu_mpi_copy_sink_to_sink_async(
                     (void*) (src + src_offset), src_node,
@@ -854,7 +852,7 @@ void _starpu_driver_wait_request_completion(struct _starpu_async_channel *async_
 #endif
 #ifdef STARPU_USE_MPI_MASTER_SLAVE
     case STARPU_MPI_MS_RAM:
-        _starpu_mpi_src_wait_event(async_channel);
+        _starpu_mpi_common_wait_event(async_channel);
         break;
 #endif
 	case STARPU_MAIN_RAM:

+ 2 - 1
src/datawizard/copy_driver.h

@@ -110,7 +110,8 @@ struct _starpu_async_channel
 	union _starpu_async_channel_event event;
 	enum starpu_node_kind type;
     /* Which node to polling when needing ACK msg */
-    struct _starpu_mp_node *polling_node;
+    struct _starpu_mp_node *polling_node_sender;
+    struct _starpu_mp_node *polling_node_receiver;
     /* Used to know if the acknowlegdment msg is arrived from sinks */
     volatile int starpu_mp_common_finished_sender; 
     volatile int starpu_mp_common_finished_receiver; 

+ 2 - 0
src/datawizard/data_request.c

@@ -155,6 +155,8 @@ struct _starpu_data_request *_starpu_create_data_request(starpu_data_handle_t ha
 	r->async_channel.type = STARPU_UNUSED;
     r->async_channel.starpu_mp_common_finished_sender = 0;
     r->async_channel.starpu_mp_common_finished_receiver = 0;
+    r->async_channel.polling_node_sender = NULL;
+    r->async_channel.polling_node_receiver = NULL;
 #ifdef STARPU_USE_MPI_MASTER_SLAVE
     r->async_channel.event.mpi_ms_event.requests = NULL;
 #endif

+ 8 - 1
src/drivers/mp_common/mp_common.c

@@ -233,7 +233,8 @@ _starpu_mp_common_node_create(enum _starpu_mp_node_kind node_kind,
 		node->nb_mp_sinks = 
 		node->devid = 
     */
-        node->mp_connection.mpi_remote_nodeid = peer_id+1;
+        node->peer_id = (_starpu_mpi_common_get_src_node() <= peer_id ? peer_id+1 : peer_id);
+        node->mp_connection.mpi_remote_nodeid = node->peer_id;
 
         node->init = _starpu_mpi_source_init;
         node->launch_workers = NULL;
@@ -306,6 +307,8 @@ _starpu_mp_common_node_create(enum _starpu_mp_node_kind node_kind,
 	mp_message_list_init(&node->message_queue);
 	STARPU_PTHREAD_MUTEX_INIT(&node->message_queue_mutex,NULL);
 
+    STARPU_PTHREAD_MUTEX_INIT(&node->connection_mutex, NULL);
+
     _starpu_mp_event_list_init(&node->event_list);
 
 	/* If the node is a sink then we must initialize some field */
@@ -367,6 +370,8 @@ void _starpu_mp_common_send_command(const struct _starpu_mp_node *node,
 {
 	STARPU_ASSERT_MSG(arg_size <= BUFFER_SIZE, "Too much data (%d) for the static MIC buffer (%d), increase BUFFER_SIZE perhaps?", arg_size, BUFFER_SIZE);
 
+    printf("SEND CMD : %d - arg_size %d by %lu \n", command, arg_size, pthread_self());
+
 	/* MIC and MPI sizes are given through a int */
 	int command_size = sizeof(enum _starpu_mp_command);
 	int arg_size_size = sizeof(int);
@@ -401,6 +406,8 @@ enum _starpu_mp_command _starpu_mp_common_recv_command(const struct _starpu_mp_n
 	command = *((enum _starpu_mp_command *) node->buffer);
 	*arg_size = *((int *) ((uintptr_t)node->buffer + command_size));
 
+    printf("RECV command : %d - arg_size %d by %lu \n", command, *arg_size, pthread_self());
+
 	/* If there is no argument (ie. arg_size == 0),
 	 * let's return the command right now */
 	if (!(*arg_size))

+ 5 - 0
src/drivers/mp_common/mp_common.h

@@ -193,6 +193,11 @@ struct _starpu_mp_node
 	 * Connection used for data transfers between the host and his sink. */
 	union _starpu_mp_connection host_sink_dt_connection;
 
+    /* Mutex to protect the interleaving of communications when using one thread per node,
+     * for instance, when a thread transfers piece of data and an other wants to use
+     * a sink_to_sink communication */
+	starpu_pthread_mutex_t connection_mutex;
+
 	/* Only MIC use this for now !!
 	 * Only sink use this for now !!
 	 * Connection used for data transfer between devices.

+ 8 - 4
src/drivers/mp_common/sink_common.c

@@ -137,7 +137,8 @@ static void _starpu_sink_common_copy_from_host_async(struct _starpu_mp_node *mp_
     async_channel->type = STARPU_UNUSED;
     async_channel->starpu_mp_common_finished_sender = -1;
     async_channel->starpu_mp_common_finished_receiver = 0;
-    async_channel->polling_node = NULL;
+    async_channel->polling_node_receiver = NULL;
+    async_channel->polling_node_sender = NULL;
 
     mp_node->dt_recv(mp_node, cmd->addr, cmd->size, &sink_event->event);
     /* Push event on the list */
@@ -182,7 +183,8 @@ static void _starpu_sink_common_copy_to_host_async(struct _starpu_mp_node *mp_no
     async_channel->type = STARPU_UNUSED;
     async_channel->starpu_mp_common_finished_sender = 0;
     async_channel->starpu_mp_common_finished_receiver = -1;
-    async_channel->polling_node = NULL;
+    async_channel->polling_node_receiver = NULL;
+    async_channel->polling_node_sender = NULL;
 
     mp_node->dt_send(mp_node, cmd->addr, cmd->size, &sink_event->event);
     /* Push event on the list */
@@ -221,7 +223,8 @@ static void _starpu_sink_common_copy_from_sink_async(struct _starpu_mp_node *mp_
     async_channel->type = STARPU_UNUSED;
     async_channel->starpu_mp_common_finished_sender = -1;
     async_channel->starpu_mp_common_finished_receiver = 0;
-    async_channel->polling_node = NULL;
+    async_channel->polling_node_receiver = NULL;
+    async_channel->polling_node_sender = NULL;
 
     mp_node->dt_recv_from_device(mp_node, cmd->devid, cmd->addr, cmd->size, &sink_event->event);
     /* Push event on the list */
@@ -260,7 +263,8 @@ static void _starpu_sink_common_copy_to_sink_async(struct _starpu_mp_node *mp_no
     async_channel->type = STARPU_UNUSED;
     async_channel->starpu_mp_common_finished_sender = 0;
     async_channel->starpu_mp_common_finished_receiver = -1;
-    async_channel->polling_node = NULL;
+    async_channel->polling_node_receiver = NULL;
+    async_channel->polling_node_sender = NULL;
 
     mp_node->dt_send_to_device(mp_node, cmd->devid, cmd->addr, cmd->size, &sink_event->event);
 

+ 117 - 19
src/drivers/mp_common/source_common.c

@@ -83,7 +83,7 @@ static int _starpu_src_common_finalize_job (struct _starpu_job *j, struct _starp
 
 
 /* Complete the execution of the job */
-static int _starpu_src_common_process_completed_job(struct _starpu_worker_set *workerset, void * arg, int arg_size)
+static int _starpu_src_common_process_completed_job(struct _starpu_mp_node *node, struct _starpu_worker_set *workerset, void * arg, int arg_size, int stored)
 {
 	int coreid;
 
@@ -96,6 +96,10 @@ static int _starpu_src_common_process_completed_job(struct _starpu_worker_set *w
 
 	struct _starpu_worker * old_worker = _starpu_get_local_worker_key();
 
+    /* if arg is not copied we release the mutex */
+    if (!stored)
+        STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
+
 	_starpu_set_local_worker_key(worker);
 	_starpu_src_common_finalize_job (j, worker);
 	_starpu_set_local_worker_key(old_worker);
@@ -105,12 +109,17 @@ static int _starpu_src_common_process_completed_job(struct _starpu_worker_set *w
 }
 
 /* Tell the scheduler when the execution has begun */
-static void _starpu_src_common_pre_exec(void * arg, int arg_size)
+static void _starpu_src_common_pre_exec(struct _starpu_mp_node *node, void * arg, int arg_size, int stored)
 {
 	int cb_workerid, i;
 	STARPU_ASSERT(sizeof(cb_workerid) == arg_size);
 	cb_workerid = *(int *) arg;
 	struct _starpu_combined_worker *combined_worker = _starpu_get_combined_worker_struct(cb_workerid);
+
+    /* if arg is not copied we release the mutex */
+    if (!stored)
+        STARPU_PTHREAD_MUTEX_LOCK(&node->connection_mutex);
+
 	for(i=0; i < combined_worker->worker_size; i++)
 	{
 		struct _starpu_worker * worker = _starpu_get_worker_struct(combined_worker->combined_workerid[i]);
@@ -123,19 +132,19 @@ static void _starpu_src_common_pre_exec(void * arg, int arg_size)
  * return 0 if the message has not been handle (it's certainly mean that it's a synchronous message)
  * return 1 if the message has been handle
  */
-static int _starpu_src_common_handle_async(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED,
+static int _starpu_src_common_handle_async(struct _starpu_mp_node *node,
 		void * arg, int arg_size,
-		enum _starpu_mp_command answer)
+		enum _starpu_mp_command answer, int stored)
 {
 	struct _starpu_worker_set * worker_set = NULL;
 	switch(answer)
 	{
 		case STARPU_MP_COMMAND_EXECUTION_COMPLETED:
 			worker_set = _starpu_get_worker_struct(starpu_worker_get_id())->set;
-			_starpu_src_common_process_completed_job(worker_set, arg, arg_size);
+			_starpu_src_common_process_completed_job(node, worker_set, arg, arg_size, stored);
 			break;
 		case STARPU_MP_COMMAND_PRE_EXECUTION:
-			_starpu_src_common_pre_exec(arg,arg_size);
+			_starpu_src_common_pre_exec(node, arg,arg_size, stored);
 			break;
         case STARPU_MP_COMMAND_RECV_FROM_HOST_ASYNC_COMPLETED:
         case STARPU_MP_COMMAND_RECV_FROM_SINK_ASYNC_COMPLETED:
@@ -170,7 +179,7 @@ static void _starpu_src_common_handle_stored_async(struct _starpu_mp_node *node)
         /* Release mutex during handle */
 	    STARPU_PTHREAD_MUTEX_UNLOCK(&node->message_queue_mutex);
 		_starpu_src_common_handle_async(node, message->buffer,
-				message->size, message->type);
+				message->size, message->type, 1);
 		free(message->buffer);
 		mp_message_delete(message);
         /* Take it again */
@@ -246,7 +255,7 @@ static void _starpu_src_common_recv_async(struct _starpu_mp_node * node)
 	void *arg;
 	int arg_size;
 	answer = _starpu_mp_common_recv_command(node, &arg, &arg_size);
-	if(!_starpu_src_common_handle_async(node,arg,arg_size,answer))
+	if(!_starpu_src_common_handle_async(node,arg,arg_size,answer, 0))
 	{
 		printf("incorrect commande: unknown command or sync command");
 		STARPU_ASSERT(0);
@@ -288,13 +297,15 @@ static void _starpu_src_common_recv_async(struct _starpu_mp_node * node)
 
 
 /* Send a request to the sink NODE for the number of cores on it. */
-int _starpu_src_common_sink_nbcores (const struct _starpu_mp_node *node, int *buf)
+int _starpu_src_common_sink_nbcores (struct _starpu_mp_node *node, int *buf)
 {
 
 	enum _starpu_mp_command answer;
 	void *arg;
 	int arg_size = sizeof (int);
 
+    STARPU_PTHREAD_MUTEX_LOCK(&node->connection_mutex);
+
 	_starpu_mp_common_send_command (node, STARPU_MP_COMMAND_SINK_NBCORES, NULL, 0);
 
 	answer = _starpu_mp_common_recv_command (node, &arg, &arg_size);
@@ -303,6 +314,8 @@ int _starpu_src_common_sink_nbcores (const struct _starpu_mp_node *node, int *bu
 
 	memcpy (buf, arg, arg_size);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
+
 	return 0;
 }
 
@@ -321,6 +334,8 @@ int _starpu_src_common_lookup(struct _starpu_mp_node *node,
 	/* strlen ignore the terminating '\0' */
 	arg_size = (strlen(func_name) + 1) * sizeof(char);
 
+    STARPU_PTHREAD_MUTEX_LOCK(&node->connection_mutex);
+
 	//_STARPU_DEBUG("Looking up %s\n", func_name);
 	_starpu_mp_common_send_command(node, STARPU_MP_COMMAND_LOOKUP, (void *) func_name,
 			arg_size);
@@ -328,9 +343,11 @@ int _starpu_src_common_lookup(struct _starpu_mp_node *node,
 	answer = _starpu_src_common_wait_command_sync(node, (void **) &arg,
 			&arg_size);
 
+
 	if (answer == STARPU_MP_COMMAND_ERROR_LOOKUP)
 	{
 		_STARPU_DISP("Error looking up symbol %s\n", func_name);
+        STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
 		return -ESPIPE;
 	}
 
@@ -341,6 +358,8 @@ int _starpu_src_common_lookup(struct _starpu_mp_node *node,
 
 	memcpy(func_ptr, arg, arg_size);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
+
 	//_STARPU_DEBUG("got %p\n", *func_ptr);
 
 	return 0;
@@ -435,14 +454,22 @@ int _starpu_src_common_execute_kernel(struct _starpu_mp_node *node,
 	if (cl_arg)
 		memcpy((void*) buffer_ptr, cl_arg, cl_arg_size);
 
+    STARPU_PTHREAD_MUTEX_LOCK(&node->connection_mutex);
+
 	_starpu_mp_common_send_command(node, STARPU_MP_COMMAND_EXECUTE, buffer, buffer_size);
+
 	enum _starpu_mp_command answer = _starpu_src_common_wait_command_sync(node, &arg, &arg_size);
 
 	if (answer == STARPU_MP_COMMAND_ERROR_EXECUTE)
+    {
+        STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
 		return -EINVAL;
+    }
 
 	STARPU_ASSERT(answer == STARPU_MP_COMMAND_EXECUTION_SUBMITTED);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
+
 	free(buffer);
 
 	return 0;
@@ -502,43 +529,56 @@ int _starpu_src_common_allocate(struct _starpu_mp_node *mp_node,
 	void *arg;
 	int arg_size;
 
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+
 	_starpu_mp_common_send_command(mp_node, STARPU_MP_COMMAND_ALLOCATE, &size,
 			sizeof(size));
 
 	answer = _starpu_src_common_wait_command_sync(mp_node, &arg, &arg_size);
 
 	if (answer == STARPU_MP_COMMAND_ERROR_ALLOCATE)
+    {
+        STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
 		return 1;
+    }
 
 	STARPU_ASSERT(answer == STARPU_MP_COMMAND_ANSWER_ALLOCATE &&
 			arg_size == sizeof(*addr));
     
 	memcpy(addr, arg, arg_size);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
+
 	return 0;
 }
 
 /* Send a request to the sink linked to the MP_NODE to deallocate the memory
  * area pointed by ADDR.
  */
-void _starpu_src_common_free(const struct _starpu_mp_node *mp_node,
+void _starpu_src_common_free(struct _starpu_mp_node *mp_node,
 		void *addr)
 {
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
 	_starpu_mp_common_send_command(mp_node, STARPU_MP_COMMAND_FREE, &addr, sizeof(addr));
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
 }
 
 /* Send SIZE bytes pointed by SRC to DST on the sink linked to the MP_NODE with a
  * synchronous mode.
  */
-int _starpu_src_common_copy_host_to_sink_sync(const struct _starpu_mp_node *mp_node,
+int _starpu_src_common_copy_host_to_sink_sync(struct _starpu_mp_node *mp_node,
 		void *src, void *dst, size_t size)
 {
 	struct _starpu_mp_transfer_command cmd = {size, dst, NULL};
 
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+
 	_starpu_mp_common_send_command(mp_node, STARPU_MP_COMMAND_RECV_FROM_HOST, &cmd, sizeof(cmd));
 
 	mp_node->dt_send(mp_node, src, size, NULL);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
+
 	return 0;
 }
 
@@ -550,16 +590,20 @@ int _starpu_src_common_copy_host_to_sink_async(struct _starpu_mp_node *mp_node,
 {
 	struct _starpu_mp_transfer_command cmd = {size, dst, event};
 
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+
     /* For asynchronous transfers, we save informations
      * to test is they are finished
      */
     struct _starpu_async_channel * async_channel = event;
-    async_channel->polling_node = mp_node;
+    async_channel->polling_node_receiver = mp_node;
 
 	_starpu_mp_common_send_command(mp_node, STARPU_MP_COMMAND_RECV_FROM_HOST_ASYNC, &cmd, sizeof(cmd));
 
 	mp_node->dt_send(mp_node, src, size, event);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
+
 	return -EAGAIN;
 }
 
@@ -574,6 +618,8 @@ int _starpu_src_common_copy_sink_to_host_sync(struct _starpu_mp_node *mp_node,
 	int arg_size;
 	struct _starpu_mp_transfer_command cmd = {size, src, NULL};
 
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+
 	_starpu_mp_common_send_command(mp_node, STARPU_MP_COMMAND_SEND_TO_HOST, &cmd, sizeof(cmd));
 
     answer = _starpu_src_common_wait_command_sync(mp_node, &arg, &arg_size);
@@ -582,6 +628,8 @@ int _starpu_src_common_copy_sink_to_host_sync(struct _starpu_mp_node *mp_node,
 
 	mp_node->dt_recv(mp_node, dst, size, NULL);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
+
 	return 0;
 }
 
@@ -593,16 +641,20 @@ int _starpu_src_common_copy_sink_to_host_async(struct _starpu_mp_node *mp_node,
 {
 	struct _starpu_mp_transfer_command cmd = {size, src, event};
 
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+
     /* For asynchronous transfers, we save informations
      * to test is they are finished
      */
     struct _starpu_async_channel * async_channel = event;
-    async_channel->polling_node = mp_node;
+    async_channel->polling_node_sender = mp_node;
 
 	_starpu_mp_common_send_command(mp_node, STARPU_MP_COMMAND_SEND_TO_HOST_ASYNC, &cmd, sizeof(cmd));
 
 	mp_node->dt_recv(mp_node, dst, size, event);
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
+
 	return -EAGAIN;
 }
 
@@ -610,8 +662,8 @@ int _starpu_src_common_copy_sink_to_host_async(struct _starpu_mp_node *mp_node,
  * to the sink linked to DST_NODE. The latter store them in DST with a synchronous
  * mode.
  */
-int _starpu_src_common_copy_sink_to_sink_sync(const struct _starpu_mp_node *src_node,
-		const struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size)
+int _starpu_src_common_copy_sink_to_sink_sync(struct _starpu_mp_node *src_node,
+		struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size)
 {
 	enum _starpu_mp_command answer;
 	void *arg;
@@ -619,9 +671,24 @@ int _starpu_src_common_copy_sink_to_sink_sync(const struct _starpu_mp_node *src_
 
 	struct _starpu_mp_transfer_command_to_device cmd = {dst_node->peer_id, size, src, NULL};
 
+    /* lock the node with the little peer_id first to prevent deadlock */
+    if (src_node->peer_id > dst_node->peer_id)
+    {
+        STARPU_PTHREAD_MUTEX_LOCK(&dst_node->connection_mutex);
+        STARPU_PTHREAD_MUTEX_LOCK(&src_node->connection_mutex);
+    }
+    else
+    {
+        STARPU_PTHREAD_MUTEX_LOCK(&src_node->connection_mutex);
+        STARPU_PTHREAD_MUTEX_LOCK(&dst_node->connection_mutex);
+    }
+
 	/* Tell source to send data to dest. */
 	_starpu_mp_common_send_command(src_node, STARPU_MP_COMMAND_SEND_TO_SINK, &cmd, sizeof(cmd));
 
+    /* Release the source as fast as possible */
+    STARPU_PTHREAD_MUTEX_UNLOCK(&src_node->connection_mutex);
+
 	cmd.devid = src_node->peer_id;
 	cmd.size = size;
 	cmd.addr = dst;
@@ -630,10 +697,13 @@ int _starpu_src_common_copy_sink_to_sink_sync(const struct _starpu_mp_node *src_
 	_starpu_mp_common_send_command(dst_node, STARPU_MP_COMMAND_RECV_FROM_SINK, &cmd, sizeof(cmd));
 
 	/* Wait for answer from dest to know wether transfer is finished. */
-	answer = _starpu_mp_common_recv_command(dst_node, &arg, &arg_size);
+    answer = _starpu_src_common_wait_command_sync(dst_node, &arg, &arg_size);
 
 	STARPU_ASSERT(answer == STARPU_MP_COMMAND_TRANSFER_COMPLETE);
 
+    /* Release the receiver when we received the acknowlegment */
+    STARPU_PTHREAD_MUTEX_UNLOCK(&dst_node->connection_mutex);
+
 	return 0;
 }
 
@@ -641,20 +711,35 @@ int _starpu_src_common_copy_sink_to_sink_sync(const struct _starpu_mp_node *src_
  * to the sink linked to DST_NODE. The latter store them in DST with an asynchronous
  * mode.
  */
-int _starpu_src_common_copy_sink_to_sink_async(const struct _starpu_mp_node *src_node,
-		const struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size, void * event)
+int _starpu_src_common_copy_sink_to_sink_async(struct _starpu_mp_node *src_node,
+		struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size, void * event)
 {
 	struct _starpu_mp_transfer_command_to_device cmd = {dst_node->peer_id, size, src, event};
 
+    /* lock the node with the little peer_id first to prevent deadlock */
+    if (src_node->peer_id > dst_node->peer_id)
+    {
+        STARPU_PTHREAD_MUTEX_LOCK(&dst_node->connection_mutex);
+        STARPU_PTHREAD_MUTEX_LOCK(&src_node->connection_mutex);
+    }
+    else
+    {
+        STARPU_PTHREAD_MUTEX_LOCK(&src_node->connection_mutex);
+        STARPU_PTHREAD_MUTEX_LOCK(&dst_node->connection_mutex);
+    }
+
     /* For asynchronous transfers, we save informations
      * to test is they are finished
      */
     struct _starpu_async_channel * async_channel = event;
-    async_channel->polling_node = NULL; /* TODO which node ? */
+    async_channel->polling_node_sender = src_node; 
+    async_channel->polling_node_receiver = dst_node; 
 
 	/* Tell source to send data to dest. */
 	_starpu_mp_common_send_command(src_node, STARPU_MP_COMMAND_SEND_TO_SINK_ASYNC, &cmd, sizeof(cmd));
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&src_node->connection_mutex);
+
 	cmd.devid = src_node->peer_id;
 	cmd.size = size;
 	cmd.addr = dst;
@@ -662,6 +747,7 @@ int _starpu_src_common_copy_sink_to_sink_async(const struct _starpu_mp_node *src
 	/* Tell dest to receive data from source. */
 	_starpu_mp_common_send_command(dst_node, STARPU_MP_COMMAND_RECV_FROM_SINK_ASYNC, &cmd, sizeof(cmd));
 
+    STARPU_PTHREAD_MUTEX_UNLOCK(&dst_node->connection_mutex);
 
 	return -EAGAIN;
 }
@@ -829,6 +915,8 @@ static void _starpu_src_common_send_workers(struct _starpu_mp_node * node, int b
 	msg[3] = baseworkerid;
 	msg[4] = starpu_worker_get_count();
 
+    STARPU_PTHREAD_MUTEX_LOCK(&node->connection_mutex);
+
 	/* tell the sink node that we will send him all workers */
 	_starpu_mp_common_send_command(node, STARPU_MP_COMMAND_SYNC_WORKERS,
 			&msg, sizeof(msg));
@@ -838,6 +926,8 @@ static void _starpu_src_common_send_workers(struct _starpu_mp_node * node, int b
 
 	/* Send all combined workers to the sink node */
 	node->dt_send(node, &config->combined_workers,combined_worker_size, NULL);
+
+    STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
 }
 
 
@@ -858,9 +948,17 @@ static void _starpu_src_common_worker_internal_work(struct _starpu_worker_set *
     /* Handle message which have been store */
     _starpu_src_common_handle_stored_async(mp_node);
 
+    STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+
     /* poll the device for completed jobs.*/
     while(mp_node->mp_recv_is_ready(mp_node))
+    {
         _starpu_src_common_recv_async(mp_node);
+        /* Mutex is unlock in _starpu_src_common_recv_async */
+        STARPU_PTHREAD_MUTEX_LOCK(&mp_node->connection_mutex);
+    }
+
+    STARPU_PTHREAD_MUTEX_UNLOCK(&mp_node->connection_mutex);
 
     /* get task for each worker*/
     res |= _starpu_get_multi_worker_task(worker_set->workers, tasks, worker_set->nworkers, memnode);

+ 8 - 8
src/drivers/mp_common/source_common.h

@@ -33,7 +33,7 @@ int _starpu_src_common_store_message(struct _starpu_mp_node *node,
 
 enum _starpu_mp_command _starpu_src_common_wait_completed_execution(struct _starpu_mp_node *node, int devid, void **arg, int * arg_size);
 
-int _starpu_src_common_sink_nbcores (const struct _starpu_mp_node *node, int *buf);
+int _starpu_src_common_sink_nbcores (struct _starpu_mp_node *node, int *buf);
 
 int _starpu_src_common_lookup(const struct _starpu_mp_node *node,
 			      void (**func_ptr)(void), const char *func_name);
@@ -41,7 +41,7 @@ int _starpu_src_common_lookup(const struct _starpu_mp_node *node,
 int _starpu_src_common_allocate(const struct _starpu_mp_node *mp_node,
 				void **addr, size_t size);
 
-void _starpu_src_common_free(const struct _starpu_mp_node *mp_node,
+void _starpu_src_common_free(struct _starpu_mp_node *mp_node,
 			     void *addr);
 
 int _starpu_src_common_execute_kernel(const struct _starpu_mp_node *node,
@@ -54,23 +54,23 @@ int _starpu_src_common_execute_kernel(const struct _starpu_mp_node *node,
 				      void *cl_arg, size_t cl_arg_size);
 
 
-int _starpu_src_common_copy_host_to_sink_sync(const struct _starpu_mp_node *mp_node,
+int _starpu_src_common_copy_host_to_sink_sync(struct _starpu_mp_node *mp_node,
 					 void *src, void *dst, size_t size);
 
 int _starpu_src_common_copy_sink_to_host_sync(struct _starpu_mp_node *mp_node,
 					 void *src, void *dst, size_t size);
 
-int _starpu_src_common_copy_sink_to_sink_sync(const struct _starpu_mp_node *src_node,
-					 const struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size);
+int _starpu_src_common_copy_sink_to_sink_sync(struct _starpu_mp_node *src_node,
+					 struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size);
 
-int _starpu_src_common_copy_host_to_sink_async(const struct _starpu_mp_node *mp_node,
+int _starpu_src_common_copy_host_to_sink_async(struct _starpu_mp_node *mp_node,
 					 void *src, void *dst, size_t size, void *event);
 
 int _starpu_src_common_copy_sink_to_host_async(struct _starpu_mp_node *mp_node,
 					 void *src, void *dst, size_t size, void *event);
 
-int _starpu_src_common_copy_sink_to_sink_async(const struct _starpu_mp_node *src_node,
-					 const struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size, void *event);
+int _starpu_src_common_copy_sink_to_sink_async(struct _starpu_mp_node *src_node,
+					 struct _starpu_mp_node *dst_node, void *src, void *dst, size_t size, void *event);
 
 int _starpu_src_common_locate_file(char *located_file_name,
 				   const char *env_file_name, const char *env_mic_path,

+ 71 - 17
src/drivers/mpi/driver_mpi_common.c

@@ -264,7 +264,7 @@ void _starpu_mpi_common_send_to_device(const struct _starpu_mp_node *node, int d
     int id_proc;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
-    printf("send %d bytes from %d from %p\n", len, dst_devid, msg);
+    printf("S_to_D send %d bytes from %d from %p\n", len, dst_devid, msg);
 
     if (event)
     {
@@ -308,7 +308,7 @@ void _starpu_mpi_common_recv_from_device(const struct _starpu_mp_node *node, int
     int id_proc;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
-    printf("nop recv %d bytes from %d\n", len, src_devid);
+    printf("R_to_D nop recv %d bytes from %d\n", len, src_devid);
 
     if (event)
     {
@@ -339,11 +339,37 @@ void _starpu_mpi_common_recv_from_device(const struct _starpu_mp_node *node, int
     else
     {
         /* Synchronous recv */
-        res = MPI_Recv(msg, len, MPI_BYTE, src_devid, SYNC_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+        res = MPI_Recv(msg, len, MPI_BYTE, src_devid, SYNC_TAG, MPI_COMM_WORLD, &s);
+        int num_expected;
+        MPI_Get_count(&s, MPI_BYTE, &num_expected);
+
+        STARPU_ASSERT_MSG(num_expected == len, "MPI Master/Slave received a msg with a size of %d Bytes (expected %d Bytes) !", num_expected, len);
         STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
     }
 }
 
+static void _starpu_mpi_common_polling_node(struct _starpu_mp_node * node)
+{
+    /* poll the asynchronous messages.*/
+    if (node != NULL)
+    {
+        STARPU_PTHREAD_MUTEX_LOCK(&node->connection_mutex);
+        while(node->mp_recv_is_ready(node))
+        {
+            enum _starpu_mp_command answer;
+            void *arg;
+            int arg_size;
+            answer = _starpu_mp_common_recv_command(node, &arg, &arg_size);
+            if(!_starpu_src_common_store_message(node,arg,arg_size,answer))
+            {
+                printf("incorrect commande: unknown command or sync command");
+                STARPU_ASSERT(0);
+            }
+        }
+        STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
+    }
+}
+
  /* - In device to device communications, the first ack received by host
  * is considered as the sender (but it cannot be, in fact, the sender)
  */
@@ -383,29 +409,57 @@ int _starpu_mpi_common_test_event(struct _starpu_async_channel * event)
         }
     }
 
-    /* poll the asynchronous messages.*/
-    if (event->polling_node != NULL)
+    _starpu_mpi_common_polling_node(event->polling_node_sender);
+    _starpu_mpi_common_polling_node(event->polling_node_receiver);
+
+    return !event->starpu_mp_common_finished_sender && !event->starpu_mp_common_finished_receiver;
+}
+
+
+/* - In device to device communications, the first ack received by host
+ * is considered as the sender (but it cannot be, in fact, the sender)
+ */
+void _starpu_mpi_common_wait_event(struct _starpu_async_channel * event)
+{
+    if (event->event.mpi_ms_event.requests != NULL && !_starpu_mpi_ms_event_request_list_empty(event->event.mpi_ms_event.requests))
     {
-        while(event->polling_node->mp_recv_is_ready(event->polling_node))
+        struct _starpu_mpi_ms_event_request * req = _starpu_mpi_ms_event_request_list_begin(event->event.mpi_ms_event.requests);
+        struct _starpu_mpi_ms_event_request * req_next;
+
+        while (req != _starpu_mpi_ms_event_request_list_end(event->event.mpi_ms_event.requests))
         {
-            enum _starpu_mp_command answer;
-            void *arg;
-            int arg_size;
-            answer = _starpu_mp_common_recv_command(event->polling_node, &arg, &arg_size);
-            if(!_starpu_src_common_store_message(event->polling_node,arg,arg_size,answer))
-            {
-                printf("incorrect commande: unknown command or sync command");
-                STARPU_ASSERT(0);
-            }
+            req_next = _starpu_mpi_ms_event_request_list_next(req);
+
+            MPI_Wait(&req->request, MPI_STATUS_IGNORE);
+            _starpu_mpi_ms_event_request_list_erase(event->event.mpi_ms_event.requests, req);
+
+            _starpu_mpi_ms_event_request_delete(req);
+            req = req_next;
+
+            if (event->event.mpi_ms_event.is_sender)
+                event->starpu_mp_common_finished_sender--;
+            else
+                event->starpu_mp_common_finished_receiver--;
+
         }
+
+        STARPU_ASSERT_MSG(_starpu_mpi_ms_event_request_list_empty(event->event.mpi_ms_event.requests), "MPI Request list is not empty after a wait_event !");
+
+        /* Destroy the list */
+        _starpu_mpi_ms_event_request_list_delete(event->event.mpi_ms_event.requests);
+        event->event.mpi_ms_event.requests = NULL;
     }
 
-    return !event->starpu_mp_common_finished_sender && !event->starpu_mp_common_finished_receiver;
+    //incoming ack from devices
+    while(event->starpu_mp_common_finished_sender > 0 || event->starpu_mp_common_finished_receiver > 0)
+    {
+        _starpu_mpi_common_polling_node(event->polling_node_sender);
+        _starpu_mpi_common_polling_node(event->polling_node_receiver);
+    }
 }
 
 
 
-
 void _starpu_mpi_common_barrier(void)
 {
     MPI_Barrier(MPI_COMM_WORLD);

+ 1 - 0
src/drivers/mpi/driver_mpi_common.h

@@ -44,6 +44,7 @@ void _starpu_mpi_common_recv_from_device(const struct _starpu_mp_node *node, int
 void _starpu_mpi_common_send_to_device(const struct _starpu_mp_node *node, int dst_devid, void *msg, int len, void * event);
 
 int _starpu_mpi_common_test_event(struct _starpu_async_channel * event);
+void _starpu_mpi_common_wait_event(struct _starpu_async_channel * event);
 
 void _starpu_mpi_common_barrier(void);
 

+ 1 - 56
src/drivers/mpi/driver_mpi_source.c

@@ -76,7 +76,7 @@ int _starpu_mpi_src_allocate_memory(void ** addr, size_t size, unsigned memory_n
 
 void _starpu_mpi_source_free_memory(void *addr, unsigned memory_node)
 {
-	const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
+	struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
     _starpu_src_common_free(mp_node, addr);
 }
 
@@ -124,61 +124,6 @@ int _starpu_mpi_copy_sink_to_sink_async(void *src, unsigned src_node, void *dst,
             src, dst, size, event);
 }
 
-/* - In device to device communications, the first ack received by host
- * is considered as the sender (but it cannot be, in fact, the sender)
- */
-void _starpu_mpi_src_wait_event(struct _starpu_async_channel * event)
-{
-    if (event->event.mpi_ms_event.requests != NULL && !_starpu_mpi_ms_event_request_list_empty(event->event.mpi_ms_event.requests))
-    {
-        struct _starpu_mpi_ms_event_request * req = _starpu_mpi_ms_event_request_list_begin(event->event.mpi_ms_event.requests);
-        struct _starpu_mpi_ms_event_request * req_next;
-
-        while (req != _starpu_mpi_ms_event_request_list_end(event->event.mpi_ms_event.requests))
-        {
-            req_next = _starpu_mpi_ms_event_request_list_next(req);
-
-            MPI_Wait(&req->request, MPI_STATUS_IGNORE);
-            _starpu_mpi_ms_event_request_list_erase(event->event.mpi_ms_event.requests, req);
-
-            _starpu_mpi_ms_event_request_delete(req);
-            req = req_next;
-
-            if (event->event.mpi_ms_event.is_sender)
-                event->starpu_mp_common_finished_sender--;
-            else
-                event->starpu_mp_common_finished_receiver--;
-
-        }
-
-        STARPU_ASSERT_MSG(_starpu_mpi_ms_event_request_list_empty(event->event.mpi_ms_event.requests), "MPI Request list is not empty after a wait_event !");
-
-        /* Destroy the list */
-        _starpu_mpi_ms_event_request_list_delete(event->event.mpi_ms_event.requests);
-        event->event.mpi_ms_event.requests = NULL;
-    }
-
-    //XXX: Maybe cause deadlock when the same thread is waiting here and cannot handle
-    //incoming ack from devices
-    while(event->starpu_mp_common_finished_sender > 0 || event->starpu_mp_common_finished_receiver > 0)
-        /* poll the asynchronous messages.*/
-        if (event->polling_node != NULL)
-        {
-            while(event->polling_node->mp_recv_is_ready(event->polling_node))
-            {
-                enum _starpu_mp_command answer;
-                void *arg;
-                int arg_size;
-                answer = _starpu_mp_common_recv_command(event->polling_node, &arg, &arg_size);
-                if(!_starpu_src_common_store_message(event->polling_node,arg,arg_size,answer))
-                {
-                    printf("incorrect commande: unknown command or sync command");
-                    STARPU_ASSERT(0);
-                }
-            }
-        }
-}
-
 
 int _starpu_mpi_ms_src_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name)
 {

+ 0 - 2
src/drivers/mpi/driver_mpi_source.h

@@ -45,8 +45,6 @@ int _starpu_mpi_copy_mpi_to_ram_async(void *src, unsigned src_node, void *dst, u
 int _starpu_mpi_copy_ram_to_mpi_async(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size, void * event);
 int _starpu_mpi_copy_sink_to_sink_async(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size, void * event);
 
-void _starpu_mpi_src_wait_event(struct _starpu_async_channel * event);
-
 void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void);
 
 #endif /* STARPU_USE_MPI_MASTER_SLAVE */