task.jl 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # StarPU --- Runtime system for heterogeneous multicore architectures.
  2. #
  3. # Copyright (C) 2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. #
  5. # StarPU is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 2.1 of the License, or (at
  8. # your option) any later version.
  9. #
  10. # StarPU is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. #
  14. # See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. #
  16. using ThreadPools
  17. mutable struct jl_starpu_codelet
  18. c_codelet :: starpu_codelet
  19. perfmodel :: starpu_perfmodel
  20. cpu_func :: Union{String, STARPU_BLAS}
  21. cuda_func :: Union{String, STARPU_BLAS}
  22. opencl_func :: String
  23. modes
  24. end
  25. global codelet_list = Vector{jl_starpu_codelet}()
  26. function starpu_codelet(;
  27. cpu_func :: Union{String, STARPU_BLAS, Cvoid} = "",
  28. cuda_func :: Union{String, STARPU_BLAS, Cvoid} = "",
  29. opencl_func :: String = "",
  30. modes = [],
  31. perfmodel :: starpu_perfmodel,
  32. where_to_execute :: Union{Cvoid, UInt32} = nothing,
  33. color :: UInt32 = 0x00000000
  34. )
  35. if (length(modes) > STARPU_NMAXBUFS)
  36. error("Codelet has too much buffers ($(length(modes)) but only $STARPU_NMAXBUFS are allowed)")
  37. end
  38. if (where_to_execute == nothing)
  39. real_where = ((cpu_func != nothing) * STARPU_CPU) | ((cuda_func != nothing) * STARPU_CUDA)
  40. else
  41. real_where = where_to_execute
  42. end
  43. output = jl_starpu_codelet(starpu_codelet(zero), perfmodel, cpu_func, cuda_func, opencl_func, modes)
  44. ## TODO: starpu_codelet_init
  45. output.c_codelet.where = real_where
  46. for i in 1:length(modes)
  47. output.c_codelet.modes[i] = modes[i]
  48. end
  49. output.c_codelet.nbuffers = length(modes)
  50. output.c_codelet.model = pointer_from_objref(perfmodel)
  51. output.c_codelet.color = color
  52. if typeof(cpu_func) == STARPU_BLAS
  53. output.cpu_func = cpu_blas_codelets[cpu_func]
  54. output.c_codelet.cpu_func = load_wrapper_function_pointer(output.cpu_func)
  55. else
  56. output.c_codelet.cpu_func = load_starpu_function_pointer(get(CPU_CODELETS, cpu_func, ""))
  57. end
  58. if typeof(cuda_func) == STARPU_BLAS
  59. output.cuda_func = cuda_blas_codelets[cuda_func]
  60. output.c_codelet.cuda_func = load_wrapper_function_pointer(output.cuda_func)
  61. output.c_codelet.cuda_flags[1] = STARPU_CUDA_ASYNC
  62. else
  63. output.c_codelet.cuda_func = load_starpu_function_pointer(get(CUDA_CODELETS, cuda_func, ""))
  64. end
  65. output.c_codelet.opencl_func = load_starpu_function_pointer("")
  66. # Codelets must not be garbage collected before starpu shutdown is called.
  67. lock(mutex)
  68. push!(codelet_list, output)
  69. unlock(mutex)
  70. return output
  71. end
  72. mutable struct jl_starpu_task
  73. cl :: jl_starpu_codelet
  74. handles :: Vector{StarpuDataHandle}
  75. handle_pointers :: Vector{StarpuDataHandlePointer}
  76. synchronous :: Bool
  77. cl_arg # type depends on codelet
  78. callback_signal :: Vector{Cint}
  79. callback_function :: Union{Cvoid, Function}
  80. callback_arg
  81. c_task :: starpu_task
  82. end
  83. task_list = Vector{jl_starpu_task}()
  84. """
  85. starpu_task(; cl :: jl_starpu_codelet, handles :: Vector{StarpuDataHandle}, cl_arg :: Ref)
  86. Creates a new task which will run the specified codelet on handle buffers and cl_args data
  87. """
  88. function starpu_task(;
  89. cl :: Union{Cvoid, jl_starpu_codelet} = nothing,
  90. handles :: Vector{StarpuDataHandle} = StarpuDataHandle[],
  91. cl_arg = (),
  92. callback :: Union{Cvoid, Function} = nothing,
  93. callback_arg = nothing,
  94. tag :: Union{Cvoid, starpu_tag_t} = nothing,
  95. tag_only :: Union{Cvoid, starpu_tag_t} = nothing,
  96. sequential_consistency = true,
  97. detach = 1,
  98. color :: Union{Cvoid, UInt32} = nothing,
  99. where :: Union{Cvoid, Int32} = nothing)
  100. if (cl == nothing)
  101. error("\"cl\" field can't be empty when creating a StarpuTask")
  102. end
  103. output = jl_starpu_task(cl, handles, map((x -> x.object), handles), false, nothing, Vector{Cint}(undef, 1), callback, callback_arg, starpu_task(zero))
  104. # handle scalar_parameters
  105. codelet_name = ""
  106. if isa(cl.cpu_func, String) && cl.cpu_func != ""
  107. codelet = cl.cpu_func
  108. elseif isa(cl.gpu_func, String) && cl.gpu_func != ""
  109. codelet = cl.gpu_func
  110. end
  111. scalar_parameters = get(CODELETS_SCALARS, codelet_name, nothing)
  112. if scalar_parameters != nothing
  113. nb_scalar_required = length(scalar_parameters)
  114. nb_scalar_provided = tuple_len(cl_arg)
  115. if (nb_scalar_provided != nb_scalar_required)
  116. error("$nb_scalar_provided scalar parameters provided but $nb_scalar_required are required by $codelet_name.")
  117. end
  118. output.cl_arg = create_param_struct_from_clarg(codelet_name, cl_arg)
  119. else
  120. output.cl_arg = cl_arg
  121. end
  122. starpu_task_init(Ref(output.c_task))
  123. output.c_task.cl = pointer_from_objref(cl.c_codelet)
  124. output.c_task.synchronous = false
  125. output.c_task.sequential_consistency = sequential_consistency
  126. output.c_task.detach = detach
  127. ## TODO: check num handles equals num codelet buffers
  128. for i in 1:length(handles)
  129. output.c_task.handles[i] = output.handle_pointers[i]
  130. end
  131. if tuple_len(cl_arg) > 0
  132. output.c_task.cl_arg = Base.unsafe_convert(Ptr{Cvoid}, Ref(output.cl_arg))
  133. output.c_task.cl_arg_size = sizeof(output.cl_arg)
  134. end
  135. # callback
  136. if output.callback_function != nothing
  137. output.callback_signal[1] = 0
  138. output.c_task.callback_arg = Base.unsafe_convert(Ptr{Cvoid}, output.callback_signal)
  139. output.c_task.callback_func = load_wrapper_function_pointer("julia_callback_func")
  140. end
  141. if tag != nothing
  142. output.c_task.tag_id = tag
  143. output.c_task.use_tag = 1
  144. end
  145. if tag_only != nothing
  146. output.c_task.tag_id = tag_only
  147. end
  148. if color != nothing
  149. output.c_task.color = color
  150. end
  151. if where != nothing
  152. output.c_task.where = where
  153. end
  154. # Tasks must not be garbage collected before starpu_task_wait_for_all is called.
  155. # This is necessary in particular for tasks created inside callback functions.
  156. lock(mutex)
  157. push!(task_list, output)
  158. unlock(mutex)
  159. return output
  160. end
  161. function create_param_struct_from_clarg(codelet_name, cl_arg)
  162. struct_params_name = CODELETS_PARAMS_STRUCT[codelet_name]
  163. if struct_params_name == false
  164. error("structure name not found in CODELET_PARAMS_STRUCT")
  165. end
  166. nb_scalar_provided = length(cl_arg)
  167. create_struct_param_str = "output = $struct_params_name("
  168. for i in 1:nb_scalar_provided-1
  169. arg = cl_arg[i]
  170. create_struct_param_str *= "$arg, "
  171. end
  172. if (nb_scalar_provided > 0)
  173. arg = cl_arg[nb_scalar_provided]
  174. create_struct_param_str *= "$arg"
  175. end
  176. create_struct_param_str *= ")"
  177. eval(Meta.parse(create_struct_param_str))
  178. return output
  179. end
  180. """
  181. Launches task execution, if "synchronous" task field is set to "false", call
  182. returns immediately
  183. """
  184. function starpu_task_submit(task :: jl_starpu_task)
  185. if (length(task.handles) != length(task.cl.modes))
  186. error("Invalid number of handles for task : $(length(task.handles)) where given while codelet has $(task.cl.modes) modes")
  187. end
  188. starpu_task_submit(Ref(task.c_task))
  189. if task.callback_function != nothing
  190. callback_arg = task.callback_arg
  191. callback_signal = task.callback_signal
  192. callback_function = task.callback_function
  193. lock(mutex)
  194. put!(task_pool) do
  195. # Active waiting loop
  196. @starpucall(julia_wait_signal, Cvoid, (Ptr{Cvoid},), Base.unsafe_convert(Ptr{Cvoid}, callback_signal))
  197. # We've received the signal from the pthread, now execute the callback.
  198. callback_function(callback_arg)
  199. # Tell the pthread that the callback is done.
  200. callback_signal[1] = 0
  201. end
  202. unlock(mutex)
  203. end
  204. end
  205. function starpu_modes(x :: Symbol)
  206. if (x == Symbol("STARPU_RW"))
  207. return STARPU_RW
  208. elseif (x == Symbol("STARPU_R"))
  209. return STARPU_R
  210. else return STARPU_W
  211. end
  212. end
  213. default_codelet = Dict{String, jl_starpu_codelet}()
  214. default_perfmodel = Dict{String, starpu_perfmodel}()
  215. function get_default_perfmodel(name)
  216. if name in keys(default_perfmodel)
  217. return default_perfmodel[name]
  218. end
  219. perfmodel = starpu_perfmodel(
  220. perf_type = starpu_perfmodel_type(STARPU_HISTORY_BASED),
  221. symbol = name
  222. )
  223. default_perfmodel[name] = perfmodel
  224. return perfmodel
  225. end
  226. function get_default_codelet(codelet_name, perfmodel, modes) :: jl_starpu_codelet
  227. if codelet_name in keys(default_codelet)
  228. return default_codelet[codelet_name]
  229. end
  230. cl = starpu_codelet(
  231. cpu_func = codelet_name in keys(CPU_CODELETS) ? codelet_name : "",
  232. cuda_func = codelet_name in keys(CUDA_CODELETS) ? codelet_name : "",
  233. modes = modes,
  234. perfmodel = perfmodel,
  235. )
  236. default_codelet[codelet_name] = cl
  237. return cl
  238. end
  239. function starpu_task_insert(;
  240. codelet_name :: Union{Cvoid, String} = nothing,
  241. cl :: Union{Cvoid, jl_starpu_codelet} = nothing,
  242. perfmodel :: Union{starpu_perfmodel, Cvoid} = nothing,
  243. handles :: Vector{StarpuDataHandle} = StarpuDataHandle[],
  244. cl_arg = (),
  245. callback :: Union{Cvoid, Function} = nothing,
  246. callback_arg = nothing,
  247. tag :: Union{Cvoid, starpu_tag_t} = nothing,
  248. tag_only :: Union{Cvoid, starpu_tag_t} = nothing,
  249. sequential_consistency = true,
  250. detach = 1,
  251. where :: Union{Cvoid, Int32} = nothing,
  252. color :: Union{Cvoid, UInt32} = nothing,
  253. modes = nothing)
  254. if cl == nothing && codelet_name == nothing
  255. error("At least one of the two parameters codelet_name or cl must be provided when calling starpu_task_insert.")
  256. end
  257. if cl == nothing && modes == nothing
  258. error("Modes must be defined when calling starpu_task_insert without a codelet.")
  259. end
  260. if perfmodel == nothing
  261. perfmodel = get_default_perfmodel(codelet_name == nothing ? "default" : codelet_name)
  262. end
  263. if cl == nothing
  264. cl = get_default_codelet(codelet_name, perfmodel, modes)
  265. end
  266. task = starpu_task(cl = cl, handles = handles, cl_arg = cl_arg, callback = callback,
  267. callback_arg = callback_arg, tag = tag, tag_only = tag_only,
  268. sequential_consistency = sequential_consistency,
  269. detach = detach, color = color, where = where)
  270. starpu_task_submit(task)
  271. end
  272. """
  273. Creates and submits an asynchronous task running cl Codelet function.
  274. Ex : @starpu_async_cl cl(handle1, handle2)
  275. """
  276. macro starpu_async_cl(expr, modes, cl_arg=(), color ::UInt32=0x00000000)
  277. if (!isa(expr, Expr) || expr.head != :call)
  278. error("Invalid task submit syntax")
  279. end
  280. if (!isa(expr, Expr)||modes.head != :vect)
  281. error("Invalid task submit syntax")
  282. end
  283. perfmodel = starpu_perfmodel(
  284. perf_type = starpu_perfmodel_type(STARPU_HISTORY_BASED),
  285. symbol = "history_perf"
  286. )
  287. println(CPU_CODELETS[string(expr.args[1])])
  288. cl = starpu_codelet(
  289. cpu_func = string(expr.args[1]),
  290. cuda_func = string(expr.args[1]),
  291. #opencl_func="ocl_matrix_mult",
  292. ### TODO: CORRECT !
  293. modes = map((x -> starpu_modes(x)),modes.args),
  294. perfmodel = perfmodel,
  295. color = color
  296. )
  297. handles = Expr(:vect, expr.args[2:end]...)
  298. #dump(handles)
  299. quote
  300. task = starpu_task(cl = $(esc(cl)), handles = $(esc(handles)), cl_arg=$(esc(cl_arg)))
  301. starpu_task_submit(task)
  302. end
  303. end
  304. function starpu_task_wait(task :: jl_starpu_task)
  305. @threadcall(@starpufunc(:starpu_task_wait),
  306. Cint, (Ptr{Cvoid},), Ref(task.c_task))
  307. # starpu_task_wait(Ref(task.c_task))
  308. end
  309. """
  310. Blocks until every submitted task has finished.
  311. """
  312. function starpu_task_wait_for_all()
  313. @threadcall(@starpufunc(:starpu_task_wait_for_all),
  314. Cint, ())
  315. lock(mutex)
  316. empty!(task_list)
  317. unlock(mutex)
  318. end
  319. """
  320. Blocks until every submitted task has finished.
  321. Ex : @starpu_sync_tasks begin
  322. [...]
  323. starpu_task_submit(task)
  324. [...]
  325. end
  326. TODO : Make the macro only wait for tasks declared inside the following expression.
  327. (similar mechanism as @starpu_block)
  328. """
  329. macro starpu_sync_tasks(expr)
  330. quote
  331. $(esc(expr))
  332. starpu_task_wait_for_all()
  333. end
  334. end
  335. function starpu_task_destroy(task :: jl_starpu_task)
  336. starpu_task_destroy(Ref(task.c_task))
  337. end