Procházet zdrojové kódy

Merge branch 'master' of git+ssh://scm.gforge.inria.fr/gitroot/starpu/starpu

Samuel Thibault před 5 roky
rodič
revize
22736383b3

+ 3 - 2
mpi/examples/matrix_decomposition/mpi_cholesky_codelets.c

@@ -115,6 +115,7 @@ void dw_cholesky(float ***matA, unsigned ld, int rank, int nodes, double *timing
 		}
 	}
 
+	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
 	starpu_mpi_barrier(MPI_COMM_WORLD);
 	start = starpu_timing_now();
 
@@ -159,9 +160,9 @@ void dw_cholesky(float ***matA, unsigned ld, int rank, int nodes, double *timing
 		starpu_iteration_pop();
 	}
 
-	starpu_task_wait_for_all();
-
+	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
 	starpu_mpi_barrier(MPI_COMM_WORLD);
+
 	end = starpu_timing_now();
 
 	for (m = 0; m < nblocks; m++)

+ 2 - 0
mpi/examples/mpi_lu/pxlu.c

@@ -899,6 +899,8 @@ double STARPU_PLU(plu_main)(unsigned _nblocks, int _rank, int _world_size, unsig
 		starpu_iteration_pop();
 	}
 
+	int wait_ret = starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+	STARPU_ASSERT(wait_ret == MPI_SUCCESS);
 	int barrier_ret = starpu_mpi_barrier(MPI_COMM_WORLD);
 	STARPU_ASSERT(barrier_ret == MPI_SUCCESS);
 

+ 14 - 9
mpi/examples/mpi_lu/pxlu_implicit.c

@@ -57,7 +57,7 @@ static void create_task_11(unsigned k)
 static void create_task_12(unsigned k, unsigned j)
 {
 #ifdef STARPU_DEVEL
-#warning temporary fix 
+#warning temporary fix
 #endif
 	starpu_mpi_task_insert(MPI_COMM_WORLD,
 			       //&STARPU_PLU(cl12),
@@ -79,7 +79,7 @@ static void create_task_12(unsigned k, unsigned j)
 static void create_task_21(unsigned k, unsigned i)
 {
 #ifdef STARPU_DEVEL
-#warning temporary fix 
+#warning temporary fix
 #endif
 	starpu_mpi_task_insert(MPI_COMM_WORLD,
 			       //&STARPU_PLU(cl21),
@@ -114,13 +114,14 @@ static void create_task_22(unsigned k, unsigned i, unsigned j)
 }
 
 /*
- *	code to bootstrap the factorization 
+ *	code to bootstrap the factorization
  */
 
 double STARPU_PLU(plu_main)(unsigned _nblocks, int _rank, int _world_size, unsigned _no_prio)
 {
 	double start;
 	double end;
+	int ret;
 
 	nblocks = _nblocks;
 	rank = _rank;
@@ -130,7 +131,10 @@ double STARPU_PLU(plu_main)(unsigned _nblocks, int _rank, int _world_size, unsig
 	/* create all the DAG nodes */
 	unsigned i,j,k;
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
+	ret = starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+	STARPU_ASSERT(ret == MPI_SUCCESS);
+	ret = starpu_mpi_barrier(MPI_COMM_WORLD);
+	STARPU_ASSERT(ret == MPI_SUCCESS);
 
 	start = starpu_timing_now();
 
@@ -170,15 +174,16 @@ double STARPU_PLU(plu_main)(unsigned _nblocks, int _rank, int _world_size, unsig
 		starpu_iteration_pop();
 	}
 
-	starpu_task_wait_for_all();
-
-	starpu_mpi_barrier(MPI_COMM_WORLD);
+	ret = starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+	STARPU_ASSERT(ret == MPI_SUCCESS);
+	ret = starpu_mpi_barrier(MPI_COMM_WORLD);
+	STARPU_ASSERT(ret == MPI_SUCCESS);
 
 	end = starpu_timing_now();
 
 	double timing = end - start;
-	
+
 //	fprintf(stderr, "RANK %d -> took %f ms\n", rank, timing/1000);
-	
+
 	return timing;
 }

+ 1 - 1
mpi/examples/user_datatype/user_datatype.c

@@ -120,8 +120,8 @@ int main(int argc, char **argv)
 		starpu_mpi_isend_detached(handle0, 0, 20, MPI_COMM_WORLD, NULL, NULL);
 	}
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
 	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+	starpu_mpi_barrier(MPI_COMM_WORLD);
 
 	starpu_mpi_datatype_unregister(handle0);
 	starpu_data_unregister(handle0);

+ 1 - 1
mpi/examples/user_datatype/user_datatype2.c

@@ -80,8 +80,8 @@ int main(int argc, char **argv)
 		starpu_mpi_isend_detached(handle0, 0, 20, MPI_COMM_WORLD, NULL, NULL);
 	}
 
-	starpu_mpi_barrier(MPI_COMM_WORLD);
 	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+	starpu_mpi_barrier(MPI_COMM_WORLD);
 
 	starpu_mpi_datatype_unregister(handle0);
 	starpu_data_unregister(handle0);

+ 31 - 27
mpi/src/mpi/starpu_mpi_mpi.c

@@ -96,7 +96,7 @@ starpu_pthread_queue_t _starpu_mpi_thread_dontsleep;
 /* Count requests posted by the application and not yet submitted to MPI */
 static starpu_pthread_mutex_t mutex_posted_requests;
 static starpu_pthread_mutex_t mutex_ready_requests;
-static int posted_requests = 0, ready_requests = 0, newer_requests, barrier_running = 0;
+static int posted_requests = 0, ready_requests = 0, newer_requests, mpi_wait_for_all_running = 0;
 
 #define _STARPU_MPI_INC_POSTED_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_posted_requests); posted_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_posted_requests); }
 #define _STARPU_MPI_INC_READY_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_ready_requests); ready_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_ready_requests); }
@@ -761,16 +761,40 @@ static void _starpu_mpi_barrier_func(struct _starpu_mpi_req *barrier_req)
 int _starpu_mpi_barrier(MPI_Comm comm)
 {
 	struct _starpu_mpi_req *barrier_req;
-	int ret = posted_requests+ready_requests;
 
+	/* Initialize the request structure */
+	_starpu_mpi_request_init(&barrier_req);
+	barrier_req->prio = INT_MAX;
+	barrier_req->func = _starpu_mpi_barrier_func;
+	barrier_req->request_type = BARRIER_REQ;
+	barrier_req->node_tag.node.comm = comm;
+
+	_STARPU_MPI_INC_POSTED_REQUESTS(1);
+	_starpu_mpi_submit_ready_request(barrier_req);
+
+	/* We wait for the MPI request to finish */
+	STARPU_PTHREAD_MUTEX_LOCK(&barrier_req->backend->req_mutex);
+	while (!barrier_req->completed)
+		STARPU_PTHREAD_COND_WAIT(&barrier_req->backend->req_cond, &barrier_req->backend->req_mutex);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->backend->req_mutex);
+
+	_starpu_mpi_request_destroy(barrier_req);
+	_STARPU_MPI_LOG_OUT();
+
+	return 0;
+}
+
+int _starpu_mpi_wait_for_all(MPI_Comm comm)
+{
+	(void) comm;
 	_STARPU_MPI_LOG_IN();
 
 	/* First wait for *both* all tasks and MPI requests to finish, in case
 	 * some tasks generate MPI requests, MPI requests generate tasks, etc.
 	 */
 	STARPU_PTHREAD_MUTEX_LOCK(&progress_mutex);
-	STARPU_MPI_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
-	barrier_running = 1;
+	STARPU_MPI_ASSERT_MSG(!mpi_wait_for_all_running, "Concurrent starpu_mpi_wait_for_all is not implemented, even on different communicators");
+	mpi_wait_for_all_running = 1;
 	do
 	{
 		while (posted_requests || ready_requests)
@@ -786,29 +810,9 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 		 * triggered by tasks completed and triggered tasks between
 		 * wait_for_all finished and we take the lock */
 	} while (posted_requests || ready_requests || newer_requests);
-	barrier_running = 0;
+	mpi_wait_for_all_running = 0;
 	STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
-
-	/* Initialize the request structure */
-	_starpu_mpi_request_init(&barrier_req);
-	barrier_req->prio = INT_MAX;
-	barrier_req->func = _starpu_mpi_barrier_func;
-	barrier_req->request_type = BARRIER_REQ;
-	barrier_req->node_tag.node.comm = comm;
-
-	_STARPU_MPI_INC_POSTED_REQUESTS(1);
-	_starpu_mpi_submit_ready_request(barrier_req);
-
-	/* We wait for the MPI request to finish */
-	STARPU_PTHREAD_MUTEX_LOCK(&barrier_req->backend->req_mutex);
-	while (!barrier_req->completed)
-		STARPU_PTHREAD_COND_WAIT(&barrier_req->backend->req_cond, &barrier_req->backend->req_mutex);
-	STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->backend->req_mutex);
-
-	_starpu_mpi_request_destroy(barrier_req);
-	_STARPU_MPI_LOG_OUT();
-
-	return ret;
+	return 0;
 }
 
 /********************************************************/
@@ -1269,7 +1273,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 			_STARPU_MPI_DEBUG(3, "NO MORE REQUESTS TO HANDLE\n");
 			_STARPU_MPI_TRACE_SLEEP_BEGIN();
 
-			if (barrier_running)
+			if (mpi_wait_for_all_running)
 				/* Tell mpi_barrier */
 				STARPU_PTHREAD_COND_SIGNAL(&barrier_cond);
 			STARPU_PTHREAD_COND_WAIT(&progress_cond, &progress_mutex);

+ 53 - 16
mpi/src/nmad/starpu_mpi_nmad.c

@@ -59,11 +59,15 @@ static starpu_pthread_cond_t progress_cond;
 static starpu_pthread_mutex_t progress_mutex;
 static volatile int running = 0;
 
-extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
+static starpu_pthread_cond_t mpi_wait_for_all_running_cond;
+static int mpi_wait_for_all_running = 0;
+static starpu_pthread_mutex_t mpi_wait_for_all_running_mutex;
 
-/* Count requests posted by the application and not yet submitted to MPI, i.e pushed into the new_requests list */
+extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
 
-static volatile int pending_request = 0;
+/* Count running requests: this counter is incremented just before StarPU
+ * submits a MPI request, and decremented when a MPI request finishes. */
+static volatile int nb_pending_requests = 0;
 
 #define REQ_FINALIZED 0x1
 
@@ -80,7 +84,7 @@ static starpu_sem_t callback_sem;
 
 void _starpu_mpi_req_willpost(struct _starpu_mpi_req *req STARPU_ATTRIBUTE_UNUSED)
 {
-	STARPU_ATOMIC_ADD( &pending_request, 1);
+	STARPU_ATOMIC_ADD( &nb_pending_requests, 1);
 }
 
 /********************************************************/
@@ -269,16 +273,39 @@ int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
 int _starpu_mpi_barrier(MPI_Comm comm)
 {
 	_STARPU_MPI_LOG_IN();
-	int ret;
-	//	STARPU_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
-	ret = MPI_Barrier(comm);
 
+	int ret = MPI_Barrier(comm);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Barrier returning %d", ret);
 
 	_STARPU_MPI_LOG_OUT();
 	return ret;
 }
 
+int _starpu_mpi_wait_for_all(MPI_Comm comm)
+{
+	(void) comm;
+	_STARPU_MPI_LOG_IN();
+
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
+	STARPU_MPI_ASSERT_MSG(!mpi_wait_for_all_running, "Concurrent starpu_mpi_wait_for_all is not implemented, even on different communicators");
+	mpi_wait_for_all_running = 1;
+	do
+	{
+		while (nb_pending_requests)
+			STARPU_PTHREAD_COND_WAIT(&mpi_wait_for_all_running_cond, &mpi_wait_for_all_running_mutex);
+		STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);
+
+		starpu_task_wait_for_all();
+
+		STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
+	} while (nb_pending_requests);
+	mpi_wait_for_all_running = 0;
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);
+
+	_STARPU_MPI_LOG_OUT();
+	return 0;
+}
+
 /********************************************************/
 /*                                                      */
 /*  Progression                                         */
@@ -353,9 +380,13 @@ void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req,nm_sr_ev
 			req->completed = 1;
 			piom_cond_signal(&req->backend->req_cond, REQ_FINALIZED);
 		}
-		int pending_remaining = STARPU_ATOMIC_ADD(&pending_request, -1);
-		if (!running && !pending_remaining)
-			starpu_sem_post(&callback_sem);
+		int pending_remaining = STARPU_ATOMIC_ADD(&nb_pending_requests, -1);
+		if (!pending_remaining)
+		{
+			STARPU_PTHREAD_COND_BROADCAST(&mpi_wait_for_all_running_cond);
+			if (!running)
+				starpu_sem_post(&callback_sem);
+		}
 	}
 	_STARPU_MPI_LOG_OUT();
 }
@@ -476,24 +507,24 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 		struct callback_lfstack_cell_s* c = callback_lfstack_pop(&callback_stack);
 		int err=0;
 
-		if(running || pending_request>0)
+		if(running || nb_pending_requests>0)
 		{
 			/* shall we block ? */
 			err = starpu_sem_wait(&callback_sem);
-			//running pending_request can change while waiting
+			//running nb_pending_requests can change while waiting
 		}
 		if(c==NULL)
 		{
 			c = callback_lfstack_pop(&callback_stack);
 			if (c == NULL)
 			{
-				if(running && pending_request>0)
+				if(running && nb_pending_requests>0)
 				{
 					STARPU_ASSERT_MSG(c!=NULL, "Callback thread awakened without callback ready with error %d.",err);
 				}
 				else
 				{
-					if (pending_request==0)
+					if (nb_pending_requests==0)
 						break;
 				}
 				continue;
@@ -511,14 +542,14 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 			c->req->completed=1;
 			piom_cond_signal(&(c->req->backend->req_cond), REQ_FINALIZED);
 		}
-		STARPU_ATOMIC_ADD( &pending_request, -1);
+		STARPU_ATOMIC_ADD( &nb_pending_requests, -1);
 		/* we signal that the request is completed.*/
 
 		free(c);
 
 	}
 	STARPU_ASSERT_MSG(callback_lfstack_pop(&callback_stack)==NULL, "List of callback not empty.");
-	STARPU_ASSERT_MSG(pending_request==0, "Request still pending.");
+	STARPU_ASSERT_MSG(nb_pending_requests==0, "Request still pending.");
 
 	if (argc_argv->initialize_mpi)
 	{
@@ -580,6 +611,9 @@ int _starpu_mpi_progress_init(struct _starpu_mpi_argc_argv *argc_argv)
         STARPU_PTHREAD_MUTEX_INIT(&progress_mutex, NULL);
         STARPU_PTHREAD_COND_INIT(&progress_cond, NULL);
 
+        STARPU_PTHREAD_MUTEX_INIT(&mpi_wait_for_all_running_mutex, NULL);
+        STARPU_PTHREAD_COND_INIT(&mpi_wait_for_all_running_cond, NULL);
+
 	starpu_sem_init(&callback_sem, 0, 0);
 	running = 0;
 
@@ -669,6 +703,9 @@ void _starpu_mpi_progress_shutdown(void **value)
 
         STARPU_PTHREAD_MUTEX_DESTROY(&progress_mutex);
         STARPU_PTHREAD_COND_DESTROY(&progress_cond);
+
+        STARPU_PTHREAD_MUTEX_DESTROY(&mpi_wait_for_all_running_mutex);
+        STARPU_PTHREAD_COND_DESTROY(&mpi_wait_for_all_running_cond);
 }
 
 static int64_t _starpu_mpi_tag_max = INT64_MAX;

+ 1 - 3
mpi/src/starpu_mpi.c

@@ -437,7 +437,5 @@ void starpu_mpi_data_migrate(MPI_Comm comm, starpu_data_handle_t data, int new_r
 
 int starpu_mpi_wait_for_all(MPI_Comm comm)
 {
-	starpu_task_wait_for_all();
-	starpu_mpi_barrier(comm);
-	return 0;
+	return _starpu_mpi_wait_for_all(comm);
 }

+ 3 - 1
mpi/src/starpu_mpi_private.h

@@ -310,7 +310,6 @@ void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req);
 void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req);
 int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status);
 int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status);
-int _starpu_mpi_barrier(MPI_Comm comm);
 
 struct _starpu_mpi_argc_argv
 {
@@ -331,6 +330,9 @@ void _starpu_mpi_wait_for_initialization();
 #endif
 void _starpu_mpi_data_flush(starpu_data_handle_t data_handle);
 
+int _starpu_mpi_barrier(MPI_Comm comm);
+int _starpu_mpi_wait_for_all(MPI_Comm comm);
+
 /*
  * Specific functions to backend implementation
  */

+ 0 - 8
mpi/tests/sendrecv_gemm_bench.c

@@ -320,13 +320,6 @@ static void* comm_thread_func(void* arg)
 	return NULL;
 }
 
-#ifdef STARPU_USE_MPI_MPI
-int main(int argc, char **argv)
-{
-	FPRINTF(stderr, "This test does not work with the MPI backend.\n");
-	return STARPU_TEST_SKIPPED;
-}
-#else
 int main(int argc, char **argv)
 {
 	double start, end;
@@ -467,4 +460,3 @@ enodev:
 
 	return ret;
 }
-#endif