|
@@ -13,6 +13,8 @@
|
|
|
#
|
|
|
# See the GNU Lesser General Public License in COPYING.LGPL for more details.
|
|
|
#
|
|
|
+using ThreadPools
|
|
|
+
|
|
|
struct jl_starpu_codelet
|
|
|
c_codelet :: starpu_codelet
|
|
|
perfmodel :: starpu_perfmodel
|
|
@@ -68,6 +70,9 @@ mutable struct jl_starpu_task
|
|
|
handle_pointers :: Vector{StarpuDataHandlePointer}
|
|
|
synchronous :: Bool
|
|
|
cl_arg # type depends on codelet
|
|
|
+ callback_signal :: Vector{Cint}
|
|
|
+ callback_function :: Union{Cvoid, Function}
|
|
|
+ callback_arg
|
|
|
c_task :: starpu_task
|
|
|
end
|
|
|
|
|
@@ -76,13 +81,13 @@ end
|
|
|
|
|
|
Creates a new task which will run the specified codelet on handle buffers and cl_args data
|
|
|
"""
|
|
|
-function starpu_task(; cl :: Union{Cvoid, jl_starpu_codelet} = nothing, handles :: Vector{StarpuDataHandle} = StarpuDataHandle[], cl_arg = ())
|
|
|
-
|
|
|
+function starpu_task(; cl :: Union{Cvoid, jl_starpu_codelet} = nothing, handles :: Vector{StarpuDataHandle} = StarpuDataHandle[], cl_arg = (),
|
|
|
+ callback :: Union{Cvoid, Function} = nothing, callback_arg = nothing)
|
|
|
if (cl == nothing)
|
|
|
error("\"cl\" field can't be empty when creating a StarpuTask")
|
|
|
end
|
|
|
|
|
|
- output = jl_starpu_task(cl, handles, map((x -> x.object), handles), false, nothing, starpu_task(zero))
|
|
|
+ output = jl_starpu_task(cl, handles, map((x -> x.object), handles), false, nothing, Vector{Cint}(undef, 1), callback, callback_arg, starpu_task(zero))
|
|
|
|
|
|
# handle scalar_parameters
|
|
|
codelet_name = cl.cpu_func
|
|
@@ -119,6 +124,14 @@ function starpu_task(; cl :: Union{Cvoid, jl_starpu_codelet} = nothing, handles
|
|
|
output.c_task.cl_arg = Base.unsafe_convert(Ptr{Cvoid}, Ref(output.cl_arg))
|
|
|
output.c_task.cl_arg_size = sizeof(output.cl_arg)
|
|
|
end
|
|
|
+
|
|
|
+ # callback
|
|
|
+ if output.callback_function != nothing
|
|
|
+ output.callback_signal[1] = 0
|
|
|
+ output.c_task.callback_arg = Base.unsafe_convert(Ptr{Cvoid}, output.callback_signal)
|
|
|
+ output.c_task.callback_func = load_wrapper_function_pointer("julia_callback_func")
|
|
|
+ end
|
|
|
+
|
|
|
return output
|
|
|
end
|
|
|
|
|
@@ -154,7 +167,34 @@ function starpu_task_submit(task :: jl_starpu_task)
|
|
|
error("Invalid number of handles for task : $(length(task.handles)) where given while codelet has $(task.cl.modes) modes")
|
|
|
end
|
|
|
|
|
|
+ # Prevent task from being garbage collected. This is necessary for tasks created
|
|
|
+ # inside callbacks.
|
|
|
starpu_task_submit(Ref(task.c_task))
|
|
|
+
|
|
|
+ if task.callback_function != nothing
|
|
|
+ callback_arg = task.callback_arg
|
|
|
+ callback_signal = task.callback_signal
|
|
|
+ callback_function = task.callback_function
|
|
|
+
|
|
|
+ @qbthreads for x in 1:1
|
|
|
+ begin
|
|
|
+ # Active waiting loop
|
|
|
+ # We're doing a fake computation on tmp to prevent optimization.
|
|
|
+ tmp = 0
|
|
|
+ while task.callback_signal[1] == 0
|
|
|
+ tmp += 1
|
|
|
+ end
|
|
|
+
|
|
|
+ # We've received the signal from the pthread, now execute the callback.
|
|
|
+ callback_function(callback_arg)
|
|
|
+
|
|
|
+ # Tell the pthread that the callback is done.
|
|
|
+ callback_signal[1] = 0
|
|
|
+
|
|
|
+ return callback_signal[1]
|
|
|
+ end
|
|
|
+ end
|
|
|
+ end
|
|
|
end
|
|
|
|
|
|
function starpu_modes(x :: Symbol)
|
|
@@ -205,7 +245,7 @@ end
|
|
|
"""
|
|
|
function starpu_task_wait_for_all()
|
|
|
@threadcall(@starpufunc(:starpu_task_wait_for_all),
|
|
|
- Cint, ())
|
|
|
+ Cint, ())
|
|
|
end
|
|
|
|
|
|
"""
|