bcsr_interface.c 12 KB


  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #include <datawizard/data_parameters.h>
  17. #include <datawizard/coherency.h>
  18. #include <datawizard/copy-driver.h>
  19. #include <datawizard/hierarchy.h>
  20. #include <starpu.h>
  21. #include <common/hash.h>
  22. #ifdef USE_CUDA
  23. #include <cuda.h>
  24. #endif
  25. /*
  26. * BCSR : blocked CSR, we use blocks of size (r x c)
  27. */
  28. size_t allocate_bcsr_buffer_on_node(struct starpu_data_state_t *state, uint32_t dst_node);
  29. void liberate_bcsr_buffer_on_node(starpu_data_interface_t *interface, uint32_t node);
  30. size_t dump_bcsr_interface(starpu_data_interface_t *interface, void *_buffer);
  31. int do_copy_bcsr_buffer_1_to_1(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node);
  32. size_t bcsr_interface_get_size(struct starpu_data_state_t *state);
  33. uint32_t footprint_bcsr_interface_crc32(data_state *state, uint32_t hstate);
  34. struct data_interface_ops_t interface_bcsr_ops = {
  35. .allocate_data_on_node = allocate_bcsr_buffer_on_node,
  36. .liberate_data_on_node = liberate_bcsr_buffer_on_node,
  37. .copy_data_1_to_1 = do_copy_bcsr_buffer_1_to_1,
  38. .dump_data_interface = dump_bcsr_interface,
  39. .get_size = bcsr_interface_get_size,
  40. .interfaceid = BCSR_INTERFACE,
  41. .footprint = footprint_bcsr_interface_crc32
  42. };
  43. void starpu_register_bcsr_data(struct starpu_data_state_t **handle, uint32_t home_node,
  44. uint32_t nnz, uint32_t nrow, uintptr_t nzval, uint32_t *colind, uint32_t *rowptr, uint32_t firstentry, uint32_t r, uint32_t c, size_t elemsize)
  45. {
  46. struct starpu_data_state_t *state = calloc(1, sizeof(struct starpu_data_state_t));
  47. STARPU_ASSERT(state);
  48. STARPU_ASSERT(handle);
  49. *handle = state;
  50. unsigned node;
  51. for (node = 0; node < MAXNODES; node++)
  52. {
  53. starpu_bcsr_interface_t *local_interface = &state->interface[node].bcsr;
  54. if (node == home_node) {
  55. local_interface->nzval = nzval;
  56. local_interface->colind = colind;
  57. local_interface->rowptr = rowptr;
  58. }
  59. else {
  60. local_interface->nzval = 0;
  61. local_interface->colind = NULL;
  62. local_interface->rowptr = NULL;
  63. }
  64. local_interface->nnz = nnz;
  65. local_interface->nrow = nrow;
  66. local_interface->firstentry = firstentry;
  67. local_interface->r = r;
  68. local_interface->c = c;
  69. local_interface->elemsize = elemsize;
  70. }
  71. state->ops = &interface_bcsr_ops;
  72. register_new_data(state, home_node, 0);
  73. }
  74. static inline uint32_t footprint_bcsr_interface_generic(uint32_t (*hash_func)(uint32_t input, uint32_t hstate), data_state *state, uint32_t hstate)
  75. {
  76. uint32_t hash;
  77. hash = hstate;
  78. hash = hash_func(starpu_get_bcsr_nnz(state), hash);
  79. hash = hash_func(starpu_get_bcsr_c(state), hash);
  80. hash = hash_func(starpu_get_bcsr_r(state), hash);
  81. return hash;
  82. }
  83. uint32_t footprint_bcsr_interface_crc32(data_state *state, uint32_t hstate)
  84. {
  85. return footprint_bcsr_interface_generic(crc32_be, state, hstate);
  86. }
  87. struct dumped_bcsr_interface_s {
  88. uint32_t nnz;
  89. uint32_t nrow;
  90. uintptr_t nzval;
  91. uint32_t *colind;
  92. uint32_t *rowptr;
  93. uint32_t firstentry;
  94. uint32_t r;
  95. uint32_t c;
  96. uint32_t elemsize;
  97. } __attribute__ ((packed));
  98. size_t dump_bcsr_interface(starpu_data_interface_t *interface, void *_buffer)
  99. {
  100. /* yes, that's DIRTY ... */
  101. struct dumped_bcsr_interface_s *buffer = _buffer;
  102. buffer->nnz = (*interface).bcsr.nnz;
  103. buffer->nrow = (*interface).bcsr.nrow;
  104. buffer->nzval = (*interface).bcsr.nzval;
  105. buffer->colind = (*interface).bcsr.colind;
  106. buffer->rowptr = (*interface).bcsr.rowptr;
  107. buffer->firstentry = (*interface).bcsr.firstentry;
  108. buffer->r = (*interface).bcsr.r;
  109. buffer->c = (*interface).bcsr.c;
  110. buffer->elemsize = (*interface).bcsr.elemsize;
  111. return (sizeof(struct dumped_bcsr_interface_s));
  112. }
  113. /* offer an access to the data parameters */
  114. uint32_t starpu_get_bcsr_nnz(struct starpu_data_state_t *state)
  115. {
  116. return (state->interface[0].bcsr.nnz);
  117. }
  118. uint32_t starpu_get_bcsr_nrow(struct starpu_data_state_t *state)
  119. {
  120. return (state->interface[0].bcsr.nrow);
  121. }
  122. uint32_t starpu_get_bcsr_firstentry(struct starpu_data_state_t *state)
  123. {
  124. return (state->interface[0].bcsr.firstentry);
  125. }
  126. uint32_t starpu_get_bcsr_r(struct starpu_data_state_t *state)
  127. {
  128. return (state->interface[0].bcsr.r);
  129. }
  130. uint32_t starpu_get_bcsr_c(struct starpu_data_state_t *state)
  131. {
  132. return (state->interface[0].bcsr.c);
  133. }
  134. size_t starpu_get_bcsr_elemsize(struct starpu_data_state_t *state)
  135. {
  136. return (state->interface[0].bcsr.elemsize);
  137. }
  138. uintptr_t starpu_get_bcsr_local_nzval(struct starpu_data_state_t *state)
  139. {
  140. unsigned node;
  141. node = get_local_memory_node();
  142. STARPU_ASSERT(state->per_node[node].allocated);
  143. return (state->interface[node].bcsr.nzval);
  144. }
  145. uint32_t *starpu_get_bcsr_local_colind(struct starpu_data_state_t *state)
  146. {
  147. // unsigned node;
  148. // node = get_local_memory_node();
  149. //
  150. // STARPU_ASSERT(state->per_node[node].allocated);
  151. //
  152. // return (state->interface[node].bcsr.colind);
  153. /* XXX */
  154. return (state->interface[0].bcsr.colind);
  155. }
  156. uint32_t *starpu_get_bcsr_local_rowptr(struct starpu_data_state_t *state)
  157. {
  158. // unsigned node;
  159. // node = get_local_memory_node();
  160. //
  161. // STARPU_ASSERT(state->per_node[node].allocated);
  162. //
  163. // return (state->interface[node].bcsr.rowptr);
  164. /* XXX */
  165. return (state->interface[0].bcsr.rowptr);
  166. }
  167. size_t bcsr_interface_get_size(struct starpu_data_state_t *state)
  168. {
  169. size_t size;
  170. uint32_t nnz = starpu_get_bcsr_nnz(state);
  171. uint32_t nrow = starpu_get_bcsr_nrow(state);
  172. uint32_t r = starpu_get_bcsr_r(state);
  173. uint32_t c = starpu_get_bcsr_c(state);
  174. size_t elemsize = starpu_get_bcsr_elemsize(state);
  175. size = nnz*r*c*elemsize + nnz*sizeof(uint32_t) + (nrow+1)*sizeof(uint32_t);
  176. return size;
  177. }
  178. /* memory allocation/deallocation primitives for the BLAS interface */
  179. /* returns the size of the allocated area */
  180. size_t allocate_bcsr_buffer_on_node(struct starpu_data_state_t *state, uint32_t dst_node)
  181. {
  182. uintptr_t addr_nzval;
  183. uint32_t *addr_colind, *addr_rowptr;
  184. size_t allocated_memory;
  185. /* we need the 3 arrays to be allocated */
  186. uint32_t nnz = state->interface[dst_node].bcsr.nnz;
  187. uint32_t nrow = state->interface[dst_node].bcsr.nrow;
  188. size_t elemsize = state->interface[dst_node].bcsr.elemsize;
  189. uint32_t r = state->interface[dst_node].bcsr.r;
  190. uint32_t c = state->interface[dst_node].bcsr.c;
  191. node_kind kind = get_node_kind(dst_node);
  192. switch(kind) {
  193. case RAM:
  194. addr_nzval = (uintptr_t)malloc(nnz*r*c*elemsize);
  195. if (!addr_nzval)
  196. goto fail_nzval;
  197. addr_colind = malloc(nnz*sizeof(uint32_t));
  198. if (!addr_colind)
  199. goto fail_colind;
  200. addr_rowptr = malloc((nrow+1)*sizeof(uint32_t));
  201. if (!addr_rowptr)
  202. goto fail_rowptr;
  203. break;
  204. #ifdef USE_CUDA
  205. case CUDA_RAM:
  206. cublasAlloc(nnz*r*c, elemsize, (void **)&addr_nzval);
  207. if (!addr_nzval)
  208. goto fail_nzval;
  209. cublasAlloc(nnz, sizeof(uint32_t), (void **)&addr_colind);
  210. if (!addr_colind)
  211. goto fail_colind;
  212. cublasAlloc((nrow+1), sizeof(uint32_t), (void **)&addr_rowptr);
  213. if (!addr_rowptr)
  214. goto fail_rowptr;
  215. break;
  216. #endif
  217. default:
  218. assert(0);
  219. }
  220. /* allocation succeeded */
  221. allocated_memory =
  222. nnz*r*c*elemsize + nnz*sizeof(uint32_t) + (nrow+1)*sizeof(uint32_t);
  223. /* update the data properly in consequence */
  224. state->interface[dst_node].bcsr.nzval = addr_nzval;
  225. state->interface[dst_node].bcsr.colind = addr_colind;
  226. state->interface[dst_node].bcsr.rowptr = addr_rowptr;
  227. return allocated_memory;
  228. fail_rowptr:
  229. switch(kind) {
  230. case RAM:
  231. free((void *)addr_colind);
  232. #ifdef USE_CUDA
  233. case CUDA_RAM:
  234. cublasFree((void*)addr_colind);
  235. break;
  236. #endif
  237. default:
  238. assert(0);
  239. }
  240. fail_colind:
  241. switch(kind) {
  242. case RAM:
  243. free((void *)addr_nzval);
  244. #ifdef USE_CUDA
  245. case CUDA_RAM:
  246. cublasFree((void*)addr_nzval);
  247. break;
  248. #endif
  249. default:
  250. assert(0);
  251. }
  252. fail_nzval:
  253. /* allocation failed */
  254. allocated_memory = 0;
  255. return allocated_memory;
  256. }
  257. void liberate_bcsr_buffer_on_node(starpu_data_interface_t *interface, uint32_t node)
  258. {
  259. node_kind kind = get_node_kind(node);
  260. switch(kind) {
  261. case RAM:
  262. free((void*)interface->bcsr.nzval);
  263. free((void*)interface->bcsr.colind);
  264. free((void*)interface->bcsr.rowptr);
  265. break;
  266. #ifdef USE_CUDA
  267. case CUDA_RAM:
  268. cublasFree((void*)interface->bcsr.nzval);
  269. cublasFree((void*)interface->bcsr.colind);
  270. cublasFree((void*)interface->bcsr.rowptr);
  271. break;
  272. #endif
  273. default:
  274. assert(0);
  275. }
  276. }
  277. #ifdef USE_CUDA
  278. static void copy_cublas_to_ram(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  279. {
  280. starpu_bcsr_interface_t *src_bcsr;
  281. starpu_bcsr_interface_t *dst_bcsr;
  282. src_bcsr = &state->interface[src_node].bcsr;
  283. dst_bcsr = &state->interface[dst_node].bcsr;
  284. uint32_t nnz = src_bcsr->nnz;
  285. uint32_t nrow = src_bcsr->nrow;
  286. size_t elemsize = src_bcsr->elemsize;
  287. uint32_t r = src_bcsr->r;
  288. uint32_t c = src_bcsr->c;
  289. cublasGetVector(nnz*r*c, elemsize, (uint8_t *)src_bcsr->nzval, 1,
  290. (uint8_t *)dst_bcsr->nzval, 1);
  291. cublasGetVector(nnz, sizeof(uint32_t), (uint8_t *)src_bcsr->colind, 1,
  292. (uint8_t *)dst_bcsr->colind, 1);
  293. cublasGetVector((nrow+1), sizeof(uint32_t), (uint8_t *)src_bcsr->rowptr, 1,
  294. (uint8_t *)dst_bcsr->rowptr, 1);
  295. TRACE_DATA_COPY(src_node, dst_node, nnz*r*c*elemsize + (nnz+nrow+1)*sizeof(uint32_t));
  296. }
  297. static void copy_ram_to_cublas(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  298. {
  299. starpu_bcsr_interface_t *src_bcsr;
  300. starpu_bcsr_interface_t *dst_bcsr;
  301. src_bcsr = &state->interface[src_node].bcsr;
  302. dst_bcsr = &state->interface[dst_node].bcsr;
  303. uint32_t nnz = src_bcsr->nnz;
  304. uint32_t nrow = src_bcsr->nrow;
  305. size_t elemsize = src_bcsr->elemsize;
  306. uint32_t r = src_bcsr->r;
  307. uint32_t c = src_bcsr->c;
  308. cublasSetVector(nnz*r*c, elemsize, (uint8_t *)src_bcsr->nzval, 1,
  309. (uint8_t *)dst_bcsr->nzval, 1);
  310. cublasSetVector(nnz, sizeof(uint32_t), (uint8_t *)src_bcsr->colind, 1,
  311. (uint8_t *)dst_bcsr->colind, 1);
  312. cublasSetVector((nrow+1), sizeof(uint32_t), (uint8_t *)src_bcsr->rowptr, 1,
  313. (uint8_t *)dst_bcsr->rowptr, 1);
  314. TRACE_DATA_COPY(src_node, dst_node, nnz*r*c*elemsize + (nnz+nrow+1)*sizeof(uint32_t));
  315. }
  316. #endif // USE_CUDA
  317. /* as not all platform easily have a BLAS lib installed ... */
  318. static void dummy_copy_ram_to_ram(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  319. {
  320. starpu_bcsr_interface_t *src_bcsr;
  321. starpu_bcsr_interface_t *dst_bcsr;
  322. src_bcsr = &state->interface[src_node].bcsr;
  323. dst_bcsr = &state->interface[dst_node].bcsr;
  324. uint32_t nnz = src_bcsr->nnz;
  325. uint32_t nrow = src_bcsr->nrow;
  326. size_t elemsize = src_bcsr->elemsize;
  327. uint32_t r = src_bcsr->r;
  328. uint32_t c = src_bcsr->c;
  329. memcpy((void *)dst_bcsr->nzval, (void *)src_bcsr->nzval, nnz*elemsize*r*c);
  330. memcpy((void *)dst_bcsr->colind, (void *)src_bcsr->colind, nnz*sizeof(uint32_t));
  331. memcpy((void *)dst_bcsr->rowptr, (void *)src_bcsr->rowptr, (nrow+1)*sizeof(uint32_t));
  332. TRACE_DATA_COPY(src_node, dst_node, nnz*elemsize*r*c + (nnz+nrow+1)*sizeof(uint32_t));
  333. }
  334. int do_copy_bcsr_buffer_1_to_1(struct starpu_data_state_t *state, uint32_t src_node, uint32_t dst_node)
  335. {
  336. node_kind src_kind = get_node_kind(src_node);
  337. node_kind dst_kind = get_node_kind(dst_node);
  338. switch (dst_kind) {
  339. case RAM:
  340. switch (src_kind) {
  341. case RAM:
  342. /* RAM -> RAM */
  343. dummy_copy_ram_to_ram(state, src_node, dst_node);
  344. break;
  345. #ifdef USE_CUDA
  346. case CUDA_RAM:
  347. /* CUBLAS_RAM -> RAM */
  348. /* only the proper CUBLAS thread can initiate this ! */
  349. if (get_local_memory_node() == src_node)
  350. {
  351. copy_cublas_to_ram(state, src_node, dst_node);
  352. }
  353. else
  354. {
  355. post_data_request(state, src_node, dst_node);
  356. }
  357. break;
  358. #endif
  359. case SPU_LS:
  360. STARPU_ASSERT(0); // TODO
  361. break;
  362. case UNUSED:
  363. printf("error node %d UNUSED\n", src_node);
  364. default:
  365. assert(0);
  366. break;
  367. }
  368. break;
  369. #ifdef USE_CUDA
  370. case CUDA_RAM:
  371. switch (src_kind) {
  372. case RAM:
  373. /* RAM -> CUBLAS_RAM */
  374. /* only the proper CUBLAS thread can initiate this ! */
  375. STARPU_ASSERT(get_local_memory_node() == dst_node);
  376. copy_ram_to_cublas(state, src_node, dst_node);
  377. break;
  378. case CUDA_RAM:
  379. case SPU_LS:
  380. STARPU_ASSERT(0); // TODO
  381. break;
  382. case UNUSED:
  383. default:
  384. STARPU_ASSERT(0);
  385. break;
  386. }
  387. break;
  388. #endif
  389. case SPU_LS:
  390. STARPU_ASSERT(0); // TODO
  391. break;
  392. case UNUSED:
  393. default:
  394. assert(0);
  395. break;
  396. }
  397. return 0;
  398. }