Browse Source

mpi/examples/reduction: simplify test

Nathalie Furmento 13 years ago
parent
commit
49cbf57bfc

+ 28 - 54
mpi/examples/reduction/mpi_reduction.c

@@ -44,8 +44,8 @@ static struct starpu_codelet dot_codelet =
 {
 	.where = STARPU_CPU,
 	.cpu_funcs = {dot_cpu_func, NULL},
-	.nbuffers = 3,
-	.modes = {STARPU_R, STARPU_R, STARPU_REDUX}
+	.nbuffers = 2,
+	.modes = {STARPU_R, STARPU_REDUX}
 };
 
 static void parse_args(int argc, char **argv)
@@ -61,21 +61,18 @@ static void parse_args(int argc, char **argv)
 }
 
 /* Returns the MPI node number where data indexes index is */
-int my_distrib(int x, int y, int nb_nodes)
+int my_distrib(int x, int nb_nodes)
 {
-	/* Block distrib */
-	return ((int)(x / sqrt(nb_nodes) + (y / sqrt(nb_nodes)) * sqrt(nb_nodes))) % nb_nodes;
+	return x % nb_nodes;
 }
 
 int main(int argc, char **argv)
 {
-        int my_rank, size, x, y;
+        int my_rank, size, x;
         int value=0;
-        unsigned matrixA[X][Y];
-        unsigned matrixB[X][Y];
+        unsigned vector[X];
 	unsigned dot, sum=0;
-        starpu_data_handle_t handlesA[X][Y];
-        starpu_data_handle_t handlesB[X][Y];
+        starpu_data_handle_t handles[X];
 	starpu_data_handle_t dot_handle;
 
 	starpu_init(NULL);
@@ -84,41 +81,26 @@ int main(int argc, char **argv)
 
         for(x = 0; x < X; x++)
 	{
-                for (y = 0; y < Y; y++)
-		{
-                        matrixA[x][y] = value;
-                        matrixB[x][y] = 10+value;
-                        value++;
-                        sum += matrixA[x][y] + matrixB[x][y];
-                }
+		vector[x] = x;
+		sum += x;
         }
 
         for(x = 0; x < X; x++)
 	{
-                for (y = 0; y < Y; y++)
+		int mpi_rank = my_distrib(x, size);
+		if (mpi_rank == my_rank)
+		{
+			/* Owning data */
+			starpu_variable_data_register(&handles[x], 0, (uintptr_t)&(vector[x]), sizeof(unsigned));
+		}
+		else
+		{
+			starpu_variable_data_register(&handles[x], -1, (uintptr_t)NULL, sizeof(unsigned));
+		}
+		if (handles[x])
 		{
-                        int mpi_rank = my_distrib(x, y, size);
-                        if (mpi_rank == my_rank)
-			{
-				/* Owning data */
-				starpu_variable_data_register(&handlesA[x][y], 0, (uintptr_t)&(matrixA[x][y]), sizeof(unsigned));
-				starpu_variable_data_register(&handlesB[x][y], 0, (uintptr_t)&(matrixB[x][y]), sizeof(unsigned));
-			}
-			else
-			{
-				starpu_variable_data_register(&handlesA[x][y], -1, (uintptr_t)NULL, sizeof(unsigned));
-				starpu_variable_data_register(&handlesB[x][y], -1, (uintptr_t)NULL, sizeof(unsigned));
-			}
-			if (handlesA[x][y])
-			{
-                                starpu_data_set_rank(handlesA[x][y], mpi_rank);
-                                starpu_data_set_tag(handlesA[x][y], (y*X)+x);
-			}
-			if (handlesB[x][y])
-			{
-                                starpu_data_set_rank(handlesB[x][y], mpi_rank);
-                                starpu_data_set_tag(handlesB[x][y], (y*X)+x);
-			}
+			starpu_data_set_rank(handles[x], mpi_rank);
+			starpu_data_set_tag(handles[x], x);
 		}
 	}
 
@@ -128,15 +110,11 @@ int main(int argc, char **argv)
 
 	for (x = 0; x < X; x++)
 	{
-		for (y = 0; y < Y ; y++)
-		{
-			starpu_mpi_insert_task(MPI_COMM_WORLD,
-					       &dot_codelet,
-					       STARPU_R, handlesA[x][y],
-					       STARPU_R, handlesB[x][y],
-					       STARPU_REDUX, dot_handle,
-					       0);
-		}
+		starpu_mpi_insert_task(MPI_COMM_WORLD,
+				       &dot_codelet,
+				       STARPU_R, handles[x],
+				       STARPU_REDUX, dot_handle,
+				       0);
 	}
 
         fprintf(stderr, "Waiting ...\n");
@@ -144,11 +122,7 @@ int main(int argc, char **argv)
 
         for(x = 0; x < X; x++)
 	{
-                for (y = 0; y < Y; y++)
-		{
-			if (handlesA[x][y]) starpu_data_unregister(handlesA[x][y]);
-			if (handlesB[x][y]) starpu_data_unregister(handlesB[x][y]);
-		}
+		if (handles[x]) starpu_data_unregister(handles[x]);
 	}
 	if (dot_handle)
 	{

+ 2 - 3
mpi/examples/reduction/mpi_reduction_kernels.c

@@ -42,9 +42,8 @@ void redux_cpu_func(void *descr[], void *cl_arg)
 void dot_cpu_func(void *descr[], void *cl_arg)
 {
 	int *local_x = (int *)STARPU_VARIABLE_GET_PTR(descr[0]);
-	int *local_y = (int *)STARPU_VARIABLE_GET_PTR(descr[1]);
-	int *dot = (int *)STARPU_VARIABLE_GET_PTR(descr[2]);
+	int *dot = (int *)STARPU_VARIABLE_GET_PTR(descr[1]);
 
-	*dot += *local_x + *local_y;
+	*dot += *local_x;
 }