Prechádzať zdrojové kódy

mpi: add a parameter initialize_mpi to starpu_mpi_init, then we no longer rely on MPI_Initialized() to detect if MPI is already initialised or not, we just ask the upper layer to tell us if it needs to be initialised

Nathalie Furmento 12 rokov pred
rodič
commit
cc29bc43df

+ 1 - 1
mpi/examples/Makefile.am

@@ -37,7 +37,7 @@ if STARPU_MPI_CHECK
 TESTS			=	$(starpu_mpi_EXAMPLES)
 endif
 
-check_PROGRAMS = $(LOADER)
+check_PROGRAMS = $(LOADER) $(starpu_mpi_EXAMPLES)
 starpu_mpi_EXAMPLES =
 
 BUILT_SOURCES =

+ 1 - 1
mpi/examples/cholesky/mpi_cholesky.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
 
-	starpu_mpi_init(&argc, &argv);
+	starpu_mpi_init(&argc, &argv, 1);
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);
 

+ 1 - 1
mpi/examples/cholesky/mpi_cholesky_distributed.c

@@ -42,7 +42,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);

+ 1 - 1
mpi/examples/complex/mpi_complex.c

@@ -39,7 +39,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);

+ 1 - 1
mpi/examples/mpi_lu/plu_example.c

@@ -428,7 +428,7 @@ int main(int argc, char **argv)
 	/* We disable sequential consistency in this example */
 	starpu_data_set_default_sequential_consistency_flag(0);
 
-	starpu_mpi_init(NULL, NULL);
+	starpu_mpi_init(NULL, NULL, 0);
 
 	STARPU_ASSERT(p*q == world_size);
 

+ 1 - 1
mpi/examples/stencil/stencil5.c

@@ -82,7 +82,7 @@ int main(int argc, char **argv)
 
 	int ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	starpu_mpi_init(&argc, &argv);
+	starpu_mpi_init(&argc, &argv, 1);
 	MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);
 

+ 1 - 1
mpi/include/starpu_mpi.h

@@ -39,8 +39,8 @@ int starpu_mpi_irecv_detached(starpu_data_handle_t data_handle, int source, int
 int starpu_mpi_wait(starpu_mpi_req *req, MPI_Status *status);
 int starpu_mpi_test(starpu_mpi_req *req, int *flag, MPI_Status *status);
 int starpu_mpi_barrier(MPI_Comm comm);
-int starpu_mpi_init(int *argc, char ***argv);
 
+int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi);
 int starpu_mpi_initialize(void) STARPU_DEPRECATED;
 int starpu_mpi_initialize_extended(int *rank, int *world_size) STARPU_DEPRECATED;
 int starpu_mpi_shutdown(void);

+ 9 - 10
mpi/src/starpu_mpi.c

@@ -705,6 +705,7 @@ static void _starpu_mpi_handle_new_request(struct _starpu_mpi_req *req)
 
 struct _starpu_mpi_argc_argv
 {
+	int initialize_mpi;
 	int *argc;
 	char ***argv;
 };
@@ -734,11 +735,8 @@ static void _starpu_mpi_print_thread_level_support(int thread_level, char *msg)
 static void *_starpu_mpi_progress_thread_func(void *arg)
 {
 	struct _starpu_mpi_argc_argv *argc_argv = (struct _starpu_mpi_argc_argv *) arg;
-	int flag;
 
-	MPI_Initialized(&flag);
-	_STARPU_DEBUG("MPI_Initialized %d\n", flag);
-	if (flag == 0)
+	if (argc_argv->initialize_mpi)
 	{
 		int thread_support;
                 _STARPU_DEBUG("Calling MPI_Init_thread\n");
@@ -805,7 +803,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 	STARPU_ASSERT(_starpu_mpi_req_list_empty(new_requests));
         STARPU_ASSERT(posted_requests == 0);
 
-        if (flag == 0)
+	if (argc_argv->initialize_mpi)
 	{
                 _STARPU_MPI_DEBUG("Calling MPI_Finalize()\n");
                 MPI_Finalize();
@@ -859,7 +857,7 @@ static void _starpu_mpi_add_sync_point_in_fxt(void)
 }
 
 static
-int _starpu_mpi_initialize(int *argc, char ***argv)
+int _starpu_mpi_initialize(int *argc, char ***argv, int initialize_mpi)
 {
 	_STARPU_PTHREAD_MUTEX_INIT(&mutex, NULL);
 	_STARPU_PTHREAD_COND_INIT(&cond_progression, NULL);
@@ -872,6 +870,7 @@ int _starpu_mpi_initialize(int *argc, char ***argv)
         _STARPU_PTHREAD_MUTEX_INIT(&mutex_posted_requests, NULL);
 
 	struct _starpu_mpi_argc_argv *argc_argv = malloc(sizeof(struct _starpu_mpi_argc_argv));
+	argc_argv->initialize_mpi = initialize_mpi;
 	argc_argv->argc = argc;
 	argc_argv->argv = argv;
 	_STARPU_PTHREAD_CREATE("MPI progress", &progress_thread, NULL, _starpu_mpi_progress_thread_func, argc_argv);
@@ -898,21 +897,21 @@ int _starpu_mpi_initialize(int *argc, char ***argv)
 	return 0;
 }
 
-int starpu_mpi_init(int *argc, char ***argv)
+int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi)
 {
-        return _starpu_mpi_initialize(argc, argv);
+        return _starpu_mpi_initialize(argc, argv, initialize_mpi);
 }
 
 int starpu_mpi_initialize(void)
 {
-        return _starpu_mpi_initialize(NULL, NULL);
+	return _starpu_mpi_initialize(NULL, NULL, 0);
 }
 
 int starpu_mpi_initialize_extended(int *rank, int *world_size)
 {
 	int ret;
 
-        ret = _starpu_mpi_initialize(NULL, NULL);
+        ret = _starpu_mpi_initialize(NULL, NULL, 1);
 	if (ret == 0)
 	{
 		_STARPU_DEBUG("Calling MPI_Comm_rank\n");

+ 1 - 1
mpi/tests/Makefile.am

@@ -37,7 +37,7 @@ if STARPU_MPI_CHECK
 TESTS			=	$(starpu_mpi_TESTS)
 endif
 
-check_PROGRAMS = $(LOADER)
+check_PROGRAMS = $(LOADER) $(starpu_mpi_TESTS)
 
 BUILT_SOURCES =
 

+ 1 - 1
mpi/tests/block_interface.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	/* Node 0 will allocate a big block and only register an inner part of

+ 1 - 1
mpi/tests/block_interface_pinned.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	/* Node 0 will allocate a big block and only register an inner part of

+ 1 - 1
mpi/tests/insert_task.c

@@ -54,7 +54,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_block.c

@@ -71,7 +71,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_initialize_extended");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_cache.c

@@ -61,7 +61,7 @@ void test_cache(int rank, int size, int enabled, size_t *comm_amount)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
         for(i = 0; i < 2; i++)

+ 1 - 1
mpi/tests/insert_task_owner.c

@@ -79,7 +79,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_owner2.c

@@ -55,7 +55,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_owner_data.c

@@ -45,7 +45,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/mpi_detached_tag.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_irecv.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_irecv_detached.c

@@ -58,7 +58,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_isend.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_isend_detached.c

@@ -57,7 +57,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_reduction.c

@@ -74,7 +74,7 @@ int main(int argc, char **argv)
 
 	int ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/mpi_scatter_gather.c

@@ -81,7 +81,7 @@ int main(int argc, char **argv)
 
 	int ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);

+ 1 - 1
mpi/tests/mpi_test.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/multiple_send.c

@@ -30,7 +30,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/pingpong.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/ring.c

@@ -79,7 +79,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	starpu_vector_data_register(&token_handle, 0, (uintptr_t)&token, 1, sizeof(unsigned));

+ 1 - 1
mpi/tests/ring_async.c

@@ -79,7 +79,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	starpu_vector_data_register(&token_handle, 0, (uintptr_t)&token, 1, sizeof(unsigned));

+ 1 - 1
mpi/tests/ring_async_implicit.c

@@ -65,7 +65,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/user_defined_datatype.c

@@ -54,7 +54,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);