Bladeren bron

Checkpoint data cpy implementation

Romain LION 5 jaren geleden
bovenliggende
commit
a045a36dc1

+ 5 - 0
mpi/src/mpi/starpu_mpi_mpi.c

@@ -1693,6 +1693,11 @@ void _starpu_mpi_driver_init(struct starpu_conf *conf)
 	}
 }
 
+void _starpu_mpi_wake_up_progress_thread()
+{
+	STARPU_PTHREAD_COND_SIGNAL(&progress_cond);
+}
+
 void _starpu_mpi_driver_shutdown()
 {
 	if (mpi_driver)

+ 2 - 0
mpi/src/mpi/starpu_mpi_mpi.h

@@ -42,6 +42,8 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm);
 int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status);
 int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status);
 
+void _starpu_mpi_wake_up_progress_thread();
+
 void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req);
 void _starpu_mpi_irecv_size_func(struct _starpu_mpi_req *req);
 

+ 66 - 19
mpi/src/starpu_mpi_checkpoint.c

@@ -23,6 +23,7 @@
 #include <sys/param.h>
 #include <starpu_mpi_private.h>
 #include <mpi/starpu_mpi_mpi_backend.h> // Should be deduced at preprocessing (Nmad vs MPI)
+#include <mpi/starpu_mpi_mpi.h>
 #include "starpu_mpi_cache.h"
 
 #define MAX_CP_TEMPLATE_NUMBER 32 // Arbitrary limit
@@ -36,8 +37,13 @@ static struct _starpu_mpi_req_list detached_ft_service_requests;
 static unsigned detached_send_n_ft_service_requests;
 static starpu_pthread_mutex_t detached_ft_service_requests_mutex;
 
-void _starpu_mpi_post_cp_ack_recv_cb(void* _args);
-void _starpu_mpi_post_cp_ack_send_cb(void* _args);
+void _starpu_checkpoint_cached_data_recv_copy_and_ack(void* _arg);
+void _starpu_checkpoint_data_recv_copy_and_ack(void* _arg);
+void _starpu_checkpoint_cached_data_send_copy_and_ack(void* _arg);
+void _starpu_checkpoint_data_send_copy_and_ack(void* _arg);
+
+void _starpu_mpi_push_cp_ack_recv_cb(void* _args);
+void _starpu_mpi_push_cp_ack_send_cb(void* _args);
 void _starpu_mpi_treat_cache_ack_no_lock_cb(void* args);
 
 extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count);
@@ -109,7 +115,7 @@ static int _starpu_mpi_checkpoint_template_register(starpu_mpi_checkpoint_templa
 	return 0;
 }
 
-struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, void (*alt_callback)(void *), void *alt_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count)
+struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *_arg, void (*alt_callback)(void *), void *_alt_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count)
 {
 	struct _starpu_mpi_req* req = NULL;
 	int already_received = _starpu_mpi_cache_received_data_set(data_handle);
@@ -118,17 +124,17 @@ struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_
 		if (data_tag == -1)
 			_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
 		_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, source);
-		req = _starpu_mpi_irecv_common(data_handle, source, data_tag, comm, detached, sync, callback, (void*)arg, sequential_consistency, is_internal_req, count);
+		req = _starpu_mpi_irecv_common(data_handle, source, data_tag, comm, detached, sync, callback, _arg, sequential_consistency, is_internal_req, count);
 	}
 	else
 	{
 		fprintf(stderr, "STARPU CACHE: Data already received\n");
-		alt_callback(alt_arg);
+		starpu_data_acquire_cb(data_handle, STARPU_R, alt_callback, _alt_arg);
 	}
 	return req;
 }
 
-struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg, void (*alt_callback)(void *), void *alt_arg, int sequential_consistency)
+struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *_arg, void (*alt_callback)(void *), void *_alt_arg, int sequential_consistency)
 {
 	struct _starpu_mpi_req* req = NULL;
 	int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, dest);
@@ -137,19 +143,20 @@ struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_
 		if (data_tag == -1)
 			_STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n");
 		_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, mpi_rank);
-		req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, detached, sync, prio, callback, (void*)arg, sequential_consistency);
+		req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, detached, sync, prio, callback, _arg, sequential_consistency);
 	}
 	else
 	{
 		fprintf(stderr, "STARPU CACHE: Data already sent\n");
-		alt_callback(alt_arg);
+		starpu_data_acquire_cb(data_handle, STARPU_R, alt_callback, _alt_arg);
 	}
 	return req;
 }
 
 int _starpu_mpi_checkpoint_template_submit(starpu_mpi_checkpoint_template_t cp_template)
 {
-	starpu_data_handle_t* handle;
+	// TODO: For now checkpoint are not taken asynchronously. It will be later, and then we will have to acquire READ permissions to StarPU in order to not have the data potentially corrupted.
+	starpu_data_handle_t handle;
 	struct _starpu_mpi_checkpoint_template_item* item;
 	//MPI_Comm comm;
 
@@ -160,7 +167,6 @@ int _starpu_mpi_checkpoint_template_submit(starpu_mpi_checkpoint_template_t cp_t
 	cp_template->remaining_ack_awaited = cp_template->message_number;
 
 	item = _starpu_mpi_checkpoint_template_get_first_data(cp_template);
-	fprintf(stderr, "begin iter\n");
 
 	while (item != _starpu_mpi_checkpoint_template_end(cp_template))
 	{
@@ -173,24 +179,28 @@ int _starpu_mpi_checkpoint_template_submit(starpu_mpi_checkpoint_template_t cp_t
 //				starpu_mpi_send
 				break;
 			case STARPU_R:
-				handle = (starpu_data_handle_t*)item->ptr;
-				if (starpu_mpi_data_get_rank(*handle)==my_rank)
+				handle = (starpu_data_handle_t)item->ptr;
+				if (starpu_mpi_data_get_rank(handle)==my_rank)
 				{
-					fprintf(stderr,"sending to %d (tag %d)\n", item->backup_rank, (int)starpu_mpi_data_get_tag(*handle));
+					fprintf(stderr,"sending to %d (tag %d)\n", item->backup_rank, (int)starpu_mpi_data_get_tag(handle));
 					struct _starpu_mpi_cp_ack_arg_cb* arg = calloc(1, sizeof(struct _starpu_mpi_cp_ack_arg_cb));
 					arg->rank = item->backup_rank;
+					arg->handle = handle;
 					arg->msg.checkpoint_id = cp_template->cp_template_id;
 					arg->msg.checkpoint_instance = cp_template->cp_template_current_instance;
-					_starpu_mpi_isend_cache_aware(*handle, item->backup_rank, starpu_mpi_data_get_tag(*handle), MPI_COMM_WORLD, 1, 0, 0, &_starpu_mpi_post_cp_ack_recv_cb, (void*)arg, &_starpu_mpi_treat_cache_ack_no_lock_cb, (void*)cp_template, 1);
+					_starpu_mpi_isend_cache_aware(handle, item->backup_rank, starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0, 0,
+					                              &_starpu_checkpoint_data_send_copy_and_ack, (void*)arg, &_starpu_checkpoint_cached_data_send_copy_and_ack, (void*)cp_template, 1);
 				}
 				else if (item->backup_rank==my_rank)
 				{
-					fprintf(stderr,"recving from %d (tag %d)\n", starpu_mpi_data_get_rank(*handle), (int)starpu_mpi_data_get_tag(*handle));
+					fprintf(stderr,"recving from %d (tag %d)\n", starpu_mpi_data_get_rank(handle), (int)starpu_mpi_data_get_tag(handle));
 					struct _starpu_mpi_cp_ack_arg_cb* arg = calloc(1, sizeof(struct _starpu_mpi_cp_ack_arg_cb));
-					arg->rank = starpu_mpi_data_get_rank(*handle);
+					arg->rank = starpu_mpi_data_get_rank(handle);
+					arg->handle = handle;
 					arg->msg.checkpoint_id = cp_template->cp_template_id;
 					arg->msg.checkpoint_instance = cp_template->cp_template_current_instance;
-					_starpu_mpi_irecv_cache_aware(*handle, starpu_mpi_data_get_rank(*handle), starpu_mpi_data_get_tag(*handle), MPI_COMM_WORLD, 1, 0, &_starpu_mpi_post_cp_ack_send_cb, (void*)arg, NULL, NULL, 1, 1, 1);
+					_starpu_mpi_irecv_cache_aware(handle, starpu_mpi_data_get_rank(handle), starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0,
+					                              &_starpu_checkpoint_data_recv_copy_and_ack, (void*)arg, &_starpu_checkpoint_cached_data_recv_copy_and_ack, (void*)arg, 1, 1, 1);
 				}
 				break;
 		}
@@ -304,6 +314,36 @@ void _print_ack_sent_cb(void* _args)
 	free(_args);
 }
 
+void _starpu_checkpoint_cached_data_recv_copy_and_ack(void* _arg)
+{
+	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
+	starpu_data_register_same(&arg->copy_handle, arg->handle);
+	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_send_cb, _arg);
+	starpu_data_release(arg->handle);
+}
+
+void _starpu_checkpoint_data_recv_copy_and_ack(void* _arg)
+{
+	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
+	starpu_data_register_same(&arg->copy_handle, arg->handle);
+	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_send_cb, _arg);
+}
+
+void _starpu_checkpoint_cached_data_send_copy_and_ack(void* _arg)
+{
+	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
+	starpu_data_register_same(&arg->copy_handle, arg->handle);
+	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_recv_cb, _arg);
+	starpu_data_release(arg->handle);
+}
+
+void _starpu_checkpoint_data_send_copy_and_ack(void* _arg)
+{
+	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
+	starpu_data_register_same(&arg->copy_handle, arg->handle);
+	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_recv_cb, _arg);
+}
+
 void _starpu_mpi_treat_cache_ack_no_lock_cb(void* args)
 {
 	starpu_mpi_checkpoint_template_t cp_template = (starpu_mpi_checkpoint_template_t)args;
@@ -336,17 +376,20 @@ void _starpu_mpi_treat_ack_receipt_cb(void* _args)
 	starpu_pthread_mutex_unlock(&cp_template_mutex);
 }
 
-void _starpu_mpi_post_cp_ack_send_cb(void* _args)
+void _starpu_mpi_push_cp_ack_send_cb(void* _args)
 {
 	struct _starpu_mpi_req* req;
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
 
 	fprintf(stderr, "Send cb\n");
 
+	//starpu_data_acquire(arg->copy_handle, STARPU_R); //Kept in R mode until use or when checkpoint becomes out-of-date
+
 	/* Initialize the request structure */
 	_starpu_mpi_request_init(&req);
 	req->request_type = SEND_REQ;
 	/* prio_list is sorted by increasing values */
+	//TODO: Check compatibility with prio
 	if (_starpu_mpi_use_prio)
 		req->prio = 0;
 	req->data_handle = NULL;
@@ -374,9 +417,11 @@ void _starpu_mpi_post_cp_ack_send_cb(void* _args)
 	req->submitted = 1;
 
 	STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
+
+	_starpu_mpi_wake_up_progress_thread();
 }
 
-void _starpu_mpi_post_cp_ack_recv_cb(void* _args)
+void _starpu_mpi_push_cp_ack_recv_cb(void* _args)
 {
 	struct _starpu_mpi_req* req;
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
@@ -410,6 +455,8 @@ void _starpu_mpi_post_cp_ack_recv_cb(void* _args)
 	req->submitted = 1;
 
 	STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
+
+	_starpu_mpi_wake_up_progress_thread();
 }
 
 static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)

+ 4 - 2
mpi/src/starpu_mpi_checkpoint.h

@@ -34,7 +34,9 @@ struct _starpu_mpi_cp_ack_msg
 
 struct _starpu_mpi_cp_ack_arg_cb
 {
-	int rank;
+	int                           rank;
+	starpu_data_handle_t          handle;
+	starpu_data_handle_t          copy_handle;
 	struct _starpu_mpi_cp_ack_msg msg;
 };
 
@@ -125,7 +127,7 @@ static inline int _starpu_mpi_checkpoint_template_freeze(starpu_mpi_checkpoint_t
 				cp_template->message_number++;
 				break;
 			case STARPU_R:
-				if (starpu_mpi_data_get_rank(*(starpu_data_handle_t *) item->ptr))
+				if (starpu_mpi_data_get_rank((starpu_data_handle_t) item->ptr))
 				{
 					cp_template->message_number++;
 				}

+ 2 - 2
mpi/tests/checkpoints.c

@@ -120,8 +120,8 @@ int test_checkpoint_submit(int argc, char* argv[])
 	starpu_mpi_data_register(handle1, 200, 1);
 
 	starpu_mpi_checkpoint_template_register(&cp_template, 321,
-			STARPU_R, &handle0, 1,
-			STARPU_R, &handle1, 0,
+			STARPU_R, handle0, 1,
+			STARPU_R, handle1, 0,
 			0);
 
 	switch (me)