Przeglądaj źródła

Change service msg submission mechs

Romain LION 4 lat temu
rodzic
commit
1a5f218f25

+ 4 - 3
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.c

@@ -64,15 +64,16 @@ void _starpu_mpi_store_data_and_send_ack_cb(struct _starpu_mpi_cp_ack_arg_cb* ar
 {
 	checkpoint_package_data_add(arg->msg.checkpoint_id, arg->msg.checkpoint_instance, arg->rank, arg->tag, arg->type, arg->copy_handle, arg->count);
 	_STARPU_MPI_DEBUG(3,"Send ack msg to %d: id=%d inst=%d\n", arg->rank, arg->msg.checkpoint_id, arg->msg.checkpoint_instance);
-	_ft_service_msg_isend_cb((void *) &arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank,
+	_starpu_mpi_ft_service_post_send((void *) &arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank,
 	                         _STARPU_MPI_TAG_CP_ACK, MPI_COMM_WORLD, _ack_msg_send_cb, arg);
 }
 
 void _starpu_mpi_push_cp_ack_recv_cb(struct _starpu_mpi_cp_ack_arg_cb* arg)
 {
 	_STARPU_MPI_DEBUG(3, "Posting ack recv cb from %d\n", arg->rank);
-	_ft_service_msg_irecv_cb((void *) &arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank,
-	                         _STARPU_MPI_TAG_CP_ACK, MPI_COMM_WORLD, _ack_msg_recv_cb, arg);
+	_starpu_mpi_ft_service_post_special_recv(_STARPU_MPI_TAG_CP_ACK);
+//	_ft_service_msg_irecv_cb((void *) &arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank,
+//	                         _STARPU_MPI_TAG_CP_ACK, MPI_COMM_WORLD, _ack_msg_recv_cb, arg);
 }
 
 void _recv_internal_dup_ro_cb(void* _args)

+ 1 - 0
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint.h

@@ -61,6 +61,7 @@ struct _starpu_mpi_cp_discard_arg_cb
 	struct _starpu_mpi_cp_info_msg msg;
 };
 
+void _ack_msg_recv_cb(void* _args);
 
 #ifdef __cplusplus
 }

+ 17 - 6
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint_template.c

@@ -177,6 +177,7 @@ void _cp_discard_message_recv_cb(void* _args)
 	_STARPU_MPI_FT_STATS_RECV_FT_SERVICE_MSG(sizeof(struct _starpu_mpi_cp_ack_msg));
 	_STARPU_MPI_DEBUG(0, "DISCARDING OLD CHECKPOINT DATA of rank %d - new one is CPID:%d - CPINST:%d\n", arg->rank, arg->msg.checkpoint_id, arg->msg.checkpoint_instance);
 	checkpoint_package_data_del(arg->msg.checkpoint_id, arg->msg.checkpoint_instance, arg->rank);
+	// TODO free _args
 }
 
 
@@ -195,16 +196,20 @@ int _starpu_mpi_checkpoint_post_cp_discard_recv(starpu_mpi_checkpoint_template_t
 		starpu_malloc((void**)&arg, sizeof(struct _starpu_mpi_cp_discard_arg_cb));
 		arg->rank = cp_template->backup_of_array[i];
 		_STARPU_MPI_DEBUG(10, "Post DISCARD msg reception from %d\n", arg->rank);
-		_ft_service_msg_irecv_cb(&arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank, _STARPU_MPI_TAG_CP_INFO,
-		                         MPI_COMM_WORLD, _cp_discard_message_recv_cb, (void *) arg);
+
+		_starpu_mpi_ft_service_post_special_recv(_STARPU_MPI_TAG_CP_INFO);
+//		_ft_service_msg_irecv_cb(&arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank, _STARPU_MPI_TAG_CP_INFO,
+//		                         MPI_COMM_WORLD, _cp_discard_message_recv_cb, (void *) arg);
 	}
 	return i;
 }
 
 void _cp_discard_message_send_cb(void* _args)
 {
+	struct _starpu_mpi_cp_discard_arg_cb* arg = (struct _starpu_mpi_cp_discard_arg_cb*) _args;
 	_STARPU_MPI_FT_STATS_SEND_FT_SERVICE_MSG(sizeof(struct _starpu_mpi_cp_ack_msg));
-	starpu_free(_args);
+	fprintf(stderr, "free_args\n");
+	free(_args);
 }
 
 int _starpu_mpi_checkpoint_post_cp_discard_send(starpu_mpi_checkpoint_template_t cp_template, int cp_id, int cp_instance)
@@ -219,15 +224,15 @@ int _starpu_mpi_checkpoint_post_cp_discard_send(starpu_mpi_checkpoint_template_t
 
 	for (i=0 ; i < cp_template->backupped_by_array_used_size ; i++)
 	{
-		starpu_malloc((void**)&arg, sizeof(struct _starpu_mpi_cp_discard_arg_cb));
+		arg = malloc(sizeof(struct _starpu_mpi_cp_discard_arg_cb));
 		arg->rank = cp_template->backupped_by_array[i];
 		_STARPU_MPI_DEBUG(10, "Post CP DISCARD msg sending to %d\n", arg->rank);
 		arg->msg.discard=1;
 		arg->msg.validation=0;
 		arg->msg.checkpoint_id = cp_id;
 		arg->msg.checkpoint_instance = cp_instance;
-		_ft_service_msg_isend_cb(&arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank, _STARPU_MPI_TAG_CP_INFO,
-		                         MPI_COMM_WORLD, _cp_discard_message_send_cb, (void *) arg);
+		_starpu_mpi_ft_service_post_send(&arg->msg, sizeof(struct _starpu_mpi_cp_ack_msg), arg->rank,
+				_STARPU_MPI_TAG_CP_INFO, MPI_COMM_WORLD, _cp_discard_message_send_cb, (void *) arg);
 	}
 
 	return 0;
@@ -503,6 +508,12 @@ int _checkpoint_template_digest_ack_reception(int checkpoint_id, int checkpoint_
 	return 0;
 }
 
+void _checkpoint_template_digest_ack_reception_cb(void* _arg)
+{
+	struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _arg;
+	_checkpoint_template_digest_ack_reception(arg->msg.checkpoint_id, arg->msg.checkpoint_instance);
+}
+
 // For test purpose
 int _starpu_mpi_checkpoint_template_print(starpu_mpi_checkpoint_template_t cp_template)
 {

+ 2 - 0
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint_template.h

@@ -45,6 +45,8 @@ void checkpoint_template_lib_init(void);
 void checkpoint_template_lib_quit(void);
 
 int _checkpoint_template_digest_ack_reception(int checkpoint_id, int checkpoint_instance);
+void _checkpoint_template_digest_ack_reception_cb(void* _arg);
+void _cp_discard_message_recv_cb(void* _args);
 
 starpu_mpi_checkpoint_template_t _starpu_mpi_get_checkpoint_template_by_id(int checkpoint_id);
 int _starpu_mpi_checkpoint_post_cp_discard_recv(starpu_mpi_checkpoint_template_t cp_template);

+ 2 - 2
mpi/src/mpi_failure_tolerance/starpu_mpi_ft.c

@@ -28,7 +28,7 @@ int starpu_mpi_ft_turn_on(void)
 {
 	STARPU_PTHREAD_MUTEX_INIT(&ft_mutex, NULL);
 	starpu_mpi_comm_rank(MPI_COMM_WORLD, &_my_rank); //TODO: check compatibility with several Comms behaviour
-	starpu_mpi_ft_service_lib_init();
+	starpu_mpi_ft_service_lib_init(_ack_msg_recv_cb, _cp_discard_message_recv_cb);
 	checkpoint_template_lib_init();
 	_starpu_mpi_checkpoint_tracker_init();
 	checkpoint_package_init();
@@ -49,7 +49,7 @@ int starpu_mpi_ft_turn_off(void)
 
 void starpu_mpi_ft_progress(void)
 {
-	starpu_mpi_test_ft_detached_service_requests();
+	starpu_mpi_ft_service_progress();
 }
 
 int starpu_mpi_ft_busy()

+ 186 - 75
mpi/src/mpi_failure_tolerance/starpu_mpi_ft_service_comms.c

@@ -28,10 +28,27 @@
 #include <mpi/starpu_mpi_mpi.h>
 #include "starpu_mpi_cache.h"
 
+#define SIMULTANEOUS_ACK_MSG_RECV_MAX 2
+#define SIMULTANEOUS_CP_INFO_RECV_MAX 2
+#define SIMULTANEOUS_PENDING_SEND_MAX 40
 
 static struct _starpu_mpi_req_list detached_ft_service_requests;
+static struct _starpu_mpi_req_list ready_send_ft_service_requests;
 static unsigned detached_send_n_ft_service_requests;
 static starpu_pthread_mutex_t detached_ft_service_requests_mutex;
+static starpu_pthread_mutex_t ready_send_ft_service_requests_mutex;
+static starpu_pthread_mutex_t ft_service_requests_mutex;
+
+int ready_ack_msgs_recv;
+int pending_ack_msgs_recv;
+int ready_cp_info_msgs_recv;
+int pending_cp_info_msgs_recv;
+int ready_send_ft_service_msg;
+int pending_send_ft_service_msg;
+
+typedef void (*cb_fn_type)(void*);
+cb_fn_type ack_msg_recv_cb;
+cb_fn_type cp_info_recv_cb;
 
 #ifdef STARPU_MPI_VERBOSE
 static char *_starpu_mpi_request_type(enum _starpu_mpi_request_type request_type)
@@ -49,108 +66,161 @@ static char *_starpu_mpi_request_type(enum _starpu_mpi_request_type request_type
 }
 #endif
 
-int _ft_service_msg_recv_send_common(void* ptr, int count, int rank, int tag, int req_type, MPI_Comm comm, void (*callback)(void *), void* arg)
+
+int _starpu_mpi_ft_service_submit_rdy()
 {
+	int i;
 	struct _starpu_mpi_req* req;
+	int max_loop;
 
-	/* Check if the tag is a service message */
-	STARPU_ASSERT_MSG(tag==_STARPU_MPI_TAG_CP_ACK || tag == _STARPU_MPI_TAG_CP_INFO, "Only _STARPU_MPI_TAG_CP_ACK or _STARPU_MPI_TAG_CP_INFO are service msgs.");
-
-	/* Initialize the request structure */
-	_starpu_mpi_request_init(&req);
-	req->request_type = req_type;
-//	/* prio_list is sorted by increasing values */
-//	//TODO: Check compatibility with prio
-//	if (_starpu_mpi_use_prio)
-//		req->prio = 0;
-	req->data_handle = NULL;
-	req->node_tag.node.rank = rank;
-	req->node_tag.data_tag = tag;
-	req->node_tag.node.comm = comm;
-	req->detached = 1;
-	req->ptr = ptr;
-	req->sync = 0;
-	req->datatype = MPI_BYTE;
-	req->callback = callback;
-	req->callback_arg = arg;
-	req->func = NULL;
-	req->sequential_consistency = 1;
-	req->count = count;
-
-	_mpi_backend._starpu_mpi_backend_request_fill(req, comm, 0);
+	STARPU_PTHREAD_MUTEX_LOCK(&ft_service_requests_mutex);
+	max_loop = MIN(SIMULTANEOUS_ACK_MSG_RECV_MAX-pending_ack_msgs_recv, ready_ack_msgs_recv);
+	for (i=0 ; i<max_loop ; i++)
+	{
+		struct _starpu_mpi_cp_ack_arg_cb* arg = malloc(sizeof(struct _starpu_mpi_cp_ack_arg_cb));
+		req = _starpu_mpi_request_fill(NULL, MPI_ANY_SOURCE, _STARPU_MPI_TAG_CP_ACK, MPI_COMM_WORLD,
+				1, 0, 0, ack_msg_recv_cb, arg, RECV_REQ, NULL,
+				1, 0, sizeof(arg->msg));
+		req->ptr = (void*)&arg->msg;
+		req->datatype = MPI_BYTE;
+		req->status = malloc(sizeof(MPI_Status));
 
-	STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex);
-	if (req_type==SEND_REQ) {
-		MPI_Isend(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag,
+		STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex);
+		MPI_Irecv(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag,
 		          req->node_tag.node.comm, &req->backend->data_request);
+		_STARPU_MPI_DEBUG(5, "Posting MPI_Irecv ft service msg: req %p tag %"PRIi64" src %d comm %ld ptr %p\n", req,  req->node_tag.data_tag, req->node_tag.node.rank, (long int)req->node_tag.node.comm, req->ptr);
+		_starpu_mpi_req_list_push_back(&detached_ft_service_requests, req);
+		pending_ack_msgs_recv++;
+		ready_ack_msgs_recv--;
+		req->submitted = 1;
+		STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
 	}
-	else if (req_type==RECV_REQ) {
+
+	max_loop = MIN(SIMULTANEOUS_CP_INFO_RECV_MAX-pending_cp_info_msgs_recv, ready_cp_info_msgs_recv);
+	for (i=0 ; i<max_loop ; i++)
+	{
+		struct _starpu_mpi_cp_discard_arg_cb* arg = malloc(sizeof(struct _starpu_mpi_cp_discard_arg_cb));
+		req = _starpu_mpi_request_fill(NULL, MPI_ANY_SOURCE, _STARPU_MPI_TAG_CP_INFO, MPI_COMM_WORLD,
+		                         1, 0, 0, cp_info_recv_cb, arg, RECV_REQ, NULL,
+		                         1, 0, sizeof(arg->msg));
+		req->ptr = (void*)&arg->msg;
+		req->datatype = MPI_BYTE;
+		req->status = malloc(sizeof(MPI_Status));
+
+		STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex);
 		MPI_Irecv(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag,
 		          req->node_tag.node.comm, &req->backend->data_request);
+		_STARPU_MPI_DEBUG(5, "Posting MPI_Irecv ft service msg: req %p tag %"PRIi64" src %d comm %ld ptr %p\n", req,  req->node_tag.data_tag, req->node_tag.node.rank, (long int)req->node_tag.node.comm, req->ptr);
+		_starpu_mpi_req_list_push_back(&detached_ft_service_requests, req);
+		pending_cp_info_msgs_recv++;
+		ready_cp_info_msgs_recv--;
+		req->submitted = 1;
+		STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
 	}
-	else {
-		STARPU_ASSERT_MSG(1, "Unrecognized req type: Only RECV_REQ and SEND_REQ are accepeted\n");
-	}
-	_starpu_mpi_req_list_push_back(&detached_ft_service_requests, req);
-	_STARPU_MPI_DEBUG(2, "pushed service req: %p in list %p - prev: %p - next: %p - dest:%d - tag:%d - type:%s\n", req, &detached_ft_service_requests, _starpu_mpi_req_list_prev(req), _starpu_mpi_req_list_next(req), req->node_tag.node.rank, (int)req->node_tag.data_tag, req_type ? "recv" : "send");
-	if (req_type==SEND_REQ) {
-		detached_send_n_ft_service_requests++;
-	}
-	req->submitted = 1;
 
-	STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
+	max_loop = MIN(SIMULTANEOUS_PENDING_SEND_MAX-pending_send_ft_service_msg, ready_send_ft_service_msg);
+	for (i=0 ; i<max_loop ; i++)
+	{
+		STARPU_PTHREAD_MUTEX_LOCK(&ready_send_ft_service_requests_mutex);
+		req = _starpu_mpi_req_list_pop_front(&ready_send_ft_service_requests);
+		STARPU_PTHREAD_MUTEX_UNLOCK(&ready_send_ft_service_requests_mutex);
+		STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex);
+		MPI_Isend(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag,
+		          req->node_tag.node.comm, &req->backend->data_request);
 
-	_starpu_mpi_wake_up_progress_thread();
+		_STARPU_MPI_DEBUG(5, "Posting MPI_Isend ft service msg: req %p tag %"PRIi64" src %d comm %ld ptr %p\n", req,  req->node_tag.data_tag, req->node_tag.node.rank, (long int)req->node_tag.node.comm, req->ptr);
+		_starpu_mpi_req_list_push_back(&detached_ft_service_requests, req);
+		pending_send_ft_service_msg++;
+		ready_send_ft_service_msg--;
+		req->submitted = 1;
+		STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
+	}
 
-	return 0;
+	STARPU_PTHREAD_MUTEX_UNLOCK(&ft_service_requests_mutex);
 }
 
-int _ft_service_msg_isend_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
+int _starpu_mpi_ft_service_post_special_recv(int tag)
 {
-	return _ft_service_msg_recv_send_common(msg, count, rank, tag, SEND_REQ, comm, callback, arg);
+
+	_STARPU_MPI_DEBUG(5, "Pushing ft service msg: %s tag %"PRIi64" ANYSOURCE\n", _starpu_mpi_request_type(RECV_REQ), tag);
+
+	if (tag==_STARPU_MPI_TAG_CP_ACK)
+	{
+		STARPU_PTHREAD_MUTEX_LOCK(&ft_service_requests_mutex);
+		ready_ack_msgs_recv++;
+		STARPU_PTHREAD_MUTEX_UNLOCK(&ft_service_requests_mutex);
+	}
+	else if (tag==_STARPU_MPI_TAG_CP_INFO)
+	{
+		STARPU_PTHREAD_MUTEX_LOCK(&ft_service_requests_mutex);
+		ready_cp_info_msgs_recv++;
+		STARPU_PTHREAD_MUTEX_UNLOCK(&ft_service_requests_mutex);
+	}
+	else
+	{
+		STARPU_ABORT_MSG("Only _STARPU_MPI_TAG_CP_ACK or _STARPU_MPI_TAG_CP_INFO are service msgs.\n");
+	}
+	_starpu_mpi_wake_up_progress_thread();
+	return 0;
 }
 
-int _ft_service_msg_irecv_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
+int _starpu_mpi_ft_service_post_send(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg)
 {
-	return _ft_service_msg_recv_send_common(msg, count, rank, tag, RECV_REQ, comm, callback, arg);
+	struct _starpu_mpi_req* req;
+
+	/* Check if the tag is a service message */
+	STARPU_ASSERT_MSG(tag==_STARPU_MPI_TAG_CP_ACK || tag == _STARPU_MPI_TAG_CP_INFO, "Only _STARPU_MPI_TAG_CP_ACK or _STARPU_MPI_TAG_CP_INFO are service msgs.");
+
+	/* Initialize the request structure */
+	req = _starpu_mpi_request_fill(NULL, rank, tag, comm, 1, 0, 0, callback, arg, SEND_REQ, NULL, 1, 0, count);
+//	TODO: Check compatibility with prio
+	req->ptr = msg;
+	req->datatype = MPI_BYTE;
+	req->status = malloc(sizeof(MPI_Status));
+
+	_STARPU_MPI_DEBUG(5, "Pushing ft service msg: %s req %p tag %"PRIi64" src %d ptr %p\n", _starpu_mpi_request_type(SEND_REQ), req, tag, rank, msg);
+
+	STARPU_PTHREAD_MUTEX_LOCK(&ready_send_ft_service_requests_mutex);
+	ready_send_ft_service_msg++;
+	_starpu_mpi_req_list_push_back(&ready_send_ft_service_requests, req);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&ready_send_ft_service_requests_mutex);
+
+	_starpu_mpi_wake_up_progress_thread();
+
+	return 0;
 }
 
 
-static void _starpu_mpi_handle_ft_request_termination(struct _starpu_mpi_req *req)
-{
-	_STARPU_MPI_LOG_IN();
-	_STARPU_MPI_DEBUG(2, "complete MPI request %p type %s tag %"PRIi64" src %d data %p ptr %p datatype '%s' count %d registered_datatype %d internal_req %p\n",
-			req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.node.rank, req->data_handle, req->ptr,
-			req->datatype_name, (int)req->count, req->registered_datatype, req->backend->internal_req);
 
-	if (req->backend->internal_req)
-	{
-		free(req->backend->early_data_handle);
-		req->backend->early_data_handle = NULL;
+static void _starpu_mpi_handle_ft_request_termination(struct _starpu_mpi_req *req) {
+	_STARPU_MPI_LOG_IN();
+	_STARPU_MPI_DEBUG(2,
+	                  "complete MPI request %p type %s tag %"PRIi64" src %d data %p ptr %p datatype '%s' count %d registered_datatype %d internal_req %p\n",
+	                  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.node.rank,
+	                  req->data_handle, req->ptr,
+	                  req->datatype_name, (int) req->count, req->registered_datatype, req->backend->internal_req);
+
+	if (req->backend->internal_req) {
+//		free(req->backend->early_data_handle);
+//		req->backend->early_data_handle = NULL;
 	}
-	else
-	{
-		if (req->request_type == RECV_REQ || req->request_type == SEND_REQ)
-		{
-			if (req->registered_datatype == 0)
-			{
-				if (req->request_type == SEND_REQ)
-				{
+	else {
+		if (req->request_type == RECV_REQ || req->request_type == SEND_REQ) {
+			if (req->registered_datatype == 0) {
+				if (req->request_type == SEND_REQ) {
 					// We need to make sure the communication for sending the size
 					// has completed, as MPI can re-order messages, let's call
 					// MPI_Wait to make sure data have been sent
-					starpu_free_on_node_flags(STARPU_MAIN_RAM, (uintptr_t)req->ptr, req->count, 0);
+					starpu_free_on_node_flags(STARPU_MAIN_RAM, (uintptr_t) req->ptr, req->count, 0);
 					req->ptr = NULL;
 				}
-				else if (req->request_type == RECV_REQ)
-				{
+				else if (req->request_type == RECV_REQ) {
 					// req->ptr is freed by starpu_data_unpack
 					starpu_data_unpack(req->data_handle, req->ptr, req->count);
 					starpu_memory_deallocate(STARPU_MAIN_RAM, req->count);
 				}
 			}
-			else
-			{
+			else {
 				//_starpu_mpi_datatype_free(req->data_handle, &req->datatype);
 			}
 		}
@@ -159,16 +229,29 @@ static void _starpu_mpi_handle_ft_request_termination(struct _starpu_mpi_req *re
 
 	_starpu_mpi_release_req_data(req);
 
-	if (req->backend->envelope)
-	{
+	if (req->backend->envelope) {
 		free(req->backend->envelope);
 		req->backend->envelope = NULL;
 	}
 
 	/* Execute the specified callback, if any */
 	if (req->callback)
+	{
+		if (req->request_type == RECV_REQ)
+		{
+		    if (req->node_tag.data_tag == _STARPU_MPI_TAG_CP_ACK)
+			{
+				struct _starpu_mpi_cp_ack_arg_cb* tmp = (struct _starpu_mpi_cp_ack_arg_cb *) req->callback_arg;
+				tmp->rank = req->status->MPI_SOURCE;
+			}
+		    else if (req->node_tag.data_tag == _STARPU_MPI_TAG_CP_INFO)
+		    {
+			    struct _starpu_mpi_cp_discard_arg_cb* tmp = (struct _starpu_mpi_cp_discard_arg_cb *) req->callback_arg;
+			    tmp->rank = req->status->MPI_SOURCE;
+		    }
+		}
 		req->callback(req->callback_arg);
-
+	}
 	/* tell anyone potentially waiting on the request that it is
 	 * terminated now */
 	STARPU_PTHREAD_MUTEX_LOCK(&req->backend->req_mutex);
@@ -205,7 +288,7 @@ void starpu_mpi_test_ft_detached_service_requests(void)
 		req->ret = _starpu_mpi_simgrid_mpi_test(&req->done, &flag);
 #else
 		STARPU_MPI_ASSERT_MSG(req->backend->data_request != MPI_REQUEST_NULL, "Cannot test completion of the request MPI_REQUEST_NULL");
-		req->ret = MPI_Test(&req->backend->data_request, &flag, MPI_STATUS_IGNORE);
+		req->ret = MPI_Test(&req->backend->data_request, &flag, req->status);
 #endif
 
 		STARPU_MPI_ASSERT_MSG(req->ret == MPI_SUCCESS, "MPI_Test returning %s", _starpu_mpi_get_mpi_error_code(req->ret));
@@ -224,8 +307,17 @@ void starpu_mpi_test_ft_detached_service_requests(void)
 			//_STARPU_MPI_TRACE_COMPLETE_BEGIN(req->request_type, req->node_tag.node.rank, req->node_tag.data_tag);
 
 			STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex);
+			STARPU_PTHREAD_MUTEX_LOCK(&ft_service_requests_mutex);
 			if (req->request_type == SEND_REQ)
-				detached_send_n_ft_service_requests--;
+				pending_send_ft_service_msg--;
+			if (req->request_type == RECV_REQ)
+			{
+				if (req->node_tag.data_tag == _STARPU_MPI_TAG_CP_ACK)
+					pending_ack_msgs_recv--;
+				else if (req->node_tag.data_tag == _STARPU_MPI_TAG_CP_INFO)
+					pending_cp_info_msgs_recv--;
+			}
+			STARPU_PTHREAD_MUTEX_UNLOCK(&ft_service_requests_mutex);
 			_starpu_mpi_req_list_erase(&detached_ft_service_requests, req);
 			STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex);
 			_starpu_mpi_handle_ft_request_termination(req);
@@ -260,10 +352,29 @@ void starpu_mpi_test_ft_detached_service_requests(void)
 	//_STARPU_MPI_LOG_OUT();
 }
 
-int starpu_mpi_ft_service_lib_init()
+int starpu_mpi_ft_service_progress()
+{
+	starpu_mpi_test_ft_detached_service_requests();
+	_starpu_mpi_ft_service_submit_rdy();
+	return 0;
+}
+
+int starpu_mpi_ft_service_lib_init(void(*_ack_msg_recv_cb)(void*), void(*_cp_info_recv_cb)(void*))
 {
 	_starpu_mpi_req_list_init(&detached_ft_service_requests);
+	_starpu_mpi_req_list_init(&ready_send_ft_service_requests);
 	STARPU_PTHREAD_MUTEX_INIT(&detached_ft_service_requests_mutex, NULL);
+	STARPU_PTHREAD_MUTEX_INIT(&ready_send_ft_service_requests_mutex, NULL);
+	STARPU_PTHREAD_MUTEX_INIT(&ft_service_requests_mutex, NULL);
+	ready_ack_msgs_recv = 0;
+	pending_ack_msgs_recv = 0;
+	ready_cp_info_msgs_recv = 0;
+	pending_cp_info_msgs_recv = 0;
+	ready_send_ft_service_msg = 0;
+	pending_send_ft_service_msg = 0;
+
+	ack_msg_recv_cb = _ack_msg_recv_cb;
+	cp_info_recv_cb = _cp_info_recv_cb;
 
 	return 0;
 }

+ 4 - 4
mpi/src/mpi_failure_tolerance/starpu_mpi_ft_service_comms.h

@@ -22,12 +22,12 @@ extern "C"
 {
 #endif
 
-int _ft_service_msg_recv_send_common(void* ptr, int count, int rank, int tag, int req_type, MPI_Comm comm, void (*callback)(void *), void* arg);
-int _ft_service_msg_isend_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg);
-int _ft_service_msg_irecv_cb(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg);
+int _starpu_mpi_ft_service_post_special_recv(int tag);
+int _starpu_mpi_ft_service_post_send(void* msg, int count, int rank, int tag, MPI_Comm comm, void (*callback)(void *), void* arg);
 
 void starpu_mpi_test_ft_detached_service_requests(void);
-int starpu_mpi_ft_service_lib_init();
+int starpu_mpi_ft_service_progress();
+int starpu_mpi_ft_service_lib_init(void(*_ack_msg_recv_cb)(void*), void(*cp_info_recv_cb)(void*));
 int starpu_mpi_ft_service_lib_busy();
 
 #ifdef __cplusplus

+ 6 - 10
mpi/tests/checkpoints.c

@@ -170,26 +170,22 @@ int test_checkpoint_submit(int argc, char* argv[])
 	fprintf(stderr, "\n\n");
 	usleep(150000);
 
+	starpu_data_acquire(handle0, STARPU_RW);
 	if (me==0)
-	{
-		starpu_data_acquire(handle0, STARPU_RW);
-		val0*=2;
-		starpu_data_release(handle0);
-	}
+		val0 *= 2;
+	starpu_data_release(handle0);
 
+	starpu_data_acquire(handle1, STARPU_RW);
 	if (me==1)
-	{
-		starpu_data_acquire(handle1, STARPU_RW);
 		val1*=2;
-		starpu_data_release(handle1);
-	}
+	starpu_data_release(handle1);
 
 	FPRINTF_MPI(stderr, "Submitting\n");
 	starpu_mpi_submit_checkpoint_template(cp_template, 0);
 
 	FPRINTF_MPI(stderr, "Submitted\n");
 
-	usleep(150000);
+	sleep(2);
 	fprintf(stderr, "\n\n");
 	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
 	FPRINTF_MPI(stderr, "Bye!\n");