dot_product.c 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2010-2011 Université de Bordeaux 1
  4. *
  5. * StarPU is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * StarPU is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu.h>
  17. #include <assert.h>
  18. #ifdef STARPU_USE_CUDA
  19. #include <cuda.h>
  20. #include <cublas.h>
  21. #include <starpu_cuda.h>
  22. #endif
  23. #define FPRINTF(ofile, fmt, args ...) do { if (!getenv("STARPU_SSILENT")) {fprintf(ofile, fmt, ##args); }} while(0)
  24. static float *x;
  25. static float *y;
  26. static starpu_data_handle_t *x_handles;
  27. static starpu_data_handle_t *y_handles;
  28. static unsigned nblocks = 4096;
  29. static unsigned entries_per_block = 1024;
  30. #define DOT_TYPE double
  31. static DOT_TYPE dot = 0.0f;
  32. static starpu_data_handle_t dot_handle;
  33. static int can_execute(unsigned workerid, struct starpu_task *task, unsigned nimpl)
  34. {
  35. const struct cudaDeviceProp *props;
  36. if (starpu_worker_get_type(workerid) == STARPU_CPU_WORKER)
  37. return 1;
  38. /* Cuda device */
  39. props = starpu_cuda_get_device_properties(workerid);
  40. if (props->major >= 2 || props->minor >= 3)
  41. /* At least compute capability 1.3, supports doubles */
  42. return 0;
  43. /* Old card, does not support doubles */
  44. return 0;
  45. }
  46. /*
  47. * Codelet to create a neutral element
  48. */
  49. void init_cpu_func(void *descr[], void *cl_arg)
  50. {
  51. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[0]);
  52. *dot = 0.0f;
  53. }
  54. #ifdef STARPU_USE_CUDA
  55. void init_cuda_func(void *descr[], void *cl_arg)
  56. {
  57. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[0]);
  58. cudaMemset(dot, 0, sizeof(DOT_TYPE));
  59. cudaThreadSynchronize();
  60. }
  61. #endif
  62. static struct starpu_codelet init_codelet = {
  63. .where = STARPU_CPU|STARPU_CUDA,
  64. .can_execute = can_execute,
  65. .cpu_func = init_cpu_func,
  66. #ifdef STARPU_USE_CUDA
  67. .cuda_func = init_cuda_func,
  68. #endif
  69. .nbuffers = 1
  70. };
  71. /*
  72. * Codelet to perform the reduction of two elements
  73. */
  74. void redux_cpu_func(void *descr[], void *cl_arg)
  75. {
  76. DOT_TYPE *dota = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[0]);
  77. DOT_TYPE *dotb = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[1]);
  78. *dota = *dota + *dotb;
  79. }
  80. #ifdef STARPU_USE_CUDA
  81. extern void redux_cuda_func(void *descr[], void *_args);
  82. #endif
  83. static struct starpu_codelet redux_codelet = {
  84. .where = STARPU_CPU|STARPU_CUDA,
  85. .can_execute = can_execute,
  86. .cpu_func = redux_cpu_func,
  87. #ifdef STARPU_USE_CUDA
  88. .cuda_func = redux_cuda_func,
  89. #endif
  90. .nbuffers = 2
  91. };
  92. /*
  93. * Dot product codelet
  94. */
  95. void dot_cpu_func(void *descr[], void *cl_arg)
  96. {
  97. float *local_x = (float *)STARPU_VECTOR_GET_PTR(descr[0]);
  98. float *local_y = (float *)STARPU_VECTOR_GET_PTR(descr[1]);
  99. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[2]);
  100. unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
  101. DOT_TYPE local_dot = 0.0;
  102. unsigned i;
  103. for (i = 0; i < n; i++)
  104. {
  105. local_dot += (DOT_TYPE)local_x[i]*(DOT_TYPE)local_y[i];
  106. }
  107. *dot = *dot + local_dot;
  108. }
  109. #ifdef STARPU_USE_CUDA
  110. void dot_cuda_func(void *descr[], void *cl_arg)
  111. {
  112. DOT_TYPE current_dot;
  113. DOT_TYPE local_dot;
  114. float *local_x = (float *)STARPU_VECTOR_GET_PTR(descr[0]);
  115. float *local_y = (float *)STARPU_VECTOR_GET_PTR(descr[1]);
  116. DOT_TYPE *dot = (DOT_TYPE *)STARPU_VARIABLE_GET_PTR(descr[2]);
  117. unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
  118. cudaMemcpy(&current_dot, dot, sizeof(DOT_TYPE), cudaMemcpyDeviceToHost);
  119. cudaThreadSynchronize();
  120. local_dot = (DOT_TYPE)cublasSdot(n, local_x, 1, local_y, 1);
  121. /* FPRINTF(stderr, "current_dot %f local dot %f -> %f\n", current_dot, local_dot, current_dot + local_dot); */
  122. current_dot += local_dot;
  123. cudaThreadSynchronize();
  124. cudaMemcpy(dot, &current_dot, sizeof(DOT_TYPE), cudaMemcpyHostToDevice);
  125. cudaThreadSynchronize();
  126. }
  127. #endif
  128. static struct starpu_codelet dot_codelet = {
  129. .where = STARPU_CPU|STARPU_CUDA,
  130. .can_execute = can_execute,
  131. .cpu_func = dot_cpu_func,
  132. #ifdef STARPU_USE_CUDA
  133. .cuda_func = dot_cuda_func,
  134. #endif
  135. .nbuffers = 3
  136. };
  137. /*
  138. * Tasks initialization
  139. */
  140. int main(int argc, char **argv)
  141. {
  142. starpu_init(NULL);
  143. starpu_helper_cublas_init();
  144. unsigned long nelems = nblocks*entries_per_block;
  145. size_t size = nelems*sizeof(float);
  146. x = (float *) malloc(size);
  147. y = (float *) malloc(size);
  148. x_handles = (starpu_data_handle_t *) calloc(nblocks, sizeof(starpu_data_handle_t));
  149. y_handles = (starpu_data_handle_t *) calloc(nblocks, sizeof(starpu_data_handle_t));
  150. assert(x && y);
  151. starpu_srand48(0);
  152. DOT_TYPE reference_dot = 0.0;
  153. unsigned long i;
  154. for (i = 0; i < nelems; i++)
  155. {
  156. x[i] = (float)starpu_drand48();
  157. y[i] = (float)starpu_drand48();
  158. reference_dot += (DOT_TYPE)x[i]*(DOT_TYPE)y[i];
  159. }
  160. unsigned block;
  161. for (block = 0; block < nblocks; block++)
  162. {
  163. starpu_vector_data_register(&x_handles[block], 0,
  164. (uintptr_t)&x[entries_per_block*block], entries_per_block, sizeof(float));
  165. starpu_vector_data_register(&y_handles[block], 0,
  166. (uintptr_t)&y[entries_per_block*block], entries_per_block, sizeof(float));
  167. }
  168. starpu_variable_data_register(&dot_handle, 0, (uintptr_t)&dot, sizeof(DOT_TYPE));
  169. /*
  170. * Compute dot product with StarPU
  171. */
  172. starpu_data_set_reduction_methods(dot_handle, &redux_codelet, &init_codelet);
  173. for (block = 0; block < nblocks; block++)
  174. {
  175. struct starpu_task *task = starpu_task_create();
  176. task->cl = &dot_codelet;
  177. task->destroy = 1;
  178. task->buffers[0].handle = x_handles[block];
  179. task->buffers[0].mode = STARPU_R;
  180. task->buffers[1].handle = y_handles[block];
  181. task->buffers[1].mode = STARPU_R;
  182. task->buffers[2].handle = dot_handle;
  183. task->buffers[2].mode = STARPU_REDUX;
  184. int ret = starpu_task_submit(task);
  185. if (ret == -ENODEV) goto enodev;
  186. STARPU_ASSERT(!ret);
  187. }
  188. for (block = 0; block < nblocks; block++)
  189. {
  190. starpu_data_unregister(x_handles[block]);
  191. starpu_data_unregister(y_handles[block]);
  192. }
  193. starpu_data_unregister(dot_handle);
  194. FPRINTF(stderr, "Reference : %e vs. %e (Delta %e)\n", reference_dot, dot, reference_dot - dot);
  195. starpu_helper_cublas_shutdown();
  196. starpu_shutdown();
  197. free(x);
  198. free(y);
  199. free(x_handles);
  200. free(y_handles);
  201. return 0;
  202. enodev:
  203. fprintf(stderr, "WARNING: No one can execute this task\n");
  204. /* yes, we do not perform the computation but we did detect that no one
  205. * could perform the kernel, so this is not an error from StarPU */
  206. return 77;
  207. }