cholesky_common.jl 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Standard kernels for the Cholesky factorization
  2. # U22 is the gemm update
  3. # U21 is the trsm update
  4. # U11 is the cholesky factorization
  5. @target STARPU_CPU+STARPU_CUDA
  6. @codelet function u11(sub11 :: Matrix{Float32}) :: Nothing
  7. nx :: Int32 = width(sub11)
  8. ld :: Int32 = ld(sub11)
  9. for z in 0:nx-1
  10. lambda11 :: Float32 = sqrt(sub11[z+1,z+1])
  11. sub11[z+1,z+1] = lambda11
  12. alpha ::Float32 = 1.0f0 / lambda11
  13. X :: Vector{Float32} = view(sub11, z+2:z+2+(nx-z-2), z+1)
  14. STARPU_SSCAL(nx-z-1, alpha, X, 1)
  15. alpha = -1.0f0
  16. A :: Matrix{Float32} = view(sub11, z+2:z+2+(nx-z-2), z+2:z+2+(nx-z-2))
  17. STARPU_SSYR("L", nx-z-1, alpha, X, 1, A, ld)
  18. end
  19. return
  20. end
  21. @target STARPU_CPU+STARPU_CUDA
  22. @codelet function u21(sub11 :: Matrix{Float32},
  23. sub21 :: Matrix{Float32}) :: Nothing
  24. ld11 :: Int32 = ld(sub11)
  25. ld21 :: Int32 = ld(sub21)
  26. nx21 :: Int32 = width(sub21)
  27. ny21 :: Int32 = height(sub21)
  28. alpha :: Float32 = 1.0f0
  29. STARPU_STRSM("R", "L", "T", "N", nx21, ny21, alpha, sub11, ld11, sub21, ld21)
  30. return
  31. end
  32. @target STARPU_CPU+STARPU_CUDA
  33. @codelet function u22(left :: Matrix{Float32},
  34. right :: Matrix{Float32},
  35. center :: Matrix{Float32}) :: Nothing
  36. dx :: Int32 = width(center)
  37. dy :: Int32 = height(center)
  38. dz :: Int32 = width(left)
  39. ld21 :: Int32 = ld(left)
  40. ld12 :: Int32 = ld(center)
  41. ld22 :: Int32 = ld(right)
  42. alpha :: Float32 = -1.0f0
  43. beta :: Float32 = 1.0f0
  44. STARPU_SGEMM("N", "T", dy, dx, dz, alpha, left, ld21, right, ld12, beta, center, ld22)
  45. return
  46. end
  47. @inline function tag11(k)
  48. return starpu_tag_t((UInt64(1)<<60) | UInt64(k))
  49. end
  50. @inline function tag21(k, j)
  51. return starpu_tag_t((UInt64(3)<<60) | (UInt64(k)<<32) | UInt64(j))
  52. end
  53. @inline function tag22(k, i, j)
  54. return starpu_tag_t((UInt64(4)<<60) | (UInt64(k)<<32) | (UInt64(i)<<16) | UInt64(j))
  55. end
  56. function check(mat::Matrix{Float32})
  57. size_p = size(mat, 1)
  58. for i in 1:size_p
  59. for j in 1:size_p
  60. if j > i
  61. mat[i, j] = 0.0f0
  62. end
  63. end
  64. end
  65. test_mat ::Matrix{Float32} = zeros(Float32, size_p, size_p)
  66. syrk!('L', 'N', 1.0f0, mat, 0.0f0, test_mat)
  67. for i in 1:size_p
  68. for j in 1:size_p
  69. if j <= i
  70. orig = (1.0f0/(1.0f0+(i-1)+(j-1))) + ((i == j) ? 1.0f0*size_p : 0.0f0)
  71. err = abs(test_mat[i,j] - orig) / orig
  72. if err > 0.0001
  73. got = test_mat[i,j]
  74. expected = orig
  75. error("[$i, $j] -> $got != $expected (err $err)")
  76. end
  77. end
  78. end
  79. end
  80. println(stderr, "Verification successful !")
  81. end
  82. function clean_tags(nblocks)
  83. for k in 1:nblocks
  84. starpu_tag_remove(tag11(k))
  85. for m in k+1:nblocks
  86. starpu_tag_remove(tag21(k, m))
  87. for n in k+1:nblocks
  88. if n <= m
  89. starpu_tag_remove(tag22(k, m, n))
  90. end
  91. end
  92. end
  93. end
  94. end
  95. function main(size_p :: Int, nblocks :: Int; verify = false, verbose = false)
  96. mat :: Matrix{Float32} = zeros(Float32, size_p, size_p)
  97. # create a simple definite positive symetric matrix
  98. # Hilbert matrix h(i,j) = 1/(i+j+1)
  99. for i in 1:size_p
  100. for j in 1:size_p
  101. mat[i, j] = 1.0f0 / (1.0f0+(i-1)+(j-1)) + ((i == j) ? 1.0f0*size_p : 0.0f0)
  102. end
  103. end
  104. if verbose
  105. display(mat)
  106. end
  107. starpu_memory_pin(mat)
  108. t_start = time_ns()
  109. cholesky(mat, size_p, nblocks)
  110. t_end = time_ns()
  111. starpu_memory_unpin(mat)
  112. flop = (1.0*size_p*size_p*size_p)/3.0
  113. time_ms = (t_end-t_start) / 1e6
  114. gflops = flop/(time_ms*1000)/1000
  115. println("$size_p\t$time_ms\t$gflops")
  116. clean_tags(nblocks)
  117. if verbose
  118. display(mat)
  119. end
  120. if verify
  121. check(mat)
  122. end
  123. end