소스 검색

Use a mutex to protect cublas initialization and shutdown, it seems that it is actually not threadsafe for the same device

Samuel Thibault 8 년 전
부모
커밋
fc69bb54c7
1개의 변경된 파일7개의 추가작업 그리고 2개의 파일을 삭제
  1. 7 2
      src/drivers/cuda/starpu_cublas.c

+ 7 - 2
src/drivers/cuda/starpu_cublas.c

@@ -24,6 +24,7 @@
 #include <cublas.h>
 
 static int cublas_initialized[STARPU_NMAXWORKERS];
+static starpu_pthread_mutex_t mutex;
 
 static unsigned get_idx(void) {
 	unsigned workerid = starpu_worker_get_id_check();
@@ -42,12 +43,14 @@ static unsigned get_idx(void) {
 static void init_cublas_func(void *args STARPU_ATTRIBUTE_UNUSED)
 {
 	unsigned idx = get_idx();
-	if (STARPU_ATOMIC_ADD(&cublas_initialized[idx], 1) == 1)
+	STARPU_PTHREAD_MUTEX_LOCK(&mutex);
+	if (!(cublas_initialized[idx]++))
 	{
 		cublasStatus cublasst = cublasInit();
 		if (STARPU_UNLIKELY(cublasst))
 			STARPU_CUBLAS_REPORT_ERROR(cublasst);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
 }
 
 static void set_cublas_stream_func(void *args STARPU_ATTRIBUTE_UNUSED)
@@ -58,8 +61,10 @@ static void set_cublas_stream_func(void *args STARPU_ATTRIBUTE_UNUSED)
 static void shutdown_cublas_func(void *args STARPU_ATTRIBUTE_UNUSED)
 {
 	unsigned idx = get_idx();
-	if (STARPU_ATOMIC_ADD(&cublas_initialized[idx], -1) == 0)
+	STARPU_PTHREAD_MUTEX_LOCK(&mutex);
+	if (!--cublas_initialized[idx])
 		cublasShutdown();
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
 }
 #endif