Browse Source

StarPU MPI:
- When StarPU calls MPI_Init, it should also call MPI_Finalize
- Count the number of request posted by the application so as to not terminate the progression thread until all these requests have been processed, i.e the MPI request has been posted

Nathalie Furmento 14 years ago
parent
commit
09eac15629
1 changed files with 36 additions and 5 deletions
  1. 36 5
      mpi/starpu_mpi.c

+ 36 - 5
mpi/starpu_mpi.c

@@ -60,6 +60,12 @@ static pthread_mutex_t mutex;
 static pthread_t progress_thread;
 static int running = 0;
 
+/* Count requests posted by the application and not yet submitted to MPI, i.e pushed into the new_requests list */
+static pthread_mutex_t mutex_posted_requests;
+static int posted_requests = 0;
+
+#define INC_POSTED_REQUESTS(value) { PTHREAD_MUTEX_LOCK(&mutex_posted_requests); posted_requests += value; PTHREAD_MUTEX_UNLOCK(&mutex_posted_requests); }
+
 #if 0
 void starpu_mpi_debug(FILE *stream, const char *format, ...) {
         int rank;
@@ -108,6 +114,9 @@ static struct starpu_mpi_req_s *_starpu_mpi_isend_common(starpu_data_handle data
 	STARPU_ASSERT(req);
 
         _STARPU_MPI_LOG_IN();
+
+        INC_POSTED_REQUESTS(1);
+
 	/* Initialize the request structure */
 	req->submitted = 0;
 	req->completed = 0;
@@ -195,6 +204,8 @@ static struct starpu_mpi_req_s *_starpu_mpi_irecv_common(starpu_data_handle data
 	struct starpu_mpi_req_s *req = calloc(1, sizeof(struct starpu_mpi_req_s));
 	STARPU_ASSERT(req);
 
+        INC_POSTED_REQUESTS(1);
+
 	/* Initialize the request structure */
 	req->submitted = 0;
 	PTHREAD_MUTEX_INIT(&req->req_mutex, NULL);
@@ -311,6 +322,8 @@ int starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status)
 	STARPU_ASSERT(waiting_req);
 	struct starpu_mpi_req_s *req = *public_req;
 
+        INC_POSTED_REQUESTS(1);
+
 	/* We cannot try to complete a MPI request that was not actually posted
 	 * to MPI yet. */
 	PTHREAD_MUTEX_LOCK(&(req->req_mutex));
@@ -403,8 +416,9 @@ int starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
 		testing_req->completed = 0;
                 testing_req->request_type = TEST_REQ;
 
-		submit_mpi_req(testing_req);
-	
+                INC_POSTED_REQUESTS(1);
+                submit_mpi_req(testing_req);
+
 		/* We wait for the test request to finish */
 		PTHREAD_MUTEX_LOCK(&(testing_req->req_mutex));
 		while (!(testing_req->completed))
@@ -459,6 +473,7 @@ int starpu_mpi_barrier(MPI_Comm comm)
 	barrier_req->request_type = BARRIER_REQ;
 	barrier_req->comm = comm;
 
+        INC_POSTED_REQUESTS(1);
 	submit_mpi_req(barrier_req);
 
 	/* We wait for the MPI request to finish */
@@ -526,8 +541,11 @@ static void submit_mpi_req(void *arg)
         _STARPU_MPI_LOG_IN();
 	struct starpu_mpi_req_s *req = arg;
 
+        INC_POSTED_REQUESTS(-1);
+
 	PTHREAD_MUTEX_LOCK(&mutex);
 	starpu_mpi_req_list_push_front(new_requests, req);
+        _STARPU_MPI_DEBUG("Pushing new request type %d\n", req->request_type);
 	PTHREAD_COND_BROADCAST(&cond);
 	PTHREAD_MUTEX_UNLOCK(&mutex);
         _STARPU_MPI_LOG_OUT();
@@ -605,6 +623,7 @@ static void handle_new_request(struct starpu_mpi_req_s *req)
 	STARPU_ASSERT(req);
 
 	/* submit the request to MPI */
+        _STARPU_MPI_DEBUG("Handling new request type %d\n", req->request_type);
 	req->func(req);
 
 	if (req->detached)
@@ -626,13 +645,16 @@ static void handle_new_request(struct starpu_mpi_req_s *req)
 
 static void *progress_thread_func(void *arg)
 {
-        int *initialize_mpi = (int *) arg;
+        int initialize_mpi = *((int *) arg);
+
+        _STARPU_MPI_DEBUG("Initialize mpi: %d\n", initialize_mpi);
 
-        if (*initialize_mpi) {
+        if (initialize_mpi) {
 #warning get real argc and argv from the application
                 int argc = 0;
                 char **argv = NULL;
                 int thread_support;
+                _STARPU_MPI_DEBUG("Calling MPI_Init_thread\n");
                 if (MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &thread_support) != MPI_SUCCESS) {
                         fprintf(stderr,"MPI_Init_thread failed\n");
                         exit(1);
@@ -650,7 +672,7 @@ static void *progress_thread_func(void *arg)
 	PTHREAD_MUTEX_UNLOCK(&mutex);
 
 	PTHREAD_MUTEX_LOCK(&mutex);
-	while (running || !(starpu_mpi_req_list_empty(new_requests)) || !(starpu_mpi_req_list_empty(detached_requests))) {
+	while (running || posted_requests || !(starpu_mpi_req_list_empty(new_requests)) || !(starpu_mpi_req_list_empty(detached_requests))) {
 		/* shall we block ? */
 		unsigned block = starpu_mpi_req_list_empty(new_requests);
 
@@ -687,6 +709,12 @@ static void *progress_thread_func(void *arg)
 
 	STARPU_ASSERT(starpu_mpi_req_list_empty(detached_requests));
 	STARPU_ASSERT(starpu_mpi_req_list_empty(new_requests));
+        STARPU_ASSERT(posted_requests == 0);
+
+        if (initialize_mpi) {
+                _STARPU_MPI_DEBUG("Calling MPI_Finalize()\n");
+                MPI_Finalize();
+        }
 
 	PTHREAD_MUTEX_UNLOCK(&mutex);
 
@@ -747,6 +775,8 @@ int starpu_mpi_initialize_extended(int initialize_mpi, int *rank, int *world_siz
 	PTHREAD_MUTEX_INIT(&detached_requests_mutex, NULL);
 	detached_requests = starpu_mpi_req_list_new();
 
+        PTHREAD_MUTEX_INIT(&mutex_posted_requests, NULL);
+
 	int ret = pthread_create(&progress_thread, NULL, progress_thread_func, (void *)&initialize_mpi);
 
 	PTHREAD_MUTEX_LOCK(&mutex);
@@ -791,3 +821,4 @@ int starpu_mpi_shutdown(void)
 
 	return 0;
 }
+