starpufftx1d.c 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. /*
  2. * StarPU
  3. * Copyright (C) INRIA 2009 (see AUTHORS file)
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU Lesser General Public License as published by
  7. * the Free Software Foundation; either version 2.1 of the License, or (at
  8. * your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR in PARTICULAR PURPOSE.
  13. *
  14. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. */
  16. #define DIV_1D 64
  17. #define STEP_TAG_1D(plan, step, i) _STEP_TAG(plan, step, i)
  18. #ifdef USE_CUDA
  19. /* Twist the full vector into a n2 chunk */
  20. static void
  21. STARPUFFT(twist1_1d_kernel_gpu)(void *descr[], void *_args)
  22. {
  23. struct STARPUFFT(args) *args = _args;
  24. STARPUFFT(plan) plan = args->plan;
  25. int i = args->i;
  26. int n1 = plan->n1[0];
  27. int n2 = plan->n2[0];
  28. _cufftComplex * restrict in = (_cufftComplex *)GET_VECTOR_PTR(descr[0]);
  29. _cufftComplex * restrict twisted1 = (_cufftComplex *)GET_VECTOR_PTR(descr[1]);
  30. cudaStream_t stream = STARPUFFT(get_local_stream)(plan, starpu_get_worker_id());
  31. STARPUFFT(cuda_twist1_1d_host)(in, twisted1, i, n1, n2, stream);
  32. cudaStreamSynchronize(stream);
  33. }
  34. /* Perform an n2 fft */
  35. static void
  36. STARPUFFT(fft1_1d_kernel_gpu)(void *descr[], void *_args)
  37. {
  38. struct STARPUFFT(args) *args = _args;
  39. STARPUFFT(plan) plan = args->plan;
  40. int i = args->i;
  41. int n2 = plan->n2[0];
  42. cufftResult cures;
  43. _cufftComplex * restrict in = (_cufftComplex *)GET_VECTOR_PTR(descr[0]);
  44. _cufftComplex * restrict out = (_cufftComplex *)GET_VECTOR_PTR(descr[1]);
  45. const _cufftComplex * restrict roots = (_cufftComplex *)GET_VECTOR_PTR(descr[2]);
  46. int workerid = starpu_get_worker_id();
  47. cudaStream_t stream;
  48. if (!plan->plans[workerid].initialized1) {
  49. cures = cufftPlan1d(&plan->plans[workerid].plan1_cuda, n2, _CUFFT_C2C, 1);
  50. stream = STARPUFFT(get_local_stream)(plan, workerid);
  51. cufftSetStream(plan->plans[workerid].plan1_cuda, stream);
  52. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  53. plan->plans[workerid].initialized1 = 1;
  54. }
  55. stream = plan->plans[workerid].stream;
  56. cures = _cufftExecC2C(plan->plans[workerid].plan1_cuda, in, out, plan->sign == -1 ? CUFFT_FORWARD : CUFFT_INVERSE);
  57. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  58. STARPUFFT(cuda_twiddle_1d_host)(out, roots, n2, i, stream);
  59. cudaStreamSynchronize(plan->plans[workerid].stream);
  60. }
  61. static void
  62. STARPUFFT(fft2_1d_kernel_gpu)(void *descr[], void *_args)
  63. {
  64. struct STARPUFFT(args) *args = _args;
  65. STARPUFFT(plan) plan = args->plan;
  66. int n1 = plan->n1[0];
  67. int n2 = plan->n2[0];
  68. int n3 = n2/DIV_1D;
  69. cufftResult cures;
  70. _cufftComplex * restrict in = (_cufftComplex *)GET_VECTOR_PTR(descr[0]);
  71. _cufftComplex * restrict out = (_cufftComplex *)GET_VECTOR_PTR(descr[1]);
  72. int workerid = starpu_get_worker_id();
  73. if (!plan->plans[workerid].initialized2) {
  74. cures = cufftPlan1d(&plan->plans[workerid].plan2_cuda, n1, _CUFFT_C2C, n3);
  75. cudaStream_t stream = STARPUFFT(get_local_stream)(plan, workerid);
  76. cufftSetStream(plan->plans[workerid].plan2_cuda, stream);
  77. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  78. plan->plans[workerid].initialized2 = 1;
  79. }
  80. /* NOTE using batch support */
  81. cures = _cufftExecC2C(plan->plans[workerid].plan2_cuda, in, out, plan->sign == -1 ? CUFFT_FORWARD : CUFFT_INVERSE);
  82. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  83. cudaStreamSynchronize(plan->plans[workerid].stream);
  84. }
  85. #endif
  86. /* Twist the full vector into a n2 chunk */
  87. static void
  88. STARPUFFT(twist1_1d_kernel_cpu)(void *descr[], void *_args)
  89. {
  90. struct STARPUFFT(args) *args = _args;
  91. STARPUFFT(plan) plan = args->plan;
  92. int i = args->i;
  93. int j;
  94. int n1 = plan->n1[0];
  95. int n2 = plan->n2[0];
  96. STARPUFFT(complex) * restrict in = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[0]);
  97. STARPUFFT(complex) * restrict twisted1 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[1]);
  98. //printf("twist1 %d %g\n", i, (double) cabs(plan->in[i]));
  99. for (j = 0; j < n2; j++)
  100. twisted1[j] = in[i+j*n1];
  101. }
  102. #ifdef HAVE_FFTW
  103. /* Perform an n2 fft */
  104. static void
  105. STARPUFFT(fft1_1d_kernel_cpu)(void *descr[], void *_args)
  106. {
  107. struct STARPUFFT(args) *args = _args;
  108. STARPUFFT(plan) plan = args->plan;
  109. int i = args->i;
  110. int j;
  111. int n2 = plan->n2[0];
  112. int workerid = starpu_get_worker_id();
  113. const STARPUFFT(complex) * restrict twisted1 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[0]);
  114. STARPUFFT(complex) * restrict fft1 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[1]);
  115. _fftw_complex * restrict worker_in1 = (STARPUFFT(complex) *)plan->plans[workerid].in1;
  116. _fftw_complex * restrict worker_out1 = (STARPUFFT(complex) *)plan->plans[workerid].out1;
  117. //printf("fft1 %d %g\n", i, (double) cabs(twisted1[0]));
  118. memcpy(worker_in1, twisted1, plan->totsize2 * sizeof(*worker_in1));
  119. _FFTW(execute)(plan->plans[workerid].plan1_cpu);
  120. for (j = 0; j < n2; j++)
  121. fft1[j] = worker_out1[j] * plan->roots[0][i*j];
  122. }
  123. #endif
  124. /* Twist the full vector into a package of n2/DIV_1D (n1) chunks */
  125. static void
  126. STARPUFFT(twist2_1d_kernel_cpu)(void *descr[], void *_args)
  127. {
  128. struct STARPUFFT(args) *args = _args;
  129. STARPUFFT(plan) plan = args->plan;
  130. int jj = args->jj; /* between 0 and DIV_1D */
  131. int jjj; /* beetween 0 and n3 */
  132. int i;
  133. int n1 = plan->n1[0];
  134. int n2 = plan->n2[0];
  135. int n3 = n2/DIV_1D;
  136. STARPUFFT(complex) * restrict twisted2 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[0]);
  137. //printf("twist2 %d %g\n", jj, (double) cabs(plan->fft1[jj]));
  138. for (jjj = 0; jjj < n3; jjj++) {
  139. int j = jj * n3 + jjj;
  140. for (i = 0; i < n1; i++)
  141. twisted2[jjj*n1+i] = plan->fft1[i*n2+j];
  142. }
  143. }
  144. #ifdef HAVE_FFTW
  145. /* Perform n2/DIV_1D (n1) ffts */
  146. static void
  147. STARPUFFT(fft2_1d_kernel_cpu)(void *descr[], void *_args)
  148. {
  149. struct STARPUFFT(args) *args = _args;
  150. STARPUFFT(plan) plan = args->plan;
  151. //int jj = args->jj;
  152. int workerid = starpu_get_worker_id();
  153. const STARPUFFT(complex) * restrict twisted2 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[0]);
  154. STARPUFFT(complex) * restrict fft2 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[1]);
  155. //printf("fft2 %d %g\n", jj, (double) cabs(twisted2[plan->totsize4-1]));
  156. _fftw_complex * restrict worker_in2 = (STARPUFFT(complex) *)plan->plans[workerid].in2;
  157. _fftw_complex * restrict worker_out2 = (STARPUFFT(complex) *)plan->plans[workerid].out2;
  158. memcpy(worker_in2, twisted2, plan->totsize4 * sizeof(*worker_in2));
  159. _FFTW(execute)(plan->plans[workerid].plan2_cpu);
  160. /* no twiddle */
  161. memcpy(fft2, worker_out2, plan->totsize4 * sizeof(*worker_out2));
  162. }
  163. #endif
  164. /* Spread the package of n2/DIV_1D (n1) chunks into the full vector */
  165. static void
  166. STARPUFFT(twist3_1d_kernel_cpu)(void *descr[], void *_args)
  167. {
  168. struct STARPUFFT(args) *args = _args;
  169. STARPUFFT(plan) plan = args->plan;
  170. int jj = args->jj; /* between 0 and DIV_1D */
  171. int jjj; /* beetween 0 and n3 */
  172. int i;
  173. int n1 = plan->n1[0];
  174. int n2 = plan->n2[0];
  175. int n3 = n2/DIV_1D;
  176. const STARPUFFT(complex) * restrict fft2 = (STARPUFFT(complex) *)GET_VECTOR_PTR(descr[0]);
  177. //printf("twist3 %d %g\n", jj, (double) cabs(fft2[0]));
  178. for (jjj = 0; jjj < n3; jjj++) {
  179. int j = jj * n3 + jjj;
  180. for (i = 0; i < n1; i++)
  181. plan->out[i*n2+j] = fft2[jjj*n1+i];
  182. }
  183. }
  184. static struct starpu_perfmodel_t STARPUFFT(twist1_1d_model) = {
  185. .type = HISTORY_BASED,
  186. .symbol = TYPE"twist1_1d"
  187. };
  188. static struct starpu_perfmodel_t STARPUFFT(fft1_1d_model) = {
  189. .type = HISTORY_BASED,
  190. .symbol = TYPE"fft1_1d"
  191. };
  192. static struct starpu_perfmodel_t STARPUFFT(twist2_1d_model) = {
  193. .type = HISTORY_BASED,
  194. .symbol = TYPE"twist2_1d"
  195. };
  196. static struct starpu_perfmodel_t STARPUFFT(fft2_1d_model) = {
  197. .type = HISTORY_BASED,
  198. .symbol = TYPE"fft2_1d"
  199. };
  200. static struct starpu_perfmodel_t STARPUFFT(twist3_1d_model) = {
  201. .type = HISTORY_BASED,
  202. .symbol = TYPE"twist3_1d"
  203. };
  204. static starpu_codelet STARPUFFT(twist1_1d_codelet) = {
  205. .where =
  206. #ifdef USE_CUDA
  207. CUDA|
  208. #endif
  209. CORE,
  210. #ifdef USE_CUDA
  211. .cuda_func = STARPUFFT(twist1_1d_kernel_gpu),
  212. #endif
  213. .core_func = STARPUFFT(twist1_1d_kernel_cpu),
  214. .model = &STARPUFFT(twist1_1d_model),
  215. .nbuffers = 2
  216. };
  217. static starpu_codelet STARPUFFT(fft1_1d_codelet) = {
  218. .where =
  219. #ifdef USE_CUDA
  220. CUDA|
  221. #endif
  222. #ifdef HAVE_FFTW
  223. CORE|
  224. #endif
  225. 0,
  226. #ifdef USE_CUDA
  227. .cuda_func = STARPUFFT(fft1_1d_kernel_gpu),
  228. #endif
  229. #ifdef HAVE_FFTW
  230. .core_func = STARPUFFT(fft1_1d_kernel_cpu),
  231. #endif
  232. .model = &STARPUFFT(fft1_1d_model),
  233. .nbuffers = 3
  234. };
  235. static starpu_codelet STARPUFFT(twist2_1d_codelet) = {
  236. .where = CORE,
  237. .core_func = STARPUFFT(twist2_1d_kernel_cpu),
  238. .model = &STARPUFFT(twist2_1d_model),
  239. .nbuffers = 1
  240. };
  241. static starpu_codelet STARPUFFT(fft2_1d_codelet) = {
  242. .where =
  243. #ifdef USE_CUDA
  244. CUDA|
  245. #endif
  246. #ifdef HAVE_FFTW
  247. CORE|
  248. #endif
  249. 0,
  250. #ifdef USE_CUDA
  251. .cuda_func = STARPUFFT(fft2_1d_kernel_gpu),
  252. #endif
  253. #ifdef HAVE_FFTW
  254. .core_func = STARPUFFT(fft2_1d_kernel_cpu),
  255. #endif
  256. .model = &STARPUFFT(fft2_1d_model),
  257. .nbuffers = 2
  258. };
  259. static starpu_codelet STARPUFFT(twist3_1d_codelet) = {
  260. .where = CORE,
  261. .core_func = STARPUFFT(twist3_1d_kernel_cpu),
  262. .model = &STARPUFFT(twist3_1d_model),
  263. .nbuffers = 1
  264. };
  265. STARPUFFT(plan)
  266. STARPUFFT(plan_dft_1d)(int n, int sign, unsigned flags)
  267. {
  268. int workerid;
  269. int n1 = DIV_1D;
  270. int n2 = n / n1;
  271. int n3;
  272. int z;
  273. struct starpu_task *task;
  274. /*
  275. * Simple strategy:
  276. *
  277. * - twist1: twist input in n1 (n2) chunks
  278. * - fft1: perform n1 (n2) ffts
  279. * - twist2: twist into n2 (n1) chunks distributed in
  280. * DIV_1D groups
  281. * - fft2: perform DIV_1D times n3 (n1) ffts
  282. * - twist3: twist back into output
  283. */
  284. #ifdef USE_CUDA
  285. /* cufft 1D limited to 8M elements */
  286. while (n2 > 8 << 20) {
  287. n1 *= 2;
  288. n2 /= 2;
  289. }
  290. #endif
  291. STARPU_ASSERT(n == n1*n2);
  292. STARPU_ASSERT(n1 < (1ULL << I_BITS));
  293. /* distribute the n2 second ffts into DIV_1D packages */
  294. n3 = n2 / DIV_1D;
  295. STARPU_ASSERT(n2 == n3*DIV_1D);
  296. /* TODO: flags? Automatically set FFTW_MEASURE on calibration? */
  297. STARPU_ASSERT(flags == 0);
  298. STARPUFFT(plan) plan = malloc(sizeof(*plan));
  299. memset(plan, 0, sizeof(*plan));
  300. plan->number = STARPU_ATOMIC_ADD(&starpufft_last_plan_number, 1) - 1;
  301. /* 4bit limitation in the tag space */
  302. STARPU_ASSERT(plan->number < (1ULL << NUMBER_BITS));
  303. plan->dim = 1;
  304. plan->n = malloc(plan->dim * sizeof(*plan->n));
  305. plan->n[0] = n;
  306. check_dims(plan);
  307. plan->n1 = malloc(plan->dim * sizeof(*plan->n1));
  308. plan->n1[0] = n1;
  309. plan->n2 = malloc(plan->dim * sizeof(*plan->n2));
  310. plan->n2[0] = n2;
  311. plan->totsize = n;
  312. plan->totsize1 = n1;
  313. plan->totsize2 = n2;
  314. plan->totsize3 = DIV_1D;
  315. plan->totsize4 = plan->totsize / plan->totsize3;
  316. plan->type = C2C;
  317. plan->sign = sign;
  318. compute_roots(plan);
  319. /* Initialize per-worker working set */
  320. for (workerid = 0; workerid < starpu_get_worker_count(); workerid++) {
  321. switch (starpu_get_worker_type(workerid)) {
  322. case STARPU_CORE_WORKER:
  323. #ifdef HAVE_FFTW
  324. /* first fft plan: one n2 fft */
  325. plan->plans[workerid].in1 = _FFTW(malloc)(plan->totsize2 * sizeof(_fftw_complex));
  326. memset(plan->plans[workerid].in1, 0, plan->totsize2 * sizeof(_fftw_complex));
  327. plan->plans[workerid].out1 = _FFTW(malloc)(plan->totsize2 * sizeof(_fftw_complex));
  328. memset(plan->plans[workerid].out1, 0, plan->totsize2 * sizeof(_fftw_complex));
  329. plan->plans[workerid].plan1_cpu = _FFTW(plan_dft_1d)(n2, plan->plans[workerid].in1, plan->plans[workerid].out1, sign, _FFTW_FLAGS);
  330. STARPU_ASSERT(plan->plans[workerid].plan1_cpu);
  331. /* second fft plan: n3 n1 ffts */
  332. plan->plans[workerid].in2 = _FFTW(malloc)(plan->totsize4 * sizeof(_fftw_complex));
  333. memset(plan->plans[workerid].in2, 0, plan->totsize4 * sizeof(_fftw_complex));
  334. plan->plans[workerid].out2 = _FFTW(malloc)(plan->totsize4 * sizeof(_fftw_complex));
  335. memset(plan->plans[workerid].out2, 0, plan->totsize4 * sizeof(_fftw_complex));
  336. plan->plans[workerid].plan2_cpu = _FFTW(plan_many_dft)(plan->dim,
  337. plan->n1, n3,
  338. /* input */ plan->plans[workerid].in2, NULL, 1, plan->totsize1,
  339. /* output */ plan->plans[workerid].out2, NULL, 1, plan->totsize1,
  340. sign, _FFTW_FLAGS);
  341. STARPU_ASSERT(plan->plans[workerid].plan2_cpu);
  342. #else
  343. #warning libstarpufft can not work correctly without libfftw3
  344. #endif
  345. break;
  346. case STARPU_CUDA_WORKER:
  347. #ifdef USE_CUDA
  348. plan->plans[workerid].initialized1 = 0;
  349. plan->plans[workerid].initialized2 = 0;
  350. #endif
  351. break;
  352. default:
  353. STARPU_ABORT();
  354. break;
  355. }
  356. }
  357. plan->twisted1 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->twisted1));
  358. memset(plan->twisted1, 0, plan->totsize * sizeof(*plan->twisted1));
  359. plan->fft1 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->fft1));
  360. memset(plan->fft1, 0, plan->totsize * sizeof(*plan->fft1));
  361. plan->twisted2 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->twisted2));
  362. memset(plan->twisted2, 0, plan->totsize * sizeof(*plan->twisted2));
  363. plan->fft2 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->fft2));
  364. memset(plan->fft2, 0, plan->totsize * sizeof(*plan->fft2));
  365. plan->twisted1_handle = malloc(plan->totsize1 * sizeof(*plan->twisted1_handle));
  366. plan->fft1_handle = malloc(plan->totsize1 * sizeof(*plan->fft1_handle));
  367. plan->twisted2_handle = malloc(plan->totsize3 * sizeof(*plan->twisted2_handle));
  368. plan->fft2_handle = malloc(plan->totsize3 * sizeof(*plan->fft2_handle));
  369. plan->twist1_tasks = malloc(plan->totsize1 * sizeof(*plan->twist1_tasks));
  370. plan->fft1_tasks = malloc(plan->totsize1 * sizeof(*plan->fft1_tasks));
  371. plan->twist2_tasks = malloc(plan->totsize3 * sizeof(*plan->twist2_tasks));
  372. plan->fft2_tasks = malloc(plan->totsize3 * sizeof(*plan->fft2_tasks));
  373. plan->twist3_tasks = malloc(plan->totsize3 * sizeof(*plan->twist3_tasks));
  374. plan->fft1_args = malloc(plan->totsize1 * sizeof(*plan->fft1_args));
  375. plan->fft2_args = malloc(plan->totsize3 * sizeof(*plan->fft2_args));
  376. /* Create first-round tasks */
  377. for (z = 0; z < plan->totsize1; z++) {
  378. int i = z;
  379. #define STEP_TAG(step) STEP_TAG_1D(plan, step, i)
  380. plan->fft1_args[z].plan = plan;
  381. plan->fft1_args[z].i = i;
  382. /* Register (n2) chunks */
  383. starpu_register_vector_data(&plan->twisted1_handle[z], 0, (uintptr_t) &plan->twisted1[z*plan->totsize2], plan->totsize2, sizeof(*plan->twisted1));
  384. starpu_register_vector_data(&plan->fft1_handle[z], 0, (uintptr_t) &plan->fft1[z*plan->totsize2], plan->totsize2, sizeof(*plan->fft1));
  385. /* We'll need it on the CPU for the second twist anyway */
  386. starpu_data_set_wb_mask(plan->fft1_handle[z], 1<<0);
  387. /* Create twist1 task */
  388. plan->twist1_tasks[z] = task = starpu_task_create();
  389. task->cl = &STARPUFFT(twist1_1d_codelet);
  390. //task->buffers[0].handle = to be filled at execution
  391. task->buffers[0].mode = STARPU_R;
  392. task->buffers[1].handle = plan->twisted1_handle[z];
  393. task->buffers[1].mode = STARPU_W;
  394. task->cl_arg = &plan->fft1_args[z];
  395. task->tag_id = STEP_TAG(TWIST1);
  396. task->use_tag = 1;
  397. task->detach = 1;
  398. task->destroy = 0;
  399. /* Tell that fft1 depends on twisted1 */
  400. starpu_tag_declare_deps(STEP_TAG(FFT1),
  401. 1, STEP_TAG(TWIST1));
  402. /* Create FFT1 task */
  403. plan->fft1_tasks[z] = task = starpu_task_create();
  404. task->cl = &STARPUFFT(fft1_1d_codelet);
  405. task->buffers[0].handle = plan->twisted1_handle[z];
  406. task->buffers[0].mode = STARPU_R;
  407. task->buffers[1].handle = plan->fft1_handle[z];
  408. task->buffers[1].mode = STARPU_W;
  409. task->buffers[2].handle = plan->roots_handle[0];
  410. task->buffers[2].mode = STARPU_R;
  411. task->cl_arg = &plan->fft1_args[z];
  412. task->tag_id = STEP_TAG(FFT1);
  413. task->use_tag = 1;
  414. task->detach = 1;
  415. task->destroy = 0;
  416. /* Tell that to be done with first step we need to have
  417. * finished this fft1 */
  418. starpu_tag_declare_deps(STEP_TAG_1D(plan, JOIN, 0),
  419. 1, STEP_TAG(FFT1));
  420. #undef STEP_TAG
  421. }
  422. /* Create join task */
  423. plan->join_task = task = starpu_task_create();
  424. task->cl = NULL;
  425. task->tag_id = STEP_TAG_1D(plan, JOIN, 0);
  426. task->use_tag = 1;
  427. task->detach = 1;
  428. task->destroy = 0;
  429. /* Create second-round tasks */
  430. for (z = 0; z < plan->totsize3; z++) {
  431. int jj = z;
  432. #define STEP_TAG(step) STEP_TAG_1D(plan, step, jj)
  433. plan->fft2_args[z].plan = plan;
  434. plan->fft2_args[z].jj = jj;
  435. /* Register n3 (n1) chunks */
  436. starpu_register_vector_data(&plan->twisted2_handle[z], 0, (uintptr_t) &plan->twisted2[z*plan->totsize4], plan->totsize4, sizeof(*plan->twisted2));
  437. starpu_register_vector_data(&plan->fft2_handle[z], 0, (uintptr_t) &plan->fft2[z*plan->totsize4], plan->totsize4, sizeof(*plan->fft2));
  438. /* We'll need it on the CPU for the last twist anyway */
  439. starpu_data_set_wb_mask(plan->fft2_handle[z], 1<<0);
  440. /* Tell that twisted2 depends on the whole first step to be
  441. * done */
  442. starpu_tag_declare_deps(STEP_TAG(TWIST2),
  443. 1, STEP_TAG_1D(plan, JOIN, 0));
  444. /* Create twist2 task */
  445. plan->twist2_tasks[z] = task = starpu_task_create();
  446. task->cl = &STARPUFFT(twist2_1d_codelet);
  447. task->buffers[0].handle = plan->twisted2_handle[z];
  448. task->buffers[0].mode = STARPU_W;
  449. task->cl_arg = &plan->fft2_args[z];
  450. task->tag_id = STEP_TAG(TWIST2);
  451. task->use_tag = 1;
  452. task->detach = 1;
  453. task->destroy = 0;
  454. /* Tell that fft2 depends on twisted2 */
  455. starpu_tag_declare_deps(STEP_TAG(FFT2),
  456. 1, STEP_TAG(TWIST2));
  457. /* Create FFT2 task */
  458. plan->fft2_tasks[z] = task = starpu_task_create();
  459. task->cl = &STARPUFFT(fft2_1d_codelet);
  460. task->buffers[0].handle = plan->twisted2_handle[z];
  461. task->buffers[0].mode = STARPU_R;
  462. task->buffers[1].handle = plan->fft2_handle[z];
  463. task->buffers[1].mode = STARPU_W;
  464. task->cl_arg = &plan->fft2_args[z];
  465. task->tag_id = STEP_TAG(FFT2);
  466. task->use_tag = 1;
  467. task->detach = 1;
  468. task->destroy = 0;
  469. /* Tell that twist3 depends on fft2 */
  470. starpu_tag_declare_deps(STEP_TAG(TWIST3),
  471. 1, STEP_TAG(FFT2));
  472. /* Create twist3 tasks */
  473. plan->twist3_tasks[z] = task = starpu_task_create();
  474. task->cl = &STARPUFFT(twist3_1d_codelet);
  475. task->buffers[0].handle = plan->fft2_handle[z];
  476. task->buffers[0].mode = STARPU_R;
  477. task->cl_arg = &plan->fft2_args[z];
  478. task->tag_id = STEP_TAG(TWIST3);
  479. task->use_tag = 1;
  480. task->detach = 1;
  481. task->destroy = 0;
  482. /* Tell that to be completely finished we need to have finished this twisted3 */
  483. starpu_tag_declare_deps(STEP_TAG_1D(plan, END, 0),
  484. 1, STEP_TAG(TWIST3));
  485. #undef STEP_TAG
  486. }
  487. /* Create end task */
  488. plan->end_task = task = starpu_task_create();
  489. task->cl = NULL;
  490. task->tag_id = STEP_TAG_1D(plan, END, 0);
  491. task->use_tag = 1;
  492. task->detach = 1;
  493. task->destroy = 0;
  494. return plan;
  495. }
  496. static starpu_tag_t
  497. STARPUFFT(start1dC2C)(STARPUFFT(plan) plan)
  498. {
  499. STARPU_ASSERT(plan->type == C2C);
  500. int z;
  501. for (z=0; z < plan->totsize1; z++) {
  502. starpu_submit_task(plan->twist1_tasks[z]);
  503. starpu_submit_task(plan->fft1_tasks[z]);
  504. }
  505. starpu_submit_task(plan->join_task);
  506. for (z=0; z < plan->totsize3; z++) {
  507. starpu_submit_task(plan->twist2_tasks[z]);
  508. starpu_submit_task(plan->fft2_tasks[z]);
  509. starpu_submit_task(plan->twist3_tasks[z]);
  510. }
  511. starpu_submit_task(plan->end_task);
  512. return STEP_TAG_1D(plan, END, 0);
  513. }
  514. static void
  515. STARPUFFT(free_1d_tags)(STARPUFFT(plan) plan)
  516. {
  517. unsigned i;
  518. int n1 = plan->n1[0];
  519. for (i = 0; i < n1; i++) {
  520. starpu_tag_remove(STEP_TAG_1D(plan, TWIST1, i));
  521. starpu_tag_remove(STEP_TAG_1D(plan, FFT1, i));
  522. }
  523. starpu_tag_remove(STEP_TAG_1D(plan, JOIN, 0));
  524. for (i = 0; i < DIV_1D; i++) {
  525. starpu_tag_remove(STEP_TAG_1D(plan, TWIST2, i));
  526. starpu_tag_remove(STEP_TAG_1D(plan, FFT2, i));
  527. starpu_tag_remove(STEP_TAG_1D(plan, TWIST3, i));
  528. }
  529. starpu_tag_remove(STEP_TAG_1D(plan, END, 0));
  530. }