Browse Source

mpi: fix semantics for functions starpu_mpi_wait_for_all() and starpu_mpi_barrier()
- starpu_mpi_barrier() implements a MPI barrier
- starpu_mpi_wait_for_all() waits for tasks and local communications

Nathalie Furmento 5 years ago
parent
commit
021cbab561
4 changed files with 40 additions and 32 deletions
  1. 30 25
      mpi/src/mpi/starpu_mpi_mpi.c
  2. 6 3
      mpi/src/nmad/starpu_mpi_nmad.c
  3. 1 3
      mpi/src/starpu_mpi.c
  4. 3 1
      mpi/src/starpu_mpi_private.h

+ 30 - 25
mpi/src/mpi/starpu_mpi_mpi.c

@@ -96,7 +96,7 @@ starpu_pthread_queue_t _starpu_mpi_thread_dontsleep;
 /* Count requests posted by the application and not yet submitted to MPI */
 static starpu_pthread_mutex_t mutex_posted_requests;
 static starpu_pthread_mutex_t mutex_ready_requests;
-static int posted_requests = 0, ready_requests = 0, newer_requests, barrier_running = 0;
+static int posted_requests = 0, ready_requests = 0, newer_requests, mpi_wait_for_all_running = 0;
 
 #define _STARPU_MPI_INC_POSTED_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_posted_requests); posted_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_posted_requests); }
 #define _STARPU_MPI_INC_READY_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_ready_requests); ready_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_ready_requests); }
@@ -761,6 +761,31 @@ static void _starpu_mpi_barrier_func(struct _starpu_mpi_req *barrier_req)
 int _starpu_mpi_barrier(MPI_Comm comm)
 {
 	struct _starpu_mpi_req *barrier_req;
+
+	/* Initialize the request structure */
+	_starpu_mpi_request_init(&barrier_req);
+	barrier_req->prio = INT_MAX;
+	barrier_req->func = _starpu_mpi_barrier_func;
+	barrier_req->request_type = BARRIER_REQ;
+	barrier_req->node_tag.node.comm = comm;
+
+	_STARPU_MPI_INC_POSTED_REQUESTS(1);
+	_starpu_mpi_submit_ready_request(barrier_req);
+
+	/* We wait for the MPI request to finish */
+	STARPU_PTHREAD_MUTEX_LOCK(&barrier_req->backend->req_mutex);
+	while (!barrier_req->completed)
+		STARPU_PTHREAD_COND_WAIT(&barrier_req->backend->req_cond, &barrier_req->backend->req_mutex);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->backend->req_mutex);
+
+	_starpu_mpi_request_destroy(barrier_req);
+	_STARPU_MPI_LOG_OUT();
+
+	return 0;
+}
+
+int _starpu_mpi_wait_for_all(MPI_Comm comm)
+{
 	int ret = posted_requests+ready_requests;
 
 	_STARPU_MPI_LOG_IN();
@@ -769,8 +794,8 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 	 * some tasks generate MPI requests, MPI requests generate tasks, etc.
 	 */
 	STARPU_PTHREAD_MUTEX_LOCK(&progress_mutex);
-	STARPU_MPI_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
-	barrier_running = 1;
+	STARPU_MPI_ASSERT_MSG(!mpi_wait_for_all_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
+	mpi_wait_for_all_running = 1;
 	do
 	{
 		while (posted_requests || ready_requests)
@@ -786,28 +811,8 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 		 * triggered by tasks completed and triggered tasks between
 		 * wait_for_all finished and we take the lock */
 	} while (posted_requests || ready_requests || newer_requests);
-	barrier_running = 0;
+	mpi_wait_for_all_running = 0;
 	STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
-
-	/* Initialize the request structure */
-	_starpu_mpi_request_init(&barrier_req);
-	barrier_req->prio = INT_MAX;
-	barrier_req->func = _starpu_mpi_barrier_func;
-	barrier_req->request_type = BARRIER_REQ;
-	barrier_req->node_tag.node.comm = comm;
-
-	_STARPU_MPI_INC_POSTED_REQUESTS(1);
-	_starpu_mpi_submit_ready_request(barrier_req);
-
-	/* We wait for the MPI request to finish */
-	STARPU_PTHREAD_MUTEX_LOCK(&barrier_req->backend->req_mutex);
-	while (!barrier_req->completed)
-		STARPU_PTHREAD_COND_WAIT(&barrier_req->backend->req_cond, &barrier_req->backend->req_mutex);
-	STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->backend->req_mutex);
-
-	_starpu_mpi_request_destroy(barrier_req);
-	_STARPU_MPI_LOG_OUT();
-
 	return ret;
 }
 
@@ -1269,7 +1274,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 			_STARPU_MPI_DEBUG(3, "NO MORE REQUESTS TO HANDLE\n");
 			_STARPU_MPI_TRACE_SLEEP_BEGIN();
 
-			if (barrier_running)
+			if (mpi_wait_for_all_running)
 				/* Tell mpi_barrier */
 				STARPU_PTHREAD_COND_SIGNAL(&barrier_cond);
 			STARPU_PTHREAD_COND_WAIT(&progress_cond, &progress_mutex);

+ 6 - 3
mpi/src/nmad/starpu_mpi_nmad.c

@@ -269,16 +269,19 @@ int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
 int _starpu_mpi_barrier(MPI_Comm comm)
 {
 	_STARPU_MPI_LOG_IN();
-	int ret;
-	//	STARPU_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
-	ret = MPI_Barrier(comm);
 
+	int ret = MPI_Barrier(comm);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Barrier returning %d", ret);
 
 	_STARPU_MPI_LOG_OUT();
 	return ret;
 }
 
+int _starpu_mpi_wait_for_all(MPI_Comm comm)
+{
+	assert(0);
+}
+
 /********************************************************/
 /*                                                      */
 /*  Progression                                         */

+ 1 - 3
mpi/src/starpu_mpi.c

@@ -437,7 +437,5 @@ void starpu_mpi_data_migrate(MPI_Comm comm, starpu_data_handle_t data, int new_r
 
 int starpu_mpi_wait_for_all(MPI_Comm comm)
 {
-	starpu_task_wait_for_all();
-	starpu_mpi_barrier(comm);
-	return 0;
+	return _starpu_mpi_wait_for_all(comm);
 }

+ 3 - 1
mpi/src/starpu_mpi_private.h

@@ -310,7 +310,6 @@ void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req);
 void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req);
 int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status);
 int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status);
-int _starpu_mpi_barrier(MPI_Comm comm);
 
 struct _starpu_mpi_argc_argv
 {
@@ -331,6 +330,9 @@ void _starpu_mpi_wait_for_initialization();
 #endif
 void _starpu_mpi_data_flush(starpu_data_handle_t data_handle);
 
+int _starpu_mpi_barrier(MPI_Comm comm);
+int _starpu_mpi_wait_for_all(MPI_Comm comm);
+
 /*
  * Specific functions to backend implementation
  */