Browse Source

mpi/src/starpu_mpi_cache.c: protect acces to hashtables with mutexes

Nathalie Furmento 11 years ago
parent
commit
56a874e16b
1 changed files with 54 additions and 5 deletions
  1. 54 5
      mpi/src/starpu_mpi_cache.c

+ 54 - 5
mpi/src/starpu_mpi_cache.c

@@ -30,6 +30,8 @@ struct _starpu_data_entry
 	starpu_data_handle_t data;
 };
 
+static starpu_pthread_mutex_t *_cache_sent_mutex;
+static starpu_pthread_mutex_t *_cache_received_mutex;
 static struct _starpu_data_entry **_cache_sent_data = NULL;
 static struct _starpu_data_entry **_cache_received_data = NULL;
 int _cache_enabled=1;
@@ -53,11 +55,19 @@ void _starpu_mpi_cache_init(MPI_Comm comm)
 
 	MPI_Comm_size(comm, &nb_nodes);
 	_STARPU_MPI_DEBUG(2, "Initialising htable for cache\n");
+
 	_cache_sent_data = malloc(nb_nodes * sizeof(struct _starpu_data_entry *));
-	for(i=0 ; i<nb_nodes ; i++) _cache_sent_data[i] = NULL;
 	_cache_received_data = malloc(nb_nodes * sizeof(struct _starpu_data_entry *));
-	for(i=0 ; i<nb_nodes ; i++) _cache_received_data[i] = NULL;
-	_starpu_mpi_cache_stats_init(comm);
+	_cache_sent_mutex = malloc(nb_nodes * sizeof(starpu_pthread_mutex_t));
+	_cache_received_mutex = malloc(nb_nodes * sizeof(starpu_pthread_mutex_t));
+
+	for(i=0 ; i<nb_nodes ; i++)
+	{
+		_cache_sent_data[i] = NULL;
+		_cache_received_data[i] = NULL;
+		STARPU_PTHREAD_MUTEX_INIT(&_cache_sent_mutex[i], NULL);
+		STARPU_PTHREAD_MUTEX_INIT(&_cache_received_mutex[i], NULL);
+	}
 }
 
 static
@@ -72,27 +82,44 @@ void _starpu_mpi_cache_empty_tables(int world_size)
 	for(i=0 ; i<world_size ; i++)
 	{
 		struct _starpu_data_entry *entry, *tmp;
+
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
 		HASH_ITER(hh, _cache_sent_data[i], entry, tmp)
 		{
 			HASH_DEL(_cache_sent_data[i], entry);
 			free(entry);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[i]);
+
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[i]);
 		HASH_ITER(hh, _cache_received_data[i], entry, tmp)
 		{
 			HASH_DEL(_cache_received_data[i], entry);
 			_starpu_mpi_cache_stats_dec(-1, i, entry->data);
 			free(entry);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[i]);
 	}
 }
 
 void _starpu_mpi_cache_free(int world_size)
 {
+	int i;
+
 	if (_cache_enabled == 0) return;
 
 	_starpu_mpi_cache_empty_tables(world_size);
 	free(_cache_sent_data);
 	free(_cache_received_data);
+
+	for(i=0 ; i<world_size ; i++)
+	{
+		STARPU_PTHREAD_MUTEX_DESTROY(&_cache_sent_mutex[i]);
+		STARPU_PTHREAD_MUTEX_DESTROY(&_cache_received_mutex[i]);
+	}
+	free(_cache_sent_mutex);
+	free(_cache_received_mutex);
+
 	_starpu_mpi_cache_stats_free();
 }
 
@@ -104,6 +131,8 @@ void _starpu_mpi_cache_flush_sent(MPI_Comm comm, starpu_data_handle_t data)
 	for(n=0 ; n<size ; n++)
 	{
 		struct _starpu_data_entry *already_sent;
+
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[n]);
 		HASH_FIND_PTR(_cache_sent_data[n], &data, already_sent);
 		if (already_sent)
 		{
@@ -111,6 +140,7 @@ void _starpu_mpi_cache_flush_sent(MPI_Comm comm, starpu_data_handle_t data)
 			HASH_DEL(_cache_sent_data[n], already_sent);
 			free(already_sent);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[n]);
 	}
 }
 
@@ -119,6 +149,7 @@ void _starpu_mpi_cache_flush_recv(starpu_data_handle_t data, int me)
 	int mpi_rank = starpu_data_get_rank(data);
 	struct _starpu_data_entry *already_received;
 
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
 	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
 	if (already_received)
 	{
@@ -131,6 +162,7 @@ void _starpu_mpi_cache_flush_recv(starpu_data_handle_t data, int me)
 		free(already_received);
 		starpu_data_invalidate_submit(data);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[mpi_rank]);
 }
 
 void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
@@ -146,6 +178,8 @@ void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
 	for(i=0 ; i<nb_nodes ; i++)
 	{
 		struct _starpu_data_entry *entry, *tmp;
+
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
 		HASH_ITER(hh, _cache_sent_data[i], entry, tmp)
 		{
 			mpi_rank = starpu_data_get_rank(entry->data);
@@ -154,6 +188,9 @@ void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
 			HASH_DEL(_cache_sent_data[i], entry);
 			free(entry);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[i]);
+
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[i]);
 		HASH_ITER(hh, _cache_received_data[i], entry, tmp)
 		{
 			mpi_rank = starpu_data_get_rank(entry->data);
@@ -163,6 +200,7 @@ void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
 			_starpu_mpi_cache_stats_dec(my_rank, i, entry->data);
 			free(entry);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[i]);
 	}
 }
 
@@ -180,6 +218,7 @@ void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 
 	for(i=0 ; i<nb_nodes ; i++)
 	{
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
 		HASH_FIND_PTR(_cache_sent_data[i], &data_handle, avail);
 		if (avail)
 		{
@@ -187,6 +226,9 @@ void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 			HASH_DEL(_cache_sent_data[i], avail);
 			free(avail);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[i]);
+
+		STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[i]);
 		HASH_FIND_PTR(_cache_received_data[i], &data_handle, avail);
 		if (avail)
 		{
@@ -195,6 +237,7 @@ void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 			_starpu_mpi_cache_stats_dec(my_rank, i, data_handle);
 			free(avail);
 		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[i]);
 	}
 
 	if (mpi_rank != my_rank && mpi_rank != -1)
@@ -203,9 +246,11 @@ void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 
 void *_starpu_mpi_already_received(int src, starpu_data_handle_t data, int mpi_rank)
 {
+	struct _starpu_data_entry *already_received;
+
 	if (_cache_enabled == 0) return NULL;
 
-	struct _starpu_data_entry *already_received;
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
 	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
 	if (already_received == NULL)
 	{
@@ -218,14 +263,17 @@ void *_starpu_mpi_already_received(int src, starpu_data_handle_t data, int mpi_r
 	{
 		_STARPU_MPI_DEBUG(2, "Do not receive data %p from node %d as it is already available\n", data, mpi_rank);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[mpi_rank]);
 	return already_received;
 }
 
 void *_starpu_mpi_already_sent(starpu_data_handle_t data, int dest)
 {
+	struct _starpu_data_entry *already_sent;
+
 	if (_cache_enabled == 0) return NULL;
 
-	struct _starpu_data_entry *already_sent;
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[dest]);
 	HASH_FIND_PTR(_cache_sent_data[dest], &data, already_sent);
 	if (already_sent == NULL)
 	{
@@ -238,6 +286,7 @@ void *_starpu_mpi_already_sent(starpu_data_handle_t data, int dest)
 	{
 		_STARPU_MPI_DEBUG(2, "Do not send data %p to node %d as it has already been sent\n", data, dest);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[dest]);
 	return already_sent;
 }