|
@@ -431,16 +431,77 @@ int starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
|
|
|
}
|
|
|
|
|
|
/*
|
|
|
+ * Barrier
|
|
|
+ */
|
|
|
+
|
|
|
+static void starpu_mpi_barrier_func(struct starpu_mpi_req_s *barrier_req)
|
|
|
+{
|
|
|
+ _STARPU_MPI_LOG_IN();
|
|
|
+
|
|
|
+ barrier_req->ret = MPI_Barrier(barrier_req->comm);
|
|
|
+ STARPU_ASSERT(barrier_req->ret == MPI_SUCCESS);
|
|
|
+
|
|
|
+ handle_request_termination(barrier_req);
|
|
|
+ _STARPU_MPI_LOG_OUT();
|
|
|
+}
|
|
|
+
|
|
|
+int starpu_mpi_barrier(MPI_Comm comm)
|
|
|
+{
|
|
|
+ _STARPU_MPI_LOG_IN();
|
|
|
+ int ret;
|
|
|
+ struct starpu_mpi_req_s *barrier_req = calloc(1, sizeof(struct starpu_mpi_req_s));
|
|
|
+ STARPU_ASSERT(barrier_req);
|
|
|
+
|
|
|
+ /* Initialize the request structure */
|
|
|
+ PTHREAD_MUTEX_INIT(&(barrier_req->req_mutex), NULL);
|
|
|
+ PTHREAD_COND_INIT(&(barrier_req->req_cond), NULL);
|
|
|
+ barrier_req->func = starpu_mpi_barrier_func;
|
|
|
+ barrier_req->request_type = BARRIER_REQ;
|
|
|
+ barrier_req->comm = comm;
|
|
|
+
|
|
|
+ submit_mpi_req(barrier_req);
|
|
|
+
|
|
|
+ /* We wait for the MPI request to finish */
|
|
|
+ PTHREAD_MUTEX_LOCK(&barrier_req->req_mutex);
|
|
|
+ while (!barrier_req->completed)
|
|
|
+ PTHREAD_COND_WAIT(&barrier_req->req_cond, &barrier_req->req_mutex);
|
|
|
+ PTHREAD_MUTEX_UNLOCK(&barrier_req->req_mutex);
|
|
|
+
|
|
|
+ ret = barrier_req->ret;
|
|
|
+
|
|
|
+ //free(waiting_req);
|
|
|
+ _STARPU_MPI_LOG_OUT();
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+
|
|
|
+/*
|
|
|
* Requests
|
|
|
*/
|
|
|
|
|
|
+static char *starpu_mpi_request_type(unsigned request_type)
|
|
|
+{
|
|
|
+ switch (request_type)
|
|
|
+ {
|
|
|
+ case SEND_REQ: return "send";
|
|
|
+ case RECV_REQ: return "recv";
|
|
|
+ case WAIT_REQ: return "wait";
|
|
|
+ case TEST_REQ: return "test";
|
|
|
+ case BARRIER_REQ: return "barrier";
|
|
|
+ default: return "unknown request type";
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static void handle_request_termination(struct starpu_mpi_req_s *req)
|
|
|
{
|
|
|
_STARPU_MPI_LOG_IN();
|
|
|
- MPI_Type_free(&req->datatype);
|
|
|
- starpu_data_release(req->data_handle);
|
|
|
|
|
|
- _STARPU_MPI_DEBUG("complete MPI (%s %d) req %p - tag %x\n", (req->request_type == RECV_REQ)?"recv : source":"send : dest", req->srcdst, &req->request, req->mpi_tag);
|
|
|
+ _STARPU_MPI_DEBUG("complete MPI (%s %d) req %p - tag %x\n", starpu_mpi_request_type(req->request_type), req->srcdst, &req->request, req->mpi_tag);
|
|
|
+
|
|
|
+ if (req->request_type != BARRIER_REQ) {
|
|
|
+ MPI_Type_free(&req->datatype);
|
|
|
+ starpu_data_release(req->data_handle);
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
if (req->request_type == RECV_REQ)
|
|
|
{
|
|
@@ -563,8 +624,25 @@ static void handle_new_request(struct starpu_mpi_req_s *req)
|
|
|
_STARPU_MPI_LOG_OUT();
|
|
|
}
|
|
|
|
|
|
-static void *progress_thread_func(void *arg __attribute__((unused)))
|
|
|
+static void *progress_thread_func(void *arg)
|
|
|
{
|
|
|
+ int *initialize_mpi = (int *) arg;
|
|
|
+
|
|
|
+ if (*initialize_mpi) {
|
|
|
+#warning get real argc and argv from the application
|
|
|
+ int argc = 0;
|
|
|
+ char **argv = NULL;
|
|
|
+ int thread_support;
|
|
|
+ if (MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &thread_support) != MPI_SUCCESS) {
|
|
|
+ fprintf(stderr,"MPI_Init_thread failed\n");
|
|
|
+ exit(1);
|
|
|
+ }
|
|
|
+ if (thread_support == MPI_THREAD_FUNNELED)
|
|
|
+ fprintf(stderr,"Warning: MPI only has funneled thread support, not serialized, hoping this will work\n");
|
|
|
+ if (thread_support < MPI_THREAD_FUNNELED)
|
|
|
+ fprintf(stderr,"Warning: MPI does not have thread support!\n");
|
|
|
+ }
|
|
|
+
|
|
|
/* notify the main thread that the progression thread is ready */
|
|
|
PTHREAD_MUTEX_LOCK(&mutex);
|
|
|
running = 1;
|
|
@@ -657,6 +735,11 @@ static void _starpu_mpi_add_sync_point_in_fxt(void)
|
|
|
|
|
|
int starpu_mpi_initialize(void)
|
|
|
{
|
|
|
+ return starpu_mpi_initialize_extended(0, NULL, NULL);
|
|
|
+}
|
|
|
+
|
|
|
+int starpu_mpi_initialize_extended(int initialize_mpi, int *rank, int *world_size)
|
|
|
+{
|
|
|
PTHREAD_MUTEX_INIT(&mutex, NULL);
|
|
|
PTHREAD_COND_INIT(&cond, NULL);
|
|
|
new_requests = starpu_mpi_req_list_new();
|
|
@@ -664,13 +747,18 @@ int starpu_mpi_initialize(void)
|
|
|
PTHREAD_MUTEX_INIT(&detached_requests_mutex, NULL);
|
|
|
detached_requests = starpu_mpi_req_list_new();
|
|
|
|
|
|
- int ret = pthread_create(&progress_thread, NULL, progress_thread_func, NULL);
|
|
|
+ int ret = pthread_create(&progress_thread, NULL, progress_thread_func, (void *)&initialize_mpi);
|
|
|
|
|
|
PTHREAD_MUTEX_LOCK(&mutex);
|
|
|
while (!running)
|
|
|
PTHREAD_COND_WAIT(&cond, &mutex);
|
|
|
PTHREAD_MUTEX_UNLOCK(&mutex);
|
|
|
|
|
|
+ if (initialize_mpi) {
|
|
|
+ MPI_Comm_rank(MPI_COMM_WORLD, rank);
|
|
|
+ MPI_Comm_size(MPI_COMM_WORLD, world_size);
|
|
|
+ }
|
|
|
+
|
|
|
#ifdef USE_STARPU_ACTIVITY
|
|
|
hookid = starpu_progression_hook_register(progression_hook_func, NULL);
|
|
|
STARPU_ASSERT(hookid >= 0);
|