Browse Source

mpi/tests/burst.c: the MPI backend requires to call starpu_mpi_wait to properly release data from a non blocking communication

Nathalie Furmento 5 years ago
parent
commit
4d39bf6c8d
1 changed files with 157 additions and 115 deletions
  1. 157 115
      mpi/tests/burst.c

+ 157 - 115
mpi/tests/burst.c

@@ -23,7 +23,7 @@
 #include <starpu_mpi.h>
 #include "helper.h"
 
-#ifdef STARPU_SIMGRID
+#if defined(STARPU_SIMGRID) || defined(STARPU_QUICK_CHECK)
 #define NB_REQUESTS 10
 #else
 #define NB_REQUESTS 500
@@ -52,7 +52,6 @@ int main(int argc, char **argv)
 	float* send_buffers[NB_REQUESTS];
 	starpu_mpi_req recv_reqs[NB_REQUESTS];
 	starpu_mpi_req send_reqs[NB_REQUESTS];
-	MPI_Status status;
 
 	MPI_INIT_THREAD(&argc, &argv, MPI_THREAD_SERIALIZED, &mpi_init);
 
@@ -62,164 +61,207 @@ int main(int argc, char **argv)
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
 	starpu_mpi_comm_size(MPI_COMM_WORLD, &size);
 
-	if (rank > 1)
-	{
-		starpu_mpi_barrier(MPI_COMM_WORLD);
-		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
-
-		starpu_mpi_barrier(MPI_COMM_WORLD);
-		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
-
-		starpu_mpi_barrier(MPI_COMM_WORLD);
-		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
-
-		starpu_mpi_barrier(MPI_COMM_WORLD);
-		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
-
-		starpu_mpi_shutdown();
-		if (!mpi_init)
-			MPI_Finalize();
-
-		return 0;
-	}
-
 	other_rank = (rank == 0) ? 1 : 0;
 
-
-	/* Burst simultaneous from both nodes */
-	if (rank == 0)
+	if (rank == 0 || rank == 1)
 	{
-		printf("Simultaneous....\n");
-	}
-
-	for (int i = 0; i < NB_REQUESTS; i++)
-	{
-		send_buffers[i] = malloc(NX_ARRAY * sizeof(float));
-		memset(send_buffers[i], 0, NX_ARRAY * sizeof(float));
-		starpu_vector_data_register(&send_handles[i], STARPU_MAIN_RAM, (uintptr_t) send_buffers[i], NX_ARRAY, sizeof(float));
-
-		recv_buffers[i] = malloc(NX_ARRAY * sizeof(float));
-		memset(recv_buffers[i], 0, NX_ARRAY * sizeof(float));
-		starpu_vector_data_register(&recv_handles[i], STARPU_MAIN_RAM, (uintptr_t) recv_buffers[i], NX_ARRAY, sizeof(float));
+		for (int i = 0; i < NB_REQUESTS; i++)
+		{
+			send_buffers[i] = malloc(NX_ARRAY * sizeof(float));
+			memset(send_buffers[i], 0, NX_ARRAY * sizeof(float));
+			starpu_vector_data_register(&send_handles[i], STARPU_MAIN_RAM, (uintptr_t) send_buffers[i], NX_ARRAY, sizeof(float));
 
-		starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			recv_buffers[i] = malloc(NX_ARRAY * sizeof(float));
+			memset(recv_buffers[i], 0, NX_ARRAY * sizeof(float));
+			starpu_vector_data_register(&recv_handles[i], STARPU_MAIN_RAM, (uintptr_t) recv_buffers[i], NX_ARRAY, sizeof(float));
+		}
 	}
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
-
-	for (int i = 0; i < NB_REQUESTS; i++)
 	{
-		starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
-	}
+		/* Burst simultaneous from both nodes: 0 and 1 post all the recvs, synchronise, and then post all the sends */
+		FPRINTF(stderr, "Simultaneous....start (rank %d)\n", rank);
 
-	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+		if (rank == 0 || rank == 1)
+		{
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				recv_reqs[i] = NULL;
+				starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			}
+		}
 
+		starpu_mpi_barrier(MPI_COMM_WORLD);
 
-	/* Burst from 0 to 1 */
-	if (rank == 0)
-	{
-		printf("Done.\n");
-		printf("0 -> 1...\n");
-	}
-	else
-	{
-		for (int i = 0; i < NB_REQUESTS; i++)
+		if (rank == 0 || rank == 1)
 		{
-			starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				send_reqs[i] = NULL;
+				starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			}
 		}
-	}
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
-
-	if (rank == 0)
-	{
-		for (int i = 0; i < NB_REQUESTS; i++)
+		if (rank == 0 || rank == 1)
 		{
-			starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				if (recv_reqs[i]) starpu_mpi_wait(&recv_reqs[i], MPI_STATUS_IGNORE);
+				if (send_reqs[i]) starpu_mpi_wait(&send_reqs[i], MPI_STATUS_IGNORE);
+			}
 		}
+		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+		FPRINTF(stderr, "Simultaneous....end (rank %d)\n", rank);
+		starpu_mpi_barrier(MPI_COMM_WORLD);
 	}
 
-	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
-
-
-	/* Burst from 1 to 0 */
-	if (rank == 0)
 	{
-		printf("Done.\n");
-		printf("1 -> 0...\n");
+		/* Burst from 0 to 1 : rank 1 posts all the recvs, barrier, then rank 0 posts all the sends */
+		FPRINTF(stderr, "0 -> 1...start (rank %d)\n", rank);
 
-		for (int i = 0; i < NB_REQUESTS; i++)
+		if (rank == 1)
 		{
-			starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				recv_reqs[i] = NULL;
+				starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			}
 		}
-	}
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
+		starpu_mpi_barrier(MPI_COMM_WORLD);
 
-	if (rank == 1)
-	{
-		for (int i = 0; i < NB_REQUESTS; i++)
+		if (rank == 0)
 		{
-			starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				send_reqs[i] = NULL;
+				starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			}
 		}
-	}
-
-	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
 
+		if (rank == 0 || rank == 1)
+		{
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				if (rank == 1 && recv_reqs[i]) starpu_mpi_wait(&recv_reqs[i], MPI_STATUS_IGNORE);
+				if (rank == 0 && send_reqs[i]) starpu_mpi_wait(&send_reqs[i], MPI_STATUS_IGNORE);
+			}
+		}
+		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+		FPRINTF(stderr, "0 -> 1...done (rank %d)\n", rank);
+		starpu_mpi_barrier(MPI_COMM_WORLD);
+	}
 
-	/* Half burst from both nodes, second half burst is triggered after some requests finished. */
-	if (rank == 0)
 	{
-		printf("Done.\n");
-		printf("Half/half burst...\n");
-	}
+		FPRINTF(stderr, "1 -> 0...start (rank %d)\n", rank);
+		/* Burst from 1 to 0 */
+		if (rank == 0)
+		{
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				recv_reqs[i] = NULL;
+				starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			}
+		}
 
-	int received = 0;
+		starpu_mpi_barrier(MPI_COMM_WORLD);
 
-	for (int i = 0; i < NB_REQUESTS; i++)
-	{
-		if (i == (NB_REQUESTS / 4))
+		if (rank == 1)
 		{
-			starpu_mpi_irecv_detached(recv_handles[i], other_rank, i, MPI_COMM_WORLD, recv_callback, &received);
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				send_reqs[i] = NULL;
+				starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			}
 		}
-		else
+
+		if (rank == 0 || rank == 1)
 		{
-			starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				if (rank == 0 && recv_reqs[i]) starpu_mpi_wait(&recv_reqs[i], MPI_STATUS_IGNORE);
+				if (rank == 1 && send_reqs[i]) starpu_mpi_wait(&send_reqs[i], MPI_STATUS_IGNORE);
+			}
 		}
+		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+		FPRINTF(stderr, "1 -> 0...done (rank %d)\n", rank);
+		starpu_mpi_barrier(MPI_COMM_WORLD);
 	}
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
-
-	for (int i = 0; i < (NB_REQUESTS / 2); i++)
 	{
-		starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
-	}
+		/* Half burst from both nodes, second half burst is triggered after some requests finished. */
+		FPRINTF(stderr, "Half/half burst...start (rank %d)\n", rank);
 
-	STARPU_PTHREAD_MUTEX_LOCK(&mutex);
-	while (!received)
-		STARPU_PTHREAD_COND_WAIT(&cond, &mutex);
-	STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
+		int received = 0;
 
-	for (int i = (NB_REQUESTS / 2); i < NB_REQUESTS; i++)
-	{
-		starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
-	}
+		if (rank == 0 || rank == 1)
+		{
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				recv_reqs[i] = NULL;
+				if (i % 2)
+				{
+					starpu_mpi_irecv_detached(recv_handles[i], other_rank, i, MPI_COMM_WORLD, recv_callback, &received);
+				}
+				else
+				{
+					starpu_mpi_irecv(recv_handles[i], &recv_reqs[i], other_rank, i, MPI_COMM_WORLD);
+				}
+			}
+		}
 
-	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+		starpu_mpi_barrier(MPI_COMM_WORLD);
 
-	if (rank == 0)
-	{
-		printf("Done.\n");
-	}
+		if (rank == 0 || rank == 1)
+		{
+			for (int i = 0; i < (NB_REQUESTS / 2); i++)
+			{
+				send_reqs[i] = NULL;
+				starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			}
+		}
+
+		if (rank == 0 || rank == 1)
+		{
+			STARPU_PTHREAD_MUTEX_LOCK(&mutex);
+			while (!received)
+				STARPU_PTHREAD_COND_WAIT(&cond, &mutex);
+			STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
+		}
+
+		if (rank == 0 || rank == 1)
+		{
+			for (int i = (NB_REQUESTS / 2); i < NB_REQUESTS; i++)
+			{
+				send_reqs[i] = NULL;
+				starpu_mpi_isend_prio(send_handles[i], &send_reqs[i], other_rank, i, i, MPI_COMM_WORLD);
+			}
+		}
 
+		if (rank == 0 || rank == 1)
+		{
+			for (int i = 0; i < NB_REQUESTS; i++)
+			{
+				if (recv_reqs[i]) starpu_mpi_wait(&recv_reqs[i], MPI_STATUS_IGNORE);
+				if (send_reqs[i]) starpu_mpi_wait(&send_reqs[i], MPI_STATUS_IGNORE);
+			}
+		}
 
-	for (int i = 0; i < NB_REQUESTS; i++)
+		starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+		FPRINTF(stderr, "Half/half burst...done (rank %d)\n", rank);
+		starpu_mpi_barrier(MPI_COMM_WORLD);
+	}
+
+	/* Clear up */
+	if (rank == 0 || rank == 1)
 	{
-		starpu_data_unregister(send_handles[i]);
-		free(send_buffers[i]);
+		for (int i = 0; i < NB_REQUESTS; i++)
+		{
+			starpu_data_unregister(send_handles[i]);
+			free(send_buffers[i]);
 
-		starpu_data_unregister(recv_handles[i]);
-		free(recv_buffers[i]);
+			starpu_data_unregister(recv_handles[i]);
+			free(recv_buffers[i]);
+		}
 	}
 
 	starpu_mpi_shutdown();