Browse Source

mpi cholesky: fix checking result

We have to run the whole syrk on node 0 and check only there.
Samuel Thibault 5 years ago
parent
commit
f724b5ebdc

+ 1 - 1
mpi/examples/matrix_decomposition/mpi_cholesky.c

@@ -58,7 +58,7 @@ int main(int argc, char **argv)
 #ifndef STARPU_SIMGRID
 	matrix_display(bmat, rank);
 
-	if (check)
+	if (check && rank == 0)
 		dw_cholesky_check_computation(bmat, rank, nodes, &correctness, &flops, 0.001);
 #endif
 

+ 12 - 16
mpi/examples/matrix_decomposition/mpi_cholesky_codelets.c

@@ -91,7 +91,7 @@ void dw_cholesky(float ***matA, unsigned ld, int rank, int nodes, double *timing
 		for(n = 0; n < nblocks ; n++)
 		{
 			int mpi_rank = my_distrib(m, n, nodes);
-			if (mpi_rank == rank)
+			if (mpi_rank == rank || (check && rank == 0))
 			{
 				//fprintf(stderr, "[%d] Owning data[%d][%d]\n", rank, n, m);
 				starpu_matrix_data_register(&data_handles[m][n], STARPU_MAIN_RAM, (uintptr_t)matA[m][n],
@@ -170,7 +170,7 @@ void dw_cholesky(float ***matA, unsigned ld, int rank, int nodes, double *timing
 		for(n = 0; n < nblocks ; n++)
 		{
 			/* Get back data on node 0 for the check */
-			if (check)
+			if (check && data_handles[m][n])
 				starpu_mpi_get_data_on_node(MPI_COMM_WORLD, data_handles[m][n], 0);
 
 			if (data_handles[m][n])
@@ -248,24 +248,20 @@ void dw_cholesky_check_computation(float ***matA, int rank, int nodes, int *corr
 	{
 		for (m = 0; m < nblocks; m++)
 		{
-			int mpi_rank = my_distrib(m, n, nodes);
-			if (mpi_rank == rank)
+			for (nn = BLOCKSIZE*n ; nn < BLOCKSIZE*(n+1); nn++)
 			{
-				for (nn = (size/nblocks)*n ; nn < (size/nblocks)*n+(size/nblocks); nn++)
+				for (mm = BLOCKSIZE*m ; mm < BLOCKSIZE*(m+1); mm++)
 				{
-					for (mm = (size/nblocks)*m ; mm < (size/nblocks)*m+(size/nblocks); mm++)
+					if (nn <= mm)
 					{
-						if (nn <= mm)
+						float orig = (1.0f/(1.0f+nn+mm)) + ((nn == mm)?1.0f*size:0.0f);
+						float err = fabsf(test_mat[mm +nn*size] - orig) / orig;
+						if (err > epsilon)
 						{
-							float orig = (1.0f/(1.0f+nn+mm)) + ((nn == mm)?1.0f*size:0.0f);
-							float err = fabsf(test_mat[mm +nn*size] - orig) / orig;
-							if (err > epsilon)
-							{
-								FPRINTF(stderr, "[%d] Error[%u, %u] --> %2.20f != %2.20f (err %2.20f)\n", rank, nn, mm, test_mat[mm +nn*size], orig, err);
-								*correctness = 0;
-								*flops = 0;
-								break;
-							}
+							FPRINTF(stderr, "[%d] Error[%u, %u] --> %2.20f != %2.20f (err %2.20f)\n", rank, nn, mm, test_mat[mm +nn*size], orig, err);
+							*correctness = 0;
+							*flops = 0;
+							break;
 						}
 					}
 				}