|
@@ -24,6 +24,8 @@ struct jl_starpu_codelet
|
|
|
modes
|
|
|
end
|
|
|
|
|
|
+global codelet_list = Vector{jl_starpu_codelet}()
|
|
|
+
|
|
|
function starpu_codelet(;
|
|
|
cpu_func :: String = "",
|
|
|
cuda_func :: String = "",
|
|
@@ -60,6 +62,11 @@ function starpu_codelet(;
|
|
|
output.c_codelet.cuda_func = load_starpu_function_pointer(cuda_func)
|
|
|
output.c_codelet.opencl_func = load_starpu_function_pointer(opencl_func)
|
|
|
|
|
|
+ # Codelets must not be garbage collected before starpu shutdown is called.
|
|
|
+ lock(mutex)
|
|
|
+ push!(codelet_list, output)
|
|
|
+ unlock(mutex)
|
|
|
+
|
|
|
return output
|
|
|
end
|
|
|
|
|
@@ -76,6 +83,8 @@ mutable struct jl_starpu_task
|
|
|
c_task :: starpu_task
|
|
|
end
|
|
|
|
|
|
+task_list = Vector{jl_starpu_task}()
|
|
|
+
|
|
|
"""
|
|
|
starpu_task(; cl :: jl_starpu_codelet, handles :: Vector{StarpuDataHandle}, cl_arg :: Ref)
|
|
|
|
|
@@ -132,6 +141,12 @@ function starpu_task(; cl :: Union{Cvoid, jl_starpu_codelet} = nothing, handles
|
|
|
output.c_task.callback_func = load_wrapper_function_pointer("julia_callback_func")
|
|
|
end
|
|
|
|
|
|
+ # Tasks must not be garbage collected before starpu_task_wait_for_all is called.
|
|
|
+ # This is necessary in particular for tasks created inside callback functions.
|
|
|
+ lock(mutex)
|
|
|
+ push!(task_list, output)
|
|
|
+ unlock(mutex)
|
|
|
+
|
|
|
return output
|
|
|
end
|
|
|
|
|
@@ -167,8 +182,6 @@ function starpu_task_submit(task :: jl_starpu_task)
|
|
|
error("Invalid number of handles for task : $(length(task.handles)) where given while codelet has $(task.cl.modes) modes")
|
|
|
end
|
|
|
|
|
|
- # Prevent task from being garbage collected. This is necessary for tasks created
|
|
|
- # inside callbacks.
|
|
|
starpu_task_submit(Ref(task.c_task))
|
|
|
|
|
|
if task.callback_function != nothing
|
|
@@ -176,24 +189,19 @@ function starpu_task_submit(task :: jl_starpu_task)
|
|
|
callback_signal = task.callback_signal
|
|
|
callback_function = task.callback_function
|
|
|
|
|
|
- @qbthreads for x in 1:1
|
|
|
- begin
|
|
|
- # Active waiting loop
|
|
|
- # We're doing a fake computation on tmp to prevent optimization.
|
|
|
- tmp = 0
|
|
|
- while task.callback_signal[1] == 0
|
|
|
- tmp += 1
|
|
|
- end
|
|
|
+ lock(mutex)
|
|
|
+ put!(task_pool) do
|
|
|
|
|
|
- # We've received the signal from the pthread, now execute the callback.
|
|
|
- callback_function(callback_arg)
|
|
|
+ # Active waiting loop
|
|
|
+ @starpucall(julia_wait_signal, Cvoid, (Ptr{Cvoid},), Base.unsafe_convert(Ptr{Cvoid}, callback_signal))
|
|
|
|
|
|
- # Tell the pthread that the callback is done.
|
|
|
- callback_signal[1] = 0
|
|
|
+ # We've received the signal from the pthread, now execute the callback.
|
|
|
+ callback_function(callback_arg)
|
|
|
|
|
|
- return callback_signal[1]
|
|
|
- end
|
|
|
+ # Tell the pthread that the callback is done.
|
|
|
+ callback_signal[1] = 0
|
|
|
end
|
|
|
+ unlock(mutex)
|
|
|
end
|
|
|
end
|
|
|
|
|
@@ -246,6 +254,10 @@ end
|
|
|
function starpu_task_wait_for_all()
|
|
|
@threadcall(@starpufunc(:starpu_task_wait_for_all),
|
|
|
Cint, ())
|
|
|
+
|
|
|
+ lock(mutex)
|
|
|
+ empty!(task_list)
|
|
|
+ unlock(mutex)
|
|
|
end
|
|
|
|
|
|
"""
|