Procházet zdrojové kódy

- add if_clause for omp parallel region
- implement omp_get_num_threads/omp_get_thread_num

Olivier Aumage před 11 roky
rodič
revize
d5a4860212

+ 1 - 1
include/starpu_openmp.h

@@ -52,7 +52,7 @@ extern "C"
 extern int starpu_omp_init(void) __STARPU_OMP_NOTHROW;
 extern void starpu_omp_shutdown(void) __STARPU_OMP_NOTHROW;
 
-extern void starpu_omp_parallel_region(const struct starpu_codelet * const parallel_region_cl, starpu_data_handle_t *handles, void * const cl_arg, size_t cl_arg_size, unsigned cl_arg_free) __STARPU_OMP_NOTHROW;
+extern void starpu_omp_parallel_region(const struct starpu_codelet * const parallel_region_cl, starpu_data_handle_t *handles, void * const cl_arg, size_t cl_arg_size, unsigned cl_arg_free, int if_clause) __STARPU_OMP_NOTHROW;
 
 extern void starpu_omp_barrier(void) __STARPU_OMP_NOTHROW;
 

+ 14 - 2
src/util/openmp_runtime_support.c

@@ -56,6 +56,18 @@ static void _wake_up_locked_task(struct starpu_omp_task *task);
 static void wake_up_barrier(struct starpu_omp_region *parallel_region);
 static void starpu_omp_task_preempt(void);
 
+struct starpu_omp_thread *_starpu_omp_get_thread(void)
+{
+	struct starpu_omp_thread *thread = STARPU_PTHREAD_GETSPECIFIC(omp_thread_key);
+	return thread;
+}
+
+struct starpu_omp_task *_starpu_omp_get_task(void)
+{
+	struct starpu_omp_task *task = STARPU_PTHREAD_GETSPECIFIC(omp_task_key);
+	return task;
+}
+
 static void condition_init(struct starpu_omp_condition *condition)
 {
 	condition->contention_list_head = NULL;
@@ -668,7 +680,7 @@ void starpu_omp_shutdown(void)
 }
 
 void starpu_omp_parallel_region(const struct starpu_codelet * const _parallel_region_cl, starpu_data_handle_t *handles,
-		void * const cl_arg, size_t cl_arg_size, unsigned cl_arg_free)
+		void * const cl_arg, size_t cl_arg_size, unsigned cl_arg_free, int if_clause)
 {
 	struct starpu_omp_thread *master_thread = STARPU_PTHREAD_GETSPECIFIC(omp_thread_key);
 	struct starpu_omp_task *task = STARPU_PTHREAD_GETSPECIFIC(omp_task_key);
@@ -678,7 +690,7 @@ void starpu_omp_parallel_region(const struct starpu_codelet * const _parallel_re
 	/* TODO: compute the proper nb_threads and launch additional workers as needed.
 	 * for now, the level 1 parallel region spans all the threads
 	 * and level >= 2 parallel regions have only one thread */
-	int nb_threads = (region->level == 0)?starpu_cpu_worker_get_count():1;
+	int nb_threads = (if_clause != 0 && region->level == 0)?starpu_cpu_worker_get_count():1;
 
 	struct starpu_omp_region *new_region = 
 		create_omp_region_struct(region, _global_state.initial_device);

+ 2 - 0
src/util/openmp_runtime_support.h

@@ -349,6 +349,8 @@ extern double _starpu_omp_clock_ref;
  */
 void _starpu_omp_environment_init(void);
 void _starpu_omp_environment_exit(void);
+struct starpu_omp_thread *_starpu_omp_get_thread(void);
+struct starpu_omp_task *_starpu_omp_get_task(void);
 #endif // STARPU_OPENMP
 
 #endif // __OPENMP_RUNTIME_SUPPORT_H__

+ 28 - 6
src/util/openmp_runtime_support_omp_api.c

@@ -28,18 +28,40 @@ void starpu_omp_set_num_threads(int threads)
 
 int starpu_omp_get_num_threads()
 {
-	return starpu_cpu_worker_get_count();
+	struct starpu_omp_task *task = _starpu_omp_get_task();
+	struct starpu_omp_region *region;
+	if (task == NULL)
+		return 1;
+
+	region = task->owner_region;
+	return region->nb_threads;
 }
 
 int starpu_omp_get_thread_num()
 {
-	int tid = starpu_worker_get_id();
-	/* TODO: handle master thread case */
-	if (tid < 0)
+	struct starpu_omp_thread *thread = _starpu_omp_get_thread();
+	struct starpu_omp_task *task = _starpu_omp_get_task();
+	struct starpu_omp_region *region;
+	if (thread == NULL || task == NULL)
+		return 0;
+
+	region = task->owner_region;
+	if (thread == region->master_thread)
+		return 0;
+
+	struct starpu_omp_thread * region_thread;
+	int tid = 1;
+	for (region_thread  = starpu_omp_thread_list_begin(region->thread_list);
+			region_thread != starpu_omp_thread_list_end(region->thread_list);
+			region_thread  = starpu_omp_thread_list_next(region_thread))
 	{
-		_STARPU_ERROR("starpu_omp_get_thread_num: no worker associated to this thread\n");
+		if (thread == region_thread)
+		{
+			return tid;
+		}
+		tid++;
 	}
-	return tid;
+	_STARPU_ERROR("unrecognized omp thread\n");
 }
 
 int starpu_omp_get_max_threads()

+ 1 - 1
tests/openmp/parallel_01.c

@@ -59,7 +59,7 @@ static struct starpu_codelet parallel_region_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 2 - 2
tests/openmp/parallel_02.c

@@ -67,7 +67,7 @@ void parallel_region_1_f(void *buffers[], void *args)
 	worker_id = starpu_worker_get_id();
 	printf("[tid %p] parallel region 1: task thread = %d\n", (void *)tid, worker_id);
 
-	starpu_omp_parallel_region(&parallel_region_2_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_2_cl, NULL, NULL, 0, 0, 1);
 }
 
 static struct starpu_codelet parallel_region_1_cl =
@@ -80,7 +80,7 @@ static struct starpu_codelet parallel_region_1_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 2 - 2
tests/openmp/parallel_03.c

@@ -59,8 +59,8 @@ static struct starpu_codelet parallel_region_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 1 - 1
tests/openmp/parallel_barrier_01.c

@@ -68,7 +68,7 @@ int
 main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 2 - 2
tests/openmp/parallel_critical_01.c

@@ -77,9 +77,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_critical_inline_01.c

@@ -80,9 +80,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_critical_named_01.c

@@ -87,9 +87,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_critical_named_inline_01.c

@@ -80,9 +80,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 6 - 6
tests/openmp/parallel_for_01.c

@@ -197,27 +197,27 @@ static void check_array(void)
 int
 main (int argc, char *argv[]) {
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_2_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_2_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_3_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_3_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_4_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_4_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_5_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_5_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_6_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_6_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 	return 0;
 }

+ 1 - 1
tests/openmp/parallel_for_02.c

@@ -86,7 +86,7 @@ static struct starpu_codelet parallel_region_1_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 6 - 6
tests/openmp/parallel_for_ordered_01.c

@@ -207,28 +207,28 @@ static void check_array(void)
 int
 main (int argc, char *argv[]) {
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 	return 0;
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_2_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_2_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_3_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_3_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_4_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_4_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_5_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_5_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 
 	clear_array();
-	starpu_omp_parallel_region(&parallel_region_6_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_6_cl, NULL, NULL, 0, 0, 1);
 	check_array();
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_master_01.c

@@ -77,9 +77,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_master_inline_01.c

@@ -71,9 +71,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 1 - 1
tests/openmp/parallel_sections_01.c

@@ -102,7 +102,7 @@ static struct starpu_codelet parallel_region_1_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_1_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 2 - 2
tests/openmp/parallel_single_inline_01.c

@@ -88,9 +88,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_single_nowait_01.c

@@ -77,9 +77,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 2 - 2
tests/openmp/parallel_single_wait_01.c

@@ -77,9 +77,9 @@ main (int argc, char *argv[]) {
 	pthread_t tid;
 	tid = pthread_self();
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	printf("<main>\n");
 	return 0;
 }

+ 1 - 1
tests/openmp/task_01.c

@@ -90,7 +90,7 @@ static struct starpu_codelet parallel_region_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 1 - 1
tests/openmp/taskgroup_01.c

@@ -102,7 +102,7 @@ static struct starpu_codelet parallel_region_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif

+ 1 - 1
tests/openmp/taskwait_01.c

@@ -99,7 +99,7 @@ static struct starpu_codelet parallel_region_cl =
 
 int
 main (int argc, char *argv[]) {
-	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0);
+	starpu_omp_parallel_region(&parallel_region_cl, NULL, NULL, 0, 0, 1);
 	return 0;
 }
 #endif