Kaynağa Gözat

Revert "mpi: simplify test for early data/request and remove un-needed functions"

This reverts commit f9d21d11af3797a6de5a5fff619aef7a0498acf4.

+ fix memory leaks in test
Nathalie Furmento 4 yıl önce
ebeveyn
işleme
a82071ff73

+ 23 - 0
mpi/src/mpi/starpu_mpi_early_data.c

@@ -136,6 +136,29 @@ struct _starpu_mpi_early_data_handle *_starpu_mpi_early_data_find(struct _starpu
 	return early_data_handle;
 }
 
+struct _starpu_mpi_early_data_handle_tag_hashlist *_starpu_mpi_early_data_extract(struct _starpu_mpi_node_tag *node_tag)
+{
+	struct _starpu_mpi_early_data_handle_hashlist *hashlist;
+	struct _starpu_mpi_early_data_handle_tag_hashlist *tag_hashlist = NULL;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_early_data_handle_mutex);
+	_STARPU_MPI_DEBUG(60, "Looking for hashlist for (comm %ld, source %d)\n", (long int)node_tag->node.comm, node_tag->node.rank);
+	HASH_FIND(hh, _starpu_mpi_early_data_handle_hashmap, &node_tag->node, sizeof(struct _starpu_mpi_node), hashlist);
+	if (hashlist)
+	{
+		_STARPU_MPI_DEBUG(60, "Looking for hashlist for (tag %ld)\n", node_tag->data_tag);
+		HASH_FIND(hh, hashlist->datahash, &node_tag->data_tag, sizeof(starpu_mpi_tag_t), tag_hashlist);
+		if (tag_hashlist)
+		{
+			_starpu_mpi_early_data_handle_hashmap_count -= _starpu_mpi_early_data_handle_list_size(&tag_hashlist->list);
+			HASH_DEL(hashlist->datahash, tag_hashlist);
+		}
+	}
+	_STARPU_MPI_DEBUG(60, "Found hashlist %p for (comm %ld, source %d) and (tag %ld)\n", tag_hashlist, (long int)node_tag->node.comm, node_tag->node.rank, node_tag->data_tag);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_early_data_handle_mutex);
+	return tag_hashlist;
+}
+
 void _starpu_mpi_early_data_add(struct _starpu_mpi_early_data_handle *early_data_handle)
 {
 	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_early_data_handle_mutex);

+ 2 - 0
mpi/src/mpi/starpu_mpi_early_data.h

@@ -63,6 +63,8 @@ struct _starpu_mpi_early_data_handle *_starpu_mpi_early_data_find(struct _starpu
 void _starpu_mpi_early_data_add(struct _starpu_mpi_early_data_handle *early_data_handle);
 void _starpu_mpi_early_data_delete(struct _starpu_mpi_early_data_handle *early_data_handle);
 
+struct _starpu_mpi_early_data_handle_tag_hashlist *_starpu_mpi_early_data_extract(struct _starpu_mpi_node_tag *node_tag);
+
 #ifdef __cplusplus
 }
 #endif

+ 28 - 0
mpi/src/mpi/starpu_mpi_early_request.c

@@ -112,6 +112,34 @@ struct _starpu_mpi_req* _starpu_mpi_early_request_dequeue(starpu_mpi_tag_t data_
 	return found;
 }
 
+struct _starpu_mpi_early_request_tag_hashlist *_starpu_mpi_early_request_extract(starpu_mpi_tag_t data_tag, int source, MPI_Comm comm)
+{
+	struct _starpu_mpi_node_tag node_tag;
+	struct _starpu_mpi_early_request_hashlist *hashlist;
+	struct _starpu_mpi_early_request_tag_hashlist *tag_hashlist = NULL;
+
+	memset(&node_tag, 0, sizeof(struct _starpu_mpi_node_tag));
+	node_tag.node.comm = comm;
+	node_tag.node.rank = source;
+	node_tag.data_tag = data_tag;
+
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_early_request_mutex);
+	_STARPU_MPI_DEBUG(100, "Looking for early_request with comm %ld source %d tag %ld\n", (long int)node_tag.node.comm, node_tag.node.rank, node_tag.data_tag);
+	HASH_FIND(hh, _starpu_mpi_early_request_hash, &node_tag.node, sizeof(struct _starpu_mpi_node), hashlist);
+	if (hashlist)
+	{
+		HASH_FIND(hh, hashlist->datahash, &node_tag.data_tag, sizeof(starpu_mpi_tag_t), tag_hashlist);
+		if (tag_hashlist)
+		{
+			_starpu_mpi_early_request_hash_count -= _starpu_mpi_req_list_size(&tag_hashlist->list);
+			HASH_DEL(hashlist->datahash, tag_hashlist);
+		}
+	}
+	_STARPU_MPI_DEBUG(100, "Found hashlist %p with comm %ld source %d tag %ld\n", hashlist, (long int)node_tag.node.comm, node_tag.node.rank, node_tag.data_tag);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_early_request_mutex);
+	return tag_hashlist;
+}
+
 void _starpu_mpi_early_request_enqueue(struct _starpu_mpi_req *req)
 {
 	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_early_request_mutex);

+ 2 - 0
mpi/src/mpi/starpu_mpi_early_request.h

@@ -47,6 +47,8 @@ void _starpu_mpi_early_request_check_termination(void);
 void _starpu_mpi_early_request_enqueue(struct _starpu_mpi_req *req);
 struct _starpu_mpi_req* _starpu_mpi_early_request_dequeue(starpu_mpi_tag_t data_tag, int source, MPI_Comm comm);
 
+struct _starpu_mpi_early_request_tag_hashlist *_starpu_mpi_early_request_extract(starpu_mpi_tag_t data_tag, int source, MPI_Comm comm);
+
 #ifdef __cplusplus
 }
 #endif

+ 12 - 2
mpi/tests/early_stuff.c

@@ -42,6 +42,7 @@ void early_data()
 	struct _starpu_mpi_envelope envelope[2];
 	struct _starpu_mpi_node_tag node_tag[2];
 	struct _starpu_mpi_early_data_handle *early;
+	struct _starpu_mpi_early_data_handle_tag_hashlist *hash;
 
 	memset(&node_tag[0], 0, sizeof(struct _starpu_mpi_node_tag));
 	node_tag[0].node.rank = 1;
@@ -62,9 +63,13 @@ void early_data()
 	_starpu_mpi_early_data_add(edh[0]);
 	_starpu_mpi_early_data_add(edh[1]);
 
-	early = _starpu_mpi_early_data_find(&node_tag[0]);
+	hash = _starpu_mpi_early_data_extract(&node_tag[0]);
+	STARPU_ASSERT(_starpu_mpi_early_data_handle_list_size(&hash->list) == 1);
+	early = _starpu_mpi_early_data_handle_list_pop_front(&hash->list);
 	STARPU_ASSERT(early->node_tag.node.comm == node_tag[0].node.comm && early->node_tag.node.rank == node_tag[0].node.rank && early->node_tag.data_tag == node_tag[0].data_tag);
+	STARPU_ASSERT(_starpu_mpi_early_data_handle_list_size(&hash->list) == 0);
 	_starpu_mpi_early_data_delete(early);
+	free(hash);
 
 	early = _starpu_mpi_early_data_find(&node_tag[1]);
 	STARPU_ASSERT(early->node_tag.node.comm == node_tag[1].node.comm && early->node_tag.node.rank == node_tag[1].node.rank && early->node_tag.data_tag == node_tag[1].data_tag);
@@ -75,6 +80,7 @@ void early_request()
 {
 	struct _starpu_mpi_req req[2];
 	struct _starpu_mpi_req *early;
+	struct _starpu_mpi_early_request_tag_hashlist *hash;
 
 	memset(&req[0].node_tag, 0, sizeof(struct _starpu_mpi_node_tag));
 	req[0].node_tag.node.rank = 1;
@@ -92,8 +98,12 @@ void early_request()
 	early = _starpu_mpi_early_request_dequeue(req[0].node_tag.data_tag, req[0].node_tag.node.rank, req[0].node_tag.node.comm);
 	STARPU_ASSERT(early->node_tag.data_tag == req[0].node_tag.data_tag && early->node_tag.node.rank == req[0].node_tag.node.rank && early->node_tag.node.comm == req[0].node_tag.node.comm);
 
-	early = _starpu_mpi_early_request_dequeue(req[1].node_tag.data_tag, req[1].node_tag.node.rank, req[1].node_tag.node.comm);
+	hash = _starpu_mpi_early_request_extract(req[1].node_tag.data_tag, req[1].node_tag.node.rank, req[1].node_tag.node.comm);
+	STARPU_ASSERT(_starpu_mpi_req_list_size(&hash->list) == 1);
+	early = _starpu_mpi_req_list_pop_front(&hash->list);
+	STARPU_ASSERT(_starpu_mpi_req_list_size(&hash->list) == 0);
 	STARPU_ASSERT(early->node_tag.data_tag == req[1].node_tag.data_tag && early->node_tag.node.rank == req[1].node_tag.node.rank && early->node_tag.node.comm == req[1].node_tag.node.comm);
+	free(hash);
 }
 
 int main(int argc, char **argv)