Преглед на файлове

make starpu_mpi_barrier wait for tasks and requests first, to avoid not sending some important data that other ends might be waiting for

Samuel Thibault преди 13 години
родител
ревизия
5d08f1e686
променени са 1 файла, в които са добавени 42 реда и са изтрити 11 реда
  1. 42 11
      mpi/starpu_mpi.c

+ 42 - 11
mpi/starpu_mpi.c

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2009, 2010-2011  Université de Bordeaux 1
+ * Copyright (C) 2009, 2010-2012  Université de Bordeaux 1
  * Copyright (C) 2010, 2011  Centre National de la Recherche Scientifique
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -36,14 +36,17 @@ static struct _starpu_mpi_req_list *new_requests;
 static struct _starpu_mpi_req_list *detached_requests;
 static pthread_mutex_t detached_requests_mutex;
 
-static pthread_cond_t cond;
+/* Condition to wake up progression thread */
+static pthread_cond_t cond_progression;
+/* Condition to wake up waiting for all current MPI requests to finish */
+static pthread_cond_t cond_finished;
 static pthread_mutex_t mutex;
 static pthread_t progress_thread;
 static int running = 0;
 
 /* Count requests posted by the application and not yet submitted to MPI, i.e pushed into the new_requests list */
 static pthread_mutex_t mutex_posted_requests;
-static int posted_requests = 0;
+static int posted_requests = 0, newer_requests, barrier_running = 0;
 
 #define INC_POSTED_REQUESTS(value) { _STARPU_PTHREAD_MUTEX_LOCK(&mutex_posted_requests); posted_requests += value; _STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_posted_requests); }
 
@@ -433,6 +436,29 @@ int starpu_mpi_barrier(MPI_Comm comm)
 	struct _starpu_mpi_req *barrier_req = calloc(1, sizeof(struct _starpu_mpi_req));
 	STARPU_ASSERT(barrier_req);
 
+	/* First wait for *both* all tasks and MPI requests to finish, in case
+	 * some tasks generate MPI requests, MPI requests generate tasks, etc.
+	 */
+	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
+	STARPU_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
+	barrier_running = 1;
+	do {
+		while (posted_requests)
+			/* Wait for all current MPI requests to finish */
+			_STARPU_PTHREAD_COND_WAIT(&cond_finished, &mutex);
+		/* No current request, clear flag */
+		newer_requests = 0;
+		_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
+		/* Now wait for all tasks */
+		starpu_task_wait_for_all();
+		_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
+		/* Check newer_requests again, in case some MPI requests
+		 * triggered by tasks completed and triggered tasks between
+		 * wait_for_all finished and we take the lock */
+	} while (posted_requests || newer_requests);
+	barrier_running = 0;
+	_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
+
 	/* Initialize the request structure */
 	_STARPU_PTHREAD_MUTEX_INIT(&(barrier_req->req_mutex), NULL);
 	_STARPU_PTHREAD_COND_INIT(&(barrier_req->req_cond), NULL);
@@ -513,8 +539,9 @@ static void submit_mpi_req(void *arg)
 
 	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
 	_starpu_mpi_req_list_push_front(new_requests, req);
+	newer_requests = 1;
         _STARPU_MPI_DEBUG("Pushing new request type %d\n", req->request_type);
-	_STARPU_PTHREAD_COND_BROADCAST(&cond);
+	_STARPU_PTHREAD_COND_BROADCAST(&cond_progression);
 	_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
         _STARPU_MPI_LOG_OUT();
 }
@@ -531,7 +558,7 @@ static unsigned progression_hook_func(void *arg __attribute__((unused)))
 	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
 	if (!_starpu_mpi_req_list_empty(detached_requests))
 	{
-		_STARPU_PTHREAD_COND_SIGNAL(&cond);
+		_STARPU_PTHREAD_COND_SIGNAL(&cond_progression);
 		may_block = 0;
 	}
 	_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
@@ -607,7 +634,7 @@ static void handle_new_request(struct _starpu_mpi_req *req)
 		/* put the submitted request into the list of pending requests
 		 * so that it can be handled by the progression mechanisms */
 		_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
-		_STARPU_PTHREAD_COND_SIGNAL(&cond);
+		_STARPU_PTHREAD_COND_SIGNAL(&cond_progression);
 		_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
 	}
         _STARPU_MPI_LOG_OUT();
@@ -640,7 +667,7 @@ static void *progress_thread_func(void *arg)
 	/* notify the main thread that the progression thread is ready */
 	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
 	running = 1;
-	_STARPU_PTHREAD_COND_SIGNAL(&cond);
+	_STARPU_PTHREAD_COND_SIGNAL(&cond_progression);
 	_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
 
 	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
@@ -655,7 +682,10 @@ static void *progress_thread_func(void *arg)
 		if (block)
 		{
                         _STARPU_MPI_DEBUG("NO MORE REQUESTS TO HANDLE\n");
-			_STARPU_PTHREAD_COND_WAIT(&cond, &mutex);
+			if (barrier_running)
+				/* Tell mpi_barrier */
+				_STARPU_PTHREAD_COND_SIGNAL(&cond_finished);
+			_STARPU_PTHREAD_COND_WAIT(&cond_progression, &mutex);
 		}
 
 		/* test whether there are some terminated "detached request" */
@@ -736,7 +766,8 @@ static
 int _starpu_mpi_initialize(int initialize_mpi, int *rank, int *world_size)
 {
 	_STARPU_PTHREAD_MUTEX_INIT(&mutex, NULL);
-	_STARPU_PTHREAD_COND_INIT(&cond, NULL);
+	_STARPU_PTHREAD_COND_INIT(&cond_progression, NULL);
+	_STARPU_PTHREAD_COND_INIT(&cond_finished, NULL);
 	new_requests = _starpu_mpi_req_list_new();
 
 	_STARPU_PTHREAD_MUTEX_INIT(&detached_requests_mutex, NULL);
@@ -748,7 +779,7 @@ int _starpu_mpi_initialize(int initialize_mpi, int *rank, int *world_size)
 
 	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
 	while (!running)
-		_STARPU_PTHREAD_COND_WAIT(&cond, &mutex);
+		_STARPU_PTHREAD_COND_WAIT(&cond_progression, &mutex);
 	_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
 
         if (rank && world_size) {
@@ -790,7 +821,7 @@ int starpu_mpi_shutdown(void)
 	/* kill the progression thread */
 	_STARPU_PTHREAD_MUTEX_LOCK(&mutex);
 	running = 0;
-	_STARPU_PTHREAD_COND_BROADCAST(&cond);
+	_STARPU_PTHREAD_COND_BROADCAST(&cond_progression);
 	_STARPU_PTHREAD_MUTEX_UNLOCK(&mutex);
 
 	pthread_join(progress_thread, &value);