Forráskód Böngészése

julia: Add support for views in kernel codes.

Pierre Huchant 5 éve
szülő
commit
bd74e8a01f
2 módosított fájl, 43 hozzáadás és 18 törlés
  1. 20 3
      julia/src/compiler/c.jl
  2. 23 15
      julia/src/compiler/expressions.jl

+ 20 - 3
julia/src/compiler/c.jl

@@ -73,6 +73,7 @@ function transform_to_cpu_kernel(expr :: StarpuExprFunction)
     output = add_for_loop_declarations(expr)
     output = substitute_args(output)
     output = substitute_func_calls(output)
+    output = substitute_views(output)
     output = substitute_indexing(output)
     output = flatten_blocks(output)
 
@@ -212,9 +213,9 @@ function substitute_args(expr :: StarpuExprFunction)
 
 
     new_args = [
-                    starpu_parse(:($buffer_arg_name :: Matrix{Nothing})),
-                    starpu_parse(:($cl_arg_name :: Vector{Nothing}))
-                ]
+        starpu_parse(:($buffer_arg_name :: Ptr{Ptr{Nothing}})),
+        starpu_parse(:($cl_arg_name :: Vector{Nothing}))
+    ]
     new_body = StarpuExprBlock([function_start_affectations..., new_body.exprs...])
 
     return StarpuExprFunction(expr.ret_type, expr.func, new_args, new_body)
@@ -243,6 +244,22 @@ function substitute_func_calls(expr :: StarpuExpr)
     return apply(func_to_apply, expr)
 end
 
+function substitute_views(expr :: StarpuExpr)
+    function func_to_apply(x :: StarpuExpr)
+
+        if !isa(x, StarpuExprCall) || x.func != :view
+            return x
+        end
+
+        ref = x.args[1]
+        indexes = map(i -> isa(i, StarpuExprInterval) ? i.start : i, x.args[2:end])
+
+        return StarpuExprAddress(StarpuExprRef(ref, indexes))
+    end
+
+    return apply(func_to_apply, expr)
+
+end
 
 function substitute_indexing(expr :: StarpuExpr)
 

+ 23 - 15
julia/src/compiler/expressions.jl

@@ -124,6 +124,9 @@ struct StarpuExprWhile <: StarpuExpr
     body :: StarpuExpr
 end
 
+struct StarpuExprAddress <: StarpuExpr
+    ref :: StarpuExpr
+end
 
 function starpu_parse_affect(x :: Expr)
 
@@ -296,7 +299,6 @@ function apply(func :: Function, expr :: StarpuExprCall)
     return func(StarpuExprCall(expr.func, map((x -> apply(func, x)), expr.args)))
 end
 
-
 #======================================================
                 CUDA KERNEL CALL
 ======================================================#
@@ -734,8 +736,6 @@ function print(io :: IO, x :: StarpuExprRef ; indent = 0,restrict=false)
 
 end
 
-
-
 function apply(func :: Function, expr :: StarpuExprRef)
 
     ref = apply(func, expr.ref)
@@ -744,6 +744,16 @@ function apply(func :: Function, expr :: StarpuExprRef)
     return func(StarpuExprRef(ref, indexes))
 end
 
+function print(io :: IO, x :: StarpuExprAddress ; indent = 0, restrict=false)
+    print(io, "&")
+    print(io, x.ref, indent = indent)
+end
+
+function apply(func :: Function, expr :: StarpuExprAddress)
+    ref = apply(func, expr.ref)
+    return func(StarpuExprAddress(ref))
+end
+
 #======================================================
                 BREAK EXPRESSION
 ======================================================#
@@ -799,7 +809,7 @@ function apply(func :: Function, expr :: StarpuExpr)
     return func(expr)
 end
 
-print(io :: IO, x :: StarpuExprVar ; indent = 0) = print(io, x.name)
+print(io :: IO, x :: StarpuExprVar ; indent = 0, restrict = false) = print(io, x.name)
 
 function print(io :: IO, x :: StarpuExprValue ; indent = 0,restrict=false)
 
@@ -869,26 +879,24 @@ end
 
 function starpu_type_traduction(x)
     if x <: Array
-        return starpu_type_traduction_array(x)
+        return starpu_type_traduction(eltype(x)) * "*"
     end
 
     if x <: Ptr
-        return starpu_type_traduction(eltype(x)) * "*"
+        depth = 1
+        type = eltype(x)
+        while type <: Ptr
+            depth +=1
+            type = eltype(type)
+        end
+
+        return starpu_type_traduction(type) * "*"^depth
     end
 
     return starpu_type_traduction_dict[x]
 
 end
 
-function starpu_type_traduction_array(x :: Type{Array{T,N}})  where {T,N}
-    output = starpu_type_traduction(T)
-    for i in (1 : N)
-        output *= "*"
-    end
-
-    return output
-end
-
 function print(io :: IO, x :: StarpuExprTyped ; indent = 0,restrict=false)
 
     if (isa(x, StarpuExprTypedVar))