|
@@ -24,6 +24,7 @@
|
|
|
#include <cublas.h>
|
|
|
|
|
|
static int cublas_initialized[STARPU_NMAXWORKERS];
|
|
|
+static starpu_pthread_mutex_t mutex;
|
|
|
|
|
|
static unsigned get_idx(void) {
|
|
|
unsigned workerid = starpu_worker_get_id_check();
|
|
@@ -42,12 +43,14 @@ static unsigned get_idx(void) {
|
|
|
static void init_cublas_func(void *args STARPU_ATTRIBUTE_UNUSED)
|
|
|
{
|
|
|
unsigned idx = get_idx();
|
|
|
- if (STARPU_ATOMIC_ADD(&cublas_initialized[idx], 1) == 1)
|
|
|
+ STARPU_PTHREAD_MUTEX_LOCK(&mutex);
|
|
|
+ if (!(cublas_initialized[idx]++))
|
|
|
{
|
|
|
cublasStatus cublasst = cublasInit();
|
|
|
if (STARPU_UNLIKELY(cublasst))
|
|
|
STARPU_CUBLAS_REPORT_ERROR(cublasst);
|
|
|
}
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
|
|
|
}
|
|
|
|
|
|
static void set_cublas_stream_func(void *args STARPU_ATTRIBUTE_UNUSED)
|
|
@@ -58,8 +61,10 @@ static void set_cublas_stream_func(void *args STARPU_ATTRIBUTE_UNUSED)
|
|
|
static void shutdown_cublas_func(void *args STARPU_ATTRIBUTE_UNUSED)
|
|
|
{
|
|
|
unsigned idx = get_idx();
|
|
|
- if (STARPU_ATOMIC_ADD(&cublas_initialized[idx], -1) == 0)
|
|
|
+ STARPU_PTHREAD_MUTEX_LOCK(&mutex);
|
|
|
+ if (!--cublas_initialized[idx])
|
|
|
cublasShutdown();
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
|
|
|
}
|
|
|
#endif
|
|
|
|