cpu_mult.c 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2019 Mael Keryell
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <stdint.h>
  17. #include <starpu.h>
  18. /*
  19. * The codelet is passed 3 matrices, the "descr" union-type field gives a
  20. * description of the layout of those 3 matrices in the local memory (ie. RAM
  21. * in the case of CPU, GPU frame buffer in the case of GPU etc.). Since we have
  22. * registered data with the "matrix" data interface, we use the matrix macros.
  23. */
  24. void cpu_mult(void *descr[], void *arg)
  25. {
  26. (void)arg;
  27. float *subA, *subB, *subC;
  28. uint32_t nxC, nyC, nyA;
  29. uint32_t ldA, ldB, ldC;
  30. /* .blas.ptr gives a pointer to the first element of the local copy */
  31. subA = (float *)STARPU_MATRIX_GET_PTR(descr[0]);
  32. subB = (float *)STARPU_MATRIX_GET_PTR(descr[1]);
  33. subC = (float *)STARPU_MATRIX_GET_PTR(descr[2]);
  34. /* .blas.nx is the number of rows (consecutive elements) and .blas.ny
  35. * is the number of lines that are separated by .blas.ld elements (ld
  36. * stands for leading dimension).
  37. * NB: in case some filters were used, the leading dimension is not
  38. * guaranteed to be the same in main memory (on the original matrix)
  39. * and on the accelerator! */
  40. nxC = STARPU_MATRIX_GET_NX(descr[2]);
  41. nyC = STARPU_MATRIX_GET_NY(descr[2]);
  42. nyA = STARPU_MATRIX_GET_NY(descr[0]);
  43. ldA = STARPU_MATRIX_GET_LD(descr[0]);
  44. ldB = STARPU_MATRIX_GET_LD(descr[1]);
  45. ldC = STARPU_MATRIX_GET_LD(descr[2]);
  46. /* we assume a FORTRAN-ordering! */
  47. unsigned i,j,k;
  48. for (i = 0; i < nyC; i++)
  49. {
  50. for (j = 0; j < nxC; j++)
  51. {
  52. float sum = 0.0;
  53. for (k = 0; k < nyA; k++)
  54. {
  55. sum += subA[j+k*ldA]*subB[k+i*ldB];
  56. }
  57. subC[j + i*ldC] = sum;
  58. }
  59. }
  60. }