Browse Source

mpi/src/starpu_mpi_task_insert.c: (hopefully) simplify code

Nathalie Furmento 10 years ago
parent
commit
4102779158
1 changed files with 69 additions and 51 deletions
  1. 69 51
      mpi/src/starpu_mpi_task_insert.c

+ 69 - 51
mpi/src/starpu_mpi_task_insert.c

@@ -32,7 +32,7 @@
 #include <starpu_mpi_select_node.h>
 
 static
-int _starpu_mpi_find_executee_node(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int *do_execute, int *inconsistent_execute, int *dest)
+int _starpu_mpi_find_executee_node(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int *do_execute, int *inconsistent_execute, int *xrank)
 {
 	if (mode & STARPU_W)
 	{
@@ -60,12 +60,13 @@ int _starpu_mpi_find_executee_node(starpu_data_handle_t data, enum starpu_data_a
 			// This node owns the data
 			if (*do_execute == 0)
 			{
-				// Another node has already been selected to execute the codelet
+				_STARPU_MPI_DEBUG(100, "Another node has already been selected to execute the codelet\n");
 				*inconsistent_execute = 1;
 			}
 			else
 			{
-				// This node is going to execute the codelet
+				_STARPU_MPI_DEBUG(100, "This node is going to execute the codelet\n");
+				*xrank = me;
 				*do_execute = 1;
 			}
 		}
@@ -74,23 +75,23 @@ int _starpu_mpi_find_executee_node(starpu_data_handle_t data, enum starpu_data_a
 			// Another node owns the data
 			if (*do_execute == 1)
 			{
-				// But this node has already been selected to execute the codelet
+				_STARPU_MPI_DEBUG(100, "Another node owns the data but this node has already been selected to execute the codelet\n");
 				*inconsistent_execute = 1;
 			}
 			else
 			{
-				// This node will not execute the codelet
+				_STARPU_MPI_DEBUG(100, "This node will not execute the codelet\n");
 				*do_execute = 0;
-				*dest = mpi_rank;
-				/* That's the rank which needs the data to be sent to */
+				*xrank = mpi_rank;
 			}
 		}
 	}
+	_STARPU_MPI_DEBUG(100, "Executing: inconsistent=%d, do_execute=%d, xrank=%d\n", *inconsistent_execute, *do_execute, *xrank);
 	return 0;
 }
 
 static
-void _starpu_mpi_exchange_data_before_execution(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int dest, int do_execute, MPI_Comm comm)
+void _starpu_mpi_exchange_data_before_execution(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int xrank, int do_execute, MPI_Comm comm)
 {
 	if (data && mode & STARPU_R)
 	{
@@ -120,18 +121,18 @@ void _starpu_mpi_exchange_data_before_execution(starpu_data_handle_t data, enum
 		if (!do_execute && mpi_rank == me)
 		{
 			/* Somebody else will execute it, and I have the data, send it. */
-			void *already_sent = _starpu_mpi_already_sent(data, dest);
+			void *already_sent = _starpu_mpi_already_sent(data, xrank);
 			if (already_sent == NULL)
 			{
-				_STARPU_MPI_DEBUG(1, "Send data %p to %d\n", data, dest);
-				starpu_mpi_isend_detached(data, dest, mpi_tag, comm, NULL, NULL);
+				_STARPU_MPI_DEBUG(1, "Send data %p to %d\n", data, xrank);
+				starpu_mpi_isend_detached(data, xrank, mpi_tag, comm, NULL, NULL);
 			}
 		}
 	}
 }
 
 static
-void _starpu_mpi_exchange_data_after_execution(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int xrank, int dest, int do_execute, MPI_Comm comm)
+void _starpu_mpi_exchange_data_after_execution(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int xrank, int do_execute, MPI_Comm comm)
 {
 	if (mode & STARPU_W)
 	{
@@ -151,8 +152,8 @@ void _starpu_mpi_exchange_data_after_execution(starpu_data_handle_t data, enum s
 		{
 			if (xrank != -1 && me != xrank)
 			{
-				_STARPU_MPI_DEBUG(1, "Receive data %p back from the task %d which executed the codelet ...\n", data, dest);
-				starpu_mpi_irecv_detached(data, dest, mpi_tag, comm, NULL, NULL);
+				_STARPU_MPI_DEBUG(1, "Receive data %p back from the task %d which executed the codelet ...\n", data, xrank);
+				starpu_mpi_irecv_detached(data, xrank, mpi_tag, comm, NULL, NULL);
 			}
 		}
 		else if (do_execute)
@@ -311,12 +312,13 @@ int _starpu_mpi_task_select_node(struct starpu_codelet *codelet, int me, int nb_
 }
 
 static
-int _starpu_mpi_task_decode_v(struct starpu_codelet *codelet, int me, int nb_nodes, int *xrank, int *dest, int *do_execute, va_list varg_list)
+int _starpu_mpi_task_decode_v(struct starpu_codelet *codelet, int me, int nb_nodes, int *xrank, int *do_execute, va_list varg_list)
 {
 	va_list varg_list_copy;
 	int inconsistent_execute = 0;
 	int arg_type, arg_type_nocommute;
 	int current_data = 0;
+	int node_selected = 0;
 
 	*do_execute = -1;
 	*xrank = -1;
@@ -327,17 +329,27 @@ int _starpu_mpi_task_decode_v(struct starpu_codelet *codelet, int me, int nb_nod
 		if (arg_type==STARPU_EXECUTE_ON_NODE)
 		{
 			*xrank = va_arg(varg_list_copy, int);
-			_STARPU_MPI_DEBUG(100, "Executing on node %d\n", *xrank);
-			*do_execute = 1;
+			if (node_selected == 0)
+			{
+				_STARPU_MPI_DEBUG(100, "Executing on node %d\n", *xrank);
+				*do_execute = 1;
+				node_selected = 1;
+				inconsistent_execute = 0;
+			}
 		}
 		else if (arg_type==STARPU_EXECUTE_ON_DATA)
 		{
 			starpu_data_handle_t data = va_arg(varg_list_copy, starpu_data_handle_t);
-			*xrank = starpu_data_get_rank(data);
-			STARPU_ASSERT_MSG(*xrank != -1, "Rank of the data must be set using starpu_mpi_data_register() or starpu_data_set_rank()");
-			_STARPU_MPI_DEBUG(100, "Executing on data node %d\n", *xrank);
-			STARPU_ASSERT_MSG(*xrank <= nb_nodes, "Node %d to execute codelet is not a valid node (%d)", *xrank, nb_nodes);
-			*do_execute = 1;
+			if (node_selected == 0)
+			{
+				*xrank = starpu_data_get_rank(data);
+				STARPU_ASSERT_MSG(*xrank != -1, "Rank of the data must be set using starpu_mpi_data_register() or starpu_data_set_rank()");
+				_STARPU_MPI_DEBUG(100, "Executing on data node %d\n", *xrank);
+				STARPU_ASSERT_MSG(*xrank <= nb_nodes, "Node %d to execute codelet is not a valid node (%d)", *xrank, nb_nodes);
+				*do_execute = 1;
+				node_selected = 1;
+				inconsistent_execute = 0;
+			}
 		}
 		else if (arg_type==STARPU_EXECUTE_ON_WORKER)
 		{
@@ -355,10 +367,13 @@ int _starpu_mpi_task_decode_v(struct starpu_codelet *codelet, int me, int nb_nod
 		{
 			starpu_data_handle_t data = va_arg(varg_list_copy, starpu_data_handle_t);
 			enum starpu_data_access_mode mode = (enum starpu_data_access_mode) arg_type;
-			int ret = _starpu_mpi_find_executee_node(data, mode, me, do_execute, &inconsistent_execute, dest);
-			if (ret == -EINVAL)
+			if (node_selected == 0)
 			{
-				return ret;
+				int ret = _starpu_mpi_find_executee_node(data, mode, me, do_execute, &inconsistent_execute, xrank);
+				if (ret == -EINVAL)
+				{
+					return ret;
+				}
 			}
 			current_data ++;
 		}
@@ -366,16 +381,20 @@ int _starpu_mpi_task_decode_v(struct starpu_codelet *codelet, int me, int nb_nod
 		{
 			starpu_data_handle_t *datas = va_arg(varg_list_copy, starpu_data_handle_t *);
 			int nb_handles = va_arg(varg_list_copy, int);
-			int i;
-			for(i=0 ; i<nb_handles ; i++)
+			if (node_selected) current_data += nb_handles;
+			else
 			{
-				enum starpu_data_access_mode mode = STARPU_CODELET_GET_MODE(codelet, current_data);
-				int ret = _starpu_mpi_find_executee_node(datas[i], mode, me, do_execute, &inconsistent_execute, dest);
-				if (ret == -EINVAL)
+				int i;
+				for(i=0 ; i<nb_handles ; i++)
 				{
-					return ret;
+					enum starpu_data_access_mode mode = STARPU_CODELET_GET_MODE(codelet, current_data);
+					int ret = _starpu_mpi_find_executee_node(datas[i], mode, me, do_execute, &inconsistent_execute, xrank);
+					if (ret == -EINVAL)
+					{
+						return ret;
+					}
+					current_data ++;
 				}
-				current_data ++;
 			}
 		}
 		else if (arg_type==STARPU_VALUE)
@@ -440,31 +459,32 @@ int _starpu_mpi_task_decode_v(struct starpu_codelet *codelet, int me, int nb_nod
 	}
 	va_end(varg_list_copy);
 
-	if (inconsistent_execute == 1 && *xrank == -1)
+	if (inconsistent_execute == 1 || *xrank == -1)
 	{
 		// We need to find out which node is going to execute the codelet.
 		_STARPU_MPI_DISP("Different nodes are owning W data. Need to specify which node is going to execute the codelet, using STARPU_EXECUTE_ON_NODE or STARPU_EXECUTE_ON_DATA\n");
 		*xrank = _starpu_mpi_task_select_node(codelet, me, nb_nodes, varg_list, current_data);
 		*do_execute = (me == *xrank);
-		*dest = *xrank;
 	}
 	else
 	{
+		_STARPU_MPI_DEBUG(100, "Inconsistent=%d - xrank=%d\n", inconsistent_execute, *xrank);
 		*do_execute = (me == *xrank);
-		*dest = *xrank;
 	}
+	_STARPU_MPI_DEBUG(100, "do_execute=%d\n", *do_execute);
+
 	return 0;
 }
 
 static
-int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, struct starpu_task **task, int *xrank_p, int *dest_p, va_list varg_list)
+int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, struct starpu_task **task, int *xrank_p, va_list varg_list)
 {
 	int arg_type, arg_type_nocommute;
 	va_list varg_list_copy;
 	int me, do_execute, xrank, nb_nodes;
 	size_t arg_buffer_size = 0;
 	void *arg_buffer = NULL;
-	int ret, dest=0;
+	int ret;
 	int current_data;
 
 	_STARPU_MPI_LOG_IN();
@@ -487,7 +507,7 @@ int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, stru
 			starpu_data_handle_t data = va_arg(varg_list_copy, starpu_data_handle_t);
 			enum starpu_data_access_mode mode = (enum starpu_data_access_mode) arg_type;
 
-			_starpu_mpi_exchange_data_before_execution(data, mode, me, dest, do_execute, comm);
+			_starpu_mpi_exchange_data_before_execution(data, mode, me, xrank, do_execute, comm);
 			current_data ++;
 
 		}
@@ -499,7 +519,7 @@ int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, stru
 
 			for(i=0 ; i<nb_handles ; i++)
 			{
-				_starpu_mpi_exchange_data_before_execution(datas[i], STARPU_CODELET_GET_MODE(codelet, current_data), me, dest, do_execute, comm);
+				_starpu_mpi_exchange_data_before_execution(datas[i], STARPU_CODELET_GET_MODE(codelet, current_data), me, xrank, do_execute, comm);
 				current_data++;
 			}
 		}
@@ -581,7 +601,6 @@ int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, stru
 	va_end(varg_list_copy);
 
 	if (xrank_p) *xrank_p = xrank;
-	if (dest_p) *dest_p = dest;
 
 	if (do_execute == 0) return 1;
 	else
@@ -599,7 +618,7 @@ int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, stru
 			va_end(varg_list_copy);
 		}
 
-		_STARPU_MPI_DEBUG(100, "Execution of the codelet %p (%s)\n", codelet, codelet->name);
+		_STARPU_MPI_DEBUG(100, "Execution of the codelet %p (%s)\n", codelet, codelet?codelet->name:NULL);
 
 		*task = starpu_task_create();
 		(*task)->cl_arg_free = 1;
@@ -617,7 +636,7 @@ int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, stru
 }
 
 static
-int _starpu_mpi_task_postbuild_v(MPI_Comm comm, struct starpu_codelet *codelet, va_list varg_list, int xrank, int dest, int do_execute)
+int _starpu_mpi_task_postbuild_v(MPI_Comm comm, struct starpu_codelet *codelet, va_list varg_list, int xrank, int do_execute)
 {
 	int arg_type, arg_type_nocommute;
 	va_list varg_list_copy;
@@ -636,7 +655,7 @@ int _starpu_mpi_task_postbuild_v(MPI_Comm comm, struct starpu_codelet *codelet,
 			starpu_data_handle_t data = va_arg(varg_list_copy, starpu_data_handle_t);
 			enum starpu_data_access_mode mode = (enum starpu_data_access_mode) arg_type;
 
-			_starpu_mpi_exchange_data_after_execution(data, mode, me, xrank, dest, do_execute, comm);
+			_starpu_mpi_exchange_data_after_execution(data, mode, me, xrank, do_execute, comm);
 			_starpu_mpi_clear_data_after_execution(data, mode, me, do_execute, comm);
 			current_data++;
 		}
@@ -648,7 +667,7 @@ int _starpu_mpi_task_postbuild_v(MPI_Comm comm, struct starpu_codelet *codelet,
 
 			for(i=0 ; i<nb_handles ; i++)
 			{
-				_starpu_mpi_exchange_data_after_execution(datas[i], STARPU_CODELET_GET_MODE(codelet, current_data), me, xrank, dest, do_execute, comm);
+				_starpu_mpi_exchange_data_after_execution(datas[i], STARPU_CODELET_GET_MODE(codelet, current_data), me, xrank, do_execute, comm);
 				_starpu_mpi_clear_data_after_execution(datas[i], STARPU_CODELET_GET_MODE(codelet, current_data), me, do_execute, comm);
 				current_data++;
 			}
@@ -736,10 +755,9 @@ int _starpu_mpi_task_insert_v(MPI_Comm comm, struct starpu_codelet *codelet, va_
 	struct starpu_task *task;
 	int ret;
 	int xrank;
-	int dest;
 	int do_execute = 0;
 
-	ret = _starpu_mpi_task_build_v(comm, codelet, &task, &xrank, &dest, varg_list);
+	ret = _starpu_mpi_task_build_v(comm, codelet, &task, &xrank, varg_list);
 	if (ret < 0) return ret;
 
 	if (ret == 0)
@@ -759,7 +777,7 @@ int _starpu_mpi_task_insert_v(MPI_Comm comm, struct starpu_codelet *codelet, va_
 			starpu_task_destroy(task);
 		}
 	}
-	return _starpu_mpi_task_postbuild_v(comm, codelet, varg_list, xrank, dest, do_execute);
+	return _starpu_mpi_task_postbuild_v(comm, codelet, varg_list, xrank, do_execute);
 }
 
 int starpu_mpi_task_insert(MPI_Comm comm, struct starpu_codelet *codelet, ...)
@@ -791,7 +809,7 @@ struct starpu_task *starpu_mpi_task_build(MPI_Comm comm, struct starpu_codelet *
 	int ret;
 
 	va_start(varg_list, codelet);
-	ret = _starpu_mpi_task_build_v(comm, codelet, &task, NULL, NULL, varg_list);
+	ret = _starpu_mpi_task_build_v(comm, codelet, &task, NULL, varg_list);
 	va_end(varg_list);
 	STARPU_ASSERT(ret >= 0);
 	if (ret > 0) return NULL; else return task;
@@ -799,7 +817,7 @@ struct starpu_task *starpu_mpi_task_build(MPI_Comm comm, struct starpu_codelet *
 
 int starpu_mpi_task_post_build(MPI_Comm comm, struct starpu_codelet *codelet, ...)
 {
-	int xrank, dest, do_execute;
+	int xrank, do_execute;
 	int ret, me, nb_nodes;
 	va_list varg_list;
 
@@ -808,11 +826,11 @@ int starpu_mpi_task_post_build(MPI_Comm comm, struct starpu_codelet *codelet, ..
 
 	va_start(varg_list, codelet);
 	/* Find out whether we are to execute the data because we own the data to be written to. */
-	ret = _starpu_mpi_task_decode_v(codelet, me, nb_nodes, &xrank, &dest, &do_execute, varg_list);
+	ret = _starpu_mpi_task_decode_v(codelet, me, nb_nodes, &xrank, &do_execute, varg_list);
 	if (ret < 0) return ret;
 	va_end(varg_list);
 
-	return _starpu_mpi_task_postbuild_v(comm, codelet, varg_list, xrank, dest, do_execute);
+	return _starpu_mpi_task_postbuild_v(comm, codelet, varg_list, xrank, do_execute);
 }
 
 void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)