|
@@ -19,9 +19,7 @@
|
|
|
#include "cholesky.h"
|
|
|
#include "../common/blas.h"
|
|
|
#ifdef STARPU_USE_CUDA
|
|
|
-#include <cuda.h>
|
|
|
-#include <cuda_runtime.h>
|
|
|
-#include <cublas.h>
|
|
|
+#include <starpu_cuda.h>
|
|
|
#endif
|
|
|
|
|
|
|
|
@@ -43,10 +41,6 @@ static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __at
|
|
|
unsigned ld12 = STARPU_MATRIX_GET_LD(descr[1]);
|
|
|
unsigned ld22 = STARPU_MATRIX_GET_LD(descr[2]);
|
|
|
|
|
|
-#ifdef STARPU_USE_CUDA
|
|
|
- cublasStatus st;
|
|
|
-#endif
|
|
|
-
|
|
|
switch (s) {
|
|
|
case 0:
|
|
|
SGEMM("N", "T", dy, dx, dz, -1.0f, left, ld21,
|
|
@@ -57,10 +51,7 @@ static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __at
|
|
|
cublasSgemm('n', 't', dy, dx, dz,
|
|
|
-1.0f, left, ld21, right, ld12,
|
|
|
1.0f, center, ld22);
|
|
|
- st = cublasGetError();
|
|
|
- STARPU_ASSERT(!st);
|
|
|
-
|
|
|
- cudaThreadSynchronize();
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
|
|
|
break;
|
|
|
#endif
|
|
@@ -108,7 +99,7 @@ static inline void chol_common_codelet_update_u21(void *descr[], int s, __attrib
|
|
|
#ifdef STARPU_USE_CUDA
|
|
|
case 1:
|
|
|
cublasStrsm('R', 'L', 'T', 'N', nx21, ny21, 1.0f, sub11, ld11, sub21, ld21);
|
|
|
- cudaThreadSynchronize();
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
break;
|
|
|
#endif
|
|
|
default:
|
|
@@ -174,8 +165,8 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, __attrib
|
|
|
for (z = 0; z < nx; z++)
|
|
|
{
|
|
|
float lambda11;
|
|
|
- cudaMemcpy(&lambda11, &sub11[z+z*ld], sizeof(float), cudaMemcpyDeviceToHost);
|
|
|
- cudaStreamSynchronize(0);
|
|
|
+ cudaMemcpyAsync(&lambda11, &sub11[z+z*ld], sizeof(float), cudaMemcpyDeviceToHost, starpu_cuda_get_local_stream());
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
|
|
|
STARPU_ASSERT(lambda11 != 0.0f);
|
|
|
|
|
@@ -190,7 +181,7 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, __attrib
|
|
|
&sub11[(z+1)+(z+1)*ld], ld);
|
|
|
}
|
|
|
|
|
|
- cudaThreadSynchronize();
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
|
|
|
break;
|
|
|
#endif
|