plu_solve.c 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu.h>
  17. #include "pxlu.h"
  18. static STARPU_PLU(compute_ax_block)(unsigned size, unsigned nblocks,
  19. TYPE *block_data, TYPE *sub_x, TYPE *sub_y)
  20. {
  21. CPU_GEMV("N", size/nblocks, size/nblocks, 1.0, block_data, size/nblocks, sub_x, 1, 1.0, sub_y, 1);
  22. }
  23. /* y is only valid on node 0 */
  24. void STARPU_PLU(compute_ax)(unsigned size, TYPE *x, TYPE *y, unsigned nblocks, int rank)
  25. {
  26. /* Create temporary buffers where all MPI processes are going to
  27. * compute Ai x = yi where Ai is the matrix containing the blocks of A
  28. * affected to process i, and 0 everywhere else. We then have y as the
  29. * sum of all yi. */
  30. TYPE *yi = calloc(size, sizeof(TYPE));
  31. /* Compute Aix = yi */
  32. unsigned long i,j;
  33. for (j = 0; j < nblocks; j++)
  34. {
  35. for (i = 0; i < nblocks; i++)
  36. {
  37. if (get_block_rank(i, j) == rank)
  38. {
  39. /* That block belongs to the current MPI process */
  40. TYPE *block_data = STARPU_PLU(get_block)(j, i);
  41. TYPE *sub_x = &x[i*(size/nblocks)];
  42. TYPE *sub_yi = &yi[j*(size/nblocks)];
  43. STARPU_PLU(compute_ax_block)(size, nblocks, block_data, sub_x, sub_yi);
  44. }
  45. }
  46. }
  47. /* Compute the Sum of all yi = y */
  48. MPI_Reduce(yi, y, size, MPI_TYPE, MPI_SUM, 0, MPI_COMM_WORLD);
  49. }