Browse Source

StarPU MPI: Add extra functionalities for MPI Implementations which are not thread safe

   - New starpu_mpi initialization function which calls MPI_Init_Thread
   - New starpu_mpi_barrier() function to be called instead of MPI_Barrier

This ensures MPI functions are all called from the thread started by starpu_mpi
Nathalie Furmento 14 years ago
parent
commit
6b8dcf733f
3 changed files with 96 additions and 5 deletions
  1. 93 5
      mpi/starpu_mpi.c
  2. 2 0
      mpi/starpu_mpi.h
  3. 1 0
      mpi/starpu_mpi_private.h

+ 93 - 5
mpi/starpu_mpi.c

@@ -431,16 +431,77 @@ int starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status)
 }
 
 /*
+ *	Barrier
+ */
+
+static void starpu_mpi_barrier_func(struct starpu_mpi_req_s *barrier_req)
+{
+        _STARPU_MPI_LOG_IN();
+
+	barrier_req->ret = MPI_Barrier(barrier_req->comm);
+        STARPU_ASSERT(barrier_req->ret == MPI_SUCCESS);
+
+	handle_request_termination(barrier_req);
+        _STARPU_MPI_LOG_OUT();
+}
+
+int starpu_mpi_barrier(MPI_Comm comm)
+{
+        _STARPU_MPI_LOG_IN();
+	int ret;
+	struct starpu_mpi_req_s *barrier_req = calloc(1, sizeof(struct starpu_mpi_req_s));
+	STARPU_ASSERT(barrier_req);
+
+	/* Initialize the request structure */
+	PTHREAD_MUTEX_INIT(&(barrier_req->req_mutex), NULL);
+	PTHREAD_COND_INIT(&(barrier_req->req_cond), NULL);
+	barrier_req->func = starpu_mpi_barrier_func;
+	barrier_req->request_type = BARRIER_REQ;
+	barrier_req->comm = comm;
+
+	submit_mpi_req(barrier_req);
+
+	/* We wait for the MPI request to finish */
+	PTHREAD_MUTEX_LOCK(&barrier_req->req_mutex);
+	while (!barrier_req->completed)
+		PTHREAD_COND_WAIT(&barrier_req->req_cond, &barrier_req->req_mutex);
+	PTHREAD_MUTEX_UNLOCK(&barrier_req->req_mutex);
+
+	ret = barrier_req->ret;
+
+        //free(waiting_req);
+        _STARPU_MPI_LOG_OUT();
+	return ret;
+}
+
+/*
  *	Requests
  */
 
+static char *starpu_mpi_request_type(unsigned request_type)
+{
+        switch (request_type)
+                {
+                case SEND_REQ: return "send";
+                case RECV_REQ: return "recv";
+                case WAIT_REQ: return "wait";
+                case TEST_REQ: return "test";
+                case BARRIER_REQ: return "barrier";
+                default: return "unknown request type";
+                }
+}
+
 static void handle_request_termination(struct starpu_mpi_req_s *req)
 {
         _STARPU_MPI_LOG_IN();
-	MPI_Type_free(&req->datatype);
-	starpu_data_release(req->data_handle);
 
-	_STARPU_MPI_DEBUG("complete MPI (%s %d) req %p - tag %x\n", (req->request_type == RECV_REQ)?"recv : source":"send : dest", req->srcdst, &req->request, req->mpi_tag);
+	_STARPU_MPI_DEBUG("complete MPI (%s %d) req %p - tag %x\n", starpu_mpi_request_type(req->request_type), req->srcdst, &req->request, req->mpi_tag);
+
+        if (req->request_type != BARRIER_REQ) {
+                MPI_Type_free(&req->datatype);
+                starpu_data_release(req->data_handle);
+        }
+
 
 	if (req->request_type == RECV_REQ)
 	{
@@ -563,8 +624,25 @@ static void handle_new_request(struct starpu_mpi_req_s *req)
         _STARPU_MPI_LOG_OUT();
 }
 
-static void *progress_thread_func(void *arg __attribute__((unused)))
+static void *progress_thread_func(void *arg)
 {
+        int *initialize_mpi = (int *) arg;
+
+        if (*initialize_mpi) {
+#warning get real argc and argv from the application
+                int argc = 0;
+                char **argv = NULL;
+                int thread_support;
+                if (MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &thread_support) != MPI_SUCCESS) {
+                        fprintf(stderr,"MPI_Init_thread failed\n");
+                        exit(1);
+                }
+                if (thread_support == MPI_THREAD_FUNNELED)
+                        fprintf(stderr,"Warning: MPI only has funneled thread support, not serialized, hoping this will work\n");
+                if (thread_support < MPI_THREAD_FUNNELED)
+                        fprintf(stderr,"Warning: MPI does not have thread support!\n");
+        }
+
 	/* notify the main thread that the progression thread is ready */
 	PTHREAD_MUTEX_LOCK(&mutex);
 	running = 1;
@@ -657,6 +735,11 @@ static void _starpu_mpi_add_sync_point_in_fxt(void)
 
 int starpu_mpi_initialize(void)
 {
+        return starpu_mpi_initialize_extended(0, NULL, NULL);
+}
+
+int starpu_mpi_initialize_extended(int initialize_mpi, int *rank, int *world_size)
+{
 	PTHREAD_MUTEX_INIT(&mutex, NULL);
 	PTHREAD_COND_INIT(&cond, NULL);
 	new_requests = starpu_mpi_req_list_new();
@@ -664,13 +747,18 @@ int starpu_mpi_initialize(void)
 	PTHREAD_MUTEX_INIT(&detached_requests_mutex, NULL);
 	detached_requests = starpu_mpi_req_list_new();
 
-	int ret = pthread_create(&progress_thread, NULL, progress_thread_func, NULL);
+	int ret = pthread_create(&progress_thread, NULL, progress_thread_func, (void *)&initialize_mpi);
 
 	PTHREAD_MUTEX_LOCK(&mutex);
 	while (!running)
 		PTHREAD_COND_WAIT(&cond, &mutex);
 	PTHREAD_MUTEX_UNLOCK(&mutex);
 
+        if (initialize_mpi) {
+                MPI_Comm_rank(MPI_COMM_WORLD, rank);
+                MPI_Comm_size(MPI_COMM_WORLD, world_size);
+        }
+
 #ifdef USE_STARPU_ACTIVITY
 	hookid = starpu_progression_hook_register(progression_hook_func, NULL);
 	STARPU_ASSERT(hookid >= 0);

+ 2 - 0
mpi/starpu_mpi.h

@@ -30,7 +30,9 @@ int starpu_mpi_isend_detached(starpu_data_handle data_handle, int dest, int mpi_
 int starpu_mpi_irecv_detached(starpu_data_handle data_handle, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg);
 int starpu_mpi_wait(starpu_mpi_req *req, MPI_Status *status);
 int starpu_mpi_test(starpu_mpi_req *req, int *flag, MPI_Status *status);
+int starpu_mpi_barrier(MPI_Comm comm);
 int starpu_mpi_initialize(void);
+int starpu_mpi_initialize_extended(int initialize_mpi, int *rank, int *world_size);
 int starpu_mpi_shutdown(void);
 
 /* Some helper functions */

+ 1 - 0
mpi/starpu_mpi_private.h

@@ -29,6 +29,7 @@
 #define RECV_REQ	1
 #define WAIT_REQ        2
 #define TEST_REQ        3
+#define BARRIER_REQ     4
 
 LIST_TYPE(starpu_mpi_req,
 	/* description of the data at StarPU level */