浏览代码

Add starpu_mpi_get_data_on_node_detached which does not wait for termination

Samuel Thibault 13 年之前
父节点
当前提交
2f6989798d
共有 4 个文件被更改,包括 44 次插入27 次删除
  1. 2 1
      mpi/starpu_mpi.h
  2. 23 6
      mpi/starpu_mpi_insert_task.c
  3. 18 19
      mpi/tests/insert_task_owner2.c
  4. 1 1
      mpi/tests/insert_task_owner_data.c

+ 2 - 1
mpi/starpu_mpi.h

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2009-2011  Université de Bordeaux 1
+ * Copyright (C) 2009-2012  Université de Bordeaux 1
  * Copyright (C) 2010, 2011  Centre National de la Recherche Scientifique
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -42,6 +42,7 @@ int starpu_mpi_shutdown(void);
 
 int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...);
 void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node);
+void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg);
 void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle);
 
 int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm);

+ 23 - 6
mpi/starpu_mpi_insert_task.c

@@ -410,6 +410,26 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
         return 0;
 }
 
+void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
+{
+        int me, rank, tag;
+
+        rank = starpu_data_get_rank(data_handle);
+        tag = starpu_data_get_tag(data_handle);
+	MPI_Comm_rank(comm, &me);
+
+        if (node == rank) return;
+
+        if (me == node)
+        {
+		starpu_mpi_irecv_detached(data_handle, rank, tag, comm, callback, arg);
+        }
+        else if (me == rank)
+        {
+		starpu_mpi_isend_detached(data_handle, node, tag, comm, callback, arg);
+        }
+}
+
 void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node)
 {
         int me, rank, tag;
@@ -422,16 +442,13 @@ void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle
 
         if (me == node)
         {
-                starpu_mpi_irecv_detached(data_handle, rank, tag, comm, NULL, NULL);
+                MPI_Status status;
+                starpu_mpi_recv(data_handle, rank, tag, comm, &status);
         }
         else if (me == rank)
         {
-                starpu_mpi_isend_detached(data_handle, node, tag, comm, NULL, NULL);
+                starpu_mpi_send(data_handle, node, tag, comm);
         }
-#ifdef STARPU_DEVEL
-#warning TODO: wait for completion of these communication only instead
-#endif
-        starpu_task_wait_for_all();
 }
 
 void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle)

+ 18 - 19
mpi/tests/insert_task_owner2.c

@@ -59,26 +59,15 @@ int main(int argc, char **argv)
 	ret = starpu_mpi_initialize_extended(&rank, &size);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_initialize_extended");
 
-        if (rank > 1)
-	{
-                starpu_mpi_shutdown();
-                starpu_shutdown();
-                return STARPU_TEST_SKIPPED;
-        }
-
         if (rank == 0)
 	{
                 for(i=0 ; i<3 ; i++)
 		{
                         x[i] = 10*(i+1);
                         starpu_variable_data_register(&data_handles[i], 0, (uintptr_t)&x[i], sizeof(x[i]));
-                        starpu_data_set_rank(data_handles[i], rank);
-			starpu_data_set_tag(data_handles[i], i);
                 }
                 y = -1;
                 starpu_variable_data_register(&data_handles[3], -1, (uintptr_t)NULL, sizeof(int));
-                starpu_data_set_rank(data_handles[3], 1);
-		starpu_data_set_tag(data_handles[3], 3);
         }
         else if (rank == 1)
 	{
@@ -86,16 +75,24 @@ int main(int argc, char **argv)
 		{
                         x[i] = -1;
                         starpu_variable_data_register(&data_handles[i], -1, (uintptr_t)NULL, sizeof(int));
-                        starpu_data_set_rank(data_handles[i], 0);
-			starpu_data_set_tag(data_handles[i], i);
                 }
                 y=200;
                 starpu_variable_data_register(&data_handles[3], 0, (uintptr_t)&y, sizeof(int));
-                starpu_data_set_rank(data_handles[3], rank);
-		starpu_data_set_tag(data_handles[3], 3);
-        }
+        } else
+	{
+                for(i=0 ; i<4 ; i++)
+                        starpu_variable_data_register(&data_handles[i], -1, (uintptr_t)NULL, sizeof(int));
+	}
         FPRINTF(stderr, "[%d][init] VALUES: %d %d %d %d\n", rank, x[0], x[1], x[2], y);
 
+	for(i=0 ; i<3 ; i++)
+	{
+		starpu_data_set_rank(data_handles[i], 0);
+		starpu_data_set_tag(data_handles[i], i);
+	}
+	starpu_data_set_rank(data_handles[3], 1);
+	starpu_data_set_tag(data_handles[3], 3);
+
         err = starpu_mpi_insert_task(MPI_COMM_WORLD, &mycodelet,
                                      STARPU_R, data_handles[0], STARPU_RW, data_handles[1],
                                      STARPU_W, data_handles[2],
@@ -107,9 +104,11 @@ int main(int argc, char **argv)
         int *values = malloc(4 * sizeof(int *));
         for(i=0 ; i<4 ; i++)
 	{
-                starpu_mpi_get_data_on_node(MPI_COMM_WORLD, data_handles[i], 0);
-                starpu_data_acquire(data_handles[i], STARPU_R);
-                values[i] = *((int *)starpu_mpi_handle_to_ptr(data_handles[i]));
+                starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, data_handles[i], 0, NULL, NULL);
+		if (rank == 0) {
+			starpu_data_acquire(data_handles[i], STARPU_R);
+			values[i] = *((int *)starpu_mpi_handle_to_ptr(data_handles[i]));
+		}
         }
         FPRINTF(stderr, "[%d][local ptr] VALUES: %d %d %d %d\n", rank, values[0], values[1], values[2], values[3]);
         FPRINTF(stderr, "[%d][end] VALUES: %d %d %d %d\n", rank, x[0], x[1], x[2], y);

+ 1 - 1
mpi/tests/insert_task_owner_data.c

@@ -81,7 +81,7 @@ int main(int argc, char **argv)
 
         for(i=0 ; i<2 ; i++)
 	{
-                starpu_mpi_get_data_on_node(MPI_COMM_WORLD, data_handles[i], 0);
+                starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, data_handles[i], 0, NULL, NULL);
 		if (rank == 0) {
 			starpu_data_acquire(data_handles[i], STARPU_R);
 			values[i] = *((int *)starpu_mpi_handle_to_ptr(data_handles[i]));