瀏覽代碼

mpi/src: fix barrier to also wait for ready requests

Nathalie Furmento 7 年之前
父節點
當前提交
02d2a94c5a
共有 1 個文件被更改,包括 17 次插入6 次删除
  1. 17 6
      mpi/src/starpu_mpi.c

+ 17 - 6
mpi/src/starpu_mpi.c

@@ -97,9 +97,11 @@ int _starpu_mpi_fake_world_rank = -1;
 
 /* Count requests posted by the application and not yet submitted to MPI */
 static starpu_pthread_mutex_t mutex_posted_requests;
-static int posted_requests = 0, newer_requests, barrier_running = 0;
+static starpu_pthread_mutex_t mutex_ready_requests;
+static int posted_requests = 0, ready_requests = 0, newer_requests, barrier_running = 0;
 
 #define _STARPU_MPI_INC_POSTED_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_posted_requests); posted_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_posted_requests); }
+#define _STARPU_MPI_INC_READY_REQUESTS(value) { STARPU_PTHREAD_MUTEX_LOCK(&mutex_ready_requests); ready_requests += value; STARPU_PTHREAD_MUTEX_UNLOCK(&mutex_ready_requests); }
 
 #pragma weak smpi_simulated_main_
 extern int smpi_simulated_main_(int argc, char *argv[]);
@@ -235,6 +237,7 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 					  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr,
 					  req->datatype_name, (int)req->count, req->registered_datatype);
 			_starpu_mpi_req_list_push_front(&ready_recv_requests, req);
+			_STARPU_MPI_INC_READY_REQUESTS(+1);
 
 			/* inform the starpu mpi thread that the request has been pushed in the ready_requests list */
 			STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
@@ -301,6 +304,7 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 						_STARPU_MPI_MALLOC(req->ptr, req->count);
 					}
 					_starpu_mpi_req_list_push_front(&ready_recv_requests, req);
+					_STARPU_MPI_INC_READY_REQUESTS(+1);
 					_starpu_mpi_request_destroy(sync_req);
 				}
 				else
@@ -317,6 +321,7 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 			_starpu_mpi_req_prio_list_push_front(&ready_send_requests, req);
 		else
 			_starpu_mpi_req_list_push_front(&ready_recv_requests, req);
+		_STARPU_MPI_INC_READY_REQUESTS(+1);
 		_STARPU_MPI_DEBUG(3, "Pushing new request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d \n",
 				  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr,
 				  req->datatype_name, (int)req->count, req->registered_datatype);
@@ -991,10 +996,10 @@ static void _starpu_mpi_barrier_func(struct _starpu_mpi_req *barrier_req)
 
 int _starpu_mpi_barrier(MPI_Comm comm)
 {
-	_STARPU_MPI_LOG_IN();
-
-	int ret = posted_requests;
 	struct _starpu_mpi_req *barrier_req;
+	int ret = posted_requests+ready_requests;
+
+	_STARPU_MPI_LOG_IN();
 
 	/* First wait for *both* all tasks and MPI requests to finish, in case
 	 * some tasks generate MPI requests, MPI requests generate tasks, etc.
@@ -1004,7 +1009,7 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 	barrier_running = 1;
 	do
 	{
-		while (posted_requests)
+		while (posted_requests || ready_requests)
 			/* Wait for all current MPI requests to finish */
 			STARPU_PTHREAD_COND_WAIT(&barrier_cond, &progress_mutex);
 		/* No current request, clear flag */
@@ -1016,7 +1021,7 @@ int _starpu_mpi_barrier(MPI_Comm comm)
 		/* Check newer_requests again, in case some MPI requests
 		 * triggered by tasks completed and triggered tasks between
 		 * wait_for_all finished and we take the lock */
-	} while (posted_requests || newer_requests);
+	} while (posted_requests || ready_requests || newer_requests);
 	barrier_running = 0;
 	STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
 
@@ -1391,6 +1396,7 @@ static void _starpu_mpi_receive_early_data(struct _starpu_mpi_envelope *envelope
 	// Handle the request immediatly to make sure the mpi_irecv is
 	// posted before receiving an other envelope
 	_starpu_mpi_req_list_erase(&ready_recv_requests, early_data_handle->req);
+	_STARPU_MPI_INC_READY_REQUESTS(-1);
 	STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
 	_starpu_mpi_handle_ready_request(early_data_handle->req);
 	STARPU_PTHREAD_MUTEX_LOCK(&progress_mutex);
@@ -1503,6 +1509,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 				break;
 
 			req = _starpu_mpi_req_list_pop_back(&ready_recv_requests);
+			_STARPU_MPI_INC_READY_REQUESTS(-1);
 
 			/* handling a request is likely to block for a while
 			 * (on a sync_data_with_mem call), we want to let the
@@ -1524,6 +1531,7 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 				break;
 
 			req = _starpu_mpi_req_prio_list_pop_back_highest(&ready_send_requests);
+			_STARPU_MPI_INC_READY_REQUESTS(-1);
 
 			/* handling a request is likely to block for a while
 			 * (on a sync_data_with_mem call), we want to let the
@@ -1756,6 +1764,8 @@ int _starpu_mpi_progress_init(struct _starpu_mpi_argc_argv *argc_argv)
 	_starpu_mpi_req_list_init(&detached_requests);
 
         STARPU_PTHREAD_MUTEX_INIT(&mutex_posted_requests, NULL);
+        STARPU_PTHREAD_MUTEX_INIT(&mutex_ready_requests, NULL);
+
         _starpu_mpi_comm_debug = starpu_getenv("STARPU_MPI_COMM") != NULL;
 	nready_process = starpu_get_env_number_default("STARPU_MPI_NREADY_PROCESS", 10);
 	ndetached_send = starpu_get_env_number_default("STARPU_MPI_NDETACHED_SEND", 10);
@@ -1812,6 +1822,7 @@ void _starpu_mpi_progress_shutdown(int *value)
 #endif
 
         STARPU_PTHREAD_MUTEX_DESTROY(&mutex_posted_requests);
+        STARPU_PTHREAD_MUTEX_DESTROY(&mutex_ready_requests);
         STARPU_PTHREAD_MUTEX_DESTROY(&progress_mutex);
         STARPU_PTHREAD_COND_DESTROY(&barrier_cond);
 }