strassen_kernels.c 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "strassen.h"
  17. static void mult_common_codelet(starpu_data_interface_t *buffers, int s, __attribute__((unused)) void *arg)
  18. {
  19. float *center = (float *)buffers[0].blas.ptr;
  20. float *left = (float *)buffers[1].blas.ptr;
  21. float *right = (float *)buffers[2].blas.ptr;
  22. unsigned dx = buffers[0].blas.nx;
  23. unsigned dy = buffers[0].blas.ny;
  24. unsigned dz = buffers[1].blas.nx;
  25. unsigned ld21 = buffers[1].blas.ld;
  26. unsigned ld12 = buffers[2].blas.ld;
  27. unsigned ld22 = buffers[0].blas.ld;
  28. switch (s) {
  29. case 0:
  30. cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
  31. dy, dx, dz, -1.0f, left, ld21, right, ld12,
  32. 1.0f, center, ld22);
  33. break;
  34. #ifdef USE_CUDA
  35. case 1:
  36. cublasSgemm('t', 'n', dx, dy, dz,
  37. -1.0f, right, ld12, left, ld21,
  38. 1.0f, center, ld22);
  39. break;
  40. #endif
  41. default:
  42. STARPU_ASSERT(0);
  43. break;
  44. }
  45. }
  46. void mult_core_codelet(starpu_data_interface_t *descr, void *_args)
  47. {
  48. mult_common_codelet(descr, 0, _args);
  49. }
  50. #ifdef USE_CUDA
  51. void mult_cublas_codelet(starpu_data_interface_t *descr, void *_args)
  52. {
  53. mult_common_codelet(descr, 1, _args);
  54. }
  55. #endif
  56. static void add_sub_common_codelet(starpu_data_interface_t *buffers, int s, __attribute__((unused)) void *arg, float alpha)
  57. {
  58. /* C = A op B */
  59. float *C = (float *)buffers[0].blas.ptr;
  60. float *A = (float *)buffers[1].blas.ptr;
  61. float *B = (float *)buffers[2].blas.ptr;
  62. unsigned dx = buffers[0].blas.nx;
  63. unsigned dy = buffers[0].blas.ny;
  64. unsigned ldA = buffers[1].blas.ld;
  65. unsigned ldB = buffers[2].blas.ld;
  66. unsigned ldC = buffers[0].blas.ld;
  67. // TODO check dim ...
  68. unsigned line;
  69. switch (s) {
  70. case 0:
  71. for (line = 0; line < dy; line++)
  72. {
  73. /* copy line A into C */
  74. cblas_saxpy(dx, 1.0f, &A[line*ldA], 1, &C[line*ldC], 1);
  75. /* add line B to C = A */
  76. cblas_saxpy(dx, alpha, &B[line*ldB], 1, &C[line*ldC], 1);
  77. }
  78. break;
  79. #ifdef USE_CUDA
  80. case 1:
  81. for (line = 0; line < dy; line++)
  82. {
  83. /* copy line A into C */
  84. cublasSaxpy(dx, 1.0f, &A[line*ldA], 1, &C[line*ldC], 1);
  85. /* add line B to C = A */
  86. cublasSaxpy(dx, alpha, &B[line*ldB], 1, &C[line*ldC], 1);
  87. }
  88. break;
  89. #endif
  90. default:
  91. STARPU_ASSERT(0);
  92. break;
  93. }
  94. }
  95. void sub_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  96. {
  97. add_sub_common_codelet(descr, 0, arg, -1.0f);
  98. }
  99. void add_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  100. {
  101. add_sub_common_codelet(descr, 0, arg, 1.0f);
  102. }
  103. #ifdef USE_CUDA
  104. void sub_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  105. {
  106. add_sub_common_codelet(descr, 1, arg, -1.0f);
  107. }
  108. void add_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  109. {
  110. add_sub_common_codelet(descr, 1, arg, 1.0f);
  111. }
  112. #endif
  113. static void self_add_sub_common_codelet(starpu_data_interface_t *buffers, int s, __attribute__((unused)) void *arg, float alpha)
  114. {
  115. /* C +=/-= A */
  116. float *C = (float *)buffers[0].blas.ptr;
  117. float *A = (float *)buffers[1].blas.ptr;
  118. unsigned dx = buffers[0].blas.nx;
  119. unsigned dy = buffers[0].blas.ny;
  120. unsigned ldA = buffers[1].blas.ld;
  121. unsigned ldC = buffers[0].blas.ld;
  122. // TODO check dim ...
  123. unsigned line;
  124. switch (s) {
  125. case 0:
  126. for (line = 0; line < dy; line++)
  127. {
  128. /* add line A to C */
  129. cblas_saxpy(dx, alpha, &A[line*ldA], 1, &C[line*ldC], 1);
  130. }
  131. break;
  132. #ifdef USE_CUDA
  133. case 1:
  134. for (line = 0; line < dy; line++)
  135. {
  136. /* add line A to C */
  137. cublasSaxpy(dx, alpha, &A[line*ldA], 1, &C[line*ldC], 1);
  138. }
  139. break;
  140. #endif
  141. default:
  142. STARPU_ASSERT(0);
  143. break;
  144. }
  145. }
  146. void self_add_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  147. {
  148. self_add_sub_common_codelet(descr, 0, arg, 1.0f);
  149. }
  150. void self_sub_core_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  151. {
  152. self_add_sub_common_codelet(descr, 0, arg, -1.0f);
  153. }
  154. #ifdef USE_CUDA
  155. void self_add_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  156. {
  157. self_add_sub_common_codelet(descr, 1, arg, 1.0f);
  158. }
  159. void self_sub_cublas_codelet(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  160. {
  161. self_add_sub_common_codelet(descr, 1, arg, -1.0f);
  162. }
  163. #endif