Browse Source

mpi: call MPI_Unpack() for data which have been received as raw early
data and for whose a datatype is available after the post of the
application recv

Nathalie Furmento 4 years ago
parent
commit
e56bfe49c9

+ 7 - 1
mpi/examples/Makefile.am

@@ -336,7 +336,8 @@ starpu_mpi_EXAMPLES	+=			\
 
 examplebin_PROGRAMS +=				\
 	user_datatype/user_datatype		\
-	user_datatype/user_datatype2
+	user_datatype/user_datatype2		\
+	user_datatype/user_datatype_early
 
 user_datatype_user_datatype_SOURCES =		\
 	user_datatype/user_datatype.c		\
@@ -346,9 +347,14 @@ user_datatype_user_datatype2_SOURCES =		\
 	user_datatype/user_datatype2.c		\
 	user_datatype/my_interface.c
 
+user_datatype_user_datatype_early_SOURCES =	\
+	user_datatype/user_datatype_early.c	\
+	user_datatype/my_interface.c
+
 if !STARPU_SIMGRID
 starpu_mpi_EXAMPLES	+=			\
 	user_datatype/user_datatype2		\
+	user_datatype/user_datatype_early	\
 	user_datatype/user_datatype
 endif
 

+ 92 - 0
mpi/examples/user_datatype/user_datatype_early.c

@@ -0,0 +1,92 @@
+/* StarPU --- Runtime system for heterogeneous multicore architectures.
+ *
+ * Copyright (C) 2015-2020  Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
+ *
+ * StarPU is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * StarPU is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#include <starpu_mpi.h>
+#include "my_interface.h"
+
+#define FPRINTF(ofile, fmt, ...) do { if (!getenv("STARPU_SSILENT")) {fprintf(ofile, fmt, ## __VA_ARGS__); }} while(0)
+
+int main(int argc, char **argv)
+{
+	int rank, nodes;
+	int ret=0;
+
+	ret = starpu_mpi_init_conf(&argc, &argv, 1, MPI_COMM_WORLD, NULL);
+	STARPU_CHECK_RETURN_VALUE(ret, "starpu_mpi_init_conf");
+	starpu_mpi_comm_rank(MPI_COMM_WORLD, &rank);
+	starpu_mpi_comm_size(MPI_COMM_WORLD, &nodes);
+
+	if (nodes < 2 || (starpu_cpu_worker_get_count() == 0))
+	{
+		if (rank == 0)
+		{
+			if (nodes < 2)
+				fprintf(stderr, "We need at least 2 processes.\n");
+			else
+				fprintf(stderr, "We need at least 1 CPU.\n");
+		}
+		starpu_mpi_shutdown();
+		return 77;
+	}
+
+	struct starpu_my_data my0 = {.d = 42 , .c = 'n'};
+	struct starpu_my_data my1 = {.d = 11 , .c = 'a'};
+
+	if (rank == 1)
+	{
+		my0.d *= 2;
+		my0.c += 1;
+		my1.d *= 2;
+		my1.c += 1;
+	}
+
+	starpu_data_handle_t handle0;
+	starpu_data_handle_t handle1;
+	starpu_my_data_register(&handle0, STARPU_MAIN_RAM, &my0);
+	starpu_my_data_register(&handle1, STARPU_MAIN_RAM, &my1);
+
+	if (rank == 0)
+	{
+		starpu_mpi_send(handle0, 1, 10, MPI_COMM_WORLD);
+		starpu_mpi_send(handle1, 1, 20, MPI_COMM_WORLD);
+	}
+	else if (rank == 1)
+	{
+		// We want handle0 to be received as early_data and as starpu_mpi_data_register() has not be called, it will be received as raw memory, and then unpacked with MPI_Unpack()
+		starpu_task_insert(&starpu_my_data_display_codelet, STARPU_VALUE, "node1 handle0 init value", strlen("node1 handle0 init value")+1, STARPU_R, handle0, 0);
+		starpu_task_insert(&starpu_my_data_display_codelet, STARPU_VALUE, "node1 handle1 init value", strlen("node1 handle1 init value")+1, STARPU_R, handle1, 0);
+		starpu_mpi_recv(handle1, 0, 20, MPI_COMM_WORLD, NULL);
+		starpu_mpi_recv(handle0, 0, 10, MPI_COMM_WORLD, NULL);
+		starpu_task_insert(&starpu_my_data_display_codelet, STARPU_VALUE, "node1 handle0 received value", strlen("node1 handle0 received value")+1, STARPU_R, handle0, 0);
+		starpu_task_insert(&starpu_my_data_display_codelet, STARPU_VALUE, "node1 handle1 received value", strlen("node1 handle1 received value")+1, STARPU_R, handle1, 0);
+	}
+
+	starpu_mpi_wait_for_all(MPI_COMM_WORLD);
+	starpu_mpi_barrier(MPI_COMM_WORLD);
+
+	starpu_data_unregister(handle0);
+	starpu_data_unregister(handle1);
+
+	if (rank == 1)
+	{
+		STARPU_ASSERT_MSG(my0.d == 42 && my0.c == 'n' && my1.d == 11 && my1.c == 'a', "Incorrect received values");
+	}
+
+	starpu_my_data_shutdown();
+	starpu_mpi_shutdown();
+
+	return 0;
+}

+ 14 - 3
mpi/src/mpi/starpu_mpi_mpi.c

@@ -915,9 +915,20 @@ static void _starpu_mpi_early_data_cb(void* arg)
 		/* Data has been received as a raw memory, it has to be unpacked */
 		struct starpu_data_interface_ops *itf_src = starpu_data_get_interface_ops(args->early_handle);
 		struct starpu_data_interface_ops *itf_dst = starpu_data_get_interface_ops(args->data_handle);
-		STARPU_MPI_ASSERT_MSG(itf_dst->unpack_data, "The data interface does not define an unpack function\n");
-		itf_dst->unpack_data(args->data_handle, STARPU_MAIN_RAM, args->buffer, itf_src->get_size(args->early_handle));
-		args->buffer = NULL;
+		MPI_Datatype datatype = _starpu_mpi_datatype_get_user_defined_datatype(args->data_handle);
+
+		if (datatype)
+		{
+			int position=0;
+			void *ptr = starpu_data_get_local_ptr(args->data_handle);
+			MPI_Unpack(args->buffer, itf_src->get_size(args->early_handle), &position, ptr, 1, datatype, args->req->node_tag.node.comm);
+		}
+		else
+		{
+			STARPU_MPI_ASSERT_MSG(itf_dst->unpack_data, "The data interface does not define an unpack function\n");
+			itf_dst->unpack_data(args->data_handle, STARPU_MAIN_RAM, args->buffer, itf_src->get_size(args->early_handle));
+			args->buffer = NULL;
+		}
 	}
 	else
 	{

+ 21 - 0
mpi/src/starpu_mpi_datatype.c

@@ -208,6 +208,27 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
 };
 
+MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle)
+{
+	enum starpu_data_interface_id id = starpu_data_get_interface_id(data_handle);
+	if (id < STARPU_MAX_INTERFACE_ID) return 0;
+
+	struct _starpu_mpi_datatype_funcs *table;
+	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_datatype_funcs_table_mutex);
+	HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
+	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_datatype_funcs_table_mutex);
+	if (table && table->allocate_datatype_func)
+	{
+		MPI_Datatype datatype;
+		int ret = table->allocate_datatype_func(data_handle, &datatype);
+		if (ret == 0)
+			return datatype;
+		else
+			return 0;
+	}
+	return 0;
+}
+
 void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _starpu_mpi_req *req)
 {
 	enum starpu_data_interface_id id = starpu_data_get_interface_id(data_handle);

+ 2 - 0
mpi/src/starpu_mpi_datatype.h

@@ -31,6 +31,8 @@ void _starpu_mpi_datatype_shutdown(void);
 void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _starpu_mpi_req *req);
 void _starpu_mpi_datatype_free(starpu_data_handle_t data_handle, MPI_Datatype *datatype);
 
+MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle);
+
 #ifdef __cplusplus
 }
 #endif