|
@@ -51,6 +51,8 @@ static unsigned _entries_per_block = 1024;
|
|
|
static DOT_TYPE _dot = 0.0f;
|
|
|
static starpu_data_handle_t _dot_handle;
|
|
|
|
|
|
+static int cublas_version;
|
|
|
+
|
|
|
static int can_execute(unsigned workerid, struct starpu_task *task, unsigned nimpl)
|
|
|
{
|
|
|
enum starpu_worker_archtype type = starpu_worker_get_type(workerid);
|
|
@@ -250,14 +252,15 @@ void dot_cuda_func(void *descr[], void *cl_arg)
|
|
|
unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
|
|
|
|
|
|
cudaMemcpyAsync(¤t_dot, dot, sizeof(DOT_TYPE), cudaMemcpyDeviceToHost, starpu_cuda_get_local_stream());
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
|
|
|
local_dot = (DOT_TYPE)cublasSdot(n, local_x, 1, local_y, 1);
|
|
|
|
|
|
/* FPRINTF(stderr, "current_dot %f local dot %f -> %f\n", current_dot, local_dot, current_dot + local_dot); */
|
|
|
- cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
current_dot += local_dot;
|
|
|
|
|
|
cudaMemcpyAsync(dot, ¤t_dot, sizeof(DOT_TYPE), cudaMemcpyHostToDevice, starpu_cuda_get_local_stream());
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
}
|
|
|
#endif
|
|
|
|
|
@@ -318,7 +321,6 @@ static struct starpu_codelet dot_codelet =
|
|
|
.cpu_funcs_name = {"dot_cpu_func"},
|
|
|
#ifdef STARPU_USE_CUDA
|
|
|
.cuda_funcs = {dot_cuda_func},
|
|
|
- .cuda_flags = {STARPU_CUDA_ASYNC},
|
|
|
#endif
|
|
|
#ifdef STARPU_USE_OPENCL
|
|
|
.opencl_funcs = {dot_opencl_func},
|
|
@@ -352,7 +354,10 @@ int main(int argc, char **argv)
|
|
|
STARPU_CHECK_RETURN_VALUE(ret, "starpu_opencl_load_opencl_from_file");
|
|
|
#endif
|
|
|
|
|
|
- starpu_cublas_init();
|
|
|
+ /* cublasSdot has synchronization issues when using a non-blocking stream */
|
|
|
+ cublasGetVersion(&cublas_version);
|
|
|
+ if (cublas_version >= 7050)
|
|
|
+ starpu_cublas_init();
|
|
|
|
|
|
unsigned long nelems = _nblocks*_entries_per_block;
|
|
|
size_t size = nelems*sizeof(float);
|
|
@@ -419,7 +424,8 @@ int main(int argc, char **argv)
|
|
|
|
|
|
FPRINTF(stderr, "Reference : %e vs. %e (Delta %e)\n", reference_dot, _dot, reference_dot - _dot);
|
|
|
|
|
|
- starpu_cublas_shutdown();
|
|
|
+ if (cublas_version >= 7050)
|
|
|
+ starpu_cublas_shutdown();
|
|
|
|
|
|
#ifdef STARPU_USE_OPENCL
|
|
|
ret = starpu_opencl_unload_opencl(&_opencl_program);
|