Bladeren bron

Change alt_callback for cached comms into single cb + flag.
Avoid multi callback parralel paths.

Romain LION 5 jaren geleden
bovenliggende
commit
6ae562e32b

+ 20 - 26
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.c

@@ -32,8 +32,8 @@ starpu_pthread_mutex_t           cp_lib_mutex;
 int                              my_rank;
 
 
-extern struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, void (*alt_callback)(void *), void *_alt_arg, int sequential_consistency);
-extern struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, void (*alt_callback)(void *), void *_alt_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
+extern struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, int sequential_consistency, int* cache_flag);
+extern struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count, int* cache_flag);
 
 
 
@@ -79,6 +79,13 @@ void _starpu_mpi_push_cp_ack_recv_cb(struct _starpu_mpi_cp_ack_arg_cb* arg)
 void _recv_internal_dup_ro_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
+	if (arg->cache_flag) {
+		_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+	}
+	else
+	{
+		_STARPU_MPI_FT_STATS_RECV_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+	}
 	starpu_data_release(arg->copy_handle);
 	_starpu_mpi_store_data_and_send_ack_cb(arg);
 }
@@ -93,19 +100,6 @@ void _recv_cp_external_data_cb(void* _args)
 	_starpu_mpi_store_data_and_send_ack_cb(arg);
 }
 
-void _recv_cp_internal_data_cb(void* _args)
-{
-	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_RECV_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle) : -1);
-}
-
-void _recv_cached_cp_internal_data_cb(void* _args)
-{
-	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(arg->type==STARPU_VALUE?arg->count:arg->type==STARPU_R?starpu_data_get_size(arg->handle):-1);
-	starpu_data_release(arg->handle);
-}
-
 void _send_cp_external_data_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
@@ -115,19 +109,21 @@ void _send_cp_external_data_cb(void* _args)
 	_starpu_mpi_push_cp_ack_recv_cb(arg);
 }
 
-void _send_cp_internal_data_cb(void* _args)
-{
-	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_SEND_CP_DATA(
-			arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle) : -1);
+void _send_cp_internal_data_cb(void* _args) {
+	struct _starpu_mpi_cp_ack_arg_cb *arg = (struct _starpu_mpi_cp_ack_arg_cb *) _args;
+	if (arg->cache_flag) {
+		_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+	}
+	else
+	{
+		_STARPU_MPI_FT_STATS_SEND_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+	}
 	_starpu_mpi_push_cp_ack_recv_cb(arg);
 }
 
 void _send_cached_cp_internal_data_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA(
-			arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle) : -1);
 	starpu_data_release(arg->handle);
 	_starpu_mpi_push_cp_ack_recv_cb(arg);
 }
@@ -193,8 +189,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					arg->msg.checkpoint_id = cp_template->cp_id;
 					arg->msg.checkpoint_instance = current_instance;
 					_starpu_mpi_isend_cache_aware(handle, item->backupped_by, starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0, 0,
-					                              &_send_cp_internal_data_cb, (void*)arg,
-					                              &_send_cached_cp_internal_data_cb, (void*)arg, 1);
+					                              &_send_cp_internal_data_cb, (void*)arg, 1, &arg->cache_flag);
 					// the callbacks need to post ack recv. The cache one needs to release the handle.
 
 				}
@@ -210,8 +205,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					arg->msg.checkpoint_id = cp_template->cp_id;
 					arg->msg.checkpoint_instance = current_instance;
 					_starpu_mpi_irecv_cache_aware(handle, starpu_mpi_data_get_rank(handle), starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0,
-					                              &_recv_cp_internal_data_cb, (void*)arg,
-					                              &_recv_cached_cp_internal_data_cb, (void*)arg, 1, 0, 1);
+												  NULL, NULL, 1, 0, 1, &arg->cache_flag);
 					// The callback needs to do nothing. The cached one must release the handle.
 					starpu_data_dup_ro(&arg->copy_handle, arg->handle, 1);
 					starpu_data_acquire_cb(arg->copy_handle, STARPU_R, _recv_internal_dup_ro_cb, arg);

+ 1 - 0
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.h

@@ -49,6 +49,7 @@ struct _starpu_mpi_cp_ack_arg_cb
 	int count;
 	starpu_mpi_tag_t              tag;
 	struct _starpu_mpi_cp_ack_msg msg;
+	int cache_flag;
 };
 
 struct _starpu_mpi_cp_discard_arg_cb

+ 6 - 4
mpi/src/starpu_mpi.c

@@ -175,7 +175,7 @@ int starpu_mpi_issend_detached(starpu_data_handle_t data_handle, int dest, starp
 	return starpu_mpi_issend_detached_prio(data_handle, dest, data_tag, 0, comm, callback, arg);
 }
 
-struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, void (*alt_callback)(void *), void *_alt_arg, int sequential_consistency)
+struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, int sequential_consistency, int* cache_flag)
 {
 	struct _starpu_mpi_req* req = NULL;
 	int already_sent = starpu_mpi_cached_send_set(data_handle, dest);
@@ -185,11 +185,12 @@ struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_
 			_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
 		_STARPU_MPI_DEBUG(1, "Send data %p to %d\n", data_handle, dest);
 		req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, detached, sync, prio, callback, _arg, sequential_consistency);
+		*cache_flag = 0;
 	}
 	else
 	{
 		_STARPU_MPI_DEBUG(1, "STARPU CACHE: Data already sent\n");
-		starpu_data_acquire_cb(data_handle, STARPU_R, alt_callback, _alt_arg);
+		*cache_flag = 1;
 	}
 	return req;
 }
@@ -257,7 +258,7 @@ int starpu_mpi_recv(starpu_data_handle_t data_handle, int source, starpu_mpi_tag
 	return 0;
 }
 
-struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, void (*alt_callback)(void *), void *_alt_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count)
+struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count, int* cache_flag)
 {
 	struct _starpu_mpi_req* req = NULL;
 	int already_received = starpu_mpi_cached_receive_set(data_handle);
@@ -267,11 +268,12 @@ struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_
 			_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
 		_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, source);
 		req = _starpu_mpi_irecv_common(data_handle, source, data_tag, comm, detached, sync, callback, _arg, sequential_consistency, is_internal_req, count);
+		*cache_flag = 0;
 	}
 	else
 	{
 		_STARPU_MPI_DEBUG(1, "STARPU CACHE: Data already received\n");
-		starpu_data_acquire_cb(data_handle, STARPU_R, alt_callback, _alt_arg);
+		*cache_flag =1;
 	}
 	return req;
 }