Browse Source

Implement a parallel version of the GEMM kernel for the cholesky example.

Cédric Augonnet 14 years ago
parent
commit
355f25c062
2 changed files with 40 additions and 13 deletions
  1. 10 0
      examples/cholesky/cholesky_implicit.c
  2. 30 13
      examples/cholesky/cholesky_kernels.c

+ 10 - 0
examples/cholesky/cholesky_implicit.c

@@ -25,6 +25,7 @@
 static starpu_codelet cl11 =
 {
 	.where = STARPU_CPU|STARPU_CUDA,
+	.type = STARPU_SEQ,
 	.cpu_func = chol_cpu_codelet_update_u11,
 #ifdef STARPU_USE_CUDA
 	.cuda_func = chol_cublas_codelet_update_u11,
@@ -36,6 +37,7 @@ static starpu_codelet cl11 =
 static starpu_codelet cl21 =
 {
 	.where = STARPU_CPU|STARPU_CUDA,
+	.type = STARPU_SEQ,
 	.cpu_func = chol_cpu_codelet_update_u21,
 #ifdef STARPU_USE_CUDA
 	.cuda_func = chol_cublas_codelet_update_u21,
@@ -47,6 +49,8 @@ static starpu_codelet cl21 =
 static starpu_codelet cl22 =
 {
 	.where = STARPU_CPU|STARPU_CUDA,
+	.type = STARPU_SEQ,
+	.max_parallelism = INT_MAX,
 	.cpu_func = chol_cpu_codelet_update_u22,
 #ifdef STARPU_USE_CUDA
 	.cuda_func = chol_cublas_codelet_update_u22,
@@ -60,6 +64,11 @@ static starpu_codelet cl22 =
  *	and construct the DAG
  */
 
+static void callback_turn_spmd_on(void *arg __attribute__ ((unused)))
+{
+	cl22.type = STARPU_SPMD;
+}
+
 static void _cholesky(starpu_data_handle dataA, unsigned nblocks)
 {
 	struct timeval start;
@@ -79,6 +88,7 @@ static void _cholesky(starpu_data_handle dataA, unsigned nblocks)
                 starpu_insert_task(&cl11,
                                    STARPU_PRIORITY, prio_level,
                                    STARPU_RW, sdatakk,
+				   STARPU_CALLBACK, (k == 3*nblocks/4)?callback_turn_spmd_on:NULL,
                                    0);
 
 		for (j = k+1; j<nblocks; j++)

+ 30 - 13
examples/cholesky/cholesky_kernels.c

@@ -41,23 +41,40 @@ static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __at
 	unsigned ld12 = STARPU_MATRIX_GET_LD(descr[1]);
 	unsigned ld22 = STARPU_MATRIX_GET_LD(descr[2]);
 
-	switch (s) {
-		case 0:
+	if (s == 0)
+	{
+		int worker_size = starpu_combined_worker_get_size();
+
+		if (worker_size == 1)
+		{
+			/* Sequential CPU kernel */
 			SGEMM("N", "T", dy, dx, dz, -1.0f, left, ld21, 
 				right, ld12, 1.0f, center, ld22);
-			break;
+		}
+		else {
+			/* Parallel CPU kernel */
+			int rank = starpu_combined_worker_get_rank();
+
+			int block_size = (dx + worker_size - 1)/worker_size;
+			int new_dx = STARPU_MIN(dx, block_size*(rank+1)) - block_size*rank;
+			
+			float *new_left = &left[block_size*rank];
+			float *new_center = &center[block_size*rank];
+
+			SGEMM("N", "T", dy, new_dx, dz, -1.0f, new_left, ld21, 
+				right, ld12, 1.0f, new_center, ld22);
+		}
+	}
+	else
+	{
+		/* CUDA kernel */
 #ifdef STARPU_USE_CUDA
-		case 1:
-			cublasSgemm('n', 't', dy, dx, dz, 
-					-1.0f, left, ld21, right, ld12, 
-					 1.0f, center, ld22);
-			cudaStreamSynchronize(starpu_cuda_get_local_stream());
-
-			break;
+		cublasSgemm('n', 't', dy, dx, dz, 
+				-1.0f, left, ld21, right, ld12, 
+				 1.0f, center, ld22);
+		cudaStreamSynchronize(starpu_cuda_get_local_stream());
 #endif
-		default:
-			STARPU_ABORT();
-			break;
+
 	}
 }