|
@@ -22,6 +22,7 @@ struct _starpu_mpi_datatype_funcs
|
|
|
{
|
|
|
enum starpu_data_interface_id id;
|
|
|
starpu_mpi_datatype_allocate_func_t allocate_datatype_func;
|
|
|
+ starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func;
|
|
|
starpu_mpi_datatype_free_func_t free_datatype_func;
|
|
|
UT_hash_handle hh;
|
|
|
};
|
|
@@ -42,14 +43,16 @@ void _starpu_mpi_datatype_shutdown(void)
|
|
|
* Matrix
|
|
|
*/
|
|
|
|
|
|
-static int handle_to_datatype_matrix(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+static int handle_to_datatype_matrix(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
|
|
|
{
|
|
|
+ struct starpu_matrix_interface *matrix_interface = starpu_data_get_interface_on_node(data_handle, node);
|
|
|
+
|
|
|
int ret;
|
|
|
|
|
|
- unsigned nx = starpu_matrix_get_nx(data_handle);
|
|
|
- unsigned ny = starpu_matrix_get_ny(data_handle);
|
|
|
- unsigned ld = starpu_matrix_get_local_ld(data_handle);
|
|
|
- size_t elemsize = starpu_matrix_get_elemsize(data_handle);
|
|
|
+ unsigned nx = STARPU_MATRIX_GET_NX(matrix_interface);
|
|
|
+ unsigned ny = STARPU_MATRIX_GET_NY(matrix_interface);
|
|
|
+ unsigned ld = STARPU_MATRIX_GET_LD(matrix_interface);
|
|
|
+ size_t elemsize = STARPU_MATRIX_GET_ELEMSIZE(matrix_interface);
|
|
|
|
|
|
ret = MPI_Type_vector(ny, nx*elemsize, ld*elemsize, MPI_BYTE, datatype);
|
|
|
STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_vector failed");
|
|
@@ -64,16 +67,18 @@ static int handle_to_datatype_matrix(starpu_data_handle_t data_handle, MPI_Datat
|
|
|
* Block
|
|
|
*/
|
|
|
|
|
|
-static int handle_to_datatype_block(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+static int handle_to_datatype_block(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
|
|
|
{
|
|
|
+ struct starpu_block_interface *block_interface = starpu_data_get_interface_on_node(data_handle, node);
|
|
|
+
|
|
|
int ret;
|
|
|
|
|
|
- unsigned nx = starpu_block_get_nx(data_handle);
|
|
|
- unsigned ny = starpu_block_get_ny(data_handle);
|
|
|
- unsigned nz = starpu_block_get_nz(data_handle);
|
|
|
- unsigned ldy = starpu_block_get_local_ldy(data_handle);
|
|
|
- unsigned ldz = starpu_block_get_local_ldz(data_handle);
|
|
|
- size_t elemsize = starpu_block_get_elemsize(data_handle);
|
|
|
+ unsigned nx = STARPU_BLOCK_GET_NX(block_interface);
|
|
|
+ unsigned ny = STARPU_BLOCK_GET_NY(block_interface);
|
|
|
+ unsigned nz = STARPU_BLOCK_GET_NZ(block_interface);
|
|
|
+ unsigned ldy = STARPU_BLOCK_GET_LDY(block_interface);
|
|
|
+ unsigned ldz = STARPU_BLOCK_GET_LDZ(block_interface);
|
|
|
+ size_t elemsize = STARPU_BLOCK_GET_ELEMSIZE(block_interface);
|
|
|
|
|
|
MPI_Datatype datatype_2dlayer;
|
|
|
ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_2dlayer);
|
|
@@ -95,18 +100,20 @@ static int handle_to_datatype_block(starpu_data_handle_t data_handle, MPI_Dataty
|
|
|
* Tensor
|
|
|
*/
|
|
|
|
|
|
-static int handle_to_datatype_tensor(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+static int handle_to_datatype_tensor(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
|
|
|
{
|
|
|
+ struct starpu_tensor_interface *tensor_interface = starpu_data_get_interface_on_node(data_handle, node);
|
|
|
+
|
|
|
int ret;
|
|
|
|
|
|
- unsigned nx = starpu_tensor_get_nx(data_handle);
|
|
|
- unsigned ny = starpu_tensor_get_ny(data_handle);
|
|
|
- unsigned nz = starpu_tensor_get_nz(data_handle);
|
|
|
- unsigned nt = starpu_tensor_get_nt(data_handle);
|
|
|
- unsigned ldy = starpu_tensor_get_local_ldy(data_handle);
|
|
|
- unsigned ldz = starpu_tensor_get_local_ldz(data_handle);
|
|
|
- unsigned ldt = starpu_tensor_get_local_ldt(data_handle);
|
|
|
- size_t elemsize = starpu_tensor_get_elemsize(data_handle);
|
|
|
+ unsigned nx = STARPU_TENSOR_GET_NX(tensor_interface);
|
|
|
+ unsigned ny = STARPU_TENSOR_GET_NY(tensor_interface);
|
|
|
+ unsigned nz = STARPU_TENSOR_GET_NZ(tensor_interface);
|
|
|
+ unsigned nt = STARPU_TENSOR_GET_NT(tensor_interface);
|
|
|
+ unsigned ldy = STARPU_TENSOR_GET_LDY(tensor_interface);
|
|
|
+ unsigned ldz = STARPU_TENSOR_GET_LDZ(tensor_interface);
|
|
|
+ unsigned ldt = STARPU_TENSOR_GET_LDT(tensor_interface);
|
|
|
+ size_t elemsize = STARPU_TENSOR_GET_ELEMSIZE(tensor_interface);
|
|
|
|
|
|
MPI_Datatype datatype_3dlayer;
|
|
|
ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_3dlayer);
|
|
@@ -135,12 +142,14 @@ static int handle_to_datatype_tensor(starpu_data_handle_t data_handle, MPI_Datat
|
|
|
* Vector
|
|
|
*/
|
|
|
|
|
|
-static int handle_to_datatype_vector(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+static int handle_to_datatype_vector(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
|
|
|
{
|
|
|
+ struct starpu_vector_interface *vector_interface = starpu_data_get_interface_on_node(data_handle, node);
|
|
|
+
|
|
|
int ret;
|
|
|
|
|
|
- unsigned nx = starpu_vector_get_nx(data_handle);
|
|
|
- size_t elemsize = starpu_vector_get_elemsize(data_handle);
|
|
|
+ unsigned nx = STARPU_VECTOR_GET_NX(vector_interface);
|
|
|
+ size_t elemsize = STARPU_VECTOR_GET_ELEMSIZE(vector_interface);
|
|
|
|
|
|
ret = MPI_Type_contiguous(nx*elemsize, MPI_BYTE, datatype);
|
|
|
STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
|
|
@@ -155,11 +164,13 @@ static int handle_to_datatype_vector(starpu_data_handle_t data_handle, MPI_Datat
|
|
|
* Variable
|
|
|
*/
|
|
|
|
|
|
-static int handle_to_datatype_variable(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+static int handle_to_datatype_variable(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
|
|
|
{
|
|
|
+ struct starpu_variable_interface *variable_interface = starpu_data_get_interface_on_node(data_handle, node);
|
|
|
+
|
|
|
int ret;
|
|
|
|
|
|
- size_t elemsize = starpu_variable_get_elemsize(data_handle);
|
|
|
+ size_t elemsize = STARPU_VARIABLE_GET_ELEMSIZE(variable_interface);
|
|
|
|
|
|
ret = MPI_Type_contiguous(elemsize, MPI_BYTE, datatype);
|
|
|
STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
|
|
@@ -174,10 +185,11 @@ static int handle_to_datatype_variable(starpu_data_handle_t data_handle, MPI_Dat
|
|
|
* Void
|
|
|
*/
|
|
|
|
|
|
-static int handle_to_datatype_void(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+static int handle_to_datatype_void(starpu_data_handle_t data_handle, unsigned node, MPI_Datatype *datatype)
|
|
|
{
|
|
|
int ret;
|
|
|
(void)data_handle;
|
|
|
+ (void)node;
|
|
|
|
|
|
ret = MPI_Type_contiguous(0, MPI_BYTE, datatype);
|
|
|
STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
|
|
@@ -192,7 +204,7 @@ static int handle_to_datatype_void(starpu_data_handle_t data_handle, MPI_Datatyp
|
|
|
* Generic
|
|
|
*/
|
|
|
|
|
|
-static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
|
|
|
+static starpu_mpi_datatype_node_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
|
|
|
{
|
|
|
//#define DYNAMIC_MATRICES
|
|
|
#ifndef DYNAMIC_MATRICES
|
|
@@ -208,7 +220,7 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
|
|
|
[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
|
|
|
};
|
|
|
|
|
|
-MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle)
|
|
|
+MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t data_handle, unsigned node)
|
|
|
{
|
|
|
enum starpu_data_interface_id id = starpu_data_get_interface_id(data_handle);
|
|
|
if (id < STARPU_MAX_INTERFACE_ID) return 0;
|
|
@@ -217,10 +229,14 @@ MPI_Datatype _starpu_mpi_datatype_get_user_defined_datatype(starpu_data_handle_t
|
|
|
STARPU_PTHREAD_MUTEX_LOCK(&_starpu_mpi_datatype_funcs_table_mutex);
|
|
|
HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
|
|
|
STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_datatype_funcs_table_mutex);
|
|
|
- if (table && table->allocate_datatype_func)
|
|
|
+ if (table && (table->allocate_datatype_node_func || table->allocate_datatype_func))
|
|
|
{
|
|
|
MPI_Datatype datatype;
|
|
|
- int ret = table->allocate_datatype_func(data_handle, &datatype);
|
|
|
+ int ret;
|
|
|
+ if (table->allocate_datatype_node_func)
|
|
|
+ ret = table->allocate_datatype_node_func(data_handle, node, &datatype);
|
|
|
+ else
|
|
|
+ ret = table->allocate_datatype_func(data_handle, &datatype);
|
|
|
if (ret == 0)
|
|
|
return datatype;
|
|
|
else
|
|
@@ -235,10 +251,10 @@ void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _sta
|
|
|
|
|
|
if (id < STARPU_MAX_INTERFACE_ID)
|
|
|
{
|
|
|
- starpu_mpi_datatype_allocate_func_t func = handle_to_datatype_funcs[id];
|
|
|
+ starpu_mpi_datatype_node_allocate_func_t func = handle_to_datatype_funcs[id];
|
|
|
if (func)
|
|
|
{
|
|
|
- func(data_handle, &req->datatype);
|
|
|
+ func(data_handle, req->node, &req->datatype);
|
|
|
req->registered_datatype = 1;
|
|
|
}
|
|
|
else
|
|
@@ -256,8 +272,12 @@ void _starpu_mpi_datatype_allocate(starpu_data_handle_t data_handle, struct _sta
|
|
|
STARPU_PTHREAD_MUTEX_UNLOCK(&_starpu_mpi_datatype_funcs_table_mutex);
|
|
|
if (table)
|
|
|
{
|
|
|
- STARPU_ASSERT_MSG(table->allocate_datatype_func, "Handle To Datatype Function not defined for StarPU data interface %d", id);
|
|
|
- int ret = table->allocate_datatype_func(data_handle, &req->datatype);
|
|
|
+ STARPU_ASSERT_MSG(table->allocate_datatype_node_func || table->allocate_datatype_func, "Handle To Datatype Function not defined for StarPU data interface %d", id);
|
|
|
+ int ret;
|
|
|
+ if (table->allocate_datatype_node_func)
|
|
|
+ ret = table->allocate_datatype_node_func(data_handle, req->node, &req->datatype);
|
|
|
+ else
|
|
|
+ ret = table->allocate_datatype_func(data_handle, &req->datatype);
|
|
|
if (ret == 0)
|
|
|
req->registered_datatype = 1;
|
|
|
else
|
|
@@ -362,7 +382,7 @@ void _starpu_mpi_datatype_free(starpu_data_handle_t data_handle, MPI_Datatype *d
|
|
|
/* else the datatype is not predefined by StarPU */
|
|
|
}
|
|
|
|
|
|
-int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
|
|
|
+int _starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
|
|
|
{
|
|
|
struct _starpu_mpi_datatype_funcs *table;
|
|
|
|
|
@@ -372,6 +392,7 @@ int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, sta
|
|
|
HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
|
|
|
if (table)
|
|
|
{
|
|
|
+ table->allocate_datatype_node_func = allocate_datatype_node_func;
|
|
|
table->allocate_datatype_func = allocate_datatype_func;
|
|
|
table->free_datatype_func = free_datatype_func;
|
|
|
}
|
|
@@ -379,6 +400,7 @@ int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, sta
|
|
|
{
|
|
|
_STARPU_MPI_MALLOC(table, sizeof(struct _starpu_mpi_datatype_funcs));
|
|
|
table->id = id;
|
|
|
+ table->allocate_datatype_node_func = allocate_datatype_node_func;
|
|
|
table->allocate_datatype_func = allocate_datatype_func;
|
|
|
table->free_datatype_func = free_datatype_func;
|
|
|
HASH_ADD_INT(_starpu_mpi_datatype_funcs_table, id, table);
|
|
@@ -387,6 +409,25 @@ int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, sta
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
+int starpu_mpi_interface_datatype_node_register(enum starpu_data_interface_id id, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func, starpu_mpi_datatype_free_func_t free_datatype_func)
|
|
|
+{
|
|
|
+ return _starpu_mpi_interface_datatype_register(id, allocate_datatype_node_func, NULL, free_datatype_func);
|
|
|
+}
|
|
|
+
|
|
|
+int starpu_mpi_interface_datatype_register(enum starpu_data_interface_id id, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
|
|
|
+{
|
|
|
+ return _starpu_mpi_interface_datatype_register(id, NULL, allocate_datatype_func, free_datatype_func);
|
|
|
+}
|
|
|
+
|
|
|
+int starpu_mpi_datatype_node_register(starpu_data_handle_t handle, starpu_mpi_datatype_node_allocate_func_t allocate_datatype_node_func, starpu_mpi_datatype_free_func_t free_datatype_func)
|
|
|
+{
|
|
|
+ enum starpu_data_interface_id id = starpu_data_get_interface_id(handle);
|
|
|
+ int ret;
|
|
|
+ ret = starpu_mpi_interface_datatype_node_register(id, allocate_datatype_node_func, free_datatype_func);
|
|
|
+ STARPU_ASSERT_MSG(handle->ops->handle_to_pointer || handle->ops->to_pointer, "The data interface must define the operation 'to_pointer'\n");
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+
|
|
|
int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
|
|
|
{
|
|
|
enum starpu_data_interface_id id = starpu_data_get_interface_id(handle);
|