Browse Source

Avoid cublasSdot bug when using non-blocking streams, fixed in cuda 7.5

Samuel Thibault 9 years ago
parent
commit
ba27f5ed74
1 changed files with 10 additions and 4 deletions
  1. 10 4
      examples/reductions/dot_product.c

+ 10 - 4
examples/reductions/dot_product.c

@@ -51,6 +51,8 @@ static unsigned _entries_per_block = 1024;
 static DOT_TYPE _dot = 0.0f;
 static starpu_data_handle_t _dot_handle;
 
+static int cublas_version;
+
 static int can_execute(unsigned workerid, struct starpu_task *task, unsigned nimpl)
 {
 	enum starpu_worker_archtype type = starpu_worker_get_type(workerid);
@@ -250,14 +252,15 @@ void dot_cuda_func(void *descr[], void *cl_arg)
 	unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
 
 	cudaMemcpyAsync(&current_dot, dot, sizeof(DOT_TYPE), cudaMemcpyDeviceToHost, starpu_cuda_get_local_stream());
+	cudaStreamSynchronize(starpu_cuda_get_local_stream());
 
 	local_dot = (DOT_TYPE)cublasSdot(n, local_x, 1, local_y, 1);
 
 	/* FPRINTF(stderr, "current_dot %f local dot %f -> %f\n", current_dot, local_dot, current_dot + local_dot); */
-	cudaStreamSynchronize(starpu_cuda_get_local_stream());
 	current_dot += local_dot;
 
 	cudaMemcpyAsync(dot, &current_dot, sizeof(DOT_TYPE), cudaMemcpyHostToDevice, starpu_cuda_get_local_stream());
+	cudaStreamSynchronize(starpu_cuda_get_local_stream());
 }
 #endif
 
@@ -318,7 +321,6 @@ static struct starpu_codelet dot_codelet =
 	.cpu_funcs_name = {"dot_cpu_func"},
 #ifdef STARPU_USE_CUDA
 	.cuda_funcs = {dot_cuda_func},
-	.cuda_flags = {STARPU_CUDA_ASYNC},
 #endif
 #ifdef STARPU_USE_OPENCL
 	.opencl_funcs = {dot_opencl_func},
@@ -352,7 +354,10 @@ int main(int argc, char **argv)
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_opencl_load_opencl_from_file");
 #endif
 
-	starpu_cublas_init();
+	/* cublasSdot has synchronization issues when using a non-blocking stream */
+	cublasGetVersion(&cublas_version);
+	if (cublas_version >= 7050)
+		starpu_cublas_init();
 
 	unsigned long nelems = _nblocks*_entries_per_block;
 	size_t size = nelems*sizeof(float);
@@ -419,7 +424,8 @@ int main(int argc, char **argv)
 
 	FPRINTF(stderr, "Reference : %e vs. %e (Delta %e)\n", reference_dot, _dot, reference_dot - _dot);
 
-	starpu_cublas_shutdown();
+	if (cublas_version >= 7050)
+		starpu_cublas_shutdown();
 
 #ifdef STARPU_USE_OPENCL
         ret = starpu_opencl_unload_opencl(&_opencl_program);