xgemm.c 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include "dw_mult.h"
  17. #define str(s) #s
  18. #define xstr(s) str(s)
  19. #define STARPU_GEMM_STR(name) xstr(STARPU_GEMM(name))
  20. TYPE *A, *B, *C;
  21. starpu_data_handle A_handle, B_handle, C_handle;
  22. /*
  23. * That program should compute C = A * B
  24. *
  25. * A of size (z,y)
  26. * B of size (x,z)
  27. * C of size (x,y)
  28. |---------------|
  29. z | B |
  30. |---------------|
  31. z x
  32. |----| |---------------|
  33. | | | |
  34. | | | |
  35. | A | y | C |
  36. | | | |
  37. | | | |
  38. |----| |---------------|
  39. */
  40. static void check_output(void)
  41. {
  42. /* check results */
  43. /* compute C = C - AB */
  44. CPU_GEMM("N", "N", ydim, xdim, zdim, (TYPE)-1.0, A, ydim, B, zdim, (TYPE)1.0f, C, ydim);
  45. /* make sure C = 0 */
  46. TYPE err;
  47. err = CPU_ASUM(xdim*ydim, C, 1);
  48. int max;
  49. max = CPU_IAMAX(xdim*ydim, C, 1);
  50. fprintf(stderr, "Avg error : %e\n", err/(xdim*ydim));
  51. fprintf(stderr, "Max error : %e\n", C[max]);
  52. }
  53. void callback_func(void *arg)
  54. {
  55. /* do some accounting */
  56. int id = starpu_get_worker_id();
  57. flop_per_worker[id] += BLAS3_FLOP(conf.m, conf.n, conf.k);
  58. ls_per_worker[id] += BLAS3_LS(conf.m, conf.n, conf.k);
  59. }
  60. static void init_problem_data(void)
  61. {
  62. unsigned i,j;
  63. #ifdef STARPU_USE_CUDA
  64. if (pin) {
  65. starpu_malloc_pinned_if_possible((void **)&A, zdim*ydim*sizeof(TYPE));
  66. starpu_malloc_pinned_if_possible((void **)&B, xdim*zdim*sizeof(TYPE));
  67. starpu_malloc_pinned_if_possible((void **)&C, xdim*ydim*sizeof(TYPE));
  68. } else
  69. #endif
  70. {
  71. #ifdef STARPU_HAVE_POSIX_MEMALIGN
  72. posix_memalign((void **)&A, 4096, zdim*ydim*sizeof(TYPE));
  73. posix_memalign((void **)&B, 4096, xdim*zdim*sizeof(TYPE));
  74. posix_memalign((void **)&C, 4096, xdim*ydim*sizeof(TYPE));
  75. #else
  76. A = malloc(zdim*ydim*sizeof(TYPE));
  77. B = malloc(xdim*zdim*sizeof(TYPE));
  78. C = malloc(xdim*ydim*sizeof(TYPE));
  79. #endif
  80. }
  81. /* fill the A and B matrices */
  82. if (norandom) {
  83. for (j=0; j < ydim; j++) {
  84. for (i=0; i < zdim; i++) {
  85. A[j+i*ydim] = (TYPE)(i);
  86. }
  87. }
  88. for (j=0; j < zdim; j++) {
  89. for (i=0; i < xdim; i++) {
  90. B[j+i*zdim] = (TYPE)(j);
  91. }
  92. }
  93. }
  94. else {
  95. for (j=0; j < ydim; j++) {
  96. for (i=0; i < zdim; i++) {
  97. A[j+i*ydim] = (TYPE)(starpu_drand48());
  98. }
  99. }
  100. for (j=0; j < zdim; j++) {
  101. for (i=0; i < xdim; i++) {
  102. B[j+i*zdim] = (TYPE)(starpu_drand48());
  103. }
  104. }
  105. }
  106. for (j=0; j < ydim; j++) {
  107. for (i=0; i < xdim; i++) {
  108. C[j+i*ydim] = (TYPE)(0);
  109. }
  110. }
  111. /* display memory consumption */
  112. fprintf(stderr, "Total memory : %ld MB\n",
  113. ( ydim*zdim*sizeof(TYPE)
  114. + zdim*xdim*sizeof(TYPE)
  115. + ydim*xdim*sizeof(TYPE) )/(1024*1024));
  116. }
  117. static void partition_mult_data(void)
  118. {
  119. starpu_register_blas_data(&A_handle, 0, (uintptr_t)A,
  120. ydim, ydim, zdim, sizeof(TYPE));
  121. starpu_register_blas_data(&B_handle, 0, (uintptr_t)B,
  122. zdim, zdim, xdim, sizeof(TYPE));
  123. starpu_register_blas_data(&C_handle, 0, (uintptr_t)C,
  124. ydim, ydim, xdim, sizeof(TYPE));
  125. starpu_data_set_wb_mask(C_handle, 1<<0);
  126. conf.k = zdim;
  127. conf.m = ydim/nslicesy;
  128. conf.n = xdim/nslicesx;
  129. starpu_filter f;
  130. f.filter_func = starpu_vertical_block_filter_func;
  131. f.filter_arg = nslicesx;
  132. starpu_filter f2;
  133. f2.filter_func = starpu_block_filter_func;
  134. f2.filter_arg = nslicesy;
  135. starpu_partition_data(B_handle, &f);
  136. starpu_partition_data(A_handle, &f2);
  137. starpu_map_filters(C_handle, 2, &f, &f2);
  138. }
  139. static void unpartition_mult_data(void)
  140. {
  141. starpu_unpartition_data(C_handle, 0);
  142. starpu_delete_data(C_handle);
  143. }
  144. static struct starpu_perfmodel_t gemm_model = {
  145. .type = STARPU_HISTORY_BASED,
  146. #ifdef STARPU_ATLAS
  147. .symbol = STARPU_GEMM_STR(gemm_atlas)
  148. #elif defined(STARPU_GOTO)
  149. .symbol = STARPU_GEMM_STR(gemm_goto)
  150. #else
  151. .symbol = STARPU_GEMM_STR(gemm)
  152. #endif
  153. };
  154. static starpu_codelet cl = {
  155. .where = STARPU_CPU|STARPU_CUDA,
  156. .cpu_func = STARPU_GEMM(cpu_mult),
  157. #ifdef STARPU_USE_CUDA
  158. .cuda_func = STARPU_GEMM(cublas_mult),
  159. #endif
  160. .model = &gemm_model,
  161. .nbuffers = 3
  162. };
  163. static void launch_codelets(void)
  164. {
  165. /* partition the work into slices */
  166. unsigned taskx, tasky;
  167. for (taskx = 0; taskx < nslicesx; taskx++)
  168. {
  169. for (tasky = 0; tasky < nslicesy; tasky++)
  170. {
  171. /* A B[task] = C[task] */
  172. struct starpu_task *task = starpu_task_create();
  173. task->cl = &cl;
  174. task->cl_arg = &conf;
  175. task->cl_arg_size = sizeof(struct block_conf);
  176. /* we have a callback to do some accounting */
  177. task->callback_func = callback_func;
  178. task->callback_arg = NULL;
  179. task->buffers[0].handle = starpu_get_sub_data(A_handle, 1, tasky);
  180. task->buffers[0].mode = STARPU_R;
  181. task->buffers[1].handle = starpu_get_sub_data(B_handle, 1, taskx);
  182. task->buffers[1].mode = STARPU_R;
  183. task->buffers[2].handle = starpu_get_sub_data(C_handle, 2, taskx, tasky);
  184. task->buffers[2].mode = STARPU_RW;
  185. starpu_submit_task(task);
  186. }
  187. }
  188. }
  189. int main(__attribute__ ((unused)) int argc,
  190. __attribute__ ((unused)) char **argv)
  191. {
  192. parse_args(argc, argv);
  193. /* start the runtime */
  194. starpu_init(NULL);
  195. starpu_helper_init_cublas();
  196. init_problem_data();
  197. gettimeofday(&start, NULL);
  198. partition_mult_data();
  199. launch_codelets();
  200. starpu_wait_all_tasks();
  201. gettimeofday(&end, NULL);
  202. double timing = (double)((end.tv_sec - start.tv_sec)*1000000 +
  203. (end.tv_usec - start.tv_usec));
  204. display_stats(timing);
  205. unpartition_mult_data();
  206. if (check)
  207. check_output();
  208. starpu_helper_shutdown_cublas();
  209. starpu_shutdown();
  210. return 0;
  211. }