Browse Source

Clarify the behaviour of starpu_wait_all_tasks with regenarable tasks. A
regenerable task is not considered finished until the "regenerate" field of the
task structure has been explicitely set to 0. To facilitate this, we provide a
new function called "starpu_get_current_task" which returns a pointer to the
task currently being executed.

Cédric Augonnet 15 years ago
parent
commit
e241148129

+ 5 - 0
include/starpu-task.h

@@ -214,6 +214,11 @@ int starpu_wait_all_tasks(void);
 
 void starpu_display_codelet_stats(struct starpu_codelet_t *cl);
 
+/* Return the task currently executed by the worker, or NULL if this is called
+ * either from a thread that is not a task or simply because there is no task
+ * being executed at the moment. */
+struct starpu_task *starpu_get_current_task(void);
+
 #ifdef __cplusplus
 }
 #endif

+ 9 - 4
src/core/jobs.c

@@ -107,10 +107,14 @@ void _starpu_handle_job_termination(starpu_job_t j)
 		/* so that we can check whether we are doing blocking calls
 		 * within the callback */
 		_starpu_set_local_worker_status(STATUS_CALLBACK);
+		
+		_starpu_set_current_task(task);
 
 		STARPU_TRACE_START_CALLBACK(j);
 		task->callback_func(task->callback_arg);
 		STARPU_TRACE_END_CALLBACK(j);
+		
+		_starpu_set_current_task(NULL);
 
 		_starpu_set_local_worker_status(STATUS_UNKNOWN);
 	}
@@ -122,7 +126,7 @@ void _starpu_handle_job_termination(starpu_job_t j)
 	int detach = task->detach;
 	int regenerate = task->regenerate;
 
-	if (!task->detach)
+	if (!detach)
 	{
 		/* we do not desallocate the job structure if some is going to
 		 * wait after the task */
@@ -138,16 +142,17 @@ void _starpu_handle_job_termination(starpu_job_t j)
 			starpu_task_destroy(task);
 	}
 
-	_starpu_decrement_nsubmitted_tasks();
-
 	if (regenerate)
 	{
 		STARPU_ASSERT(detach && !destroy && !task->synchronous);
 
 		/* We reuse the same job structure */
-		int ret = _starpu_submit_job(j);
+		int ret = _starpu_submit_job(j, 1);
 		STARPU_ASSERT(!ret);
 	}	
+	else {
+		_starpu_decrement_nsubmitted_tasks();
+	}
 }
 
 /* This function is called when a new task is submitted to StarPU 

+ 41 - 6
src/core/task.c

@@ -27,6 +27,14 @@ static pthread_cond_t submitted_cond = PTHREAD_COND_INITIALIZER;
 static pthread_mutex_t submitted_mutex = PTHREAD_MUTEX_INITIALIZER;
 static long int nsubmitted = 0;
 
+static void _starpu_increment_nsubmitted_tasks(void);
+
+/* This key stores the task currently handled by the thread, note that we
+ * cannot use the worker structure to store that information because it is
+ * possible that we have a task with a NULL codelet, which means its callback
+ * could be executed by a user thread as well. */
+static pthread_key_t current_task_key;
+
 void starpu_task_init(struct starpu_task *task)
 {
 	STARPU_ASSERT(task);
@@ -135,9 +143,12 @@ starpu_job_t _starpu_get_job_associated_to_task(struct starpu_task *task)
 	return (struct starpu_job_s *)task->starpu_private;
 }
 
-int _starpu_submit_job(starpu_job_t j)
+/* NB in case we have a regenerable task, it is possible that the job was
+ * already counted. */
+int _starpu_submit_job(starpu_job_t j, unsigned do_not_increment_nsubmitted)
 {
-	_starpu_increment_nsubmitted_tasks();
+	if (!do_not_increment_nsubmitted)
+		_starpu_increment_nsubmitted_tasks();
 
 	j->submitted = 1;
 
@@ -191,9 +202,8 @@ int starpu_submit_task(struct starpu_task *task)
 		j = (struct starpu_job_s *)task->starpu_private;
 	}
 
-	ret = _starpu_submit_job(j);
+	ret = _starpu_submit_job(j, 0);
 
-	/* XXX modify when we'll have starpu_wait_task */
 	if (is_sync)
 		_starpu_wait_job(j);
 
@@ -222,6 +232,11 @@ void starpu_display_codelet_stats(struct starpu_codelet_t *cl)
 	}
 }
 
+/*
+ * We wait for all the tasks that have already been submitted. Note that a
+ * regenerable is not considered finished until it was explicitely set as
+ * non-regenerale anymore (eg. from a callback).
+ */
 int starpu_wait_all_tasks(void)
 {
 	int res;
@@ -229,7 +244,6 @@ int starpu_wait_all_tasks(void)
 	if (STARPU_UNLIKELY(!_starpu_worker_may_perform_blocking_calls()))
 		return -EDEADLK;
 
-
 	pthread_mutex_lock(&submitted_mutex);
 
 	if (nsubmitted > 0)
@@ -257,9 +271,30 @@ void _starpu_decrement_nsubmitted_tasks(void)
 
 }
 
-void _starpu_increment_nsubmitted_tasks(void)
+static void _starpu_increment_nsubmitted_tasks(void)
 {
 	pthread_mutex_lock(&submitted_mutex);
 	nsubmitted++;
 	pthread_mutex_unlock(&submitted_mutex);
 }
+
+void _starpu_initialize_current_task_key(void)
+{
+	pthread_key_create(&current_task_key, NULL);
+}
+
+/* Return the task currently executed by the worker, or NULL if this is called
+ * either from a thread that is not a task or simply because there is no task
+ * being executed at the moment. */
+struct starpu_task *starpu_get_current_task(void)
+{
+	return pthread_getspecific(current_task_key);
+}
+
+void _starpu_set_current_task(struct starpu_task *task)
+{
+	if (task)
+		STARPU_ASSERT(pthread_getspecific(current_task_key) == NULL);
+
+	pthread_setspecific(current_task_key, task);
+}

+ 6 - 2
src/core/task.h

@@ -22,10 +22,14 @@
 
 /* In order to implement starpu_wait_all_tasks, we keep track of the number of
  * task currently submitted */
-void _starpu_increment_nsubmitted_tasks(void);
 void _starpu_decrement_nsubmitted_tasks(void);
 
-int _starpu_submit_job(starpu_job_t j);
+void _starpu_initialize_current_task_key(void);
+void _starpu_set_current_task(struct starpu_task *task);
+
+/* NB the second argument makes it possible to count regenerable tasks only
+ * once. */
+int _starpu_submit_job(starpu_job_t j, unsigned do_not_increment_nsubmitted);
 starpu_job_t _starpu_get_job_associated_to_task(struct starpu_task *task);
 
 #endif // __CORE_TASK_H__

+ 5 - 0
src/core/workers.c

@@ -19,6 +19,7 @@
 #include <common/config.h>
 #include <core/workers.h>
 #include <core/debug.h>
+#include <core/task.h>
 
 #ifdef __MINGW32__
 #include <windows.h>
@@ -266,6 +267,10 @@ int starpu_init(struct starpu_conf *user_conf)
 		return ret;
 	}
 
+	/* We need to store the current task handled by the different
+	 * threads */
+	_starpu_initialize_current_task_key();	
+
 	/* initialize the scheduler */
 
 	/* initialize the queue containing the jobs */

+ 5 - 0
src/drivers/cpu/driver_cpu.c

@@ -177,7 +177,12 @@ void *_starpu_cpu_worker(void *arg)
 			continue;
 		}
 
+		_starpu_set_current_task(j->task);
+
                 res = execute_job_on_cpu(j, cpu_arg);
+
+		_starpu_set_current_task(NULL);
+
 		if (res) {
 			switch (res) {
 				case -EAGAIN:

+ 1 - 0
src/drivers/cpu/driver_cpu.h

@@ -19,6 +19,7 @@
 
 #include <common/config.h>
 #include <core/jobs.h>
+#include <core/task.h>
 
 #include <core/perfmodel/perfmodel.h>
 #include <common/fxt.h>

+ 5 - 0
src/drivers/cuda/driver_cuda.c

@@ -259,8 +259,13 @@ void *_starpu_cuda_worker(void *arg)
 			continue;
 		}
 
+#warning TODO adapt to OpenCL !
+		_starpu_set_current_task(j->task);
+
 		res = execute_job_on_cuda(j, args);
 
+		_starpu_set_current_task(NULL);
+
 		if (res) {
 			switch (res) {
 				case -EAGAIN:

+ 1 - 0
src/drivers/cuda/driver_cuda.h

@@ -31,6 +31,7 @@
 #include <common/config.h>
 
 #include <core/jobs.h>
+#include <core/task.h>
 #include <datawizard/datawizard.h>
 #include <core/perfmodel/perfmodel.h>
 

+ 8 - 0
tests/Makefile.am

@@ -79,12 +79,14 @@ check_PROGRAMS += 				\
 	core/static_restartable_using_initializer\
 	core/static_restartable_tag		\
 	core/regenerate				\
+	core/wait_all_regenerable_tasks		\
 	core/subgraph_repeat			\
 	core/subgraph_repeat_regenerate		\
 	core/empty_task				\
 	core/empty_task_sync_point		\
 	core/tag-wait-api			\
 	core/task-wait-api			\
+	core/get_current_task			\
 	datawizard/sync_and_notify_data		\
 	datawizard/dsm_stress			\
 	datawizard/write_only_tmp_buffer	\
@@ -135,6 +137,9 @@ core_static_restartable_tag_SOURCES =		\
 core_regenerate_SOURCES =			\
 	core/regenerate.c
 
+core_wait_all_regenerable_tasks_SOURCES =	\
+	core/wait_all_regenerable_tasks.c
+
 core_subgraph_repeat_SOURCES =			\
 	core/subgraph_repeat.c
 
@@ -153,6 +158,9 @@ core_tag_wait_api_SOURCES =			\
 core_task_wait_api_SOURCES =			\
 	core/task-wait-api.c
 
+core_get_current_task_SOURCES =			\
+	core/get_current_task.c
+
 datawizard_dsm_stress_SOURCES =			\
 	datawizard/dsm_stress.c
 

+ 97 - 0
tests/core/get_current_task.c

@@ -0,0 +1,97 @@
+/*
+ * StarPU
+ * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#include <stdio.h>
+#include <unistd.h>
+#include <starpu.h>
+
+static unsigned ntasks = 65536;
+
+static void check_task_func(void *descr[], void *arg)
+{
+	/* We check that the returned task is valid from the codelet */
+	struct starpu_task *task = arg;
+	STARPU_ASSERT(task == starpu_get_current_task());
+}
+
+static void check_task_callback(void *arg)
+{
+	/* We check that the returned task is valid from the callback */
+	struct starpu_task *task = arg;
+	STARPU_ASSERT(task == starpu_get_current_task());
+}
+
+static struct starpu_codelet_t dummy_cl = {
+	.where = STARPU_CUDA|STARPU_CPU,
+	.cuda_func = check_task_func,
+	.cpu_func = check_task_func,
+	.model = NULL,
+	.nbuffers = 0
+};
+
+int main(int argc, char **argv)
+{
+	double timing;
+	struct timeval start;
+	struct timeval end;
+
+	starpu_init(NULL);
+
+	fprintf(stderr, "#tasks : %d\n", ntasks);
+
+	int i;
+	for (i = 0; i < ntasks; i++)
+	{
+		struct starpu_task *task = starpu_task_create();
+
+		/* We check if the function is valid from the codelet or from
+		 * the callback */
+		task->cl = &dummy_cl;
+		task->cl_arg = task;
+
+		task->callback_func = check_task_callback;
+		task->callback_arg = task;
+
+		int ret = starpu_submit_task(task);
+		STARPU_ASSERT(!ret);
+	}
+
+	starpu_wait_all_tasks();
+	
+	fprintf(stderr, "#empty tasks : %d\n", ntasks);
+
+	/* We repeat the same experiment with null codelets */
+
+	for (i = 0; i < ntasks; i++)
+	{
+		struct starpu_task *task = starpu_task_create();
+
+		task->cl = NULL;
+
+		/* We check if the function is valid from the callback */
+		task->callback_func = check_task_callback;
+		task->callback_arg = task;
+
+		int ret = starpu_submit_task(task);
+		STARPU_ASSERT(!ret);
+	}
+
+	starpu_wait_all_tasks();
+
+	starpu_shutdown();
+
+	return 0;
+}

+ 2 - 3
tests/core/regenerate.c

@@ -27,9 +27,9 @@ static unsigned completed = 0;
 static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
 static pthread_cond_t cond = PTHREAD_COND_INITIALIZER;
 
-static void callback(void *arg)
+static void callback(void *arg __attribute__ ((unused)))
 {
-	struct starpu_task *task = arg;
+	struct starpu_task *task = starpu_get_current_task();
 
 	cnt++;
 
@@ -89,7 +89,6 @@ int main(int argc, char **argv)
 	task.detach = 1;
 
 	task.callback_func = callback;
-	task.callback_arg = &task;
 
 	fprintf(stderr, "#tasks : %d\n", ntasks);
 

+ 117 - 0
tests/core/wait_all_regenerable_tasks.c

@@ -0,0 +1,117 @@
+/*
+ * StarPU
+ * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#include <sys/time.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <pthread.h>
+#include <starpu.h>
+
+static unsigned ntasks = 1024;
+
+static void callback(void *arg)
+{
+	struct starpu_task *task = starpu_get_current_task();
+
+	unsigned *cnt = arg;
+
+	(*cnt)++;
+
+	if (*cnt == ntasks)
+		task->regenerate = 0;
+}
+
+static void dummy_func(void *descr[] __attribute__ ((unused)), void *arg __attribute__ ((unused)))
+{
+}
+
+static starpu_codelet dummy_codelet = 
+{
+	.where = STARPU_CPU|STARPU_CUDA,
+	.cpu_func = dummy_func,
+	.cuda_func = dummy_func,
+	.model = NULL,
+	.nbuffers = 0
+};
+
+static void parse_args(int argc, char **argv)
+{
+	int c;
+	while ((c = getopt(argc, argv, "i:")) != -1)
+	switch(c) {
+		case 'i':
+			ntasks = atoi(optarg);
+			break;
+	}
+}
+
+#define K	128
+
+int main(int argc, char **argv)
+{
+	double timing;
+	struct timeval start;
+	struct timeval end;
+
+	parse_args(argc, argv);
+
+	starpu_init(NULL);
+
+	struct starpu_task task[K];
+	unsigned cnt[K];;
+
+	int i;
+	for (i = 0; i < K; i++)
+	{
+		starpu_task_init(&task[i]);
+		cnt[i] = 0;
+
+		task[i].cl = &dummy_codelet;
+		task[i].regenerate = 1;
+		task[i].detach = 1;
+
+		task[i].callback_func = callback;
+		task[i].callback_arg = &cnt[i];
+	}
+
+	fprintf(stderr, "#tasks : %d x %d tasks\n", K, ntasks);
+
+	gettimeofday(&start, NULL);
+	
+	for (i = 0; i < K; i++)
+		starpu_submit_task(&task[i]);
+
+	starpu_wait_all_tasks();
+
+	gettimeofday(&end, NULL);
+
+	/* Check that all the tasks have been properly executed */
+	unsigned total_cnt = 0;
+	for (i = 0; i < K; i++)
+		total_cnt += cnt[i];
+
+	STARPU_ASSERT(total_cnt == K*ntasks);
+
+	timing = (double)((end.tv_sec - start.tv_sec)*1000000
+				+ (end.tv_usec - start.tv_usec));
+
+	fprintf(stderr, "Total: %lf secs\n", timing/1000000);
+	fprintf(stderr, "Per task: %lf usecs\n", timing/(K*ntasks));
+
+	starpu_shutdown();
+
+	return 0;
+}