|
@@ -59,11 +59,15 @@ static starpu_pthread_cond_t progress_cond;
|
|
|
static starpu_pthread_mutex_t progress_mutex;
|
|
|
static volatile int running = 0;
|
|
|
|
|
|
-extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
|
|
|
+static starpu_pthread_cond_t mpi_wait_for_all_running_cond;
|
|
|
+static int mpi_wait_for_all_running = 0;
|
|
|
+static starpu_pthread_mutex_t mpi_wait_for_all_running_mutex;
|
|
|
|
|
|
-/* Count requests posted by the application and not yet submitted to MPI, i.e pushed into the new_requests list */
|
|
|
+extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, int data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
|
|
|
|
|
|
-static volatile int pending_request = 0;
|
|
|
+/* Count running requests: this counter is incremented just before StarPU
|
|
|
+ * submits a MPI request, and decremented when a MPI request finishes. */
|
|
|
+static volatile int nb_pending_requests = 0;
|
|
|
|
|
|
#define REQ_FINALIZED 0x1
|
|
|
|
|
@@ -80,7 +84,7 @@ static starpu_sem_t callback_sem;
|
|
|
|
|
|
void _starpu_mpi_req_willpost(struct _starpu_mpi_req *req STARPU_ATTRIBUTE_UNUSED)
|
|
|
{
|
|
|
- STARPU_ATOMIC_ADD( &pending_request, 1);
|
|
|
+ STARPU_ATOMIC_ADD( &nb_pending_requests, 1);
|
|
|
}
|
|
|
|
|
|
/********************************************************/
|
|
@@ -269,16 +273,39 @@ int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
|
|
|
int _starpu_mpi_barrier(MPI_Comm comm)
|
|
|
{
|
|
|
_STARPU_MPI_LOG_IN();
|
|
|
- int ret;
|
|
|
- // STARPU_ASSERT_MSG(!barrier_running, "Concurrent starpu_mpi_barrier is not implemented, even on different communicators");
|
|
|
- ret = MPI_Barrier(comm);
|
|
|
|
|
|
+ int ret = MPI_Barrier(comm);
|
|
|
STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Barrier returning %d", ret);
|
|
|
|
|
|
_STARPU_MPI_LOG_OUT();
|
|
|
return ret;
|
|
|
}
|
|
|
|
|
|
+int _starpu_mpi_wait_for_all(MPI_Comm comm)
|
|
|
+{
|
|
|
+ (void) comm;
|
|
|
+ _STARPU_MPI_LOG_IN();
|
|
|
+
|
|
|
+ STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
|
|
|
+ STARPU_MPI_ASSERT_MSG(!mpi_wait_for_all_running, "Concurrent starpu_mpi_wait_for_all is not implemented, even on different communicators");
|
|
|
+ mpi_wait_for_all_running = 1;
|
|
|
+ do
|
|
|
+ {
|
|
|
+ while (nb_pending_requests)
|
|
|
+ STARPU_PTHREAD_COND_WAIT(&mpi_wait_for_all_running_cond, &mpi_wait_for_all_running_mutex);
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);
|
|
|
+
|
|
|
+ starpu_task_wait_for_all();
|
|
|
+
|
|
|
+ STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
|
|
|
+ } while (nb_pending_requests);
|
|
|
+ mpi_wait_for_all_running = 0;
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);
|
|
|
+
|
|
|
+ _STARPU_MPI_LOG_OUT();
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
/********************************************************/
|
|
|
/* */
|
|
|
/* Progression */
|
|
@@ -353,9 +380,13 @@ void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req,nm_sr_ev
|
|
|
req->completed = 1;
|
|
|
piom_cond_signal(&req->backend->req_cond, REQ_FINALIZED);
|
|
|
}
|
|
|
- int pending_remaining = STARPU_ATOMIC_ADD(&pending_request, -1);
|
|
|
- if (!running && !pending_remaining)
|
|
|
- starpu_sem_post(&callback_sem);
|
|
|
+ int pending_remaining = STARPU_ATOMIC_ADD(&nb_pending_requests, -1);
|
|
|
+ if (!pending_remaining)
|
|
|
+ {
|
|
|
+ STARPU_PTHREAD_COND_BROADCAST(&mpi_wait_for_all_running_cond);
|
|
|
+ if (!running)
|
|
|
+ starpu_sem_post(&callback_sem);
|
|
|
+ }
|
|
|
}
|
|
|
_STARPU_MPI_LOG_OUT();
|
|
|
}
|
|
@@ -476,24 +507,24 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
|
|
|
struct callback_lfstack_cell_s* c = callback_lfstack_pop(&callback_stack);
|
|
|
int err=0;
|
|
|
|
|
|
- if(running || pending_request>0)
|
|
|
+ if(running || nb_pending_requests>0)
|
|
|
{
|
|
|
/* shall we block ? */
|
|
|
err = starpu_sem_wait(&callback_sem);
|
|
|
- //running pending_request can change while waiting
|
|
|
+ //running nb_pending_requests can change while waiting
|
|
|
}
|
|
|
if(c==NULL)
|
|
|
{
|
|
|
c = callback_lfstack_pop(&callback_stack);
|
|
|
if (c == NULL)
|
|
|
{
|
|
|
- if(running && pending_request>0)
|
|
|
+ if(running && nb_pending_requests>0)
|
|
|
{
|
|
|
STARPU_ASSERT_MSG(c!=NULL, "Callback thread awakened without callback ready with error %d.",err);
|
|
|
}
|
|
|
else
|
|
|
{
|
|
|
- if (pending_request==0)
|
|
|
+ if (nb_pending_requests==0)
|
|
|
break;
|
|
|
}
|
|
|
continue;
|
|
@@ -511,14 +542,14 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
|
|
|
c->req->completed=1;
|
|
|
piom_cond_signal(&(c->req->backend->req_cond), REQ_FINALIZED);
|
|
|
}
|
|
|
- STARPU_ATOMIC_ADD( &pending_request, -1);
|
|
|
+ STARPU_ATOMIC_ADD( &nb_pending_requests, -1);
|
|
|
/* we signal that the request is completed.*/
|
|
|
|
|
|
free(c);
|
|
|
|
|
|
}
|
|
|
STARPU_ASSERT_MSG(callback_lfstack_pop(&callback_stack)==NULL, "List of callback not empty.");
|
|
|
- STARPU_ASSERT_MSG(pending_request==0, "Request still pending.");
|
|
|
+ STARPU_ASSERT_MSG(nb_pending_requests==0, "Request still pending.");
|
|
|
|
|
|
if (argc_argv->initialize_mpi)
|
|
|
{
|
|
@@ -580,6 +611,9 @@ int _starpu_mpi_progress_init(struct _starpu_mpi_argc_argv *argc_argv)
|
|
|
STARPU_PTHREAD_MUTEX_INIT(&progress_mutex, NULL);
|
|
|
STARPU_PTHREAD_COND_INIT(&progress_cond, NULL);
|
|
|
|
|
|
+ STARPU_PTHREAD_MUTEX_INIT(&mpi_wait_for_all_running_mutex, NULL);
|
|
|
+ STARPU_PTHREAD_COND_INIT(&mpi_wait_for_all_running_cond, NULL);
|
|
|
+
|
|
|
starpu_sem_init(&callback_sem, 0, 0);
|
|
|
running = 0;
|
|
|
|
|
@@ -669,6 +703,9 @@ void _starpu_mpi_progress_shutdown(void **value)
|
|
|
|
|
|
STARPU_PTHREAD_MUTEX_DESTROY(&progress_mutex);
|
|
|
STARPU_PTHREAD_COND_DESTROY(&progress_cond);
|
|
|
+
|
|
|
+ STARPU_PTHREAD_MUTEX_DESTROY(&mpi_wait_for_all_running_mutex);
|
|
|
+ STARPU_PTHREAD_COND_DESTROY(&mpi_wait_for_all_running_cond);
|
|
|
}
|
|
|
|
|
|
static int64_t _starpu_mpi_tag_max = INT64_MAX;
|