Browse Source

fix MPI LU

Cédric Augonnet 15 years ago
parent
commit
c1f052654a

+ 26 - 27
mpi/examples/mpi_lu/plu_example.c

@@ -95,8 +95,7 @@ static void fill_block_with_random(TYPE *blockptr, unsigned size, unsigned nbloc
 	for (j = 0; j < block_size; j++)
 	for (i = 0; i < block_size; i++)
 	{
-	//	blockptr[i+j*block_size] = (TYPE)drand48();
-		blockptr[i+j*block_size] = (i == j)?2.0:(TYPE)j;
+		blockptr[i+j*block_size] = (TYPE)drand48();
 	}
 }
 
@@ -199,8 +198,7 @@ static void init_matrix(int rank)
 			size/nblocks, size/nblocks, size/nblocks, sizeof(TYPE));
 	}
 
-
-	display_all_blocks(nblocks, size/nblocks);
+//	display_all_blocks(nblocks, size/nblocks);
 }
 
 int get_block_rank(unsigned i, unsigned j)
@@ -280,9 +278,8 @@ int main(int argc, char **argv)
 
 	init_matrix(rank);
 
-	TYPE *a_r = STARPU_PLU(reconstruct_matrix)(size, nblocks);
-	STARPU_PLU(display_data_content)(a_r, size);
-	
+	TYPE *a_r;
+//	STARPU_PLU(display_data_content)(a_r, size);
 
 	TYPE *x, *y;
 
@@ -295,26 +292,19 @@ int main(int argc, char **argv)
 
 		y = calloc(size, sizeof(TYPE));
 		STARPU_ASSERT(y);
-		
+
+		a_r = STARPU_PLU(reconstruct_matrix)(size, nblocks);
+	
 		if (rank == 0)
 		{
-			fprintf(stderr, "Compute AX = B\n");
-
 			for (ind = 0; ind < size; ind++)
 			{
-				x[ind] = (TYPE)ind;
-//				x[ind] = (TYPE)drand48();
+				x[ind] = (TYPE)drand48();
 				y[ind] = (TYPE)0.0;
 			}
 		}
 
 		STARPU_PLU(compute_ax)(size, x, y, nblocks, rank);
-
-		if (rank == 0)
-		for (ind = 0; ind < STARPU_MIN(10, size); ind++)
-		{
-			fprintf(stderr, "y[%d] = %f\n", ind, (float)y[ind]);
-		}
 	}
 
 	barrier_ret = MPI_Barrier(MPI_COMM_WORLD);
@@ -359,12 +349,20 @@ int main(int argc, char **argv)
 	 *	Test Result Correctness
 	 */
 
-	STARPU_PLU(compute_lu_matrix)(size, nblocks);
-
 	TYPE *y2;
 
 	if (check)
 	{
+		/*
+		 *	Compute || A - LU ||
+		 */
+
+		STARPU_PLU(compute_lu_matrix)(size, nblocks, a_r);
+
+		/*
+		 *	Compute || Ax - LUx ||
+		 */
+
 		unsigned ind;
 
 		y2 = calloc(size, sizeof(TYPE));
@@ -372,8 +370,6 @@ int main(int argc, char **argv)
 		
 		if (rank == 0)
 		{
-			fprintf(stderr, "Compute LUX = B2\n");
-
 			for (ind = 0; ind < size; ind++)
 			{
 				y2[ind] = (TYPE)0.0;
@@ -382,11 +378,14 @@ int main(int argc, char **argv)
 
 		STARPU_PLU(compute_lux)(size, x, y2, nblocks, rank);
 
-		if (rank == 0)
-		for (ind = 0; ind < STARPU_MIN(10, size); ind++)
-		{
-			fprintf(stderr, "y[%d] = %f\n", ind, (float)y2[ind]);
-		}
+		/* Compute y2 = y2 - y */
+	        CPU_AXPY(size, -1.0, y, 1, y2, 1);
+	
+	        TYPE err = CPU_ASUM(size, y2, 1);
+	        int max = CPU_IAMAX(size, y2, 1);
+	
+	        fprintf(stderr, "(A - LU)X Avg error : %e\n", err/(size*size));
+	        fprintf(stderr, "(A - LU)X Max error : %e\n", y2[max]);
 	}
 
 	/*

+ 100 - 51
mpi/examples/mpi_lu/plu_solve.c

@@ -54,38 +54,64 @@ void STARPU_PLU(extract_upper)(unsigned block_size, TYPE *inblock, TYPE *outbloc
 	}
 }
 
-static STARPU_PLU(compute_ax_block_upper)(unsigned size, unsigned nblocks,
-				 TYPE *block_data, TYPE *sub_x, TYPE *sub_y)
+void STARPU_PLU(extract_lower)(unsigned block_size, TYPE *inblock, TYPE *outblock)
 {
-	unsigned block_size = size/nblocks;
-
-	fprintf(stderr, "KEEP UPPER\n");
-	STARPU_PLU(display_data_content)(block_data, block_size);
-
-	/* Take a copy of the upper part of the diagonal block */
-	TYPE *upper_block_copy = calloc((block_size)*(block_size), sizeof(TYPE));
-	STARPU_PLU(extract_upper)(block_size, block_data, upper_block_copy);
-		
-	STARPU_PLU(display_data_content)(upper_block_copy, block_size);
-
-	STARPU_PLU(compute_ax_block)(size/nblocks, upper_block_copy, sub_x, sub_y);
-	
-	free(upper_block_copy);
+	unsigned li, lj;
+	for (lj = 0; lj < block_size; lj++)
+	{
+		for (li = 0; li <= lj; li++)
+		{
+			outblock[lj + li*block_size] = inblock[lj + li*block_size];
+		}
+	}
 }
 
+#if 0
+void STARPU_PLU(extract_upper)(unsigned block_size, TYPE *inblock, TYPE *outblock)
+{
+	unsigned li, lj;
+	for (lj = 0; lj < block_size; lj++)
+	{
+		for (li = lj; li < block_size; li++)
+		{
+			outblock[lj + li*block_size] = inblock[lj + li*block_size];
+		}
+	}
+}
 
 void STARPU_PLU(extract_lower)(unsigned block_size, TYPE *inblock, TYPE *outblock)
 {
 	unsigned li, lj;
 	for (lj = 0; lj < block_size; lj++)
 	{
-		for (li = 0; li <= lj; li++)
+		for (li = 0; li < lj; li++)
 		{
 			outblock[lj + li*block_size] = inblock[lj + li*block_size];
 		}
+
+		outblock[lj*(block_size + 1)] = (TYPE)1.0;
 	}
 }
+#endif
 
+static STARPU_PLU(compute_ax_block_upper)(unsigned size, unsigned nblocks,
+				 TYPE *block_data, TYPE *sub_x, TYPE *sub_y)
+{
+	unsigned block_size = size/nblocks;
+
+//	fprintf(stderr, "KEEP UPPER\n");
+//	STARPU_PLU(display_data_content)(block_data, block_size);
+
+	/* Take a copy of the upper part of the diagonal block */
+	TYPE *upper_block_copy = calloc((block_size)*(block_size), sizeof(TYPE));
+	STARPU_PLU(extract_upper)(block_size, block_data, upper_block_copy);
+		
+//	STARPU_PLU(display_data_content)(upper_block_copy, block_size);
+
+	STARPU_PLU(compute_ax_block)(size/nblocks, upper_block_copy, sub_x, sub_y);
+	
+	free(upper_block_copy);
+}
 
 TYPE *STARPU_PLU(reconstruct_matrix)(unsigned size, unsigned nblocks)
 {
@@ -135,25 +161,57 @@ static TYPE *reconstruct_upper(unsigned size, unsigned nblocks)
 }
 
 
-void STARPU_PLU(compute_lu_matrix)(unsigned size, unsigned nblocks)
+void STARPU_PLU(compute_lu_matrix)(unsigned size, unsigned nblocks, TYPE *Asaved)
 {
-	fprintf(stderr, "ALL\n\n");
+//	fprintf(stderr, "ALL\n\n");
 	TYPE *all_r = STARPU_PLU(reconstruct_matrix)(size, nblocks);
-	STARPU_PLU(display_data_content)(all_r, size);
+//	STARPU_PLU(display_data_content)(all_r, size);
+
+        TYPE *L = malloc((size_t)size*size*sizeof(TYPE));
+        TYPE *U = malloc((size_t)size*size*sizeof(TYPE));
 
-	fprintf(stderr, "\nLOWER\n");
-	TYPE *lower_r = reconstruct_lower(size, nblocks);
-	STARPU_PLU(display_data_content)(lower_r, size);
+        memset(L, 0, size*size*sizeof(TYPE));
+        memset(U, 0, size*size*sizeof(TYPE));
 
-	fprintf(stderr, "\nUPPER\n");
-	TYPE *upper_r = reconstruct_upper(size, nblocks);
-	STARPU_PLU(display_data_content)(upper_r, size);
+        /* only keep the lower part */
+	unsigned i, j;
+        for (j = 0; j < size; j++)
+        {
+                for (i = 0; i < j; i++)
+                {
+                        L[j+i*size] = all_r[j+i*size];
+                }
+
+                /* diag i = j */
+                L[j+j*size] = all_r[j+j*size];
+                U[j+j*size] = 1.0;
+
+                for (i = j+1; i < size; i++)
+                {
+                        U[j+i*size] = all_r[j+i*size];
+                }
+        }
+
+//	STARPU_PLU(display_data_content)(L, size);
+//	STARPU_PLU(display_data_content)(U, size);
+
+        /* now A_err = L, compute L*U */
+        CPU_TRMM("R", "U", "N", "U", size, size, 1.0f, U, size, L, size);
+
+//	fprintf(stderr, "\nLU\n");
+//	STARPU_PLU(display_data_content)(L, size);
+
+        /* compute "LU - A" in L*/
+        CPU_AXPY(size*size, -1.0, Asaved, 1, L, 1);
+
+        TYPE err = CPU_ASUM(size*size, L, 1);
+        int max = CPU_IAMAX(size*size, L, 1);
+
+//	STARPU_PLU(display_data_content)(L, size);
 
-	TYPE *lu_r = calloc(size*size, sizeof(TYPE));
-	CPU_TRMM("R", "U", "N", "U", size, size, 1.0f, lower_r, size, upper_r, size);
+        fprintf(stderr, "(A - LU) Avg error : %e\n", err/(size*size));
+        fprintf(stderr, "(A - LU) Max error : %e\n", L[max]);
 
-	fprintf(stderr, "\nLU\n");
-	STARPU_PLU(display_data_content)(lower_r, size);
 }
 
 static STARPU_PLU(compute_ax_block_lower)(unsigned size, unsigned nblocks,
@@ -161,14 +219,14 @@ static STARPU_PLU(compute_ax_block_lower)(unsigned size, unsigned nblocks,
 {
 	unsigned block_size = size/nblocks;
 
-	fprintf(stderr, "KEEP LOWER\n");
-	STARPU_PLU(display_data_content)(block_data, block_size);
+//	fprintf(stderr, "KEEP LOWER\n");
+//	STARPU_PLU(display_data_content)(block_data, block_size);
 
 	/* Take a copy of the upper part of the diagonal block */
 	TYPE *lower_block_copy = calloc((block_size)*(block_size), sizeof(TYPE));
 	STARPU_PLU(extract_lower)(block_size, block_data, lower_block_copy);
 
-	STARPU_PLU(display_data_content)(lower_block_copy, block_size);
+//	STARPU_PLU(display_data_content)(lower_block_copy, block_size);
 
 	STARPU_PLU(compute_ax_block)(size/nblocks, lower_block_copy, sub_x, sub_y);
 	
@@ -217,15 +275,15 @@ void STARPU_PLU(compute_lux)(unsigned size, TYPE *x, TYPE *y, unsigned nblocks,
 	memset(yi, 0, size*sizeof(TYPE));
 
 	unsigned ind;
-	if (rank == 0)
-	{
-		fprintf(stderr, "INTERMEDIATE\n");
-		for (ind = 0; ind < STARPU_MIN(10, size); ind++)
-		{
-			fprintf(stderr, "x[%d] = %f\n", ind, (float)x[ind]);
-		}
-		fprintf(stderr, "****\n");
-	}
+//	if (rank == 0)
+//	{
+//		fprintf(stderr, "INTERMEDIATE\n");
+//		for (ind = 0; ind < STARPU_MIN(10, size); ind++)
+//		{
+//			fprintf(stderr, "x[%d] = %f\n", ind, (float)x[ind]);
+//		}
+//		fprintf(stderr, "****\n");
+//	}
 
 	/* Everyone needs x */
 	int bcst_ret;
@@ -274,15 +332,6 @@ void STARPU_PLU(compute_ax)(unsigned size, TYPE *x, TYPE *y, unsigned nblocks, i
 	bcst_ret = MPI_Bcast(&x, size, MPI_TYPE, 0, MPI_COMM_WORLD);
 	STARPU_ASSERT(bcst_ret == MPI_SUCCESS);
 
-	if (rank == 0)
-	{
-		unsigned ind;
-		for (ind = 0; ind < STARPU_MIN(10, size); ind++)
-			fprintf(stderr, "x[%d] = %f\n", ind, (float)x[ind]);
-
-		fprintf(stderr, "Compute AX = B\n");
-	}
-
 	/* Create temporary buffers where all MPI processes are going to
 	 * compute Ai x = yi where Ai is the matrix containing the blocks of A
 	 * affected to process i, and 0 everywhere else. We then have y as the

+ 22 - 12
mpi/examples/mpi_lu/pxlu.c

@@ -346,7 +346,8 @@ static void callback_task_12_real(void *_arg)
 	rank_mask[rank] = 0;
 
 	/* Send the block to those nodes */
-	starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(j, k);
+//	starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(j, k);
+	starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(k, j);
 	starpu_tag_t tag = TAG12_SAVE(k, j);
 	int mpi_tag = MPI_TAG12(k, j);
 	send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
@@ -369,7 +370,8 @@ static void create_task_12_real(unsigned k, unsigned j)
 
 	task->buffers[0].handle = diag_block; 
 	task->buffers[0].mode = STARPU_R;
-	task->buffers[1].handle = STARPU_PLU(get_block_handle)(j, k); 
+//	task->buffers[1].handle = STARPU_PLU(get_block_handle)(j, k); 
+	task->buffers[1].handle = STARPU_PLU(get_block_handle)(k, j); 
 	task->buffers[1].mode = STARPU_RW;
 
 	struct callback_arg *arg = malloc(sizeof(struct callback_arg));
@@ -477,7 +479,8 @@ static void callback_task_21_real(void *_arg)
 	rank_mask[rank] = 0;
 
 	/* Send the block to those nodes */
-	starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(k, i);
+//	starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(k, i);
+	starpu_data_handle block_handle = STARPU_PLU(get_block_handle)(i, k);
 	starpu_tag_t tag = TAG21_SAVE(k, i);
 	int mpi_tag = MPI_TAG21(k, i);
 	send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
@@ -500,7 +503,8 @@ static void create_task_21_real(unsigned k, unsigned i)
 
 	task->buffers[0].handle = diag_block; 
 	task->buffers[0].mode = STARPU_R;
-	task->buffers[1].handle = STARPU_PLU(get_block_handle)(k, i);
+//	task->buffers[1].handle = STARPU_PLU(get_block_handle)(k, i);
+	task->buffers[1].handle = STARPU_PLU(get_block_handle)(i, k);
 	task->buffers[1].mode = STARPU_RW;
 
 	struct callback_arg *arg = malloc(sizeof(struct callback_arg));
@@ -565,7 +569,10 @@ static void create_task_22_real(unsigned k, unsigned i, unsigned j)
 	/* produced by TAG21_SAVE(k, i) */ 
 	starpu_data_handle block21;
 	if (get_block_rank(i, k) == rank)
-		block21 = STARPU_PLU(get_block_handle)(k, i);
+	{
+	//	block21 = STARPU_PLU(get_block_handle)(k, i);
+		block21 = STARPU_PLU(get_block_handle)(i, k);
+	}
 	else 
 		block21 = STARPU_PLU(get_tmp_21_block_handle)(i);
 
@@ -575,7 +582,10 @@ static void create_task_22_real(unsigned k, unsigned i, unsigned j)
 	/* produced by TAG12_SAVE(k, j) */
 	starpu_data_handle block12;
 	if (get_block_rank(k, j) == rank)
-		block12 = STARPU_PLU(get_block_handle)(j, k);
+	{
+	//	block12 = STARPU_PLU(get_block_handle)(j, k);
+		block12 = STARPU_PLU(get_block_handle)(k, j);
+	}
 	else 
 		block12 = STARPU_PLU(get_tmp_12_block_handle)(j);
 
@@ -583,7 +593,8 @@ static void create_task_22_real(unsigned k, unsigned i, unsigned j)
 	task->buffers[1].mode = STARPU_R;
 
 	/* produced by TAG22(k-1, i, j) */
-	task->buffers[2].handle = STARPU_PLU(get_block_handle)(j, i);
+//	task->buffers[2].handle = STARPU_PLU(get_block_handle)(j, i);
+	task->buffers[2].handle = STARPU_PLU(get_block_handle)(i, j);
 	task->buffers[2].mode = STARPU_RW;
 
 	if (!no_prio &&  (i == k + 1) && (j == k +1) ) {
@@ -634,7 +645,6 @@ static void wait_termination(void)
 		if (get_block_rank(k, k) == rank)
 		{
 			starpu_data_handle diag_block = STARPU_PLU(get_block_handle)(k, k);
-//			fprintf(stderr, "Rank %d : waiting tag %lx = TAG11_SAVE(k=%d)\n", rank, TAG11_SAVE(k), k);
 			wait_tag_and_fetch_handle(TAG11_SAVE(k), diag_block);
 		}
 		
@@ -644,8 +654,8 @@ static void wait_termination(void)
 			/* Wait task 21ki is needed */
 			if (get_block_rank(i, k) == rank)
 			{
-				starpu_data_handle block21 = STARPU_PLU(get_block_handle)(k, i);
-//				fprintf(stderr, "Rank %d : waiting tag %lx = TAG21_SAVE(k=%d, i=%d)\n", rank, TAG21_SAVE(k, i), k, i);
+				//starpu_data_handle block21 = STARPU_PLU(get_block_handle)(k, i);
+				starpu_data_handle block21 = STARPU_PLU(get_block_handle)(i, k);
 				wait_tag_and_fetch_handle(TAG21_SAVE(k, i), block21);
 			}
 		}
@@ -655,8 +665,8 @@ static void wait_termination(void)
 			/* Wait task 12kj is needed */
 			if (get_block_rank(k, j) == rank)
 			{
-				starpu_data_handle block12 = STARPU_PLU(get_block_handle)(j, k);
-//				fprintf(stderr, "Rank %d : waiting tag %lx = TAG12_SAVE(k=%d, j=%d)\n", rank, TAG12_SAVE(k, j), k, j);
+				//starpu_data_handle block12 = STARPU_PLU(get_block_handle)(j, k);
+				starpu_data_handle block12 = STARPU_PLU(get_block_handle)(k, j);
 				wait_tag_and_fetch_handle(TAG12_SAVE(k, j), block12);
 			}
 		}

+ 26 - 4
mpi/examples/mpi_lu/pxlu_kernels.c

@@ -39,7 +39,7 @@ static inline void STARPU_PLU(common_u22)(void *descr[],
 
 	int rank;
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	fprintf(stderr, "KERNEL 22 %d\n", rank);
+	//fprintf(stderr, "KERNEL 22 %d\n", rank);
 
 #ifdef USE_CUDA
 	cublasStatus status;
@@ -129,13 +129,20 @@ static inline void STARPU_PLU(common_u12)(void *descr[],
 
 	int rank;
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	fprintf(stderr, "KERNEL 12 %d\n", rank);
+//	fprintf(stderr, "KERNEL 12 %d\n", rank);
 
 #ifdef USE_CUDA
 	cublasStatus status;
 	cudaError_t cures;
 #endif
 
+//	fprintf(stderr, "INPUT 12 U11\n");
+//	STARPU_PLU(display_data_content)(sub11, nx12);
+//	fprintf(stderr, "INPUT 12 U12\n");
+//	STARPU_PLU(display_data_content)(sub12, nx12);
+
+
+
 	/* solve L11 U12 = A12 (find U12) */
 	switch (s) {
 		case 0:
@@ -160,6 +167,9 @@ static inline void STARPU_PLU(common_u12)(void *descr[],
 			STARPU_ABORT();
 			break;
 	}
+
+//	fprintf(stderr, "OUTPUT 12 U12\n");
+//	STARPU_PLU(display_data_content)(sub12, nx12);
 }
 
 static void STARPU_PLU(cpu_u12)(void *descr[], void *_args)
@@ -217,7 +227,13 @@ static inline void STARPU_PLU(common_u21)(void *descr[],
 	
 	int rank;
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	fprintf(stderr, "KERNEL 21 %d\n", rank);
+//	fprintf(stderr, "KERNEL 21 %d \n", rank);
+
+	//fprintf(stderr, "INPUT 21 U11\n");
+	//STARPU_PLU(display_data_content)(sub11, nx21);
+	//fprintf(stderr, "INPUT 21 U12\n");
+	//STARPU_PLU(display_data_content)(sub21, nx21);
+
 
 
 #ifdef USE_CUDA
@@ -225,6 +241,7 @@ static inline void STARPU_PLU(common_u21)(void *descr[],
 	cudaError_t cures;
 #endif
 
+
 	switch (s) {
 		case 0:
 			CPU_TRSM("R", "U", "N", "U", nx21, ny21,
@@ -247,6 +264,11 @@ static inline void STARPU_PLU(common_u21)(void *descr[],
 			STARPU_ABORT();
 			break;
 	}
+
+//	fprintf(stderr, "INPUT 21 U21\n");
+//	STARPU_PLU(display_data_content)(sub21, nx21);
+
+
 }
 
 static void STARPU_PLU(cpu_u21)(void *descr[], void *_args)
@@ -301,7 +323,7 @@ static inline void STARPU_PLU(common_u11)(void *descr[],
 
 	int rank;
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	fprintf(stderr, "KERNEL 11 %d\n", rank);
+//	fprintf(stderr, "KERNEL 11 %d\n", rank);
 
 	switch (s) {
 		case 0: