dw_cholesky_kernels.c 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu_config.h>
  17. #include "dw_cholesky.h"
  18. #include "../common/blas.h"
  19. #ifdef STARPU_USE_CUDA
  20. #include <cuda.h>
  21. #include <cuda_runtime.h>
  22. #include <cublas.h>
  23. #endif
  24. /*
  25. * U22
  26. */
  27. static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __attribute__((unused)) void *_args)
  28. {
  29. //printf("22\n");
  30. float *left = (float *)STARPU_GET_MATRIX_PTR(descr[0]);
  31. float *right = (float *)STARPU_GET_MATRIX_PTR(descr[1]);
  32. float *center = (float *)STARPU_GET_MATRIX_PTR(descr[2]);
  33. unsigned dx = STARPU_GET_MATRIX_NY(descr[2]);
  34. unsigned dy = STARPU_GET_MATRIX_NX(descr[2]);
  35. unsigned dz = STARPU_GET_MATRIX_NY(descr[0]);
  36. unsigned ld21 = STARPU_GET_MATRIX_LD(descr[0]);
  37. unsigned ld12 = STARPU_GET_MATRIX_LD(descr[1]);
  38. unsigned ld22 = STARPU_GET_MATRIX_LD(descr[2]);
  39. #ifdef STARPU_USE_CUDA
  40. cublasStatus st;
  41. #endif
  42. switch (s) {
  43. case 0:
  44. SGEMM("N", "T", dy, dx, dz, -1.0f, left, ld21,
  45. right, ld12, 1.0f, center, ld22);
  46. break;
  47. #ifdef STARPU_USE_CUDA
  48. case 1:
  49. cublasSgemm('n', 't', dy, dx, dz,
  50. -1.0f, left, ld21, right, ld12,
  51. 1.0f, center, ld22);
  52. st = cublasGetError();
  53. STARPU_ASSERT(!st);
  54. cudaThreadSynchronize();
  55. break;
  56. #endif
  57. default:
  58. STARPU_ABORT();
  59. break;
  60. }
  61. }
  62. void chol_cpu_codelet_update_u22(void *descr[], void *_args)
  63. {
  64. chol_common_cpu_codelet_update_u22(descr, 0, _args);
  65. }
  66. #ifdef STARPU_USE_CUDA
  67. void chol_cublas_codelet_update_u22(void *descr[], void *_args)
  68. {
  69. chol_common_cpu_codelet_update_u22(descr, 1, _args);
  70. }
  71. #endif// STARPU_USE_CUDA
  72. /*
  73. * U21
  74. */
  75. static inline void chol_common_codelet_update_u21(void *descr[], int s, __attribute__((unused)) void *_args)
  76. {
  77. // printf("21\n");
  78. float *sub11;
  79. float *sub21;
  80. sub11 = (float *)STARPU_GET_MATRIX_PTR(descr[0]);
  81. sub21 = (float *)STARPU_GET_MATRIX_PTR(descr[1]);
  82. unsigned ld11 = STARPU_GET_MATRIX_LD(descr[0]);
  83. unsigned ld21 = STARPU_GET_MATRIX_LD(descr[1]);
  84. unsigned nx21 = STARPU_GET_MATRIX_NY(descr[1]);
  85. unsigned ny21 = STARPU_GET_MATRIX_NX(descr[1]);
  86. switch (s) {
  87. case 0:
  88. STRSM("R", "L", "T", "N", nx21, ny21, 1.0f, sub11, ld11, sub21, ld21);
  89. break;
  90. #ifdef STARPU_USE_CUDA
  91. case 1:
  92. cublasStrsm('R', 'L', 'T', 'N', nx21, ny21, 1.0f, sub11, ld11, sub21, ld21);
  93. cudaThreadSynchronize();
  94. break;
  95. #endif
  96. default:
  97. STARPU_ABORT();
  98. break;
  99. }
  100. }
  101. void chol_cpu_codelet_update_u21(void *descr[], void *_args)
  102. {
  103. chol_common_codelet_update_u21(descr, 0, _args);
  104. }
  105. #ifdef STARPU_USE_CUDA
  106. void chol_cublas_codelet_update_u21(void *descr[], void *_args)
  107. {
  108. chol_common_codelet_update_u21(descr, 1, _args);
  109. }
  110. #endif
  111. /*
  112. * U11
  113. */
  114. static inline void chol_common_codelet_update_u11(void *descr[], int s, __attribute__((unused)) void *_args)
  115. {
  116. // printf("11\n");
  117. float *sub11;
  118. sub11 = (float *)STARPU_GET_MATRIX_PTR(descr[0]);
  119. unsigned nx = STARPU_GET_MATRIX_NY(descr[0]);
  120. unsigned ld = STARPU_GET_MATRIX_LD(descr[0]);
  121. unsigned z;
  122. switch (s) {
  123. case 0:
  124. /*
  125. * - alpha 11 <- lambda 11 = sqrt(alpha11)
  126. * - alpha 21 <- l 21 = alpha 21 / lambda 11
  127. * - A22 <- A22 - l21 trans(l21)
  128. */
  129. for (z = 0; z < nx; z++)
  130. {
  131. float lambda11;
  132. lambda11 = sqrt(sub11[z+z*ld]);
  133. sub11[z+z*ld] = lambda11;
  134. STARPU_ASSERT(lambda11 != 0.0f);
  135. SSCAL(nx - z - 1, 1.0f/lambda11, &sub11[(z+1)+z*ld], 1);
  136. SSYR("L", nx - z - 1, -1.0f,
  137. &sub11[(z+1)+z*ld], 1,
  138. &sub11[(z+1)+(z+1)*ld], ld);
  139. }
  140. break;
  141. #ifdef STARPU_USE_CUDA
  142. case 1:
  143. for (z = 0; z < nx; z++)
  144. {
  145. float lambda11;
  146. cudaMemcpy(&lambda11, &sub11[z+z*ld], sizeof(float), cudaMemcpyDeviceToHost);
  147. cudaStreamSynchronize(0);
  148. STARPU_ASSERT(lambda11 != 0.0f);
  149. lambda11 = sqrt(lambda11);
  150. cublasSetVector(1, sizeof(float), &lambda11, sizeof(float), &sub11[z+z*ld], sizeof(float));
  151. cublasSscal(nx - z - 1, 1.0f/lambda11, &sub11[(z+1)+z*ld], 1);
  152. cublasSsyr('U', nx - z - 1, -1.0f,
  153. &sub11[(z+1)+z*ld], 1,
  154. &sub11[(z+1)+(z+1)*ld], ld);
  155. }
  156. cudaThreadSynchronize();
  157. break;
  158. #endif
  159. default:
  160. STARPU_ABORT();
  161. break;
  162. }
  163. }
  164. void chol_cpu_codelet_update_u11(void *descr[], void *_args)
  165. {
  166. chol_common_codelet_update_u11(descr, 0, _args);
  167. }
  168. #ifdef STARPU_USE_CUDA
  169. void chol_cublas_codelet_update_u11(void *descr[], void *_args)
  170. {
  171. chol_common_codelet_update_u11(descr, 1, _args);
  172. }
  173. #endif// STARPU_USE_CUDA