Browse Source

mpi/src/starpu_mpi_comm.c: make communicator-related functions thread safe (that fixes bug #20665 reported by Jerome Roberts)

Nathalie Furmento 8 years ago
parent
commit
7f59393fc3
1 changed files with 20 additions and 0 deletions
  1. 20 0
      mpi/src/starpu_mpi_comm.c

+ 20 - 0
mpi/src/starpu_mpi_comm.c

@@ -34,6 +34,8 @@ struct _starpu_mpi_comm_hashtable
 	UT_hash_handle hh;
 	MPI_Comm comm;
 };
+
+static starpu_pthread_mutex_t _starpu_mpi_comms_mutex;
 struct _starpu_mpi_comm_hashtable *_starpu_mpi_comms_cache;
 struct _starpu_mpi_comm **_starpu_mpi_comms;
 int _starpu_mpi_comm_nb;
@@ -48,6 +50,7 @@ void _starpu_mpi_comm_init(MPI_Comm comm)
 	_starpu_mpi_comm_nb=0;
 	_starpu_mpi_comm_tested=0;
 	_starpu_mpi_comms_cache = NULL;
+	STARPU_PTHREAD_MUTEX_INIT(&_starpu_mpi_comms_mutex, NULL);
 
 	_starpu_mpi_comm_register(comm);
 }
@@ -69,12 +72,15 @@ void _starpu_mpi_comm_free()
 		HASH_DEL(_starpu_mpi_comms_cache, entry);
 		free(entry);
 	}
+
+	STARPU_PTHREAD_MUTEX_DESTROY(&_starpu_mpi_comms_mutex);
 }
 
 void _starpu_mpi_comm_register(MPI_Comm comm)
 {
 	struct _starpu_mpi_comm_hashtable *found;
 
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_comms_mutex);
 	HASH_FIND(hh, _starpu_mpi_comms_cache, &comm, sizeof(MPI_Comm), found);
 	if (found)
 	{
@@ -99,11 +105,14 @@ void _starpu_mpi_comm_register(MPI_Comm comm)
 		entry->comm = comm;
 		HASH_ADD(hh, _starpu_mpi_comms_cache, comm, sizeof(entry->comm), entry);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_comms_mutex);
 }
 
 void _starpu_mpi_comm_post_recv()
 {
 	int i;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_comms_mutex);
 	for(i=0 ; i<_starpu_mpi_comm_nb ; i++)
 	{
 		struct _starpu_mpi_comm *_comm = _starpu_mpi_comms[i]; // get the ith _comm;
@@ -115,11 +124,14 @@ void _starpu_mpi_comm_post_recv()
 			_comm->posted = 1;
 		}
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_comms_mutex);
 }
 
 int _starpu_mpi_comm_test_recv(MPI_Status *status, struct _starpu_mpi_envelope **envelope, MPI_Comm *comm)
 {
 	int i=_starpu_mpi_comm_tested;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_comms_mutex);
 	while (1)
 	{
 		int flag, res;
@@ -142,21 +154,28 @@ int _starpu_mpi_comm_test_recv(MPI_Status *status, struct _starpu_mpi_envelope *
 					_starpu_mpi_comm_tested = 0;
 				*envelope = _comm->envelope;
 				*comm = _comm->comm;
+				STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_comms_mutex);
 				return 1;
 			}
 		}
 		i++;
 		if (i == _starpu_mpi_comm_nb) i=0;
 		if (i == _starpu_mpi_comm_tested)
+		{
 			// We have tested all the requests, none has completed
+			STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_comms_mutex);
 			return 0;
+		}
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_comms_mutex);
 	return 0;
 }
 
 void _starpu_mpi_comm_cancel_recv()
 {
 	int i;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_comms_mutex);
 	for(i=0 ; i<_starpu_mpi_comm_nb ; i++)
 	{
 		struct _starpu_mpi_comm *_comm = _starpu_mpi_comms[i]; // get the ith _comm;
@@ -168,4 +187,5 @@ void _starpu_mpi_comm_cancel_recv()
 			_comm->posted = 0;
 		}
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_comms_mutex);
 }