Procházet zdrojové kódy

Add MPI type for tensor data interface

Samuel Thibault před 5 roky
rodič
revize
739978461a
1 změnil soubory, kde provedl 41 přidání a 1 odebrání
  1. 41 1
      mpi/src/starpu_mpi_datatype.c

+ 41 - 1
mpi/src/starpu_mpi_datatype.c

@@ -2,7 +2,7 @@
  *
  * Copyright (C) 2010-2017,2019                           CNRS
  * Copyright (C) 2011,2012,2015                           Inria
- * Copyright (C) 2009-2011,2014,2015,2018                 Université de Bordeaux
+ * Copyright (C) 2009-2011,2014,2015,2018,2020            Université de Bordeaux
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -91,6 +91,44 @@ static void handle_to_datatype_block(starpu_data_handle_t data_handle, MPI_Datat
 }
 
 /*
+ * 	Tensor
+ */
+
+static void handle_to_datatype_tensor(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+{
+	int ret;
+
+	unsigned nx = starpu_tensor_get_nx(data_handle);
+	unsigned ny = starpu_tensor_get_ny(data_handle);
+	unsigned nz = starpu_tensor_get_nz(data_handle);
+	unsigned nt = starpu_tensor_get_nt(data_handle);
+	unsigned ldy = starpu_tensor_get_local_ldy(data_handle);
+	unsigned ldz = starpu_tensor_get_local_ldz(data_handle);
+	unsigned ldt = starpu_tensor_get_local_ldt(data_handle);
+	size_t elemsize = starpu_block_get_elemsize(data_handle);
+
+	MPI_Datatype datatype_3dlayer;
+	ret = MPI_Type_vector(ny, nx*elemsize, ldy*elemsize, MPI_BYTE, &datatype_3dlayer);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_vector failed");
+
+	ret = MPI_Type_commit(&datatype_3dlayer);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
+
+	MPI_Datatype datatype_2dlayer;
+	ret = MPI_Type_create_hvector(nz, 1, ldz*elemsize, datatype_3dlayer, &datatype_2dlayer);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_hvector failed");
+
+	ret = MPI_Type_commit(&datatype_2dlayer);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
+
+	ret = MPI_Type_create_hvector(nt, 1, ldt*elemsize, datatype_2dlayer, datatype);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_hvector failed");
+
+	ret = MPI_Type_commit(datatype);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
+}
+
+/*
  * 	Vector
  */
 
@@ -149,6 +187,7 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
 {
 	[STARPU_MATRIX_INTERFACE_ID]	= handle_to_datatype_matrix,
 	[STARPU_BLOCK_INTERFACE_ID]	= handle_to_datatype_block,
+	[STARPU_TENSOR_INTERFACE_ID]	= handle_to_datatype_tensor,
 	[STARPU_VECTOR_INTERFACE_ID]	= handle_to_datatype_vector,
 	[STARPU_CSR_INTERFACE_ID]	= NULL, /* Sent through pack/unpack operations */
 	[STARPU_BCSR_INTERFACE_ID]	= NULL, /* Sent through pack/unpack operations */
@@ -245,6 +284,7 @@ static starpu_mpi_datatype_free_func_t handle_free_datatype_funcs[STARPU_MAX_INT
 {
 	[STARPU_MATRIX_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_BLOCK_INTERFACE_ID]	= _starpu_mpi_handle_free_complex_datatype,
+	[STARPU_TENSOR_INTERFACE_ID]	= _starpu_mpi_handle_free_complex_datatype,
 	[STARPU_VECTOR_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_CSR_INTERFACE_ID]	= NULL,  /* Sent through pack/unpack operations */
 	[STARPU_BCSR_INTERFACE_ID]	= NULL,  /* Sent through pack/unpack operations */