Browse Source

starpufft: do not use deprecated function starpu_free

Nathalie Furmento 4 years ago
parent
commit
a4e0ec5aed

+ 1 - 1
starpufft/include/starpufft.h

@@ -46,7 +46,7 @@ starpufft(plan) starpufft(plan_dft_r2c_1d)(int n, unsigned flags); \
 starpufft(plan) starpufft(plan_dft_c2r_1d)(int n, unsigned flags); \
 \
 void *starpufft(malloc)(size_t n); \
-void starpufft(free)(void *p); \
+void starpufft(free)(void *p, size_t dim);   \
 \
 int starpufft(execute)(starpufft(plan) p, void *in, void *out); \
 struct starpu_task *starpufft(start)(starpufft(plan) p, void *in, void *out); \

+ 7 - 6
starpufft/src/starpufftx.c

@@ -110,6 +110,7 @@ struct STARPUFFT(plan)
 
 	/* Buffers for codelets */
 	STARPUFFT(complex) *in, *twisted1, *fft1, *twisted2, *fft2, *out;
+	size_t twisted1_size, twisted2_size, fft1_size, fft2_size;
 
 	/* corresponding starpu DSM handles */
 	starpu_data_handle_t in_handle, *twisted1_handle, *fft1_handle, *twisted2_handle, *fft2_handle, out_handle;
@@ -406,10 +407,10 @@ STARPUFFT(destroy_plan)(STARPUFFT(plan) plan)
 
 		free(plan->n1);
 		free(plan->n2);
-		STARPUFFT(free)(plan->twisted1);
-		STARPUFFT(free)(plan->fft1);
-		STARPUFFT(free)(plan->twisted2);
-		STARPUFFT(free)(plan->fft2);
+		STARPUFFT(free)(plan->twisted1, plan->twisted1_size);
+		STARPUFFT(free)(plan->fft1, plan->fft1_size);
+		STARPUFFT(free)(plan->twisted2, plan->twisted2_size);
+		STARPUFFT(free)(plan->fft2, plan->fft2_size);
 	}
 	free(plan->n);
 	free(plan);
@@ -432,10 +433,10 @@ STARPUFFT(malloc)(size_t n)
 }
 
 void
-STARPUFFT(free)(void *p)
+STARPUFFT(free)(void *p, size_t dim)
 {
 #ifdef STARPU_USE_CUDA
-	starpu_free(p);
+	starpu_free_noflag(p, dim);
 #else
 #  ifdef STARPU_HAVE_FFTW
 	_FFTW(free)(p);

+ 15 - 8
starpufft/src/starpufftx1d.c

@@ -620,14 +620,21 @@ if (PARALLEL) {
 
 if (PARALLEL) {
 	/* Allocate buffers. */
-	plan->twisted1 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->twisted1));
-	memset(plan->twisted1, 0, plan->totsize * sizeof(*plan->twisted1));
-	plan->fft1 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->fft1));
-	memset(plan->fft1, 0, plan->totsize * sizeof(*plan->fft1));
-	plan->twisted2 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->twisted2));
-	memset(plan->twisted2, 0, plan->totsize * sizeof(*plan->twisted2));
-	plan->fft2 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->fft2));
-	memset(plan->fft2, 0, plan->totsize * sizeof(*plan->fft2));
+	plan->twisted1_size = plan->totsize * sizeof(*plan->twisted1);
+	plan->twisted1 = STARPUFFT(malloc)(plan->twisted1_size);
+	memset(plan->twisted1, 0, plan->twisted1_size);
+
+	plan->fft1_size = plan->totsize * sizeof(*plan->fft1);
+	plan->fft1 = STARPUFFT(malloc)(plan->fft1_size);
+	memset(plan->fft1, 0, plan->fft1_size);
+
+	plan->twisted2_size = plan->totsize * sizeof(*plan->twisted2);
+	plan->twisted2 = STARPUFFT(malloc)(plan->twisted2_size);
+	memset(plan->twisted2, 0, plan->twisted2_size);
+
+	plan->fft2_size = plan->totsize * sizeof(*plan->fft2);
+	plan->fft2 = STARPUFFT(malloc)(plan->fft2_size);
+	memset(plan->fft2, 0, plan->fft2_size);
 
 	/* Allocate handle arrays */
 	plan->twisted1_handle = malloc(plan->totsize1 * sizeof(*plan->twisted1_handle));

+ 15 - 8
starpufft/src/starpufftx2d.c

@@ -622,14 +622,21 @@ if (PARALLEL) {
 
 if (PARALLEL) {
 	/* Allocate buffers. */
-	plan->twisted1 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->twisted1));
-	memset(plan->twisted1, 0, plan->totsize * sizeof(*plan->twisted1));
-	plan->fft1 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->fft1));
-	memset(plan->fft1, 0, plan->totsize * sizeof(*plan->fft1));
-	plan->twisted2 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->twisted2));
-	memset(plan->twisted2, 0, plan->totsize * sizeof(*plan->twisted2));
-	plan->fft2 = STARPUFFT(malloc)(plan->totsize * sizeof(*plan->fft2));
-	memset(plan->fft2, 0, plan->totsize * sizeof(*plan->fft2));
+	plan->twisted1_size = plan->totsize * sizeof(*plan->twisted1);
+	plan->twisted1 = STARPUFFT(malloc)(plan->twisted1_size);
+	memset(plan->twisted1, 0, plan->twisted1_size);
+
+	plan->fft1_size = plan->totsize * sizeof(*plan->fft1);
+	plan->fft1 = STARPUFFT(malloc)(plan->fft1_size);
+	memset(plan->fft1, 0, plan->fft1_size);
+
+	plan->twisted2_size = plan->totsize * sizeof(*plan->twisted2);
+	plan->twisted2 = STARPUFFT(malloc)(plan->twisted2_size);
+	memset(plan->twisted2, 0, plan->twisted2_size);
+
+	plan->fft2_size = plan->totsize * sizeof(*plan->fft2);
+	plan->fft2 = STARPUFFT(malloc)(plan->fft2_size);
+	memset(plan->fft2, 0, plan->fft2_size);
 
 	/* Allocate handle arrays */
 	plan->twisted1_handle = malloc(plan->totsize1 * sizeof(*plan->twisted1_handle));