Explorar o código

Use execute_dft, which permits to avoid memcpy

Samuel Thibault %!s(int64=13) %!d(string=hai) anos
pai
achega
209a8a5cdd

+ 0 - 7
examples/starpufft/starpufftx.c

@@ -93,9 +93,6 @@ struct STARPUFFT(plan) {
 #ifdef STARPU_HAVE_FFTW
 		/* FFTW plans */
 		_fftw_plan plan1_cpu, plan2_cpu;
-		/* Buffers used by the plans above */
-		_fftw_complex *in1, *out1;
-		_fftw_complex *in2, *out2;
 #endif
 	} plans[STARPU_NMAXWORKERS];
 
@@ -231,11 +228,7 @@ STARPUFFT(destroy_plan)(STARPUFFT(plan) plan)
 		switch (starpu_worker_get_type(workerid)) {
 		case STARPU_CPU_WORKER:
 #ifdef STARPU_HAVE_FFTW
-			_FFTW(free)(plan->plans[workerid].in1);
-			_FFTW(free)(plan->plans[workerid].out1);
 			_FFTW(destroy_plan)(plan->plans[workerid].plan1_cpu);
-			_FFTW(free)(plan->plans[workerid].in2);
-			_FFTW(free)(plan->plans[workerid].out2);
 			_FFTW(destroy_plan)(plan->plans[workerid].plan2_cpu);
 #endif
 			break;

+ 9 - 27
examples/starpufft/starpufftx1d.c

@@ -202,20 +202,16 @@ STARPUFFT(fft1_1d_kernel_cpu)(void *descr[], void *_args)
 
 	task_per_worker[workerid]++;
 
-	const STARPUFFT(complex) * restrict twisted1 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
+	STARPUFFT(complex) * restrict twisted1 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
 	STARPUFFT(complex) * restrict fft1 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[1]);
 
-	_fftw_complex * restrict worker_in1 = (STARPUFFT(complex) *)plan->plans[workerid].in1;
-	_fftw_complex * restrict worker_out1 = (STARPUFFT(complex) *)plan->plans[workerid].out1;
-
 	/* printf("fft1 %d %g\n", i, (double) cabs(twisted1[0])); */
 
-	memcpy(worker_in1, twisted1, plan->totsize2 * sizeof(*worker_in1));
-	_FFTW(execute)(plan->plans[workerid].plan1_cpu);
+	_FFTW(execute_dft)(plan->plans[workerid].plan1_cpu, twisted1, fft1);
 
-	/* twiddle while copying from fftw output buffer to fft1 buffer */
+	/* twiddle fft1 buffer */
 	for (j = 0; j < n2; j++)
-		fft1[j] = worker_out1[j] * plan->roots[0][i*j];
+		fft1[j] = fft1[j] * plan->roots[0][i*j];
 }
 #endif
 
@@ -260,18 +256,12 @@ STARPUFFT(fft2_1d_kernel_cpu)(void *descr[], void *_args)
 
 	task_per_worker[workerid]++;
 
-	const STARPUFFT(complex) * restrict twisted2 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
+	STARPUFFT(complex) * restrict twisted2 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
 	STARPUFFT(complex) * restrict fft2 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[1]);
 
 	/* printf("fft2 %d %g\n", jj, (double) cabs(twisted2[plan->totsize4-1])); */
 
-	_fftw_complex * restrict worker_in2 = (STARPUFFT(complex) *)plan->plans[workerid].in2;
-	_fftw_complex * restrict worker_out2 = (STARPUFFT(complex) *)plan->plans[workerid].out2;
-
-	memcpy(worker_in2, twisted2, plan->totsize4 * sizeof(*worker_in2));
-	_FFTW(execute)(plan->plans[workerid].plan2_cpu);
-	/* no twiddle */
-	memcpy(fft2, worker_out2, plan->totsize4 * sizeof(*worker_out2));
+	_FFTW(execute_dft)(plan->plans[workerid].plan2_cpu, twisted2, fft2);
 }
 #endif
 
@@ -473,22 +463,14 @@ STARPUFFT(plan_dft_1d)(int n, int sign, unsigned flags)
 			/* first fft plan: one fft of size n2.
 			 * FFTW imposes that buffer pointers are known at
 			 * planning time. */
-			plan->plans[workerid].in1 = _FFTW(malloc)(plan->totsize2 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].in1, 0, plan->totsize2 * sizeof(_fftw_complex));
-			plan->plans[workerid].out1 = _FFTW(malloc)(plan->totsize2 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].out1, 0, plan->totsize2 * sizeof(_fftw_complex));
-			plan->plans[workerid].plan1_cpu = _FFTW(plan_dft_1d)(n2, plan->plans[workerid].in1, plan->plans[workerid].out1, sign, _FFTW_FLAGS);
+			plan->plans[workerid].plan1_cpu = _FFTW(plan_dft_1d)(n2, NULL, NULL, sign, _FFTW_FLAGS);
 			STARPU_ASSERT(plan->plans[workerid].plan1_cpu);
 
 			/* second fft plan: n3 ffts of size n1 */
-			plan->plans[workerid].in2 = _FFTW(malloc)(plan->totsize4 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].in2, 0, plan->totsize4 * sizeof(_fftw_complex));
-			plan->plans[workerid].out2 = _FFTW(malloc)(plan->totsize4 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].out2, 0, plan->totsize4 * sizeof(_fftw_complex));
 			plan->plans[workerid].plan2_cpu = _FFTW(plan_many_dft)(plan->dim,
 					plan->n1, n3,
-					/* input */ plan->plans[workerid].in2, NULL, 1, plan->totsize1,
-					/* output */ plan->plans[workerid].out2, NULL, 1, plan->totsize1,
+					NULL, NULL, 1, plan->totsize1,
+					NULL, NULL, 1, plan->totsize1,
 					sign, _FFTW_FLAGS);
 			STARPU_ASSERT(plan->plans[workerid].plan2_cpu);
 #else

+ 8 - 26
examples/starpufft/starpufftx2d.c

@@ -182,19 +182,15 @@ STARPUFFT(fft1_2d_kernel_cpu)(void *descr[], void *_args)
 
 	task_per_worker[workerid]++;
 
-	const STARPUFFT(complex) *twisted1 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
+	STARPUFFT(complex) *twisted1 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
 	STARPUFFT(complex) *fft1 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[1]);
 
-	_fftw_complex * restrict worker_in1 = (STARPUFFT(complex) *)plan->plans[workerid].in1;
-	_fftw_complex * restrict worker_out1 = (STARPUFFT(complex) *)plan->plans[workerid].out1;
-
 	/* printf("fft1 %d %d %g\n", i, j, (double) cabs(twisted1[0])); */
 
-	memcpy(worker_in1, twisted1, plan->totsize2 * sizeof(*worker_in1));
-	_FFTW(execute)(plan->plans[workerid].plan1_cpu);
+	_FFTW(execute_dft)(plan->plans[workerid].plan1_cpu, twisted1, fft1);
 	for (k = 0; k < n2; k++)
 		for (l = 0; l < m2; l++)
-			fft1[k*m2 + l] = worker_out1[k*m2 + l] * plan->roots[0][i*k] * plan->roots[1][j*l];
+			fft1[k*m2 + l] = fft1[k*m2 + l] * plan->roots[0][i*k] * plan->roots[1][j*l];
 }
 #endif
 
@@ -243,18 +239,12 @@ STARPUFFT(fft2_2d_kernel_cpu)(void *descr[], void *_args)
 
 	task_per_worker[workerid]++;
 
-	const STARPUFFT(complex) *twisted2 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
+	STARPUFFT(complex) *twisted2 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[0]);
 	STARPUFFT(complex) *fft2 = (STARPUFFT(complex) *)STARPU_VECTOR_GET_PTR(descr[1]);
 
 	/* printf("fft2 %d %d %g\n", kk, ll, (double) cabs(twisted2[plan->totsize4-1])); */
 
-	_fftw_complex * restrict worker_in2 = (STARPUFFT(complex) *)plan->plans[workerid].in2;
-	_fftw_complex * restrict worker_out2 = (STARPUFFT(complex) *)plan->plans[workerid].out2;
-
-	memcpy(worker_in2, twisted2, plan->totsize4 * sizeof(*worker_in2));
-	_FFTW(execute)(plan->plans[workerid].plan2_cpu);
-	/* no twiddle */
-	memcpy(fft2, worker_out2, plan->totsize4 * sizeof(*worker_out2));
+	_FFTW(execute_dft)(plan->plans[workerid].plan2_cpu, twisted2, fft2);
 }
 #endif
 
@@ -474,22 +464,14 @@ STARPUFFT(plan_dft_2d)(int n, int m, int sign, unsigned flags)
 		case STARPU_CPU_WORKER:
 #ifdef STARPU_HAVE_FFTW
 			/* first fft plan: one n2*m2 fft */
-			plan->plans[workerid].in1 = _FFTW(malloc)(plan->totsize2 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].in1, 0, plan->totsize2 * sizeof(_fftw_complex));
-			plan->plans[workerid].out1 = _FFTW(malloc)(plan->totsize2 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].out1, 0, plan->totsize2 * sizeof(_fftw_complex));
-			plan->plans[workerid].plan1_cpu = _FFTW(plan_dft_2d)(n2, m2, plan->plans[workerid].in1, plan->plans[workerid].out1, sign, _FFTW_FLAGS);
+			plan->plans[workerid].plan1_cpu = _FFTW(plan_dft_2d)(n2, m2, NULL, NULL, sign, _FFTW_FLAGS);
 			STARPU_ASSERT(plan->plans[workerid].plan1_cpu);
 
 			/* second fft plan: n3*m3 n1*m1 ffts */
-			plan->plans[workerid].in2 = _FFTW(malloc)(plan->totsize4 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].in2, 0, plan->totsize4 * sizeof(_fftw_complex));
-			plan->plans[workerid].out2 = _FFTW(malloc)(plan->totsize4 * sizeof(_fftw_complex));
-			memset(plan->plans[workerid].out2, 0, plan->totsize4 * sizeof(_fftw_complex));
 			plan->plans[workerid].plan2_cpu = _FFTW(plan_many_dft)(plan->dim,
 					plan->n1, n3*m3,
-					/* input */ plan->plans[workerid].in2, NULL, 1, plan->totsize1,
-					/* output */ plan->plans[workerid].out2, NULL, 1, plan->totsize1,
+					NULL, NULL, 1, plan->totsize1,
+					NULL, NULL, 1, plan->totsize1,
 					sign, _FFTW_FLAGS);
 			STARPU_ASSERT(plan->plans[workerid].plan2_cpu);
 #else