dgemm.c 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2017,2018 Inria
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. /* This example shows a basic StarPU vector scale app on top of StarPURM with a nVidia CUDA kernel */
  17. #define _GNU_SOURCE
  18. #include <sched.h>
  19. #include <stdio.h>
  20. #include <float.h>
  21. #include <mkl.h>
  22. #include <morse.h>
  23. #include <starpurm.h>
  24. #include <hwloc.h>
  25. #include <pthread.h>
  26. #define CHECK
  27. static int rm_cpu_type_id = -1;
  28. static int rm_cuda_type_id = -1;
  29. static int rm_nb_cpu_units = 0;
  30. static int rm_nb_cuda_units = 0;
  31. static const int nb_random_tests = 10;
  32. static unsigned spawn_pending = 0;
  33. static pthread_mutex_t spawn_pending_mutex = PTHREAD_MUTEX_INITIALIZER;
  34. static pthread_cond_t spawn_pending_cond;
  35. static void _inc_spawn_pending(void)
  36. {
  37. pthread_mutex_lock(&spawn_pending_mutex);
  38. assert(spawn_pending < UINT_MAX);
  39. spawn_pending++;
  40. pthread_mutex_unlock(&spawn_pending_mutex);
  41. }
  42. static void _dec_spawn_pending(void)
  43. {
  44. pthread_mutex_lock(&spawn_pending_mutex);
  45. assert(spawn_pending > 0);
  46. spawn_pending--;
  47. if (spawn_pending == 0)
  48. pthread_cond_broadcast(&spawn_pending_cond);
  49. pthread_mutex_unlock(&spawn_pending_mutex);
  50. }
  51. static void _wait_pending_spawns(void)
  52. {
  53. pthread_mutex_lock(&spawn_pending_mutex);
  54. while (spawn_pending > 0)
  55. pthread_cond_wait(&spawn_pending_cond, &spawn_pending_mutex);
  56. pthread_mutex_unlock(&spawn_pending_mutex);
  57. }
  58. static void spawn_callback(void *_arg)
  59. {
  60. assert(42 == (uintptr_t)_arg);
  61. _dec_spawn_pending();
  62. }
  63. static void usage(void)
  64. {
  65. fprintf(stderr, "dgemm: M N K <trans_A=T|N> <trans_B=[T|N]>\n");
  66. exit(EXIT_FAILURE);
  67. }
  68. static void init_rm_infos(void)
  69. {
  70. int cpu_type = starpurm_get_device_type_id("cpu");
  71. int nb_cpu_units = starpurm_get_nb_devices_by_type(cpu_type);
  72. if (nb_cpu_units < 1)
  73. {
  74. /* No CPU unit available. */
  75. exit(77);
  76. }
  77. int cuda_type = starpurm_get_device_type_id("cuda");
  78. int nb_cuda_units = starpurm_get_nb_devices_by_type(cuda_type);
  79. rm_cpu_type_id = cpu_type;
  80. rm_cuda_type_id = cuda_type;
  81. rm_nb_cpu_units = nb_cpu_units;
  82. rm_nb_cuda_units = nb_cuda_units;
  83. }
  84. static void disp_cpuset(hwloc_cpuset_t selected_cpuset)
  85. {
  86. //hwloc_cpuset_t selected_cpuset = starpurm_get_selected_cpuset();
  87. int strl = hwloc_bitmap_snprintf(NULL, 0, selected_cpuset);
  88. char str[strl+1];
  89. hwloc_bitmap_snprintf(str, strl+1, selected_cpuset);
  90. printf("%llx: selected cpuset = %s\n", (unsigned long long)pthread_self(), str);
  91. }
  92. struct s_test_args
  93. {
  94. const int m;
  95. const int n;
  96. const int k;
  97. int transA;
  98. int transB;
  99. };
  100. static void test(void *_args)
  101. {
  102. struct s_test_args *args = _args;
  103. const int m = args->m;
  104. const int n = args->n;
  105. const int k = args->k;
  106. int transA = args->transA;
  107. int transB = args->transB;
  108. unsigned rand_seed = (unsigned)time(NULL);
  109. double *A = malloc(m * k * sizeof(double));
  110. double *B = malloc(k * n * sizeof(double));
  111. double *C = calloc(m * n, sizeof(double));
  112. double *C_test = calloc(m * n, sizeof(double));
  113. const double alpha = (double)rand_r(&rand_seed) / ((double)rand_r(&rand_seed) + DBL_MIN);
  114. const double beta = (double)rand_r(&rand_seed) / ((double)rand_r(&rand_seed) + DBL_MIN);
  115. int i;
  116. for (i = 0; i < m; i++)
  117. {
  118. int j;
  119. for (j = 0; j < n; j++)
  120. {
  121. A[i*n+j] = (double)rand_r(&rand_seed) / ((double)rand_r(&rand_seed) + DBL_MIN);
  122. B[i*n+j] = (double)rand_r(&rand_seed) / ((double)rand_r(&rand_seed) + DBL_MIN);
  123. }
  124. }
  125. MORSE_dgemm(transA, transB, m, n, k, alpha, A, k, B, n, beta, C, n);
  126. #ifdef CHECK
  127. /* Check */
  128. cblas_dgemm( CblasColMajor,
  129. ( CBLAS_TRANSPOSE ) transA,
  130. ( CBLAS_TRANSPOSE ) transB,
  131. m, n, k,
  132. alpha, A, k,
  133. B, n,
  134. beta, C_test, n );
  135. double C_test_inorm = LAPACKE_dlange(CblasColMajor, 'I', m, n, C_test, n);
  136. cblas_daxpy(m*n, -1, C, 1, C_test, 1);
  137. double inorm = LAPACKE_dlange(CblasColMajor, 'I', m, n, C_test, n);
  138. printf("%llx: ||C_test-C||_I / ||C_test||_I = %e\n", (unsigned long long)pthread_self(), inorm/C_test_inorm);
  139. #endif
  140. free(A);
  141. free(B);
  142. free(C);
  143. free(C_test);
  144. }
  145. static void select_units(hwloc_cpuset_t selected_cpuset, hwloc_cpuset_t available_cpuset, int offset, int nb)
  146. {
  147. int first_idx = hwloc_bitmap_first(available_cpuset);
  148. int last_idx = hwloc_bitmap_last(available_cpuset);
  149. int count = 0;
  150. int idx = first_idx;
  151. while (idx != -1 && idx <= last_idx && count < offset+nb)
  152. {
  153. if (hwloc_bitmap_isset(available_cpuset, idx))
  154. {
  155. if (count >= offset)
  156. {
  157. hwloc_bitmap_set(selected_cpuset, idx);
  158. }
  159. count ++;
  160. }
  161. idx = hwloc_bitmap_next(available_cpuset, idx);
  162. }
  163. assert(count == offset+nb);
  164. }
  165. void spawn_tests(int cpu_offset, int cpu_nb, int cuda_offset, int cuda_nb, void *args)
  166. {
  167. if (cpu_offset + cpu_nb > rm_nb_cpu_units)
  168. exit(77);
  169. if (cuda_offset + cuda_nb > rm_nb_cuda_units)
  170. exit(77);
  171. hwloc_cpuset_t cpu_cpuset = starpurm_get_all_cpu_workers_cpuset();
  172. hwloc_cpuset_t cuda_cpuset = starpurm_get_all_device_workers_cpuset_by_type(rm_cuda_type_id);
  173. hwloc_cpuset_t sel_cpuset = hwloc_bitmap_alloc();
  174. assert(sel_cpuset != NULL);
  175. select_units(sel_cpuset, cpu_cpuset, cpu_offset, cpu_nb);
  176. select_units(sel_cpuset, cuda_cpuset, cuda_offset, cuda_nb);
  177. {
  178. int strl1 = hwloc_bitmap_snprintf(NULL, 0, cpu_cpuset);
  179. char str1[strl1+1];
  180. hwloc_bitmap_snprintf(str1, strl1+1, cpu_cpuset);
  181. int strl2 = hwloc_bitmap_snprintf(NULL, 0, cuda_cpuset);
  182. char str2[strl2+1];
  183. hwloc_bitmap_snprintf(str2, strl2+1, cuda_cpuset);
  184. printf("all cpus cpuset = %s\n", str1);
  185. int strl3 = hwloc_bitmap_snprintf(NULL, 0, sel_cpuset);
  186. char str3[strl3+1];
  187. hwloc_bitmap_snprintf(str3, strl1+3, sel_cpuset);
  188. printf("spawn on selected cpuset = %s (avail cpu %s, avail cuda %s)\n", str3, str1, str2);
  189. }
  190. _inc_spawn_pending();
  191. starpurm_spawn_kernel_on_cpus_callback(NULL, test, args, sel_cpuset, spawn_callback, (void*)(uintptr_t)42);
  192. hwloc_bitmap_free(sel_cpuset);
  193. hwloc_bitmap_free(cpu_cpuset);
  194. hwloc_bitmap_free(cuda_cpuset);
  195. }
  196. int main( int argc, char const *argv[])
  197. {
  198. pthread_cond_init(&spawn_pending_cond, NULL);
  199. int transA = MorseTrans;
  200. int transB = MorseTrans;
  201. if (argc < 6 || argc > 6)
  202. usage();
  203. int m = atoi(argv[1]);
  204. if (m < 1)
  205. usage();
  206. int n = atoi(argv[2]);
  207. if (n < 1)
  208. usage();
  209. int k = atoi(argv[3]);
  210. if (k < 1)
  211. usage();
  212. if (strcmp(argv[4], "T") == 0)
  213. transA = MorseTrans;
  214. else if (strcmp(argv[4], "N") == 0)
  215. transA = MorseNoTrans;
  216. else
  217. usage();
  218. if (strcmp(argv[5], "T") == 0)
  219. transB = MorseTrans;
  220. else if (strcmp(argv[5], "N") == 0)
  221. transB = MorseNoTrans;
  222. else
  223. usage();
  224. srand(time(NULL));
  225. struct s_test_args test_args = { .m = m, .n = n, .k = k, .transA = transA, .transB = transB };
  226. /* Test case */
  227. starpurm_initialize();
  228. starpurm_set_drs_enable(NULL);
  229. init_rm_infos();
  230. printf("cpu units: %d\n", rm_nb_cpu_units);
  231. printf("cuda units: %d\n", rm_nb_cuda_units);
  232. printf("using default units\n");
  233. disp_cpuset(starpurm_get_selected_cpuset());
  234. MORSE_Init(rm_nb_cpu_units, rm_nb_cuda_units);
  235. test(&test_args);
  236. {
  237. int cpu_offset = 0;
  238. int cpu_nb = rm_nb_cpu_units/2;
  239. if (cpu_nb == 0 && rm_nb_cpu_units > 0)
  240. {
  241. cpu_nb = 1;
  242. }
  243. int cuda_offset = 0;
  244. int cuda_nb = rm_nb_cuda_units/2;
  245. if (cuda_nb == 0 && rm_nb_cuda_units > 0)
  246. {
  247. cuda_nb = 1;
  248. }
  249. spawn_tests(cpu_offset, cpu_nb, cuda_offset, cuda_nb, &test_args);
  250. }
  251. {
  252. int cpu_offset = rm_nb_cpu_units/2;
  253. int cpu_nb = cpu_offset;
  254. if (cpu_nb == 0 && rm_nb_cpu_units > 0)
  255. {
  256. cpu_nb = 1;
  257. }
  258. int cuda_offset = rm_nb_cuda_units/2;
  259. int cuda_nb = rm_nb_cuda_units/2;
  260. spawn_tests(cpu_offset, cpu_nb, cuda_offset, cuda_nb, &test_args);
  261. }
  262. _wait_pending_spawns();
  263. MORSE_Finalize();
  264. starpurm_shutdown();
  265. pthread_cond_destroy(&spawn_pending_cond);
  266. return 0;
  267. }