cuda.jl 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. # StarPU --- Runtime system for heterogeneous multicore architectures.
  2. #
  3. # Copyright (C) 2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. #
  5. # StarPU is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 2.1 of the License, or (at
  8. # your option) any later version.
  9. #
  10. # StarPU is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. #
  14. # See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. #
  16. function is_indep_for_expr(x :: StarpuExpr)
  17. return isa(x, StarpuExprFor) && x.is_independant
  18. end
  19. function extract_init_indep_finish(expr :: StarpuExpr) # TODO : it is not a correct extraction (example : if (cond) {@indep for ...} else {return} would not work)
  20. # better use apply() (NOTE :assert_no_indep_for already exists) to find recursively every for loops
  21. init = StarpuExpr[]
  22. finish = StarpuExpr[]
  23. if is_indep_for_expr(expr)
  24. return init, StarpuIndepFor(expr), finish
  25. end
  26. if !isa(expr, StarpuExprBlock)
  27. return [expr], nothing, finish
  28. end
  29. for i in (1 : length(expr.exprs))
  30. if !is_indep_for_expr(expr.exprs[i])
  31. continue
  32. end
  33. init = expr.exprs[1 : i-1]
  34. indep = StarpuIndepFor(expr.exprs[i])
  35. finish = expr.exprs[i+1 : end]
  36. if any(is_indep_for_expr, finish)
  37. error("Sequence of several independant loops is not allowed") #same it may be tricked by a Block(Indep_for(...))
  38. end
  39. return init, indep, finish
  40. end
  41. return expr.exprs, nothing, finish
  42. end
  43. function analyse_variable_declarations(expr :: StarpuExpr, already_defined :: Vector{StarpuExprTypedVar} = StarpuExprTypedVar[])
  44. undefined_variables = Symbol[]
  45. defined_variable_names = map((x -> x.name), already_defined)
  46. defined_variable_types = map((x -> x.typ), already_defined)
  47. function func_to_apply(x :: StarpuExpr)
  48. if isa(x, StarpuExprFunction)
  49. error("No function declaration allowed in this section")
  50. end
  51. if isa(x, StarpuExprVar) || isa(x, StarpuExprTypedVar)
  52. if !(x.name in defined_variable_names) && !(x.name in undefined_variables)
  53. push!(undefined_variables, x.name)
  54. end
  55. return x
  56. end
  57. if isa(x, StarpuExprAffect) || isa(x, StarpuExprFor)
  58. if isa(x, StarpuExprAffect)
  59. var = x.var
  60. if !isa(var, StarpuExprTypedVar)
  61. return x
  62. end
  63. name = var.name
  64. typ = var.typ
  65. else
  66. name = x.iter
  67. typ = Int64
  68. end
  69. if name in defined_variable_names
  70. error("Multiple definition of variable $name")
  71. end
  72. filter!((sym -> sym != name), undefined_variables)
  73. push!(defined_variable_names, name)
  74. push!(defined_variable_types, typ)
  75. return x
  76. end
  77. return x
  78. end
  79. apply(func_to_apply, expr)
  80. defined_variable = map(StarpuExprTypedVar, defined_variable_names, defined_variable_types)
  81. return defined_variable, undefined_variables
  82. end
  83. function find_variable(name :: Symbol, vars :: Vector{StarpuExprTypedVar})
  84. for x in vars
  85. if x.name == name
  86. return x
  87. end
  88. end
  89. return nothing
  90. end
  91. function add_device_to_interval_call(expr :: StarpuExpr)
  92. function func_to_apply(x :: StarpuExpr)
  93. if isa(x, StarpuExprCall) && x.func == :jlstarpu_interval_size
  94. return StarpuExprCall(:jlstarpu_interval_size__device, x.args)
  95. end
  96. return x
  97. end
  98. return apply(func_to_apply, expr)
  99. end
  100. function translate_cublas(expr :: StarpuExpr)
  101. function func_to_run(x :: StarpuExpr)
  102. # STARPU_BLAS => (CUBLAS, TRANS, FILLMODE, ALPHA, SIDE, DIAG)
  103. blas_to_cublas = Dict(:STARPU_SGEMM => (:cublasSgemm, [1, 2], [], [6, 11], [], []),
  104. :STARPU_DGEMM => (:cublasDgemm, [1, 2], [], [6, 11], [], []),
  105. :STARPU_SGEMV => (:cublasSgemv, [1], [], [4,9], [], []),
  106. :STARPU_DGEMV => (:cublasDgemv, [1], [], [4,9], [], []),
  107. :STARPU_SSCAL => (:cublasSscal, [], [], [2], [], []),
  108. :STARPU_DSCAL => (:cublasDscal, [], [], [2], [], []),
  109. :STARPU_STRSM => (:cublasStrsm, [3], [2], [7], [1], [4]),
  110. :STARPU_DTRSM => (:cublasDtrsm, [3], [2], [7], [1], [4]),
  111. :STARPU_SSYR => (:cublasSsyr, [], [1], [3], [], []),
  112. :STARPU_SSYRK => (:cublasSsyrk, [2], [1], [5,8], [], []),
  113. :STARPU_SGER => (:cublasSger, [], [], [3], [], []),
  114. :STARPU_DGER => (:cublasDger, [], [], [3], [], []),
  115. :STARPU_STRSV => (:cublasStrsv, [2], [1], [], [], [3]),
  116. :STARPU_STRMM => (:cublasStrmm, [3], [2], [7], [1], [4]),
  117. :STARPU_DTRMM => (:cublasDtrmm, [3], [2], [7], [1], [4]),
  118. :STARPU_STRMV => (:cublasStrmv, [2], [1], [], [], [3]),
  119. :STARPU_SAXPY => (:cublasSaxpy, [], [], [2], [], []),
  120. :STARPU_DAXPY => (:cublasDaxpy, [], [], [2], [], []),
  121. :STARPU_SSWAP => (:cublasSswap, [], [], [], [], []),
  122. :STARPU_DSWAP => (:cublasDswap, [], [], [], [], []))
  123. if !(isa(x, StarpuExprCall) && x.func in keys(blas_to_cublas))
  124. return x
  125. end
  126. new_args = x.args
  127. # cublasOperation_t parameters (e.g. StarpuExprValue("N")
  128. for i in blas_to_cublas[x.func][2]
  129. if !isa(new_args[i], StarpuExprValue) || !isa(new_args[i].value, String)
  130. error("Argument $i of ", x.func, " must be a string")
  131. end
  132. value = new_args[i].value
  133. if value == "N" || value == "n"
  134. new_args[i] = StarpuExprVar(:CUBLAS_OP_N)
  135. elseif value == "T" || value == "t"
  136. new_args[i] = StarpuExprVar(:CUBLAS_OP_T)
  137. elseif value == "C" || value == "c"
  138. new_args[i] = StarpuExprVar(:CUBLAS_OP_C)
  139. else
  140. error("Unhandled value for rgument $i of ", x.func, ": ", value,
  141. "expecting (\"N\", \"T\", or \"C\")")
  142. end
  143. end
  144. # cublasFillMode_t parameters (e.g. StarpuExprValue("L")
  145. for i in blas_to_cublas[x.func][3]
  146. if !isa(new_args[i], StarpuExprValue) || !isa(new_args[i].value, String)
  147. error("Argument $i of ", x.func, " must be a string")
  148. end
  149. value = new_args[i].value
  150. if value == "L" || value == "l"
  151. new_args[i] = StarpuExprVar(:CUBLAS_FILL_MODE_LOWER)
  152. elseif value == "U" || value == "u"
  153. new_args[i] = StarpuExprVar(:CUBLAS_FILL_MODE_UPPER)
  154. else
  155. error("Unhandled value for rgument $i of ", x.func, ": ", value,
  156. "expecting (\"L\" or \"U\")")
  157. end
  158. end
  159. # scalar parameters (alpha, beta, ...): alpha -> &alpha
  160. for i in blas_to_cublas[x.func][4]
  161. if !isa(new_args[i], StarpuExprVar)
  162. error("Argument $i of ", x.func, " must be a variable")
  163. end
  164. var_name = new_args[i].name
  165. new_args[i] = StarpuExprVar(Symbol("&$var_name"))
  166. end
  167. # cublasSideMode_t parameters (e.g. StarpuExprValue("L")
  168. for i in blas_to_cublas[x.func][5]
  169. if !isa(new_args[i], StarpuExprValue) || !isa(new_args[i].value, String)
  170. error("Argument $i of ", x.func, " must be a string, got: ", new_args[i])
  171. end
  172. value = new_args[i].value
  173. if value == "L" || value == "l"
  174. new_args[i] = StarpuExprVar(:CUBLAS_SIDE_LEFT)
  175. elseif value == "R" || value == "r"
  176. new_args[i] = StarpuExprVar(:CUBLAS_SIDE_RIGHT)
  177. else
  178. error("Unhandled value for rgument $i of ", x.func, ": ", value,
  179. "expecting (\"L\" or \"R\")")
  180. end
  181. end
  182. # cublasDiag_Typet parameters (e.g. StarpuExprValue("N")
  183. for i in blas_to_cublas[x.func][6]
  184. if !isa(new_args[i], StarpuExprValue) || !isa(new_args[i].value, String)
  185. error("Argument $i of ", x.func, " must be a string")
  186. end
  187. value = new_args[i].value
  188. if value == "N" || value == "n"
  189. new_args[i] = StarpuExprVar(:CUBLAS_DIAG_NON_UNIT)
  190. elseif value == "U" || value == "u"
  191. new_args[i] = StarpuExprVar(:CUBLAS_DIAG_UNIT)
  192. else
  193. error("Unhandled value for rgument $i of ", x.func, ": ", value,
  194. "expecting (\"N\" or \"U\")")
  195. end
  196. end
  197. new_args = [@parse(starpu_cublas_get_local_handle()), x.args...]
  198. status_varname = "status"*rand_string()
  199. status_var = StarpuExprVar(Symbol("cublasStatus_t "*status_varname))
  200. call_expr = StarpuExprCall(blas_to_cublas[x.func][1], new_args)
  201. return StarpuExprBlock([StarpuExprAffect(status_var, call_expr),
  202. starpu_parse(Meta.parse("""if $status_varname != CUBLAS_STATUS_SUCCESS
  203. STARPU_CUBLAS_REPORT_ERROR($status_varname)
  204. end""")),
  205. @parse cudaStreamSynchronize(starpu_cuda_get_local_stream())])
  206. end
  207. return apply(func_to_run, expr)
  208. end
  209. function get_all_assignments(cpu_instr)
  210. ret = StarpuExpr[]
  211. function func_to_run(x :: StarpuExpr)
  212. if isa(x, StarpuExprAffect)
  213. push!(ret, x)
  214. end
  215. return x
  216. end
  217. apply(func_to_run, cpu_instr)
  218. return ret
  219. end
  220. function get_all_buffer_vars(cpu_instr)
  221. ret = StarpuExprTypedVar[]
  222. assignments = get_all_assignments(cpu_instr)
  223. for x in assignments
  224. var = x.var
  225. expr = x.expr
  226. if isa(expr, StarpuExprCall) && expr.func in [:STARPU_MATRIX_GET_PTR, :STARPU_VECTOR_GET_PTR]
  227. push!(ret, var)
  228. end
  229. end
  230. return ret
  231. end
  232. function get_all_buffer_stores(cpu_instr, vars)
  233. ret = StarpuExprAffect[]
  234. function func_to_run(x :: StarpuExpr)
  235. if isa(x, StarpuExprAffect) && isa(x.var, StarpuExprRef) && isa(x.var.ref, StarpuExprVar) &&
  236. x.var.ref.name in map(x -> x.name, vars)
  237. push!(ret, x)
  238. end
  239. return x
  240. end
  241. apply(func_to_run, cpu_instr)
  242. return ret
  243. end
  244. function get_all_buffer_refs(cpu_instr, vars)
  245. ret = []
  246. current_instr = nothing
  247. InstrTy = Union{StarpuExprAffect,
  248. StarpuExprCall,
  249. StarpuExprCudaCall,
  250. StarpuExprFor,
  251. StarpuExprIf,
  252. StarpuExprIfElse,
  253. StarpuExprReturn,
  254. StarpuExprBreak,
  255. StarpuExprWhile}
  256. parent = nothing
  257. function func_to_run(x :: StarpuExpr)
  258. if isa(x, InstrTy) && !(isa(x, StarpuExprCall) && x.func in [:(+), :(-), :(*), :(/), :(%), :(<), :(<=), :(==), :(!=), :(>=), :(>), :sqrt])
  259. current_instr = x
  260. end
  261. if isa(x, StarpuExprRef) && isa(x.ref, StarpuExprVar) && x.ref.name in map(x -> x.name, vars) && # var[...]
  262. !isa(parent, StarpuExprAddress) && # filter &var[..]
  263. !(isa(current_instr, StarpuExprAffect) && current_instr.var == x) # filter lhs ref
  264. push!(ret, (current_instr, x))
  265. end
  266. parent = x
  267. return x
  268. end
  269. visit_preorder(func_to_run, cpu_instr)
  270. return ret
  271. end
  272. function transform_cuda_device_loadstore(cpu_instr :: StarpuExprBlock)
  273. # Get all CUDA buffer pointers
  274. buffer_vars = get_all_buffer_vars(cpu_instr)
  275. buffer_types = Dict{Symbol, Type}()
  276. for var in buffer_vars
  277. buffer_types[var.name] = var.typ
  278. end
  279. # Get all store to a CUDA buffer
  280. stores = get_all_buffer_stores(cpu_instr, buffer_vars)
  281. # Get all load from CUDA buffer
  282. loads = get_all_buffer_refs(cpu_instr, buffer_vars)
  283. # Replace each load L:
  284. # L: ... buffer[id]
  285. # With the following instruction block:
  286. # Type varX
  287. # cudaMemcpy(&varX, &buffer[id], sizeof(Type), cudaMemcpyDeviceToHost)
  288. # L: ... varX
  289. for l in loads
  290. (instr, ref) = l
  291. block = []
  292. buffer = ref.ref.name
  293. varX = "var"*rand_string()
  294. type = buffer_types[Symbol(buffer)]
  295. ctype = starpu_type_traduction(eltype(type))
  296. push!(block, StarpuExprTypedVar(Symbol(varX), eltype(type)))
  297. push!(block, StarpuExprCall(:cudaMemcpy,
  298. [StarpuExprAddress(StarpuExprVar(Symbol(varX))),
  299. StarpuExprAddress(ref),
  300. StarpuExprVar(Symbol("sizeof($ctype)")),
  301. StarpuExprVar(:cudaMemcpyDeviceToHost)]))
  302. push!(block, substitute(instr, ref, StarpuExprVar(Symbol("$varX"))))
  303. cpu_instr = substitute(cpu_instr, instr, StarpuExprBlock(block))
  304. end
  305. # Replace each Store S:
  306. # S: buffer[id] = expr
  307. # With the following instruction block:
  308. # Type varX
  309. # varX = expr
  310. # cudaMemcpy(&buffer[id], &varX, sizeof(Type), cudaMemcpyHostToDevice)
  311. for s in stores
  312. block = []
  313. buffer = s.var.ref.name
  314. varX = "var"*rand_string()
  315. type = buffer_types[Symbol(buffer)]
  316. ctype = starpu_type_traduction(eltype(type))
  317. push!(block, StarpuExprTypedVar(Symbol(varX), eltype(type)))
  318. push!(block, StarpuExprAffect(StarpuExprVar(Symbol("$varX")), s.expr))
  319. push!(block, StarpuExprCall(:cudaMemcpy,
  320. [StarpuExprAddress(s.var),
  321. StarpuExprAddress(StarpuExprVar(Symbol(varX))),
  322. StarpuExprVar(Symbol("sizeof($ctype)")),
  323. StarpuExprVar(:cudaMemcpyHostToDevice)]))
  324. cpu_instr = substitute(cpu_instr, s, StarpuExprBlock(block))
  325. end
  326. return cpu_instr
  327. end
  328. function transform_to_cuda_kernel(func :: StarpuExprFunction)
  329. cpu_func = transform_to_cpu_kernel(func)
  330. init, indep, finish = extract_init_indep_finish(cpu_func.body)
  331. cpu_instr = init
  332. kernel = nothing
  333. # Generate a CUDA kernel only if there is an independent loop (@parallel macro).
  334. if (indep != nothing)
  335. prekernel_instr, kernel_args, kernel_instr = analyse_sets(indep)
  336. kernel_call = StarpuExprCudaCall(:cudaKernel, (@parse nblocks), (@parse THREADS_PER_BLOCK), StarpuExpr[])
  337. cpu_instr = vcat(cpu_instr, prekernel_instr)
  338. kernel_instr = vcat(kernel_instr, indep.body)
  339. indep_for_def, indep_for_undef = analyse_variable_declarations(StarpuExprBlock(kernel_instr), kernel_args)
  340. prekernel_def, prekernel_undef = analyse_variable_declarations(StarpuExprBlock(cpu_instr), cpu_func.args)
  341. for undef_var in indep_for_undef
  342. found_var = find_variable(undef_var, prekernel_def)
  343. if found_var == nothing # TODO : error then ?
  344. continue
  345. end
  346. push!(kernel_args, found_var)
  347. end
  348. call_args = map((x -> StarpuExprVar(x.name)), kernel_args)
  349. kernelname=Symbol("KERNEL_",func.func);
  350. cuda_call = StarpuExprCudaCall(kernelname, (@parse nblocks), (@parse THREADS_PER_BLOCK), call_args)
  351. push!(cpu_instr, cuda_call)
  352. push!(cpu_instr, @parse cudaStreamSynchronize(starpu_cuda_get_local_stream()))
  353. kernel = StarpuExprFunction(Nothing, kernelname, kernel_args, StarpuExprBlock(kernel_instr))
  354. kernel = add_device_to_interval_call(kernel)
  355. kernel = flatten_blocks(kernel)
  356. end
  357. cpu_instr = vcat(cpu_instr, finish)
  358. cpu_instr = StarpuExprBlock(cpu_instr)
  359. cpu_instr = transform_cuda_device_loadstore(cpu_instr)
  360. prekernel_name = Symbol("CUDA_", func.func)
  361. prekernel = StarpuExprFunction(Nothing, prekernel_name, cpu_func.args, cpu_instr)
  362. prekernel = translate_cublas(prekernel)
  363. prekernel = flatten_blocks(prekernel)
  364. return prekernel, kernel
  365. end
  366. struct StarpuIndepFor
  367. iters :: Vector{Symbol}
  368. sets :: Vector{StarpuExprInterval}
  369. body :: StarpuExpr
  370. end
  371. function assert_no_indep_for(expr :: StarpuExpr)
  372. function func_to_run(x :: StarpuExpr)
  373. if (isa(x, StarpuExprFor) && x.is_independant)
  374. error("Invalid usage of intricated @indep for loops")
  375. end
  376. return x
  377. end
  378. return apply(func_to_run, expr)
  379. end
  380. function StarpuIndepFor(expr :: StarpuExprFor)
  381. if !expr.is_independant
  382. error("For expression must be prefixed by @indep")
  383. end
  384. iters = []
  385. sets = []
  386. for_loop = expr
  387. while isa(for_loop, StarpuExprFor) && for_loop.is_independant
  388. push!(iters, for_loop.iter)
  389. push!(sets, for_loop.set)
  390. for_loop = for_loop.body
  391. while (isa(for_loop, StarpuExprBlock) && length(for_loop.exprs) == 1)
  392. for_loop = for_loop.exprs[1]
  393. end
  394. end
  395. return StarpuIndepFor(iters, sets, assert_no_indep_for(for_loop))
  396. end
  397. function translate_index_code(dims :: Vector{StarpuExprVar})
  398. ndims = length(dims)
  399. if ndims == 0
  400. error("No dimension specified")
  401. end
  402. prod = StarpuExprValue(1)
  403. output = StarpuExpr[]
  404. reversed_dim = reverse(dims)
  405. thread_index_patern = @parse € :: Int64 = (€ / €) % €
  406. thread_id = @parse THREAD_ID
  407. for i in (1 : ndims)
  408. index_lvalue = StarpuExprVar(Symbol(:kernel_ids__index_, ndims - i + 1))
  409. expr = replace_pattern(thread_index_patern, index_lvalue, thread_id, prod, reversed_dim[i])
  410. push!(output, expr)
  411. prod = StarpuExprCall(:(*), [prod, reversed_dim[i]])
  412. end
  413. thread_id_pattern = @parse begin
  414. € :: Int64 = blockIdx.x * blockDim.x + threadIdx.x
  415. if (€ >= €)
  416. return
  417. end
  418. end
  419. bound_verif = replace_pattern(thread_id_pattern, thread_id, thread_id, prod)
  420. push!(output, bound_verif)
  421. return reverse(output)
  422. end
  423. function kernel_index_declarations(ind_for :: StarpuIndepFor)
  424. pre_kernel_instr = StarpuExpr[]
  425. kernel_args = StarpuExprTypedVar[]
  426. kernel_instr = StarpuExpr[]
  427. decl_pattern = @parse € :: Int64 = €
  428. interv_size_decl_pattern = @parse € :: Int64 = jlstarpu_interval_size(€, €, €)
  429. iter_pattern = @parse € :: Int64 = € + € * €
  430. dims = StarpuExprVar[]
  431. ker_instr_to_add_later_on = StarpuExpr[]
  432. for k in (1 : length(ind_for.sets))
  433. set = ind_for.sets[k]
  434. start_var = starpu_parse(Symbol(:kernel_ids__start_, k))
  435. start_decl = replace_pattern(decl_pattern, start_var, set.start)
  436. step_var = starpu_parse(Symbol(:kernel_ids__step_, k))
  437. step_decl = replace_pattern(decl_pattern, step_var, set.step)
  438. dim_var = starpu_parse(Symbol(:kernel_ids__dim_, k))
  439. dim_decl = replace_pattern(interv_size_decl_pattern, dim_var, start_var, step_var, set.stop)
  440. push!(dims, dim_var)
  441. push!(pre_kernel_instr, start_decl, step_decl, dim_decl)
  442. push!(kernel_args, StarpuExprTypedVar(start_var.name, Int64))
  443. push!(kernel_args, StarpuExprTypedVar(step_var.name, Int64))
  444. push!(kernel_args, StarpuExprTypedVar(dim_var.name, Int64))
  445. iter_var = starpu_parse(ind_for.iters[k])
  446. index_var = starpu_parse(Symbol(:kernel_ids__index_, k))
  447. iter_decl = replace_pattern(iter_pattern, iter_var, start_var, index_var, step_var)
  448. push!(ker_instr_to_add_later_on, iter_decl)
  449. end
  450. return dims, ker_instr_to_add_later_on, pre_kernel_instr , kernel_args, kernel_instr
  451. end
  452. function analyse_sets(ind_for :: StarpuIndepFor)
  453. decl_pattern = @parse € :: Int64 = €
  454. nblocks_decl_pattern = @parse € :: Int64 = (€ + THREADS_PER_BLOCK - 1)/THREADS_PER_BLOCK
  455. dims, ker_instr_to_add, pre_kernel_instr, kernel_args, kernel_instr = kernel_index_declarations(ind_for)
  456. dim_prod = @parse 1
  457. for d in dims
  458. dim_prod = StarpuExprCall(:(*), [dim_prod, d])
  459. end
  460. nthreads_var = @parse nthreads
  461. nthreads_decl = replace_pattern(decl_pattern, nthreads_var, dim_prod)
  462. push!(pre_kernel_instr, nthreads_decl)
  463. nblocks_var = @parse nblocks
  464. nblocks_decl = replace_pattern(nblocks_decl_pattern, nblocks_var, nthreads_var)
  465. push!(pre_kernel_instr, nblocks_decl)
  466. index_decomposition = translate_index_code(dims)
  467. push!(kernel_instr, index_decomposition...)
  468. push!(kernel_instr, ker_instr_to_add...)
  469. return pre_kernel_instr, kernel_args, kernel_instr
  470. end