csr_interface.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <starpu.h>
  17. #include <datawizard/data_parameters.h>
  18. #include <datawizard/coherency.h>
  19. #include <datawizard/copy-driver.h>
  20. #include <datawizard/hierarchy.h>
  21. #include <common/hash.h>
  22. static int dummy_copy_ram_to_ram(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node);
  23. #ifdef USE_CUDA
  24. static int copy_ram_to_cublas(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node);
  25. static int copy_cublas_to_ram(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node);
  26. #endif
  27. static const struct copy_data_methods_s csr_copy_data_methods_s = {
  28. .ram_to_ram = dummy_copy_ram_to_ram,
  29. .ram_to_spu = NULL,
  30. #ifdef USE_CUDA
  31. .ram_to_cuda = copy_ram_to_cublas,
  32. .cuda_to_ram = copy_cublas_to_ram,
  33. #endif
  34. .cuda_to_cuda = NULL,
  35. .cuda_to_spu = NULL,
  36. .spu_to_ram = NULL,
  37. .spu_to_cuda = NULL,
  38. .spu_to_spu = NULL
  39. };
  40. static size_t allocate_csr_buffer_on_node(struct starpu_data_state_t *state, uint32_t dst_node);
  41. static void liberate_csr_buffer_on_node(starpu_data_interface_t *interface, uint32_t node);
  42. static size_t dump_csr_interface(starpu_data_interface_t *interface, void *_buffer);
  43. static size_t csr_interface_get_size(struct starpu_data_state_t *state);
  44. static uint32_t footprint_csr_interface_crc32(data_state *state, uint32_t hstate);
  45. struct data_interface_ops_t interface_csr_ops = {
  46. .allocate_data_on_node = allocate_csr_buffer_on_node,
  47. .liberate_data_on_node = liberate_csr_buffer_on_node,
  48. .copy_methods = &csr_copy_data_methods_s,
  49. .dump_data_interface = dump_csr_interface,
  50. .get_size = csr_interface_get_size,
  51. .interfaceid = STARPU_CSR_INTERFACE_ID,
  52. .footprint = footprint_csr_interface_crc32
  53. };
  54. /* declare a new data with the BLAS interface */
  55. void starpu_register_csr_data(struct starpu_data_state_t **handle, uint32_t home_node,
  56. uint32_t nnz, uint32_t nrow, uintptr_t nzval, uint32_t *colind, uint32_t *rowptr, uint32_t firstentry, size_t elemsize)
  57. {
  58. struct starpu_data_state_t *state =
  59. starpu_data_state_create(sizeof(starpu_csr_interface_t));
  60. STARPU_ASSERT(handle);
  61. *handle = state;
  62. unsigned node;
  63. for (node = 0; node < MAXNODES; node++)
  64. {
  65. starpu_csr_interface_t *local_interface =
  66. starpu_data_get_interface_on_node(state, node);
  67. if (node == home_node) {
  68. local_interface->nzval = nzval;
  69. local_interface->colind = colind;
  70. local_interface->rowptr = rowptr;
  71. }
  72. else {
  73. local_interface->nzval = 0;
  74. local_interface->colind = NULL;
  75. local_interface->rowptr = NULL;
  76. }
  77. local_interface->nnz = nnz;
  78. local_interface->nrow = nrow;
  79. local_interface->firstentry = firstentry;
  80. local_interface->elemsize = elemsize;
  81. }
  82. state->ops = &interface_csr_ops;
  83. register_new_data(state, home_node, 0);
  84. }
  85. static inline uint32_t footprint_csr_interface_generic(uint32_t (*hash_func)(uint32_t input, uint32_t hstate), data_state *state, uint32_t hstate)
  86. {
  87. uint32_t hash;
  88. hash = hstate;
  89. hash = hash_func(starpu_get_csr_nnz(state), hash);
  90. return hash;
  91. }
  92. static uint32_t footprint_csr_interface_crc32(data_state *state, uint32_t hstate)
  93. {
  94. return footprint_csr_interface_generic(crc32_be, state, hstate);
  95. }
  96. struct dumped_csr_interface_s {
  97. uint32_t nnz;
  98. uint32_t nrow;
  99. uintptr_t nzval;
  100. uint32_t *colind;
  101. uint32_t *rowptr;
  102. uint32_t firstentry;
  103. uint32_t elemsize;
  104. } __attribute__ ((packed));
  105. static size_t dump_csr_interface(starpu_data_interface_t *interface, void *_buffer)
  106. {
  107. /* yes, that's DIRTY ... */
  108. struct dumped_csr_interface_s *buffer = _buffer;
  109. buffer->nnz = (*interface).csr.nnz;
  110. buffer->nrow = (*interface).csr.nrow;
  111. buffer->nzval = (*interface).csr.nzval;
  112. buffer->colind = (*interface).csr.colind;
  113. buffer->rowptr = (*interface).csr.rowptr;
  114. buffer->firstentry = (*interface).csr.firstentry;
  115. buffer->elemsize = (*interface).csr.elemsize;
  116. return (sizeof(struct dumped_csr_interface_s));
  117. }
  118. /* offer an access to the data parameters */
  119. uint32_t starpu_get_csr_nnz(struct starpu_data_state_t *state)
  120. {
  121. starpu_csr_interface_t *interface =
  122. starpu_data_get_interface_on_node(state, 0);
  123. return interface->nnz;
  124. }
  125. uint32_t starpu_get_csr_nrow(struct starpu_data_state_t *state)
  126. {
  127. starpu_csr_interface_t *interface =
  128. starpu_data_get_interface_on_node(state, 0);
  129. return interface->nrow;
  130. }
  131. uint32_t starpu_get_csr_firstentry(struct starpu_data_state_t *state)
  132. {
  133. starpu_csr_interface_t *interface =
  134. starpu_data_get_interface_on_node(state, 0);
  135. return interface->firstentry;
  136. }
  137. size_t starpu_get_csr_elemsize(struct starpu_data_state_t *state)
  138. {
  139. starpu_csr_interface_t *interface =
  140. starpu_data_get_interface_on_node(state, 0);
  141. return interface->elemsize;
  142. }
  143. uintptr_t starpu_get_csr_local_nzval(struct starpu_data_state_t *state)
  144. {
  145. unsigned node;
  146. node = get_local_memory_node();
  147. STARPU_ASSERT(state->per_node[node].allocated);
  148. starpu_csr_interface_t *interface =
  149. starpu_data_get_interface_on_node(state, node);
  150. return interface->nzval;
  151. }
  152. uint32_t *starpu_get_csr_local_colind(struct starpu_data_state_t *state)
  153. {
  154. unsigned node;
  155. node = get_local_memory_node();
  156. STARPU_ASSERT(state->per_node[node].allocated);
  157. starpu_csr_interface_t *interface =
  158. starpu_data_get_interface_on_node(state, node);
  159. return interface->colind;
  160. }
  161. uint32_t *starpu_get_csr_local_rowptr(struct starpu_data_state_t *state)
  162. {
  163. unsigned node;
  164. node = get_local_memory_node();
  165. STARPU_ASSERT(state->per_node[node].allocated);
  166. starpu_csr_interface_t *interface =
  167. starpu_data_get_interface_on_node(state, node);
  168. return interface->rowptr;
  169. }
  170. static size_t csr_interface_get_size(struct starpu_data_state_t *state)
  171. {
  172. size_t size;
  173. uint32_t nnz = starpu_get_csr_nnz(state);
  174. uint32_t nrow = starpu_get_csr_nrow(state);
  175. size_t elemsize = starpu_get_csr_elemsize(state);
  176. size = nnz*elemsize + nnz*sizeof(uint32_t) + (nrow+1)*sizeof(uint32_t);
  177. return size;
  178. }
  179. /* memory allocation/deallocation primitives for the BLAS interface */
  180. /* returns the size of the allocated area */
  181. static size_t allocate_csr_buffer_on_node(struct starpu_data_state_t *state, uint32_t dst_node)
  182. {
  183. uintptr_t addr_nzval;
  184. uint32_t *addr_colind, *addr_rowptr;
  185. size_t allocated_memory;
  186. /* we need the 3 arrays to be allocated */
  187. starpu_csr_interface_t *interface =
  188. starpu_data_get_interface_on_node(state, dst_node);
  189. uint32_t nnz = interface->nnz;
  190. uint32_t nrow = interface->nrow;
  191. size_t elemsize = interface->elemsize;
  192. node_kind kind = get_node_kind(dst_node);
  193. switch(kind) {
  194. case RAM:
  195. addr_nzval = (uintptr_t)malloc(nnz*elemsize);
  196. if (!addr_nzval)
  197. goto fail_nzval;
  198. addr_colind = malloc(nnz*sizeof(uint32_t));
  199. if (!addr_colind)
  200. goto fail_colind;
  201. addr_rowptr = malloc((nrow+1)*sizeof(uint32_t));
  202. if (!addr_rowptr)
  203. goto fail_rowptr;
  204. break;
  205. #ifdef USE_CUDA
  206. case CUDA_RAM:
  207. cublasAlloc(nnz, elemsize, (void **)&addr_nzval);
  208. if (!addr_nzval)
  209. goto fail_nzval;
  210. cublasAlloc(nnz, sizeof(uint32_t), (void **)&addr_colind);
  211. if (!addr_colind)
  212. goto fail_colind;
  213. cublasAlloc((nrow+1), sizeof(uint32_t), (void **)&addr_rowptr);
  214. if (!addr_rowptr)
  215. goto fail_rowptr;
  216. break;
  217. #endif
  218. default:
  219. assert(0);
  220. }
  221. /* allocation succeeded */
  222. allocated_memory =
  223. nnz*elemsize + nnz*sizeof(uint32_t) + (nrow+1)*sizeof(uint32_t);
  224. /* update the data properly in consequence */
  225. interface->nzval = addr_nzval;
  226. interface->colind = addr_colind;
  227. interface->rowptr = addr_rowptr;
  228. return allocated_memory;
  229. fail_rowptr:
  230. switch(kind) {
  231. case RAM:
  232. free((void *)addr_colind);
  233. #ifdef USE_CUDA
  234. case CUDA_RAM:
  235. cublasFree((void*)addr_colind);
  236. break;
  237. #endif
  238. default:
  239. assert(0);
  240. }
  241. fail_colind:
  242. switch(kind) {
  243. case RAM:
  244. free((void *)addr_nzval);
  245. #ifdef USE_CUDA
  246. case CUDA_RAM:
  247. cublasFree((void*)addr_nzval);
  248. break;
  249. #endif
  250. default:
  251. assert(0);
  252. }
  253. fail_nzval:
  254. /* allocation failed */
  255. allocated_memory = 0;
  256. return allocated_memory;
  257. }
  258. static void liberate_csr_buffer_on_node(starpu_data_interface_t *interface, uint32_t node)
  259. {
  260. node_kind kind = get_node_kind(node);
  261. switch(kind) {
  262. case RAM:
  263. free((void*)interface->csr.nzval);
  264. free((void*)interface->csr.colind);
  265. free((void*)interface->csr.rowptr);
  266. break;
  267. #ifdef USE_CUDA
  268. case CUDA_RAM:
  269. cublasFree((void*)interface->csr.nzval);
  270. cublasFree((void*)interface->csr.colind);
  271. cublasFree((void*)interface->csr.rowptr);
  272. break;
  273. #endif
  274. default:
  275. assert(0);
  276. }
  277. }
  278. #ifdef USE_CUDA
  279. static int copy_cublas_to_ram(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  280. {
  281. starpu_csr_interface_t *src_csr;
  282. starpu_csr_interface_t *dst_csr;
  283. src_csr = starpu_data_get_interface_on_node(state, src_node);
  284. dst_csr = starpu_data_get_interface_on_node(state, dst_node);
  285. uint32_t nnz = src_csr->nnz;
  286. uint32_t nrow = src_csr->nrow;
  287. size_t elemsize = src_csr->elemsize;
  288. cublasGetVector(nnz, elemsize, (uint8_t *)src_csr->nzval, 1,
  289. (uint8_t *)dst_csr->nzval, 1);
  290. cublasGetVector(nnz, sizeof(uint32_t), (uint8_t *)src_csr->colind, 1,
  291. (uint8_t *)dst_csr->colind, 1);
  292. cublasGetVector((nrow+1), sizeof(uint32_t), (uint8_t *)src_csr->rowptr, 1,
  293. (uint8_t *)dst_csr->rowptr, 1);
  294. TRACE_DATA_COPY(src_node, dst_node, nnz*elemsize + (nnz+nrow+1)*sizeof(uint32_t));
  295. return 0;
  296. }
  297. static int copy_ram_to_cublas(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  298. {
  299. starpu_csr_interface_t *src_csr;
  300. starpu_csr_interface_t *dst_csr;
  301. src_csr = starpu_data_get_interface_on_node(state, src_node);
  302. dst_csr = starpu_data_get_interface_on_node(state, dst_node);
  303. uint32_t nnz = src_csr->nnz;
  304. uint32_t nrow = src_csr->nrow;
  305. size_t elemsize = src_csr->elemsize;
  306. cublasSetVector(nnz, elemsize, (uint8_t *)src_csr->nzval, 1,
  307. (uint8_t *)dst_csr->nzval, 1);
  308. cublasSetVector(nnz, sizeof(uint32_t), (uint8_t *)src_csr->colind, 1,
  309. (uint8_t *)dst_csr->colind, 1);
  310. cublasSetVector((nrow+1), sizeof(uint32_t), (uint8_t *)src_csr->rowptr, 1,
  311. (uint8_t *)dst_csr->rowptr, 1);
  312. TRACE_DATA_COPY(src_node, dst_node, nnz*elemsize + (nnz+nrow+1)*sizeof(uint32_t));
  313. return 0;
  314. }
  315. #endif // USE_CUDA
  316. /* as not all platform easily have a BLAS lib installed ... */
  317. static int dummy_copy_ram_to_ram(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  318. {
  319. starpu_csr_interface_t *src_csr;
  320. starpu_csr_interface_t *dst_csr;
  321. src_csr = starpu_data_get_interface_on_node(state, src_node);
  322. dst_csr = starpu_data_get_interface_on_node(state, dst_node);
  323. uint32_t nnz = src_csr->nnz;
  324. uint32_t nrow = src_csr->nrow;
  325. size_t elemsize = src_csr->elemsize;
  326. memcpy((void *)dst_csr->nzval, (void *)src_csr->nzval, nnz*elemsize);
  327. memcpy((void *)dst_csr->colind, (void *)src_csr->colind, nnz*sizeof(uint32_t));
  328. memcpy((void *)dst_csr->rowptr, (void *)src_csr->rowptr, (nrow+1)*sizeof(uint32_t));
  329. TRACE_DATA_COPY(src_node, dst_node, nnz*elemsize + (nnz+nrow+1)*sizeof(uint32_t));
  330. return 0;
  331. }