Pārlūkot izejas kodu

mpi/src/starpu_mpi_cache.c: check rank is valid

Nathalie Furmento 10 gadi atpakaļ
vecāks
revīzija
553aca2bb9
1 mainītis faili ar 10 papildinājumiem un 6 dzēšanām
  1. 10 6
      mpi/src/starpu_mpi_cache.c

+ 10 - 6
mpi/src/starpu_mpi_cache.c

@@ -176,6 +176,8 @@ void _starpu_mpi_cache_received_data_clear(starpu_data_handle_t data)
 	int mpi_rank = starpu_mpi_data_get_rank(data);
 	struct _starpu_data_entry *already_received;
 
+	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
+
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
 	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
 	if (already_received)
@@ -234,15 +236,11 @@ void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
 void _starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 {
 	struct _starpu_data_entry *avail;
-	int i, my_rank, nb_nodes;
-	int mpi_rank;
+	int i, nb_nodes;
 
 	if (_starpu_cache_enabled == 0) return;
 
 	starpu_mpi_comm_size(comm, &nb_nodes);
-	starpu_mpi_comm_rank(comm, &my_rank);
-	mpi_rank = starpu_mpi_data_get_rank(data_handle);
-
 	for(i=0 ; i<nb_nodes ; i++)
 	{
 		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
@@ -286,6 +284,8 @@ void *_starpu_mpi_cache_received_data_set(starpu_data_handle_t data, int mpi_ran
 
 	if (_starpu_cache_enabled == 0) return NULL;
 
+	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
+
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
 	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
 	if (already_received == NULL)
@@ -308,6 +308,9 @@ void *_starpu_mpi_cache_received_data_get(starpu_data_handle_t data, int mpi_ran
 	struct _starpu_data_entry *already_received;
 
 	if (_starpu_cache_enabled == 0) return NULL;
+
+	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
+
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
 	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[mpi_rank]);
@@ -320,6 +323,8 @@ void *_starpu_mpi_cache_sent_data_set(starpu_data_handle_t data, int dest)
 
 	if (_starpu_cache_enabled == 0) return NULL;
 
+	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
+
 	STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[dest]);
 	HASH_FIND_PTR(_cache_sent_data[dest], &data, already_sent);
 	if (already_sent == NULL)
@@ -336,4 +341,3 @@ void *_starpu_mpi_cache_sent_data_set(starpu_data_handle_t data, int dest)
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[dest]);
 	return already_sent;
 }
-