Selaa lähdekoodia

openmp: rework parsing of OMP_NUM_THREADS

Samuel Pitoiset 9 vuotta sitten
vanhempi
commit
3dd1fbd2b9
1 muutettua tiedostoa jossa 34 lisäystä ja 68 poistoa
  1. 34 68
      src/util/openmp_runtime_support_environment.c

+ 34 - 68
src/util/openmp_runtime_support_environment.c

@@ -330,56 +330,6 @@ static void convert_bind_string(const char *_str, int *bind_list, const int max_
 	free(str);
 }
 
-static void convert_num_threads_string(const char *_str, int *num_threads_list, const int max_levels)
-{
-	char *str = strdup(_str);
-	if (str == NULL)
-		_STARPU_ERROR("memory allocation failed\n");
-	remove_spaces(str);
-	if (str[0] == '\0')
-	{
-		free(str);
-		return;
-	}
-	enum { state_split, state_read };
-	int level = 0;
-	int i = 0;
-	int state = state_read;
-	while (1)
-	{
-		/* split a comma separated list of numerical items */
-		if (state == state_split)
-		{
-			if (str[i] == '\0')
-				break;
-			if (str[i] != ',')
-				_STARPU_ERROR("num_threads list parse error\n");
-			i++;
-			state = state_read;
-		}
-		/* read a numerical item */
-		else if (state == state_read)
-		{
-			char *endptr = NULL;
-			errno = 0;
-			int num_threads = (int)strtol(str+i, &endptr, 10);
-			if (errno != 0)
-				_STARPU_ERROR("num_threads list parse error, strtol failed with error %s\n", strerror(errno));
-			if (num_threads < 1)
-				_STARPU_ERROR("num_threads list invalid value\n");
-			num_threads_list[level] = num_threads;
-			level++;
-			if (level == max_levels)
-				break;
-			i = endptr - str;
-			state = state_split;
-		}
-		else
-			_STARPU_ERROR("invalid state in parsing num_threads list\n");
-	}
-	free(str);
-}
-
 static int convert_place_name(const char *str, size_t n)
 {
 	static const char *strings[] = { "threads", "cores", "sockets", NULL };
@@ -665,6 +615,39 @@ static void free_places(struct starpu_omp_place *places)
 	}
 }
 
+static void read_num_threads_var()
+{
+	const int max_levels = _initial_icv_values.max_active_levels_var + 1;
+	int *num_threads_list = NULL;
+	int level = 0;
+	char *env;
+
+	num_threads_list = calloc(max_levels, sizeof(*num_threads_list));
+	if (!num_threads_list)
+		_STARPU_ERROR("memory allocation failed\n");
+
+	env = starpu_getenv("OMP_NUM_THREADS");
+	if (env)
+	{
+		char *saveptr, *token;
+
+		token = strtok_r(env, ",", &saveptr);
+		for (; token != NULL; token = strtok_r(NULL, ",", &saveptr)) {
+			int value;
+
+			if (!read_int_var(token, &value))
+			{
+				fprintf(stderr, "StarPU: Invalid value for environment variable OMP_NUM_THREADS\n");
+				break;
+			}
+
+			num_threads_list[level++] = value;
+		}
+	}
+
+	_initial_icv_values.nthreads_var = num_threads_list;
+}
+
 static void read_omp_int_var(const char *name, int *icv)
 {
 	int ret, value;
@@ -739,24 +722,7 @@ static void read_omp_environment(void)
 		_initial_icv_values.bind_var = bind_list;
 	}
 
-	/* read OMP_NUM_THREADS */
-	{
-		int *num_threads_list = malloc((1+max_levels) * sizeof(*num_threads_list));
-		if (num_threads_list == NULL)
-			_STARPU_ERROR("memory allocation failed\n");
-		int level;
-		for (level = 0;level < max_levels+1;level++)
-		{
-			/* TODO: check what should be used as default value */
-			num_threads_list[level] = 0;
-		}
-		const char *env = starpu_getenv("OMP_NUM_THREADS");
-		if (env)
-		{
-			convert_num_threads_string(env, num_threads_list, max_levels);
-		}
-		_initial_icv_values.nthreads_var = num_threads_list;
-	}
+	read_num_threads_var();
 
 	/* read OMP_PLACES */
 	{