Explorar el Código

mpi lb: fixes

Nathalie Furmento hace 8 años
padre
commit
b708310412

+ 1 - 2
mpi/include/starpu_mpi_lb.h

@@ -26,13 +26,12 @@ struct starpu_mpi_lb_conf
 {
 	void (*get_neighbors)(int **neighbor_ids, int *nneighbors);
 	void (*get_data_unit_to_migrate)(starpu_data_handle_t **handle_unit, int *nhandles, int dst_node);
-	const char *name;
 };
 
 /* Inits the load balancer's environment with the load policy provided by the
  * user
  */
-void starpu_mpi_lb_init(struct starpu_mpi_lb_conf *);
+void starpu_mpi_lb_init(const char *lb_policy_name, struct starpu_mpi_lb_conf *);
 void starpu_mpi_lb_shutdown();
 
 #ifdef __cplusplus

+ 28 - 38
mpi/src/load_balancer/load_balancer.c

@@ -26,15 +26,16 @@
 #include "policy/load_balancer_policy.h"
 
 static struct load_balancer_policy *defined_policy = NULL;
-static void (*saved_post_exec_hook)(struct starpu_task *task, unsigned sched_ctx_id) = NULL;
+typedef void (*_post_exec_hook_func_t)(struct starpu_task *task, unsigned sched_ctx_id);
+static _post_exec_hook_func_t saved_post_exec_hook[STARPU_NMAX_SCHED_CTXS];
 
 static void post_exec_hook_wrapper(struct starpu_task *task, unsigned sched_ctx_id)
 {
 	//fprintf(stderr,"I am called ! \n");
 	if (defined_policy && defined_policy->finished_task_entry_point)
 		defined_policy->finished_task_entry_point();
-	if (saved_post_exec_hook)
-		saved_post_exec_hook(task, sched_ctx_id);
+	if (saved_post_exec_hook[sched_ctx_id])
+		saved_post_exec_hook[sched_ctx_id](task, sched_ctx_id);
 }
 
 static struct load_balancer_policy *predefined_policies[] =
@@ -43,13 +44,13 @@ static struct load_balancer_policy *predefined_policies[] =
 	NULL
 };
 
-void starpu_mpi_lb_init(struct starpu_mpi_lb_conf *itf)
+void starpu_mpi_lb_init(const char *lb_policy_name, struct starpu_mpi_lb_conf *itf)
 {
 	int ret;
 
 	const char *policy_name = starpu_getenv("STARPU_MPI_LB");
-	if (!policy_name && itf)
-		policy_name = itf->name;
+	if (!policy_name)
+		policy_name = lb_policy_name;
 
 	if (!policy_name || (strcmp(policy_name, "help") == 0))
 	{
@@ -103,25 +104,19 @@ void starpu_mpi_lb_init(struct starpu_mpi_lb_conf *itf)
 	/* starpu_register_hook(finished_task, defined_policy->finished_task_entry_point); */
 	if (defined_policy->finished_task_entry_point)
 	{
-		STARPU_ASSERT(saved_post_exec_hook == NULL);
-		struct starpu_sched_policy **predefined_sched_policies = starpu_sched_get_predefined_policies();
-		struct starpu_sched_policy **sched_policy;
-		const char *sched_policy_name = starpu_getenv("STARPU_SCHED");
-
-		if (!sched_policy_name)
-			sched_policy_name = "eager";
-
-		for(sched_policy=predefined_sched_policies ; *sched_policy!=NULL ; sched_policy++)
+		int i;
+		for(i = 0; i < STARPU_NMAX_SCHED_CTXS; i++)
 		{
-			struct starpu_sched_policy *sched_p = *sched_policy;
-			if (strcmp(sched_policy_name, sched_p->policy_name) == 0)
+			struct starpu_sched_policy *sched_policy = starpu_sched_ctx_get_sched_policy(i);
+			if (sched_policy)
 			{
-				/* We found the scheduling policy with the requested name */
-				saved_post_exec_hook = sched_p->post_exec_hook;
-				break;
+				_STARPU_DEBUG("Setting post_exec_hook for scheduling context %d %s (%d)\n", i, sched_policy->policy_name, STARPU_NMAX_SCHED_CTXS);
+				saved_post_exec_hook[i] = sched_policy->post_exec_hook;
+				sched_policy->post_exec_hook = post_exec_hook_wrapper;
 			}
+			else
+				saved_post_exec_hook[i] = NULL;
 		}
-		starpu_sched_policy_set_post_exec_hook(post_exec_hook_wrapper, sched_policy_name);
 	}
 
 	return;
@@ -132,35 +127,30 @@ void starpu_mpi_lb_shutdown()
 	if (!defined_policy)
 		return;
 
-	if (defined_policy && defined_policy->deinit())
+	int ret = defined_policy->deinit();
+	if (ret != 0)
+	{
+		_STARPU_MSG("Error (%d) in %s->deinit\n", ret, defined_policy->policy_name);
 		return;
+	}
 
 	/* starpu_unregister_hook(submitted_task, defined_policy->submitted_task_entry_point); */
 	if (defined_policy->submitted_task_entry_point)
 		starpu_mpi_pre_submit_hook_unregister();
 
 	/* starpu_unregister_hook(finished_task, defined_policy->finished_task_entry_point); */
-	if (defined_policy->finished_task_entry_point && saved_post_exec_hook != NULL)
+	if (defined_policy->finished_task_entry_point)
 	{
-		struct starpu_sched_policy **predefined_sched_policies = starpu_sched_get_predefined_policies();
-		struct starpu_sched_policy **sched_policy;
-		const char *sched_policy_name = starpu_getenv("STARPU_SCHED");
-
-		if (!sched_policy_name)
-			sched_policy_name = "eager";
-
-		for(sched_policy=predefined_sched_policies ; *sched_policy!=NULL ; sched_policy++)
+		int i;
+		for(i = 0; i < STARPU_NMAX_SCHED_CTXS; i++)
 		{
-			struct starpu_sched_policy *sched_p = *sched_policy;
-			if (strcmp(sched_policy_name, sched_p->policy_name) == 0)
+			if (saved_post_exec_hook[i])
 			{
-				/* We found the scheduling policy with the requested name */
-				sched_p->post_exec_hook = saved_post_exec_hook;
-				saved_post_exec_hook = NULL;
-				break;
+				struct starpu_sched_policy *sched_policy = starpu_sched_ctx_get_sched_policy(i);
+				sched_policy->post_exec_hook = saved_post_exec_hook[i];
+				saved_post_exec_hook[i] = NULL;
 			}
 		}
 	}
-	STARPU_ASSERT(saved_post_exec_hook == NULL);
 	defined_policy = NULL;
 }

+ 3 - 1
mpi/src/load_balancer/policy/load_heat_propagation.c

@@ -554,11 +554,13 @@ static int deinit_heat()
 	if ((!user_itf) || (nneighbors == 0))
 		return 1;
 
+	_STARPU_DEBUG("Shutting down heat lb policy\n");
+
 	unsigned int ndata_to_move_back = HASH_COUNT(mdh);
 
 	if (ndata_to_move_back)
 	{
-		//fprintf(stderr,"Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
+		_STARPU_DEBUG("Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
 		data_movements_reallocate_tables(data_movements_handles[my_rank], ndata_to_move_back);
 
 		int *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);

+ 9 - 10
mpi/src/starpu_mpi_task_insert.c

@@ -40,17 +40,16 @@ static void (*pre_submit_hook)(struct starpu_task *task) = NULL;
 
 int starpu_mpi_pre_submit_hook_register(void (*f)(struct starpu_task *))
 {
-    if (pre_submit_hook)
-        fprintf(stderr,"Warning: a pre_submit_hook has already been registered.\nPlease check if you really want to erase the previously registered hook.\n");
-
-    pre_submit_hook = f; 
-    return 0;
+	if (pre_submit_hook)
+		_STARPU_MSG("Warning: a pre_submit_hook has already been registered. Please check if you really want to erase the previously registered hook.\n");
+	pre_submit_hook = f;
+	return 0;
 }
 
 int starpu_mpi_pre_submit_hook_unregister()
 {
-    pre_submit_hook = NULL;
-    return 0;
+	pre_submit_hook = NULL;
+	return 0;
 }
 
 int _starpu_mpi_find_executee_node(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int *do_execute, int *inconsistent_execute, int *xrank)
@@ -547,10 +546,10 @@ int _starpu_mpi_task_insert_v(MPI_Comm comm, struct starpu_codelet *codelet, va_
 
 	int val = _starpu_mpi_task_postbuild_v(comm, xrank, do_execute, descrs, nb_data);
 
-    if (ret == 0 && pre_submit_hook)
-        pre_submit_hook(task);
+	if (ret == 0 && pre_submit_hook)
+		pre_submit_hook(task);
 
-    return val;
+	return val;
 }
 
 int starpu_mpi_task_insert(MPI_Comm comm, struct starpu_codelet *codelet, ...)

+ 4 - 5
mpi/tests/load_balancer.c

@@ -39,7 +39,7 @@ void get_neighbors(int **neighbor_ids, int *nneighbors)
 
 void get_data_unit_to_migrate(starpu_data_handle_t **handle_unit, int *nhandles, int dst_node)
 {
-	STARPU_ASSERT(0);
+	*nhandles = 0;
 }
 
 int main(int argc, char **argv)
@@ -49,7 +49,6 @@ int main(int argc, char **argv)
 
 	itf.get_neighbors = get_neighbors;
 	itf.get_data_unit_to_migrate = get_data_unit_to_migrate;
-	itf.name = "my_itf";
 
 	MPI_Init(&argc, &argv);
 	ret = starpu_init(NULL);
@@ -58,11 +57,11 @@ int main(int argc, char **argv)
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	unsetenv("STARPU_MPI_LB");
-	starpu_mpi_lb_init(NULL);
+	starpu_mpi_lb_init(NULL, NULL);
 	starpu_mpi_lb_shutdown();
 
-	setenv("STARPU_MPI_LB", "heat", 1);
-	starpu_mpi_lb_init(&itf);
+	starpu_mpi_lb_init("heat", &itf);
+	starpu_mpi_lb_shutdown();
 
 	starpu_mpi_shutdown();
 	starpu_shutdown();