Browse Source

Pass a fresh copy of input each time, it seems fftw overwrites it

Samuel Thibault 5 years ago
parent
commit
64c6440b1c
1 changed files with 10 additions and 3 deletions
  1. 10 3
      starpufft/tests/testx.c

+ 10 - 3
starpufft/tests/testx.c

@@ -2,7 +2,7 @@
  *
  * Copyright (C) 2010-2015,2017                           CNRS
  * Copyright (C) 2012,2013,2017                           Inria
- * Copyright (C) 2009-2012,2014                           Université de Bordeaux
+ * Copyright (C) 2009-2012,2014,2019                      Université de Bordeaux
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -166,10 +166,12 @@ int main(int argc, char *argv[])
 	bytes = size * sizeof(STARPUFFT(complex));
 #endif
 
-	STARPUFFT(complex) *in = STARPUFFT(malloc)(size * sizeof(*in));
+	STARPUFFT(complex) *in_orig = STARPUFFT(malloc)(size * sizeof(*in_orig));
 	starpu_srand48(0);
 	for (i = 0; i < size; i++)
-		in[i] = starpu_drand48() + I * starpu_drand48();
+		in_orig[i] = starpu_drand48() + I * starpu_drand48();
+
+	STARPUFFT(complex) *in = STARPUFFT(malloc)(size * sizeof(*in));
 
 	STARPUFFT(complex) *out = STARPUFFT(malloc)(size * sizeof(*out));
 
@@ -209,6 +211,7 @@ int main(int argc, char *argv[])
 	}
 
 #ifdef STARPU_HAVE_FFTW
+	memcpy(in, in_orig, size * sizeof(*in));
 	gettimeofday(&begin, NULL);
 	_FFTW(execute_dft)(fftw_plan, in, out_fftw);
 	gettimeofday(&end, NULL);
@@ -217,6 +220,7 @@ int main(int argc, char *argv[])
 	printf("FFTW took %2.2f ms (%2.2f MB/s)\n\n", timing/1000, bytes/timing);
 #endif
 #ifdef STARPU_USE_CUDA
+	memcpy(in, in_orig, size * sizeof(*in));
 	gettimeofday(&begin, NULL);
 	if (cufftExecC2C(cuda_plan, (cufftComplex*) in, (cufftComplex*) out_cuda, CUFFT_FORWARD) != CUFFT_SUCCESS)
 		printf("erf2\n");
@@ -228,6 +232,7 @@ int main(int argc, char *argv[])
 	printf("CUDA took %2.2f ms (%2.2f MB/s)\n\n", timing/1000, bytes/timing);
 #endif
 
+	memcpy(in, in_orig, size * sizeof(*in));
 	ret = STARPUFFT(execute)(plan, in, out);
 	if (ret == -1) return 77;
 	STARPUFFT(showstats)(stdout);
@@ -240,6 +245,7 @@ int main(int argc, char *argv[])
 #endif
 
 #if 1
+	memcpy(in, in_orig, size * sizeof(*in));
 	starpu_vector_data_register(&in_handle, STARPU_MAIN_RAM, (uintptr_t) in, size, sizeof(*in));
 	starpu_vector_data_register(&out_handle, STARPU_MAIN_RAM, (uintptr_t) out, size, sizeof(*out));
 
@@ -275,6 +281,7 @@ int main(int argc, char *argv[])
 #endif
 #endif
 
+	STARPUFFT(free)(in_orig);
 	STARPUFFT(free)(in);
 	STARPUFFT(free)(out);