Prechádzať zdrojové kódy

mpi: add possibility to specify a communicator when initializing MPI.

Nathalie Furmento 10 rokov pred
rodič
commit
db67c3db93

+ 2 - 0
ChangeLog

@@ -51,6 +51,8 @@ New features:
 	  starpu_mpi_node_selection_unregister_policy()
 	- New environment variable STARPU_MPI_COMM which enables
 	  basic tracing of communications.
+	- New function starpu_mpi_init_comm() which allows to specify
+	  a MPI communicator.
 
   * New STARPU_COMMUTE flag which can be passed along STARPU_W or STARPU_RW to
     let starpu commute write accesses.

+ 15 - 4
doc/doxygen/chapters/16mpi_support.doxy

@@ -529,15 +529,26 @@ starpu_mpi_gather_detached(data_handles, nblocks, 0, MPI_COMM_WORLD);
 MPI examples are available in the StarPU source code in mpi/examples:
 
 <ul>
-<li><c>complex</c> is a simple example using a user-define data interface over
+<li>
+<c>comm</c> shows how to use communicators with StarPU-MPI
+</li>
+<li>
+<c>complex</c> is a simple example using a user-define data interface over
 MPI (complex numbers),
-<li><c>stencil5</c> is a simple stencil example using starpu_mpi_task_insert(),
-<li><c>matrix_decomposition</c> is a cholesky decomposition example using
+</li>
+<li>
+<c>stencil5</c> is a simple stencil example using starpu_mpi_task_insert(),
+</li>
+<li>
+<c>matrix_decomposition</c> is a cholesky decomposition example using
 starpu_mpi_task_insert(). The non-distributed version can check for
 <algorithm correctness in 1-node configuration, the distributed version uses
 exactly the same source code, to be used over MPI,
-<li><c>mpi_lu</c> is an LU decomposition example, provided in three versions:
+</li>
+<li>
+<c>mpi_lu</c> is an LU decomposition example, provided in three versions:
 <c>plu_example</c> uses explicit MPI data transfers, <c>plu_implicit_example</c>
 uses implicit MPI data transfers, <c>plu_outofcore_example</c> uses implicit MPI
 data transfers and supports data matrices which do not fit in memory (out-of-core).
+</li>
 </ul>

+ 8 - 4
doc/doxygen/chapters/api/mpi.doxy

@@ -17,12 +17,16 @@ This macro is defined when StarPU has been installed with MPI
 support. It should be used in your code to detect the availability of
 MPI.
 
+\fn int starpu_mpi_init_comm(int *argc, char ***argv, int initialize_mpi, MPI_Comm comm)
+\ingroup API_MPI_Support
+Initializes the starpumpi library with the given communicator.
+\p initialize_mpi indicates if MPI should be initialized or not by StarPU.
+If the value is not 0, MPI will be initialized by calling
+<c>MPI_Init_Thread(argc, argv, MPI_THREAD_SERIALIZED, ...)</c>.
+
 \fn int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi)
 \ingroup API_MPI_Support
-Initializes the starpumpi library. \p initialize_mpi indicates if MPI
-should be initialized or not by StarPU. If the value is not 0, MPI
-will be initialized by calling <c>MPI_Init_Thread(argc, argv,
-MPI_THREAD_SERIALIZED, ...)</c>.
+Call starpu_mpi_init_comm() with the MPI communicator MPI_COMM_WORLD.
 
 \fn int starpu_mpi_initialize(void)
 \deprecated

+ 1 - 0
mpi/include/starpu_mpi.h

@@ -44,6 +44,7 @@ int starpu_mpi_barrier(MPI_Comm comm);
 
 int starpu_mpi_irecv_detached_sequential_consistency(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg, int sequential_consistency);
 
+int starpu_mpi_init_comm(int *argc, char ***argv, int initialize_mpi, MPI_Comm comm);
 int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi);
 int starpu_mpi_initialize(void) STARPU_DEPRECATED;
 int starpu_mpi_initialize_extended(int *rank, int *world_size) STARPU_DEPRECATED;

+ 20 - 13
mpi/src/starpu_mpi.c

@@ -1126,6 +1126,7 @@ struct _starpu_mpi_argc_argv
 	int initialize_mpi;
 	int *argc;
 	char ***argv;
+	MPI_Comm comm;
 };
 
 static void _starpu_mpi_print_thread_level_support(int thread_level, char *msg)
@@ -1228,9 +1229,9 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 		_starpu_mpi_print_thread_level_support(provided, " has been initialized with");
 	}
 
-	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	MPI_Comm_size(MPI_COMM_WORLD, &worldsize);
-	MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
+	MPI_Comm_rank(argc_argv->comm, &rank);
+	MPI_Comm_size(argc_argv->comm, &worldsize);
+	MPI_Comm_set_errhandler(argc_argv->comm, MPI_ERRORS_RETURN);
 
 #ifdef STARPU_SIMGRID
 	_mpi_world_size = worldsize;
@@ -1247,8 +1248,8 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 	}
 
 	_starpu_mpi_add_sync_point_in_fxt();
-	_starpu_mpi_comm_amounts_init(MPI_COMM_WORLD);
-	_starpu_mpi_cache_init(MPI_COMM_WORLD);
+	_starpu_mpi_comm_amounts_init(argc_argv->comm);
+	_starpu_mpi_cache_init(argc_argv->comm);
 	_starpu_mpi_select_node_init();
 	_starpu_mpi_tag_init();
 
@@ -1316,7 +1317,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 		{
 			_STARPU_MPI_DEBUG(3, "Posting a receive to get a data envelop\n");
 			_STARPU_MPI_COMM_FROM_DEBUG(sizeof(struct _starpu_mpi_envelope), MPI_BYTE, MPI_ANY_SOURCE, _STARPU_MPI_TAG_ENVELOPE, _STARPU_MPI_TAG_ENVELOPE);
-			MPI_Irecv(envelope, sizeof(struct _starpu_mpi_envelope), MPI_BYTE, MPI_ANY_SOURCE, _STARPU_MPI_TAG_ENVELOPE, MPI_COMM_WORLD, &envelope_request);
+			MPI_Irecv(envelope, sizeof(struct _starpu_mpi_envelope), MPI_BYTE, MPI_ANY_SOURCE, _STARPU_MPI_TAG_ENVELOPE, argc_argv->comm, &envelope_request);
 			envelope_request_submitted = 1;
 		}
 
@@ -1373,7 +1374,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 							new_req->data_handle = NULL;
 							new_req->srcdst = status.MPI_SOURCE;
 							new_req->data_tag = envelope->data_tag;
-							new_req->comm = MPI_COMM_WORLD;
+							new_req->comm = argc_argv->comm;
 							new_req->detached = 1;
 							new_req->sync = 1;
 							new_req->callback = NULL;
@@ -1387,7 +1388,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 						}
 						else
 						{
-							_starpu_mpi_receive_early_data(envelope, status, MPI_COMM_WORLD);
+							_starpu_mpi_receive_early_data(envelope, status, argc_argv->comm);
 						}
 					}
 					/* Case: a matching application request has been found for
@@ -1513,7 +1514,7 @@ static void _starpu_mpi_add_sync_point_in_fxt(void)
 }
 
 static
-int _starpu_mpi_initialize(int *argc, char ***argv, int initialize_mpi)
+int _starpu_mpi_initialize(int *argc, char ***argv, int initialize_mpi, MPI_Comm comm)
 {
 	STARPU_PTHREAD_MUTEX_INIT(&mutex, NULL);
 	STARPU_PTHREAD_COND_INIT(&cond_progression, NULL);
@@ -1529,6 +1530,7 @@ int _starpu_mpi_initialize(int *argc, char ***argv, int initialize_mpi)
 	argc_argv->initialize_mpi = initialize_mpi;
 	argc_argv->argc = argc;
 	argc_argv->argv = argv;
+	argc_argv->comm = comm;
 
 #ifdef STARPU_MPI_ACTIVITY
 	hookid = starpu_progression_hook_register(_starpu_mpi_progression_hook_func, NULL);
@@ -1559,23 +1561,28 @@ int _starpu_mpi_simgrid_init(int argc, char *argv[])
 }
 #endif
 
-int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi)
+int starpu_mpi_init_comm(int *argc, char ***argv, int initialize_mpi, MPI_Comm comm)
 {
 #ifdef STARPU_SIMGRID
 	STARPU_MPI_ASSERT_MSG(initialize_mpi, "application has to let StarPU initialize MPI");
 	return 0;
 #else
-	return _starpu_mpi_initialize(argc, argv, initialize_mpi);
+	return _starpu_mpi_initialize(argc, argv, initialize_mpi, comm);
 #endif
 }
 
+int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi)
+{
+	return starpu_mpi_init_comm(argc, argv, initialize_mpi, MPI_COMM_WORLD);
+}
+
 int starpu_mpi_initialize(void)
 {
 #ifdef STARPU_SIMGRID
 	STARPU_MPI_ASSERT_MSG(0, "application has to let StarPU initialize MPI");
 	return 0;
 #else
-	return _starpu_mpi_initialize(NULL, NULL, 0);
+	return _starpu_mpi_initialize(NULL, NULL, 0, MPI_COMM_WORLD);
 #endif
 }
 
@@ -1588,7 +1595,7 @@ int starpu_mpi_initialize_extended(int *rank, int *world_size)
 #else
 	int ret;
 
-	ret = _starpu_mpi_initialize(NULL, NULL, 1);
+	ret = _starpu_mpi_initialize(NULL, NULL, 1, MPI_COMM_WORLD);
 	if (ret == 0)
 	{
 		_STARPU_DEBUG("Calling MPI_Comm_rank\n");

+ 14 - 14
mpi/src/starpu_mpi_cache.c

@@ -35,6 +35,8 @@ static starpu_pthread_mutex_t *_cache_received_mutex;
 static struct _starpu_data_entry **_cache_sent_data = NULL;
 static struct _starpu_data_entry **_cache_received_data = NULL;
 int _starpu_cache_enabled=1;
+MPI_Comm _starpu_cache_comm;
+int _starpu_cache_comm_size;
 
 int starpu_mpi_cache_is_enabled()
 {
@@ -52,10 +54,8 @@ int starpu_mpi_cache_set(int enabled)
 		if (_starpu_cache_enabled)
 		{
 			// We need to clean the cache
-			int world_size;
-			starpu_mpi_cache_flush_all_data(MPI_COMM_WORLD);
-			starpu_mpi_comm_size(MPI_COMM_WORLD, &world_size);
-			_starpu_mpi_cache_free(world_size);
+			starpu_mpi_cache_flush_all_data(_starpu_cache_comm);
+			_starpu_mpi_cache_free(_starpu_cache_comm_size);
 		}
 		_starpu_cache_enabled = 0;
 	}
@@ -64,7 +64,6 @@ int starpu_mpi_cache_set(int enabled)
 
 void _starpu_mpi_cache_init(MPI_Comm comm)
 {
-	int nb_nodes;
 	int i;
 
 	_starpu_cache_enabled = starpu_get_env_number("STARPU_MPI_CACHE");
@@ -79,15 +78,16 @@ void _starpu_mpi_cache_init(MPI_Comm comm)
 		return;
 	}
 
-	starpu_mpi_comm_size(comm, &nb_nodes);
+	_starpu_cache_comm = comm;
+	starpu_mpi_comm_size(comm, &_starpu_cache_comm_size);
 	_STARPU_MPI_DEBUG(2, "Initialising htable for cache\n");
 
-	_cache_sent_data = malloc(nb_nodes * sizeof(struct _starpu_data_entry *));
-	_cache_received_data = malloc(nb_nodes * sizeof(struct _starpu_data_entry *));
-	_cache_sent_mutex = malloc(nb_nodes * sizeof(starpu_pthread_mutex_t));
-	_cache_received_mutex = malloc(nb_nodes * sizeof(starpu_pthread_mutex_t));
+	_cache_sent_data = malloc(_starpu_cache_comm_size * sizeof(struct _starpu_data_entry *));
+	_cache_received_data = malloc(_starpu_cache_comm_size * sizeof(struct _starpu_data_entry *));
+	_cache_sent_mutex = malloc(_starpu_cache_comm_size * sizeof(starpu_pthread_mutex_t));
+	_cache_received_mutex = malloc(_starpu_cache_comm_size * sizeof(starpu_pthread_mutex_t));
 
-	for(i=0 ; i<nb_nodes ; i++)
+	for(i=0 ; i<_starpu_cache_comm_size ; i++)
 	{
 		_cache_sent_data[i] = NULL;
 		_cache_received_data[i] = NULL;
@@ -129,17 +129,17 @@ void _starpu_mpi_cache_empty_tables(int world_size)
 	}
 }
 
-void _starpu_mpi_cache_free(int world_size)
+void _starpu_mpi_cache_free()
 {
 	int i;
 
 	if (_starpu_cache_enabled == 0) return;
 
-	_starpu_mpi_cache_empty_tables(world_size);
+	_starpu_mpi_cache_empty_tables(_starpu_cache_comm_size);
 	free(_cache_sent_data);
 	free(_cache_received_data);
 
-	for(i=0 ; i<world_size ; i++)
+	for(i=0 ; i<_starpu_cache_comm_size ; i++)
 	{
 		STARPU_PTHREAD_MUTEX_DESTROY(&_cache_sent_mutex[i]);
 		STARPU_PTHREAD_MUTEX_DESTROY(&_cache_received_mutex[i]);

+ 1 - 1
mpi/src/starpu_mpi_cache.h

@@ -29,7 +29,7 @@ extern "C" {
 
 extern int _starpu_cache_enabled;
 void _starpu_mpi_cache_init(MPI_Comm comm);
-void _starpu_mpi_cache_free(int world_size);
+void _starpu_mpi_cache_free();
 
 /*
  * If the data is already available in the cache, return a pointer to the data