浏览代码

Cleanup the SpMV example a little, and make sure that the codelet structure is properly initialized.

Cédric Augonnet 15 年之前
父节点
当前提交
f81bb7fa8d
共有 1 个文件被更改,包括 11 次插入34 次删除
  1. 11 34
      examples/spmv/dw_spmv.c

+ 11 - 34
examples/spmv/dw_spmv.c

@@ -24,7 +24,6 @@ struct timeval start;
 struct timeval end;
 
 unsigned nblocks = 1;
-unsigned remainingtasks = -1;
 
 #ifdef STARPU_USE_CUDA
 
@@ -53,8 +52,6 @@ void spmv_kernel_cuda(void *descr[], void *args)
 
 #endif // STARPU_USE_CUDA
 
-
-sem_t sem;
 uint32_t size = 4194304;
 
 starpu_data_handle sparse_matrix;
@@ -211,33 +208,10 @@ void create_data(void)
 
 }
 
-void init_problem_callback(void *arg)
-{
-	unsigned *remaining = arg;
-
-
-	unsigned val = STARPU_ATOMIC_ADD(remaining, -1);
-
-	printf("callback %d remaining \n", val);
-	if ( val == 0 )
-	{
-		printf("DONE ...\n");
-		gettimeofday(&end, NULL);
-
-		starpu_unpartition_data(sparse_matrix, 0);
-		starpu_unpartition_data(vector_out, 0);
-
-		sem_post(&sem);
-	}
-}
-
-
 void call_spmv_codelet_filters(void)
 {
 
-	remainingtasks = nblocks;
-
-	starpu_codelet *cl = malloc(sizeof(starpu_codelet));
+	starpu_codelet *cl = calloc(1, sizeof(starpu_codelet));
 
 	/* partition the data along a block distribution */
 	starpu_filter csr_f, vector_f;
@@ -255,6 +229,7 @@ void call_spmv_codelet_filters(void)
 	cl->cuda_func = spmv_kernel_cuda;
 #endif
 	cl->nbuffers = 3;
+	cl->model = NULL;
 
 	gettimeofday(&start, NULL);
 
@@ -263,8 +238,8 @@ void call_spmv_codelet_filters(void)
 	{
 		struct starpu_task *task = starpu_task_create();
 
-		task->callback_func = init_problem_callback;
-		task->callback_arg = &remainingtasks;
+		task->callback_func = NULL;
+
 		task->cl = cl;
 		task->cl_arg = NULL;
 	
@@ -277,6 +252,13 @@ void call_spmv_codelet_filters(void)
 	
 		starpu_submit_task(task);
 	}
+
+	starpu_wait_all_tasks();
+
+	gettimeofday(&end, NULL);
+
+	starpu_unpartition_data(sparse_matrix, 0);
+	starpu_unpartition_data(vector_out, 0);
 }
 
 void init_problem(void)
@@ -306,13 +288,8 @@ int main(__attribute__ ((unused)) int argc,
 	/* start the runtime */
 	starpu_init(NULL);
 
-	sem_init(&sem, 0, 0U);
-
 	init_problem();
 
-	sem_wait(&sem);
-	sem_destroy(&sem);
-
 	print_results();
 
 	double timing = (double)((end.tv_sec - start.tv_sec)*1000000 + (end.tv_usec - start.tv_usec));