Quellcode durchsuchen

In simgrid mode, avoid calling MPI_Comm_size/rank outside the MPI thread

Samuel Thibault vor 10 Jahren
Ursprung
Commit
4d5f8f17da

+ 3 - 1
mpi/include/starpu_mpi.h

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2009-2012, 2014  Université de Bordeaux
+ * Copyright (C) 2009-2012, 2014-2015  Université de Bordeaux
  * Copyright (C) 2010, 2011, 2012, 2013, 2014  Centre National de la Recherche Scientifique
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -73,6 +73,8 @@ void starpu_mpi_comm_amounts_retrieve(size_t *comm_amounts);
 void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle);
 void starpu_mpi_cache_flush_all_data(MPI_Comm comm);
 
+int starpu_mpi_comm_size(MPI_Comm comm, int *size);
+int starpu_mpi_comm_rank(MPI_Comm comm, int *rank);
 int starpu_mpi_world_rank(void);
 
 int starpu_mpi_get_communication_tag(void);

+ 42 - 7
mpi/src/starpu_mpi.c

@@ -63,6 +63,11 @@ static starpu_pthread_mutex_t mutex;
 static starpu_pthread_t progress_thread;
 static int running = 0;
 
+#ifdef STARPU_SIMGRID
+static int _mpi_world_size;
+static int _mpi_world_rank;
+#endif
+
 /* Count requests posted by the application and not yet submitted to MPI */
 static starpu_pthread_mutex_t mutex_posted_requests;
 static int posted_requests = 0, newer_requests, barrier_running = 0;
@@ -1099,6 +1104,8 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 	MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
 
 #ifdef STARPU_SIMGRID
+	_mpi_world_size = worldsize;
+	_mpi_world_rank = rank;
 	/* Now that MPI is set up, let the rest of simgrid get initialized */
 	MSG_process_create_with_arguments("main", smpi_simulated_main_, NULL, _starpu_simgrid_get_host_by_name("MAIN"), *(argc_argv->argc), *(argc_argv->argv));
 #endif
@@ -1349,8 +1356,8 @@ static void _starpu_mpi_add_sync_point_in_fxt(void)
 	int worldsize;
 	int ret;
 
-	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	MPI_Comm_size(MPI_COMM_WORLD, &worldsize);
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &worldsize);
 
 	ret = MPI_Barrier(MPI_COMM_WORLD);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Barrier returning %d", ret);
@@ -1419,7 +1426,7 @@ int _starpu_mpi_initialize(int *argc, char ***argv, int initialize_mpi)
  * create MSG processes to run application's main */
 int _starpu_mpi_simgrid_init(int argc, char *argv[])
 {
-	_starpu_mpi_initialize(&argc, &argv, 1);
+	return _starpu_mpi_initialize(&argc, &argv, 1);
 }
 #endif
 
@@ -1445,6 +1452,11 @@ int starpu_mpi_initialize(void)
 
 int starpu_mpi_initialize_extended(int *rank, int *world_size)
 {
+#ifdef STARPU_SIMGRID
+	*world_size = _mpi_world_size;
+	*rank = _mpi_world_rank;
+	return 0;
+#else
 	int ret;
 
 	ret = _starpu_mpi_initialize(NULL, NULL, 1);
@@ -1455,6 +1467,7 @@ int starpu_mpi_initialize_extended(int *rank, int *world_size)
 		MPI_Comm_size(MPI_COMM_WORLD, world_size);
 	}
 	return ret;
+#endif
 }
 
 int starpu_mpi_shutdown(void)
@@ -1463,8 +1476,8 @@ int starpu_mpi_shutdown(void)
 	int rank, world_size;
 
 	/* We need to get the rank before calling MPI_Finalize to pass to _starpu_mpi_comm_amounts_display() */
-	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-	MPI_Comm_size(MPI_COMM_WORLD, &world_size);
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &world_size);
 
 	/* kill the progression thread */
 	STARPU_PTHREAD_MUTEX_LOCK(&mutex);
@@ -1502,7 +1515,7 @@ void starpu_mpi_data_register(starpu_data_handle_t data_handle, int tag, int ran
 #warning see if the following code is really needed, it deadlocks some applications
 #if 0
 	int my;
-	MPI_Comm_rank(MPI_COMM_WORLD, &my);
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &my);
 	if (my != rank)
 		STARPU_ASSERT_MSG(data_handle->home_node == -1, "Data does not belong to node %d, it should be assigned a home node -1", my);
 #endif
@@ -1512,9 +1525,31 @@ void starpu_mpi_data_register(starpu_data_handle_t data_handle, int tag, int ran
 	_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_clear_cache);
 }
 
+int starpu_mpi_comm_size(MPI_Comm comm, int *size)
+{
+#ifdef STARPU_SIMGRID
+	STARPU_ASSERT_MSG(comm == MPI_COMM_WORLD, "StarPU-SMPI only works with COMM_WORLD for now");
+	*size = _mpi_world_size;
+	return 0;
+#else
+	return MPI_Comm_size(comm, size);
+#endif
+}
+
+int starpu_mpi_comm_rank(MPI_Comm comm, int *rank)
+{
+#ifdef STARPU_SIMGRID
+	STARPU_ASSERT_MSG(comm == MPI_COMM_WORLD, "StarPU-SMPI only works with COMM_WORLD for now");
+	*rank = _mpi_world_rank;
+	return 0;
+#else
+	return MPI_Comm_rank(comm, rank);
+#endif
+}
+
 int starpu_mpi_world_rank(void)
 {
 	int rank;
-	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 	return rank;
 }

+ 7 - 7
mpi/src/starpu_mpi_cache.c

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2011, 2012, 2013, 2014  Centre National de la Recherche Scientifique
- * Copyright (C) 2011-2014  Université de Bordeaux
+ * Copyright (C) 2011-2015  Université de Bordeaux
  * Copyright (C) 2014 INRIA
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -53,7 +53,7 @@ void _starpu_mpi_cache_init(MPI_Comm comm)
 		return;
 	}
 
-	MPI_Comm_size(comm, &nb_nodes);
+	starpu_mpi_comm_size(comm, &nb_nodes);
 	_STARPU_MPI_DEBUG(2, "Initialising htable for cache\n");
 
 	_cache_sent_data = malloc(nb_nodes * sizeof(struct _starpu_data_entry *));
@@ -126,7 +126,7 @@ void _starpu_mpi_cache_free(int world_size)
 void _starpu_mpi_cache_flush_sent(MPI_Comm comm, starpu_data_handle_t data)
 {
 	int n, size;
-	MPI_Comm_size(comm, &size);
+	starpu_mpi_comm_size(comm, &size);
 
 	for(n=0 ; n<size ; n++)
 	{
@@ -172,8 +172,8 @@ void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
 
 	if (_starpu_cache_enabled == 0) return;
 
-	MPI_Comm_size(comm, &nb_nodes);
-	MPI_Comm_rank(comm, &my_rank);
+	starpu_mpi_comm_size(comm, &nb_nodes);
+	starpu_mpi_comm_rank(comm, &my_rank);
 
 	for(i=0 ; i<nb_nodes ; i++)
 	{
@@ -212,8 +212,8 @@ void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 
 	if (_starpu_cache_enabled == 0) return;
 
-	MPI_Comm_size(comm, &nb_nodes);
-	MPI_Comm_rank(comm, &my_rank);
+	starpu_mpi_comm_size(comm, &nb_nodes);
+	starpu_mpi_comm_rank(comm, &my_rank);
 	mpi_rank = starpu_data_get_rank(data_handle);
 
 	for(i=0 ; i<nb_nodes ; i++)

+ 1 - 1
mpi/src/starpu_mpi_cache_stats.c

@@ -36,7 +36,7 @@ void _starpu_mpi_cache_stats_init(MPI_Comm comm)
 
 	if (!getenv("STARPU_SILENT")) fprintf(stderr,"Warning: StarPU is executed with STARPU_MPI_CACHE_STATS=1, which slows down a bit\n");
 
-	MPI_Comm_size(comm, &world_size);
+	starpu_mpi_comm_size(comm, &world_size);
 	_STARPU_MPI_DEBUG(1, "allocating for %d nodes\n", world_size);
 
 	comm_cache_amount = (size_t *) calloc(world_size, sizeof(size_t));

+ 2 - 2
mpi/src/starpu_mpi_collective.c

@@ -47,7 +47,7 @@ int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, i
 	void (*callback_func)(void *) = NULL;
 	void (*callback)(void *);
 
-	MPI_Comm_rank(comm, &rank);
+	starpu_mpi_comm_rank(comm, &rank);
 
 	callback = (rank == root) ? scallback : rcallback;
 	if (callback)
@@ -108,7 +108,7 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
 	void (*callback_func)(void *) = NULL;
 	void (*callback)(void *);
 
-	MPI_Comm_rank(comm, &rank);
+	starpu_mpi_comm_rank(comm, &rank);
 
 	callback = (rank == root) ? scallback : rcallback;
 	if (callback)

+ 5 - 5
mpi/src/starpu_mpi_private.h

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2010, 2012-2014  Université de Bordeaux
+ * Copyright (C) 2010, 2012-2015  Université de Bordeaux
  * Copyright (C) 2010, 2011, 2012, 2013, 2014  Centre National de la Recherche Scientifique
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -44,7 +44,7 @@ void _starpu_mpi_set_debug_level_max(int level);
 	{								\
 		if (!getenv("STARPU_SILENT") && _starpu_debug_level_min <= level && level <= _starpu_debug_level_max)	\
 		{							\
-			if (_starpu_debug_rank == -1) MPI_Comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank); \
+			if (_starpu_debug_rank == -1) starpu_mpi_comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank); \
 			fprintf(stderr, "%*s[%d][starpu_mpi][%s:%d] " fmt , (_starpu_debug_rank+1)*4, "", _starpu_debug_rank, __starpu_func__ , __LINE__,## __VA_ARGS__); \
 			fflush(stderr); \
 		}			\
@@ -54,17 +54,17 @@ void _starpu_mpi_set_debug_level_max(int level);
 #endif
 
 #define _STARPU_MPI_DISP(fmt, ...) do { if (!getenv("STARPU_SILENT")) { \
-	       				     if (_starpu_debug_rank == -1) MPI_Comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank); \
+	       				     if (_starpu_debug_rank == -1) starpu_mpi_comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank); \
                                              fprintf(stderr, "%*s[%d][starpu_mpi][%s:%d] " fmt , (_starpu_debug_rank+1)*4, "", _starpu_debug_rank, __starpu_func__ , __LINE__ ,## __VA_ARGS__); \
                                              fflush(stderr); }} while(0);
 
 #ifdef STARPU_VERBOSE0
 #  define _STARPU_MPI_LOG_IN()             do { if (!getenv("STARPU_SILENT")) { \
-                                               if (_starpu_debug_rank == -1) MPI_Comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank);                        \
+                                               if (_starpu_debug_rank == -1) starpu_mpi_comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank);                        \
                                                fprintf(stderr, "%*s[%d][starpu_mpi][%s:%d] -->\n", (_starpu_debug_rank+1)*4, "", _starpu_debug_rank, __starpu_func__ , __LINE__); \
                                                fflush(stderr); }} while(0)
 #  define _STARPU_MPI_LOG_OUT()            do { if (!getenv("STARPU_SILENT")) { \
-                                               if (_starpu_debug_rank == -1) MPI_Comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank);                        \
+                                               if (_starpu_debug_rank == -1) starpu_mpi_comm_rank(MPI_COMM_WORLD, &_starpu_debug_rank);                        \
                                                fprintf(stderr, "%*s[%d][starpu_mpi][%s:%d] <--\n", (_starpu_debug_rank+1)*4, "", _starpu_debug_rank, __starpu_func__, __LINE__ ); \
                                                fflush(stderr); }} while(0)
 #else

+ 2 - 2
mpi/src/starpu_mpi_stats.c

@@ -36,7 +36,7 @@ void _starpu_mpi_comm_amounts_init(MPI_Comm comm)
 
 	if (!getenv("STARPU_SILENT")) fprintf(stderr,"Warning: StarPU is executed with STARPU_COMM_STATS=1, which slows down a bit\n");
 
-	MPI_Comm_size(comm, &world_size);
+	starpu_mpi_comm_size(comm, &world_size);
 	_STARPU_MPI_DEBUG(1, "allocating for %d nodes\n", world_size);
 
 	comm_amount = (size_t *) calloc(world_size, sizeof(size_t));
@@ -54,7 +54,7 @@ void _starpu_mpi_comm_amounts_inc(MPI_Comm comm, unsigned dst, MPI_Datatype data
 
 	if (stats_enabled == 0) return;
 
-	MPI_Comm_rank(comm, &src);
+	starpu_mpi_comm_rank(comm, &src);
 	MPI_Type_size(datatype, &size);
 
 	_STARPU_MPI_DEBUG(1, "[%d] adding %d to %d\n", src, count*size, dst);

+ 10 - 10
mpi/src/starpu_mpi_task_insert.c

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2011, 2012, 2013, 2014  Centre National de la Recherche Scientifique
- * Copyright (C) 2011-2014  Université de Bordeaux
+ * Copyright (C) 2011-2015  Université de Bordeaux
  * Copyright (C) 2014 INRIA
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -433,8 +433,8 @@ int _starpu_mpi_task_build_v(MPI_Comm comm, struct starpu_codelet *codelet, stru
 
 	_STARPU_MPI_LOG_IN();
 
-	MPI_Comm_rank(comm, &me);
-	MPI_Comm_size(comm, &nb_nodes);
+	starpu_mpi_comm_rank(comm, &me);
+	starpu_mpi_comm_size(comm, &nb_nodes);
 
 	/* Find out whether we are to execute the data because we own the data to be written to. */
 	ret = _starpu_mpi_task_decode_v(codelet, me, nb_nodes, &xrank, &do_execute, &descrs, &nb_data, varg_list);
@@ -470,7 +470,7 @@ int _starpu_mpi_task_postbuild_v(MPI_Comm comm, int xrank, int do_execute, struc
 {
 	int me, i;
 
-	MPI_Comm_rank(comm, &me);
+	starpu_mpi_comm_rank(comm, &me);
 
 	for(i=0 ; i<nb_data ; i++)
 	{
@@ -558,8 +558,8 @@ int starpu_mpi_task_post_build(MPI_Comm comm, struct starpu_codelet *codelet, ..
 	struct starpu_data_descr *descrs;
 	int nb_data;
 
-	MPI_Comm_rank(comm, &me);
-	MPI_Comm_size(comm, &nb_nodes);
+	starpu_mpi_comm_rank(comm, &me);
+	starpu_mpi_comm_size(comm, &nb_nodes);
 
 	va_start(varg_list, codelet);
 	/* Find out whether we are to execute the data because we own the data to be written to. */
@@ -584,7 +584,7 @@ void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t da
 	{
 		_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register() or starpu_data_set_tag()\n");
 	}
-	MPI_Comm_rank(comm, &me);
+	starpu_mpi_comm_rank(comm, &me);
 
 	if (node == rank) return;
 
@@ -614,7 +614,7 @@ void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle
 		fprintf(stderr,"StarPU needs to be told the MPI tag of this data, using starpu_data_set_tag\n");
 		STARPU_ABORT();
 	}
-	MPI_Comm_rank(comm, &me);
+	starpu_mpi_comm_rank(comm, &me);
 
 	if (node == rank) return;
 
@@ -705,8 +705,8 @@ void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle)
 		STARPU_ABORT();
 	}
 
-	MPI_Comm_rank(comm, &me);
-	MPI_Comm_size(comm, &nb_nodes);
+	starpu_mpi_comm_rank(comm, &me);
+	starpu_mpi_comm_size(comm, &nb_nodes);
 
 	_STARPU_MPI_DEBUG(1, "Doing reduction for data %p on node %d with %d nodes ...\n", data_handle, rank, nb_nodes);
 

+ 1 - 0
src/common/thread.c

@@ -49,6 +49,7 @@ int starpu_pthread_create_on(char *name, starpu_pthread_t *thread, const starpu_
 	_args->arg = arg;
 	if (!host)
 		host = MSG_get_host_by_name("MAIN");
+	fprintf(stderr,"starting %p on %s\n", start_routine, MSG_host_get_name(host));
 	*thread = MSG_process_create(name, _starpu_simgrid_thread_start, _args, host);
 	return 0;
 }

+ 1 - 0
src/core/workers.c

@@ -1095,6 +1095,7 @@ int starpu_initialize(struct starpu_conf *user_conf, int *argc, char ***argv)
 #endif
 
 	STARPU_PTHREAD_MUTEX_LOCK(&init_mutex);
+	fprintf(stderr,"initialized is %d in %d\n", initialized, starpu_mpi_world_rank());
 	while (initialized == CHANGING)
 		/* Wait for the other one changing it */
 		STARPU_PTHREAD_COND_WAIT(&init_cond, &init_mutex);