nf_mm_cl.f90 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. ! StarPU --- Runtime system for heterogeneous multicore architectures.
  2. !
  3. ! Copyright (C) 2017 CNRS
  4. ! Copyright (C) 2016 Inria
  5. !
  6. ! StarPU is free software; you can redistribute it and/or modify
  7. ! it under the terms of the GNU Lesser General Public License as published by
  8. ! the Free Software Foundation; either version 2.1 of the License, or (at
  9. ! your option) any later version.
  10. !
  11. ! StarPU is distributed in the hope that it will be useful, but
  12. ! WITHOUT ANY WARRANTY; without even the implied warranty of
  13. ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  14. !
  15. ! See the GNU Lesser General Public License in COPYING.LGPL for more details.
  16. !
  17. module nf_mm_cl
  18. contains
  19. subroutine mat_disp (m)
  20. ! declared here so it can be used both for the
  21. ! program and for debugging codelet routines
  22. use iso_c_binding ! C interfacing module
  23. implicit none
  24. real(kind=c_double) :: m(:,:)
  25. integer i,j
  26. do i=lbound(m,1),ubound(m,1)
  27. write(*, fmt="(A2) ",advance="no") "| "
  28. do j=lbound(m,2),ubound(m,2)
  29. write(*, fmt="(F6.1,A1) ", advance="no") m(i,j)," "
  30. end do
  31. write(*,*) "|"
  32. end do
  33. write(*,*)
  34. end subroutine
  35. recursive subroutine cl_cpu_mult (buffers, cl_args) bind(C)
  36. use iso_c_binding ! C interfacing module
  37. use fstarpu_mod ! StarPU interfacing module
  38. implicit none
  39. type(c_ptr), value, intent(in) :: buffers, cl_args ! cl_args is unused
  40. real(kind=c_double),pointer :: A(:,:), B(:,:), C(:,:)
  41. integer :: ld_A,nx_A,ny_A
  42. integer :: ld_B,nx_B,ny_B
  43. integer :: ld_C,nx_C,ny_C
  44. integer :: i,j,k
  45. ld_A = fstarpu_matrix_get_ld(buffers, 0)
  46. ld_B = fstarpu_matrix_get_ld(buffers, 1)
  47. ld_C = fstarpu_matrix_get_ld(buffers, 2)
  48. nx_A = fstarpu_matrix_get_nx(buffers, 0)
  49. nx_B = fstarpu_matrix_get_nx(buffers, 1)
  50. nx_C = fstarpu_matrix_get_nx(buffers, 2)
  51. ny_A = fstarpu_matrix_get_ny(buffers, 0)
  52. ny_B = fstarpu_matrix_get_ny(buffers, 1)
  53. ny_C = fstarpu_matrix_get_ny(buffers, 2)
  54. if (ny_C /= ny_B) then
  55. write(*,*) "C -- B column mismatch"
  56. stop 1
  57. end if
  58. if (nx_C /= nx_A) then
  59. write(*,*) "C -- A row mismatch"
  60. stop 1
  61. end if
  62. if (ny_A /= nx_B) then
  63. write(*,*) "A -- B col/row mismatch"
  64. stop 1
  65. end if
  66. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 0), A, shape=[ld_A,ny_A])
  67. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 1), B, shape=[ld_B,ny_B])
  68. call c_f_pointer(fstarpu_matrix_get_ptr(buffers, 2), C, shape=[ld_C,ny_C])
  69. do k = 1, ny_C
  70. do j = 1, nx_C
  71. do i = 1, nx_B
  72. C(j,k) = C(j,k) + A(j,i) * B(i,k)
  73. end do
  74. end do
  75. end do
  76. end subroutine cl_cpu_mult
  77. end module nf_mm_cl