xgemm.c 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /*
  2. * StarPU
  3. * Copyright (C) Université Bordeaux 1, CNRS 2008-2010 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "dw_mult.h"
  17. #define str(s) #s
  18. #define xstr(s) str(s)
  19. #define STARPU_GEMM_STR(name) xstr(STARPU_GEMM(name))
  20. TYPE *A, *B, *C;
  21. starpu_data_handle A_handle, B_handle, C_handle;
  22. /*
  23. * That program should compute C = A * B
  24. *
  25. * A of size (z,y)
  26. * B of size (x,z)
  27. * C of size (x,y)
  28. |---------------|
  29. z | B |
  30. |---------------|
  31. z x
  32. |----| |---------------|
  33. | | | |
  34. | | | |
  35. | A | y | C |
  36. | | | |
  37. | | | |
  38. |----| |---------------|
  39. */
  40. static void check_output(void)
  41. {
  42. /* check results */
  43. /* compute C = C - AB */
  44. CPU_GEMM("N", "N", ydim, xdim, zdim, (TYPE)-1.0, A, ydim, B, zdim, (TYPE)1.0f, C, ydim);
  45. /* make sure C = 0 */
  46. TYPE err;
  47. err = CPU_ASUM(xdim*ydim, C, 1);
  48. int max;
  49. max = CPU_IAMAX(xdim*ydim, C, 1);
  50. fprintf(stderr, "Avg error : %e\n", err/(xdim*ydim));
  51. fprintf(stderr, "Max error : %e\n", C[max]);
  52. }
  53. void callback_func(void *arg)
  54. {
  55. /* do some accounting */
  56. int id = starpu_worker_get_id();
  57. flop_per_worker[id] += BLAS3_FLOP(conf.m, conf.n, conf.k);
  58. ls_per_worker[id] += BLAS3_LS(conf.m, conf.n, conf.k);
  59. }
  60. static void init_problem_data(void)
  61. {
  62. unsigned i,j;
  63. #ifdef STARPU_USE_CUDA
  64. if (pin) {
  65. starpu_data_malloc_pinned_if_possible((void **)&A, zdim*ydim*sizeof(TYPE));
  66. starpu_data_malloc_pinned_if_possible((void **)&B, xdim*zdim*sizeof(TYPE));
  67. starpu_data_malloc_pinned_if_possible((void **)&C, xdim*ydim*sizeof(TYPE));
  68. } else
  69. #endif
  70. {
  71. #ifdef STARPU_HAVE_POSIX_MEMALIGN
  72. posix_memalign((void **)&A, 4096, zdim*ydim*sizeof(TYPE));
  73. posix_memalign((void **)&B, 4096, xdim*zdim*sizeof(TYPE));
  74. posix_memalign((void **)&C, 4096, xdim*ydim*sizeof(TYPE));
  75. #else
  76. A = malloc(zdim*ydim*sizeof(TYPE));
  77. B = malloc(xdim*zdim*sizeof(TYPE));
  78. C = malloc(xdim*ydim*sizeof(TYPE));
  79. #endif
  80. }
  81. /* fill the A and B matrices */
  82. if (norandom) {
  83. for (j=0; j < ydim; j++) {
  84. for (i=0; i < zdim; i++) {
  85. A[j+i*ydim] = (TYPE)(i);
  86. }
  87. }
  88. for (j=0; j < zdim; j++) {
  89. for (i=0; i < xdim; i++) {
  90. B[j+i*zdim] = (TYPE)(j);
  91. }
  92. }
  93. }
  94. else {
  95. for (j=0; j < ydim; j++) {
  96. for (i=0; i < zdim; i++) {
  97. A[j+i*ydim] = (TYPE)(starpu_drand48());
  98. }
  99. }
  100. for (j=0; j < zdim; j++) {
  101. for (i=0; i < xdim; i++) {
  102. B[j+i*zdim] = (TYPE)(starpu_drand48());
  103. }
  104. }
  105. }
  106. for (j=0; j < ydim; j++) {
  107. for (i=0; i < xdim; i++) {
  108. C[j+i*ydim] = (TYPE)(0);
  109. }
  110. }
  111. /* display memory consumption */
  112. fprintf(stderr, "Total memory : %ld MB\n",
  113. ( ydim*zdim*sizeof(TYPE)
  114. + zdim*xdim*sizeof(TYPE)
  115. + ydim*xdim*sizeof(TYPE) )/(1024*1024));
  116. }
  117. static void partition_mult_data(void)
  118. {
  119. starpu_matrix_data_register(&A_handle, 0, (uintptr_t)A,
  120. ydim, ydim, zdim, sizeof(TYPE));
  121. starpu_matrix_data_register(&B_handle, 0, (uintptr_t)B,
  122. zdim, zdim, xdim, sizeof(TYPE));
  123. starpu_matrix_data_register(&C_handle, 0, (uintptr_t)C,
  124. ydim, ydim, xdim, sizeof(TYPE));
  125. starpu_data_set_wt_mask(C_handle, 1<<0);
  126. conf.k = zdim;
  127. conf.m = ydim/nslicesy;
  128. conf.n = xdim/nslicesx;
  129. struct starpu_data_filter f;
  130. f.filter_func = starpu_vertical_block_filter_func;
  131. f.nchildren = nslicesx;
  132. f.get_nchildren = NULL;
  133. f.get_child_ops = NULL;
  134. struct starpu_data_filter f2;
  135. f2.filter_func = starpu_block_filter_func;
  136. f2.nchildren = nslicesy;
  137. f2.get_nchildren = NULL;
  138. f2.get_child_ops = NULL;
  139. starpu_data_partition(B_handle, &f);
  140. starpu_data_partition(A_handle, &f2);
  141. starpu_data_map_filters(C_handle, 2, &f, &f2);
  142. }
  143. static void unpartition_mult_data(void)
  144. {
  145. starpu_data_unpartition(C_handle, 0);
  146. starpu_data_unregister(C_handle);
  147. }
  148. static struct starpu_perfmodel_t gemm_model = {
  149. .type = STARPU_HISTORY_BASED,
  150. #ifdef STARPU_ATLAS
  151. .symbol = STARPU_GEMM_STR(gemm_atlas)
  152. #elif defined(STARPU_GOTO)
  153. .symbol = STARPU_GEMM_STR(gemm_goto)
  154. #else
  155. .symbol = STARPU_GEMM_STR(gemm)
  156. #endif
  157. };
  158. static starpu_codelet cl = {
  159. .where = STARPU_CPU|STARPU_CUDA,
  160. .cpu_func = STARPU_GEMM(cpu_mult),
  161. #ifdef STARPU_USE_CUDA
  162. .cuda_func = STARPU_GEMM(cublas_mult),
  163. #endif
  164. .model = &gemm_model,
  165. .nbuffers = 3
  166. };
  167. static void launch_codelets(void)
  168. {
  169. /* partition the work into slices */
  170. unsigned taskx, tasky;
  171. for (taskx = 0; taskx < nslicesx; taskx++)
  172. {
  173. for (tasky = 0; tasky < nslicesy; tasky++)
  174. {
  175. /* A B[task] = C[task] */
  176. struct starpu_task *task = starpu_task_create();
  177. task->cl = &cl;
  178. task->cl_arg = &conf;
  179. task->cl_arg_size = sizeof(struct block_conf);
  180. /* we have a callback to do some accounting */
  181. task->callback_func = callback_func;
  182. task->callback_arg = NULL;
  183. task->buffers[0].handle = starpu_data_get_sub_data(A_handle, 1, tasky);
  184. task->buffers[0].mode = STARPU_R;
  185. task->buffers[1].handle = starpu_data_get_sub_data(B_handle, 1, taskx);
  186. task->buffers[1].mode = STARPU_R;
  187. task->buffers[2].handle = starpu_data_get_sub_data(C_handle, 2, taskx, tasky);
  188. task->buffers[2].mode = STARPU_RW;
  189. starpu_task_submit(task);
  190. }
  191. }
  192. }
  193. int main(__attribute__ ((unused)) int argc,
  194. __attribute__ ((unused)) char **argv)
  195. {
  196. parse_args(argc, argv);
  197. /* start the runtime */
  198. starpu_init(NULL);
  199. starpu_helper_cublas_init();
  200. init_problem_data();
  201. gettimeofday(&start, NULL);
  202. partition_mult_data();
  203. launch_codelets();
  204. starpu_task_wait_for_all();
  205. gettimeofday(&end, NULL);
  206. double timing = (double)((end.tv_sec - start.tv_sec)*1000000 +
  207. (end.tv_usec - start.tv_usec));
  208. display_stats(timing);
  209. unpartition_mult_data();
  210. if (check)
  211. check_output();
  212. starpu_helper_cublas_shutdown();
  213. starpu_shutdown();
  214. return 0;
  215. }