nf_mm_cl.f90 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. ! StarPU --- Runtime system for heterogeneous multicore architectures.
  2. !
  3. ! Copyright (C) 2016-2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. !
  5. ! StarPU is free software; you can redistribute it and/or modify
  6. ! it under the terms of the GNU Lesser General Public License as published by
  7. ! the Free Software Foundation; either version 2.1 of the License, or (at
  8. ! your option) any later version.
  9. !
  10. ! StarPU is distributed in the hope that it will be useful, but
  11. ! WITHOUT ANY WARRANTY; without even the implied warranty of
  12. ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. !
  14. ! See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. !
  16. module nf_mm_cl
  17. contains
  18. subroutine mat_disp (m)
  19. ! declared here so it can be used both for the
  20. ! program and for debugging codelet routines
  21. use iso_c_binding ! C interfacing module
  22. implicit none
  23. real(kind=c_double) :: m(:,:)
  24. integer i,j
  25. do i=lbound(m,1),ubound(m,1)
  26. write(*, fmt="(A2) ",advance="no") "| "
  27. do j=lbound(m,2),ubound(m,2)
  28. write(*, fmt="(F6.1,A1) ", advance="no") m(i,j)," "
  29. end do
  30. write(*,*) "|"
  31. end do
  32. write(*,*)
  33. end subroutine
  34. recursive subroutine cl_cpu_mult (buffers, cl_args) bind(C)
  35. use iso_c_binding ! C interfacing module
  36. use fstarpu_mod ! StarPU interfacing module
  37. implicit none
  38. type(c_ptr), value, intent(in) :: buffers, cl_args ! cl_args is unused
  39. real(kind=c_double),pointer :: A(:,:), B(:,:), C(:,:)
  40. integer :: ld_A,nx_A,ny_A
  41. integer :: ld_B,nx_B,ny_B
  42. integer :: ld_C,nx_C,ny_C
  43. integer :: i,j,k
  44. ld_A = fstarpu_matrix_get_ld(buffers, 0)
  45. ld_B = fstarpu_matrix_get_ld(buffers, 1)
  46. ld_C = fstarpu_matrix_get_ld(buffers, 2)
  47. nx_A = fstarpu_matrix_get_nx(buffers, 0)
  48. nx_B = fstarpu_matrix_get_nx(buffers, 1)
  49. nx_C = fstarpu_matrix_get_nx(buffers, 2)
  50. ny_A = fstarpu_matrix_get_ny(buffers, 0)
  51. ny_B = fstarpu_matrix_get_ny(buffers, 1)
  52. ny_C = fstarpu_matrix_get_ny(buffers, 2)
  53. if (ny_C /= ny_B) then
  54. write(*,*) "C -- B column mismatch"
  55. stop 1
  56. end if
  57. if (nx_C /= nx_A) then
  58. write(*,*) "C -- A row mismatch"
  59. stop 1
  60. end if
  61. if (ny_A /= nx_B) then
  62. write(*,*) "A -- B col/row mismatch"
  63. stop 1
  64. end if
  65. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 0), A, shape=[ld_A,ny_A])
  66. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 1), B, shape=[ld_B,ny_B])
  67. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 2), C, shape=[ld_C,ny_C])
  68. do k = 1, ny_C
  69. do j = 1, nx_C
  70. do i = 1, nx_B
  71. C(j,k) = C(j,k) + A(j,i) * B(i,k)
  72. end do
  73. end do
  74. end do
  75. end subroutine cl_cpu_mult
  76. end module nf_mm_cl