vector_scal.jl 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import Libdl
  2. using StarPU
  3. using LinearAlgebra
  4. @target STARPU_CPU+STARPU_CUDA
  5. @codelet function vector_scal(m::Int32, v :: Vector{Float32}, k :: Float32, l :: Float32) :: Float32
  6. N :: Int32 = length(v)
  7. # Naive version
  8. @parallel for i in (1 : N)
  9. v[i] = v[i] * m + l + k
  10. end
  11. end
  12. starpu_init()
  13. function vector_scal_with_starpu(v :: Vector{Float32}, m :: Int32, k :: Float32, l :: Float32)
  14. tmin=0
  15. @starpu_block let
  16. hV = starpu_data_register(v)
  17. tmin=0
  18. perfmodel = StarpuPerfmodel(
  19. perf_type = STARPU_HISTORY_BASED,
  20. symbol = "history_perf"
  21. )
  22. cl = StarpuCodelet(
  23. cpu_func = CPU_CODELETS["vector_scal"],
  24. # cuda_func = CUDA_CODELETS["vector_scal"],
  25. #opencl_func="ocl_matrix_mult",
  26. modes = [STARPU_RW],
  27. perfmodel = perfmodel
  28. )
  29. for i in (1 : 1)
  30. t=time_ns()
  31. @starpu_sync_tasks begin
  32. handles = [hV]
  33. task = StarpuTask(cl = cl, handles = handles, cl_arg=(m, k, l))
  34. starpu_task_submit(task)
  35. end
  36. # @starpu_sync_tasks for task in (1:1)
  37. # @starpu_async_cl vector_scal(hV, STARPU_RW, [m, k, l])
  38. # end
  39. t=time_ns()-t
  40. if (tmin==0 || tmin>t)
  41. tmin=t
  42. end
  43. end
  44. end
  45. return tmin
  46. end
  47. function compute_times(io,start_dim, step_dim, stop_dim)
  48. for size in (start_dim : step_dim : stop_dim)
  49. V = Array(rand(Cfloat, size))
  50. starpu_memory_pin(V)
  51. m :: Int32 = 10
  52. k :: Float32 = 2.
  53. l :: Float32 = 3.
  54. println("INPUT ", V[1:10])
  55. mt = vector_scal_with_starpu(V, m, k, l)
  56. starpu_memory_unpin(V)
  57. println("OUTPUT ", V[1:10])
  58. println(io,"$size $mt")
  59. println("$size $mt")
  60. end
  61. end
  62. io=open(ARGS[1],"w")
  63. compute_times(io,1024,1024,4096)
  64. close(io)
  65. starpu_shutdown()