starpu-audio-processing.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <math.h>
  4. #include <string.h>
  5. #include <pthread.h>
  6. #include <sys/types.h>
  7. #include <sys/time.h>
  8. #include <starpu.h>
  9. #include <fftw3.h>
  10. #ifdef USE_CUDA
  11. #include <cufft.h>
  12. #endif
  13. //#define SAVE_RAW 1
  14. #define DEFAULTINPUTFILE "input.wav"
  15. #define DEFAULTOUTPUTFILE "output.wav"
  16. #define NSAMPLES (256*1024)
  17. #define SAMPLERATE 44100
  18. static unsigned nsamples = NSAMPLES;
  19. /* This is a band filter, we want to stop everything that is not between LOWFREQ and HIGHFREQ*/
  20. /* LOWFREQ < i * SAMPLERATE / NSAMPLE */
  21. #define LOWFREQ 500U
  22. #define HIFREQ 800U
  23. static const size_t headersize = 37+9;
  24. static FILE *infile, *outfile;
  25. static FILE *infile_raw, *outfile_raw;
  26. static char *inputfilename = DEFAULTINPUTFILE;
  27. static char *outputfilename = DEFAULTOUTPUTFILE;
  28. static unsigned use_pin = 0;
  29. unsigned length_data;
  30. /* buffer containing input WAV data */
  31. float *A;
  32. starpu_data_handle A_handle;
  33. /* For performance evaluation */
  34. static struct timeval start;
  35. static struct timeval end;
  36. static unsigned task_per_worker[STARPU_NMAXWORKERS] = {0};
  37. /*
  38. * Functions to Manipulate WAV files
  39. */
  40. unsigned get_wav_data_bytes_length(FILE *file)
  41. {
  42. /* this is clearly suboptimal !! */
  43. fseek(file, headersize, SEEK_SET);
  44. unsigned cnt = 0;
  45. while (fgetc(file) != EOF)
  46. cnt++;
  47. return cnt;
  48. }
  49. void copy_wav_header(FILE *srcfile, FILE *dstfile)
  50. {
  51. unsigned char buffer[128];
  52. fseek(srcfile, 0, SEEK_SET);
  53. fseek(dstfile, 0, SEEK_SET);
  54. fread(buffer, 1, headersize, infile);
  55. fwrite(buffer, 1, headersize, outfile);
  56. }
  57. void read_16bit_wav(FILE *infile, unsigned size, float *arrayout, FILE *save_file)
  58. {
  59. int v;
  60. #if SAVE_RAW
  61. unsigned currentpos = 0;
  62. #endif
  63. /* we skip the header to only keep the data */
  64. fseek(infile, headersize, SEEK_SET);
  65. for (v=0;v<size;v++) {
  66. signed char val = (signed char)fgetc(infile);
  67. signed char val2 = (signed char)fgetc(infile);
  68. arrayout[v] = 256*val2 + val;
  69. #if SAVE_RAW
  70. fprintf(save_file, "%d %f\n", currentpos++, arrayout[v]);
  71. #endif
  72. }
  73. }
  74. /* we only write the data, not the header !*/
  75. void write_16bit_wav(FILE *outfile, unsigned size, float *arrayin, FILE *save_file)
  76. {
  77. int v;
  78. #if SAVE_RAW
  79. unsigned currentpos = 0;
  80. #endif
  81. /* we assume that the header is copied using copy_wav_header */
  82. fseek(outfile, headersize, SEEK_SET);
  83. for (v=0;v<size;v++) {
  84. signed char val = ((int)arrayin[v]) % 256;
  85. signed char val2 = ((int)arrayin[v]) / 256;
  86. fputc(val, outfile);
  87. fputc(val2, outfile);
  88. #if SAVE_RAW
  89. if (save_file)
  90. fprintf(save_file, "%d %f\n", currentpos++, arrayin[v]);
  91. #endif
  92. }
  93. }
  94. /*
  95. *
  96. * The actual kernels
  97. *
  98. */
  99. /* we don't reinitialize the CUFFT plan for every kernel, so we "cache" it */
  100. typedef struct {
  101. unsigned is_initialized;
  102. #ifdef USE_CUDA
  103. cufftHandle plan;
  104. cufftHandle inv_plan;
  105. cufftComplex *localout;
  106. #endif
  107. fftwf_complex *localout_cpu;
  108. float *Acopy;
  109. fftwf_plan plan_cpu;
  110. fftwf_plan inv_plan_cpu;
  111. } fft_plan_cache;
  112. static fft_plan_cache plans[STARPU_NMAXWORKERS];
  113. #ifdef USE_CUDA
  114. static void band_filter_kernel_gpu(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  115. {
  116. cufftResult cures;
  117. float *localA = (float *)descr[0].vector.ptr;
  118. cufftComplex *localout;
  119. int workerid = starpu_get_worker_id();
  120. /* initialize the plane only during the first iteration */
  121. if (!plans[workerid].is_initialized)
  122. {
  123. cures = cufftPlan1d(&plans[workerid].plan, nsamples, CUFFT_R2C, 1);
  124. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  125. cures = cufftPlan1d(&plans[workerid].inv_plan, nsamples, CUFFT_C2R, 1);
  126. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  127. cudaMalloc((void **)&plans[workerid].localout,
  128. nsamples*sizeof(cufftComplex));
  129. STARPU_ASSERT(plans[workerid].localout);
  130. plans[workerid].is_initialized = 1;
  131. }
  132. localout = plans[workerid].localout;
  133. /* FFT */
  134. cures = cufftExecR2C(plans[workerid].plan, localA, localout);
  135. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  136. /* filter low freqs */
  137. unsigned lowfreq_index = (LOWFREQ*nsamples)/SAMPLERATE;
  138. cudaMemset(&localout[0], 0, lowfreq_index*sizeof(fftwf_complex));
  139. /* filter high freqs */
  140. unsigned hifreq_index = (HIFREQ*nsamples)/SAMPLERATE;
  141. cudaMemset(&localout[hifreq_index], nsamples/2, (nsamples/2 - hifreq_index)*sizeof(fftwf_complex));
  142. /* inverse FFT */
  143. cures = cufftExecC2R(plans[workerid].inv_plan, localout, localA);
  144. STARPU_ASSERT(cures == CUFFT_SUCCESS);
  145. /* FFTW does not normalize its output ! */
  146. cublasSscal (nsamples, 1.0f/nsamples, localA, 1);
  147. }
  148. #endif
  149. static pthread_mutex_t fftw_mutex = PTHREAD_MUTEX_INITIALIZER;
  150. static void band_filter_kernel_cpu(starpu_data_interface_t *descr, __attribute__((unused)) void *arg)
  151. {
  152. float *localA = (float *)descr[0].vector.ptr;
  153. int workerid = starpu_get_worker_id();
  154. /* initialize the plane only during the first iteration */
  155. if (!plans[workerid].is_initialized)
  156. {
  157. plans[workerid].localout_cpu = malloc(nsamples*sizeof(fftwf_complex));
  158. plans[workerid].Acopy = malloc(nsamples*sizeof(float));
  159. /* create plans, only "fftwf_execute" is thread safe in FFTW ... */
  160. pthread_mutex_lock(&fftw_mutex);
  161. plans[workerid].plan_cpu = fftwf_plan_dft_r2c_1d(nsamples,
  162. plans[workerid].Acopy,
  163. plans[workerid].localout_cpu,
  164. FFTW_ESTIMATE);
  165. plans[workerid].inv_plan_cpu = fftwf_plan_dft_c2r_1d(nsamples,
  166. plans[workerid].localout_cpu,
  167. plans[workerid].Acopy,
  168. FFTW_ESTIMATE);
  169. pthread_mutex_unlock(&fftw_mutex);
  170. plans[workerid].is_initialized = 1;
  171. }
  172. fftwf_complex *localout = plans[workerid].localout_cpu;
  173. /* copy data into the temporary buffer */
  174. memcpy(plans[workerid].Acopy, localA, nsamples*sizeof(float));
  175. /* FFT */
  176. fftwf_execute(plans[workerid].plan_cpu);
  177. /* filter low freqs */
  178. unsigned lowfreq_index = (LOWFREQ*nsamples)/SAMPLERATE;
  179. memset(&localout[0], 0, lowfreq_index*sizeof(fftwf_complex));
  180. /* filter high freqs */
  181. unsigned hifreq_index = (HIFREQ*nsamples)/SAMPLERATE;
  182. memset(&localout[hifreq_index], nsamples/2, (nsamples/2 - hifreq_index)*sizeof(fftwf_complex));
  183. /* inverse FFT */
  184. fftwf_execute(plans[workerid].inv_plan_cpu);
  185. /* copy data into the temporary buffer */
  186. memcpy(localA, plans[workerid].Acopy, nsamples*sizeof(float));
  187. /* FFTW does not normalize its output ! */
  188. /* TODO use BLAS ?*/
  189. int i;
  190. for (i = 0; i < nsamples; i++)
  191. localA[i] /= nsamples;
  192. }
  193. struct starpu_perfmodel_t band_filter_model = {
  194. .type = HISTORY_BASED,
  195. .symbol = "FFT_band_filter"
  196. };
  197. static starpu_codelet band_filter_cl = {
  198. .where = CORE|CUDA,
  199. #ifdef USE_CUDA
  200. .cuda_func = band_filter_kernel_gpu,
  201. #endif
  202. .core_func = band_filter_kernel_cpu,
  203. .model = &band_filter_model,
  204. .nbuffers = 1
  205. };
  206. void callback(void *arg)
  207. {
  208. /* do some accounting */
  209. int id = starpu_get_worker_id();
  210. task_per_worker[id]++;
  211. }
  212. void create_starpu_task(unsigned iter)
  213. {
  214. struct starpu_task *task = starpu_task_create();
  215. task->cl = &band_filter_cl;
  216. task->buffers[0].handle = get_sub_data(A_handle, 1, iter);
  217. task->buffers[0].mode = STARPU_RW;
  218. task->callback_func = callback;
  219. task->callback_arg = NULL;
  220. starpu_submit_task(task);
  221. }
  222. static void init_problem(void)
  223. {
  224. infile = fopen(inputfilename, "r");
  225. if (outputfilename)
  226. outfile = fopen(outputfilename, "w+");
  227. #if SAVE_RAW
  228. infile_raw = fopen("input.raw", "w");
  229. outfile_raw = fopen("output.raw", "w");
  230. #endif
  231. /* copy input's header into output WAV */
  232. if (outputfilename)
  233. copy_wav_header(infile, outfile);
  234. /* read length of input WAV's data */
  235. /* each element is 2 bytes long (16bits)*/
  236. length_data = get_wav_data_bytes_length(infile)/2;
  237. /* allocate a buffer to store the content of input file */
  238. if (use_pin)
  239. {
  240. starpu_malloc_pinned_if_possible((void **)&A, length_data*sizeof(float));
  241. }
  242. else {
  243. A = malloc(length_data*sizeof(float));
  244. }
  245. /* allocate working buffer (this could be done online, but we'll keep it simple) */
  246. //starpu_malloc_pinned_if_possible((void **)&outdata, length_data*sizeof(fftwf_complex));
  247. /* read input data into buffer "A" */
  248. read_16bit_wav(infile, length_data, A, infile_raw);
  249. }
  250. static void parse_args(int argc, char **argv)
  251. {
  252. int i;
  253. for (i = 1; i < argc; i++) {
  254. if (strcmp(argv[i], "-h") == 0) {
  255. fprintf(stderr, "Usage: %s [-pin] [-nsamples block_size] [-i input.wav] [-o output.wav | -no-output] [-h]\n", argv[0]);
  256. exit(-1);
  257. }
  258. if (strcmp(argv[i], "-i") == 0) {
  259. inputfilename = argv[++i];;
  260. }
  261. if (strcmp(argv[i], "-o") == 0) {
  262. outputfilename = argv[++i];;
  263. }
  264. if (strcmp(argv[i], "-no-output") == 0) {
  265. outputfilename = NULL;;
  266. }
  267. /* block size */
  268. if (strcmp(argv[i], "-nsamples") == 0) {
  269. char *argptr;
  270. nsamples = strtol(argv[++i], &argptr, 10);
  271. }
  272. if (strcmp(argv[i], "-pin") == 0) {
  273. use_pin = 1;
  274. }
  275. }
  276. }
  277. int main(int argc, char **argv)
  278. {
  279. unsigned iter;
  280. parse_args(argc, argv);
  281. fprintf(stderr, "Reading input data\n");
  282. init_problem();
  283. unsigned niter = length_data/nsamples;
  284. fprintf(stderr, "input: %s\noutput: %s\n#chunks %d\n", inputfilename, outputfilename, niter);
  285. /* launch StarPU */
  286. starpu_init(NULL);
  287. starpu_register_vector_data(&A_handle, 0, (uintptr_t)A, niter*nsamples, sizeof(float));
  288. starpu_filter f =
  289. {
  290. .filter_func = starpu_block_filter_func_vector,
  291. .filter_arg = niter
  292. };
  293. starpu_partition_data(A_handle, &f);
  294. for (iter = 0; iter < niter; iter++)
  295. starpu_data_set_wb_mask(get_sub_data(A_handle, 1, iter), 1<<0);
  296. gettimeofday(&start, NULL);
  297. for (iter = 0; iter < niter; iter++)
  298. {
  299. create_starpu_task(iter);
  300. }
  301. starpu_wait_all_tasks();
  302. gettimeofday(&end, NULL);
  303. double timing = (double)((end.tv_sec - start.tv_sec)*1000000 + (end.tv_usec - start.tv_usec));
  304. fprintf(stderr, "Computation took %2.2f ms\n", timing/1000);
  305. int worker;
  306. for (worker = 0; worker < STARPU_NMAXWORKERS; worker++)
  307. {
  308. if (task_per_worker[worker])
  309. {
  310. char name[32];
  311. starpu_get_worker_name(worker, name, 32);
  312. unsigned long bytes = nsamples*sizeof(float)*task_per_worker[worker];
  313. fprintf(stderr, "\t%s -> %2.2f MB\t%2.2f\tMB/s\t%2.2f %%\n", name, (1.0*bytes)/(1024*1024), bytes/timing, (100.0*task_per_worker[worker])/niter);
  314. }
  315. }
  316. if (outputfilename)
  317. fprintf(stderr, "Writing output data\n");
  318. /* make sure that the output is in RAM before quitting StarPU */
  319. starpu_unpartition_data(A_handle, 0);
  320. starpu_delete_data(A_handle);
  321. /* we are done ! */
  322. starpu_shutdown();
  323. fclose(infile);
  324. if (outputfilename)
  325. {
  326. write_16bit_wav(outfile, length_data, A, outfile_raw);
  327. fclose(outfile);
  328. }
  329. #if SAVE_RAW
  330. fclose(infile_raw);
  331. fclose(outfile_raw);
  332. #endif
  333. return 0;
  334. }