Browse Source

Always use starpu_mpi_tag_t for the starpu mpi tags

This notably fixes starpu_mpi_get_data_on_node with tags beyond 32bits.
Samuel Thibault 4 years ago
parent
commit
807642e4ad

+ 1 - 1
mpi/examples/matrix_mult/mm.c

@@ -128,7 +128,7 @@ static void register_matrices()
 	int mr = (comm_rank == 0) ? STARPU_MAIN_RAM : -1;
 
 	/* mpi tag used for the block */
-	int tag = 0;
+	starpu_mpi_tag_t tag = 0;
 
 	int b_row,b_col;
 

+ 9 - 9
mpi/examples/mpi_lu/pxlu.c

@@ -90,7 +90,7 @@ static struct starpu_task *create_task(starpu_tag_t id)
 
 /* Send handle to every node appearing in the mask, and unlock tag once the
  * transfers are done. */
-static void send_data_to_mask(starpu_data_handle_t handle, int *rank_mask, int mpi_tag, starpu_tag_t tag)
+static void send_data_to_mask(starpu_data_handle_t handle, int *rank_mask, starpu_mpi_tag_t mpi_tag, starpu_tag_t tag)
 {
 	unsigned cnt = 0;
 
@@ -134,7 +134,7 @@ static void send_data_to_mask(starpu_data_handle_t handle, int *rank_mask, int m
 struct recv_when_done_callback_arg
 {
 	int source;
-	int mpi_tag;
+	starpu_mpi_tag_t mpi_tag;
 	starpu_data_handle_t handle;
 	starpu_tag_t unlocked_tag;
 };
@@ -150,7 +150,7 @@ static void callback_receive_when_done(void *_arg)
 }
 
 static void receive_when_deps_are_done(unsigned ndeps, starpu_tag_t *deps_tags,
-				int source, int mpi_tag,
+				int source, starpu_mpi_tag_t mpi_tag,
 				starpu_data_handle_t handle,
 				starpu_tag_t partial_tag,
 				starpu_tag_t unlocked_tag)
@@ -218,7 +218,7 @@ static void create_task_11_recv(unsigned k)
 #else
 	starpu_data_handle_t block_handle = STARPU_PLU(get_tmp_11_block_handle)(k);
 #endif
-	int mpi_tag = MPI_TAG11(k);
+	starpu_mpi_tag_t mpi_tag = MPI_TAG11(k);
 	starpu_tag_t partial_tag = TAG11_SAVE_PARTIAL(k);
 	starpu_tag_t unlocked_tag = TAG11_SAVE(k);
 
@@ -260,7 +260,7 @@ static void callback_task_11_real(void *_arg)
 	/* Send the block to those nodes */
 	starpu_data_handle_t block_handle = STARPU_PLU(get_block_handle)(k, k);
 	starpu_tag_t tag = TAG11_SAVE(k);
-	int mpi_tag = MPI_TAG11(k);
+	starpu_mpi_tag_t mpi_tag = MPI_TAG11(k);
 	send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
 
 	free(arg);
@@ -380,7 +380,7 @@ static void create_task_12_recv(unsigned k, unsigned j)
 #else
 	starpu_data_handle_t block_handle = STARPU_PLU(get_tmp_12_block_handle)(j,k);
 #endif
-	int mpi_tag = MPI_TAG12(k, j);
+	starpu_mpi_tag_t mpi_tag = MPI_TAG12(k, j);
 	starpu_tag_t partial_tag = TAG12_SAVE_PARTIAL(k, j);
 	starpu_tag_t unlocked_tag = TAG12_SAVE(k, j);
 
@@ -415,7 +415,7 @@ static void callback_task_12_real(void *_arg)
 	/* Send the block to those nodes */
 	starpu_data_handle_t block_handle = STARPU_PLU(get_block_handle)(k, j);
 	starpu_tag_t tag = TAG12_SAVE(k, j);
-	int mpi_tag = MPI_TAG12(k, j);
+	starpu_mpi_tag_t mpi_tag = MPI_TAG12(k, j);
 	send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
 
 	free(arg);
@@ -564,7 +564,7 @@ static void create_task_21_recv(unsigned k, unsigned i)
 #else
 	starpu_data_handle_t block_handle = STARPU_PLU(get_tmp_21_block_handle)(i, k);
 #endif
-	int mpi_tag = MPI_TAG21(k, i);
+	starpu_mpi_tag_t mpi_tag = MPI_TAG21(k, i);
 	starpu_tag_t partial_tag = TAG21_SAVE_PARTIAL(k, i);
 	starpu_tag_t unlocked_tag = TAG21_SAVE(k, i);
 
@@ -600,7 +600,7 @@ static void callback_task_21_real(void *_arg)
 	/* Send the block to those nodes */
 	starpu_data_handle_t block_handle = STARPU_PLU(get_block_handle)(i, k);
 	starpu_tag_t tag = TAG21_SAVE(k, i);
-	int mpi_tag = MPI_TAG21(k, i);
+	starpu_mpi_tag_t mpi_tag = MPI_TAG21(k, i);
 	send_data_to_mask(block_handle, rank_mask, mpi_tag, tag);
 
 	free(arg);

+ 16 - 15
mpi/src/load_balancer/policy/data_movements_interface.c

@@ -23,7 +23,7 @@
 
 #if defined(STARPU_USE_MPI_MPI)
 
-int **data_movements_get_ref_tags_table(starpu_data_handle_t handle)
+starpu_mpi_tag_t **data_movements_get_ref_tags_table(starpu_data_handle_t handle)
 {
 	struct data_movements_interface *dm_interface =
 		(struct data_movements_interface *) starpu_data_get_interface_on_node(handle, STARPU_MAIN_RAM);
@@ -45,7 +45,7 @@ int **data_movements_get_ref_ranks_table(starpu_data_handle_t handle)
 		return NULL;
 }
 
-int *data_movements_get_tags_table(starpu_data_handle_t handle)
+starpu_mpi_tag_t *data_movements_get_tags_table(starpu_data_handle_t handle)
 {
 	struct data_movements_interface *dm_interface =
 		(struct data_movements_interface *) starpu_data_get_interface_on_node(handle, STARPU_MAIN_RAM);
@@ -94,8 +94,8 @@ int data_movements_reallocate_tables(starpu_data_handle_t handle, int size)
 
 	if (dm_interface->size)
 	{
-		_STARPU_MPI_MALLOC(dm_interface->tags, size*sizeof(int));
-		_STARPU_MPI_MALLOC(dm_interface->ranks, size*sizeof(int));
+		_STARPU_MPI_MALLOC(dm_interface->tags, size*sizeof(*dm_interface->tags));
+		_STARPU_MPI_MALLOC(dm_interface->ranks, size*sizeof(*dm_interface->ranks));
 	}
 
 	return 0 ;
@@ -129,7 +129,7 @@ static starpu_ssize_t data_movements_allocate_data_on_node(void *data_interface,
 {
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
 
-	int *addr_tags;
+	starpu_mpi_tag_t *addr_tags;
 	int *addr_ranks;
 	starpu_ssize_t requested_memory = dm_interface->size * sizeof(int);
 
@@ -155,10 +155,11 @@ fail_tags:
 static void data_movements_free_data_on_node(void *data_interface, unsigned node)
 {
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
-	starpu_ssize_t requested_memory = dm_interface->size * sizeof(int);
+	starpu_ssize_t requested_memory_tags = dm_interface->size * sizeof(starpu_mpi_tag_t);
+	starpu_ssize_t requested_memory_ranks = dm_interface->size * sizeof(int);
 
-	starpu_free_on_node(node, (uintptr_t) dm_interface->tags, requested_memory);
-	starpu_free_on_node(node, (uintptr_t) dm_interface->ranks, requested_memory);
+	starpu_free_on_node(node, (uintptr_t) dm_interface->tags, requested_memory_tags);
+	starpu_free_on_node(node, (uintptr_t) dm_interface->ranks, requested_memory_ranks);
 }
 
 static size_t data_movements_get_size(starpu_data_handle_t handle)
@@ -166,7 +167,7 @@ static size_t data_movements_get_size(starpu_data_handle_t handle)
 	size_t size;
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) starpu_data_get_interface_on_node(handle, STARPU_MAIN_RAM);
 
-	size = (dm_interface->size * 2 * sizeof(int)) + sizeof(int);
+	size = (dm_interface->size * sizeof(starpu_mpi_tag_t)) + (dm_interface->size * sizeof(int)) + sizeof(int);
 	return size;
 }
 
@@ -192,8 +193,8 @@ static int data_movements_pack_data(starpu_data_handle_t handle, unsigned node,
 		memcpy(data, &dm_interface->size, sizeof(int));
 		if (dm_interface->size)
 		{
-			memcpy(data+sizeof(int), dm_interface->tags, (dm_interface->size*sizeof(int)));
-			memcpy(data+sizeof(int)+(dm_interface->size*sizeof(int)), dm_interface->ranks, dm_interface->size*sizeof(int));
+			memcpy(data+sizeof(int), dm_interface->tags, (dm_interface->size*sizeof(starpu_mpi_tag_t)));
+			memcpy(data+sizeof(int)+(dm_interface->size*sizeof(starpu_mpi_tag_t)), dm_interface->ranks, dm_interface->size*sizeof(int));
 		}
 	}
 
@@ -216,8 +217,8 @@ static int data_movements_unpack_data(starpu_data_handle_t handle, unsigned node
 
 	if (dm_interface->size)
 	{
-		memcpy(dm_interface->tags, data+sizeof(int), dm_interface->size*sizeof(int));
-		memcpy(dm_interface->ranks, data+sizeof(int)+(dm_interface->size*sizeof(int)), dm_interface->size*sizeof(int));
+		memcpy(dm_interface->tags, data+sizeof(int), dm_interface->size*sizeof(starpu_mpi_tag_t));
+		memcpy(dm_interface->ranks, data+sizeof(int)+(dm_interface->size*sizeof(starpu_mpi_tag_t)), dm_interface->size*sizeof(int));
 	}
 
     return 0;
@@ -233,7 +234,7 @@ static int copy_any_to_any(void *src_interface, unsigned src_node,
 
 	if (starpu_interface_copy((uintptr_t) src_data_movements->tags, 0, src_node,
 				    (uintptr_t) dst_data_movements->tags, 0, dst_node,
-				     src_data_movements->size*sizeof(int),
+				     src_data_movements->size*sizeof(starpu_mpi_tag_t),
 				     async_data))
 		ret = -EAGAIN;
 	if (starpu_interface_copy((uintptr_t) src_data_movements->ranks, 0, src_node,
@@ -265,7 +266,7 @@ static struct starpu_data_interface_ops interface_data_movements_ops =
 	.describe = NULL
 };
 
-void data_movements_data_register(starpu_data_handle_t *handleptr, unsigned home_node, int *ranks, int *tags, int size)
+void data_movements_data_register(starpu_data_handle_t *handleptr, unsigned home_node, int *ranks, starpu_mpi_tag_t *tags, int size)
 {
 	struct data_movements_interface data_movements =
 	{

+ 4 - 4
mpi/src/load_balancer/policy/data_movements_interface.h

@@ -25,20 +25,20 @@
 struct data_movements_interface
 {
 	/** Data tags table */
-	int *tags;
+	starpu_mpi_tag_t *tags;
 	/** Ranks table (where to move the corresponding data) */
 	int *ranks;
 	/** Size of the tables */
 	int size;
 };
 
-void data_movements_data_register(starpu_data_handle_t *handle, unsigned home_node, int *ranks, int *tags, int size);
+void data_movements_data_register(starpu_data_handle_t *handle, unsigned home_node, int *ranks, starpu_mpi_tag_t *tags, int size);
 
-int **data_movements_get_ref_tags_table(starpu_data_handle_t handle);
+starpu_mpi_tag_t **data_movements_get_ref_tags_table(starpu_data_handle_t handle);
 int **data_movements_get_ref_ranks_table(starpu_data_handle_t handle);
 int data_movements_reallocate_tables(starpu_data_handle_t handle, int size);
 
-int *data_movements_get_tags_table(starpu_data_handle_t handle);
+starpu_mpi_tag_t *data_movements_get_tags_table(starpu_data_handle_t handle);
 int *data_movements_get_ranks_table(starpu_data_handle_t handle);
 int data_movements_get_size_tables(starpu_data_handle_t handle);
 

+ 4 - 4
mpi/src/load_balancer/policy/load_heat_propagation.c

@@ -27,12 +27,12 @@
 
 #if defined(STARPU_USE_MPI_MPI)
 
-static int TAG_LOAD(int n)
+static starpu_mpi_tag_t TAG_LOAD(int n)
 {
 	return (n+1) << 24;
 }
 
-static int TAG_MOV(int n)
+static starpu_mpi_tag_t TAG_MOV(int n)
 {
 	return (n+1) << 20;
 }
@@ -132,7 +132,7 @@ static void balance(starpu_data_handle_t load_data_cpy)
 
 			if (nhandles)
 			{
-				int *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
+				starpu_mpi_tag_t *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
 				int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
 
 				for (n = 0; n < nhandles; n++)
@@ -564,7 +564,7 @@ static int deinit_heat()
 		_STARPU_DEBUG("Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
 		data_movements_reallocate_tables(data_movements_handles[my_rank], ndata_to_move_back);
 
-		int *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
+		starpu_mpi_tag_t *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
 		int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
 
 		int n = 0;

+ 4 - 2
mpi/src/starpu_mpi.c

@@ -325,7 +325,8 @@ starpu_mpi_tag_t starpu_mpi_data_get_tag(starpu_data_handle_t data)
 
 void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
 {
-	int me, rank, tag;
+	int me, rank;
+	starpu_mpi_tag_t tag;
 
 	rank = starpu_mpi_data_get_rank(data_handle);
 	if (rank == -1)
@@ -367,7 +368,8 @@ void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t da
 
 void starpu_mpi_get_data_on_node(MPI_Comm comm, starpu_data_handle_t data_handle, int node)
 {
-	int me, rank, tag;
+	int me, rank;
+	starpu_mpi_tag_t tag;
 
 	rank = starpu_mpi_data_get_rank(data_handle);
 	if (rank == -1)

+ 1 - 1
mpi/tests/ring.c

@@ -99,7 +99,7 @@ int main(int argc, char **argv)
 
 	for (loop = 0; loop < nloops; loop++)
 	{
-		int tag = loop*size + rank;
+		starpu_mpi_tag_t tag = loop*size + rank;
 
 		if (loop == 0 && rank == 0)
 		{

+ 1 - 1
mpi/tests/ring_async.c

@@ -99,7 +99,7 @@ int main(int argc, char **argv)
 
 	for (loop = 0; loop < nloops; loop++)
 	{
-		int tag = loop*size + rank;
+		starpu_mpi_tag_t tag = loop*size + rank;
 
 		if (loop == 0 && rank == 0)
 		{

+ 1 - 1
mpi/tests/ring_async_implicit.c

@@ -92,7 +92,7 @@ int main(int argc, char **argv)
 
 	for (loop = 0; loop < nloops; loop++)
 	{
-		int tag = loop*size + rank;
+		starpu_mpi_tag_t tag = loop*size + rank;
 
 		if (loop == 0 && rank == 0)
 		{

+ 1 - 1
mpi/tests/ring_sync.c

@@ -99,7 +99,7 @@ int main(int argc, char **argv)
 
 	for (loop = 0; loop < nloops; loop++)
 	{
-		int tag = loop*size + rank;
+		starpu_mpi_tag_t tag = loop*size + rank;
 
 		if (loop == 0 && rank == 0)
 		{

+ 1 - 1
mpi/tests/ring_sync_detached.c

@@ -112,7 +112,7 @@ int main(int argc, char **argv)
 
 	for (loop = 0; loop < nloops; loop++)
 	{
-		int tag = loop*size + rank;
+		starpu_mpi_tag_t tag = loop*size + rank;
 
 		if (loop == 0 && rank == 0)
 		{

+ 3 - 3
mpi/tests/user_defined_datatype.c

@@ -26,9 +26,9 @@
 #  define ELEMENTS 1000
 #endif
 
-typedef void (*test_func)(starpu_data_handle_t *, int, int, int);
+typedef void (*test_func)(starpu_data_handle_t *, int, int, starpu_mpi_tag_t);
 
-void test_handle_irecv_isend_detached(starpu_data_handle_t *handles, int nb_handles, int rank, int tag)
+void test_handle_irecv_isend_detached(starpu_data_handle_t *handles, int nb_handles, int rank, starpu_mpi_tag_t tag)
 {
 	int i;
 	(void)rank;
@@ -42,7 +42,7 @@ void test_handle_irecv_isend_detached(starpu_data_handle_t *handles, int nb_hand
 		starpu_mpi_get_data_on_node_detached(MPI_COMM_WORLD, handles[i], 0, NULL, NULL);
 }
 
-void test_handle_recv_send(starpu_data_handle_t *handles, int nb_handles, int rank, int tag)
+void test_handle_recv_send(starpu_data_handle_t *handles, int nb_handles, int rank, starpu_mpi_tag_t tag)
 {
 	int i;