gemm_bare.jl 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # StarPU --- Runtime system for heterogeneous multicore architectures.
  2. #
  3. # Copyright (C) 2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. #
  5. # StarPU is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 2.1 of the License, or (at
  8. # your option) any later version.
  9. #
  10. # StarPU is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. #
  14. # See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. #
  16. using StarPU
  17. using LinearAlgebra.BLAS
  18. @target STARPU_CPU+STARPU_CUDA
  19. @codelet function gemm(A :: Matrix{Float32}, B :: Matrix{Float32}, C :: Matrix{Float32}, alpha :: Float32, beta :: Float32) :: Nothing
  20. M :: Int32 = height(A)
  21. N :: Int32 = width(B)
  22. K :: Int32 = width(A)
  23. lda :: Int32 = ld(A)
  24. ldb :: Int32 = ld(B)
  25. ldc :: Int32 = ld(C)
  26. STARPU_SGEMM("N", "N", M, N, K, alpha, A, lda, B, ldb, beta, C, ldc)
  27. return
  28. end
  29. function multiply_with_starpu(A :: Matrix{Float32}, B :: Matrix{Float32}, C :: Matrix{Float32}, alpha :: Float32, beta :: Float32, nslicesx, nslicesy)
  30. scale= 3
  31. tmin=0
  32. hA,hB,hC = starpu_data_register(A, B, C)
  33. tmin=0
  34. perfmodel = starpu_perfmodel(
  35. perf_type = starpu_perfmodel_type(STARPU_HISTORY_BASED),
  36. symbol = "gemm"
  37. )
  38. cl = starpu_codelet(
  39. cpu_func = "gemm",
  40. cuda_func = "",
  41. modes =[STARPU_R,STARPU_R,STARPU_RW],
  42. perfmodel = perfmodel,
  43. )
  44. task = starpu_task(cl = cl, handles =[hA,hB,hC], cl_arg = (alpha,beta), callback = nothing,
  45. callback_arg = nothing, tag = nothing, tag_only = nothing,
  46. sequential_consistency = true,
  47. detach = 1, color = nothing, where = nothing)
  48. for i in (1 : 10 )
  49. t=time_ns()
  50. starpu_task_submit(Ref(task.c_task))
  51. #starpu_task_submit(task)
  52. starpu_task_wait_for_all()
  53. t=time_ns()-t
  54. if (tmin==0 || tmin>t)
  55. tmin=t
  56. end
  57. end
  58. starpu_data_unregister(hA)
  59. starpu_data_unregister(hB)
  60. starpu_data_unregister(hC)
  61. return tmin
  62. end
  63. function approximately_equals(
  64. A :: Matrix{Cfloat},
  65. B :: Matrix{Cfloat},
  66. eps = 1e-2
  67. )
  68. (height, width) = size(A)
  69. for j in (1 : width)
  70. for i in (1 : height)
  71. if (abs(A[i,j] - B[i,j]) > eps * max(abs(B[i,j]), abs(A[i,j])))
  72. println("A[$i,$j] : $(A[i,j]), B[$i,$j] : $(B[i,j])")
  73. return false
  74. end
  75. end
  76. end
  77. return true
  78. end
  79. function check(expected, A, B, C, alpha, beta)
  80. for i in 1 : 10
  81. gemm!('N', 'N', alpha, A, B, beta, expected)
  82. end
  83. height,width = size(C)
  84. for i in 1:height
  85. for j in 1:width
  86. got = C[i, j]
  87. exp = expected[i, j]
  88. err = abs(exp - got) / exp
  89. if err > 0.0001
  90. error("[$i] -> $got != $exp (err $err)")
  91. end
  92. end
  93. end
  94. end
  95. function compute_times(io,start_dim, step_dim, stop_dim, nslicesx, nslicesy)
  96. for dim in (start_dim : step_dim : stop_dim)
  97. A = Array(rand(Cfloat, dim, dim))
  98. B = Array(rand(Cfloat, dim, dim))
  99. C = zeros(Float32, dim, dim)
  100. C_ref = copy(C)
  101. starpu_memory_pin(A)
  102. starpu_memory_pin(B)
  103. starpu_memory_pin(C)
  104. alpha = 4.0f0
  105. beta = 2.0f0
  106. mt = multiply_with_starpu(A, B, C, alpha, beta, nslicesx, nslicesy)
  107. gflop = 2 * dim * dim * dim * 1.e-9
  108. gflops = gflop / (mt * 1.e-9)
  109. size=dim*dim*dim*4*3/1024/1024
  110. println(io,"$dim $gflops")
  111. println("$dim $gflops")
  112. starpu_memory_unpin(A)
  113. starpu_memory_unpin(B)
  114. starpu_memory_unpin(C)
  115. #check(C_ref, A, B, C, alpha, beta)
  116. end
  117. end
  118. if size(ARGS, 1) < 1
  119. filename="x.dat"
  120. else
  121. filename=ARGS[1]
  122. end
  123. starpu_init()
  124. starpu_cublas_init()
  125. io=open(filename,"w")
  126. compute_times(io,64,512,4096,1,1)
  127. close(io)
  128. starpu_shutdown()