|
@@ -2,7 +2,7 @@
|
|
|
*
|
|
|
* Copyright (C) 2010-2017,2019 CNRS
|
|
|
* Copyright (C) 2011,2012,2015 Inria
|
|
|
- * Copyright (C) 2009-2011,2014,2015,2018 Université de Bordeaux
|
|
|
+ * Copyright (C) 2009-2011,2014,2015,2018,2020 Université de Bordeaux
|
|
|
*
|
|
|
* StarPU is free software; you can redistribute it and/or modify
|
|
|
* it under the terms of the GNU Lesser General Public License as published by
|
|
@@ -91,6 +91,44 @@ static void handle_to_datatype_block(starpu_data_handle_t data_handle, MPI_Datat
|
|
|
}
|
|
|
|
|
|
/*
|
|
|
+ * Tensor
|
|
|
+ */
|
|
|
+
|
|
|
+static void handle_to_datatype_tensor(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+{
|
|
|
+ int ret;
|
|
|
+
|
|
|
+ unsigned nx = starpu_tensor_get_nx(data_handle);
|
|
|
+ unsigned ny = starpu_tensor_get_ny(data_handle);
|
|
|
+ unsigned nz = starpu_tensor_get_nz(data_handle);
|
|
|
+ unsigned nt = starpu_tensor_get_nt(data_handle);
|
|
|
+ unsigned ldy = starpu_tensor_get_local_ldy(data_handle);
|
|
|
+ unsigned ldz = starpu_tensor_get_local_ldz(data_handle);
|
|
|
+ unsigned ldt = starpu_tensor_get_local_ldt(data_handle);
|
|
|
+ size_t elemsize = starpu_block_get_elemsize(data_handle);
|
|
|
+
|
|
|
+ MPI_Datatype datatype_3dlayer;
|
|
|
+ ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_3dlayer);
|
|
|
+ STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_vector failed");
|
|
|
+
|
|
|
+ ret = MPI_Type_commit(&datatype_3dlayer);
|
|
|
+ STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
|
|
|
+
|
|
|
+ MPI_Datatype datatype_2dlayer;
|
|
|
+ ret = MPI_Type_create_hvector(nz, 1, ldz*elemsize, datatype_3dlayer, &datatype_2dlayer);
|
|
|
+ STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_hvector failed");
|
|
|
+
|
|
|
+ ret = MPI_Type_commit(&datatype_2dlayer);
|
|
|
+ STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
|
|
|
+
|
|
|
+ ret = MPI_Type_create_hvector(nt, 1, ldt*elemsize, datatype_2dlayer, datatype);
|
|
|
+ STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_hvector failed");
|
|
|
+
|
|
|
+ ret = MPI_Type_commit(datatype);
|
|
|
+ STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
|
|
|
+}
|
|
|
+
|
|
|
+/*
|
|
|
* Vector
|
|
|
*/
|
|
|
|
|
@@ -149,6 +187,7 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
|
|
|
{
|
|
|
[STARPU_MATRIX_INTERFACE_ID] = handle_to_datatype_matrix,
|
|
|
[STARPU_BLOCK_INTERFACE_ID] = handle_to_datatype_block,
|
|
|
+ [STARPU_TENSOR_INTERFACE_ID] = handle_to_datatype_tensor,
|
|
|
[STARPU_VECTOR_INTERFACE_ID] = handle_to_datatype_vector,
|
|
|
[STARPU_CSR_INTERFACE_ID] = NULL, /* Sent through pack/unpack operations */
|
|
|
[STARPU_BCSR_INTERFACE_ID] = NULL, /* Sent through pack/unpack operations */
|
|
@@ -245,6 +284,7 @@ static starpu_mpi_datatype_free_func_t handle_free_datatype_funcs[STARPU_MAX_INT
|
|
|
{
|
|
|
[STARPU_MATRIX_INTERFACE_ID] = _starpu_mpi_handle_free_simple_datatype,
|
|
|
[STARPU_BLOCK_INTERFACE_ID] = _starpu_mpi_handle_free_complex_datatype,
|
|
|
+ [STARPU_TENSOR_INTERFACE_ID] = _starpu_mpi_handle_free_complex_datatype,
|
|
|
[STARPU_VECTOR_INTERFACE_ID] = _starpu_mpi_handle_free_simple_datatype,
|
|
|
[STARPU_CSR_INTERFACE_ID] = NULL, /* Sent through pack/unpack operations */
|
|
|
[STARPU_BCSR_INTERFACE_ID] = NULL, /* Sent through pack/unpack operations */
|