Przeglądaj źródła

mpi: New function starpu_mpi_wait_for_all(MPI_Comm comm) that allows to wait until all StarPU tasks and communications for the given communicator are completed.

Nathalie Furmento 9 lat temu
rodzic
commit
3aa5ed22d3

+ 3 - 0
ChangeLog

@@ -187,6 +187,9 @@ Small features:
   * Add _starpu_fifo_pop_this_task.
   * Add STARPU_MAX_MEMORY_USE environment variable.
   * Add starpu_worker_get_id_check().
+  * New function starpu_mpi_wait_for_all(MPI_Comm comm) that allows to
+    wait until all StarPU tasks and communications for the given
+    communicator are completed.
 
 Changes:
   * Data interfaces (variable, vector, matrix and block) now define

+ 5 - 1
doc/doxygen/chapters/api/mpi.doxy

@@ -1,7 +1,7 @@
 /*
  * This file is part of the StarPU Handbook.
  * Copyright (C) 2009--2011  Universit@'e de Bordeaux
- * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015  CNRS
+ * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015, 2016  CNRS
  * Copyright (C) 2011, 2012 INRIA
  * See the file version.doxy for copying conditions.
  */
@@ -156,6 +156,10 @@ operation.
 Blocks the caller until all group members of the communicator \p comm
 have called it.
 
+\fn int starpu_mpi_wait_for_all(MPI_Comm comm)
+\ingroup API_MPI_Support
+Wait until all StarPU tasks and communications for the given communicator are completed.
+
 \fn int starpu_mpi_isend_detached_unlock_tag(starpu_data_handle_t data_handle, int dest, int mpi_tag, MPI_Comm comm, starpu_tag_t tag)
 \ingroup API_MPI_Support
 Posts a standard-mode, non blocking send of \p data_handle to the node

+ 1 - 1
mpi/examples/user_datatype/user_datatype.c

@@ -97,7 +97,7 @@ int main(int argc, char **argv)
 		starpu_mpi_isend_detached(handle0, 0, 20, MPI_COMM_WORLD, NULL, NULL);
 	}
 
-	starpu_task_wait_for_all();
+	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
 
 	starpu_mpi_datatype_unregister(handle0);
 	starpu_data_unregister(handle0);

+ 3 - 1
mpi/include/starpu_mpi.h

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2009-2012, 2014-2015  Université de Bordeaux
- * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015  CNRS
+ * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015, 2016  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -108,6 +108,8 @@ int starpu_mpi_node_selection_set_current_policy(int policy);
 int starpu_mpi_cache_is_enabled();
 int starpu_mpi_cache_set(int enabled);
 
+int starpu_mpi_wait_for_all(MPI_Comm comm);
+
 typedef void (*starpu_mpi_datatype_allocate_func_t)(starpu_data_handle_t, MPI_Datatype *);
 typedef void (*starpu_mpi_datatype_free_func_t)(MPI_Datatype *);
 int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func);

+ 24 - 6
mpi/src/starpu_mpi.c

@@ -34,6 +34,7 @@
 #include <datawizard/interfaces/data_interface.h>
 #include <datawizard/coherency.h>
 #include <core/simgrid.h>
+#include <core/task.h>
 
 static void _starpu_mpi_add_sync_point_in_fxt(void);
 static void _starpu_mpi_submit_ready_request(void *arg);
@@ -829,12 +830,12 @@ static void _starpu_mpi_barrier_func(struct _starpu_mpi_req *barrier_req)
 	_STARPU_MPI_LOG_OUT();
 }
 
-int starpu_mpi_barrier(MPI_Comm comm)
+int _starpu_mpi_barrier(MPI_Comm comm)
 {
-	int ret;
-	struct _starpu_mpi_req *barrier_req;
-
 	_STARPU_MPI_LOG_IN();
+
+	int ret = posted_requests;
+	struct _starpu_mpi_req *barrier_req;
 	_starpu_mpi_request_init(&barrier_req);
 
 	/* First wait for *both* all tasks and MPI requests to finish, in case
@@ -877,14 +878,19 @@ int starpu_mpi_barrier(MPI_Comm comm)
 		STARPU_PTHREAD_COND_WAIT(&barrier_req->req_cond, &barrier_req->req_mutex);
 	STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->req_mutex);
 
-	ret = barrier_req->ret;
-
 	free(barrier_req);
 	barrier_req = NULL;
 	_STARPU_MPI_LOG_OUT();
+
 	return ret;
 }
 
+int starpu_mpi_barrier(MPI_Comm comm)
+{
+	_starpu_mpi_barrier(comm);
+	return 0;
+}
+
 /********************************************************/
 /*                                                      */
 /*  Progression                                         */
@@ -1785,3 +1791,15 @@ int starpu_mpi_world_rank(void)
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 	return rank;
 }
+
+int starpu_mpi_wait_for_all(MPI_Comm comm)
+{
+	int mpi = 1;
+	int task = 1;
+	while (task || mpi)
+	{
+		task = _starpu_task_wait_for_all_and_return_nb_waited_tasks();
+		mpi = _starpu_mpi_barrier(comm);
+	}
+	return 0;
+}