Explorar o código

mpi: add a parameter initialize_mpi to starpu_mpi_init, then we no longer rely on MPI_Initialized() to detect if MPI is already initialised or not, we just ask the upper layer to tell us if it needs to be initialised

Nathalie Furmento %!s(int64=12) %!d(string=hai) anos
pai
achega
cc29bc43df

+ 1 - 1
mpi/examples/Makefile.am

@@ -37,7 +37,7 @@ if STARPU_MPI_CHECK
 TESTS			=	$(starpu_mpi_EXAMPLES)
 endif
 
-check_PROGRAMS = $(LOADER)
+check_PROGRAMS = $(LOADER) $(starpu_mpi_EXAMPLES)
 starpu_mpi_EXAMPLES =
 
 BUILT_SOURCES =

+ 1 - 1
mpi/examples/cholesky/mpi_cholesky.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
 
-	starpu_mpi_init(&argc, &argv);
+	starpu_mpi_init(&argc, &argv, 1);
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);
 

+ 1 - 1
mpi/examples/cholesky/mpi_cholesky_distributed.c

@@ -42,7 +42,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);

+ 1 - 1
mpi/examples/complex/mpi_complex.c

@@ -39,7 +39,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);

+ 1 - 1
mpi/examples/mpi_lu/plu_example.c

@@ -428,7 +428,7 @@ int main(int argc, char **argv)
 	/* We disable sequential consistency in this example */
 	starpu_data_set_default_sequential_consistency_flag(0);
 
-	starpu_mpi_init(NULL, NULL);
+	starpu_mpi_init(NULL, NULL, 0);
 
 	STARPU_ASSERT(p*q == world_size);
 

+ 1 - 1
mpi/examples/stencil/stencil5.c

@@ -82,7 +82,7 @@ int main(int argc, char **argv)
 
 	int ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	starpu_mpi_init(&argc, &argv);
+	starpu_mpi_init(&argc, &argv, 1);
 	MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);
 

+ 1 - 1
mpi/include/starpu_mpi.h

@@ -39,8 +39,8 @@ int starpu_mpi_irecv_detached(starpu_data_handle_t data_handle, int source, int
 int starpu_mpi_wait(starpu_mpi_req *req, MPI_Status *status);
 int starpu_mpi_test(starpu_mpi_req *req, int *flag, MPI_Status *status);
 int starpu_mpi_barrier(MPI_Comm comm);
-int starpu_mpi_init(int *argc, char ***argv);
 
+int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi);
 int starpu_mpi_initialize(void) STARPU_DEPRECATED;
 int starpu_mpi_initialize_extended(int *rank, int *world_size) STARPU_DEPRECATED;
 int starpu_mpi_shutdown(void);

+ 9 - 10
mpi/src/starpu_mpi.c

@@ -705,6 +705,7 @@ static void _starpu_mpi_handle_new_request(struct _starpu_mpi_req *req)
 
 struct _starpu_mpi_argc_argv
 {
+	int initialize_mpi;
 	int *argc;
 	char ***argv;
 };
@@ -734,11 +735,8 @@ static void _starpu_mpi_print_thread_level_support(int thread_level, char *msg)
 static void *_starpu_mpi_progress_thread_func(void *arg)
 {
 	struct _starpu_mpi_argc_argv *argc_argv = (struct _starpu_mpi_argc_argv *) arg;
-	int flag;
 
-	MPI_Initialized(&flag);
-	_STARPU_DEBUG("MPI_Initialized %d\n", flag);
-	if (flag == 0)
+	if (argc_argv->initialize_mpi)
 	{
 		int thread_support;
                 _STARPU_DEBUG("Calling MPI_Init_thread\n");
@@ -805,7 +803,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 	STARPU_ASSERT(_starpu_mpi_req_list_empty(new_requests));
         STARPU_ASSERT(posted_requests == 0);
 
-        if (flag == 0)
+	if (argc_argv->initialize_mpi)
 	{
                 _STARPU_MPI_DEBUG("Calling MPI_Finalize()\n");
                 MPI_Finalize();
@@ -859,7 +857,7 @@ static void _starpu_mpi_add_sync_point_in_fxt(void)
 }
 
 static
-int _starpu_mpi_initialize(int *argc, char ***argv)
+int _starpu_mpi_initialize(int *argc, char ***argv, int initialize_mpi)
 {
 	_STARPU_PTHREAD_MUTEX_INIT(&mutex, NULL);
 	_STARPU_PTHREAD_COND_INIT(&cond_progression, NULL);
@@ -872,6 +870,7 @@ int _starpu_mpi_initialize(int *argc, char ***argv)
         _STARPU_PTHREAD_MUTEX_INIT(&mutex_posted_requests, NULL);
 
 	struct _starpu_mpi_argc_argv *argc_argv = malloc(sizeof(struct _starpu_mpi_argc_argv));
+	argc_argv->initialize_mpi = initialize_mpi;
 	argc_argv->argc = argc;
 	argc_argv->argv = argv;
 	_STARPU_PTHREAD_CREATE("MPI progress", &progress_thread, NULL, _starpu_mpi_progress_thread_func, argc_argv);
@@ -898,21 +897,21 @@ int _starpu_mpi_initialize(int *argc, char ***argv)
 	return 0;
 }
 
-int starpu_mpi_init(int *argc, char ***argv)
+int starpu_mpi_init(int *argc, char ***argv, int initialize_mpi)
 {
-        return _starpu_mpi_initialize(argc, argv);
+        return _starpu_mpi_initialize(argc, argv, initialize_mpi);
 }
 
 int starpu_mpi_initialize(void)
 {
-        return _starpu_mpi_initialize(NULL, NULL);
+	return _starpu_mpi_initialize(NULL, NULL, 0);
 }
 
 int starpu_mpi_initialize_extended(int *rank, int *world_size)
 {
 	int ret;
 
-        ret = _starpu_mpi_initialize(NULL, NULL);
+        ret = _starpu_mpi_initialize(NULL, NULL, 1);
 	if (ret == 0)
 	{
 		_STARPU_DEBUG("Calling MPI_Comm_rank\n");

+ 1 - 1
mpi/tests/Makefile.am

@@ -37,7 +37,7 @@ if STARPU_MPI_CHECK
 TESTS			=	$(starpu_mpi_TESTS)
 endif
 
-check_PROGRAMS = $(LOADER)
+check_PROGRAMS = $(LOADER) $(starpu_mpi_TESTS)
 
 BUILT_SOURCES =
 

+ 1 - 1
mpi/tests/block_interface.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	/* Node 0 will allocate a big block and only register an inner part of

+ 1 - 1
mpi/tests/block_interface_pinned.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	/* Node 0 will allocate a big block and only register an inner part of

+ 1 - 1
mpi/tests/insert_task.c

@@ -54,7 +54,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_block.c

@@ -71,7 +71,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_initialize_extended");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_cache.c

@@ -61,7 +61,7 @@ void test_cache(int rank, int size, int enabled, size_t *comm_amount)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
         for(i = 0; i < 2; i++)

+ 1 - 1
mpi/tests/insert_task_owner.c

@@ -79,7 +79,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_owner2.c

@@ -55,7 +55,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/insert_task_owner_data.c

@@ -45,7 +45,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/mpi_detached_tag.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_irecv.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_irecv_detached.c

@@ -58,7 +58,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_isend.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_isend_detached.c

@@ -57,7 +57,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/mpi_reduction.c

@@ -74,7 +74,7 @@ int main(int argc, char **argv)
 
 	int ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/mpi_scatter_gather.c

@@ -81,7 +81,7 @@ int main(int argc, char **argv)
 
 	int ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);

+ 1 - 1
mpi/tests/mpi_test.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/multiple_send.c

@@ -30,7 +30,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/pingpong.c

@@ -43,7 +43,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	tab = malloc(SIZE*sizeof(float));

+ 1 - 1
mpi/tests/ring.c

@@ -79,7 +79,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	starpu_vector_data_register(&token_handle, 0, (uintptr_t)&token, 1, sizeof(unsigned));

+ 1 - 1
mpi/tests/ring_async.c

@@ -79,7 +79,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	starpu_vector_data_register(&token_handle, 0, (uintptr_t)&token, 1, sizeof(unsigned));

+ 1 - 1
mpi/tests/ring_async_implicit.c

@@ -65,7 +65,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL);
+	ret = starpu_mpi_init(NULL, NULL, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &size);

+ 1 - 1
mpi/tests/user_defined_datatype.c

@@ -54,7 +54,7 @@ int main(int argc, char **argv)
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(&argc, &argv);
+	ret = starpu_mpi_init(&argc, &argv, 1);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
 	MPI_Comm_size(MPI_COMM_WORLD, &nodes);