dw_block_spmv_kernels.c 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "dw_block_spmv.h"
  17. /*
  18. * U22
  19. */
  20. static inline void common_block_spmv(starpu_data_interface_t *buffers, int s, __attribute__((unused)) void *_args)
  21. {
  22. //printf("22\n");
  23. float *block = (float *)buffers[0].blas.ptr;
  24. float *in = (float *)buffers[1].vector.ptr;
  25. float *out = (float *)buffers[2].vector.ptr;
  26. unsigned dx = buffers[0].blas.nx;
  27. unsigned dy = buffers[0].blas.ny;
  28. unsigned ld = buffers[0].blas.ld;
  29. switch (s) {
  30. case 0:
  31. cblas_sgemv(CblasRowMajor, CblasNoTrans, dx, dy, 1.0f, block, ld, in, 1, 1.0f, out, 1);
  32. break;
  33. #ifdef USE_CUDA
  34. case 1:
  35. cublasSgemv ('t', dx, dy, 1.0f, block, ld, in, 1, 1.0f, out, 1);
  36. break;
  37. #endif
  38. default:
  39. STARPU_ASSERT(0);
  40. break;
  41. }
  42. }
  43. void core_block_spmv(starpu_data_interface_t *descr, void *_args)
  44. {
  45. // printf("CORE CODELET \n");
  46. common_block_spmv(descr, 0, _args);
  47. }
  48. #ifdef USE_CUDA
  49. void cublas_block_spmv(starpu_data_interface_t *descr, void *_args)
  50. {
  51. // printf("CUBLAS CODELET \n");
  52. common_block_spmv(descr, 1, _args);
  53. }
  54. #endif// USE_CUDA