Browse Source

mpi/examples/reduction: minor fixes

Nathalie Furmento 13 years ago
parent
commit
a362e5e42c

+ 10 - 7
mpi/examples/reduction/mpi_reduction.c

@@ -90,6 +90,10 @@ int main(int argc, char **argv)
 		}
 		sum += x+1;
         }
+	if (my_rank == 0) {
+		dot = 14;
+		sum+= dot;
+	}
 
         for(x = 0; x < X; x++)
 	{
@@ -110,12 +114,9 @@ int main(int argc, char **argv)
 		}
 	}
 
-	if (my_rank == 0) {
-		dot = 14;
-		sum+= dot;
-	}
 	starpu_variable_data_register(&dot_handle, 0, (uintptr_t)&dot, sizeof(unsigned));
 	starpu_data_set_rank(dot_handle, 0);
+	starpu_data_set_tag(dot_handle, X+1);
 	starpu_data_set_reduction_methods(dot_handle, &redux_codelet, &init_codelet);
 
 	for (x = 0; x < X; x++)
@@ -126,6 +127,7 @@ int main(int argc, char **argv)
 				       STARPU_REDUX, dot_handle,
 				       0);
 	}
+	starpu_mpi_redux_data(MPI_COMM_WORLD, dot_handle);
 
         fprintf(stderr, "Waiting ...\n");
         starpu_task_wait_for_all();
@@ -142,10 +144,11 @@ int main(int argc, char **argv)
 	starpu_mpi_shutdown();
 	starpu_shutdown();
 
-	if (display)
+	if (display && my_rank == 0)
 	{
-                fprintf(stdout, "[%d] sum=%d\n", my_rank, sum);
-                fprintf(stdout, "[%d] dot=%d\n", my_rank, dot);
+                fprintf(stderr, "[%d] sum=%d\n", my_rank, sum);
+                fprintf(stderr, "[%d] dot=%d\n", my_rank, dot);
+		if (sum != dot) fprintf(stderr, "Error when computing reduction\n");
         }
 
 	return 0;

+ 6 - 6
mpi/examples/reduction/mpi_reduction_kernels.c

@@ -17,10 +17,10 @@
 #include <starpu.h>
 #include <mpi.h>
 
-#define _DISPLAY(fmt, args ...) { \
-		int rank; MPI_Comm_rank(MPI_COMM_WORLD, &rank);		\
-		fprintf(stderr, "[%d][%s] " fmt , rank, __func__ ,##args); \
-		fflush(stderr); }
+#define _DISPLAY(fmt, args ...) do { \
+		int _display_rank; MPI_Comm_rank(MPI_COMM_WORLD, &_display_rank);	\
+		fprintf(stderr, "[%d][%s] " fmt , _display_rank, __func__ ,##args); 	\
+		fflush(stderr); } while(0)
 
 /*
  *	Codelet to create a neutral element
@@ -40,8 +40,8 @@ void redux_cpu_func(void *descr[], void *cl_arg)
 	int *dota = (int *)STARPU_VARIABLE_GET_PTR(descr[0]);
 	int *dotb = (int *)STARPU_VARIABLE_GET_PTR(descr[1]);
 
-	_DISPLAY("Calling redux %d %d\n", *dota, *dotb);
 	*dota = *dota + *dotb;
+	_DISPLAY("Calling redux %d=%d+%d\n", *dota, *dota-*dotb, *dotb);
 }
 
 /*
@@ -53,6 +53,6 @@ void dot_cpu_func(void *descr[], void *cl_arg)
 	int *dot = (int *)STARPU_VARIABLE_GET_PTR(descr[1]);
 
 	*dot += *local_x;
-	_DISPLAY("Calling dot %d %d\n", *dot, *local_x);
+	_DISPLAY("Calling dot=%d=%d+%d\n", *dot, *dot-*local_x, *local_x);
 }