cg.c 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. /*
  2. * StarPU
  3. * Copyright (C) Université Bordeaux 1, CNRS 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <math.h>
  17. #include <assert.h>
  18. #include <starpu.h>
  19. #include <common/blas.h>
  20. #ifdef STARPU_USE_CUDA
  21. #include <cuda.h>
  22. #include <cublas.h>
  23. #endif
  24. /*
  25. * Conjugate Gradient
  26. *
  27. * Input:
  28. * - matrix A
  29. * - vector b
  30. * - vector x (starting value)
  31. * - int i_max, error tolerance eps < 1.
  32. * Ouput:
  33. * - vector x
  34. *
  35. * Pseudo code:
  36. *
  37. * i <- 0
  38. * r <- b - Ax
  39. * d <- r
  40. * delta_new <- dot(r,r)
  41. * delta_0 <- delta_new
  42. *
  43. * while (i < i_max && delta_new > eps^2 delta_0)
  44. * {
  45. * q <- Ad
  46. * alpha <- delta_new/dot(d, q)
  47. * x <- x + alpha d
  48. *
  49. * If (i is divisible by 50)
  50. * r <- b - Ax
  51. * else
  52. * r <- r - alpha q
  53. *
  54. * delta_old <- delta_new
  55. * delta_new <- dot(r,r)
  56. * beta <- delta_new/delta_old
  57. * d <- r + beta d
  58. * i <- i + 1
  59. * }
  60. *
  61. */
  62. #include "cg.h"
  63. static int long long n = 1024;
  64. static int nblocks = 8;
  65. static starpu_data_handle A_handle, b_handle, x_handle;
  66. static TYPE *A, *b, *x;
  67. static int i_max = 4000;
  68. static double eps = (10e-14);
  69. static starpu_data_handle r_handle, d_handle, q_handle;
  70. static TYPE *r, *d, *q;
  71. static starpu_data_handle dtq_handle, rtr_handle;
  72. static TYPE dtq, rtr;
  73. /*
  74. * Generate Input data
  75. */
  76. static void generate_random_problem(void)
  77. {
  78. srand48(0xdeadbeef);
  79. int i, j;
  80. starpu_data_malloc_pinned_if_possible((void **)&A, n*n*sizeof(TYPE));
  81. starpu_data_malloc_pinned_if_possible((void **)&b, n*sizeof(TYPE));
  82. starpu_data_malloc_pinned_if_possible((void **)&x, n*sizeof(TYPE));
  83. assert(A && b && x);
  84. /* Create a random matrix (A) and two random vectors (x and b) */
  85. for (j = 0; j < n; j++)
  86. {
  87. b[j] = (TYPE)1.0;
  88. x[j] = (TYPE)0.0;
  89. #if 0
  90. for (i = 0; i < n; i++)
  91. {
  92. A[n*j + i] = (i <= j)?1.0:0.0;
  93. }
  94. #else
  95. /* We take Hilbert matrix that is not well conditionned but definite positive: H(i,j) = 1/(1+i+j) */
  96. for (i = 0; i < n; i++)
  97. {
  98. A[n*j + i] = (TYPE)(1.0/(1.0+i+j));
  99. }
  100. #endif
  101. }
  102. /* Internal vectors */
  103. starpu_data_malloc_pinned_if_possible((void **)&r, n*sizeof(TYPE));
  104. starpu_data_malloc_pinned_if_possible((void **)&d, n*sizeof(TYPE));
  105. starpu_data_malloc_pinned_if_possible((void **)&q, n*sizeof(TYPE));
  106. assert(r && d && q);
  107. memset(r, 0, n*sizeof(TYPE));
  108. memset(d, 0, n*sizeof(TYPE));
  109. memset(q, 0, n*sizeof(TYPE));
  110. }
  111. static void register_data(void)
  112. {
  113. starpu_matrix_data_register(&A_handle, 0, (uintptr_t)A, n, n, n, sizeof(TYPE));
  114. starpu_vector_data_register(&b_handle, 0, (uintptr_t)b, n, sizeof(TYPE));
  115. starpu_vector_data_register(&x_handle, 0, (uintptr_t)x, n, sizeof(TYPE));
  116. starpu_vector_data_register(&r_handle, 0, (uintptr_t)r, n, sizeof(TYPE));
  117. starpu_vector_data_register(&d_handle, 0, (uintptr_t)d, n, sizeof(TYPE));
  118. starpu_vector_data_register(&q_handle, 0, (uintptr_t)q, n, sizeof(TYPE));
  119. starpu_variable_data_register(&dtq_handle, 0, (uintptr_t)&dtq, sizeof(TYPE));
  120. starpu_variable_data_register(&rtr_handle, 0, (uintptr_t)&rtr, sizeof(TYPE));
  121. }
  122. /*
  123. * Data partitioning filters
  124. */
  125. struct starpu_data_filter vector_filter;
  126. struct starpu_data_filter matrix_filter_1;
  127. struct starpu_data_filter matrix_filter_2;
  128. static void partition_data(void)
  129. {
  130. assert(n % nblocks == 0);
  131. /*
  132. * Partition the A matrix
  133. */
  134. /* Partition into contiguous parts */
  135. matrix_filter_1.filter_func = starpu_block_filter_func;
  136. matrix_filter_1.nchildren = nblocks;
  137. /* Partition into non-contiguous parts */
  138. matrix_filter_2.filter_func = starpu_vertical_block_filter_func;
  139. matrix_filter_2.nchildren = nblocks;
  140. /* A is in FORTRAN ordering, starpu_data_get_sub_data(A_handle, 2, i,
  141. * j) designates the block in column i and row j. */
  142. starpu_data_map_filters(A_handle, 2, &matrix_filter_1, &matrix_filter_2);
  143. /*
  144. * Partition the vectors
  145. */
  146. vector_filter.filter_func = starpu_block_filter_func_vector;
  147. vector_filter.nchildren = nblocks;
  148. starpu_data_partition(b_handle, &vector_filter);
  149. starpu_data_partition(x_handle, &vector_filter);
  150. starpu_data_partition(r_handle, &vector_filter);
  151. starpu_data_partition(d_handle, &vector_filter);
  152. starpu_data_partition(q_handle, &vector_filter);
  153. }
  154. /*
  155. * Debug
  156. */
  157. #if 0
  158. static void display_vector(starpu_data_handle handle, TYPE *ptr)
  159. {
  160. unsigned block_size = n / nblocks;
  161. unsigned b, ind;
  162. for (b = 0; b < nblocks; b++)
  163. {
  164. starpu_data_acquire(starpu_data_get_sub_data(handle, 1, b), STARPU_R);
  165. for (ind = 0; ind < block_size; ind++)
  166. {
  167. fprintf(stderr, "%2.2e ", ptr[b*block_size + ind]);
  168. }
  169. fprintf(stderr, "| ");
  170. starpu_data_release(starpu_data_get_sub_data(handle, 1, b));
  171. }
  172. fprintf(stderr, "\n");
  173. }
  174. static void display_matrix(void)
  175. {
  176. unsigned i, j;
  177. for (i = 0; i < n; i++)
  178. {
  179. for (j = 0; j < n; j++)
  180. {
  181. fprintf(stderr, "%2.2e ", A[j*n + i]);
  182. }
  183. fprintf(stderr, "\n");
  184. }
  185. }
  186. #endif
  187. /*
  188. * Main loop
  189. */
  190. static void cg(void)
  191. {
  192. double delta_new, delta_old, delta_0;
  193. double alpha, beta;
  194. int i = 0;
  195. /* r <- b */
  196. copy_handle(r_handle, b_handle, nblocks);
  197. /* r <- r - A x */
  198. gemv_kernel(r_handle, A_handle, x_handle, 1.0, -1.0, nblocks);
  199. /* d <- r */
  200. copy_handle(d_handle, r_handle, nblocks);
  201. /* delta_new = dot(r,r) */
  202. dot_kernel(r_handle, r_handle, rtr_handle, nblocks);
  203. starpu_data_acquire(rtr_handle, STARPU_R);
  204. delta_new = rtr;
  205. delta_0 = delta_new;
  206. starpu_data_release(rtr_handle);
  207. fprintf(stderr, "*************** INITIAL ************ \n");
  208. fprintf(stderr, "Delta 0: %e\n", delta_new);
  209. while ((i < i_max) && ((double)delta_new > (double)(eps*eps*delta_0)))
  210. {
  211. /* q <- A d */
  212. gemv_kernel(q_handle, A_handle, d_handle, 0.0, 1.0, nblocks);
  213. /* dtq <- dot(d,q) */
  214. dot_kernel(d_handle, q_handle, dtq_handle, nblocks);
  215. /* alpha = delta_new / dtq */
  216. starpu_data_acquire(dtq_handle, STARPU_R);
  217. alpha = delta_new/dtq;
  218. starpu_data_release(dtq_handle);
  219. /* x <- x + alpha d */
  220. axpy_kernel(x_handle, d_handle, alpha, nblocks);
  221. if ((i % 50) == 0)
  222. {
  223. /* r <- b */
  224. copy_handle(r_handle, b_handle, nblocks);
  225. /* r <- r - A x */
  226. gemv_kernel(r_handle, A_handle, x_handle, 1.0, -1.0, nblocks);
  227. }
  228. else {
  229. /* r <- r - alpha q */
  230. axpy_kernel(r_handle, q_handle, -alpha, nblocks);
  231. }
  232. /* delta_new = dot(r,r) */
  233. dot_kernel(r_handle, r_handle, rtr_handle, nblocks);
  234. starpu_data_acquire(rtr_handle, STARPU_R);
  235. delta_old = delta_new;
  236. delta_new = rtr;
  237. beta = delta_new / delta_old;
  238. starpu_data_release(rtr_handle);
  239. /* d <- beta d + r */
  240. scal_axpy_kernel(d_handle, beta, r_handle, 1.0, nblocks);
  241. /* We here take the error as ||r||_2 / (n||b||_2) */
  242. double error = sqrt(delta_new/delta_0)/(1.0*n);
  243. fprintf(stderr, "*****************************************\n");
  244. fprintf(stderr, "iter %d DELTA %e - %e\n", i, delta_new, error);
  245. i++;
  246. }
  247. }
  248. static int check(void)
  249. {
  250. return 0;
  251. }
  252. static void parse_args(int argc, char **argv)
  253. {
  254. int i;
  255. for (i = 1; i < argc; i++) {
  256. if (strcmp(argv[i], "-n") == 0) {
  257. n = (int long long)atoi(argv[++i]);
  258. continue;
  259. }
  260. if (strcmp(argv[i], "-maxiter") == 0) {
  261. i_max = atoi(argv[++i]);
  262. continue;
  263. }
  264. if (strcmp(argv[i], "-nblocks") == 0) {
  265. nblocks = atoi(argv[++i]);
  266. continue;
  267. }
  268. }
  269. }
  270. int main(int argc, char **argv)
  271. {
  272. int ret;
  273. parse_args(argc, argv);
  274. starpu_init(NULL);
  275. starpu_helper_cublas_init();
  276. generate_random_problem();
  277. register_data();
  278. partition_data();
  279. cg();
  280. ret = check();
  281. starpu_helper_cublas_shutdown();
  282. starpu_shutdown();
  283. return ret;
  284. }