nf_mm_cl.f90 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. ! StarPU --- Runtime system for heterogeneous multicore architectures.
  2. !
  3. ! Copyright (C) 2016 Inria
  4. !
  5. ! StarPU is free software; you can redistribute it and/or modify
  6. ! it under the terms of the GNU Lesser General Public License as published by
  7. ! the Free Software Foundation; either version 2.1 of the License, or (at
  8. ! your option) any later version.
  9. !
  10. ! StarPU is distributed in the hope that it will be useful, but
  11. ! WITHOUT ANY WARRANTY; without even the implied warranty of
  12. ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. !
  14. ! See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. module nf_mm_cl
  16. contains
  17. subroutine mat_disp (m)
  18. ! declared here so it can be used both for the
  19. ! program and for debugging codelet routines
  20. use iso_c_binding ! C interfacing module
  21. implicit none
  22. real(kind=c_double) :: m(:,:)
  23. integer i,j
  24. do i=lbound(m,1),ubound(m,1)
  25. write(*, fmt="(A2) ",advance="no") "| "
  26. do j=lbound(m,2),ubound(m,2)
  27. write(*, fmt="(F6.1,A1) ", advance="no") m(i,j)," "
  28. end do
  29. write(*,*) "|"
  30. end do
  31. write(*,*)
  32. end subroutine
  33. recursive subroutine cl_cpu_mult (buffers, cl_args) bind(C)
  34. use iso_c_binding ! C interfacing module
  35. use fstarpu_mod ! StarPU interfacing module
  36. implicit none
  37. type(c_ptr), value, intent(in) :: buffers, cl_args ! cl_args is unused
  38. real(kind=c_double),pointer :: A(:,:), B(:,:), C(:,:)
  39. integer :: ld_A,nx_A,ny_A
  40. integer :: ld_B,nx_B,ny_B
  41. integer :: ld_C,nx_C,ny_C
  42. integer :: i,j,k
  43. ld_A = fstarpu_matrix_get_ld(buffers, 0)
  44. ld_B = fstarpu_matrix_get_ld(buffers, 1)
  45. ld_C = fstarpu_matrix_get_ld(buffers, 2)
  46. nx_A = fstarpu_matrix_get_nx(buffers, 0)
  47. nx_B = fstarpu_matrix_get_nx(buffers, 1)
  48. nx_C = fstarpu_matrix_get_nx(buffers, 2)
  49. ny_A = fstarpu_matrix_get_ny(buffers, 0)
  50. ny_B = fstarpu_matrix_get_ny(buffers, 1)
  51. ny_C = fstarpu_matrix_get_ny(buffers, 2)
  52. if (ny_C /= ny_B) then
  53. write(*,*) "C -- B column mismatch"
  54. stop 1
  55. end if
  56. if (nx_C /= nx_A) then
  57. write(*,*) "C -- A row mismatch"
  58. stop 1
  59. end if
  60. if (ny_A /= nx_B) then
  61. write(*,*) "A -- B col/row mismatch"
  62. stop 1
  63. end if
  64. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 0), A, shape=[ld_A,ny_A])
  65. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 1), B, shape=[ld_B,ny_B])
  66. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 2), C, shape=[ld_C,ny_C])
  67. do k = 1, ny_C
  68. do j = 1, nx_C
  69. do i = 1, nx_B
  70. C(j,k) = C(j,k) + A(j,i) * B(i,k)
  71. end do
  72. end do
  73. end do
  74. end subroutine cl_cpu_mult
  75. end module nf_mm_cl