Просмотр исходного кода

The count of checkpoints comm which benefits to the computation is simplified and corrected. Only the recv count is corrected (send count correction is more expensive)

Romain LION лет назад: 5
Родитель
Сommit
beb3330bd7

+ 32 - 18
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.c

@@ -35,7 +35,7 @@ int                              my_rank;
 extern struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, int sequential_consistency, int* cache_flag);
 extern struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count, int* cache_flag);
 
-
+extern int starpu_mpi_cache_set_ft_induced_cache_receive(starpu_data_handle_t data_handle);
 
 void _ack_msg_send_cb(void* _args)
 {
@@ -79,15 +79,7 @@ void _starpu_mpi_push_cp_ack_recv_cb(struct _starpu_mpi_cp_ack_arg_cb* arg)
 void _recv_internal_dup_ro_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	if (arg->cache_flag) {
-		_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->copy_handle): -1);
-	}
-	else
-	{
-		struct _starpu_mpi_data* mpi_data = _starpu_mpi_data_get(arg->copy_handle);
-		mpi_data->cache_received.ft_induced_cache = 1;
-		_STARPU_MPI_FT_STATS_RECV_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->copy_handle): -1);
-	}
+
 	starpu_data_release(arg->copy_handle);
 	_starpu_mpi_store_data_and_send_ack_cb(arg);
 }
@@ -95,7 +87,7 @@ void _recv_internal_dup_ro_cb(void* _args)
 void _recv_cp_external_data_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_RECV_CP_DATA(arg->type==STARPU_VALUE?arg->count:arg->type==STARPU_R?starpu_data_get_size(arg->handle):-1);
+	_STARPU_MPI_FT_STATS_RECV_CP_DATA(starpu_data_get_size(arg->handle));
 	// an handle has specifically been created, Let's get the value back, and unregister the handle
 	arg->copy_handle = starpu_data_handle_to_pointer(arg->handle, STARPU_MAIN_RAM);
 	starpu_data_unregister_submit(arg->handle);
@@ -105,7 +97,7 @@ void _recv_cp_external_data_cb(void* _args)
 void _send_cp_external_data_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	_STARPU_MPI_FT_STATS_SEND_CP_DATA(arg->type==STARPU_VALUE?arg->count:arg->type==STARPU_R?starpu_data_get_size(arg->handle):-1);
+	_STARPU_MPI_FT_STATS_SEND_CP_DATA(starpu_data_get_size(arg->handle));
 	free(starpu_data_handle_to_pointer(arg->handle, STARPU_MAIN_RAM));
 	starpu_data_unregister_submit(arg->handle);
 	_starpu_mpi_push_cp_ack_recv_cb(arg);
@@ -113,18 +105,39 @@ void _send_cp_external_data_cb(void* _args)
 
 void _send_cp_internal_data_cb(void* _args) {
 	struct _starpu_mpi_cp_ack_arg_cb *arg = (struct _starpu_mpi_cp_ack_arg_cb *) _args;
+	_starpu_mpi_push_cp_ack_recv_cb(arg);
+}
+
+void _send_internal_data_stats(struct _starpu_mpi_cp_ack_arg_cb* arg)
+{
 	if (arg->cache_flag) {
-		_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+		_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA( starpu_data_get_size(arg->handle));
 	}
 	else
 	{
-		struct _starpu_mpi_data* mpi_data = _starpu_mpi_data_get(arg->handle);
-		mpi_data->cache_sent[arg->rank].ft_induced_cache = 1;
-		_STARPU_MPI_FT_STATS_SEND_CP_DATA(arg->type == STARPU_VALUE ? arg->count : arg->type == STARPU_R ? starpu_data_get_size(arg->handle): -1);
+		_STARPU_MPI_FT_STATS_SEND_CP_DATA(starpu_data_get_size(arg->handle));
 	}
-	_starpu_mpi_push_cp_ack_recv_cb(arg);
 }
 
+#ifdef STARPU_USE_MPI_FT_STATS
+void _recv_internal_data_stats(struct _starpu_mpi_cp_ack_arg_cb* arg)
+{
+	if (arg->cache_flag) {
+		_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA( starpu_data_get_size(arg->handle));
+	}
+	else
+	{
+		_STARPU_MPI_FT_STATS_RECV_CP_DATA(starpu_data_get_size(arg->handle));
+		starpu_mpi_cache_set_ft_induced_cache_receive(arg->handle);
+	}
+}
+#else
+void _recv_internal_data_stats(STARPU_ATTRIBUTE_UNUSED struct _starpu_mpi_cp_ack_arg_cb* arg)
+{
+	return;
+}
+#endif
+
 int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_template)
 {
 	starpu_data_handle_t handle;
@@ -188,7 +201,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					_starpu_mpi_isend_cache_aware(handle, item->backupped_by, starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0, 0,
 					                              &_send_cp_internal_data_cb, (void*)arg, 1, &arg->cache_flag);
 					// the callbacks need to post ack recv. The cache one needs to release the handle.
-
+					_send_internal_data_stats(arg);
 				}
 				else if (item->backup_of == starpu_mpi_data_get_rank(handle))
 				{
@@ -204,6 +217,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					_starpu_mpi_irecv_cache_aware(handle, starpu_mpi_data_get_rank(handle), starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0,
 												  NULL, NULL, 1, 0, 1, &arg->cache_flag);
 					// The callback needs to do nothing. The cached one must release the handle.
+					_recv_internal_data_stats(arg);
 					starpu_data_dup_ro(&arg->copy_handle, arg->handle, 1);
 					starpu_data_acquire_cb(arg->copy_handle, STARPU_R, _recv_internal_dup_ro_cb, arg);
 					// The callback need to store the data and post ack send.

+ 34 - 35
mpi/src/starpu_mpi_cache.c

@@ -130,13 +130,12 @@ void _starpu_mpi_cache_data_init(starpu_data_handle_t data_handle)
 		return;
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
-	mpi_data->cache_received.in_cache         = 0;
-	mpi_data->cache_received.ft_induced_cache = 0;
+	mpi_data->cache_received = 0;
+	mpi_data->ft_induced_cache_received = 0;
 	_STARPU_MALLOC(mpi_data->cache_sent, _starpu_cache_comm_size*sizeof(mpi_data->cache_sent[0]));
 	for(i=0 ; i<_starpu_cache_comm_size ; i++)
 	{
-		mpi_data->cache_sent[i].in_cache         = 0;
-		mpi_data->cache_sent[i].ft_induced_cache = 0;
+		mpi_data->cache_sent[i] = 0;
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 }
@@ -187,14 +186,14 @@ void starpu_mpi_cached_receive_clear(starpu_data_handle_t data_handle)
 	STARPU_ASSERT(mpi_data->magic == 42);
 	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
 
-	if (mpi_data->cache_received.in_cache == 1)
+	if (mpi_data->cache_received == 1)
 	{
 #ifdef STARPU_DEVEL
 #  warning TODO: Somebody else will write to the data, so discard our cached copy if any. starpu_mpi could just remember itself.
 #endif
 		_STARPU_MPI_DEBUG(2, "Clearing receive cache for data %p\n", data_handle);
-		mpi_data->cache_received.in_cache         = 0;
-		mpi_data->cache_received.ft_induced_cache = 0;
+		mpi_data->cache_received = 0;
+		mpi_data->ft_induced_cache_received = 0;
 		starpu_data_invalidate_submit(data_handle);
 		_starpu_mpi_cache_data_remove_nolock(data_handle);
 		_starpu_mpi_cache_stats_dec(mpi_rank, data_handle);
@@ -202,6 +201,12 @@ void starpu_mpi_cached_receive_clear(starpu_data_handle_t data_handle)
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 }
 
+int starpu_mpi_cache_set_ft_induced_cache_receive(starpu_data_handle_t data_handle)
+{
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
+	mpi_data->ft_induced_cache_received = 1;
+}
+
 int starpu_mpi_cached_receive_set(starpu_data_handle_t data_handle)
 {
 	int mpi_rank = starpu_mpi_data_get_rank(data_handle);
@@ -214,22 +219,24 @@ int starpu_mpi_cached_receive_set(starpu_data_handle_t data_handle)
 	STARPU_ASSERT(mpi_data->magic == 42);
 	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
 
-	int already_received = mpi_data->cache_received.in_cache;
+	int already_received = mpi_data->cache_received;
 	if (already_received == 0)
 	{
 		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been received by %d\n", data_handle, mpi_rank);
-		mpi_data->cache_received.in_cache = 1;
+		mpi_data->cache_received = 1;
 		_starpu_mpi_cache_data_add_nolock(data_handle);
 		_starpu_mpi_cache_stats_inc(mpi_rank, data_handle);
 	}
 	else
 	{
-		if (mpi_data->cache_received.ft_induced_cache == 1)
-		{
-			_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(starpu_data_get_size(data_handle));
-			_STARPU_MPI_FT_STATS_CANCEL_RECV_CP_DATA(starpu_data_get_size(data_handle));
-			mpi_data->cache_received.ft_induced_cache = 0;
-		}
+		#ifdef STARPU_USE_MPI_FT_STATS
+			if (mpi_data->ft_induced_cache_received == 1)
+			{
+				_STARPU_MPI_FT_STATS_RECV_CACHED_CP_DATA(starpu_data_get_size(data_handle));
+				_STARPU_MPI_FT_STATS_CANCEL_RECV_CP_DATA(starpu_data_get_size(data_handle));
+				mpi_data->ft_induced_cache_received = 0;
+			}
+		#endif //STARPU_USE_MPI_FT_STATS
 		_STARPU_MPI_DEBUG(2, "Do not receive data %p from node %d as it is already available\n", data_handle, mpi_rank);
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
@@ -246,7 +253,7 @@ int starpu_mpi_cached_receive(starpu_data_handle_t data_handle)
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
 	STARPU_ASSERT(mpi_data->magic == 42);
-	already_received = mpi_data->cache_received.in_cache;
+	already_received = mpi_data->cache_received;
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 	return already_received;
 }
@@ -266,11 +273,10 @@ void starpu_mpi_cached_send_clear(starpu_data_handle_t data_handle)
 	starpu_mpi_comm_size(mpi_data->node_tag.node.comm, &size);
 	for(n=0 ; n<size ; n++)
 	{
-		if (mpi_data->cache_sent[n].in_cache == 1)
+		if (mpi_data->cache_sent[n] == 1)
 		{
 			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
-			mpi_data->cache_sent[n].in_cache = 0;
-			mpi_data->cache_sent[n].ft_induced_cache = 0;
+			mpi_data->cache_sent[n] = 0;
 			_starpu_mpi_cache_data_remove_nolock(data_handle);
 		}
 	}
@@ -287,21 +293,15 @@ int starpu_mpi_cached_send_set(starpu_data_handle_t data_handle, int dest)
 	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
-	int already_sent = mpi_data->cache_sent[dest].in_cache;
-	if (mpi_data->cache_sent[dest].in_cache == 0)
+	int already_sent = mpi_data->cache_sent[dest];
+	if (mpi_data->cache_sent[dest] == 0)
 	{
-		mpi_data->cache_sent[dest].in_cache = 1;
+		mpi_data->cache_sent[dest] = 1;
 		_starpu_mpi_cache_data_add_nolock(data_handle);
 		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been sent to %d\n", data_handle, dest);
 	}
 	else
 	{
-		if (mpi_data->cache_sent[dest].ft_induced_cache == 1)
-		{
-			_STARPU_MPI_FT_STATS_SEND_CACHED_CP_DATA(starpu_data_get_size(data_handle));
-			_STARPU_MPI_FT_STATS_CANCEL_SEND_CP_DATA(starpu_data_get_size(data_handle));
-			mpi_data->cache_sent[dest].ft_induced_cache = 0;
-		}
 		_STARPU_MPI_DEBUG(2, "Do not send data %p to node %d as it has already been sent\n", data_handle, dest);
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
@@ -318,7 +318,7 @@ int starpu_mpi_cached_send(starpu_data_handle_t data_handle, int dest)
 
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
 	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
-	already_sent = mpi_data->cache_sent[dest].in_cache;
+	already_sent = mpi_data->cache_sent[dest];
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 	return already_sent;
 }
@@ -334,21 +334,20 @@ static void _starpu_mpi_cache_flush_nolock(starpu_data_handle_t data_handle)
 	starpu_mpi_comm_size(mpi_data->node_tag.node.comm, &nb_nodes);
 	for(i=0 ; i<nb_nodes ; i++)
 	{
-		if (mpi_data->cache_sent[i].in_cache == 1)
+		if (mpi_data->cache_sent[i] == 1)
 		{
 			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
-			mpi_data->cache_sent[i].in_cache         = 0;
-			mpi_data->cache_sent[i].ft_induced_cache = 0;
+			mpi_data->cache_sent[i] = 0;
 			_starpu_mpi_cache_stats_dec(i, data_handle);
 		}
 	}
 
-	if (mpi_data->cache_received.in_cache == 1)
+	if (mpi_data->cache_received == 1)
 	{
 		int mpi_rank = starpu_mpi_data_get_rank(data_handle);
 		_STARPU_MPI_DEBUG(2, "Clearing received cache for data %p\n", data_handle);
-		mpi_data->cache_received.in_cache         = 0;
-		mpi_data->cache_received.ft_induced_cache = 0;
+		mpi_data->cache_received = 0;
+		mpi_data->ft_induced_cache_received = 0;
 		_starpu_mpi_cache_stats_dec(mpi_rank, data_handle);
 	}
 }

+ 3 - 2
mpi/src/starpu_mpi_private.h

@@ -206,8 +206,9 @@ struct _starpu_mpi_data
 {
 	int                         magic;
 	struct _starpu_mpi_node_tag node_tag;
-	struct _starpu_cache_info   *cache_sent;
-	struct _starpu_cache_info   cache_received;
+	unsigned int *cache_sent;
+	unsigned int cache_received;
+	unsigned int ft_induced_cache_received:1;
 
 	/** Rendez-vous data for opportunistic cooperative sends */
 	/** Needed to synchronize between submit thread and workers */

+ 0 - 1
src/util/starpu_data_cpy.c

@@ -198,7 +198,6 @@ int starpu_data_dup_ro(starpu_data_handle_t *dst_handle, starpu_data_handle_t sr
 	_starpu_spin_unlock(&src_handle->header_lock);
 
 	starpu_data_register_same(dst_handle, src_handle);
-	(*dst_handle)->mpi_data = src_handle->mpi_data;
 	_starpu_data_cpy(*dst_handle, src_handle, asynchronous, NULL, NULL, 0, NULL);
 	(*dst_handle)->readonly = 1;