cholesky_native.jl 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using LinearAlgebra.BLAS
  2. function u11(sub11)
  3. nx = size(sub11, 1)
  4. ld = size(sub11, 1)
  5. for z in 0:nx-1
  6. lambda11::Float32 = sqrt(sub11[z+1,z+1])
  7. sub11[z+1,z+1] = lambda11
  8. if lambda11 == 0.0f0
  9. error("lamda11")
  10. end
  11. X = view(sub11, z+2:z+2+(nx-z-2), z+1)
  12. scal!(nx-z-1, 1.0f0/lambda11, X, 1)
  13. A = view(sub11, z+2:z+2+(nx-z-2), z+2:z+2+(nx-z-2))
  14. syr!('L', -1.0f0, X, A)
  15. end
  16. end
  17. function u21(sub11, sub21)
  18. trsm!('R', 'L', 'T', 'N', 1.0f0, sub11, sub21)
  19. end
  20. function u22(left, right, center)
  21. gemm!('N', 'T', -1.0f0, left, right, 1.0f0, center)
  22. end
  23. function get_block(mat :: Matrix{Float32}, m, n, nblocks)
  24. dim = size(mat, 1)
  25. if dim != size(mat,2)
  26. error("mat must be a square matrix")
  27. end
  28. if dim % nblocks != 0
  29. error("dim must be a multiple of nblocks")
  30. end
  31. stride = Int(dim/nblocks)
  32. return view(mat,
  33. m*stride+1:(m+1)*stride,
  34. n*stride+1:(n+1)*stride)
  35. end
  36. function cholesky(mat :: Matrix{Float32}, size, nblocks)
  37. for k in 0:nblocks-1
  38. sdatakk = get_block(mat, k, k, nblocks)
  39. u11(sdatakk)
  40. for m in k+1:nblocks-1
  41. sdatamk = get_block(mat, m, k, nblocks)
  42. u21(sdatakk, sdatamk)
  43. end
  44. for m in k+1:nblocks-1
  45. sdatamk = get_block(mat, m, k, nblocks)
  46. for n in k+1:nblocks-1
  47. if n <= m
  48. sdatank = get_block(mat, n, k, nblocks)
  49. sdatamn = get_block(mat, m, n, nblocks)
  50. u22(sdatamk, sdatank, sdatamn)
  51. end
  52. end
  53. end
  54. end
  55. end
  56. function check(mat::Matrix{Float32})
  57. size_p = size(mat, 1)
  58. for i in 1:size_p
  59. for j in 1:size_p
  60. if j > i
  61. mat[i, j] = 0.0f0
  62. end
  63. end
  64. end
  65. test_mat ::Matrix{Float32} = zeros(Float32, size_p, size_p)
  66. syrk!('L', 'N', 1.0f0, mat, 0.0f0, test_mat)
  67. for i in 1:size_p
  68. for j in 1:size_p
  69. if j <= i
  70. orig = (1.0f0/(1.0f0+(i-1)+(j-1))) + ((i == j) ? 1.0f0*size_p : 0.0f0)
  71. err = abs(test_mat[i,j] - orig) / orig
  72. if err > 0.0001
  73. got = test_mat[i,j]
  74. expected = orig
  75. error("[$i, $j] -> $got != $expected (err $err)")
  76. end
  77. end
  78. end
  79. end
  80. println("Verification successful !")
  81. end
  82. function main(size_p :: Int, nblocks :: Int, display = false)
  83. mat :: Matrix{Float32} = zeros(Float32, size_p, size_p)
  84. # create a simple definite positive symetric matrix
  85. # Hilbert matrix h(i,j) = 1/(i+j+1)
  86. for i in 1:size_p
  87. for j in 1:size_p
  88. mat[i, j] = 1.0f0 / (1.0f0+(i-1)+(j-1)) + ((i == j) ? 1.0f0*size_p : 0.0f0)
  89. end
  90. end
  91. if display
  92. display(mat)
  93. end
  94. t_start = time_ns()
  95. cholesky(mat, size_p, nblocks)
  96. t_end = time_ns()
  97. flop = (1.0*size_p*size_p*size_p)/3.0
  98. println("# size\tms\tGFlops")
  99. time_ms = (t_end-t_start) / 1e6
  100. gflops = flop/(time_ms*1000)/1000
  101. println("# $size_p\t$time_ms\t$gflops")
  102. if display
  103. display(mat)
  104. end
  105. check(mat)
  106. end
  107. main(1024*20, 8)