Parcourir la source

sink_to_sink transfers with async mode

Corentin Salingue il y a 8 ans
Parent
commit
fea37039f6
2 fichiers modifiés avec 9 ajouts et 1 suppressions
  1. 7 0
      src/drivers/mp_common/source_common.c
  2. 2 1
      src/drivers/mpi/driver_mpi_common.c

+ 7 - 0
src/drivers/mp_common/source_common.c

@@ -151,6 +151,8 @@ static int _starpu_src_common_handle_async(struct _starpu_mp_node *node,
         {
             struct _starpu_async_channel * event = *((struct _starpu_async_channel **) arg);
             event->starpu_mp_common_finished_receiver--;
+            if (!stored)
+                STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
             break;
         }
         case STARPU_MP_COMMAND_SEND_TO_HOST_ASYNC_COMPLETED:
@@ -158,6 +160,8 @@ static int _starpu_src_common_handle_async(struct _starpu_mp_node *node,
         {
             struct _starpu_async_channel * event = *((struct _starpu_async_channel **) arg);
             event->starpu_mp_common_finished_sender--;
+            if (!stored)
+                STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);
             break;
         }
 		default:
@@ -734,6 +738,9 @@ int _starpu_src_common_copy_sink_to_sink_async(struct _starpu_mp_node *src_node,
     struct _starpu_async_channel * async_channel = event;
     async_channel->polling_node_sender = src_node; 
     async_channel->polling_node_receiver = dst_node; 
+    /* Increase number of ack waited */
+    async_channel->starpu_mp_common_finished_receiver++;
+    async_channel->starpu_mp_common_finished_sender++;
 
 	/* Tell source to send data to dest. */
 	_starpu_mp_common_send_command(src_node, STARPU_MP_COMMAND_SEND_TO_SINK_ASYNC, &cmd, sizeof(cmd));

+ 2 - 1
src/drivers/mpi/driver_mpi_common.c

@@ -339,6 +339,7 @@ void _starpu_mpi_common_recv_from_device(const struct _starpu_mp_node *node, int
     else
     {
         /* Synchronous recv */
+        MPI_Status s;
         res = MPI_Recv(msg, len, MPI_BYTE, src_devid, SYNC_TAG, MPI_COMM_WORLD, &s);
         int num_expected;
         MPI_Get_count(&s, MPI_BYTE, &num_expected);
@@ -411,7 +412,7 @@ int _starpu_mpi_common_test_event(struct _starpu_async_channel * event)
 
     _starpu_mpi_common_polling_node(event->polling_node_sender);
     _starpu_mpi_common_polling_node(event->polling_node_receiver);
-
+    
     return !event->starpu_mp_common_finished_sender && !event->starpu_mp_common_finished_receiver;
 }