plu_solve.c 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu.h>
  17. #include "pxlu.h"
  18. void STARPU_PLU(display_data_content)(TYPE *data, unsigned blocksize)
  19. {
  20. fprintf(stderr, "DISPLAY BLOCK\n");
  21. unsigned i, j;
  22. for (j = 0; j < blocksize; j++)
  23. {
  24. for (i = 0; i < blocksize; i++)
  25. {
  26. fprintf(stderr, "%f ", data[j+i*blocksize]);
  27. }
  28. fprintf(stderr, "\n");
  29. }
  30. fprintf(stderr, "****\n");
  31. }
  32. static STARPU_PLU(compute_ax_block)(unsigned block_size, TYPE *block_data, TYPE *sub_x, TYPE *sub_y)
  33. {
  34. CPU_GEMV("N", block_size, block_size, 1.0, block_data, block_size, sub_x, 1, 1.0, sub_y, 1);
  35. }
  36. void STARPU_PLU(extract_upper)(unsigned block_size, TYPE *inblock, TYPE *outblock)
  37. {
  38. unsigned li, lj;
  39. for (lj = 0; lj < block_size; lj++)
  40. {
  41. /* Upper block diag is 1 */
  42. outblock[lj*(block_size + 1)] = (TYPE)1.0;
  43. for (li = lj + 1; li < block_size; li++)
  44. {
  45. outblock[lj + li*block_size] = inblock[lj + li*block_size];
  46. }
  47. }
  48. }
  49. static STARPU_PLU(compute_ax_block_upper)(unsigned size, unsigned nblocks,
  50. TYPE *block_data, TYPE *sub_x, TYPE *sub_y)
  51. {
  52. unsigned block_size = size/nblocks;
  53. fprintf(stderr, "KEEP UPPER\n");
  54. STARPU_PLU(display_data_content)(block_data, block_size);
  55. /* Take a copy of the upper part of the diagonal block */
  56. TYPE *upper_block_copy = calloc((block_size)*(block_size), sizeof(TYPE));
  57. STARPU_PLU(extract_upper)(block_size, block_data, upper_block_copy);
  58. STARPU_PLU(display_data_content)(upper_block_copy, block_size);
  59. STARPU_PLU(compute_ax_block)(size/nblocks, upper_block_copy, sub_x, sub_y);
  60. free(upper_block_copy);
  61. }
  62. void STARPU_PLU(extract_lower)(unsigned block_size, TYPE *inblock, TYPE *outblock)
  63. {
  64. unsigned li, lj;
  65. for (lj = 0; lj < block_size; lj++)
  66. {
  67. for (li = 0; li <= lj; li++)
  68. {
  69. outblock[lj + li*block_size] = inblock[lj + li*block_size];
  70. }
  71. }
  72. }
  73. TYPE *STARPU_PLU(reconstruct_matrix)(unsigned size, unsigned nblocks)
  74. {
  75. TYPE *bigmatrix = calloc(size*size, sizeof(TYPE));
  76. unsigned block_size = size/nblocks;
  77. unsigned bi, bj;
  78. for (bj = 0; bj < nblocks; bj++)
  79. for (bi = 0; bi < nblocks; bi++)
  80. {
  81. TYPE *block = STARPU_PLU(get_block)(bj, bi);
  82. //TYPE *block = STARPU_PLU(get_block)(bj, bi);
  83. unsigned j, i;
  84. for (j = 0; j < block_size; j++)
  85. for (i = 0; i < block_size; i++)
  86. {
  87. bigmatrix[(j + bj*block_size)+(i+bi*block_size)*size] =
  88. block[j+i*block_size];
  89. }
  90. }
  91. return bigmatrix;
  92. }
  93. static TYPE *reconstruct_lower(unsigned size, unsigned nblocks)
  94. {
  95. TYPE *lower = calloc(size*size, sizeof(TYPE));
  96. TYPE *bigmatrix = STARPU_PLU(reconstruct_matrix)(size, nblocks);
  97. STARPU_PLU(extract_lower)(size, bigmatrix, lower);
  98. return lower;
  99. }
  100. static TYPE *reconstruct_upper(unsigned size, unsigned nblocks)
  101. {
  102. TYPE *upper = calloc(size*size, sizeof(TYPE));
  103. TYPE *bigmatrix = STARPU_PLU(reconstruct_matrix)(size, nblocks);
  104. STARPU_PLU(extract_upper)(size, bigmatrix, upper);
  105. return upper;
  106. }
  107. void STARPU_PLU(compute_lu_matrix)(unsigned size, unsigned nblocks)
  108. {
  109. fprintf(stderr, "ALL\n\n");
  110. TYPE *all_r = STARPU_PLU(reconstruct_matrix)(size, nblocks);
  111. STARPU_PLU(display_data_content)(all_r, size);
  112. fprintf(stderr, "\nLOWER\n");
  113. TYPE *lower_r = reconstruct_lower(size, nblocks);
  114. STARPU_PLU(display_data_content)(lower_r, size);
  115. fprintf(stderr, "\nUPPER\n");
  116. TYPE *upper_r = reconstruct_upper(size, nblocks);
  117. STARPU_PLU(display_data_content)(upper_r, size);
  118. TYPE *lu_r = calloc(size*size, sizeof(TYPE));
  119. CPU_TRMM("R", "U", "N", "U", size, size, 1.0f, lower_r, size, upper_r, size);
  120. fprintf(stderr, "\nLU\n");
  121. STARPU_PLU(display_data_content)(lower_r, size);
  122. }
  123. static STARPU_PLU(compute_ax_block_lower)(unsigned size, unsigned nblocks,
  124. TYPE *block_data, TYPE *sub_x, TYPE *sub_y)
  125. {
  126. unsigned block_size = size/nblocks;
  127. fprintf(stderr, "KEEP LOWER\n");
  128. STARPU_PLU(display_data_content)(block_data, block_size);
  129. /* Take a copy of the upper part of the diagonal block */
  130. TYPE *lower_block_copy = calloc((block_size)*(block_size), sizeof(TYPE));
  131. STARPU_PLU(extract_lower)(block_size, block_data, lower_block_copy);
  132. STARPU_PLU(display_data_content)(lower_block_copy, block_size);
  133. STARPU_PLU(compute_ax_block)(size/nblocks, lower_block_copy, sub_x, sub_y);
  134. free(lower_block_copy);
  135. }
  136. void STARPU_PLU(compute_lux)(unsigned size, TYPE *x, TYPE *y, unsigned nblocks, int rank)
  137. {
  138. /* Create temporary buffers where all MPI processes are going to
  139. * compute Ui x = yi where Ai is the matrix containing the blocks of U
  140. * affected to process i, and 0 everywhere else. We then have y as the
  141. * sum of all yi. */
  142. TYPE *yi = calloc(size, sizeof(TYPE));
  143. unsigned block_size = size/nblocks;
  144. /* Compute UiX = Yi */
  145. unsigned long i,j;
  146. for (j = 0; j < nblocks; j++)
  147. {
  148. if (get_block_rank(j, j) == rank)
  149. {
  150. TYPE *block_data = STARPU_PLU(get_block)(j, j);
  151. TYPE *sub_x = &x[j*(block_size)];
  152. TYPE *sub_yi = &yi[j*(block_size)];
  153. STARPU_PLU(compute_ax_block_upper)(size, nblocks, block_data, sub_x, sub_yi);
  154. }
  155. for (i = j + 1; i < nblocks; i++)
  156. {
  157. if (get_block_rank(i, j) == rank)
  158. {
  159. /* That block belongs to the current MPI process */
  160. TYPE *block_data = STARPU_PLU(get_block)(j, i);
  161. TYPE *sub_x = &x[i*(block_size)];
  162. TYPE *sub_yi = &yi[j*(block_size)];
  163. STARPU_PLU(compute_ax_block)(size/nblocks, block_data, sub_x, sub_yi);
  164. }
  165. }
  166. }
  167. /* Grab Sum Yi in X */
  168. MPI_Reduce(yi, x, size, MPI_TYPE, MPI_SUM, 0, MPI_COMM_WORLD);
  169. memset(yi, 0, size*sizeof(TYPE));
  170. unsigned ind;
  171. if (rank == 0)
  172. {
  173. fprintf(stderr, "INTERMEDIATE\n");
  174. for (ind = 0; ind < STARPU_MIN(10, size); ind++)
  175. {
  176. fprintf(stderr, "x[%d] = %f\n", ind, (float)x[ind]);
  177. }
  178. fprintf(stderr, "****\n");
  179. }
  180. /* Everyone needs x */
  181. int bcst_ret;
  182. bcst_ret = MPI_Bcast(&x, size, MPI_TYPE, 0, MPI_COMM_WORLD);
  183. STARPU_ASSERT(bcst_ret == MPI_SUCCESS);
  184. /* Compute LiX = Yi (with X = UX) */
  185. for (j = 0; j < nblocks; j++)
  186. {
  187. if (j > 0)
  188. for (i = 0; i < j; i++)
  189. {
  190. if (get_block_rank(i, j) == rank)
  191. {
  192. /* That block belongs to the current MPI process */
  193. TYPE *block_data = STARPU_PLU(get_block)(j, i);
  194. TYPE *sub_x = &x[i*(block_size)];
  195. TYPE *sub_yi = &yi[j*(block_size)];
  196. STARPU_PLU(compute_ax_block)(size/nblocks, block_data, sub_x, sub_yi);
  197. }
  198. }
  199. if (get_block_rank(j, j) == rank)
  200. {
  201. TYPE *block_data = STARPU_PLU(get_block)(j, j);
  202. TYPE *sub_x = &x[j*(block_size)];
  203. TYPE *sub_yi = &yi[j*(block_size)];
  204. STARPU_PLU(compute_ax_block_lower)(size, nblocks, block_data, sub_x, sub_yi);
  205. }
  206. }
  207. /* Grab Sum Yi in Y */
  208. MPI_Reduce(yi, y, size, MPI_TYPE, MPI_SUM, 0, MPI_COMM_WORLD);
  209. free(yi);
  210. }
  211. /* x and y must be valid (at least) on 0 */
  212. void STARPU_PLU(compute_ax)(unsigned size, TYPE *x, TYPE *y, unsigned nblocks, int rank)
  213. {
  214. /* Send x to everyone */
  215. int bcst_ret;
  216. bcst_ret = MPI_Bcast(&x, size, MPI_TYPE, 0, MPI_COMM_WORLD);
  217. STARPU_ASSERT(bcst_ret == MPI_SUCCESS);
  218. if (rank == 0)
  219. {
  220. unsigned ind;
  221. for (ind = 0; ind < STARPU_MIN(10, size); ind++)
  222. fprintf(stderr, "x[%d] = %f\n", ind, (float)x[ind]);
  223. fprintf(stderr, "Compute AX = B\n");
  224. }
  225. /* Create temporary buffers where all MPI processes are going to
  226. * compute Ai x = yi where Ai is the matrix containing the blocks of A
  227. * affected to process i, and 0 everywhere else. We then have y as the
  228. * sum of all yi. */
  229. TYPE *yi = calloc(size, sizeof(TYPE));
  230. /* Compute Aix = yi */
  231. unsigned long i,j;
  232. for (j = 0; j < nblocks; j++)
  233. {
  234. for (i = 0; i < nblocks; i++)
  235. {
  236. if (get_block_rank(i, j) == rank)
  237. {
  238. /* That block belongs to the current MPI process */
  239. TYPE *block_data = STARPU_PLU(get_block)(j, i);
  240. TYPE *sub_x = &x[i*(size/nblocks)];
  241. TYPE *sub_yi = &yi[j*(size/nblocks)];
  242. STARPU_PLU(compute_ax_block)(size/nblocks, block_data, sub_x, sub_yi);
  243. }
  244. }
  245. }
  246. /* Compute the Sum of all yi = y */
  247. MPI_Reduce(yi, y, size, MPI_TYPE, MPI_SUM, 0, MPI_COMM_WORLD);
  248. free(yi);
  249. }