cg.h 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /*
  2. * StarPU
  3. * Copyright (C) Université Bordeaux 1, CNRS 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #ifndef __STARPU_EXAMPLE_CG_H__
  17. #define __STARPU_EXAMPLE_CG_H__
  18. #include <starpu.h>
  19. #include <math.h>
  20. #include <common/blas.h>
  21. #include <cuda.h>
  22. #include <cublas.h>
  23. #include <starpu.h>
  24. #define DOUBLE
  25. #ifdef DOUBLE
  26. #define TYPE double
  27. #define GEMV DGEMV
  28. #define DOT DDOT
  29. #define GEMV DGEMV
  30. #define AXPY DAXPY
  31. #define SCAL DSCAL
  32. #define cublasdot cublasDdot
  33. #define cublasscal cublasDscal
  34. #define cublasaxpy cublasDaxpy
  35. #define cublasgemv cublasDgemv
  36. #else
  37. #define TYPE float
  38. #define GEMV SGEMV
  39. #define DOT SDOT
  40. #define GEMV SGEMV
  41. #define AXPY SAXPY
  42. #define SCAL SSCAL
  43. #define cublasdot cublasSdot
  44. #define cublasscal cublasSscal
  45. #define cublasaxpy cublasSaxpy
  46. #define cublasgemv cublasSgemv
  47. #endif
  48. void dot_kernel(starpu_data_handle v1,
  49. starpu_data_handle v2,
  50. starpu_data_handle s,
  51. unsigned nblocks);
  52. void gemv_kernel(starpu_data_handle v1,
  53. starpu_data_handle matrix,
  54. starpu_data_handle v2,
  55. TYPE p1, TYPE p2,
  56. unsigned nblocks);
  57. void axpy_kernel(starpu_data_handle v1,
  58. starpu_data_handle v2, TYPE p1,
  59. unsigned nblocks);
  60. void scal_axpy_kernel(starpu_data_handle v1, TYPE p1,
  61. starpu_data_handle v2, TYPE p2,
  62. unsigned nblocks);
  63. void copy_handle(starpu_data_handle dst,
  64. starpu_data_handle src,
  65. unsigned nblocks);
  66. #endif // __STARPU_EXAMPLE_CG_H__