Browse Source

Use streams instead of cudaThreadSynchronize

Cédric Augonnet 14 years ago
parent
commit
b8ffa4f5e8
1 changed files with 6 additions and 15 deletions
  1. 6 15
      examples/cholesky/cholesky_kernels.c

+ 6 - 15
examples/cholesky/cholesky_kernels.c

@@ -19,9 +19,7 @@
 #include "cholesky.h"
 #include "../common/blas.h"
 #ifdef STARPU_USE_CUDA
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <cublas.h>
+#include <starpu_cuda.h>
 #endif
 
 /*
@@ -43,10 +41,6 @@ static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __at
 	unsigned ld12 = STARPU_MATRIX_GET_LD(descr[1]);
 	unsigned ld22 = STARPU_MATRIX_GET_LD(descr[2]);
 
-#ifdef STARPU_USE_CUDA
-	cublasStatus st;
-#endif
-
 	switch (s) {
 		case 0:
 			SGEMM("N", "T", dy, dx, dz, -1.0f, left, ld21, 
@@ -57,10 +51,7 @@ static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __at
 			cublasSgemm('n', 't', dy, dx, dz, 
 					-1.0f, left, ld21, right, ld12, 
 					 1.0f, center, ld22);
-			st = cublasGetError();
-			STARPU_ASSERT(!st);
-
-			cudaThreadSynchronize();
+			cudaStreamSynchronize(starpu_cuda_get_local_stream());
 
 			break;
 #endif
@@ -108,7 +99,7 @@ static inline void chol_common_codelet_update_u21(void *descr[], int s, __attrib
 #ifdef STARPU_USE_CUDA
 		case 1:
 			cublasStrsm('R', 'L', 'T', 'N', nx21, ny21, 1.0f, sub11, ld11, sub21, ld21);
-			cudaThreadSynchronize();
+			cudaStreamSynchronize(starpu_cuda_get_local_stream());
 			break;
 #endif
 		default:
@@ -174,8 +165,8 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, __attrib
 			for (z = 0; z < nx; z++)
 			{
 				float lambda11;
-				cudaMemcpy(&lambda11, &sub11[z+z*ld], sizeof(float), cudaMemcpyDeviceToHost);
-				cudaStreamSynchronize(0);
+				cudaMemcpyAsync(&lambda11, &sub11[z+z*ld], sizeof(float), cudaMemcpyDeviceToHost, starpu_cuda_get_local_stream());
+				cudaStreamSynchronize(starpu_cuda_get_local_stream());
 
 				STARPU_ASSERT(lambda11 != 0.0f);
 				
@@ -190,7 +181,7 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, __attrib
 							&sub11[(z+1)+(z+1)*ld], ld);
 			}
 		
-			cudaThreadSynchronize();
+			cudaStreamSynchronize(starpu_cuda_get_local_stream());
 
 			break;
 #endif