Browse Source

Use a task to increment the variable in the mpi-like tests

Cédric Augonnet 15 years ago
parent
commit
a3c7c3a755
2 changed files with 61 additions and 21 deletions
  1. 33 9
      tests/datawizard/mpi_like.c
  2. 28 12
      tests/datawizard/mpi_like_async.c

+ 33 - 9
tests/datawizard/mpi_like.c

@@ -15,6 +15,7 @@
  */
 
 #include <starpu.h>
+#include <errno.h>
 #include <pthread.h>
 
 #define NTHREADS	4
@@ -41,14 +42,40 @@ static struct thread_data problem_data[NTHREADS];
 /* We implement some ring transfer, every thread will try to receive a piece of
  * data from its neighbour and increment it before transmitting it to its
  * successor. */
+static void increment_handle_cpu_kernel(void *descr[], void *cl_arg __attribute__((unused)))
+{
+	unsigned *val = (unsigned *)STARPU_GET_VARIABLE_PTR(descr[0]);
+	*val += 1;
+}
+
+static starpu_codelet increment_handle_cl = {
+	.where = STARPU_CPU,
+	.cpu_func = increment_handle_cpu_kernel,
+	.nbuffers = 1
+};
 
 static void increment_handle(struct thread_data *thread_data)
 {
-	starpu_data_sync_with_mem(thread_data->handle, STARPU_RW);
-	thread_data->val++;
-	starpu_data_release_from_mem(thread_data->handle);
+	struct starpu_task *task = starpu_task_create();
+	task->cl = &increment_handle_cl;
+
+	task->buffers[0].handle = thread_data->handle;
+	task->buffers[0].mode = STARPU_RW;
+
+	task->cl_arg = thread_data;
+
+	task->destroy = 1;
+	task->detach = 0;
+
+	int ret = starpu_task_submit(task);
+	STARPU_ASSERT(!ret);
+
+	ret = starpu_task_wait(task);
+	STARPU_ASSERT(!ret);
 }
 
+
+
 static void recv_handle(struct thread_data *thread_data)
 {
 	starpu_data_sync_with_mem(thread_data->handle, STARPU_W);
@@ -99,8 +126,6 @@ static void *thread_func(void *arg)
 	struct thread_data *thread_data = arg;
 	unsigned index = thread_data->index;
 
-//	fprintf(stderr, "Hello from thread %d\n", thread_data->index);
-
 	starpu_variable_data_register(&thread_data->handle, 0, (uintptr_t)&thread_data->val, sizeof(unsigned));
 
 	for (iter = 0; iter < NITER; iter++)
@@ -119,10 +144,6 @@ static void *thread_func(void *arg)
 		}
 	}
 
-//	starpu_data_sync_with_mem(thread_data->handle, STARPU_R);
-//	fprintf(stderr, "Final value on thread %d: %d\n", thread_data->index, thread_data->val);
-//	starpu_data_release_from_mem(thread_data->handle);
-
 	return NULL;
 }
 
@@ -160,7 +181,10 @@ int main(int argc, char **argv)
 	starpu_data_handle last_handle = problem_data[NTHREADS - 1].handle;
 	starpu_data_sync_with_mem(last_handle, STARPU_R);
 	if (problem_data[NTHREADS - 1].val != (NTHREADS * NITER))
+	{
+		fprintf(stderr, "Final value : %d should be %d\n", problem_data[NTHREADS - 1].val, (NTHREADS * NITER));
 		STARPU_ABORT();
+	}
 	starpu_data_release_from_mem(last_handle);
 
 	starpu_shutdown();

+ 28 - 12
tests/datawizard/mpi_like_async.c

@@ -42,11 +42,31 @@ static struct thread_data problem_data[NTHREADS];
  * data from its neighbour and increment it before transmitting it to its
  * successor. */
 
-static void increment_handle_async(void *_thread_data)
+static void increment_handle_cpu_kernel(void *descr[], void *cl_arg __attribute__((unused)))
 {
-	struct thread_data *thread_data = _thread_data;
-	thread_data->val++;
-	starpu_data_release_from_mem(thread_data->handle);
+	unsigned *val = (unsigned *)STARPU_GET_VARIABLE_PTR(descr[0]);
+	*val += 1;
+}
+
+static starpu_codelet increment_handle_cl = {
+	.where = STARPU_CPU,
+	.cpu_func = increment_handle_cpu_kernel,
+	.nbuffers = 1
+};
+
+static void increment_handle_async(struct thread_data *thread_data)
+{
+	struct starpu_task *task = starpu_task_create();
+	task->cl = &increment_handle_cl;
+
+	task->buffers[0].handle = thread_data->handle;
+	task->buffers[0].mode = STARPU_RW;
+
+	task->detach = 1;
+	task->destroy = 1;
+
+	int ret = starpu_task_submit(task);
+	STARPU_ASSERT(!ret);
 }
 
 static void recv_handle_async(void *_thread_data)
@@ -113,10 +133,7 @@ static void *thread_func(void *arg)
 			);
 		}
 		
-		starpu_data_sync_with_mem_non_blocking(
-			thread_data->handle, STARPU_RW,
-			increment_handle_async, thread_data
-		);
+		increment_handle_async(thread_data);
 
 		if (!((index == (NTHREADS - 1)) && (iter == (NITER - 1))))
 		{
@@ -129,10 +146,6 @@ static void *thread_func(void *arg)
 
 	starpu_task_wait_for_all();
 
-//	starpu_data_sync_with_mem(thread_data->handle, STARPU_R);
-//	fprintf(stderr, "Final value on thread %d: %d\n", thread_data->index, thread_data->val);
-//	starpu_data_release_from_mem(thread_data->handle);
-
 	return NULL;
 }
 
@@ -170,7 +183,10 @@ int main(int argc, char **argv)
 	starpu_data_handle last_handle = problem_data[NTHREADS - 1].handle;
 	starpu_data_sync_with_mem(last_handle, STARPU_R);
 	if (problem_data[NTHREADS - 1].val != (NTHREADS * NITER))
+	{
+		fprintf(stderr, "Final value : %d should be %d\n", problem_data[NTHREADS - 1].val, (NTHREADS * NITER));
 		STARPU_ABORT();
+	}
 	starpu_data_release_from_mem(last_handle);
 
 	starpu_shutdown();