data.jl 7.8 KB


  1. # StarPU --- Runtime system for heterogeneous multicore architectures.
  2. #
  3. # Copyright (C) 2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. #
  5. # StarPU is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 2.1 of the License, or (at
  8. # your option) any later version.
  9. #
  10. # StarPU is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. #
  14. # See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. #
  16. const StarpuDataHandlePointer = Ptr{Cvoid}
  17. StarpuDataHandle = StarpuDestructible{StarpuDataHandlePointer}
  18. @enum(StarpuDataFilterFunc,
  19. STARPU_MATRIX_FILTER_VERTICAL_BLOCK = 0,
  20. STARPU_MATRIX_FILTER_BLOCK = 1,
  21. STARPU_VECTOR_FILTER_BLOCK = 2,
  22. )
  23. export starpu_data_filter
  24. function starpu_data_filter(filter_func ::StarpuDataFilterFunc, nchildren ::Integer)
  25. output = starpu_data_filter(zero)
  26. output.nchildren = UInt32(nchildren)
  27. if filter_func == STARPU_MATRIX_FILTER_VERTICAL_BLOCK
  28. output.filter_func = Libdl.dlsym(starpu_wrapper_library_handle, "starpu_matrix_filter_vertical_block")
  29. elseif filter_func == STARPU_MATRIX_FILTER_BLOCK
  30. output.filter_func = Libdl.dlsym(starpu_wrapper_library_handle, "starpu_matrix_filter_block")
  31. else filter_func == STARPU_VECTOR_FILTER_BLOCK
  32. output.filter_func = Libdl.dlsym(starpu_wrapper_library_handle, "starpu_vector_filter_block")
  33. end
  34. return output
  35. end
  36. function starpu_memory_pin(data :: Union{Vector{T}, Matrix{T}}) where T
  37. starpu_memory_pin(data, sizeof(data))::Cint
  38. end
  39. function starpu_memory_unpin(data :: Union{Vector{T}, Matrix{T}}) where T
  40. starpu_memory_unpin(data, sizeof(data))::Cint
  41. end
  42. function StarpuNewDataHandle(ptr :: StarpuDataHandlePointer, destr :: Function...) :: StarpuDataHandle
  43. return StarpuDestructible(ptr, destr...)
  44. end
  45. function starpu_data_unregister_pointer(ptr :: StarpuDataHandlePointer)
  46. starpu_data_unregister(ptr)
  47. end
  48. function starpu_data_unregister(handles :: StarpuDataHandle...)
  49. for h in handles
  50. starpu_execute_destructor!(h, starpu_data_unregister_pointer)
  51. end
  52. end
  53. function starpu_data_register(v :: Vector{T}) where T
  54. output = Ref{Ptr{Cvoid}}(0)
  55. data_pointer = pointer(v)
  56. starpu_vector_data_register(output, STARPU_MAIN_RAM, data_pointer, length(v), sizeof(T))
  57. return StarpuNewDataHandle(output[], starpu_data_unregister_pointer)#, [starpu_data_unregister_pointer])
  58. end
  59. function starpu_data_register(m :: Matrix{T}) where T
  60. output = Ref{Ptr{Cvoid}}(0)
  61. data_pointer = pointer(m)
  62. (height, width) = size(m)
  63. starpu_matrix_data_register(output, STARPU_MAIN_RAM, data_pointer, height, height, width, sizeof(T))
  64. return StarpuNewDataHandle(output[], starpu_data_unregister_pointer)#, [starpu_data_unregister_pointer])
  65. end
  66. function starpu_data_register(block :: Array{T,3}) where T
  67. output = Ref{Ptr{Cvoid}}(0)
  68. data_pointer = pointer(block)
  69. (height, width, depth) = size(block)
  70. starpu_block_data_register(output, STARPU_MAIN_RAM, data_pointer, height, height * width, height, width, depth, sizeof(T))
  71. return StarpuNewDataHandle(output[], starpu_data_unregister_pointer)
  72. end
  73. function starpu_data_register(ref :: Ref{T}) where T
  74. output = Ref{Ptr{Cvoid}}(0)
  75. starpu_variable_data_register(output, STARPU_MAIN_RAM, ref, sizeof(T))
  76. return StarpuNewDataHandle(output[], starpu_data_unregister_pointer)
  77. end
  78. function starpu_data_register(x1, x2, next_args...)
  79. handle_1 = starpu_data_register(x1)
  80. handle_2 = starpu_data_register(x2)
  81. next_handles = map(starpu_data_register, next_args)
  82. return [handle_1, handle_2, next_handles...]
  83. end
  84. import Base.getindex
  85. function Base.getindex(handle :: StarpuDataHandle, indexes...)
  86. output = starpu_data_get_sub_data(handle.object, length(indexes),
  87. map(x->x-1, indexes)...)
  88. return StarpuNewDataHandle(output)
  89. end
  90. function starpu_data_unpartition_pointer(ptr :: StarpuDataHandlePointer)
  91. starpu_data_unpartition(ptr, STARPU_MAIN_RAM)
  92. end
  93. function starpu_data_partition(handle :: StarpuDataHandle, filter :: starpu_data_filter)
  94. starpu_add_destructor!(handle, starpu_data_unpartition_pointer)
  95. starpu_data_partition(handle.object, pointer_from_objref(filter))
  96. end
  97. function starpu_data_unpartition(handles :: StarpuDataHandle...)
  98. for h in handles
  99. starpu_execute_destructor!(h, starpu_data_unpartition_pointer)
  100. end
  101. return nothing
  102. end
  103. function starpu_data_map_filters(handle :: StarpuDataHandle, filter :: starpu_data_filter)
  104. starpu_add_destructor!(handle, starpu_data_unpartition_pointer)
  105. starpu_data_map_filters(handle.object, 1, pointer_from_objref(filter))
  106. end
  107. function starpu_data_map_filters(handle :: StarpuDataHandle, filter_1 :: starpu_data_filter, filter_2 :: starpu_data_filter)
  108. starpu_add_destructor!(handle, starpu_data_unpartition_pointer)
  109. starpu_data_map_filters(handle.object, 2, pointer_from_objref(filter_1), pointer_from_objref(filter_2))
  110. end
  111. function starpu_data_get_sequential_consistency_flag(handle :: StarpuDataHandle)
  112. return starpu_data_get_sequential_consistency_flag(handle.object)
  113. end
  114. function starpu_data_set_sequential_consistency_flag(handle :: StarpuDataHandle, flag :: Int)
  115. starpu_data_set_sequential_consistency_flag(handle.object, flag)
  116. end
  117. function starpu_data_acquire_on_node(handle :: StarpuDataHandle, node :: Int, mode)
  118. starpu_data_acquire_on_node(handle.object, node, mode)
  119. end
  120. function starpu_data_release_on_node(handle :: StarpuDataHandle, node :: Int)
  121. starpu_data_release_on_node(handle.object, node)
  122. end
  123. function repl(x::Symbol)
  124. return x
  125. end
  126. function repl(x::Number)
  127. return x
  128. end
  129. function repl(x :: Expr)
  130. if (x.head == :call && x.args[1] == :+)
  131. if (x.args[2] == :_)
  132. return x.args[3]
  133. elseif (x.args[3] == :_)
  134. return x.args[2]
  135. else return Expr(:call,:+,repl(x.args[2]),repl(x.args[3]))
  136. end
  137. elseif (x.head == :call && x.args[1] == :-)
  138. if (x.args[2] == :_)
  139. return Expr(:call,:-,x.args[3])
  140. elseif (x.args[3] == :_)
  141. return x.args[2]
  142. else return Expr(:call,:-,repl(x.args[2]),repl(x.args[3]))
  143. end
  144. else return Expr(:call,x.args[1],repl(x.args[2]),repl(x.args[3]))
  145. end
  146. end
  147. """
  148. Declares a subarray.
  149. Ex : @starpu_filter ha = A[ _:_+1, : ]
  150. """
  151. macro starpu_filter(expr)
  152. #dump(expr, maxdepth=20)
  153. if (expr.head==Symbol("="))
  154. region = expr.args[2]
  155. if (region.head == Symbol("ref"))
  156. farray = expr.args[1]
  157. println("starpu filter")
  158. index = 0
  159. filter2=nothing
  160. filter3=nothing
  161. if (region.args[2]==Symbol(":"))
  162. index = 3
  163. filter2=:(STARPU_MATRIX_FILTER_BLOCK)
  164. elseif (region.args[3] == Symbol(":"))
  165. index = 2
  166. filter3=:(STARPU_MATRIX_FILTER_VERTICAL_BLOCK)
  167. else
  168. end
  169. ex = repl(region.args[index].args[3])
  170. if (region.args[index].args[2] != Symbol("_"))
  171. throw(AssertionError("LHS must be _"))
  172. end
  173. ret = quote
  174. # escape and not global for farray!
  175. $(esc(farray)) = starpu_data_register($(esc(region.args[1])))
  176. starpu_data_partition( $(esc(farray)),starpu_data_filter($(esc(filter)),$(esc(ex))))
  177. end
  178. return ret
  179. else
  180. ret = quote
  181. $(esc(farray))= starpu_data_register($(esc(region.args[1])))
  182. end
  183. dump("coucou"); #dump(region.args[2])
  184. # dump(region.args[2])
  185. # dump(region.args[3])
  186. return ret
  187. end
  188. end
  189. end