瀏覽代碼

Add mutexs.

Romain LION 5 年之前
父節點
當前提交
7a5bde6222

+ 7 - 2
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint_package.c

@@ -32,8 +32,8 @@ int checkpoint_package_init()
 
 int checkpoint_package_shutdown()
 {
-	starpu_pthread_mutex_destroy(&package_package_mutex);
 	_checkpoint_package_data_delete_all();
+	starpu_pthread_mutex_destroy(&package_package_mutex);
 	return 0;
 }
 
@@ -47,7 +47,9 @@ int checkpoint_package_data_add(int cp_id, int cp_inst, int rank, starpu_mpi_tag
 	checkpoint_data->type = type;
 	checkpoint_data->ptr = ptr;
 	checkpoint_data->count = count;
+	starpu_pthread_mutex_lock(&package_package_mutex);
 	_starpu_mpi_checkpoint_data_list_push_back(checkpoint_data_list, checkpoint_data);
+	starpu_pthread_mutex_unlock(&package_package_mutex);
 	_STARPU_MPI_FT_STATS_STORE_CP_DATA(type==STARPU_VALUE?count:type==STARPU_R?starpu_data_get_size((starpu_data_handle_t) ptr):-1);
 	_STARPU_MPI_DEBUG(8, "CP data (%p) added - cpid:%d - cpinst:%d - rank:%d - tag:%ld\n", checkpoint_data->ptr, checkpoint_data->cp_id, checkpoint_data->cp_inst, checkpoint_data->rank, checkpoint_data->tag);
 	return 0;
@@ -83,6 +85,7 @@ int checkpoint_package_data_del(int cp_id, int cp_inst, int rank)
 	int done = 0;
 	size_t size = 0;
 	struct _starpu_mpi_checkpoint_data* next_checkpoint_data = NULL;
+	starpu_pthread_mutex_lock(&package_package_mutex);
 	struct _starpu_mpi_checkpoint_data* checkpoint_data = _starpu_mpi_checkpoint_data_list_begin(checkpoint_data_list);
 	while (checkpoint_data != _starpu_mpi_checkpoint_data_list_end(checkpoint_data_list))
 	{
@@ -96,6 +99,7 @@ int checkpoint_package_data_del(int cp_id, int cp_inst, int rank)
 		}
 		checkpoint_data = next_checkpoint_data;
 	}
+	starpu_pthread_mutex_unlock(&package_package_mutex);
 	_STARPU_MPI_FT_STATS_DISCARD_CP_DATA(size);
 	_STARPU_MPI_DEBUG(0, "cleared %d data from checkpoint database (%ld bytes).\n", done, size);
 
@@ -107,6 +111,7 @@ int _checkpoint_package_data_delete_all()
 	int done = 0;
 	size_t size = 0;
 	struct _starpu_mpi_checkpoint_data* next_checkpoint_data = NULL;
+	starpu_pthread_mutex_lock(&package_package_mutex);
 	struct _starpu_mpi_checkpoint_data* checkpoint_data = _starpu_mpi_checkpoint_data_list_begin(checkpoint_data_list);
 	while (checkpoint_data != _starpu_mpi_checkpoint_data_list_end(checkpoint_data_list))
 	{
@@ -116,7 +121,7 @@ int _checkpoint_package_data_delete_all()
 		done++;
 		checkpoint_data = next_checkpoint_data;
 	}
-
+	starpu_pthread_mutex_unlock(&package_package_mutex);
 	_STARPU_MPI_FT_STATS_DISCARD_CP_DATA(size);
 	_STARPU_MPI_DEBUG(0, "cleared %d data from checkpoint database (%ld bytes).\n", done, size);
 

+ 25 - 5
mpi/src/mpi_failure_tolerance/starpu_mpi_checkpoint_tracker.c

@@ -28,6 +28,7 @@
 //	int                              ack_msg_count;
 
 struct _starpu_mpi_checkpoint_domain_tracker_index_list* domain_tracker_list;
+starpu_pthread_mutex_t tracker_mutex;
 
 struct _starpu_mpi_checkpoint_domain_tracker_entry
 {
@@ -139,32 +140,43 @@ static inline int _domain_tracker_delete_all()
 
 
 
+
+
 int _starpu_mpi_checkpoint_tracker_init()
 {
 	domain_tracker_list = _starpu_mpi_checkpoint_domain_tracker_index_list_new();
+	starpu_pthread_mutex_init(&tracker_mutex, NULL);
 	return 0;
 }
 
 int _starpu_mpi_checkpoint_tracker_shutdown()
 {
 	_domain_tracker_delete_all();
+	starpu_pthread_mutex_destroy(&tracker_mutex);
 	free(domain_tracker_list);
 	return 0;
 }
 
-struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_get_tracking_inst_by_id_inst(int cp_domain, int cp_inst)
-{
-	struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(cp_domain);
-	if (NULL == index)
+struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_get_tracking_inst_by_id_inst(int cp_domain, int cp_inst) {
+	starpu_pthread_mutex_lock(&tracker_mutex);
+	struct _starpu_mpi_checkpoint_domain_tracker_index *index = get_domain_tracker_index(cp_domain);
+	if (NULL == index) {
+		starpu_pthread_mutex_unlock(&tracker_mutex);
 		return NULL;
-	struct _starpu_mpi_checkpoint_domain_tracker_entry* entry = get_tracker_entry(index, cp_inst);
+	}
+	struct _starpu_mpi_checkpoint_domain_tracker_entry *entry = get_tracker_entry(index, cp_inst);
 	if (NULL == entry)
+	{
+		starpu_pthread_mutex_unlock(&tracker_mutex);
 		return NULL;
+	}
+	starpu_pthread_mutex_unlock(&tracker_mutex);
 	return &entry->tracker;
 }
 
 struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_create_instance_tracker(starpu_mpi_checkpoint_template_t cp_template, int cp_id, int cp_domain, int cp_inst)
 {
+	starpu_pthread_mutex_lock(&tracker_mutex);
 	struct _starpu_mpi_checkpoint_domain_tracker_entry *entry;
 	struct _starpu_mpi_checkpoint_domain_tracker_index *index = get_domain_tracker_index(cp_domain);
 	if (NULL == index)
@@ -172,11 +184,13 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_create_in
 	entry     = get_tracker_entry(index, cp_inst);
 	if (NULL == entry)
 		entry = add_tracker_entry(index, cp_id, cp_inst, cp_domain, cp_template);
+	starpu_pthread_mutex_unlock(&tracker_mutex);
 	return &entry->tracker;
 }
 
 struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_update(starpu_mpi_checkpoint_template_t cp_template, int cp_id, int cp_domain, int cp_instance)
 {
+	starpu_pthread_mutex_lock(&tracker_mutex);
 	struct _starpu_mpi_checkpoint_domain_tracker_entry* entry;
 	struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(cp_domain);
 	if (NULL == index)
@@ -189,6 +203,7 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_update(sta
 	}
 	STARPU_ASSERT_MSG(entry->tracker.ack_msg_count>0, "Error. Trying to count ack message while all have already been received. id:%d, inst:%d, remaining_ack_messages:%d\n", entry->tracker.cp_id, entry->instance, entry->tracker.ack_msg_count);
 	entry->tracker.ack_msg_count--;
+	starpu_pthread_mutex_unlock(&tracker_mutex);
 	return &entry->tracker;
 }
 
@@ -203,6 +218,7 @@ int _starpu_mpi_checkpoint_check_tracker(struct _starpu_mpi_checkpoint_tracker*
 
 struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_validate_instance(struct _starpu_mpi_checkpoint_tracker* tracker)
 {
+	starpu_pthread_mutex_lock(&tracker_mutex);
 	// Here we validate a checkpoint and return the old cp info that must be discarded
 	struct _starpu_mpi_checkpoint_tracker* temp_tracker;
 	struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(tracker->cp_domain);
@@ -217,6 +233,7 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_validate_i
 		{
 			temp_tracker->old = 1;
 		}
+		starpu_pthread_mutex_unlock(&tracker_mutex);
 		return temp_tracker;
 	}
 	else
@@ -227,12 +244,15 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_validate_i
 		// The checkpoint to validate is older than the latest validated, just return it to discard it
 		tracker->valid = 1;
 		tracker->old =1;
+		starpu_pthread_mutex_unlock(&tracker_mutex);
 		return tracker;
 	}
 }
 
 struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_get_last_valid_tracker(int domain)
 {
+	starpu_pthread_mutex_lock(&tracker_mutex);
 	struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(domain);
+	starpu_pthread_mutex_unlock(&tracker_mutex);
 	return index->last_valid_instance;
 }