浏览代码

mpi/tests/insert_task.c: update stencil to do a mean on the values, and iterate the computation over the matrix several times

Nathalie Furmento 14 年之前
父节点
当前提交
03c8d19810
共有 1 个文件被更改,包括 23 次插入21 次删除
  1. 23 21
      mpi/tests/insert_task.c

+ 23 - 21
mpi/tests/insert_task.c

@@ -25,8 +25,8 @@ void stencil5_cpu(void *descr[], __attribute__ ((unused)) void *_args)
 	unsigned *xym1 = (unsigned *)STARPU_VARIABLE_GET_PTR(descr[3]);
 	unsigned *xyp1 = (unsigned *)STARPU_VARIABLE_GET_PTR(descr[4]);
 
-        fprintf(stdout, "VALUES: %d %d %d %d %d\n", *xy, *xm1y, *xp1y, *xym1, *xyp1);
-        *xy += *xm1y + *xp1y + *xym1 + *xyp1;
+        //        fprintf(stdout, "VALUES: %d %d %d %d %d\n", *xy, *xm1y, *xp1y, *xym1, *xyp1);
+        *xy = (*xy + *xm1y + *xp1y + *xym1 + *xyp1) / 5;
 }
 
 starpu_codelet stencil5_cl = {
@@ -35,8 +35,9 @@ starpu_codelet stencil5_cl = {
         .nbuffers = 5
 };
 
-#define X 3
-#define Y 5
+#define NITER 2000
+#define X     15
+#define Y     50
 
 /* Returns the MPI node number where data indexes index is */
 int my_distrib(int x, int y, int nb_nodes) {
@@ -49,8 +50,8 @@ int my_distrib(int x, int y, int nb_nodes) {
 
 int main(int argc, char **argv)
 {
-        int rank, size, x, y;
-        int value=10;
+        int rank, size, x, y, loop;
+        int value=0, mean=0;
         unsigned matrix[X][Y];
         starpu_data_handle data_handles[X][Y];
 
@@ -59,22 +60,18 @@ int main(int argc, char **argv)
 
         for(x = 0; x < X; x++) {
                 for (y = 0; y < Y; y++) {
-                        matrix[x][y] = value;
+                        matrix[x][y] = (rank+1)*10 + value;
                         value++;
+                        mean += matrix[x][y];
                 }
         }
-        for(x = 0; x < X; x++) {
-                for (y = 0; y < Y; y++) {
-                        fprintf(stdout, "%4d ", matrix[x][y]);
-                }
-                fprintf(stdout, "\n");
-        }
+        mean /= value;
 
         for(x = 0; x < X; x++) {
                 for (y = 0; y < Y; y++) {
                         int mpi_rank = my_distrib(x, y, size);
                         if (mpi_rank == rank) {
-                                fprintf(stderr, "[%d] Owning data[%d][%d]\n", rank, x, y);
+                                //fprintf(stderr, "[%d] Owning data[%d][%d]\n", rank, x, y);
                                 starpu_variable_data_register(&data_handles[x][y], 0, (uintptr_t)&(matrix[x][y]), sizeof(unsigned));
                         }
                         else if (rank == mpi_rank+1 || rank == mpi_rank-1) {
@@ -92,22 +89,27 @@ int main(int argc, char **argv)
                 }
         }
 
-	for (x = 1; x < X-1; x++) {
-                for (y = 1; y < Y-1; y++) {
-                        starpu_mpi_insert_task(MPI_COMM_WORLD, &stencil5_cl, STARPU_RW, data_handles[x][y],
-                                               STARPU_R, data_handles[x-1][y], STARPU_R, data_handles[x+1][y],
-                                               STARPU_R, data_handles[x][y-1], STARPU_R, data_handles[x][y+1],
-                                               0);
+        for(loop=0 ; loop<NITER; loop++) {
+                for (x = 1; x < X-1; x++) {
+                        for (y = 1; y < Y-1; y++) {
+                                starpu_mpi_insert_task(MPI_COMM_WORLD, &stencil5_cl, STARPU_RW, data_handles[x][y],
+                                                       STARPU_R, data_handles[x-1][y], STARPU_R, data_handles[x+1][y],
+                                                       STARPU_R, data_handles[x][y-1], STARPU_R, data_handles[x][y+1],
+                                                       0);
+                        }
                 }
         }
+        fprintf(stderr, "Waiting ...\n");
         starpu_task_wait_for_all();
 
 	starpu_mpi_shutdown();
 	starpu_shutdown();
 
+        fprintf(stdout, "[%d] mean=%d\n", rank, mean);
         for(x = 0; x < X; x++) {
+                fprintf(stdout, "[%d] ", rank);
                 for (y = 0; y < Y; y++) {
-                        fprintf(stdout, "%4d ", matrix[x][y]);
+                        fprintf(stdout, "%3d ", matrix[x][y]);
                 }
                 fprintf(stdout, "\n");
         }