Browse Source

mpi/tests/: update code so that it runs with simgrid

Nathalie Furmento 8 years ago
parent
commit
90eacff9dd

+ 6 - 4
mpi/tests/block_interface.c

@@ -27,16 +27,18 @@
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/block_interface_pinned.c

@@ -27,16 +27,18 @@
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/datatypes.c

@@ -332,16 +332,18 @@ int main(int argc, char **argv)
 {
 	int ret, rank, size;
 	int error=0;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/early_request.c

@@ -191,16 +191,18 @@ int main(int argc, char * argv[])
 	/* Init */
 	int ret;
 	int mpi_rank, mpi_size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &mpi_rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &mpi_size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &mpi_rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &mpi_size);
+
 	if (starpu_cpu_worker_get_count() == 0)
 	{
 		if (mpi_rank == 0)

+ 6 - 4
mpi/tests/gather.c

@@ -22,16 +22,18 @@ int main(int argc, char **argv)
 	int ret, rank, size;
 	starpu_data_handle_t handle;
 	int var;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size<3)
 	{
 		FPRINTF(stderr, "We need more than 2 processes.\n");

+ 6 - 4
mpi/tests/gather2.c

@@ -20,16 +20,18 @@
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size<3)
 	{
 		FPRINTF(stderr, "We need more than 2 processes.\n");

+ 13 - 4
mpi/tests/helper.h

@@ -16,16 +16,17 @@
 
 #include <errno.h>
 #include <starpu_mpi.h>
+#include <starpu_config.h>
 #include "../../tests/helper.h"
 
 #define FPRINTF(ofile, fmt, ...) do { if (!getenv("STARPU_SSILENT")) {fprintf(ofile, fmt, ## __VA_ARGS__); }} while(0)
 #define FPRINTF_MPI(ofile, fmt, ...) do { if (!getenv("STARPU_SSILENT")) { \
 			int _disp_rank; starpu_mpi_comm_rank(MPI_COMM_WORLD, &_disp_rank); \
 			fprintf(ofile, "[%d][starpu_mpi][%s] " fmt , _disp_rank, __starpu_func__ ,## __VA_ARGS__); \
-			fflush(ofile); }} while(0);
+			fflush(ofile); }} while(0)
 
-#define MPI_INIT_THREAD(argc, argv, required) do {	    \
-		int thread_support;					\
+#define MPI_INIT_THREAD_real(argc, argv, required) do {	\
+		int thread_support;				\
 		if (MPI_Init_thread(argc, argv, required, &thread_support) != MPI_SUCCESS) \
 		{						\
 			fprintf(stderr,"MPI_Init_thread failed\n");	\
@@ -34,5 +35,13 @@
 		if (thread_support == MPI_THREAD_FUNNELED)		\
 			fprintf(stderr,"Warning: MPI only has funneled thread support, not serialized, hoping this will work\n"); \
 		if (thread_support < MPI_THREAD_FUNNELED)		\
-			fprintf(stderr,"Warning: MPI does not have thread support!\n"); } while(0);
+			fprintf(stderr,"Warning: MPI does not have thread support!\n"); } while(0)
+
+#ifdef STARPU_SIMGRID
+#define MPI_INIT_THREAD(argc, argv, required, init) do { *(init) = 1 ; } while(0)
+#else
+#define MPI_INIT_THREAD(argc, argv, required, init) do {	\
+		*(init) = 0;                                    \
+		MPI_INIT_THREAD_real(argc, argv, required); } while(0)
+#endif
 

+ 1 - 1
mpi/tests/insert_task_compute.c

@@ -226,7 +226,7 @@ int main(int argc, char **argv)
 	int after_node[2][4] = {{220, 20, 11, 22}, {220, 20, 11, 22}};
 	int node, insert_task, data_array;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD_real(&argc, &argv, MPI_THREAD_SERIALIZED);
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 
 	global_ret = 0;

+ 6 - 4
mpi/tests/insert_task_count.c

@@ -63,16 +63,18 @@ int main(int argc, char **argv)
 	int ret, rank, size;
 	int token = 0;
 	starpu_data_handle_t token_handle;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2 || (starpu_cpu_worker_get_count() + starpu_cuda_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 5 - 3
mpi/tests/insert_task_dyn_handles.c

@@ -73,15 +73,17 @@ int main(int argc, char **argv)
         starpu_data_handle_t *data_handles;
         starpu_data_handle_t factor_handle;
 	struct starpu_data_descr *descrs;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+
 	if (starpu_cpu_worker_get_count() == 0)
 	{
 		if (rank == 0)

+ 1 - 1
mpi/tests/insert_task_recv_cache.c

@@ -137,7 +137,7 @@ int main(int argc, char **argv)
 	size_t *comm_amount_with_cache;
 	size_t *comm_amount_without_cache;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD_real(&argc, &argv, MPI_THREAD_SERIALIZED);
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
 

+ 1 - 1
mpi/tests/insert_task_sent_cache.c

@@ -143,7 +143,7 @@ int main(int argc, char **argv)
 	size_t *comm_amount_with_cache;
 	size_t *comm_amount_without_cache;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD_real(&argc, &argv, MPI_THREAD_SERIALIZED);
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
 

+ 3 - 2
mpi/tests/load_balancer.c

@@ -46,14 +46,15 @@ int main(int argc, char **argv)
 {
 	int ret;
 	struct starpu_mpi_lb_conf itf;
+	int mpi_init;
 
 	itf.get_neighbors = get_neighbors;
 	itf.get_data_unit_to_migrate = get_data_unit_to_migrate;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	unsetenv("STARPU_MPI_LB");

+ 6 - 4
mpi/tests/matrix2.c

@@ -58,16 +58,18 @@ int main(int argc, char **argv)
 	unsigned X[N];
 	starpu_data_handle_t data_A[N];
 	starpu_data_handle_t data_X[N];
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if ((size < 3) || (starpu_cpu_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/mpi_detached_tag.c

@@ -33,16 +33,18 @@ starpu_data_handle_t tab_handle;
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/mpi_earlyrecv.c

@@ -25,16 +25,18 @@ int main(int argc, char **argv)
 	starpu_data_handle_t tab_handle[4];
 	int values[4];
 	starpu_mpi_req request[2] = {NULL, NULL};
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		FPRINTF_MPI(stderr, "We need a even number of processes.\n");

+ 6 - 4
mpi/tests/mpi_earlyrecv2.c

@@ -207,16 +207,18 @@ int main(int argc, char **argv)
 {
 	int ret=0, global_ret=0;
 	int rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		FPRINTF(stderr, "We need a even number of processes.\n");

+ 1 - 1
mpi/tests/mpi_earlyrecv2_sync.c

@@ -211,7 +211,7 @@ int main(int argc, char **argv)
 	int ret=0, global_ret=0;
 	int rank, size;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD_real(&argc, &argv, MPI_THREAD_SERIALIZED);
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
 

+ 6 - 4
mpi/tests/mpi_irecv.c

@@ -31,16 +31,18 @@ starpu_data_handle_t tab_handle;
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/mpi_irecv_detached.c

@@ -48,16 +48,18 @@ void callback(void *arg STARPU_ATTRIBUTE_UNUSED)
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/mpi_isend.c

@@ -31,16 +31,18 @@ starpu_data_handle_t tab_handle;
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/mpi_isend_detached.c

@@ -47,16 +47,18 @@ int main(int argc, char **argv)
 	int ret, rank, size;
 	float *tab;
 	starpu_data_handle_t tab_handle;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 8 - 6
mpi/tests/mpi_redux.c

@@ -36,18 +36,20 @@ int main(int argc, char **argv)
 	int ret, rank, size, sum;
 	int value=0;
 	starpu_data_handle_t *handles;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
-
-	sum = ((size-1) * (size) / 2);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
+	sum = ((size-1) * (size) / 2);
+
 	if (rank == 0)
 	{
 		int src;

+ 6 - 4
mpi/tests/mpi_test.c

@@ -31,16 +31,18 @@ int main(int argc, char **argv)
 	int ret, rank, size;
 	float *tab;
 	starpu_data_handle_t tab_handle;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/pingpong.c

@@ -32,16 +32,18 @@ starpu_data_handle_t tab_handle;
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size%2 != 0)
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/policy_register.c

@@ -69,16 +69,18 @@ int main(int argc, char **argv)
 	int policy;
 	struct starpu_task *task;
 	starpu_data_handle_t handles[2];
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2)
 	{
 		if (rank == 0)

+ 5 - 3
mpi/tests/policy_selection.c

@@ -56,16 +56,18 @@ int main(int argc, char **argv)
 	int policy = 12;
 	struct starpu_task *task;
 	starpu_data_handle_t handles[3];
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
 	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 3)
 	{
 		if (rank == 0)

+ 5 - 3
mpi/tests/policy_selection2.c

@@ -54,16 +54,18 @@ int main(int argc, char **argv)
 	int rank, size;
 	int data[3];
 	starpu_data_handle_t handles[3];
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
 	ret = starpu_mpi_init(NULL, NULL, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if ((size < 3) || (starpu_cpu_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/ring.c

@@ -76,16 +76,18 @@ void increment_token(void)
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2 || (starpu_cpu_worker_get_count() + starpu_cuda_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/ring_async.c

@@ -76,16 +76,18 @@ void increment_token(void)
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2 || (starpu_cpu_worker_get_count() + starpu_cuda_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/ring_sync.c

@@ -76,16 +76,18 @@ void increment_token(void)
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2 || (starpu_cpu_worker_get_count() + starpu_cuda_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 6 - 4
mpi/tests/ring_sync_detached.c

@@ -88,16 +88,18 @@ int main(int argc, char **argv)
 	int ret, rank, size;
 	int token = 42;
 	starpu_data_handle_t token_handle;
+	int mpi_init;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+
 	if (size < 2 || (starpu_cpu_worker_get_count() + starpu_cuda_worker_get_count() == 0))
 	{
 		if (rank == 0)

+ 3 - 2
mpi/tests/starpu_redefine.c

@@ -21,14 +21,15 @@ int main(int argc, char **argv)
 {
 	int ret;
 	starpu_data_handle_t handle;
+	int mpi_init;
 
 	disable_coredump();
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
 	ret = starpu_init(NULL);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
-	ret = starpu_mpi_init(NULL, NULL, 0);
+	ret = starpu_mpi_init(NULL, NULL, mpi_init);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init");
 
 	starpu_vector_data_register(&handle, STARPU_MAIN_RAM, (uintptr_t)&ret, 1, sizeof(int));

+ 3 - 3
mpi/tests/sync.c

@@ -24,9 +24,9 @@ int main(int argc, char **argv)
 	int ret;
 	starpu_data_handle_t data[2];
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
-        starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
-        starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
+	MPI_INIT_THREAD_real(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+        MPI_Comm_size(MPI_COMM_WORLD, &size);
 
         if (size % 2)
         {

+ 1 - 1
mpi/tests/tags_checking.c

@@ -124,7 +124,7 @@ int main(int argc, char **argv)
 	int ret=0;
 	int sdetached, rdetached;
 
-	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED);
+	MPI_INIT_THREAD_real(&argc, &argv, MPI_THREAD_SERIALIZED);
         starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
         starpu_mpi_comm_size(MPI_COMM_WORLD, &size);