Przeglądaj źródła

Add support for SPMD parallel tasks in the GEMM example

Cédric Augonnet 14 lat temu
rodzic
commit
a03bf99d4b
1 zmienionych plików z 73 dodań i 50 usunięć
  1. 73 50
      examples/mult/xgemm.c

+ 73 - 50
examples/mult/xgemm.c

@@ -27,6 +27,7 @@
 #ifdef STARPU_USE_CUDA
 #include <cuda.h>
 #include <cublas.h>
+#include <starpu_cuda.h>
 #endif
 
 static unsigned niter = 100;
@@ -40,52 +41,6 @@ static unsigned check = 0;
 static TYPE *A, *B, *C;
 static starpu_data_handle A_handle, B_handle, C_handle;
 
-static void parse_args(int argc, char **argv)
-{
-	int i;
-	for (i = 1; i < argc; i++) {
-		if (strcmp(argv[i], "-nblocks") == 0) {
-			char *argptr;
-			nslicesx = strtol(argv[++i], &argptr, 10);
-			nslicesy = nslicesx;
-		}
-
-		if (strcmp(argv[i], "-nblocksx") == 0) {
-			char *argptr;
-			nslicesx = strtol(argv[++i], &argptr, 10);
-		}
-
-		if (strcmp(argv[i], "-nblocksy") == 0) {
-			char *argptr;
-			nslicesy = strtol(argv[++i], &argptr, 10);
-		}
-
-		if (strcmp(argv[i], "-x") == 0) {
-			char *argptr;
-			xdim = strtol(argv[++i], &argptr, 10);
-		}
-
-		if (strcmp(argv[i], "-y") == 0) {
-			char *argptr;
-			ydim = strtol(argv[++i], &argptr, 10);
-		}
-
-		if (strcmp(argv[i], "-z") == 0) {
-			char *argptr;
-			zdim = strtol(argv[++i], &argptr, 10);
-		}
-
-		if (strcmp(argv[i], "-iter") == 0) {
-			char *argptr;
-			niter = strtol(argv[++i], &argptr, 10);
-		}
-
-		if (strcmp(argv[i], "-check") == 0) {
-			check = 1;
-		}
-	}
-}
-
 static void check_output(void)
 {
 	/* compute C = C - AB */
@@ -144,8 +99,6 @@ static void partition_mult_data(void)
 	starpu_matrix_data_register(&C_handle, 0, (uintptr_t)C, 
 		ydim, ydim, xdim, sizeof(TYPE));
 
-	starpu_data_set_wt_mask(C_handle, 1<<0);
-
 	struct starpu_data_filter f;
 	memset(&f, 0, sizeof(f));
 	f.filter_func = starpu_vertical_block_filter_func;
@@ -177,13 +130,31 @@ static void mult_kernel_common(void *descr[], int type)
 	unsigned ldC = STARPU_MATRIX_GET_LD(descr[2]);
 
 	if (type == STARPU_CPU) {
-		CPU_GEMM("N", "N", nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB, (TYPE)0.0, subC, ldC);
+		int worker_size = starpu_combined_worker_get_size();
+
+		if (worker_size == 1)
+		{
+			/* Sequential CPU task */
+			CPU_GEMM("N", "N", nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB, (TYPE)0.0, subC, ldC);
+		}
+		else {
+			/* Parallel CPU task */
+			int rank = starpu_combined_worker_get_rank();
+		
+			int block_size = (nyC + worker_size - 1)/worker_size;
+			int new_nyC = STARPU_MIN(nyC, block_size*(rank+1)) - block_size*rank;
+
+			TYPE *new_subA = &subA[block_size*rank];
+			TYPE *new_subC = &subC[block_size*rank];
+
+			CPU_GEMM("N", "N", nxC, new_nyC, nyA, (TYPE)1.0, new_subA, ldA, subB, ldB, (TYPE)0.0, new_subC, ldC);
+		}
 	}
 #ifdef STARPU_USE_CUDA
 	else {
 		CUBLAS_GEMM('n', 'n', nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB,
 					     (TYPE)0.0, subC, ldC);
-		cudaThreadSynchronize();
+		cudaStreamSynchronize(starpu_cuda_get_local_stream());
 	}
 #endif
 }
@@ -207,6 +178,8 @@ static struct starpu_perfmodel_t starpu_gemm_model = {
 
 static starpu_codelet cl = {
 	.where = STARPU_CPU|STARPU_CUDA,
+	.type = STARPU_SEQ, /* changed to STARPU_SPMD if -spmd is passed */
+	.max_parallelism = INT_MAX,
 	.cpu_func = cpu_mult,
 #ifdef STARPU_USE_CUDA
 	.cuda_func = cublas_mult,
@@ -215,6 +188,56 @@ static starpu_codelet cl = {
 	.model = &starpu_gemm_model
 };
 
+static void parse_args(int argc, char **argv)
+{
+	int i;
+	for (i = 1; i < argc; i++) {
+		if (strcmp(argv[i], "-nblocks") == 0) {
+			char *argptr;
+			nslicesx = strtol(argv[++i], &argptr, 10);
+			nslicesy = nslicesx;
+		}
+
+		if (strcmp(argv[i], "-nblocksx") == 0) {
+			char *argptr;
+			nslicesx = strtol(argv[++i], &argptr, 10);
+		}
+
+		if (strcmp(argv[i], "-nblocksy") == 0) {
+			char *argptr;
+			nslicesy = strtol(argv[++i], &argptr, 10);
+		}
+
+		if (strcmp(argv[i], "-x") == 0) {
+			char *argptr;
+			xdim = strtol(argv[++i], &argptr, 10);
+		}
+
+		if (strcmp(argv[i], "-y") == 0) {
+			char *argptr;
+			ydim = strtol(argv[++i], &argptr, 10);
+		}
+
+		if (strcmp(argv[i], "-z") == 0) {
+			char *argptr;
+			zdim = strtol(argv[++i], &argptr, 10);
+		}
+
+		if (strcmp(argv[i], "-iter") == 0) {
+			char *argptr;
+			niter = strtol(argv[++i], &argptr, 10);
+		}
+
+		if (strcmp(argv[i], "-check") == 0) {
+			check = 1;
+		}
+
+		if (strcmp(argv[i], "-spmd") == 0) {
+			cl.type = STARPU_SPMD;
+		}
+	}
+}
+
 int main(int argc, char **argv)
 {
 	struct timeval start;