Browse Source

- add api for selecting the number of threads participating to a parallel section

Olivier Aumage 11 years ago
parent
commit
ec1acc1c6a

+ 1 - 0
include/starpu_openmp.h

@@ -52,6 +52,7 @@ struct starpu_omp_parallel_region_attr
 	unsigned  cl_arg_free;
 
 	int if_clause;
+	int num_threads;
 };
 
 struct starpu_omp_task_region_attr

+ 6 - 2
src/util/openmp_runtime_support.c

@@ -781,12 +781,16 @@ void starpu_omp_parallel_region(const struct starpu_omp_parallel_region_attr *at
 	if (attr->if_clause != 0)
 	{
 		const int max_threads = (int)starpu_cpu_worker_get_count();
-		if (generating_region->icvs.nthreads_var[0] < max_threads)
+		if (attr->num_threads > 0)
 		{
-			nb_threads = generating_region->icvs.nthreads_var[0];
+			nb_threads = attr->num_threads;
 		}
 		else
 		{
+			nb_threads = generating_region->icvs.nthreads_var[0];
+		}
+		if (nb_threads > max_threads)
+		{
 			nb_threads = max_threads;
 		}
 		if (nb_threads > 1 && generating_region->icvs.active_levels_var+1 > max_active_levels)

+ 7 - 2
src/util/openmp_runtime_support_omp_api.c

@@ -22,8 +22,13 @@
 
 void starpu_omp_set_num_threads(int threads)
 {
-	(void) threads;
-	__not_implemented__;
+	STARPU_ASSERT(threads > 0);
+	struct starpu_omp_task *task = _starpu_omp_get_task();
+	STARPU_ASSERT(task != NULL);
+	struct starpu_omp_region *region;
+	region = task->owner_region;
+	STARPU_ASSERT(region != NULL);
+	region->icvs.nthreads_var[0] = threads;
 }
 
 int starpu_omp_get_num_threads()