Browse Source

mpi: move cache data in the starpu_data_handle_t

Nathalie Furmento 8 years ago
parent
commit
d8961fb495

+ 23 - 22
mpi/src/starpu_mpi.c

@@ -1679,40 +1679,41 @@ void _starpu_mpi_progress_shutdown(int *value)
         STARPU_PTHREAD_COND_DESTROY(&barrier_cond);
 }
 
-void _starpu_mpi_clear_cache(starpu_data_handle_t data_handle)
+void _starpu_mpi_data_clear(starpu_data_handle_t data_handle)
 {
 	_starpu_mpi_tag_data_release(data_handle);
-	struct _starpu_mpi_node_tag *mpi_data = data_handle->mpi_data;
-	_starpu_mpi_cache_flush(mpi_data->comm, data_handle);
+	_starpu_mpi_cache_data_clear(data_handle);
 	free(data_handle->mpi_data);
 }
 
 void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, int tag, int rank, MPI_Comm comm)
 {
-	struct _starpu_mpi_node_tag *mpi_data;
+	struct _starpu_mpi_data *mpi_data;
 	if (data_handle->mpi_data)
 	{
 		mpi_data = data_handle->mpi_data;
 	}
 	else
 	{
-		_STARPU_CALLOC(mpi_data, 1, sizeof(struct _starpu_mpi_node_tag));
-		mpi_data->data_tag = -1;
-		mpi_data->rank = -1;
-		mpi_data->comm = MPI_COMM_WORLD;
+		_STARPU_CALLOC(mpi_data, 1, sizeof(struct _starpu_mpi_data));
+		mpi_data->magic = 42;
+		mpi_data->node_tag.data_tag = -1;
+		mpi_data->node_tag.rank = -1;
+		mpi_data->node_tag.comm = MPI_COMM_WORLD;
 		data_handle->mpi_data = mpi_data;
+		_starpu_mpi_cache_data_init(data_handle);
 		_starpu_mpi_tag_data_register(data_handle, tag);
-		_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_clear_cache);
+		_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_data_clear);
 	}
 
 	if (tag != -1)
 	{
-		mpi_data->data_tag = tag;
+		mpi_data->node_tag.data_tag = tag;
 	}
 	if (rank != -1)
 	{
-		mpi_data->rank = rank;
-		mpi_data->comm = comm;
+		mpi_data->node_tag.rank = rank;
+		mpi_data->node_tag.comm = comm;
 		_starpu_mpi_comm_register(comm);
 	}
 }
@@ -1730,13 +1731,13 @@ void starpu_mpi_data_set_tag(starpu_data_handle_t handle, int tag)
 int starpu_mpi_data_get_rank(starpu_data_handle_t data)
 {
 	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
-	return ((struct _starpu_mpi_node_tag *)(data->mpi_data))->rank;
+	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.rank;
 }
 
 int starpu_mpi_data_get_tag(starpu_data_handle_t data)
 {
 	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
-	return ((struct _starpu_mpi_node_tag *)(data->mpi_data))->data_tag;
+	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.data_tag;
 }
 
 void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
@@ -1760,8 +1761,8 @@ void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t da
 	if (me == node)
 	{
 		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		void *already_received = _starpu_mpi_cache_received_data_set(data_handle, rank);
-		if (already_received == NULL)
+		int already_received = _starpu_mpi_cache_received_data_set(data_handle);
+		if (already_received == 0)
 		{
 			_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
 			starpu_mpi_irecv_detached(data_handle, rank, tag, comm, callback, arg);
@@ -1770,8 +1771,8 @@ void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t da
 	else if (me == rank)
 	{
 		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		void *already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
-		if (already_sent == NULL)
+		int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
+		if (already_sent == 0)
 		{
 			_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
 			starpu_mpi_isend_detached(data_handle, node, tag, comm, NULL, NULL);
@@ -1801,8 +1802,8 @@ void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle
 	{
 		MPI_Status status;
 		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		void *already_received = _starpu_mpi_cache_received_data_set(data_handle, rank);
-		if (already_received == NULL)
+		int already_received = _starpu_mpi_cache_received_data_set(data_handle);
+		if (already_received == 0)
 		{
 			_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, rank);
 			starpu_mpi_recv(data_handle, rank, tag, comm, &status);
@@ -1811,8 +1812,8 @@ void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle
 	else if (me == rank)
 	{
 		_STARPU_MPI_DEBUG(1, "Migrating data %p from %d to %d\n", data_handle, rank, node);
-		void *already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
-		if (already_sent == NULL)
+		int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, node);
+		if (already_sent == 0)
 		{
 			_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data_handle, node);
 			starpu_mpi_send(data_handle, node, tag, comm);

+ 208 - 205
mpi/src/starpu_mpi_cache.c

@@ -18,6 +18,7 @@
 
 #include <starpu.h>
 #include <common/uthash.h>
+#include <datawizard/coherency.h>
 
 #include <starpu_mpi_cache.h>
 #include <starpu_mpi_cache_stats.h>
@@ -27,16 +28,14 @@
 struct _starpu_data_entry
 {
 	UT_hash_handle hh;
-	starpu_data_handle_t data;
+	starpu_data_handle_t data_handle;
 };
 
-static starpu_pthread_mutex_t *_cache_sent_mutex;
-static starpu_pthread_mutex_t *_cache_received_mutex;
-static struct _starpu_data_entry **_cache_sent_data = NULL;
-static struct _starpu_data_entry **_cache_received_data = NULL;
+static starpu_pthread_mutex_t _cache_mutex;
+static struct _starpu_data_entry *_cache_data = NULL;
 int _starpu_cache_enabled=1;
-MPI_Comm _starpu_cache_comm;
-int _starpu_cache_comm_size;
+static MPI_Comm _starpu_cache_comm;
+static int _starpu_cache_comm_size;
 
 int starpu_mpi_cache_is_enabled()
 {
@@ -55,7 +54,7 @@ int starpu_mpi_cache_set(int enabled)
 		{
 			// We need to clean the cache
 			starpu_mpi_cache_flush_all_data(_starpu_cache_comm);
-			_starpu_mpi_cache_shutdown(_starpu_cache_comm_size);
+			_starpu_mpi_cache_shutdown();
 		}
 		_starpu_cache_enabled = 0;
 	}
@@ -64,8 +63,6 @@ int starpu_mpi_cache_set(int enabled)
 
 void _starpu_mpi_cache_init(MPI_Comm comm)
 {
-	int i;
-
 	_starpu_cache_enabled = starpu_get_env_number("STARPU_MPI_CACHE");
 	if (_starpu_cache_enabled == -1)
 	{
@@ -80,295 +77,301 @@ void _starpu_mpi_cache_init(MPI_Comm comm)
 
 	_starpu_cache_comm = comm;
 	starpu_mpi_comm_size(comm, &_starpu_cache_comm_size);
-	_STARPU_MPI_DEBUG(2, "Initialising htable for cache\n");
+	_starpu_mpi_cache_stats_init(comm);
+	STARPU_PTHREAD_MUTEX_INIT(&_cache_mutex, NULL);
+}
 
-	_STARPU_MPI_MALLOC(_cache_sent_data, _starpu_cache_comm_size * sizeof(struct _starpu_data_entry *));
-	_STARPU_MPI_MALLOC(_cache_received_data, _starpu_cache_comm_size * sizeof(struct _starpu_data_entry *));
-	_STARPU_MPI_MALLOC(_cache_sent_mutex, _starpu_cache_comm_size * sizeof(starpu_pthread_mutex_t));
-	_STARPU_MPI_MALLOC(_cache_received_mutex, _starpu_cache_comm_size * sizeof(starpu_pthread_mutex_t));
+void _starpu_mpi_cache_shutdown()
+{
+	if (_starpu_cache_enabled == 0) return;
 
-	for(i=0 ; i<_starpu_cache_comm_size ; i++)
+	struct _starpu_data_entry *entry, *tmp;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
+	HASH_ITER(hh, _cache_data, entry, tmp)
 	{
-		_cache_sent_data[i] = NULL;
-		_cache_received_data[i] = NULL;
-		STARPU_PTHREAD_MUTEX_INIT(&_cache_sent_mutex[i], NULL);
-		STARPU_PTHREAD_MUTEX_INIT(&_cache_received_mutex[i], NULL);
+		HASH_DEL(_cache_data, entry);
+		free(entry);
 	}
-	_starpu_mpi_cache_stats_init(comm);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
+	STARPU_PTHREAD_MUTEX_DESTROY(&_cache_mutex);
+	free(_cache_data);
+	_starpu_mpi_cache_stats_shutdown();
 }
 
-static
-void _starpu_mpi_cache_empty_tables(int world_size)
+void _starpu_mpi_cache_data_clear(starpu_data_handle_t data_handle)
 {
 	int i;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 
-	if (_starpu_cache_enabled == 0) return;
-
-	_STARPU_MPI_DEBUG(2, "Clearing htable for cache\n");
+	if (_starpu_cache_enabled == 0) return 0;
 
-	for(i=0 ; i<world_size ; i++)
+	_starpu_mpi_cache_flush(data_handle);
+	for(i=0 ; i<_starpu_cache_comm_size ; i++)
 	{
-		struct _starpu_data_entry *entry, *tmp;
-
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
-		HASH_ITER(hh, _cache_sent_data[i], entry, tmp)
-		{
-			HASH_DEL(_cache_sent_data[i], entry);
-			free(entry);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[i]);
-
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[i]);
-		HASH_ITER(hh, _cache_received_data[i], entry, tmp)
-		{
-			HASH_DEL(_cache_received_data[i], entry);
-			_starpu_mpi_cache_stats_dec(i, entry->data);
-			free(entry);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[i]);
+		STARPU_PTHREAD_MUTEX_DESTROY(&mpi_data->cache_sent_mutex[i]);
 	}
+	free(mpi_data->cache_sent);
+	free(mpi_data->cache_sent_mutex);
 }
 
-void _starpu_mpi_cache_shutdown()
+void _starpu_mpi_cache_data_init(starpu_data_handle_t data_handle)
 {
 	int i;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 
-	if (_starpu_cache_enabled == 0) return;
-
-	_starpu_mpi_cache_empty_tables(_starpu_cache_comm_size);
-	free(_cache_sent_data);
-	free(_cache_received_data);
+	if (_starpu_cache_enabled == 0) return 0;
 
+	STARPU_PTHREAD_MUTEX_INIT(&mpi_data->cache_received_mutex, NULL);
+	mpi_data->cache_received = 0;
+	_STARPU_MALLOC(mpi_data->cache_sent_mutex, _starpu_cache_comm_size*sizeof(mpi_data->cache_sent_mutex[0]));
+	_STARPU_MALLOC(mpi_data->cache_sent, _starpu_cache_comm_size*sizeof(mpi_data->cache_sent[0]));
 	for(i=0 ; i<_starpu_cache_comm_size ; i++)
 	{
-		STARPU_PTHREAD_MUTEX_DESTROY(&_cache_sent_mutex[i]);
-		STARPU_PTHREAD_MUTEX_DESTROY(&_cache_received_mutex[i]);
+		STARPU_PTHREAD_MUTEX_INIT(&mpi_data->cache_sent_mutex[i], NULL);
+		mpi_data->cache_sent[i] = 0;
 	}
-	free(_cache_sent_mutex);
-	free(_cache_received_mutex);
-
-	_starpu_mpi_cache_stats_shutdown();
 }
 
-void _starpu_mpi_cache_sent_data_clear(MPI_Comm comm, starpu_data_handle_t data)
+static void _starpu_mpi_cache_data_add(starpu_data_handle_t data_handle)
 {
-	int n, size;
-	starpu_mpi_comm_size(comm, &size);
+	struct _starpu_data_entry *entry;
 
-	for(n=0 ; n<size ; n++)
+	if (_starpu_cache_enabled == 0) return 0;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
+	HASH_FIND_PTR(_cache_data, &data_handle, entry);
+	if (entry == NULL)
 	{
-		struct _starpu_data_entry *already_sent;
+		_STARPU_MPI_MALLOC(entry, sizeof(*entry));
+		entry->data_handle = data_handle;
+		HASH_ADD_PTR(_cache_data, data_handle, entry);
+	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
+}
 
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[n]);
-		HASH_FIND_PTR(_cache_sent_data[n], &data, already_sent);
-		if (already_sent)
-		{
-			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data);
-			HASH_DEL(_cache_sent_data[n], already_sent);
-			free(already_sent);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[n]);
+static void _starpu_mpi_cache_data_remove(starpu_data_handle_t data_handle)
+{
+	struct _starpu_data_entry *entry;
+
+	if (_starpu_cache_enabled == 0) return 0;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
+	HASH_FIND_PTR(_cache_data, &data_handle, entry);
+	if (entry)
+	{
+		HASH_DEL(_cache_data, entry);
+		free(entry);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
 }
 
-void _starpu_mpi_cache_received_data_clear(starpu_data_handle_t data)
+/**************************************
+ * Received cache
+ **************************************/
+void _starpu_mpi_cache_received_data_clear(starpu_data_handle_t data_handle)
 {
-	int mpi_rank = starpu_mpi_data_get_rank(data);
-	struct _starpu_data_entry *already_received;
+	int mpi_rank = starpu_mpi_data_get_rank(data_handle);
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
+
+	if (_starpu_cache_enabled == 0) return 0;
 
+	STARPU_ASSERT(mpi_data->magic == 42);
 	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
 
-	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
-	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
-	if (already_received)
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_received_mutex);
+	if (mpi_data->cache_received == 1)
 	{
 #ifdef STARPU_DEVEL
 #  warning TODO: Somebody else will write to the data, so discard our cached copy if any. starpu_mpi could just remember itself.
 #endif
-		_STARPU_MPI_DEBUG(2, "Clearing receive cache for data %p\n", data);
-		HASH_DEL(_cache_received_data[mpi_rank], already_received);
-		_starpu_mpi_cache_stats_dec(mpi_rank, data);
-		free(already_received);
-		starpu_data_invalidate_submit(data);
+		_STARPU_MPI_DEBUG(2, "Clearing receive cache for data %p\n", data_handle);
+		mpi_data->cache_received = 0;
+		starpu_data_invalidate_submit(data_handle);
+		_starpu_mpi_cache_data_remove(data_handle);
+		_starpu_mpi_cache_stats_dec(mpi_rank, data_handle);
 	}
-	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[mpi_rank]);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_received_mutex);
 }
 
-void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
+int _starpu_mpi_cache_received_data_set(starpu_data_handle_t data_handle)
 {
-	int nb_nodes, i;
-	int mpi_rank, my_rank;
+	int mpi_rank = starpu_mpi_data_get_rank(data_handle);
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 
-	if (_starpu_cache_enabled == 0) return;
+	if (_starpu_cache_enabled == 0) return 0;
 
-	starpu_mpi_comm_size(comm, &nb_nodes);
-	starpu_mpi_comm_rank(comm, &my_rank);
+	STARPU_ASSERT(mpi_data->magic == 42);
+	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
 
-	for(i=0 ; i<nb_nodes ; i++)
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_received_mutex);
+	int already_received = mpi_data->cache_received;
+	if (already_received == 0)
 	{
-		struct _starpu_data_entry *entry, *tmp;
-
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
-		HASH_ITER(hh, _cache_sent_data[i], entry, tmp)
-		{
-			mpi_rank = starpu_mpi_data_get_rank(entry->data);
-			if (mpi_rank != my_rank && mpi_rank != -1)
-				starpu_data_invalidate_submit(entry->data);
-			HASH_DEL(_cache_sent_data[i], entry);
-			free(entry);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[i]);
-
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[i]);
-		HASH_ITER(hh, _cache_received_data[i], entry, tmp)
-		{
-			mpi_rank = starpu_mpi_data_get_rank(entry->data);
-			if (mpi_rank != my_rank && mpi_rank != -1)
-				starpu_data_invalidate_submit(entry->data);
-			HASH_DEL(_cache_received_data[i], entry);
-			_starpu_mpi_cache_stats_dec(i, entry->data);
-			free(entry);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[i]);
+		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been received by %d\n", data_handle, mpi_rank);
+		mpi_data->cache_received = 1;
+		_starpu_mpi_cache_data_add(data_handle);
+		_starpu_mpi_cache_stats_inc(mpi_rank, data_handle);
+	}
+	else
+	{
+		_STARPU_MPI_DEBUG(2, "Do not receive data %p from node %d as it is already available\n", data_handle, mpi_rank);
 	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_received_mutex);
+	return already_received;
 }
 
-void _starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
+int _starpu_mpi_cache_received_data_get(starpu_data_handle_t data_handle)
 {
-	struct _starpu_data_entry *avail;
-	int i, nb_nodes;
+	int already_received;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 
-	if (_starpu_cache_enabled == 0) return;
+	if (_starpu_cache_enabled == 0) return 0;
 
-	starpu_mpi_comm_size(comm, &nb_nodes);
-	for(i=0 ; i<nb_nodes ; i++)
-	{
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[i]);
-		HASH_FIND_PTR(_cache_sent_data[i], &data_handle, avail);
-		if (avail)
-		{
-			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
-			HASH_DEL(_cache_sent_data[i], avail);
-			free(avail);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[i]);
+	STARPU_ASSERT(mpi_data->magic == 42);
 
-		STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[i]);
-		HASH_FIND_PTR(_cache_received_data[i], &data_handle, avail);
-		if (avail)
-		{
-			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
-			HASH_DEL(_cache_received_data[i], avail);
-			_starpu_mpi_cache_stats_dec(i, data_handle);
-			free(avail);
-		}
-		STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[i]);
-	}
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_received_mutex);
+	already_received = mpi_data->cache_received;
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_received_mutex);
+	return already_received;
 }
 
-void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
+int starpu_mpi_cached_receive(starpu_data_handle_t data_handle)
 {
-	int my_rank, mpi_rank;
+	return _starpu_mpi_cache_received_data_get(data_handle);
+}
+
+/**************************************
+ * Send cache
+ **************************************/
+void _starpu_mpi_cache_sent_data_clear(starpu_data_handle_t data_handle)
+{
+	int n, size;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 
-	_starpu_mpi_cache_flush(comm, data_handle);
+	if (_starpu_cache_enabled == 0) return 0;
 
-	starpu_mpi_comm_rank(comm, &my_rank);
-	mpi_rank = starpu_mpi_data_get_rank(data_handle);
-	if (mpi_rank != my_rank && mpi_rank != -1)
-		starpu_data_invalidate_submit(data_handle);
+	starpu_mpi_comm_size(mpi_data->node_tag.comm, &size);
+	for(n=0 ; n<size ; n++)
+	{
+		STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_sent_mutex[n]);
+		if (mpi_data->cache_sent[n] == 1)
+		{
+			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
+			mpi_data->cache_sent[n] = 0;
+			_starpu_mpi_cache_data_remove(data_handle);
+		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_sent_mutex[n]);
+	}
 }
 
-void *_starpu_mpi_cache_received_data_set(starpu_data_handle_t data, int mpi_rank)
+int _starpu_mpi_cache_sent_data_set(starpu_data_handle_t data_handle, int dest)
 {
-	struct _starpu_data_entry *already_received;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 
-	if (_starpu_cache_enabled == 0) return NULL;
+	if (_starpu_cache_enabled == 0) return 0;
 
-	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
+	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
 
-	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
-	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
-	if (already_received == NULL)
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_sent_mutex[dest]);
+	int already_sent = mpi_data->cache_sent[dest];
+	if (mpi_data->cache_sent[dest] == 0)
 	{
-		struct _starpu_data_entry *entry;
-		_STARPU_MPI_MALLOC(entry, sizeof(*entry));
-		entry->data = data;
-		HASH_ADD_PTR(_cache_received_data[mpi_rank], data, entry);
-		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been received by %d\n", data, mpi_rank);
-		_starpu_mpi_cache_stats_inc(mpi_rank, data);
+		mpi_data->cache_sent[dest] = 1;
+		_starpu_mpi_cache_data_add(data_handle);
+		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been sent to %d\n", data_handle, dest);
 	}
 	else
 	{
-		_STARPU_MPI_DEBUG(2, "Do not receive data %p from node %d as it is already available\n", data, mpi_rank);
+		_STARPU_MPI_DEBUG(2, "Do not send data %p to node %d as it has already been sent\n", data_handle, dest);
 	}
-	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[mpi_rank]);
-	return already_received;
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_sent_mutex[dest]);
+	return already_sent;
 }
 
-void *_starpu_mpi_cache_received_data_get(starpu_data_handle_t data, int mpi_rank)
+int _starpu_mpi_cache_sent_data_get(starpu_data_handle_t data_handle, int dest)
 {
-	struct _starpu_data_entry *already_received;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
+	int already_sent;
 
-	if (_starpu_cache_enabled == 0) return NULL;
+	if (_starpu_cache_enabled == 0) return 0;
 
-	STARPU_MPI_ASSERT_MSG(mpi_rank < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", mpi_rank, _starpu_cache_comm_size);
+	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
 
-	STARPU_PTHREAD_MUTEX_LOCK(&_cache_received_mutex[mpi_rank]);
-	HASH_FIND_PTR(_cache_received_data[mpi_rank], &data, already_received);
-	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_received_mutex[mpi_rank]);
-	return already_received;
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_sent_mutex[dest]);
+	already_sent = mpi_data->cache_sent[dest];
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_sent_mutex[dest]);
+	return already_sent;
 }
 
-int starpu_mpi_cached_receive(starpu_data_handle_t data_handle)
+int starpu_mpi_cached_send(starpu_data_handle_t data_handle, int dest)
 {
-	int owner = starpu_mpi_data_get_rank(data_handle);
-	void *already_received = _starpu_mpi_cache_received_data_get(data_handle, owner);
-	return already_received != NULL;
+	return _starpu_mpi_cache_sent_data_get(data_handle, dest);
 }
 
-void *_starpu_mpi_cache_sent_data_set(starpu_data_handle_t data, int dest)
+void _starpu_mpi_cache_flush(starpu_data_handle_t data_handle)
 {
-	struct _starpu_data_entry *already_sent;
-
-	if (_starpu_cache_enabled == 0) return NULL;
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
+	int i, nb_nodes;
 
-	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
+	if (_starpu_cache_enabled == 0) return;
 
-	STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[dest]);
-	HASH_FIND_PTR(_cache_sent_data[dest], &data, already_sent);
-	if (already_sent == NULL)
+	starpu_mpi_comm_size(mpi_data->node_tag.comm, &nb_nodes);
+	for(i=0 ; i<nb_nodes ; i++)
 	{
-		struct _starpu_data_entry *entry;
-		_STARPU_MPI_MALLOC(entry, sizeof(*entry));
-		entry->data = data;
-		HASH_ADD_PTR(_cache_sent_data[dest], data, entry);
-		_STARPU_MPI_DEBUG(2, "Noting that data %p has already been sent to %d\n", data, dest);
+		STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_sent_mutex[i]);
+		if (mpi_data->cache_sent[i] == 1)
+		{
+			_STARPU_MPI_DEBUG(2, "Clearing send cache for data %p\n", data_handle);
+			mpi_data->cache_sent[i] = 0;
+			_starpu_mpi_cache_stats_dec(i, data_handle);
+		}
+		STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_sent_mutex[i]);
 	}
-	else
+
+	STARPU_PTHREAD_MUTEX_LOCK(&mpi_data->cache_received_mutex);
+	if (mpi_data->cache_received == 1)
 	{
-		_STARPU_MPI_DEBUG(2, "Do not send data %p to node %d as it has already been sent\n", data, dest);
+		int mpi_rank = starpu_mpi_data_get_rank(data_handle);
+		_STARPU_MPI_DEBUG(2, "Clearing received cache for data %p\n", data_handle);
+		mpi_data->cache_received = 0;
+		_starpu_mpi_cache_stats_dec(mpi_rank, data_handle);
 	}
-	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[dest]);
-	return already_sent;
+	STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_data->cache_received_mutex);
 }
 
-void *_starpu_mpi_cache_sent_data_get(starpu_data_handle_t data, int dest)
+static void _starpu_mpi_cache_flush_and_invalidate(MPI_Comm comm, starpu_data_handle_t data_handle)
 {
-	struct _starpu_data_entry *already_sent;
-
-	if (_starpu_cache_enabled == 0) return NULL;
+	int my_rank, mpi_rank;
 
-	STARPU_MPI_ASSERT_MSG(dest < _starpu_cache_comm_size, "Node %d invalid. Max node is %d\n", dest, _starpu_cache_comm_size);
+	_starpu_mpi_cache_flush(data_handle);
 
-	STARPU_PTHREAD_MUTEX_LOCK(&_cache_sent_mutex[dest]);
-	HASH_FIND_PTR(_cache_sent_data[dest], &data, already_sent);
-	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_sent_mutex[dest]);
-	return already_sent;
+	starpu_mpi_comm_rank(comm, &my_rank);
+	mpi_rank = starpu_mpi_data_get_rank(data_handle);
+	if (mpi_rank != my_rank && mpi_rank != -1)
+		starpu_data_invalidate_submit(data_handle);
 }
 
-int starpu_mpi_cached_send(starpu_data_handle_t data_handle, int dest)
+void starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle)
 {
-	void *already_sent = _starpu_mpi_cache_sent_data_get(data_handle, dest);
-	return already_sent != NULL;
+	if (_starpu_cache_enabled == 0) return 0;
+
+	_starpu_mpi_cache_flush_and_invalidate(comm, data_handle);
+	_starpu_mpi_cache_data_remove(data_handle);
 }
 
+void starpu_mpi_cache_flush_all_data(MPI_Comm comm)
+{
+	struct _starpu_data_entry *entry, *tmp;
+
+	if (_starpu_cache_enabled == 0) return;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_cache_mutex);
+	HASH_ITER(hh, _cache_data, entry, tmp)
+	{
+		_starpu_mpi_cache_flush_and_invalidate(comm, entry->data_handle);
+		HASH_DEL(_cache_data, entry);
+		free(entry);
+	}
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_cache_mutex);
+}

+ 9 - 7
mpi/src/starpu_mpi_cache.h

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016  CNRS
+ * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016, 2017  CNRS
  * Copyright (C) 2011-2014, 2017  Université de Bordeaux
  * Copyright (C) 2014 INRIA
  *
@@ -30,24 +30,26 @@ extern "C" {
 extern int _starpu_cache_enabled;
 void _starpu_mpi_cache_init(MPI_Comm comm);
 void _starpu_mpi_cache_shutdown();
+void _starpu_mpi_cache_data_init(starpu_data_handle_t data_handle);
+void _starpu_mpi_cache_data_clear(starpu_data_handle_t data_handle);
 
 /*
  * If the data is already available in the cache, return a pointer to the data
  * If the data is NOT available in the cache, add it to the cache and return NULL
  */
-void *_starpu_mpi_cache_received_data_set(starpu_data_handle_t data, int mpi_rank);
-void *_starpu_mpi_cache_received_data_get(starpu_data_handle_t data, int mpi_rank);
+int _starpu_mpi_cache_received_data_set(starpu_data_handle_t data);
+int _starpu_mpi_cache_received_data_get(starpu_data_handle_t data);
 void _starpu_mpi_cache_received_data_clear(starpu_data_handle_t data);
 
 /*
  * If the data is already available in the cache, return a pointer to the data
  * If the data is NOT available in the cache, add it to the cache and return NULL
  */
-void *_starpu_mpi_cache_sent_data_set(starpu_data_handle_t data, int dest);
-void *_starpu_mpi_cache_sent_data_get(starpu_data_handle_t data, int dest);
-void _starpu_mpi_cache_sent_data_clear(MPI_Comm comm, starpu_data_handle_t data);
+int _starpu_mpi_cache_sent_data_set(starpu_data_handle_t data, int dest);
+int _starpu_mpi_cache_sent_data_get(starpu_data_handle_t data, int dest);
+void _starpu_mpi_cache_sent_data_clear(starpu_data_handle_t data);
 
-void _starpu_mpi_cache_flush(MPI_Comm comm, starpu_data_handle_t data_handle);
+void _starpu_mpi_cache_flush(starpu_data_handle_t data_handle);
 
 #ifdef __cplusplus
 }

+ 10 - 0
mpi/src/starpu_mpi_private.h

@@ -186,6 +186,16 @@ struct _starpu_mpi_node_tag
 	int data_tag;
 };
 
+struct _starpu_mpi_data
+{
+	int magic;
+	struct _starpu_mpi_node_tag node_tag;
+	starpu_pthread_mutex_t *cache_sent_mutex;
+	int *cache_sent;
+	starpu_pthread_mutex_t cache_received_mutex;
+	int cache_received;
+};
+
 LIST_TYPE(_starpu_mpi_req,
 	/* description of the data at StarPU level */
 	starpu_data_handle_t data_handle;

+ 2 - 2
mpi/src/starpu_mpi_tag.c

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016  CNRS
+ * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016, 2017  CNRS
  * Copyright (C) 2011-2015  Université de Bordeaux
  * Copyright (C) 2014 INRIA
  *
@@ -102,7 +102,7 @@ int _starpu_mpi_tag_data_release(starpu_data_handle_t handle)
 		struct handle_tag_entry *tag_entry;
 
 		_starpu_spin_lock(&registered_tag_handles_lock);
-		HASH_FIND_INT(registered_tag_handles, &(((struct _starpu_mpi_node_tag *)(handle->mpi_data))->data_tag), tag_entry);
+		HASH_FIND_INT(registered_tag_handles, &(((struct _starpu_mpi_data *)(handle->mpi_data))->node_tag.data_tag), tag_entry);
 		STARPU_ASSERT_MSG((tag_entry != NULL),"Data handle %p with tag %d isn't in the hashmap !",handle,tag);
 
 		HASH_DEL(registered_tag_handles, tag_entry);

+ 7 - 7
mpi/src/starpu_mpi_task_insert.c

@@ -110,8 +110,8 @@ void _starpu_mpi_exchange_data_before_execution(starpu_data_handle_t data, enum
 		if (do_execute && mpi_rank != me)
 		{
 			/* The node is going to execute the codelet, but it does not own the data, it needs to receive the data from the owner node */
-			void *already_received = _starpu_mpi_cache_received_data_set(data, mpi_rank);
-			if (already_received == NULL)
+			int already_received = _starpu_mpi_cache_received_data_set(data);
+			if (already_received == 0)
 			{
 				_STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data, mpi_rank);
 				starpu_mpi_irecv_detached(data, mpi_rank, data_tag, comm, NULL, NULL);
@@ -122,8 +122,8 @@ void _starpu_mpi_exchange_data_before_execution(starpu_data_handle_t data, enum
 		if (!do_execute && mpi_rank == me)
 		{
 			/* The node owns the data, but another node is going to execute the codelet, the node needs to send the data to the executee node. */
-			void *already_sent = _starpu_mpi_cache_sent_data_set(data, xrank);
-			if (already_sent == NULL)
+			int already_sent = _starpu_mpi_cache_sent_data_set(data, xrank);
+			if (already_sent == 0)
 			{
 				_STARPU_MPI_DEBUG(1, "Sending data %p to %d\n", data, xrank);
 				_SEND_DATA(data, mode, xrank, data_tag, comm, NULL, NULL);
@@ -165,14 +165,14 @@ void _starpu_mpi_exchange_data_after_execution(starpu_data_handle_t data, enum s
 }
 
 static
-void _starpu_mpi_clear_data_after_execution(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int do_execute, MPI_Comm comm)
+void _starpu_mpi_clear_data_after_execution(starpu_data_handle_t data, enum starpu_data_access_mode mode, int me, int do_execute)
 {
 	if (_starpu_cache_enabled)
 	{
 		if (mode & STARPU_W || mode & STARPU_REDUX)
 		{
 			/* The data has been modified, it MUST be removed from the cache */
-			_starpu_mpi_cache_sent_data_clear(comm, data);
+			_starpu_mpi_cache_sent_data_clear(data);
 			_starpu_mpi_cache_received_data_clear(data);
 		}
 	}
@@ -503,7 +503,7 @@ int _starpu_mpi_task_postbuild_v(MPI_Comm comm, int xrank, int do_execute, struc
 	for(i=0 ; i<nb_data ; i++)
 	{
 		_starpu_mpi_exchange_data_after_execution(descrs[i].handle, descrs[i].mode, me, xrank, do_execute, comm);
-		_starpu_mpi_clear_data_after_execution(descrs[i].handle, descrs[i].mode, me, do_execute, comm);
+		_starpu_mpi_clear_data_after_execution(descrs[i].handle, descrs[i].mode, me, do_execute);
 	}
 
 	free(descrs);

+ 5 - 5
mpi/tests/cache.c

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2015, 2016  CNRS
+ * Copyright (C) 2015, 2016, 2017  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -60,23 +60,23 @@ struct starpu_codelet mycodelet_rw =
 
 void test(struct starpu_codelet *codelet, enum starpu_data_access_mode mode, starpu_data_handle_t data, int rank, int in_cache)
 {
-	void *ptr;
+	int cache;
 	int ret;
 
 	ret = starpu_mpi_task_insert(MPI_COMM_WORLD, codelet, mode, data, STARPU_EXECUTE_ON_NODE, 1, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_task_insert");
 
-	ptr = _starpu_mpi_cache_received_data_get(data, 0);
+	cache = _starpu_mpi_cache_received_data_get(data);
 
 	if (rank == 1)
 	{
 	     if (in_cache)
 	     {
-		     STARPU_ASSERT_MSG(ptr != NULL, "Data should be in cache\n");
+		     STARPU_ASSERT_MSG(cache == 1, "Data should be in cache\n");
 	     }
 	     else
 	     {
-		     STARPU_ASSERT_MSG(ptr == NULL, "Data should NOT be in cache\n");
+		     STARPU_ASSERT_MSG(cache == 0, "Data should NOT be in cache\n");
 	     }
 	}
 }

+ 8 - 8
mpi/tests/cache_disable.c

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2015, 2016  CNRS
+ * Copyright (C) 2015, 2016, 2017  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -48,7 +48,7 @@ int main(int argc, char **argv)
 	int ret;
 	unsigned *val;
 	starpu_data_handle_t data;
-	void *ptr = NULL;
+	int in_cache;
 	int cache;
 
 	ret = starpu_init(NULL);
@@ -73,28 +73,28 @@ int main(int argc, char **argv)
 	ret = starpu_mpi_task_insert(MPI_COMM_WORLD, &mycodelet_r, STARPU_R, data, STARPU_EXECUTE_ON_NODE, 1, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_task_insert");
 
-	ptr = _starpu_mpi_cache_received_data_get(data, 0);
+	in_cache = _starpu_mpi_cache_received_data_get(data);
 	if (rank == 1)
 	{
-		STARPU_ASSERT_MSG(ptr != NULL, "Data should be in cache\n");
+		STARPU_ASSERT_MSG(in_cache == 1, "Data should be in cache\n");
 	}
 
 	// We clean the cache
 	starpu_mpi_cache_set(0);
 
 	// We check the data is no longer in the cache
-	ptr = _starpu_mpi_cache_received_data_get(data, 0);
+	in_cache = _starpu_mpi_cache_received_data_get(data);
 	if (rank == 1)
 	{
-		STARPU_ASSERT_MSG(ptr == NULL, "Data should NOT be in cache\n");
+		STARPU_ASSERT_MSG(in_cache == 0, "Data should NOT be in cache\n");
 	}
 
 	ret = starpu_mpi_task_insert(MPI_COMM_WORLD, &mycodelet_r, STARPU_R, data, STARPU_EXECUTE_ON_NODE, 1, 0);
 	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_task_insert");
-	ptr = _starpu_mpi_cache_received_data_get(data, 0);
+	in_cache = _starpu_mpi_cache_received_data_get(data);
 	if (rank == 1)
 	{
-		STARPU_ASSERT_MSG(ptr == NULL, "Data should NOT be in cache\n");
+		STARPU_ASSERT_MSG(in_cache == 0, "Data should NOT be in cache\n");
 	}
 
 	FPRINTF(stderr, "Waiting ...\n");