|
@@ -96,7 +96,7 @@ starpu_pthread_queue_t _starpu_mpi_thread_dontsleep;
|
|
|
/* Count requests posted by the application and not yet submitted to MPI */
|
|
|
static starpu_pthread_mutex_t mutex_posted_requests;
|
|
|
static starpu_pthread_mutex_t mutex_ready_requests;
|
|
|
-static int posted_requests = 0, ready_requests = 0, newer_requests, barrier_running = 0;
|
|
|
+static int posted_requests = 0, ready_requests = 0, newer_requests, mpi_wait_for_all_running = 0;
|
|
|
|
|
|
#define _STARPU_MPI_INC_POSTED_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_posted_requests); posted_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_posted_requests); }
|
|
|
#define _STARPU_MPI_INC_READY_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_ready_requests); ready_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_ready_requests); }
|
|
@@ -761,6 +761,31 @@ static void _starpu_mpi_barrier_func(struct _starpu_mpi_req *barrier_req)
|
|
|
int _starpu_mpi_barrier(MPI_Comm comm)
|
|
|
{
|
|
|
struct _starpu_mpi_req *barrier_req;
|
|
|
+
|
|
|
+ /* Initialize the request structure */
|
|
|
+ _starpu_mpi_request_init(&barrier_req);
|
|
|
+ barrier_req->prio = INT_MAX;
|
|
|
+ barrier_req->func = _starpu_mpi_barrier_func;
|
|
|
+ barrier_req->request_type = BARRIER_REQ;
|
|
|
+ barrier_req->node_tag.node.comm = comm;
|
|
|
+
|
|
|
+ _STARPU_MPI_INC_POSTED_REQUESTS(1);
|
|
|
+ _starpu_mpi_submit_ready_request(barrier_req);
|
|
|
+
|
|
|
+ /* We wait for the MPI request to finish */
|
|
|
+ STARPU_PTHREAD_MUTEX_LOCK(&barrier_req->backend->req_mutex);
|
|
|
+ while (!barrier_req->completed)
|
|
|
+ STARPU_PTHREAD_COND_WAIT(&barrier_req->backend->req_cond, &barrier_req->backend->req_mutex);
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->backend->req_mutex);
|
|
|
+
|
|
|
+ _starpu_mpi_request_destroy(barrier_req);
|
|
|
+ _STARPU_MPI_LOG_OUT();
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int _starpu_mpi_wait_for_all(MPI_Comm comm)
|
|
|
+{
|
|
|
int ret = posted_requests+ready_requests;
|
|
|
|
|
|
_STARPU_MPI_LOG_IN();
|
|
@@ -769,8 +794,8 @@ int _starpu_mpi_barrier(MPI_Comm comm)
|
|
|
* some tasks generate MPI requests, MPI requests generate tasks, etc.
|
|
|
*/
|
|
|
STARPU_PTHREAD_MUTEX_LOCK(&progress_mutex);
|
|
|
- STARPU_MPI_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
|
|
|
- barrier_running = 1;
|
|
|
+ STARPU_MPI_ASSERT_MSG(!mpi_wait_for_all_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
|
|
|
+ mpi_wait_for_all_running = 1;
|
|
|
do
|
|
|
{
|
|
|
while (posted_requests || ready_requests)
|
|
@@ -786,28 +811,8 @@ int _starpu_mpi_barrier(MPI_Comm comm)
|
|
|
* triggered by tasks completed and triggered tasks between
|
|
|
* wait_for_all finished and we take the lock */
|
|
|
} while (posted_requests || ready_requests || newer_requests);
|
|
|
- barrier_running = 0;
|
|
|
+ mpi_wait_for_all_running = 0;
|
|
|
STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
|
|
|
-
|
|
|
- /* Initialize the request structure */
|
|
|
- _starpu_mpi_request_init(&barrier_req);
|
|
|
- barrier_req->prio = INT_MAX;
|
|
|
- barrier_req->func = _starpu_mpi_barrier_func;
|
|
|
- barrier_req->request_type = BARRIER_REQ;
|
|
|
- barrier_req->node_tag.node.comm = comm;
|
|
|
-
|
|
|
- _STARPU_MPI_INC_POSTED_REQUESTS(1);
|
|
|
- _starpu_mpi_submit_ready_request(barrier_req);
|
|
|
-
|
|
|
- /* We wait for the MPI request to finish */
|
|
|
- STARPU_PTHREAD_MUTEX_LOCK(&barrier_req->backend->req_mutex);
|
|
|
- while (!barrier_req->completed)
|
|
|
- STARPU_PTHREAD_COND_WAIT(&barrier_req->backend->req_cond, &barrier_req->backend->req_mutex);
|
|
|
- STARPU_PTHREAD_MUTEX_UNLOCK(&barrier_req->backend->req_mutex);
|
|
|
-
|
|
|
- _starpu_mpi_request_destroy(barrier_req);
|
|
|
- _STARPU_MPI_LOG_OUT();
|
|
|
-
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
@@ -1269,7 +1274,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
|
|
|
_STARPU_MPI_DEBUG(3, "NO MORE REQUESTS TO HANDLE\n");
|
|
|
_STARPU_MPI_TRACE_SLEEP_BEGIN();
|
|
|
|
|
|
- if (barrier_running)
|
|
|
+ if (mpi_wait_for_all_running)
|
|
|
/* Tell mpi_barrier */
|
|
|
STARPU_PTHREAD_COND_SIGNAL(&barrier_cond);
|
|
|
STARPU_PTHREAD_COND_WAIT(&progress_cond, &progress_mutex);
|