瀏覽代碼

Do not factorize code in xgemm codelet, to avoid confusing the reader

Samuel Thibault 10 年之前
父節點
當前提交
d84a627a42
共有 1 個文件被更改,包括 33 次插入35 次删除
  1. 33 35
      examples/mult/xgemm.c

+ 33 - 35
examples/mult/xgemm.c

@@ -136,7 +136,8 @@ static void partition_mult_data(void)
 	starpu_data_map_filters(C_handle, 2, &vert, &horiz);
 }
 
-static void mult_kernel_common(void *descr[], int type)
+#ifdef STARPU_USE_CUDA
+static void cublas_mult(void *descr[], STARPU_ATTRIBUTE_UNUSED void *arg)
 {
 	TYPE *subA = (TYPE *)STARPU_MATRIX_GET_PTR(descr[0]);
 	TYPE *subB = (TYPE *)STARPU_MATRIX_GET_PTR(descr[1]);
@@ -150,50 +151,47 @@ static void mult_kernel_common(void *descr[], int type)
 	unsigned ldB = STARPU_MATRIX_GET_LD(descr[1]);
 	unsigned ldC = STARPU_MATRIX_GET_LD(descr[2]);
 
-	if (type == STARPU_CPU)
-	{
-		int worker_size = starpu_combined_worker_get_size();
+	CUBLAS_GEMM('n', 'n', nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB,
+				     (TYPE)0.0, subC, ldC);
+}
+#endif
 
-		if (worker_size == 1)
-		{
-			/* Sequential CPU task */
-			CPU_GEMM("N", "N", nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB, (TYPE)0.0, subC, ldC);
-		}
-		else
-		{
-			/* Parallel CPU task */
-			int rank = starpu_combined_worker_get_rank();
+static void cpu_mult(void *descr[], STARPU_ATTRIBUTE_UNUSED  void *arg)
+{
+	TYPE *subA = (TYPE *)STARPU_MATRIX_GET_PTR(descr[0]);
+	TYPE *subB = (TYPE *)STARPU_MATRIX_GET_PTR(descr[1]);
+	TYPE *subC = (TYPE *)STARPU_MATRIX_GET_PTR(descr[2]);
 
-			int block_size = (nyC + worker_size - 1)/worker_size;
-			int new_nyC = STARPU_MIN(nyC, block_size*(rank+1)) - block_size*rank;
+	unsigned nxC = STARPU_MATRIX_GET_NX(descr[2]);
+	unsigned nyC = STARPU_MATRIX_GET_NY(descr[2]);
+	unsigned nyA = STARPU_MATRIX_GET_NY(descr[0]);
 
-			STARPU_ASSERT(nyC = STARPU_MATRIX_GET_NY(descr[1]));
+	unsigned ldA = STARPU_MATRIX_GET_LD(descr[0]);
+	unsigned ldB = STARPU_MATRIX_GET_LD(descr[1]);
+	unsigned ldC = STARPU_MATRIX_GET_LD(descr[2]);
 
-			TYPE *new_subB = &subB[block_size*rank];
-			TYPE *new_subC = &subC[block_size*rank];
+	int worker_size = starpu_combined_worker_get_size();
 
-			CPU_GEMM("N", "N", nxC, new_nyC, nyA, (TYPE)1.0, subA, ldA, new_subB, ldB, (TYPE)0.0, new_subC, ldC);
-		}
+	if (worker_size == 1)
+	{
+		/* Sequential CPU task */
+		CPU_GEMM("N", "N", nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB, (TYPE)0.0, subC, ldC);
 	}
-#ifdef STARPU_USE_CUDA
 	else
 	{
-		CUBLAS_GEMM('n', 'n', nxC, nyC, nyA, (TYPE)1.0, subA, ldA, subB, ldB,
-					     (TYPE)0.0, subC, ldC);
-	}
-#endif
-}
+		/* Parallel CPU task */
+		int rank = starpu_combined_worker_get_rank();
 
-#ifdef STARPU_USE_CUDA
-static void cublas_mult(void *descr[], STARPU_ATTRIBUTE_UNUSED void *arg)
-{
-	mult_kernel_common(descr, STARPU_CUDA);
-}
-#endif
+		int block_size = (nyC + worker_size - 1)/worker_size;
+		int new_nyC = STARPU_MIN(nyC, block_size*(rank+1)) - block_size*rank;
 
-static void cpu_mult(void *descr[], STARPU_ATTRIBUTE_UNUSED  void *arg)
-{
-	mult_kernel_common(descr, STARPU_CPU);
+		STARPU_ASSERT(nyC = STARPU_MATRIX_GET_NY(descr[1]));
+
+		TYPE *new_subB = &subB[block_size*rank];
+		TYPE *new_subC = &subC[block_size*rank];
+
+		CPU_GEMM("N", "N", nxC, new_nyC, nyA, (TYPE)1.0, subA, ldA, new_subB, ldB, (TYPE)0.0, new_subC, ldC);
+	}
 }
 
 static struct starpu_perfmodel starpu_gemm_model =