Browse Source

mpi: Fix data_movement interface

Samuel Thibault 4 years ago
parent
commit
45572515f0

+ 29 - 13
mpi/src/load_balancer/policy/data_movements_interface.c

@@ -69,19 +69,15 @@ int data_movements_get_size_tables(starpu_data_handle_t handle)
 	return dm_interface->size;
 	return dm_interface->size;
 }
 }
 
 
-int data_movements_reallocate_tables(starpu_data_handle_t handle, int size)
-{
-	struct data_movements_interface *dm_interface =
-		(struct data_movements_interface *) starpu_data_get_interface_on_node(handle, STARPU_MAIN_RAM);
+static void data_movements_free_data_on_node(void *data_interface, unsigned node);
+static starpu_ssize_t data_movements_allocate_data_on_node(void *data_interface, unsigned node);
 
 
-	if (dm_interface->size)
+int data_movements_reallocate_tables_interface(struct data_movements_interface *dm_interface, unsigned node, int size)
+{
+	if (dm_interface->tags)
 	{
 	{
-		STARPU_ASSERT(dm_interface->tags);
-		free(dm_interface->tags);
+		data_movements_free_data_on_node(dm_interface, node);
 		dm_interface->tags = NULL;
 		dm_interface->tags = NULL;
-
-		STARPU_ASSERT(dm_interface->ranks);
-		free(dm_interface->ranks);
 		dm_interface->ranks = NULL;
 		dm_interface->ranks = NULL;
 	}
 	}
 	else
 	else
@@ -94,13 +90,20 @@ int data_movements_reallocate_tables(starpu_data_handle_t handle, int size)
 
 
 	if (dm_interface->size)
 	if (dm_interface->size)
 	{
 	{
-		_STARPU_MPI_MALLOC(dm_interface->tags, size*sizeof(*dm_interface->tags));
-		_STARPU_MPI_MALLOC(dm_interface->ranks, size*sizeof(*dm_interface->ranks));
+		starpu_ssize_t resize = data_movements_allocate_data_on_node(dm_interface, node);
+		STARPU_ASSERT(resize > 0);
 	}
 	}
 
 
 	return 0 ;
 	return 0 ;
 }
 }
 
 
+int data_movements_reallocate_tables(starpu_data_handle_t handle, unsigned node, int size)
+{
+	struct data_movements_interface *dm_interface =
+		(struct data_movements_interface *) starpu_data_get_interface_on_node(handle, node);
+	return data_movements_reallocate_tables_interface (dm_interface, node, size);
+}
+
 static void data_movements_register_data_handle(starpu_data_handle_t handle, unsigned home_node, void *data_interface)
 static void data_movements_register_data_handle(starpu_data_handle_t handle, unsigned home_node, void *data_interface)
 {
 {
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
@@ -129,6 +132,13 @@ static starpu_ssize_t data_movements_allocate_data_on_node(void *data_interface,
 {
 {
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
 
 
+	if (!dm_interface->size)
+	{
+		dm_interface->tags = NULL;
+		dm_interface->ranks = NULL;
+		return 0;
+	}
+
 	starpu_mpi_tag_t *addr_tags;
 	starpu_mpi_tag_t *addr_tags;
 	int *addr_ranks;
 	int *addr_ranks;
 	starpu_ssize_t requested_memory_tags = dm_interface->size * sizeof(starpu_mpi_tag_t);
 	starpu_ssize_t requested_memory_tags = dm_interface->size * sizeof(starpu_mpi_tag_t);
@@ -156,6 +166,10 @@ fail_tags:
 static void data_movements_free_data_on_node(void *data_interface, unsigned node)
 static void data_movements_free_data_on_node(void *data_interface, unsigned node)
 {
 {
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
 	struct data_movements_interface *dm_interface = (struct data_movements_interface *) data_interface;
+
+	if (! dm_interface->tags)
+		return;
+
 	starpu_ssize_t requested_memory_tags = dm_interface->size * sizeof(starpu_mpi_tag_t);
 	starpu_ssize_t requested_memory_tags = dm_interface->size * sizeof(starpu_mpi_tag_t);
 	starpu_ssize_t requested_memory_ranks = dm_interface->size * sizeof(int);
 	starpu_ssize_t requested_memory_ranks = dm_interface->size * sizeof(int);
 
 
@@ -213,7 +227,7 @@ static int data_movements_peek_data(starpu_data_handle_t handle, unsigned node,
 	memcpy(&size, data, sizeof(int));
 	memcpy(&size, data, sizeof(int));
 	STARPU_ASSERT(count == (2 * size * sizeof(int)) + sizeof(int));
 	STARPU_ASSERT(count == (2 * size * sizeof(int)) + sizeof(int));
 
 
-	data_movements_reallocate_tables(handle, size);
+	data_movements_reallocate_tables(handle, node, size);
 
 
 	if (dm_interface->size)
 	if (dm_interface->size)
 	{
 	{
@@ -240,6 +254,8 @@ static int copy_any_to_any(void *src_interface, unsigned src_node,
 	struct data_movements_interface *dst_data_movements = dst_interface;
 	struct data_movements_interface *dst_data_movements = dst_interface;
 	int ret = 0;
 	int ret = 0;
 
 
+	data_movements_reallocate_tables_interface(dst_data_movements, dst_node, src_data_movements->size);
+
 	if (starpu_interface_copy((uintptr_t) src_data_movements->tags, 0, src_node,
 	if (starpu_interface_copy((uintptr_t) src_data_movements->tags, 0, src_node,
 				    (uintptr_t) dst_data_movements->tags, 0, dst_node,
 				    (uintptr_t) dst_data_movements->tags, 0, dst_node,
 				     src_data_movements->size*sizeof(starpu_mpi_tag_t),
 				     src_data_movements->size*sizeof(starpu_mpi_tag_t),

+ 1 - 1
mpi/src/load_balancer/policy/data_movements_interface.h

@@ -36,7 +36,7 @@ void data_movements_data_register(starpu_data_handle_t *handle, unsigned home_no
 
 
 starpu_mpi_tag_t **data_movements_get_ref_tags_table(starpu_data_handle_t handle);
 starpu_mpi_tag_t **data_movements_get_ref_tags_table(starpu_data_handle_t handle);
 int **data_movements_get_ref_ranks_table(starpu_data_handle_t handle);
 int **data_movements_get_ref_ranks_table(starpu_data_handle_t handle);
-int data_movements_reallocate_tables(starpu_data_handle_t handle, int size);
+int data_movements_reallocate_tables(starpu_data_handle_t handle, unsigned node, int size);
 
 
 starpu_mpi_tag_t *data_movements_get_tags_table(starpu_data_handle_t handle);
 starpu_mpi_tag_t *data_movements_get_tags_table(starpu_data_handle_t handle);
 int *data_movements_get_ranks_table(starpu_data_handle_t handle);
 int *data_movements_get_ranks_table(starpu_data_handle_t handle);

+ 12 - 6
mpi/src/load_balancer/policy/load_heat_propagation.c

@@ -114,6 +114,7 @@ static void balance(starpu_data_handle_t load_data_cpy)
 		}
 		}
 	}
 	}
 
 
+	starpu_data_acquire_on_node(data_movements_handles[my_rank], STARPU_MAIN_RAM, STARPU_RW);
 	/* We found it */
 	/* We found it */
 	if (less_loaded >= 0)
 	if (less_loaded >= 0)
 	{
 	{
@@ -128,7 +129,7 @@ static void balance(starpu_data_handle_t load_data_cpy)
 			int nhandles = 0;
 			int nhandles = 0;
 			user_itf->get_data_unit_to_migrate(&handles, &nhandles, less_loaded);
 			user_itf->get_data_unit_to_migrate(&handles, &nhandles, less_loaded);
 
 
-			data_movements_reallocate_tables(data_movements_handles[my_rank], nhandles);
+			data_movements_reallocate_tables(data_movements_handles[my_rank], STARPU_MAIN_RAM, nhandles);
 
 
 			if (nhandles)
 			if (nhandles)
 			{
 			{
@@ -145,10 +146,11 @@ static void balance(starpu_data_handle_t load_data_cpy)
 			}
 			}
 		}
 		}
 		else
 		else
-			data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
+			data_movements_reallocate_tables(data_movements_handles[my_rank], STARPU_MAIN_RAM, 0);
 	}
 	}
 	else
 	else
-		data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
+		data_movements_reallocate_tables(data_movements_handles[my_rank], STARPU_MAIN_RAM, 0);
+	starpu_data_release_on_node(data_movements_handles[my_rank], STARPU_MAIN_RAM);
 }
 }
 
 
 static void exchange_load_data_infos(starpu_data_handle_t load_data_cpy)
 static void exchange_load_data_infos(starpu_data_handle_t load_data_cpy)
@@ -559,10 +561,11 @@ static int deinit_heat()
 
 
 	unsigned int ndata_to_move_back = HASH_COUNT(mdh);
 	unsigned int ndata_to_move_back = HASH_COUNT(mdh);
 
 
+	starpu_data_acquire_on_node(data_movements_handles[my_rank], STARPU_MAIN_RAM, STARPU_RW);
 	if (ndata_to_move_back)
 	if (ndata_to_move_back)
 	{
 	{
 		_STARPU_DEBUG("Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
 		_STARPU_DEBUG("Move back %u data on node %d ..\n", ndata_to_move_back, my_rank);
-		data_movements_reallocate_tables(data_movements_handles[my_rank], ndata_to_move_back);
+		data_movements_reallocate_tables(data_movements_handles[my_rank], STARPU_MAIN_RAM, ndata_to_move_back);
 
 
 		starpu_mpi_tag_t *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
 		starpu_mpi_tag_t *tags = data_movements_get_tags_table(data_movements_handles[my_rank]);
 		int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
 		int *ranks = data_movements_get_ranks_table(data_movements_handles[my_rank]);
@@ -577,7 +580,8 @@ static int deinit_heat()
 		}
 		}
 	}
 	}
 	else
 	else
-		data_movements_reallocate_tables(data_movements_handles[my_rank], 0);
+		data_movements_reallocate_tables(data_movements_handles[my_rank], STARPU_MAIN_RAM, 0);
+	starpu_data_release_on_node(data_movements_handles[my_rank], STARPU_MAIN_RAM);
 
 
 	exchange_data_movements_infos();
 	exchange_data_movements_infos();
 	move_back_data();
 	move_back_data();
@@ -612,7 +616,9 @@ static int deinit_heat()
 	for (i = 0; i < world_size; i++)
 	for (i = 0; i < world_size; i++)
 	{
 	{
 		starpu_mpi_cache_flush(MPI_COMM_WORLD, data_movements_handles[i]);
 		starpu_mpi_cache_flush(MPI_COMM_WORLD, data_movements_handles[i]);
-		data_movements_reallocate_tables(data_movements_handles[i], 0);
+		starpu_data_acquire_on_node(data_movements_handles[i], STARPU_MAIN_RAM, STARPU_W);
+		data_movements_reallocate_tables(data_movements_handles[i], STARPU_MAIN_RAM, 0);
+		starpu_data_release_on_node(data_movements_handles[i], STARPU_MAIN_RAM);
 		starpu_data_unregister(data_movements_handles[i]);
 		starpu_data_unregister(data_movements_handles[i]);
 	}
 	}
 	free(data_movements_handles);
 	free(data_movements_handles);