瀏覽代碼

julia: Update callback support.

Pierre Huchant 5 年之前
父節點
當前提交
382faf7937
共有 5 個文件被更改,包括 50 次插入19 次删除
  1. 6 1
      julia/src/callback_wrapper.c
  2. 1 1
      julia/src/globals.jl
  3. 9 0
      julia/src/init.jl
  4. 6 1
      julia/src/perfmodel.jl
  5. 28 16
      julia/src/task.jl

+ 6 - 1
julia/src/callback_wrapper.c

@@ -26,9 +26,14 @@ void julia_callback_func(void *user_data)
 {
   volatile int *signal = (int *) user_data;
 
-  // Wakeup callback
+  // wakeup callback
   *(signal) = 1;
 
   // Wait for callback to end.
   while ((*signal) != 0);
 }
+
+void julia_wait_signal(volatile int *signal)
+{
+  while ((*signal) == 0);
+}

+ 1 - 1
julia/src/globals.jl

@@ -46,4 +46,4 @@ global starpu_type_traduction_dict = Dict(
 )
 export starpu_type_traduction_dict
 
-global perfmodels = Vector{starpu_perfmodel}()
+global mutex = Threads.SpinLock()

+ 9 - 0
julia/src/init.jl

@@ -44,6 +44,8 @@ function starpu_init()
     global starpu_wrapper_library_handle= Libdl.dlopen(starpu_wrapper_library_name)
     output = starpu_init(C_NULL)
 
+    global task_pool = ThreadPools.QueuePool(2)
+
     starpu_enter_new_block()
 
     return output
@@ -57,5 +59,12 @@ function starpu_shutdown()
 
     starpu_exit_block()
     @starpucall starpu_shutdown Cvoid ()
+
+    lock(mutex)
+    empty!(perfmodel_list)
+    empty!(codelet_list)
+    empty!(task_list)
+    unlock(mutex)
+
     return nothing
 end

+ 6 - 1
julia/src/perfmodel.jl

@@ -13,6 +13,9 @@
 #
 # See the GNU Lesser General Public License in COPYING.LGPL for more details.
 #
+
+perfmodel_list = Vector{starpu_perfmodel}()
+
 function starpu_perfmodel(; perf_type::starpu_perfmodel_type, symbol::String)
     output = starpu_perfmodel(zero)
     output.type = perf_type
@@ -20,7 +23,9 @@ function starpu_perfmodel(; perf_type::starpu_perfmodel_type, symbol::String)
 
     # Performance models must not be garbage collected before starpu_shutdown
     # is called.
-    push!(perfmodels, output)
+    lock(mutex)
+    push!(perfmodel_list, output)
+    unlock(mutex)
 
     return output
 end

+ 28 - 16
julia/src/task.jl

@@ -24,6 +24,8 @@ struct jl_starpu_codelet
     modes
 end
 
+global codelet_list = Vector{jl_starpu_codelet}()
+
 function starpu_codelet(;
                         cpu_func :: String = "",
                         cuda_func :: String = "",
@@ -60,6 +62,11 @@ function starpu_codelet(;
     output.c_codelet.cuda_func = load_starpu_function_pointer(cuda_func)
     output.c_codelet.opencl_func = load_starpu_function_pointer(opencl_func)
 
+    # Codelets must not be garbage collected before starpu shutdown is called.
+    lock(mutex)
+    push!(codelet_list, output)
+    unlock(mutex)
+
     return output
 end
 
@@ -76,6 +83,8 @@ mutable struct jl_starpu_task
     c_task :: starpu_task
 end
 
+task_list = Vector{jl_starpu_task}()
+
 """
             starpu_task(; cl :: jl_starpu_codelet, handles :: Vector{StarpuDataHandle}, cl_arg :: Ref)
 
@@ -132,6 +141,12 @@ function starpu_task(; cl :: Union{Cvoid, jl_starpu_codelet} = nothing, handles
         output.c_task.callback_func = load_wrapper_function_pointer("julia_callback_func")
     end
 
+    # Tasks must not be garbage collected before starpu_task_wait_for_all is called.
+    # This is necessary in particular for tasks created inside callback functions.
+    lock(mutex)
+    push!(task_list, output)
+    unlock(mutex)
+
     return output
 end
 
@@ -167,8 +182,6 @@ function starpu_task_submit(task :: jl_starpu_task)
         error("Invalid number of handles for task : $(length(task.handles)) where given while codelet has $(task.cl.modes) modes")
     end
 
-    # Prevent task from being garbage collected. This is necessary for tasks created
-    # inside callbacks.
     starpu_task_submit(Ref(task.c_task))
 
     if task.callback_function != nothing
@@ -176,24 +189,19 @@ function starpu_task_submit(task :: jl_starpu_task)
         callback_signal = task.callback_signal
         callback_function = task.callback_function
 
-        @qbthreads for x in 1:1
-            begin
-                # Active waiting loop
-                # We're doing a fake computation on tmp to prevent optimization.
-                tmp = 0
-                while task.callback_signal[1] == 0
-                    tmp += 1
-                end
+        lock(mutex)
+        put!(task_pool) do
 
-                # We've received the signal from the pthread, now execute the callback.
-                callback_function(callback_arg)
+            # Active waiting loop
+            @starpucall(julia_wait_signal, Cvoid, (Ptr{Cvoid},), Base.unsafe_convert(Ptr{Cvoid}, callback_signal))
 
-                # Tell the pthread that the callback is done.
-                callback_signal[1] = 0
+            # We've received the signal from the pthread, now execute the callback.
+            callback_function(callback_arg)
 
-                return callback_signal[1]
-            end
+            # Tell the pthread that the callback is done.
+            callback_signal[1] = 0
         end
+        unlock(mutex)
     end
 end
 
@@ -246,6 +254,10 @@ end
 function starpu_task_wait_for_all()
     @threadcall(@starpufunc(:starpu_task_wait_for_all),
                 Cint, ())
+
+    lock(mutex)
+    empty!(task_list)
+    unlock(mutex)
 end
 
 """