dot_product.c 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2010-2011 Université de Bordeaux 1
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu.h>
  17. #include <assert.h>
  18. #ifdef STARPU_USE_CUDA
  19. #include <cuda.h>
  20. #include <cublas.h>
  21. #include <starpu_cuda.h>
  22. #endif
  23. #define FPRINTF(ofile, fmt, args ...) do { if (!getenv("STARPU_SSILENT")) {fprintf(ofile, fmt, ##args); }} while(0)
  24. static float *x;
  25. static float *y;
  26. static starpu_data_handle_t *x_handles;
  27. static starpu_data_handle_t *y_handles;
  28. static unsigned nblocks = 4096;
  29. static unsigned entries_per_block = 1024;
  30. #define DOT_TYPE double
  31. static DOT_TYPE dot = 0.0f;
  32. static starpu_data_handle_t dot_handle;
  33. static int can_execute(unsigned workerid, struct starpu_task *task, unsigned nimpl)
  34. {
  35. const struct cudaDeviceProp *props;
  36. if (starpu_worker_get_type(workerid) == STARPU_CPU_WORKER)
  37. return 1;
  38. #ifdef STARPU_USE_CUDA
  39. /* Cuda device */
  40. props = starpu_cuda_get_device_properties(workerid);
  41. if (props->major >= 2 || props->minor >= 3)
  42. /* At least compute capability 1.3, supports doubles */
  43. return 0;
  44. #endif
  45. /* Old card, does not support doubles */
  46. return 0;
  47. }
  48. /*
  49. * Codelet to create a neutral element
  50. */
  51. void init_cpu_func(void *descr[], void *cl_arg)
  52. {
  53. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[0]);
  54. *dot = 0.0f;
  55. }
  56. #ifdef STARPU_USE_CUDA
  57. void init_cuda_func(void *descr[], void *cl_arg)
  58. {
  59. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[0]);
  60. cudaMemset(dot, 0, sizeof(DOT_TYPE));
  61. cudaThreadSynchronize();
  62. }
  63. #endif
  64. static struct starpu_codelet init_codelet = {
  65. .where = STARPU_CPU|STARPU_CUDA,
  66. .can_execute = can_execute,
  67. .cpu_func = init_cpu_func,
  68. #ifdef STARPU_USE_CUDA
  69. .cuda_func = init_cuda_func,
  70. #endif
  71. .nbuffers = 1
  72. };
  73. /*
  74. * Codelet to perform the reduction of two elements
  75. */
  76. void redux_cpu_func(void *descr[], void *cl_arg)
  77. {
  78. DOT_TYPE *dota = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[0]);
  79. DOT_TYPE *dotb = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[1]);
  80. *dota = *dota + *dotb;
  81. }
  82. #ifdef STARPU_USE_CUDA
  83. extern void redux_cuda_func(void *descr[], void *_args);
  84. #endif
  85. static struct starpu_codelet redux_codelet = {
  86. .where = STARPU_CPU|STARPU_CUDA,
  87. .can_execute = can_execute,
  88. .cpu_func = redux_cpu_func,
  89. #ifdef STARPU_USE_CUDA
  90. .cuda_func = redux_cuda_func,
  91. #endif
  92. .nbuffers = 2
  93. };
  94. /*
  95. * Dot product codelet
  96. */
  97. void dot_cpu_func(void *descr[], void *cl_arg)
  98. {
  99. float *local_x = (float *)STARPU_VECTOR_GET_PTR(descr[0]);
  100. float *local_y = (float *)STARPU_VECTOR_GET_PTR(descr[1]);
  101. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[2]);
  102. unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
  103. DOT_TYPE local_dot = 0.0;
  104. unsigned i;
  105. for (i = 0; i < n; i++)
  106. {
  107. local_dot += (DOT_TYPE)local_x[i]*(DOT_TYPE)local_y[i];
  108. }
  109. *dot = *dot + local_dot;
  110. }
  111. #ifdef STARPU_USE_CUDA
  112. void dot_cuda_func(void *descr[], void *cl_arg)
  113. {
  114. DOT_TYPE current_dot;
  115. DOT_TYPE local_dot;
  116. float *local_x = (float *)STARPU_VECTOR_GET_PTR(descr[0]);
  117. float *local_y = (float *)STARPU_VECTOR_GET_PTR(descr[1]);
  118. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[2]);
  119. unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
  120. cudaMemcpy(&current_dot, dot, sizeof(DOT_TYPE), cudaMemcpyDeviceToHost);
  121. cudaThreadSynchronize();
  122. local_dot = (DOT_TYPE)cublasSdot(n, local_x, 1, local_y, 1);
  123. /* FPRINTF(stderr, "current_dot %f local dot %f -> %f\n", current_dot, local_dot, current_dot + local_dot); */
  124. current_dot += local_dot;
  125. cudaThreadSynchronize();
  126. cudaMemcpy(dot, &current_dot, sizeof(DOT_TYPE), cudaMemcpyHostToDevice);
  127. cudaThreadSynchronize();
  128. }
  129. #endif
  130. static struct starpu_codelet dot_codelet = {
  131. .where = STARPU_CPU|STARPU_CUDA,
  132. .can_execute = can_execute,
  133. .cpu_func = dot_cpu_func,
  134. #ifdef STARPU_USE_CUDA
  135. .cuda_func = dot_cuda_func,
  136. #endif
  137. .nbuffers = 3
  138. };
  139. /*
  140. * Tasks initialization
  141. */
  142. int main(int argc, char **argv)
  143. {
  144. starpu_init(NULL);
  145. starpu_helper_cublas_init();
  146. unsigned long nelems = nblocks*entries_per_block;
  147. size_t size = nelems*sizeof(float);
  148. x = (float *) malloc(size);
  149. y = (float *) malloc(size);
  150. x_handles = (starpu_data_handle_t *) calloc(nblocks, sizeof(starpu_data_handle_t));
  151. y_handles = (starpu_data_handle_t *) calloc(nblocks, sizeof(starpu_data_handle_t));
  152. assert(x && y);
  153. starpu_srand48(0);
  154. DOT_TYPE reference_dot = 0.0;
  155. unsigned long i;
  156. for (i = 0; i < nelems; i++)
  157. {
  158. x[i] = (float)starpu_drand48();
  159. y[i] = (float)starpu_drand48();
  160. reference_dot += (DOT_TYPE)x[i]*(DOT_TYPE)y[i];
  161. }
  162. unsigned block;
  163. for (block = 0; block < nblocks; block++)
  164. {
  165. starpu_vector_data_register(&x_handles[block], 0,
  166. (uintptr_t)&x[entries_per_block*block], entries_per_block, sizeof(float));
  167. starpu_vector_data_register(&y_handles[block], 0,
  168. (uintptr_t)&y[entries_per_block*block], entries_per_block, sizeof(float));
  169. }
  170. starpu_variable_data_register(&dot_handle, 0, (uintptr_t)&dot, sizeof(DOT_TYPE));
  171. /*
  172. * Compute dot product with StarPU
  173. */
  174. starpu_data_set_reduction_methods(dot_handle, &redux_codelet, &init_codelet);
  175. for (block = 0; block < nblocks; block++)
  176. {
  177. struct starpu_task *task = starpu_task_create();
  178. task->cl = &dot_codelet;
  179. task->destroy = 1;
  180. task->buffers[0].handle = x_handles[block];
  181. task->buffers[0].mode = STARPU_R;
  182. task->buffers[1].handle = y_handles[block];
  183. task->buffers[1].mode = STARPU_R;
  184. task->buffers[2].handle = dot_handle;
  185. task->buffers[2].mode = STARPU_REDUX;
  186. int ret = starpu_task_submit(task);
  187. if (ret == -ENODEV) goto enodev;
  188. STARPU_ASSERT(!ret);
  189. }
  190. for (block = 0; block < nblocks; block++)
  191. {
  192. starpu_data_unregister(x_handles[block]);
  193. starpu_data_unregister(y_handles[block]);
  194. }
  195. starpu_data_unregister(dot_handle);
  196. FPRINTF(stderr, "Reference : %e vs. %e (Delta %e)\n", reference_dot, dot, reference_dot - dot);
  197. starpu_helper_cublas_shutdown();
  198. starpu_shutdown();
  199. free(x);
  200. free(y);
  201. free(x_handles);
  202. free(y_handles);
  203. return 0;
  204. enodev:
  205. fprintf(stderr, "WARNING: No one can execute this task\n");
  206. /* yes, we do not perform the computation but we did detect that no one
  207. * could perform the kernel, so this is not an error from StarPU */
  208. return 77;
  209. }