cholesky.jl 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # StarPU --- Runtime system for heterogeneous multicore architectures.
  2. #
  3. # Copyright (C) 2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. #
  5. # StarPU is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 2.1 of the License, or (at
  8. # your option) any later version.
  9. #
  10. # StarPU is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. #
  14. # See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. #
  16. using StarPU
  17. using LinearAlgebra.BLAS
  18. # Standard kernels for the Cholesky factorization
  19. # U22 is the gemm update
  20. # U21 is the trsm update
  21. # U11 is the cholesky factorization
  22. @target STARPU_CPU+STARPU_CUDA
  23. @codelet function u11(sub11 :: Matrix{Float32}) :: Nothing
  24. nx :: Int32 = width(sub11)
  25. ld :: Int32 = ld(sub11)
  26. for z in 0:nx-1
  27. lambda11 :: Float32 = sqrt(sub11[z+1,z+1])
  28. sub11[z+1,z+1] = lambda11
  29. alpha ::Float32 = 1.0f0 / lambda11
  30. X :: Vector{Float32} = view(sub11, z+2:z+2+(nx-z-2), z+1)
  31. STARPU_SSCAL(nx-z-1, alpha, X, 1)
  32. alpha = -1.0f0
  33. A :: Matrix{Float32} = view(sub11, z+2:z+2+(nx-z-2), z+2:z+2+(nx-z-2))
  34. STARPU_SSYR("L", nx-z-1, alpha, X, 1, A, ld)
  35. end
  36. return
  37. end
  38. @target STARPU_CPU+STARPU_CUDA
  39. @codelet function u21(sub11 :: Matrix{Float32},
  40. sub21 :: Matrix{Float32}) :: Nothing
  41. ld11 :: Int32 = ld(sub11)
  42. ld21 :: Int32 = ld(sub21)
  43. nx21 :: Int32 = width(sub21)
  44. ny21 :: Int32 = height(sub21)
  45. alpha :: Float32 = 1.0f0
  46. STARPU_STRSM("R", "L", "T", "N", nx21, ny21, alpha, sub11, ld11, sub21, ld21)
  47. return
  48. end
  49. @target STARPU_CPU+STARPU_CUDA
  50. @codelet function u22(left :: Matrix{Float32},
  51. right :: Matrix{Float32},
  52. center :: Matrix{Float32}) :: Nothing
  53. dx :: Int32 = width(center)
  54. dy :: Int32 = height(center)
  55. dz :: Int32 = width(left)
  56. ld21 :: Int32 = ld(left)
  57. ld12 :: Int32 = ld(center)
  58. ld22 :: Int32 = ld(right)
  59. alpha :: Float32 = -1.0f0
  60. beta :: Float32 = 1.0f0
  61. STARPU_SGEMM("N", "T", dy, dx, dz, alpha, left, ld21, right, ld12, beta, center, ld22)
  62. return
  63. end
  64. function cholesky(mat :: Matrix{Float32}, size, nblocks)
  65. perfmodel = starpu_perfmodel(
  66. perf_type = starpu_perfmodel_type(STARPU_HISTORY_BASED),
  67. symbol = "history_perf"
  68. )
  69. cl_11 = starpu_codelet(
  70. cpu_func = CPU_CODELETS["u11"],
  71. # This kernel cannot be translated to CUDA yet.
  72. # cuda_func = CUDA_CODELETS["u11"],
  73. modes = [STARPU_RW],
  74. color = 0xffff00,
  75. perfmodel = perfmodel
  76. )
  77. cl_21 = starpu_codelet(
  78. cpu_func = CPU_CODELETS["u21"],
  79. # cuda_func = CUDA_CODELETS["u21"],
  80. modes = [STARPU_R, STARPU_RW],
  81. color = 0x8080ff,
  82. perfmodel = perfmodel
  83. )
  84. cl_22 = starpu_codelet(
  85. cpu_func = CPU_CODELETS["u22"],
  86. # cuda_func = CUDA_CODELETS["u22"],
  87. modes = [STARPU_R, STARPU_R, STARPU_RW],
  88. color = 0x00ff00,
  89. perfmodel = perfmodel
  90. )
  91. horiz = starpu_data_filter(STARPU_MATRIX_FILTER_BLOCK, nblocks)
  92. vert = starpu_data_filter(STARPU_MATRIX_FILTER_VERTICAL_BLOCK, nblocks)
  93. @starpu_block let
  94. h_mat = starpu_data_register(mat)
  95. starpu_data_map_filters(h_mat, horiz, vert)
  96. for k in 1:nblocks
  97. starpu_iteration_push(k)
  98. task = starpu_task(cl = cl_11, handles = [h_mat[k, k]])
  99. starpu_task_submit(task)
  100. for m in k+1:nblocks
  101. task = starpu_task(cl = cl_21, handles = [h_mat[k, k], h_mat[m, k]])
  102. starpu_task_submit(task)
  103. end
  104. for m in k+1:nblocks
  105. for n in k+1:nblocks
  106. if n <= m
  107. task = starpu_task(cl = cl_22, handles = [h_mat[m, k], h_mat[n, k], h_mat[m, n]])
  108. starpu_task_submit(task)
  109. end
  110. end
  111. end
  112. starpu_iteration_pop()
  113. end
  114. starpu_task_wait_for_all()
  115. end
  116. end
  117. function check(mat::Matrix{Float32})
  118. size_p = size(mat, 1)
  119. for i in 1:size_p
  120. for j in 1:size_p
  121. if j > i
  122. mat[i, j] = 0.0f0
  123. end
  124. end
  125. end
  126. test_mat ::Matrix{Float32} = zeros(Float32, size_p, size_p)
  127. syrk!('L', 'N', 1.0f0, mat, 0.0f0, test_mat)
  128. for i in 1:size_p
  129. for j in 1:size_p
  130. if j <= i
  131. orig = (1.0f0/(1.0f0+(i-1)+(j-1))) + ((i == j) ? 1.0f0*size_p : 0.0f0)
  132. err = abs(test_mat[i,j] - orig) / orig
  133. if err > 0.0001
  134. got = test_mat[i,j]
  135. expected = orig
  136. error("[$i, $j] -> $got != $expected (err $err)")
  137. end
  138. end
  139. end
  140. end
  141. println("Verification successful !")
  142. end
  143. function main(size_p :: Int, nblocks :: Int, verbose = false)
  144. starpu_init()
  145. mat :: Matrix{Float32} = zeros(Float32, size_p, size_p)
  146. # create a simple definite positive symetric matrix
  147. # Hilbert matrix h(i,j) = 1/(i+j+1)
  148. for i in 1:size_p
  149. for j in 1:size_p
  150. mat[i, j] = 1.0f0 / (1.0f0+(i-1)+(j-1)) + ((i == j) ? 1.0f0*size_p : 0.0f0)
  151. end
  152. end
  153. if verbose
  154. display(mat)
  155. end
  156. starpu_memory_pin(mat)
  157. t_start = time_ns()
  158. cholesky(mat, size_p, nblocks)
  159. t_end = time_ns()
  160. starpu_memory_unpin(mat)
  161. flop = (1.0*size_p*size_p*size_p)/3.0
  162. println("# size\tms\tGFlops")
  163. time_ms = (t_end-t_start) / 1e6
  164. gflops = flop/(time_ms*1000)/1000
  165. println("# $size_p\t$time_ms\t$gflops")
  166. if verbose
  167. display(mat)
  168. end
  169. check(mat)
  170. starpu_shutdown()
  171. end
  172. main(1024, 8)