Explorar o código

Fix potrf kernel: benefit from the MKL implementation, and Magma is actually not asynchronous

Samuel Thibault %!s(int64=10) %!d(string=hai) anos
pai
achega
7fa7ccb485
Modificáronse 3 ficheiros con 30 adicións e 3 borrados
  1. 5 3
      examples/cholesky/cholesky_kernels.c
  2. 15 0
      examples/common/blas.c
  3. 10 0
      examples/common/blas.h

+ 5 - 3
examples/cholesky/cholesky_kernels.c

@@ -158,6 +158,9 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, STARPU_A
 	{
 		case 0:
 
+#ifdef STARPU_MKL
+			STARPU_SPOTRF("L", nx, sub11, ld);
+#else
 			/*
 			 *	- alpha 11 <- lambda 11 = sqrt(alpha11)
 			 *	- alpha 21 <- l 21	= alpha 21 / lambda 11
@@ -178,6 +181,7 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, STARPU_A
 							&sub11[(z+1)+z*ld], 1,
 							&sub11[(z+1)+(z+1)*ld], ld);
 			}
+#endif
 			break;
 #ifdef STARPU_USE_CUDA
 		case 1:
@@ -191,6 +195,7 @@ static inline void chol_common_codelet_update_u11(void *descr[], int s, STARPU_A
 				fprintf(stderr, "Error in Magma: %d\n", ret);
 				STARPU_ABORT();
 			}
+			cudaThreadSynchronize();
 			}
 #else
 			{
@@ -252,9 +257,6 @@ struct starpu_codelet cl11 =
 #elif defined(STARPU_SIMGRID)
 	.cuda_funcs = {(void*)1},
 #endif
-#ifdef STARPU_HAVE_MAGMA
-	.cuda_flags = {STARPU_CUDA_ASYNC},
-#endif
 	.nbuffers = 1,
 	.modes = { STARPU_RW },
 	.model = &chol_model_11

+ 15 - 0
examples/common/blas.c

@@ -414,6 +414,17 @@ void STARPU_DSWAP(const int n, double *X, const int incX, double *Y, const int i
 	dswap_(&n, X, &incX, Y, &incY);
 }
 
+void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda)
+{
+	int info = 0;
+	spotrf_(uplo, &n, a, &lda, &info);
+}
+
+void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda)
+{
+	int info = 0;
+	dpotrf_(uplo, &n, a, &lda, &info);
+}
 
 #elif defined(STARPU_SIMGRID)
 inline void STARPU_SGEMM(char *transa, char *transb, int M, int N, int K, 
@@ -498,6 +509,10 @@ void STARPU_SSWAP(const int n, float *X, const int incX, float *Y, const int inc
 
 void STARPU_DSWAP(const int n, double *X, const int incX, double *Y, const int incY) { }
 
+void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda) { }
+
+void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda) { }
+
 
 #else
 #error "no BLAS lib available..."

+ 10 - 0
examples/common/blas.h

@@ -82,6 +82,11 @@ double STARPU_DDOT(const int n, const double *x, const int incx, const double *y
 void STARPU_SSWAP(const int n, float *x, const int incx, float *y, const int incy);
 void STARPU_DSWAP(const int n, double *x, const int incx, double *y, const int incy);
 
+#ifdef STARPU_MKL
+void STARPU_SPOTRF(const char*uplo, const int n, float *a, const int lda);
+void STARPU_DPOTRF(const char*uplo, const int n, double *a, const int lda);
+#endif
+
 #if defined(STARPU_GOTO) || defined(STARPU_SYSTEM_BLAS) || defined(STARPU_MKL)
 
 extern void sgemm_ (const char *transa, const char *transb, const int *m,
@@ -152,6 +157,11 @@ extern double ddot_(const int *n, const double *x, const int *incx, const double
 extern void sswap_(const int *n, float *x, const int *incx, float *y, const int *incy);
 extern void dswap_(const int *n, double *x, const int *incx, double *y, const int *incy);
 
+#ifdef STARPU_MKL
+extern void spotrf_(const char*uplo, const int *n, float *a, const int *lda, int *info);
+extern void dpotrf_(const char*uplo, const int *n, double *a, const int *lda, int *info);
+#endif
+
 #endif
 
 #endif /* __BLAS_H__ */