Pārlūkot izejas kodu

julia: Check cublas return status.

Pierre Huchant 5 gadi atpakaļ
vecāks
revīzija
ce213cfc8b
2 mainītis faili ar 9 papildinājumiem un 2 dzēšanām
  1. 8 1
      julia/src/compiler/cuda.jl
  2. 1 1
      julia/src/compiler/expressions.jl

+ 8 - 1
julia/src/compiler/cuda.jl

@@ -259,7 +259,14 @@ function translate_cublas(expr :: StarpuExpr)
 
 
         new_args = [@parse(starpu_cublas_get_local_handle()), x.args...]
         new_args = [@parse(starpu_cublas_get_local_handle()), x.args...]
 
 
-        return StarpuExprBlock([StarpuExprCall(blas_to_cublas[x.func][1], new_args),
+        status_varname = "status"*rand_string()
+        status_var = StarpuExprVar(Symbol("cublasStatus_t "*status_varname))
+        call_expr = StarpuExprCall(blas_to_cublas[x.func][1], new_args)
+
+        return StarpuExprBlock([StarpuExprAffect(status_var, call_expr),
+                                starpu_parse(Meta.parse("""if $status_varname != CUBLAS_STATUS_SUCCESS
+                                                              STARPU_CUBLAS_REPORT_ERROR($status_varname)
+                                                          end""")),
                                 @parse cudaStreamSynchronize(starpu_cuda_get_local_stream())])
                                 @parse cudaStreamSynchronize(starpu_cuda_get_local_stream())])
     end
     end
 
 

+ 1 - 1
julia/src/compiler/expressions.jl

@@ -253,7 +253,7 @@ function starpu_parse_call(x :: Expr)
 end
 end
 
 
 
 
-starpu_infix_operators = (:(+), :(*), :(-), :(/), :(<), :(>), :(<=), :(>=), :(%))
+starpu_infix_operators = (:(+), :(*), :(-), :(/), :(<), :(>), :(<=), :(>=), :(!=), :(%))
 
 
 
 
 function print_prefix(io :: IO, x :: StarpuExprCall ; indent = 0, restrict=false)
 function print_prefix(io :: IO, x :: StarpuExprCall ; indent = 0, restrict=false)