|
@@ -28,6 +28,7 @@
|
|
|
// int ack_msg_count;
|
|
|
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_index_list* domain_tracker_list;
|
|
|
+starpu_pthread_mutex_t tracker_mutex;
|
|
|
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_entry
|
|
|
{
|
|
@@ -139,32 +140,43 @@ static inline int _domain_tracker_delete_all()
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
int _starpu_mpi_checkpoint_tracker_init()
|
|
|
{
|
|
|
domain_tracker_list = _starpu_mpi_checkpoint_domain_tracker_index_list_new();
|
|
|
+ starpu_pthread_mutex_init(&tracker_mutex, NULL);
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
int _starpu_mpi_checkpoint_tracker_shutdown()
|
|
|
{
|
|
|
_domain_tracker_delete_all();
|
|
|
+ starpu_pthread_mutex_destroy(&tracker_mutex);
|
|
|
free(domain_tracker_list);
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
-struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_get_tracking_inst_by_id_inst(int cp_domain, int cp_inst)
|
|
|
-{
|
|
|
- struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(cp_domain);
|
|
|
- if (NULL == index)
|
|
|
+struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_get_tracking_inst_by_id_inst(int cp_domain, int cp_inst) {
|
|
|
+ starpu_pthread_mutex_lock(&tracker_mutex);
|
|
|
+ struct _starpu_mpi_checkpoint_domain_tracker_index *index = get_domain_tracker_index(cp_domain);
|
|
|
+ if (NULL == index) {
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return NULL;
|
|
|
- struct _starpu_mpi_checkpoint_domain_tracker_entry* entry = get_tracker_entry(index, cp_inst);
|
|
|
+ }
|
|
|
+ struct _starpu_mpi_checkpoint_domain_tracker_entry *entry = get_tracker_entry(index, cp_inst);
|
|
|
if (NULL == entry)
|
|
|
+ {
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return NULL;
|
|
|
+ }
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return &entry->tracker;
|
|
|
}
|
|
|
|
|
|
struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_create_instance_tracker(starpu_mpi_checkpoint_template_t cp_template, int cp_id, int cp_domain, int cp_inst)
|
|
|
{
|
|
|
+ starpu_pthread_mutex_lock(&tracker_mutex);
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_entry *entry;
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_index *index = get_domain_tracker_index(cp_domain);
|
|
|
if (NULL == index)
|
|
@@ -172,11 +184,13 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_template_create_in
|
|
|
entry = get_tracker_entry(index, cp_inst);
|
|
|
if (NULL == entry)
|
|
|
entry = add_tracker_entry(index, cp_id, cp_inst, cp_domain, cp_template);
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return &entry->tracker;
|
|
|
}
|
|
|
|
|
|
struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_update(starpu_mpi_checkpoint_template_t cp_template, int cp_id, int cp_domain, int cp_instance)
|
|
|
{
|
|
|
+ starpu_pthread_mutex_lock(&tracker_mutex);
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_entry* entry;
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(cp_domain);
|
|
|
if (NULL == index)
|
|
@@ -189,6 +203,7 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_update(sta
|
|
|
}
|
|
|
STARPU_ASSERT_MSG(entry->tracker.ack_msg_count>0, "Error. Trying to count ack message while all have already been received. id:%d, inst:%d, remaining_ack_messages:%d\n", entry->tracker.cp_id, entry->instance, entry->tracker.ack_msg_count);
|
|
|
entry->tracker.ack_msg_count--;
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return &entry->tracker;
|
|
|
}
|
|
|
|
|
@@ -203,6 +218,7 @@ int _starpu_mpi_checkpoint_check_tracker(struct _starpu_mpi_checkpoint_tracker*
|
|
|
|
|
|
struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_validate_instance(struct _starpu_mpi_checkpoint_tracker* tracker)
|
|
|
{
|
|
|
+ starpu_pthread_mutex_lock(&tracker_mutex);
|
|
|
// Here we validate a checkpoint and return the old cp info that must be discarded
|
|
|
struct _starpu_mpi_checkpoint_tracker* temp_tracker;
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(tracker->cp_domain);
|
|
@@ -217,6 +233,7 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_validate_i
|
|
|
{
|
|
|
temp_tracker->old = 1;
|
|
|
}
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return temp_tracker;
|
|
|
}
|
|
|
else
|
|
@@ -227,12 +244,15 @@ struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_validate_i
|
|
|
// The checkpoint to validate is older than the latest validated, just return it to discard it
|
|
|
tracker->valid = 1;
|
|
|
tracker->old =1;
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return tracker;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
struct _starpu_mpi_checkpoint_tracker* _starpu_mpi_checkpoint_tracker_get_last_valid_tracker(int domain)
|
|
|
{
|
|
|
+ starpu_pthread_mutex_lock(&tracker_mutex);
|
|
|
struct _starpu_mpi_checkpoint_domain_tracker_index* index = get_domain_tracker_index(domain);
|
|
|
+ starpu_pthread_mutex_unlock(&tracker_mutex);
|
|
|
return index->last_valid_instance;
|
|
|
}
|