瀏覽代碼

mpi/examples/stencil: add parameters for command line: -iter and -display

Nathalie Furmento 14 年之前
父節點
當前提交
655a5a5230
共有 1 個文件被更改,包括 30 次插入10 次删除
  1. 30 10
      mpi/examples/stencil/stencil5.c

+ 30 - 10
mpi/examples/stencil/stencil5.c

@@ -35,9 +35,12 @@ starpu_codelet stencil5_cl = {
         .nbuffers = 5
 };
 
-#define NITER 2000
-#define X     15
-#define Y     50
+#define NITER_DEF 2000
+#define X         15
+#define Y         50
+
+int display = 0;
+int niter = NITER_DEF;
 
 /* Returns the MPI node number where data indexes index is */
 int my_distrib(int x, int y, int nb_nodes) {
@@ -48,6 +51,20 @@ int my_distrib(int x, int y, int nb_nodes) {
 }
 
 
+static void parse_args(int argc, char **argv)
+{
+	int i;
+	for (i = 1; i < argc; i++) {
+		if (strcmp(argv[i], "-iter") == 0) {
+			char *argptr;
+			niter = strtol(argv[++i], &argptr, 10);
+		}
+		if (strcmp(argv[i], "-display") == 0) {
+			display = 1;
+		}
+	}
+}
+
 int main(int argc, char **argv)
 {
         int rank, size, x, y, loop;
@@ -57,6 +74,7 @@ int main(int argc, char **argv)
 
 	starpu_init(NULL);
 	starpu_mpi_initialize_extended(1, &rank, &size);
+        parse_args(argc, argv);
 
         for(x = 0; x < X; x++) {
                 for (y = 0; y < Y; y++) {
@@ -88,7 +106,7 @@ int main(int argc, char **argv)
                 }
         }
 
-        for(loop=0 ; loop<NITER; loop++) {
+        for(loop=0 ; loop<niter; loop++) {
                 for (x = 1; x < X-1; x++) {
                         for (y = 1; y < Y-1; y++) {
                                 starpu_mpi_insert_task(MPI_COMM_WORLD, &stencil5_cl, STARPU_RW, data_handles[x][y],
@@ -104,13 +122,15 @@ int main(int argc, char **argv)
 	starpu_mpi_shutdown();
 	starpu_shutdown();
 
-        fprintf(stdout, "[%d] mean=%d\n", rank, mean);
-        for(x = 0; x < X; x++) {
-                fprintf(stdout, "[%d] ", rank);
-                for (y = 0; y < Y; y++) {
-                        fprintf(stdout, "%3d ", matrix[x][y]);
+        if (display) {
+                fprintf(stdout, "[%d] mean=%d\n", rank, mean);
+                for(x = 0; x < X; x++) {
+                        fprintf(stdout, "[%d] ", rank);
+                        for (y = 0; y < Y; y++) {
+                                fprintf(stdout, "%3d ", matrix[x][y]);
+                        }
+                        fprintf(stdout, "\n");
                 }
-                fprintf(stdout, "\n");
         }
 
 	return 0;