浏览代码

mpi: implement and test communication with BCSR data interfaces

Nathalie Furmento 8 年之前
父节点
当前提交
8046187292
共有 2 个文件被更改,包括 119 次插入3 次删除
  1. 23 3
      mpi/src/starpu_mpi_datatype.c
  2. 96 0
      mpi/tests/datatypes.c

+ 23 - 3
mpi/src/starpu_mpi_datatype.c

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2009-2011, 2015  Université de Bordeaux
- * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015, 2016  CNRS
+ * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -41,6 +41,26 @@ void _starpu_mpi_datatype_shutdown(void)
 }
 
 /*
+ * 	Bcsr
+ */
+
+static void handle_to_datatype_bcsr(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+{
+	int ret;
+
+	uint32_t r = starpu_bcsr_get_r(data_handle);
+	uint32_t c = starpu_bcsr_get_c(data_handle);
+	uint32_t nnz = starpu_bcsr_get_nnz(data_handle);
+	size_t elemsize = starpu_bcsr_get_elemsize(data_handle);
+
+	ret = MPI_Type_contiguous(r*c*nnz*elemsize, MPI_BYTE, datatype);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
+
+	ret = MPI_Type_commit(datatype);
+	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
+}
+
+/*
  * 	Matrix
  */
 
@@ -149,7 +169,7 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
 	[STARPU_BLOCK_INTERFACE_ID]	= handle_to_datatype_block,
 	[STARPU_VECTOR_INTERFACE_ID]	= handle_to_datatype_vector,
 	[STARPU_CSR_INTERFACE_ID]	= NULL,
-	[STARPU_BCSR_INTERFACE_ID]	= NULL,
+	[STARPU_BCSR_INTERFACE_ID]	= handle_to_datatype_bcsr,
 	[STARPU_VARIABLE_INTERFACE_ID]	= handle_to_datatype_variable,
 	[STARPU_VOID_INTERFACE_ID]	= handle_to_datatype_void,
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
@@ -237,7 +257,7 @@ static starpu_mpi_datatype_free_func_t handle_free_datatype_funcs[STARPU_MAX_INT
 	[STARPU_BLOCK_INTERFACE_ID]	= _starpu_mpi_handle_free_complex_datatype,
 	[STARPU_VECTOR_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_CSR_INTERFACE_ID]	= NULL,
-	[STARPU_BCSR_INTERFACE_ID]	= NULL,
+	[STARPU_BCSR_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_VARIABLE_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_VOID_INTERFACE_ID]      = _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,

+ 96 - 0
mpi/tests/datatypes.c

@@ -152,6 +152,46 @@ void check_block(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, i
 	starpu_data_release(handle_r);
 }
 
+void check_bcsr(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
+{
+	STARPU_ASSERT(starpu_bcsr_get_elemsize(handle_s) == starpu_bcsr_get_elemsize(handle_r));
+	STARPU_ASSERT(starpu_bcsr_get_nnz(handle_s) == starpu_bcsr_get_nnz(handle_r));
+	STARPU_ASSERT(starpu_bcsr_get_nrow(handle_s) == starpu_bcsr_get_nrow(handle_r));
+	STARPU_ASSERT(starpu_bcsr_get_firstentry(handle_s) == starpu_bcsr_get_firstentry(handle_r));
+	STARPU_ASSERT(starpu_bcsr_get_r(handle_s) == starpu_bcsr_get_r(handle_r));
+	STARPU_ASSERT(starpu_bcsr_get_c(handle_s) == starpu_bcsr_get_c(handle_r));
+	//	STARPU_ASSERT(starpu_bcsr_get_local_colind(handle_s) == starpu_bcsr_get_local_colind(handle_r));
+	//	STARPU_ASSERT(starpu_bcsr_get_local_rowptr(handle_s) == starpu_bcsr_get_local_rowptr(handle_r));
+
+	starpu_data_acquire(handle_s, STARPU_R);
+	starpu_data_acquire(handle_r, STARPU_R);
+
+	int *bcsr_s = (int *)starpu_bcsr_get_local_nzval(handle_s);
+	int *bcsr_r = (int *)starpu_bcsr_get_local_nzval(handle_r);
+
+	int r = starpu_bcsr_get_r(handle_s);
+	int c = starpu_bcsr_get_c(handle_s);
+	int nnz = starpu_bcsr_get_nnz(handle_s);
+
+	int x;
+
+	for(x=0 ; x<r*c*nnz ; x++)
+	{
+		if (bcsr_s[x] == bcsr_r[x])
+		{
+			FPRINTF_MPI(stderr, "Success with bcsr[%d] value: %d == %d\n", x, bcsr_s[x], bcsr_r[x]);
+		}
+		else
+		{
+			*error = 1;
+			FPRINTF_MPI(stderr, "Error with bcsr[%d] value: %d != %d\n", x, bcsr_s[x], bcsr_r[x]);
+		}
+	}
+
+	starpu_data_release(handle_s);
+	starpu_data_release(handle_r);
+}
+
 void send_recv_and_check(int rank, int node, starpu_data_handle_t handle_s, int tag_s, starpu_data_handle_t handle_r, int tag_r, int *error, check_func func)
 {
 	int ret;
@@ -329,6 +369,61 @@ void exchange_block(int rank, int *error)
 	}
 }
 
+void exchange_bcsr(int rank, int *error)
+{
+	/*
+	 * We use the following matrix:
+	 *
+	 *   +----------------+
+	 *   |  0   1   0   0 |
+	 *   |  2   3   0   0 |
+	 *   |  4   5   8   9 |
+	 *   |  6   7  10  11 |
+	 *   +----------------+
+	 *
+	 * nzval  = [0, 1, 2, 3] ++ [4, 5, 6, 7] ++ [8, 9, 10, 11]
+	 * colind = [0, 0, 1]
+	 * rowptr = [0, 1 ]
+	 * r = c = 2
+	 */
+
+	/* Size of the blocks */
+#define BCSR_R 2
+#define BCSR_C 2
+#define BCSR_NROW 2
+#define BCSR_NNZ_BLOCKS 3     /* out of 4 */
+#define BCSR_NZVAL_SIZE (BCSR_R*BCSR_C*BCSR_NNZ_BLOCKS)
+
+	uint32_t colind[BCSR_NNZ_BLOCKS] = {0, 0, 1};
+	uint32_t rowptr[BCSR_NROW] = {0, 1};
+
+	if (rank == 0)
+	{
+		starpu_data_handle_t bcsr_handle[2];
+		int nzval[BCSR_NZVAL_SIZE]  =
+		{
+			0, 1, 2, 3,    /* First block  */
+			4, 5, 6, 7,    /* Second block */
+			8, 9, 10, 11   /* Third block  */
+		};
+
+		starpu_bcsr_data_register(&bcsr_handle[0], STARPU_MAIN_RAM, BCSR_NNZ_BLOCKS, BCSR_NROW, (uintptr_t) nzval, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
+		starpu_bcsr_data_register(&bcsr_handle[1], -1, BCSR_NNZ_BLOCKS, BCSR_NROW, (uintptr_t) NULL, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
+
+		send_recv_and_check(rank, 1, bcsr_handle[0], 0x73, bcsr_handle[1], 0x8337, error, check_bcsr);
+
+		starpu_data_unregister(bcsr_handle[0]);
+		starpu_data_unregister(bcsr_handle[1]);
+	}
+	else if (rank == 1)
+	{
+		starpu_data_handle_t bcsr_handle;
+		starpu_bcsr_data_register(&bcsr_handle, -1, BCSR_NNZ_BLOCKS, BCSR_NROW, (uintptr_t) NULL, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(int));
+		send_recv_and_check(rank, 0, bcsr_handle, 0x73, NULL, 0x8337, NULL, NULL);
+		starpu_data_unregister(bcsr_handle);
+	}
+}
+
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
@@ -362,6 +457,7 @@ int main(int argc, char **argv)
 	exchange_vector(rank, &error);
 	exchange_matrix(rank, &error);
 	exchange_block(rank, &error);
+	exchange_bcsr(rank, &error);
 
 	starpu_mpi_shutdown();
 	starpu_shutdown();