Переглянути джерело

nmad: implement starpu_mpi_wait_for_all()

Philippe SWARTVAGHER 5 роки тому
батько
коміт
ae30f44077

+ 1 - 1
mpi/src/mpi/starpu_mpi_mpi.c

@@ -784,7 +784,7 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 	return 0;
 }
 
-int _starpu_mpi_wait_for_all(MPI_Comm comm)
+int _starpu_mpi_wait_for_all()
 {
 	_STARPU_MPI_LOG_IN();
 

+ 48 - 15
mpi/src/nmad/starpu_mpi_nmad.c

@@ -59,11 +59,15 @@ static starpu_pthread_cond_t progress_cond;
 static starpu_pthread_mutex_t progress_mutex;
 static volatile int running = 0;
 
-extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
+static starpu_pthread_cond_t mpi_wait_for_all_running_cond;
+static int mpi_wait_for_all_running = 0;
+static starpu_pthread_mutex_t mpi_wait_for_all_running_mutex;
 
-/* Count requests posted by the application and not yet submitted to MPI, i.e pushed into the new_requests list */
+extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
 
-static volatile int pending_request = 0;
+/* Count running requests: this counter is incremented just before StarPU
+ * submits a MPI request, and decremented when a MPI request finishes. */
+static volatile int nb_pending_requests = 0;
 
 #define REQ_FINALIZED 0x1
 
@@ -80,7 +84,7 @@ static starpu_sem_t callback_sem;
 
 void _starpu_mpi_req_willpost(struct _starpu_mpi_req *req STARPU_ATTRIBUTE_UNUSED)
 {
-	STARPU_ATOMIC_ADD( &pending_request, 1);
+	STARPU_ATOMIC_ADD( &nb_pending_requests, 1);
 }
 
 /********************************************************/
@@ -277,9 +281,28 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 	return ret;
 }
 
-int _starpu_mpi_wait_for_all(MPI_Comm comm)
+int _starpu_mpi_wait_for_all()
 {
-	assert(0);
+	_STARPU_MPI_LOG_IN();
+
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
+	STARPU_MPI_ASSERT_MSG(!mpi_wait_for_all_running, "Concurrent starpu_mpi_wait_for_all is not implemented, even on different communicators");
+	mpi_wait_for_all_running = 1;
+	do
+	{
+		while (nb_pending_requests)
+			STARPU_PTHREAD_COND_WAIT(&mpi_wait_for_all_running_cond, &mpi_wait_for_all_running_mutex);
+		STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);
+
+		starpu_task_wait_for_all();
+
+		STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
+	} while (nb_pending_requests);
+	mpi_wait_for_all_running = 0;
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
 }
 
 /********************************************************/
@@ -356,9 +379,13 @@ void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req,nm_sr_ev
 			req->completed = 1;
 			piom_cond_signal(&req->backend->req_cond, REQ_FINALIZED);
 		}
-		int pending_remaining = STARPU_ATOMIC_ADD(&pending_request, -1);
-		if (!running && !pending_remaining)
-			starpu_sem_post(&callback_sem);
+		int pending_remaining = STARPU_ATOMIC_ADD(&nb_pending_requests, -1);
+		if (!pending_remaining)
+		{
+			STARPU_PTHREAD_COND_BROADCAST(&mpi_wait_for_all_running_cond);
+			if (!running)
+				starpu_sem_post(&callback_sem);
+		}
 	}
 	_STARPU_MPI_LOG_OUT();
 }
@@ -479,24 +506,24 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 		struct callback_lfstack_cell_s* c = callback_lfstack_pop(&callback_stack);
 		int err=0;
 
-		if(running || pending_request>0)
+		if(running || nb_pending_requests>0)
 		{
 			/* shall we block ? */
 			err = starpu_sem_wait(&callback_sem);
-			//running pending_request can change while waiting
+			//running nb_pending_requests can change while waiting
 		}
 		if(c==NULL)
 		{
 			c = callback_lfstack_pop(&callback_stack);
 			if (c == NULL)
 			{
-				if(running && pending_request>0)
+				if(running && nb_pending_requests>0)
 				{
 					STARPU_ASSERT_MSG(c!=NULL, "Callback thread awakened without callback ready with error %d.",err);
 				}
 				else
 				{
-					if (pending_request==0)
+					if (nb_pending_requests==0)
 						break;
 				}
 				continue;
@@ -514,14 +541,14 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 			c->req->completed=1;
 			piom_cond_signal(&(c->req->backend->req_cond), REQ_FINALIZED);
 		}
-		STARPU_ATOMIC_ADD( &pending_request, -1);
+		STARPU_ATOMIC_ADD( &nb_pending_requests, -1);
 		/* we signal that the request is completed.*/
 
 		free(c);
 
 	}
 	STARPU_ASSERT_MSG(callback_lfstack_pop(&callback_stack)==NULL, "List of callback not empty.");
-	STARPU_ASSERT_MSG(pending_request==0, "Request still pending.");
+	STARPU_ASSERT_MSG(nb_pending_requests==0, "Request still pending.");
 
 	if (argc_argv->initialize_mpi)
 	{
@@ -583,6 +610,9 @@ int _starpu_mpi_progress_init(struct _starpu_mpi_argc_argv *argc_argv)
         STARPU_PTHREAD_MUTEX_INIT(&progress_mutex, NULL);
         STARPU_PTHREAD_COND_INIT(&progress_cond, NULL);
 
+        STARPU_PTHREAD_MUTEX_INIT(&mpi_wait_for_all_running_mutex, NULL);
+        STARPU_PTHREAD_COND_INIT(&mpi_wait_for_all_running_cond, NULL);
+
 	starpu_sem_init(&callback_sem, 0, 0);
 	running = 0;
 
@@ -672,6 +702,9 @@ void _starpu_mpi_progress_shutdown(void **value)
 
         STARPU_PTHREAD_MUTEX_DESTROY(&progress_mutex);
         STARPU_PTHREAD_COND_DESTROY(&progress_cond);
+
+        STARPU_PTHREAD_MUTEX_DESTROY(&mpi_wait_for_all_running_mutex);
+        STARPU_PTHREAD_COND_DESTROY(&mpi_wait_for_all_running_cond);
 }
 
 static int64_t _starpu_mpi_tag_max = INT64_MAX;

+ 2 - 2
mpi/src/starpu_mpi.c

@@ -435,7 +435,7 @@ void starpu_mpi_data_migrate(MPI_Comm comm, starpu_data_handle_t data, int new_r
 	return;
 }
 
-int starpu_mpi_wait_for_all(MPI_Comm comm)
+int starpu_mpi_wait_for_all(MPI_Comm comm STARPU_ATTRIBUTE_UNUSED)
 {
-	return _starpu_mpi_wait_for_all(comm);
+	return _starpu_mpi_wait_for_all();
 }

+ 1 - 1
mpi/src/starpu_mpi_private.h

@@ -331,7 +331,7 @@ void _starpu_mpi_wait_for_initialization();
 void _starpu_mpi_data_flush(starpu_data_handle_t data_handle);
 
 int _starpu_mpi_barrier(MPI_Comm comm);
-int _starpu_mpi_wait_for_all(MPI_Comm comm);
+int _starpu_mpi_wait_for_all();
 
 /*
  * Specific functions to backend implementation