Selaa lähdekoodia

mpi: re-implement reduction

     src: new function void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle)
	 which performs a reduction on the owner node of the data. The
	 data must have been used previously in codelet with the
	 STARPU_REDUX mode.

     examples/reduction: always display result
Nathalie Furmento 13 vuotta sitten
vanhempi
commit
4391eac2ab
3 muutettua tiedostoa jossa 45 lisäystä ja 20 poistoa
  1. 1 17
      mpi/examples/reduction/mpi_reduction.c
  2. 1 0
      mpi/starpu_mpi.h
  3. 43 3
      mpi/starpu_mpi_insert_task.c

+ 1 - 17
mpi/examples/reduction/mpi_reduction.c

@@ -19,8 +19,6 @@
 
 #define X         7
 
-int display = 0;
-
 extern void init_cpu_func(void *descr[], void *cl_arg);
 extern void redux_cpu_func(void *descr[], void *cl_arg);
 extern void dot_cpu_func(void *descr[], void *cl_arg);
@@ -50,18 +48,6 @@ static struct starpu_codelet dot_codelet =
 	.name = "dot_codelet"
 };
 
-static void parse_args(int argc, char **argv)
-{
-	int i;
-	for (i = 1; i < argc; i++)
-	{
-		if (strcmp(argv[i], "-display") == 0)
-		{
-			display = 1;
-		}
-	}
-}
-
 /* Returns the MPI node number where data indexes index is */
 int my_distrib(int x, int nb_nodes)
 {
@@ -71,7 +57,6 @@ int my_distrib(int x, int nb_nodes)
 int main(int argc, char **argv)
 {
         int my_rank, size, x;
-        int value=0;
         unsigned vector[X];
 	unsigned dot, sum=0;
         starpu_data_handle_t handles[X];
@@ -79,7 +64,6 @@ int main(int argc, char **argv)
 
 	starpu_init(NULL);
 	starpu_mpi_initialize_extended(&my_rank, &size);
-        parse_args(argc, argv);
 
         for(x = 0; x < X; x++)
 	{
@@ -144,7 +128,7 @@ int main(int argc, char **argv)
 	starpu_mpi_shutdown();
 	starpu_shutdown();
 
-	if (display && my_rank == 0)
+	if (my_rank == 0)
 	{
                 fprintf(stderr, "[%d] sum=%d\n", my_rank, sum);
                 fprintf(stderr, "[%d] dot=%d\n", my_rank, dot);

+ 1 - 0
mpi/starpu_mpi.h

@@ -42,6 +42,7 @@ int starpu_mpi_shutdown(void);
 
 int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...);
 void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node);
+void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle);
 
 int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm);
 int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm);

+ 43 - 3
mpi/starpu_mpi_insert_task.c

@@ -186,7 +186,7 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
 		}
 		free(size_on_nodes);
 		if (xrank != -1) {
-			_STARPU_MPI_DEBUG("Node %d is having the most REDUX data\n", xrank);
+			_STARPU_MPI_DEBUG("Node %d is having the most R data\n", xrank);
 			do_execute = 1;
 		}
 	}
@@ -209,8 +209,6 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
 		dest = xrank;
 	}
 
-	_STARPU_MPI_DEBUG("Executing %d - Sending to node %d\n", do_execute, dest);
-
         /* Send and receive data as requested */
 	va_start(varg_list, codelet);
 	while ((arg_type = va_arg(varg_list, int)) != 0) {
@@ -431,3 +429,45 @@ void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle
         }
         starpu_task_wait_for_all();
 }
+
+void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle)
+{
+        int me, rank, tag, nb_nodes;
+
+        rank = starpu_data_get_rank(data_handle);
+        tag = starpu_data_get_tag(data_handle);
+
+	MPI_Comm_rank(comm, &me);
+	MPI_Comm_size(comm, &nb_nodes);
+
+	_STARPU_MPI_DEBUG("Doing reduction for data %p on node %d with %d nodes ...\n", data_handle, rank, nb_nodes);
+
+	// need to count how many nodes have the data in redux mode
+	if (me == rank) {
+		int i;
+
+		STARPU_ASSERT(data_handle->ops->allocate_new_data);
+
+		for(i=0 ; i<nb_nodes ; i++) {
+			if (i != rank) {
+				void *data_interface;
+				starpu_data_handle_t new_handle;
+
+				data_handle->ops->allocate_new_data(data_handle, &data_interface);
+				starpu_data_register(&new_handle, -1, data_interface, data_handle->ops);
+
+				_STARPU_MPI_DEBUG("Receiving redux handle from %d in %p ...\n", i, new_handle);
+
+				starpu_mpi_irecv_detached(new_handle, i, tag, comm, NULL, NULL);
+				starpu_insert_task(data_handle->redux_cl,
+						   STARPU_RW, data_handle,
+						   STARPU_R, new_handle,
+						   0);
+			}
+		}
+	}
+	else {
+		_STARPU_MPI_DEBUG("Sending redux handle to %d ...\n", rank);
+		starpu_mpi_isend_detached(data_handle, rank, tag, comm, NULL, NULL);
+	}
+}