Переглянути джерело

Prevent wrong thread destination by using different tags

Corentin Salingue 8 роки тому
батько
коміт
776a294c84
1 змінених файлів з 51 додано та 17 видалено
  1. 51 17
      src/drivers/mpi/driver_mpi_common.c

+ 51 - 17
src/drivers/mpi/driver_mpi_common.c

@@ -24,9 +24,6 @@
 static int mpi_initialized;
 static int src_node_id;
 
-#define STARPU_MPI_MS_MSG_TAG 42
-
-
 static void _starpu_mpi_set_src_node_id()
 {
 	int node_id = starpu_get_env_number("STARPU_MPI_MASTER_NODE");
@@ -95,19 +92,24 @@ int _starpu_mpi_common_is_mp_initialized()
 /* common parts to initialize a source or a sink node */
 void _starpu_mpi_common_mp_initialize_src_sink(struct _starpu_mp_node *node)
 {
-    /* TODO : use hwloc */
-    node->nb_cores = 4;
-
-    /* TODO next */
+    struct _starpu_machine_topology *topology = &_starpu_get_machine_config()->topology;
 
+    node->nb_cores = topology->nhwcpus;
 }
 
 int _starpu_mpi_common_recv_is_ready(const struct _starpu_mp_node *mp_node)
 {
-    int res;
+    int res, tag;
     int flag = 0;
+    int id_proc;
+    MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
-    res = MPI_Iprobe(MPI_ANY_SOURCE, STARPU_MPI_MS_MSG_TAG, MPI_COMM_WORLD, &flag, MPI_STATUS_IGNORE);
+    if (id_proc == src_node_id)
+        tag = mp_node->mp_connection.mpi_remote_nodeid;
+    else
+        tag = id_proc;
+    
+    res = MPI_Iprobe(MPI_ANY_SOURCE, tag, MPI_COMM_WORLD, &flag, MPI_STATUS_IGNORE);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot test if we received a message !");
 
     return flag;
@@ -117,8 +119,16 @@ int _starpu_mpi_common_recv_is_ready(const struct _starpu_mp_node *mp_node)
 void _starpu_mpi_common_send(const struct _starpu_mp_node *node, void *msg, int len)
 {
     printf("envoi %d B to %d \n", len, node->mp_connection.mpi_remote_nodeid);
-    int res;
-    res = MPI_Send(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, STARPU_MPI_MS_MSG_TAG, MPI_COMM_WORLD);
+    int res, tag;
+    int id_proc;
+    MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
+
+    if (id_proc == src_node_id)
+        tag = node->mp_connection.mpi_remote_nodeid;
+    else
+        tag = id_proc;
+
+    res = MPI_Send(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, tag, MPI_COMM_WORLD);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
 }
 
@@ -126,23 +136,47 @@ void _starpu_mpi_common_send(const struct _starpu_mp_node *node, void *msg, int
 void _starpu_mpi_common_recv(const struct _starpu_mp_node *node, void *msg, int len)
 {
     printf("recv %d B from %d \n", len, node->mp_connection.mpi_remote_nodeid);
-    int res;
-    res = MPI_Recv(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, STARPU_MPI_MS_MSG_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+    int res, tag;
+    int id_proc;
+    MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
+
+    if (id_proc == src_node_id)
+        tag = node->mp_connection.mpi_remote_nodeid;
+    else
+        tag = id_proc;
+
+    res = MPI_Recv(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
 }
 
 /* SEND to any node */
 void _starpu_mpi_common_send_to_device(const struct _starpu_mp_node *node, int dst_devid, void *msg, int len)
 {   
-    int res;
-    res = MPI_Send(msg, len, MPI_BYTE, dst_devid, STARPU_MPI_MS_MSG_TAG, MPI_COMM_WORLD);
+    int res, tag;
+    int id_proc;
+    MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
+
+    if (id_proc == src_node_id)
+        tag = node->mp_connection.mpi_remote_nodeid;
+    else
+        tag = id_proc;
+
+    res = MPI_Send(msg, len, MPI_BYTE, dst_devid, tag, MPI_COMM_WORLD);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
 }
 
 /* RECV to any node */
 void _starpu_mpi_common_recv_from_device(const struct _starpu_mp_node *node, int src_devid, void *msg, int len)
 {
-    int res;
-    res = MPI_Recv(msg, len, MPI_BYTE, src_devid, STARPU_MPI_MS_MSG_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+    int res, tag;
+    int id_proc;
+    MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
+
+    if (id_proc == src_node_id)
+        tag = node->mp_connection.mpi_remote_nodeid;
+    else
+        tag = id_proc;
+
+    res = MPI_Recv(msg, len, MPI_BYTE, src_devid, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
 }