소스 검색

Add starpu_mpi_datatype_node_register and starpu_mpi_interface_datatype_node_register

which will be needed for MPI/NUMA/GPUDirect
Samuel Thibault 4 년 전
부모
커밋
5ce089e2d4

+ 3 - 0
ChangeLog

@@ -47,6 +47,9 @@ New features:
   * Add an experimental python interface (not actually parallel yet)
   * Add task submission file+line in traces.
   * Add papi- and nvml-based energy measurement.
+  * Add starpu_mpi_datatype_node_register and
+    starpu_mpi_interface_datatype_node_register which will be needed for
+    MPI/NUMA/GPUDirect.
 
 Small changes:
   * Add a synthetic energy efficiency testcase.

+ 6 - 6
mpi/examples/user_datatype/my_interface.c

@@ -45,7 +45,7 @@ void starpu_my_data_compare_codelet_cpu(void *descr[], void *_args)
 	*compare = (d0 == d1 && c0 == c1);
 }
 
-void _starpu_my_data_datatype_allocate(MPI_Datatype *mpi_datatype)
+void _starpu_my_data_datatype_allocate(unsigned node, MPI_Datatype *mpi_datatype)
 {
 	int ret;
 	int blocklengths[2] = {1, 1};
@@ -68,10 +68,10 @@ void _starpu_my_data_datatype_allocate(MPI_Datatype *mpi_datatype)
 	free(myinterface);
 }
 
-int starpu_my_data_datatype_allocate(starpu_data_handle_t handle, MPI_Datatype *mpi_datatype)
+int starpu_my_data_datatype_allocate(starpu_data_handle_t handle, unsigned node, MPI_Datatype *mpi_datatype)
 {
 	(void)handle;
-	_starpu_my_data_datatype_allocate(mpi_datatype);
+	_starpu_my_data_datatype_allocate(node, mpi_datatype);
 	return 0;
 }
 
@@ -80,7 +80,7 @@ void starpu_my_data_datatype_free(MPI_Datatype *mpi_datatype)
 	MPI_Type_free(mpi_datatype);
 }
 
-int starpu_my_data2_datatype_allocate(starpu_data_handle_t handle, MPI_Datatype *mpi_datatype)
+int starpu_my_data2_datatype_allocate(starpu_data_handle_t handle, unsigned node, MPI_Datatype *mpi_datatype)
 {
 	(void)handle;
 	(void)mpi_datatype;
@@ -315,7 +315,7 @@ void starpu_my_data_register(starpu_data_handle_t *handleptr, unsigned home_node
 	if (interface_data_ops.interfaceid == STARPU_UNKNOWN_INTERFACE_ID)
 	{
 		interface_data_ops.interfaceid = starpu_data_interface_get_next_id();
-		starpu_mpi_interface_datatype_register(interface_data_ops.interfaceid, starpu_my_data_datatype_allocate, starpu_my_data_datatype_free);
+		starpu_mpi_interface_datatype_node_register(interface_data_ops.interfaceid, starpu_my_data_datatype_allocate, starpu_my_data_datatype_free);
 	}
 
 	struct starpu_my_data_interface data =
@@ -357,7 +357,7 @@ void starpu_my_data2_register(starpu_data_handle_t *handleptr, unsigned home_nod
 	if (interface_data2_ops.interfaceid == STARPU_UNKNOWN_INTERFACE_ID)
 	{
 		interface_data2_ops.interfaceid = starpu_data_interface_get_next_id();
-		starpu_mpi_interface_datatype_register(interface_data2_ops.interfaceid, starpu_my_data2_datatype_allocate, starpu_my_data2_datatype_free);
+		starpu_mpi_interface_datatype_node_register(interface_data2_ops.interfaceid, starpu_my_data2_datatype_allocate, starpu_my_data2_datatype_free);
 	}
 
 	struct starpu_my_data_interface data =

+ 3 - 3
mpi/examples/user_datatype/my_interface.h

@@ -47,10 +47,10 @@ int starpu_my_data_interface_get_int(void *interface);
 #define STARPU_MY_DATA_GET_CHAR(interface)	starpu_my_data_interface_get_char(interface)
 #define STARPU_MY_DATA_GET_INT(interface)	starpu_my_data_interface_get_int(interface)
 
-void _starpu_my_data_datatype_allocate(MPI_Datatype *mpi_datatype);
-int starpu_my_data_datatype_allocate(starpu_data_handle_t handle, MPI_Datatype *mpi_datatype);
+void _starpu_my_data_datatype_allocate(unsigned node, MPI_Datatype *mpi_datatype);
+int starpu_my_data_datatype_allocate(starpu_data_handle_t handle, unsigned node, MPI_Datatype *mpi_datatype);
 void starpu_my_data_datatype_free(MPI_Datatype *mpi_datatype);
-int starpu_my_data2_datatype_allocate(starpu_data_handle_t handle, MPI_Datatype *mpi_datatype);
+int starpu_my_data2_datatype_allocate(starpu_data_handle_t handle, unsigned node, MPI_Datatype *mpi_datatype);
 void starpu_my_data2_datatype_free(MPI_Datatype *mpi_datatype);
 
 void starpu_my_data_display_codelet_cpu(void *descr[], void *_args);

+ 2 - 2
mpi/examples/user_datatype/user_datatype.c

@@ -62,7 +62,7 @@ int main(int argc, char **argv)
 	if (rank == 0)
 	{
 		MPI_Datatype mpi_datatype;
-		_starpu_my_data_datatype_allocate(&mpi_datatype);
+		_starpu_my_data_datatype_allocate(STARPU_MAIN_RAM, &mpi_datatype);
 		MPI_Send(&my0, 1, mpi_datatype, 1, 42, MPI_COMM_WORLD);
 		starpu_my_data_datatype_free(&mpi_datatype);
 	}
@@ -71,7 +71,7 @@ int main(int argc, char **argv)
 		MPI_Datatype mpi_datatype;
 		MPI_Status status;
 		struct starpu_my_data myx;
-		_starpu_my_data_datatype_allocate(&mpi_datatype);
+		_starpu_my_data_datatype_allocate(STARPU_MAIN_RAM, &mpi_datatype);
 		MPI_Recv(&myx, 1, mpi_datatype, 0, 42, MPI_COMM_WORLD, &status);
 		FPRINTF(stderr, "[mpi] Received value: '%c' %d\n", myx.c, myx.d);
 		starpu_my_data_datatype_free(&mpi_datatype);

+ 21 - 0
mpi/include/starpu_mpi.h

@@ -344,6 +344,7 @@ int starpu_mpi_isend_array_detached_unlock_tag_prio(unsigned array_size, starpu_
 int starpu_mpi_irecv_array_detached_unlock_tag(unsigned array_size, starpu_data_handle_t *data_handle, int *source, starpu_mpi_tag_t *data_tag, MPI_Comm *comm, starpu_tag_t tag);
 
 typedef int (*starpu_mpi_datatype_allocate_func_t)(starpu_data_handle_t, MPI_Datatype *);
+typedef int (*starpu_mpi_datatype_node_allocate_func_t)(starpu_data_handle_t, unsigned node, MPI_Datatype *);
 typedef void (*starpu_mpi_datatype_free_func_t)(MPI_Datatype *);
 
 /**
@@ -367,6 +368,26 @@ int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatyp
 int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func);
 
 /**
+   Register functions to create and free a MPI datatype for the given
+   handle.
+   Similar to starpu_mpi_interface_datatype_register().
+   It is important that the function is called before any
+   communication can take place for a data with the given handle. See
+   \ref ExchangingUserDefinedDataInterface for an example.
+*/
+int starpu_mpi_datatype_node_register(starpu_data_handle_t handle, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func);
+
+/**
+   Register functions to create and free a MPI datatype for the given
+   interface id.
+   Similar to starpu_mpi_datatype_register().
+   It is important that the function is called before any
+   communication can take place for a data with the given handle. See
+   \ref ExchangingUserDefinedDataInterface for an example.
+*/
+int starpu_mpi_interface_datatype_node_register(enum starpu_data_interface_id id, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func);
+
+/**
    Unregister the MPI datatype functions stored for the interface of
    the given handle.
 */

+ 1 - 1
mpi/src/mpi/starpu_mpi_mpi.c

@@ -930,7 +930,7 @@ static void _starpu_mpi_early_data_cb(void* arg)
 		/* Data has been received as a raw memory, it has to be unpacked */
 		struct starpu_data_interface_ops *itf_src = starpu_data_get_interface_ops(args->early_handle);
 		struct starpu_data_interface_ops *itf_dst = starpu_data_get_interface_ops(args->data_handle);
-		MPI_Datatype datatype = _starpu_mpi_datatype_get_user_defined_datatype(args->data_handle);
+		MPI_Datatype datatype = _starpu_mpi_datatype_get_user_defined_datatype(args->data_handle, STARPU_MAIN_RAM);
 
 		if (datatype)
 		{

+ 77 - 36
mpi/src/starpu_mpi_datatype.c

@@ -22,6 +22,7 @@ struct _starpu_mpi_datatype_funcs
 {
 	enum starpu_data_interface_id id;
 	starpu_mpi_datatype_allocate_func_t allocate_datatype_func;
+	starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func;
 	starpu_mpi_datatype_free_func_t free_datatype_func;
 	UT_hash_handle hh;
 };
@@ -42,14 +43,16 @@ void _starpu_mpi_datatype_shutdown(void)
  * 	Matrix
  */
 
-static int handle_to_datatype_matrix(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+static int handle_to_datatype_matrix(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
 {
+	struct starpu_matrix_interface *matrix_interface = starpu_data_get_interface_on_node(data_handle, node);
+
 	int ret;
 
-	unsigned nx = starpu_matrix_get_nx(data_handle);
-	unsigned ny = starpu_matrix_get_ny(data_handle);
-	unsigned ld = starpu_matrix_get_local_ld(data_handle);
-	size_t elemsize = starpu_matrix_get_elemsize(data_handle);
+	unsigned nx = STARPU_MATRIX_GET_NX(matrix_interface);
+	unsigned ny = STARPU_MATRIX_GET_NY(matrix_interface);
+	unsigned ld = STARPU_MATRIX_GET_LD(matrix_interface);
+	size_t elemsize = STARPU_MATRIX_GET_ELEMSIZE(matrix_interface);
 
 	ret = MPI_Type_vector(ny, nx*elemsize, ld*elemsize, MPI_BYTE, datatype);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_vector failed");
@@ -64,16 +67,18 @@ static int handle_to_datatype_matrix(starpu_data_handle_t data_handle, MPI_Datat
  * 	Block
  */
 
-static int handle_to_datatype_block(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+static int handle_to_datatype_block(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
 {
+	struct starpu_block_interface *block_interface = starpu_data_get_interface_on_node(data_handle, node);
+
 	int ret;
 
-	unsigned nx = starpu_block_get_nx(data_handle);
-	unsigned ny = starpu_block_get_ny(data_handle);
-	unsigned nz = starpu_block_get_nz(data_handle);
-	unsigned ldy = starpu_block_get_local_ldy(data_handle);
-	unsigned ldz = starpu_block_get_local_ldz(data_handle);
-	size_t elemsize = starpu_block_get_elemsize(data_handle);
+	unsigned nx = STARPU_BLOCK_GET_NX(block_interface);
+	unsigned ny = STARPU_BLOCK_GET_NY(block_interface);
+	unsigned nz = STARPU_BLOCK_GET_NZ(block_interface);
+	unsigned ldy = STARPU_BLOCK_GET_LDY(block_interface);
+	unsigned ldz = STARPU_BLOCK_GET_LDZ(block_interface);
+	size_t elemsize = STARPU_BLOCK_GET_ELEMSIZE(block_interface);
 
 	MPI_Datatype datatype_2dlayer;
 	ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_2dlayer);
@@ -95,18 +100,20 @@ static int handle_to_datatype_block(starpu_data_handle_t data_handle, MPI_Dataty
  * 	Tensor
  */
 
-static int handle_to_datatype_tensor(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+static int handle_to_datatype_tensor(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
 {
+	struct starpu_tensor_interface *tensor_interface = starpu_data_get_interface_on_node(data_handle, node);
+
 	int ret;
 
-	unsigned nx = starpu_tensor_get_nx(data_handle);
-	unsigned ny = starpu_tensor_get_ny(data_handle);
-	unsigned nz = starpu_tensor_get_nz(data_handle);
-	unsigned nt = starpu_tensor_get_nt(data_handle);
-	unsigned ldy = starpu_tensor_get_local_ldy(data_handle);
-	unsigned ldz = starpu_tensor_get_local_ldz(data_handle);
-	unsigned ldt = starpu_tensor_get_local_ldt(data_handle);
-	size_t elemsize = starpu_tensor_get_elemsize(data_handle);
+	unsigned nx = STARPU_TENSOR_GET_NX(tensor_interface);
+	unsigned ny = STARPU_TENSOR_GET_NY(tensor_interface);
+	unsigned nz = STARPU_TENSOR_GET_NZ(tensor_interface);
+	unsigned nt = STARPU_TENSOR_GET_NT(tensor_interface);
+	unsigned ldy = STARPU_TENSOR_GET_LDY(tensor_interface);
+	unsigned ldz = STARPU_TENSOR_GET_LDZ(tensor_interface);
+	unsigned ldt = STARPU_TENSOR_GET_LDT(tensor_interface);
+	size_t elemsize = STARPU_TENSOR_GET_ELEMSIZE(tensor_interface);
 
 	MPI_Datatype datatype_3dlayer;
 	ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_3dlayer);
@@ -135,12 +142,14 @@ static int handle_to_datatype_tensor(starpu_data_handle_t data_handle, MPI_Datat
  * 	Vector
  */
 
-static int handle_to_datatype_vector(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+static int handle_to_datatype_vector(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
 {
+	struct starpu_vector_interface *vector_interface = starpu_data_get_interface_on_node(data_handle, node);
+
 	int ret;
 
-	unsigned nx = starpu_vector_get_nx(data_handle);
-	size_t elemsize = starpu_vector_get_elemsize(data_handle);
+	unsigned nx = STARPU_VECTOR_GET_NX(vector_interface);
+	size_t elemsize = STARPU_VECTOR_GET_ELEMSIZE(vector_interface);
 
 	ret = MPI_Type_contiguous(nx*elemsize, MPI_BYTE, datatype);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
@@ -155,11 +164,13 @@ static int handle_to_datatype_vector(starpu_data_handle_t data_handle, MPI_Datat
  * 	Variable
  */
 
-static int handle_to_datatype_variable(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+static int handle_to_datatype_variable(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
 {
+	struct starpu_variable_interface *variable_interface = starpu_data_get_interface_on_node(data_handle, node);
+
 	int ret;
 
-	size_t elemsize = starpu_variable_get_elemsize(data_handle);
+	size_t elemsize = STARPU_VARIABLE_GET_ELEMSIZE(variable_interface);
 
 	ret = MPI_Type_contiguous(elemsize, MPI_BYTE, datatype);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
@@ -174,10 +185,11 @@ static int handle_to_datatype_variable(starpu_data_handle_t data_handle, MPI_Dat
  * 	Void
  */
 
-static int handle_to_datatype_void(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+static int handle_to_datatype_void(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
 {
 	int ret;
 	(void)data_handle;
+	(void)node;
 
 	ret = MPI_Type_contiguous(0, MPI_BYTE, datatype);
 	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
@@ -192,7 +204,7 @@ static int handle_to_datatype_void(starpu_data_handle_t data_handle, MPI_Datatyp
  *	Generic
  */
 
-static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
+static starpu_mpi_datatype_node_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
 {
 //#define DYNAMIC_MATRICES
 #ifndef DYNAMIC_MATRICES
@@ -208,7 +220,7 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
 };
 
-MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle)
+MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle, unsigned node)
 {
 	enum starpu_data_interface_id id = starpu_data_get_interface_id(data_handle);
 	if (id < STARPU_MAX_INTERFACE_ID) return 0;
@@ -217,10 +229,14 @@ MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t
 	STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_datatype_funcs_table_mutex);
 	HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
 	STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_datatype_funcs_table_mutex);
-	if (table && table->allocate_datatype_func)
+	if (table && (table->allocate_datatype_node_func || table->allocate_datatype_func))
 	{
 		MPI_Datatype datatype;
-		int ret = table->allocate_datatype_func(data_handle, &datatype);
+		int ret;
+		if (table->allocate_datatype_node_func)
+			ret = table->allocate_datatype_node_func(data_handle, node, &datatype);
+		else
+			ret = table->allocate_datatype_func(data_handle, &datatype);
 		if (ret == 0)
 			return datatype;
 		else
@@ -235,10 +251,10 @@ void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _sta
 
 	if (id < STARPU_MAX_INTERFACE_ID)
 	{
-		starpu_mpi_datatype_allocate_func_t func = handle_to_datatype_funcs[id];
+		starpu_mpi_datatype_node_allocate_func_t func = handle_to_datatype_funcs[id];
 		if (func)
 		{
-			func(data_handle, &req->datatype);
+			func(data_handle, req->node, &req->datatype);
 			req->registered_datatype = 1;
 		}
 		else
@@ -256,8 +272,12 @@ void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _sta
 		STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_datatype_funcs_table_mutex);
 		if (table)
 		{
-			STARPU_ASSERT_MSG(table->allocate_datatype_func, "Handle To Datatype Function not defined for StarPU data interface %d", id);
-			int ret = table->allocate_datatype_func(data_handle, &req->datatype);
+			STARPU_ASSERT_MSG(table->allocate_datatype_node_func || table->allocate_datatype_func, "Handle To Datatype Function not defined for StarPU data interface %d", id);
+			int ret;
+			if (table->allocate_datatype_node_func)
+				ret = table->allocate_datatype_node_func(data_handle, req->node, &req->datatype);
+			else
+				ret = table->allocate_datatype_func(data_handle, &req->datatype);
 			if (ret == 0)
 				req->registered_datatype = 1;
 			else
@@ -362,7 +382,7 @@ void _starpu_mpi_datatype_free(starpu_data_handle_t data_handle, MPI_Datatype *d
 	/* else the datatype is not predefined by StarPU */
 }
 
-int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
+int _starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
 {
 	struct _starpu_mpi_datatype_funcs *table;
 
@@ -372,6 +392,7 @@ int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, sta
 	HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
 	if (table)
 	{
+		table->allocate_datatype_node_func = allocate_datatype_node_func;
 		table->allocate_datatype_func = allocate_datatype_func;
 		table->free_datatype_func = free_datatype_func;
 	}
@@ -379,6 +400,7 @@ int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, sta
 	{
 		_STARPU_MPI_MALLOC(table, sizeof(struct _starpu_mpi_datatype_funcs));
 		table->id = id;
+		table->allocate_datatype_node_func = allocate_datatype_node_func;
 		table->allocate_datatype_func = allocate_datatype_func;
 		table->free_datatype_func = free_datatype_func;
 		HASH_ADD_INT(_starpu_mpi_datatype_funcs_table, id, table);
@@ -387,6 +409,25 @@ int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, sta
 	return 0;
 }
 
+int starpu_mpi_interface_datatype_node_register(enum starpu_data_interface_id id, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func, starpu_mpi_datatype_free_func_t free_datatype_func)
+{
+	return _starpu_mpi_interface_datatype_register(id, allocate_datatype_node_func, NULL, free_datatype_func);
+}
+
+int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
+{
+	return _starpu_mpi_interface_datatype_register(id, NULL, allocate_datatype_func, free_datatype_func);
+}
+
+int starpu_mpi_datatype_node_register(starpu_data_handle_t handle, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func, starpu_mpi_datatype_free_func_t free_datatype_func)
+{
+	enum starpu_data_interface_id id = starpu_data_get_interface_id(handle);
+	int ret;
+	ret = starpu_mpi_interface_datatype_node_register(id, allocate_datatype_node_func, free_datatype_func);
+	STARPU_ASSERT_MSG(handle->ops->handle_to_pointer || handle->ops->to_pointer, "The data interface must define the operation 'to_pointer'\n");
+	return ret;
+}
+
 int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
 {
 	enum starpu_data_interface_id id = starpu_data_get_interface_id(handle);

+ 1 - 1
mpi/src/starpu_mpi_datatype.h

@@ -33,7 +33,7 @@ void _starpu_mpi_datatype_shutdown(void);
 void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _starpu_mpi_req *req);
 void _starpu_mpi_datatype_free(starpu_data_handle_t data_handle, MPI_Datatype *datatype);
 
-MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle);
+MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle, unsigned node);
 
 #ifdef __cplusplus
 }

+ 1 - 0
mpi/src/starpu_mpi_private.h

@@ -225,6 +225,7 @@ LIST_TYPE(_starpu_mpi_req,
 	starpu_data_handle_t data_handle;
 
 	int prio;
+	unsigned node;	/* Which StarPU memory node this will read from / write to */
 
 	/** description of the data to be sent/received */
 	MPI_Datatype datatype;

+ 1 - 0
mpi/src/starpu_mpi_req.c

@@ -25,6 +25,7 @@ void _starpu_mpi_request_init(struct _starpu_mpi_req **req)
 	/* Initialize the request structure */
 	//(*req)->data_handle = NULL;
 	//(*req)->prio = 0;
+	(*req)->node = STARPU_MAIN_RAM; // XXX For now
 
 	//(*req)->datatype = 0;
 	//(*req)->datatype_name = NULL;