Browse Source

Fix bugs due to refactorization

Romain LION 5 years ago
parent
commit
8b30f99eef

+ 29 - 29
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.c

@@ -61,6 +61,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 	STARPU_ASSERT_MSG(cp_template->pending==0, "Can not submit a checkpoint while previous instance has not succeeded.\n");
 
 	cp_template->pending               = 1;
+	cp_template->cp_template_current_instance++;
 	cp_template->remaining_ack_awaited = cp_template->message_number;
 
 	item = _starpu_mpi_checkpoint_template_get_first_data(cp_template);
@@ -86,7 +87,7 @@ int starpu_mpi_submit_checkpoint_template(starpu_mpi_checkpoint_template_t cp_te
 					arg->msg.checkpoint_id = cp_template->cp_template_id;
 					arg->msg.checkpoint_instance = cp_template->cp_template_current_instance;
 					_starpu_mpi_isend_cache_aware(handle, item->backup_rank, starpu_mpi_data_get_tag(handle), MPI_COMM_WORLD, 1, 0, 0,
-					                              &_starpu_checkpoint_data_send_copy_and_ack, (void*)arg, &_starpu_checkpoint_cached_data_send_copy_and_ack, (void*)cp_template, 1);
+					                              &_starpu_mpi_push_cp_ack_recv_cb, (void*)arg, &_starpu_mpi_push_cp_ack_recv_cb, (void*)arg, 1);
 				}
 				else if (item->backup_rank==my_rank)
 				{
@@ -145,42 +146,41 @@ void _starpu_checkpoint_data_recv_copy_and_ack(void* _arg)
 	starpu_data_register_same(&arg->copy_handle, arg->handle);
 	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_send_cb, _arg);
 }
-
-void _starpu_checkpoint_cached_data_send_copy_and_ack(void* _arg)
-{
-	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
-	starpu_data_register_same(&arg->copy_handle, arg->handle);
-	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_recv_cb, _arg);
-	starpu_data_release(arg->handle);
-}
-
-void _starpu_checkpoint_data_send_copy_and_ack(void* _args)
-{
-	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-	starpu_data_register_same(&arg->copy_handle, arg->handle);
-	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_recv_cb, _args);
-}
-
-void _starpu_mpi_treat_cache_ack_no_lock_cb(void* _args)
-{
-	starpu_mpi_checkpoint_template_t cp_template = (starpu_mpi_checkpoint_template_t)_args;
-	cp_template->remaining_ack_awaited--;
-}
+//
+//void _starpu_checkpoint_cached_data_send_copy_and_ack(void* _arg)
+//{
+//	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
+//	starpu_data_register_same(&arg->copy_handle, arg->handle);
+//	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_recv_cb, _arg);
+//	starpu_data_release(arg->handle);
+//}
+//
+//void _starpu_checkpoint_data_send_copy_and_ack(void* _args)
+//{
+//	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
+//	starpu_data_register_same(&arg->copy_handle, arg->handle);
+//	starpu_data_cpy(arg->copy_handle, arg->handle, 1, _starpu_mpi_push_cp_ack_recv_cb, _args);
+//}
+//
+//void _starpu_mpi_treat_cache_ack_no_lock_cb(void* _args)
+//{
+//	starpu_mpi_checkpoint_template_t cp_template = (starpu_mpi_checkpoint_template_t)_args;
+//	cp_template->remaining_ack_awaited--;
+//}
 
 void _starpu_mpi_treat_ack_receipt_cb(void* _args)
 {
-	struct _starpu_mpi_cp_ack_msg* msg = (struct _starpu_mpi_cp_ack_msg*) _args;
-	if (_checkpoint_template_digest_ack_reception(msg->checkpoint_id, msg->checkpoint_instance) == 0) {
-		free(msg);
+	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
+	fprintf(stderr, "ack msg recved id:%d inst:%d\n", arg->msg.checkpoint_id, arg->msg.checkpoint_instance);
+	if (_checkpoint_template_digest_ack_reception(arg->msg.checkpoint_id, arg->msg.checkpoint_instance) == 0) {
+		free(arg);
 	}
 }
 
 void _starpu_mpi_push_cp_ack_send_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-
-	fprintf(stderr, "Send cb\n");
-
+	fprintf(stderr, "Send ack msg to %d: id=%d inst=%d\n", arg->rank, arg->msg.checkpoint_id, arg->msg.checkpoint_instance);
 	_ft_service_msg_isend_cb((void*)&arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank, _STARPU_MPI_TAG_CP_ACK, MPI_COMM_WORLD, _print_ack_sent_cb, _args);
 
 }
@@ -188,6 +188,6 @@ void _starpu_mpi_push_cp_ack_send_cb(void* _args)
 void _starpu_mpi_push_cp_ack_recv_cb(void* _args)
 {
 	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args;
-
+	fprintf(stderr, "Posting ack recv cb from %d\n", arg->rank);
 	_ft_service_msg_irecv_cb((void*)&arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank, _STARPU_MPI_TAG_CP_ACK, MPI_COMM_WORLD, _starpu_mpi_treat_ack_receipt_cb, _args);
 }

+ 3 - 7
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint_template.c

@@ -16,17 +16,11 @@
 
 #include <stdarg.h>
 #include <stdlib.h>
-#include <common/utils.h>
 
-#include <mpi_failure_tolerance/starpu_mpi_checkpoint.h>
-#include <mpi_failure_tolerance/starpu_mpi_checkpoint_template.h>
-#include <mpi_failure_tolerance/starpu_mpi_ft_service_comms.h>
-#include <mpi_failure_tolerance/starpu_mpi_checkpoint_package.h>
 #include <sys/param.h>
 #include <starpu_mpi_private.h>
-#include <mpi/starpu_mpi_mpi_backend.h> // Should be deduced at preprocessing (Nmad vs MPI)
-#include <mpi/starpu_mpi_mpi.h>
 #include <starpu_mpi_cache.h>
+#include <mpi_failure_tolerance/starpu_mpi_checkpoint_template.h>
 
 
 #define MAX_CP_TEMPLATE_NUMBER 32 // Arbitrary limit
@@ -128,11 +122,13 @@ int starpu_mpi_checkpoint_template_register(starpu_mpi_checkpoint_template_t* cp
 
 int _checkpoint_template_digest_ack_reception(int checkpoint_id, int checkpoint_instance) {
 	starpu_pthread_mutex_lock(&cp_template_mutex);
+	fprintf(stderr, "Digesting ack recv: id=%d, inst=%d\n", checkpoint_id, checkpoint_instance);
 	for (int i=0 ; i<cp_template_number ; i++)
 	{
 		starpu_pthread_mutex_lock(&cp_template_array[i]->mutex);
 		if (cp_template_array[i]->cp_template_id == checkpoint_id && cp_template_array[i]->cp_template_current_instance == checkpoint_instance)
 		{
+			fprintf(stderr, "Inst found, remaining ack msg awaited:%d\n", cp_template_array[i]->remaining_ack_awaited);
 			cp_template_array[i]->remaining_ack_awaited--;
 			if (cp_template_array[i]->remaining_ack_awaited == 0)
 			{

+ 1 - 1
mpi/src/mpi_failure_tolerance/starpu_mpi_ft.c

@@ -26,8 +26,8 @@ int                              my_rank;
 int starpu_mpi_ft_turn_on(void)
 {
 	starpu_pthread_mutex_init(&ft_mutex, NULL);
-	starpu_mpi_ft_service_lib_init();
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &my_rank); //TODO: check compatibility with several Comms behaviour
+	starpu_mpi_ft_service_lib_init();
 	checkpoint_template_lib_init();
 	return 0;
 }

+ 18 - 16
mpi/src/mpi_failure_tolerance/starpu_mpi_ft_service_comms.c

@@ -33,7 +33,7 @@ static struct _starpu_mpi_req_list detached_ft_service_requests;
 static unsigned detached_send_n_ft_service_requests;
 static starpu_pthread_mutex_t detached_ft_service_requests_mutex;
 
-int _ft_service_msg_recv_send_common(void* msg, int count, int rank, int tag, int req_type, MPI_Comm comm, void (*callback)(void *), void* arg)
+int _ft_service_msg_recv_send_common(void* ptr, int count, int rank, int tag, int req_type, MPI_Comm comm, void (*callback)(void *), void* arg)
 {
 	struct _starpu_mpi_req* req;
 
@@ -52,19 +52,20 @@ int _ft_service_msg_recv_send_common(void* msg, int count, int rank, int tag, in
 	req->node_tag.data_tag = tag;
 	req->node_tag.node.comm = comm;
 	req->detached = 1;
-	req->ptr = msg;
+	req->ptr = ptr;
 	req->sync = 0;
 	req->datatype = MPI_BYTE;
 	req->callback = callback;
 	req->callback_arg = arg;
 	req->func = NULL;
 	req->sequential_consistency = 1;
-	req->count = sizeof(struct _starpu_mpi_cp_ack_msg);
+	req->count = count;
 
 	_mpi_backend._starpu_mpi_backend_request_fill(req, comm, 0);
 
 	STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex);
 	if (req_type==SEND_REQ) {
+		fprintf(stderr, "data:%d/%d", *(int*)(req->ptr), *(int*)(req->ptr+4));
 		MPI_Isend(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag,
 		          req->node_tag.node.comm, &req->backend->data_request);
 	}
@@ -72,6 +73,9 @@ int _ft_service_msg_recv_send_common(void* msg, int count, int rank, int tag, in
 		MPI_Irecv(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag,
 		          req->node_tag.node.comm, &req->backend->data_request);
 	}
+	else {
+		STARPU_ASSERT_MSG(1, "Unrecognized req type: Only RECV_REQ and SEND_REQ are accepeted\n");
+	}
 	_starpu_mpi_req_list_push_back(&detached_ft_service_requests, req);
 	fprintf(stderr, "pushed service req: %p in list %p - prev: %p - next: %p - dest:%d - tag:%d\n", req, &detached_ft_service_requests, _starpu_mpi_req_list_prev(req), _starpu_mpi_req_list_next(req), req->node_tag.node.rank, (int)req->node_tag.data_tag);
 	if (req_type==SEND_REQ) {
@@ -82,19 +86,19 @@ int _ft_service_msg_recv_send_common(void* msg, int count, int rank, int tag, in
 	STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
 
 	_starpu_mpi_wake_up_progress_thread();
-}
-
-inline int _ft_service_msg_isend_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
-{
 
-	return _ft_service_msg_recv_send_common(msg, count, rank, tag, SEND_REQ, comm, callback, arg);
+	return 0;
 }
 
-int _ft_service_msg_irecv_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
-{
-
-	return _ft_service_msg_recv_send_common(msg, count, rank, tag, SEND_REQ, comm, callback, arg);
-}
+//inline int _ft_service_msg_isend_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
+//{
+//	return _ft_service_msg_recv_send_common(msg, count, rank, tag, SEND_REQ, comm, callback, arg);
+//}
+//
+//inline int _ft_service_msg_irecv_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
+//{
+//	return _ft_service_msg_recv_send_common(msg, count, rank, tag, SEND_REQ, comm, callback, arg);
+//}
 
 static void _starpu_mpi_handle_ft_request_termination(struct _starpu_mpi_req *req)
 {
@@ -121,8 +125,6 @@ static void _starpu_mpi_handle_ft_request_termination(struct _starpu_mpi_req *re
 					// has completed, as MPI can re-order messages, let's call
 					// MPI_Wait to make sure data have been sent
 					int ret;
-					ret = MPI_Wait(&req->backend->size_req, MPI_STATUS_IGNORE);
-					STARPU_MPI_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Wait returning %s", _starpu_mpi_get_mpi_error_code(ret));
 					starpu_free_on_node_flags(STARPU_MAIN_RAM, (uintptr_t)req->ptr, req->count, 0);
 					req->ptr = NULL;
 				}
@@ -201,7 +203,7 @@ void starpu_mpi_test_ft_detached_service_requests(void)
 		}
 		else
 		{
-			fprintf(stderr, "req success: %d\n", detached_send_n_ft_service_requests);
+			fprintf(stderr, "req success: %p\n", req);
 			_STARPU_MPI_TRACE_POLLING_END();
 			struct _starpu_mpi_req *next_req;
 			next_req = _starpu_mpi_req_list_next(req);

+ 11 - 2
mpi/src/mpi_failure_tolerance/starpu_mpi_ft_service_comms.h

@@ -22,8 +22,17 @@ extern "C"
 {
 #endif
 
-int _ft_service_msg_isend_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg);
-int _ft_service_msg_irecv_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg);
+int _ft_service_msg_recv_send_common(void* ptr, int count, int rank, int tag, int req_type, MPI_Comm comm, void (*callback)(void *), void* arg);
+
+inline int _ft_service_msg_isend_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
+{
+	return _ft_service_msg_recv_send_common(msg, count, rank, tag, SEND_REQ, comm, callback, arg);
+}
+
+inline int _ft_service_msg_irecv_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
+{
+	return _ft_service_msg_recv_send_common(msg, count, rank, tag, RECV_REQ, comm, callback, arg);
+}
 
 void starpu_mpi_test_ft_detached_service_requests(void);
 int starpu_mpi_ft_service_lib_init();