Przeglądaj źródła

mpi cholesky: move main loop in its own function

Samuel Thibault 4 lat temu
rodzic
commit
8e07ec3ec5

+ 48 - 43
mpi/examples/matrix_decomposition/mpi_cholesky_codelets.c

@@ -68,6 +68,53 @@ static struct starpu_codelet cl22 =
 	.color = 0x00ff00,
 };
 
+static void run_cholesky(starpu_data_handle_t **data_handles, int rank, int nodes)
+{
+	unsigned k, m, n;
+	unsigned unbound_prio = STARPU_MAX_PRIO == INT_MAX && STARPU_MIN_PRIO == INT_MIN;
+
+	for (k = 0; k < nblocks; k++)
+	{
+		starpu_iteration_push(k);
+
+		starpu_mpi_task_insert(MPI_COMM_WORLD, &cl11,
+				       STARPU_PRIORITY, noprio ? STARPU_DEFAULT_PRIO : unbound_prio ? (int)(2*nblocks - 2*k) : STARPU_MAX_PRIO,
+				       STARPU_RW, data_handles[k][k],
+				       0);
+
+		for (m = k+1; m<nblocks; m++)
+		{
+			starpu_mpi_task_insert(MPI_COMM_WORLD, &cl21,
+					       STARPU_PRIORITY, noprio ? STARPU_DEFAULT_PRIO : unbound_prio ? (int)(2*nblocks - 2*k - m) : (m == k+1)?STARPU_MAX_PRIO:STARPU_DEFAULT_PRIO,
+					       STARPU_R, data_handles[k][k],
+					       STARPU_RW, data_handles[m][k],
+					       0);
+
+			starpu_mpi_cache_flush(MPI_COMM_WORLD, data_handles[k][k]);
+			if (my_distrib(k, k, nodes) == rank)
+				starpu_data_wont_use(data_handles[k][k]);
+
+			for (n = k+1; n<nblocks; n++)
+			{
+				if (n <= m)
+				{
+					starpu_mpi_task_insert(MPI_COMM_WORLD, &cl22,
+							       STARPU_PRIORITY, noprio ? STARPU_DEFAULT_PRIO : unbound_prio ? (int)(2*nblocks - 2*k - m - n) : ((n == k+1) && (m == k+1))?STARPU_MAX_PRIO:STARPU_DEFAULT_PRIO,
+							       STARPU_R, data_handles[n][k],
+							       STARPU_R, data_handles[m][k],
+							       STARPU_RW | STARPU_COMMUTE, data_handles[m][n],
+							       0);
+				}
+			}
+
+			starpu_mpi_cache_flush(MPI_COMM_WORLD, data_handles[m][k]);
+			if (my_distrib(m, k, nodes) == rank)
+				starpu_data_wont_use(data_handles[m][k]);
+		}
+		starpu_iteration_pop();
+	}
+}
+
 /*
  *	code to bootstrap the factorization
  *	and construct the DAG
@@ -79,8 +126,6 @@ void dw_cholesky(float ***matA, unsigned ld, int rank, int nodes, double *timing
 	starpu_data_handle_t **data_handles;
 	unsigned k, m, n;
 
-	unsigned unbound_prio = STARPU_MAX_PRIO == INT_MAX && STARPU_MIN_PRIO == INT_MIN;
-
 	/* create all the DAG nodes */
 
 	data_handles = malloc(nblocks*sizeof(starpu_data_handle_t *));
@@ -119,50 +164,10 @@ void dw_cholesky(float ***matA, unsigned ld, int rank, int nodes, double *timing
 	starpu_mpi_barrier(MPI_COMM_WORLD);
 	start = starpu_timing_now();
 
-	for (k = 0; k < nblocks; k++)
-	{
-		starpu_iteration_push(k);
-
-		starpu_mpi_task_insert(MPI_COMM_WORLD, &cl11,
-				       STARPU_PRIORITY, noprio ? STARPU_DEFAULT_PRIO : unbound_prio ? (int)(2*nblocks - 2*k) : STARPU_MAX_PRIO,
-				       STARPU_RW, data_handles[k][k],
-				       0);
-
-		for (m = k+1; m<nblocks; m++)
-		{
-			starpu_mpi_task_insert(MPI_COMM_WORLD, &cl21,
-					       STARPU_PRIORITY, noprio ? STARPU_DEFAULT_PRIO : unbound_prio ? (int)(2*nblocks - 2*k - m) : (m == k+1)?STARPU_MAX_PRIO:STARPU_DEFAULT_PRIO,
-					       STARPU_R, data_handles[k][k],
-					       STARPU_RW, data_handles[m][k],
-					       0);
-
-			starpu_mpi_cache_flush(MPI_COMM_WORLD, data_handles[k][k]);
-			if (my_distrib(k, k, nodes) == rank)
-				starpu_data_wont_use(data_handles[k][k]);
-
-			for (n = k+1; n<nblocks; n++)
-			{
-				if (n <= m)
-				{
-					starpu_mpi_task_insert(MPI_COMM_WORLD, &cl22,
-							       STARPU_PRIORITY, noprio ? STARPU_DEFAULT_PRIO : unbound_prio ? (int)(2*nblocks - 2*k - m - n) : ((n == k+1) && (m == k+1))?STARPU_MAX_PRIO:STARPU_DEFAULT_PRIO,
-							       STARPU_R, data_handles[n][k],
-							       STARPU_R, data_handles[m][k],
-							       STARPU_RW | STARPU_COMMUTE, data_handles[m][n],
-							       0);
-				}
-			}
-
-			starpu_mpi_cache_flush(MPI_COMM_WORLD, data_handles[m][k]);
-			if (my_distrib(m, k, nodes) == rank)
-				starpu_data_wont_use(data_handles[m][k]);
-		}
-		starpu_iteration_pop();
-	}
+	run_cholesky(data_handles, rank, nodes);
 
 	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
 	starpu_mpi_barrier(MPI_COMM_WORLD);
-
 	end = starpu_timing_now();
 
 	for (m = 0; m < nblocks; m++)