浏览代码

The count of checkpoints comm which benefits to the computation is simplified and corrected. Only the recv count is corrected (send count correction is more expensive)

Romain LION 5 年之前
父节点
当前提交
beb3330bd7

+ 32 - 18
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.c

@@ -35,7 +35,7 @@ int                              my_rank;
 extern struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, int sequential_consistency, int* cache_flag);
 extern struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count, int* cache_flag);
 
-
+extern int starpu_mpi_cache_set_ft_induced_cache_receive(starpu_data_handle_t data_handle);
 
 void _ack_msg_send_cb(void* _args)
 {
@@ -79,15 +79,7 @@ void _starpu_mpi_push_cp_ack_recv_cb(struct _starpu_mpi_cp_ack_arg_cb* arg)
 void _recv_internal_dup_ro_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	if (arg->cache_flag) {
-		_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->copy_handle): -1);
-	}
-	else
-	{
-		struct _starpu_mpi_data* mpi_data = _starpu_mpi_data_get(arg->copy_handle);
-		mpi_data->cache_received.ft_induced_cache = 1;
-		_STARPU_MPI_FT_STATS_RECV_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->copy_handle): -1);
-	}
+
 	starpu_data_release(arg->copy_handle);
 	_starpu_mpi_store_data_and_send_ack_cb(arg);
 }
@@ -95,7 +87,7 @@ void _recv_internal_dup_ro_cb(void* _args)
 void _recv_cp_external_data_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_RECV_CP_DATA(arg->type==STARPU_VALUE?arg->count:arg->type==STARPU_R?starpu_data_get_size(arg->handle):-1);
+	_STARPU_MPI_FT_STATS_RECV_CP_DATA(starpu_data_get_size(arg->handle));
 	// an handle has specifically been created, Let's get the value back, and unregister the handle
 	arg->copy_handle = starpu_data_handle_to_pointer(arg->handle, STARPU_MAIN_RAM);
 	starpu_data_unregister_submit(arg->handle);
@@ -105,7 +97,7 @@ void _recv_cp_external_data_cb(void* _args)
 void _send_cp_external_data_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_SEND_CP_DATA(arg->type==STARPU_VALUE?arg->count:arg->type==STARPU_R?starpu_data_get_size(arg->handle):-1);
+	_STARPU_MPI_FT_STATS_SEND_CP_DATA(starpu_data_get_size(arg->handle));
 	free(starpu_data_handle_to_pointer(arg->handle, STARPU_MAIN_RAM));
 	starpu_data_unregister_submit(arg->handle);
 	_starpu_mpi_push_cp_ack_recv_cb(arg);
@@ -113,18 +105,39 @@ void _send_cp_external_data_cb(void* _args)
 
 void _send_cp_internal_data_cb(void* _args) {
 	struct _starpu_mpi_cp_ack_arg_cb *arg = (struct _starpu_mpi_cp_ack_arg_cb *) _args;
+	_starpu_mpi_push_cp_ack_recv_cb(arg);
+}
+
+void _send_internal_data_stats(struct _starpu_mpi_cp_ack_arg_cb* arg)
+{
 	if (arg->cache_flag) {
-		_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+		_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA( starpu_data_get_size(arg->handle));
 	}
 	else
 	{
-		struct _starpu_mpi_data* mpi_data = _starpu_mpi_data_get(arg->handle);
-		mpi_data->cache_sent[arg->rank].ft_induced_cache = 1;
-		_STARPU_MPI_FT_STATS_SEND_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+		_STARPU_MPI_FT_STATS_SEND_CP_DATA(starpu_data_get_size(arg->handle));
 	}
-	_starpu_mpi_push_cp_ack_recv_cb(arg);
 }
 
+#ifdef STARPU_USE_MPI_FT_STATS
+void _recv_internal_data_stats(struct _starpu_mpi_cp_ack_arg_cb* arg)
+{
+	if (arg->cache_flag) {
+		_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA( starpu_data_get_size(arg->handle));
+	}
+	else
+	{
+		_STARPU_MPI_FT_STATS_RECV_CP_DATA(starpu_data_get_size(arg->handle));
+		starpu_mpi_cache_set_ft_induced_cache_receive(arg->handle);
+	}
+}
+#else
+void _recv_internal_data_stats(STARPU_ATTRIBUTE_UNUSED struct _starpu_mpi_cp_ack_arg_cb* arg)
+{
+	return;
+}
+#endif
+
 int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_template)
 {
 	starpu_data_handle_t handle;
@@ -188,7 +201,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					_starpu_mpi_isend_cache_aware(handle, item->backupped_by, starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0, 0,
 					                              &_send_cp_internal_data_cb, (void*)arg, 1, &arg->cache_flag);
 					// the callbacks need to post ack recv. The cache one needs to release the handle.
-
+					_send_internal_data_stats(arg);
 				}
 				else if (item->backup_of == starpu_mpi_data_get_rank(handle))
 				{
@@ -204,6 +217,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					_starpu_mpi_irecv_cache_aware(handle, starpu_mpi_data_get_rank(handle), starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0,
 												  NULL, NULL, 1, 0, 1, &arg->cache_flag);
 					// The callback needs to do nothing. The cached one must release the handle.
+					_recv_internal_data_stats(arg);
 					starpu_data_dup_ro(&arg->copy_handle, arg->handle, 1);
 					starpu_data_acquire_cb(arg->copy_handle, STARPU_R, _recv_internal_dup_ro_cb, arg);
 					// The callback need to store the data and post ack send.

+ 34 - 35
mpi/src/starpu_mpi_cache.c

@@ -130,13 +130,12 @@ void _starpu_mpi_cache_data_init(starpu_data_handle_t data_handle)
 		return;
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
-	mpi_data->cache_received.in_cache         = 0;
-	mpi_data->cache_received.ft_induced_cache = 0;
+	mpi_data->cache_received = 0;
+	mpi_data->ft_induced_cache_received = 0;
 	_STARPU_MALLOC(mpi_data->cache_sent, _starpu_cache_comm_size*sizeof(mpi_data->cache_sent[0]));
 	for(i=0 ; i<_starpu_cache_comm_size ; i++)
 	{
-		mpi_data->cache_sent[i].in_cache         = 0;
-		mpi_data->cache_sent[i].ft_induced_cache = 0;
+		mpi_data->cache_sent[i] = 0;
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 }
@@ -187,14 +186,14 @@ void starpu_mpi_cached_receive_clear(starpu_data_handle_t data_handle)
 	STARPU_ASSERT(mpi_data->magic == 42);
 	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
 
-	if (mpi_data->cache_received.in_cache == 1)
+	if (mpi_data->cache_received == 1)
 	{
 #ifdef STARPU_DEVEL
 #  warning TODO: Somebody else will write to the data, so discard our cached copy if any. starpu_mpi could just remember itself.
 #endif
 		_STARPU_MPI_DEBUG(2, "Clearing receive cache for data %p\n", data_handle);
-		mpi_data->cache_received.in_cache         = 0;
-		mpi_data->cache_received.ft_induced_cache = 0;
+		mpi_data->cache_received = 0;
+		mpi_data->ft_induced_cache_received = 0;
 		starpu_data_invalidate_submit(data_handle);
 		_starpu_mpi_cache_data_remove_nolock(data_handle);
 		_starpu_mpi_cache_stats_dec(mpi_rank, data_handle);
@@ -202,6 +201,12 @@ void starpu_mpi_cached_receive_clear(starpu_data_handle_t data_handle)
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 }
 
+int starpu_mpi_cache_set_ft_induced_cache_receive(starpu_data_handle_t data_handle)
+{
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
+	mpi_data->ft_induced_cache_received = 1;
+}
+
 int starpu_mpi_cached_receive_set(starpu_data_handle_t data_handle)
 {
 	int mpi_rank = starpu_mpi_data_get_rank(data_handle);
@@ -214,22 +219,24 @@ int starpu_mpi_cached_receive_set(starpu_data_handle_t data_handle)
 	STARPU_ASSERT(mpi_data->magic == 42);
 	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
 
-	int already_received = mpi_data->cache_received.in_cache;
+	int already_received = mpi_data->cache_received;
 	if (already_received == 0)
 	{
 		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been received by %d\n", data_handle, mpi_rank);
-		mpi_data->cache_received.in_cache = 1;
+		mpi_data->cache_received = 1;
 		_starpu_mpi_cache_data_add_nolock(data_handle);
 		_starpu_mpi_cache_stats_inc(mpi_rank, data_handle);
 	}
 	else
 	{
-		if (mpi_data->cache_received.ft_induced_cache == 1)
-		{
-			_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(starpu_data_get_size(data_handle));
-			_STARPU_MPI_FT_STATS_CANCEL_RECV_CP_DATA(starpu_data_get_size(data_handle));
-			mpi_data->cache_received.ft_induced_cache = 0;
-		}
+		#ifdef STARPU_USE_MPI_FT_STATS
+			if (mpi_data->ft_induced_cache_received == 1)
+			{
+				_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(starpu_data_get_size(data_handle));
+				_STARPU_MPI_FT_STATS_CANCEL_RECV_CP_DATA(starpu_data_get_size(data_handle));
+				mpi_data->ft_induced_cache_received = 0;
+			}
+		#endif //STARPU_USE_MPI_FT_STATS
 		_STARPU_MPI_DEBUG(2, "Do not receive data %p from node %d as it is already available\n", data_handle, mpi_rank);
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
@@ -246,7 +253,7 @@ int starpu_mpi_cached_receive(starpu_data_handle_t data_handle)
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
 	STARPU_ASSERT(mpi_data->magic == 42);
-	already_received = mpi_data->cache_received.in_cache;
+	already_received = mpi_data->cache_received;
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 	return already_received;
 }
@@ -266,11 +273,10 @@ void starpu_mpi_cached_send_clear(starpu_data_handle_t data_handle)
 	starpu_mpi_comm_size(mpi_data->node_tag.node.comm, &size);
 	for(n=0 ; n<size ; n++)
 	{
-		if (mpi_data->cache_sent[n].in_cache == 1)
+		if (mpi_data->cache_sent[n] == 1)
 		{
 			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
-			mpi_data->cache_sent[n].in_cache = 0;
-			mpi_data->cache_sent[n].ft_induced_cache = 0;
+			mpi_data->cache_sent[n] = 0;
 			_starpu_mpi_cache_data_remove_nolock(data_handle);
 		}
 	}
@@ -287,21 +293,15 @@ int starpu_mpi_cached_send_set(starpu_data_handle_t data_handle, int dest)
 	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
-	int already_sent = mpi_data->cache_sent[dest].in_cache;
-	if (mpi_data->cache_sent[dest].in_cache == 0)
+	int already_sent = mpi_data->cache_sent[dest];
+	if (mpi_data->cache_sent[dest] == 0)
 	{
-		mpi_data->cache_sent[dest].in_cache = 1;
+		mpi_data->cache_sent[dest] = 1;
 		_starpu_mpi_cache_data_add_nolock(data_handle);
 		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been sent to %d\n", data_handle, dest);
 	}
 	else
 	{
-		if (mpi_data->cache_sent[dest].ft_induced_cache == 1)
-		{
-			_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA(starpu_data_get_size(data_handle));
-			_STARPU_MPI_FT_STATS_CANCEL_SEND_CP_DATA(starpu_data_get_size(data_handle));
-			mpi_data->cache_sent[dest].ft_induced_cache = 0;
-		}
 		_STARPU_MPI_DEBUG(2, "Do not send data %p to node %d as it has already been sent\n", data_handle, dest);
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
@@ -318,7 +318,7 @@ int starpu_mpi_cached_send(starpu_data_handle_t data_handle, int dest)
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
 	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
-	already_sent = mpi_data->cache_sent[dest].in_cache;
+	already_sent = mpi_data->cache_sent[dest];
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 	return already_sent;
 }
@@ -334,21 +334,20 @@ static void _starpu_mpi_cache_flush_nolock(starpu_data_handle_t data_handle)
 	starpu_mpi_comm_size(mpi_data->node_tag.node.comm, &nb_nodes);
 	for(i=0 ; i<nb_nodes ; i++)
 	{
-		if (mpi_data->cache_sent[i].in_cache == 1)
+		if (mpi_data->cache_sent[i] == 1)
 		{
 			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
-			mpi_data->cache_sent[i].in_cache         = 0;
-			mpi_data->cache_sent[i].ft_induced_cache = 0;
+			mpi_data->cache_sent[i] = 0;
 			_starpu_mpi_cache_stats_dec(i, data_handle);
 		}
 	}
 
-	if (mpi_data->cache_received.in_cache == 1)
+	if (mpi_data->cache_received == 1)
 	{
 		int mpi_rank = starpu_mpi_data_get_rank(data_handle);
 		_STARPU_MPI_DEBUG(2, "Clearing received cache for data %p\n", data_handle);
-		mpi_data->cache_received.in_cache         = 0;
-		mpi_data->cache_received.ft_induced_cache = 0;
+		mpi_data->cache_received = 0;
+		mpi_data->ft_induced_cache_received = 0;
 		_starpu_mpi_cache_stats_dec(mpi_rank, data_handle);
 	}
 }

+ 3 - 2
mpi/src/starpu_mpi_private.h

@@ -206,8 +206,9 @@ struct _starpu_mpi_data
 {
 	int                         magic;
 	struct _starpu_mpi_node_tag node_tag;
-	struct _starpu_cache_info   *cache_sent;
-	struct _starpu_cache_info   cache_received;
+	unsigned int *cache_sent;
+	unsigned int cache_received;
+	unsigned int ft_induced_cache_received:1;
 
 	/** Rendez-vous data for opportunistic cooperative sends */
 	/** Needed to synchronize between submit thread and workers */

+ 0 - 1
src/util/starpu_data_cpy.c

@@ -198,7 +198,6 @@ int starpu_data_dup_ro(starpu_data_handle_t *dst_handle, starpu_data_handle_t sr
 	_starpu_spin_unlock(&src_handle->header_lock);
 
 	starpu_data_register_same(dst_handle, src_handle);
-	(*dst_handle)->mpi_data = src_handle->mpi_data;
 	_starpu_data_cpy(*dst_handle, src_handle, asynchronous, NULL, NULL, 0, NULL);
 	(*dst_handle)->readonly = 1;