strassen_kernels.c 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "strassen.h"
  17. static void mult_common_codelet(void *descr[], int s, __attribute__((unused)) void *arg)
  18. {
  19. float *center = (float *)STARPU_GET_BLAS_PTR(descr[0]);
  20. float *left = (float *)STARPU_GET_BLAS_PTR(descr[1]);
  21. float *right = (float *)STARPU_GET_BLAS_PTR(descr[2]);
  22. unsigned dx = STARPU_GET_BLAS_NX(descr[0]);
  23. unsigned dy = STARPU_GET_BLAS_NY(descr[0]);
  24. unsigned dz = STARPU_GET_BLAS_NX(descr[1]);
  25. unsigned ld21 = STARPU_GET_BLAS_LD(descr[1]);
  26. unsigned ld12 = STARPU_GET_BLAS_LD(descr[2]);
  27. unsigned ld22 = STARPU_GET_BLAS_LD(descr[0]);
  28. switch (s) {
  29. case 0:
  30. cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
  31. dy, dx, dz, -1.0f, left, ld21, right, ld12,
  32. 1.0f, center, ld22);
  33. break;
  34. #ifdef USE_CUDA
  35. case 1:
  36. cublasSgemm('t', 'n', dx, dy, dz,
  37. -1.0f, right, ld12, left, ld21,
  38. 1.0f, center, ld22);
  39. cudaThreadSynchronize();
  40. break;
  41. #endif
  42. default:
  43. STARPU_ABORT();
  44. break;
  45. }
  46. }
  47. void mult_cpu_codelet(void *descr[], void *_args)
  48. {
  49. mult_common_codelet(descr, 0, _args);
  50. }
  51. #ifdef USE_CUDA
  52. void mult_cublas_codelet(void *descr[], void *_args)
  53. {
  54. mult_common_codelet(descr, 1, _args);
  55. }
  56. #endif
  57. static void add_sub_common_codelet(void *descr[], int s, __attribute__((unused)) void *arg, float alpha)
  58. {
  59. /* C = A op B */
  60. float *C = (float *)STARPU_GET_BLAS_PTR(descr[0]);
  61. float *A = (float *)STARPU_GET_BLAS_PTR(descr[1]);
  62. float *B = (float *)STARPU_GET_BLAS_PTR(descr[2]);
  63. unsigned dx = STARPU_GET_BLAS_NX(descr[0]);
  64. unsigned dy = STARPU_GET_BLAS_NY(descr[0]);
  65. unsigned ldA = STARPU_GET_BLAS_LD(descr[1]);
  66. unsigned ldB = STARPU_GET_BLAS_LD(descr[2]);
  67. unsigned ldC = STARPU_GET_BLAS_LD(descr[0]);
  68. // TODO check dim ...
  69. unsigned line;
  70. switch (s) {
  71. case 0:
  72. for (line = 0; line < dy; line++)
  73. {
  74. /* copy line A into C */
  75. cblas_saxpy(dx, 1.0f, &A[line*ldA], 1, &C[line*ldC], 1);
  76. /* add line B to C = A */
  77. cblas_saxpy(dx, alpha, &B[line*ldB], 1, &C[line*ldC], 1);
  78. }
  79. break;
  80. #ifdef USE_CUDA
  81. case 1:
  82. for (line = 0; line < dy; line++)
  83. {
  84. /* copy line A into C */
  85. cublasSaxpy(dx, 1.0f, &A[line*ldA], 1, &C[line*ldC], 1);
  86. /* add line B to C = A */
  87. cublasSaxpy(dx, alpha, &B[line*ldB], 1, &C[line*ldC], 1);
  88. }
  89. cudaThreadSynchronize();
  90. break;
  91. #endif
  92. default:
  93. STARPU_ABORT();
  94. break;
  95. }
  96. }
  97. void sub_cpu_codelet(void *descr[], __attribute__((unused)) void *arg)
  98. {
  99. add_sub_common_codelet(descr, 0, arg, -1.0f);
  100. }
  101. void add_cpu_codelet(void *descr[], __attribute__((unused)) void *arg)
  102. {
  103. add_sub_common_codelet(descr, 0, arg, 1.0f);
  104. }
  105. #ifdef USE_CUDA
  106. void sub_cublas_codelet(void *descr[], __attribute__((unused)) void *arg)
  107. {
  108. add_sub_common_codelet(descr, 1, arg, -1.0f);
  109. }
  110. void add_cublas_codelet(void *descr[], __attribute__((unused)) void *arg)
  111. {
  112. add_sub_common_codelet(descr, 1, arg, 1.0f);
  113. }
  114. #endif
  115. static void self_add_sub_common_codelet(void *descr[], int s, __attribute__((unused)) void *arg, float alpha)
  116. {
  117. /* C +=/-= A */
  118. float *C = (float *)STARPU_GET_BLAS_PTR(descr[0]);
  119. float *A = (float *)STARPU_GET_BLAS_PTR(descr[1]);
  120. unsigned dx = STARPU_GET_BLAS_NX(descr[0]);
  121. unsigned dy = STARPU_GET_BLAS_NY(descr[0]);
  122. unsigned ldA = STARPU_GET_BLAS_LD(descr[1]);
  123. unsigned ldC = STARPU_GET_BLAS_LD(descr[0]);
  124. // TODO check dim ...
  125. unsigned line;
  126. switch (s) {
  127. case 0:
  128. for (line = 0; line < dy; line++)
  129. {
  130. /* add line A to C */
  131. cblas_saxpy(dx, alpha, &A[line*ldA], 1, &C[line*ldC], 1);
  132. }
  133. break;
  134. #ifdef USE_CUDA
  135. case 1:
  136. for (line = 0; line < dy; line++)
  137. {
  138. /* add line A to C */
  139. cublasSaxpy(dx, alpha, &A[line*ldA], 1, &C[line*ldC], 1);
  140. }
  141. cudaThreadSynchronize();
  142. break;
  143. #endif
  144. default:
  145. STARPU_ABORT();
  146. break;
  147. }
  148. }
  149. void self_add_cpu_codelet(void *descr[], __attribute__((unused)) void *arg)
  150. {
  151. self_add_sub_common_codelet(descr, 0, arg, 1.0f);
  152. }
  153. void self_sub_cpu_codelet(void *descr[], __attribute__((unused)) void *arg)
  154. {
  155. self_add_sub_common_codelet(descr, 0, arg, -1.0f);
  156. }
  157. #ifdef USE_CUDA
  158. void self_add_cublas_codelet(void *descr[], __attribute__((unused)) void *arg)
  159. {
  160. self_add_sub_common_codelet(descr, 1, arg, 1.0f);
  161. }
  162. void self_sub_cublas_codelet(void *descr[], __attribute__((unused)) void *arg)
  163. {
  164. self_add_sub_common_codelet(descr, 1, arg, -1.0f);
  165. }
  166. #endif