浏览代码

Fix mult to not make it uselessly load the C tiles

Samuel Thibault 5 年之前
父节点
当前提交
abb207512d
共有 1 个文件被更改,包括 21 次插入1 次删除
  1. 21 1
      examples/mult/xgemm.c

+ 21 - 1
examples/mult/xgemm.c

@@ -173,6 +173,17 @@ static void cublas_mult(void *descr[], void *arg)
 	unsigned ldB = STARPU_MATRIX_GET_LD(descr[1]);
 	unsigned ldC = STARPU_MATRIX_GET_LD(descr[2]);
 
+	cudaStream_t stream = starpu_cuda_get_local_stream();
+
+	if (nxC == ldC)
+		cudaMemsetAsync(subC, 0, sizeof(*subC) * nxC * nyC, stream);
+	else
+	{
+		unsigned i;
+		for (i = 0; i < nyC; i++)
+			cudaMemsetAsync(subC + i*ldC, 0, sizeof(*subC) * nxC, stream);
+	}
+
 	cublasStatus_t status = CUBLAS_GEMM(starpu_cublas_get_local_handle(),
 			CUBLAS_OP_N, CUBLAS_OP_N,
 			nxC, nyC, nyA,
@@ -200,6 +211,15 @@ void cpu_mult(void *descr[], void *arg)
 
 	int worker_size = starpu_combined_worker_get_size();
 
+	if (nxC == ldC)
+		memset(subC, 0, sizeof(*subC) * nxC * nyC);
+	else
+	{
+		unsigned i;
+		for (i = 0; i < nyC; i++)
+			memset(subC + i*ldC, 0, sizeof(*subC) * nxC);
+	}
+
 	if (worker_size == 1)
 	{
 		/* Sequential CPU task */
@@ -241,7 +261,7 @@ static struct starpu_codelet cl =
 #endif
 	.cuda_flags = {STARPU_CUDA_ASYNC},
 	.nbuffers = 3,
-	.modes = {STARPU_R, STARPU_R, STARPU_RW},
+	.modes = {STARPU_R, STARPU_R, STARPU_W},
 	.model = &starpu_gemm_model
 };