Bladeren bron

mpi: exchange bcsr data through pack/unpack operations

Nathalie Furmento 8 jaren geleden
bovenliggende
commit
e59ae9c57c
2 gewijzigde bestanden met toevoegingen van 41 en 32 verwijderingen
  1. 2 22
      mpi/src/starpu_mpi_datatype.c
  2. 39 10
      mpi/tests/datatypes.c

+ 2 - 22
mpi/src/starpu_mpi_datatype.c

@@ -41,26 +41,6 @@ void _starpu_mpi_datatype_shutdown(void)
 }
 
 /*
- * 	Bcsr
- */
-
-static void handle_to_datatype_bcsr(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
-{
-	int ret;
-
-	uint32_t r = starpu_bcsr_get_r(data_handle);
-	uint32_t c = starpu_bcsr_get_c(data_handle);
-	uint32_t nnz = starpu_bcsr_get_nnz(data_handle);
-	size_t elemsize = starpu_bcsr_get_elemsize(data_handle);
-
-	ret = MPI_Type_contiguous(r*c*nnz*elemsize, MPI_BYTE, datatype);
-	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_contiguous failed");
-
-	ret = MPI_Type_commit(datatype);
-	STARPU_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Type_commit failed");
-}
-
-/*
  * 	Matrix
  */
 
@@ -169,7 +149,7 @@ static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_I
 	[STARPU_BLOCK_INTERFACE_ID]	= handle_to_datatype_block,
 	[STARPU_VECTOR_INTERFACE_ID]	= handle_to_datatype_vector,
 	[STARPU_CSR_INTERFACE_ID]	= NULL,
-	[STARPU_BCSR_INTERFACE_ID]	= handle_to_datatype_bcsr,
+	[STARPU_BCSR_INTERFACE_ID]	= NULL,
 	[STARPU_VARIABLE_INTERFACE_ID]	= handle_to_datatype_variable,
 	[STARPU_VOID_INTERFACE_ID]	= handle_to_datatype_void,
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
@@ -265,7 +245,7 @@ static starpu_mpi_datatype_free_func_t handle_free_datatype_funcs[STARPU_MAX_INT
 	[STARPU_BLOCK_INTERFACE_ID]	= _starpu_mpi_handle_free_complex_datatype,
 	[STARPU_VECTOR_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_CSR_INTERFACE_ID]	= NULL,
-	[STARPU_BCSR_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
+	[STARPU_BCSR_INTERFACE_ID]	= NULL,
 	[STARPU_VARIABLE_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_VOID_INTERFACE_ID]      = _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,

+ 39 - 10
mpi/tests/datatypes.c

@@ -160,21 +160,51 @@ void check_bcsr(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, in
 	STARPU_ASSERT(starpu_bcsr_get_firstentry(handle_s) == starpu_bcsr_get_firstentry(handle_r));
 	STARPU_ASSERT(starpu_bcsr_get_r(handle_s) == starpu_bcsr_get_r(handle_r));
 	STARPU_ASSERT(starpu_bcsr_get_c(handle_s) == starpu_bcsr_get_c(handle_r));
-	//	STARPU_ASSERT(starpu_bcsr_get_local_colind(handle_s) == starpu_bcsr_get_local_colind(handle_r));
-	//	STARPU_ASSERT(starpu_bcsr_get_local_rowptr(handle_s) == starpu_bcsr_get_local_rowptr(handle_r));
 
 	starpu_data_acquire(handle_s, STARPU_R);
 	starpu_data_acquire(handle_r, STARPU_R);
 
+	uint32_t *colind_s = starpu_bcsr_get_local_colind(handle_s);
+	uint32_t *colind_r = starpu_bcsr_get_local_colind(handle_r);
+	uint32_t *rowptr_s = starpu_bcsr_get_local_rowptr(handle_s);
+	uint32_t *rowptr_r = starpu_bcsr_get_local_rowptr(handle_r);
+
 	int *bcsr_s = (int *)starpu_bcsr_get_local_nzval(handle_s);
 	int *bcsr_r = (int *)starpu_bcsr_get_local_nzval(handle_r);
 
 	int r = starpu_bcsr_get_r(handle_s);
 	int c = starpu_bcsr_get_c(handle_s);
 	int nnz = starpu_bcsr_get_nnz(handle_s);
+	int nrows = starpu_bcsr_get_nrow(handle_s);
 
 	int x;
 
+	for(x=0 ; x<nnz ; x++)
+	{
+		if (colind_s[x] == colind_r[x])
+		{
+			FPRINTF_MPI(stderr, "Success with colind[%d] value: %u == %u\n", x, colind_s[x], colind_r[x]);
+		}
+		else
+		{
+			*error = 1;
+			FPRINTF_MPI(stderr, "Error with colind[%d] value: %u != %u\n", x, colind_s[x], colind_r[x]);
+		}
+	}
+
+	for(x=0 ; x<nrows+1 ; x++)
+	{
+		if (rowptr_s[x] == rowptr_r[x])
+		{
+			FPRINTF_MPI(stderr, "Success with rowptr[%d] value: %u == %u\n", x, rowptr_s[x], rowptr_r[x]);
+		}
+		else
+		{
+			*error = 1;
+			FPRINTF_MPI(stderr, "Error with rowptr[%d] value: %u != %u\n", x, rowptr_s[x], rowptr_r[x]);
+		}
+	}
+
 	for(x=0 ; x<r*c*nnz ; x++)
 	{
 		if (bcsr_s[x] == bcsr_r[x])
@@ -383,23 +413,22 @@ void exchange_bcsr(int rank, int *error)
 	 *
 	 * nzval  = [0, 1, 2, 3] ++ [4, 5, 6, 7] ++ [8, 9, 10, 11]
 	 * colind = [0, 0, 1]
-	 * rowptr = [0, 1 ]
+	 * rowptr = [0, 1, 3]
 	 * r = c = 2
 	 */
 
 	/* Size of the blocks */
 #define BCSR_R 2
 #define BCSR_C 2
-#define BCSR_NROW 2
+#define BCSR_NROWS 2
 #define BCSR_NNZ_BLOCKS 3     /* out of 4 */
 #define BCSR_NZVAL_SIZE (BCSR_R*BCSR_C*BCSR_NNZ_BLOCKS)
 
-	uint32_t colind[BCSR_NNZ_BLOCKS] = {0, 0, 1};
-	uint32_t rowptr[BCSR_NROW] = {0, 1};
-
 	if (rank == 0)
 	{
 		starpu_data_handle_t bcsr_handle[2];
+		uint32_t colind[BCSR_NNZ_BLOCKS] = {0, 0, 1};
+		uint32_t rowptr[BCSR_NROWS+1] = {0, 1, BCSR_NNZ_BLOCKS};
 		int nzval[BCSR_NZVAL_SIZE]  =
 		{
 			0, 1, 2, 3,    /* First block  */
@@ -407,8 +436,8 @@ void exchange_bcsr(int rank, int *error)
 			8, 9, 10, 11   /* Third block  */
 		};
 
-		starpu_bcsr_data_register(&bcsr_handle[0], STARPU_MAIN_RAM, BCSR_NNZ_BLOCKS, BCSR_NROW, (uintptr_t) nzval, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
-		starpu_bcsr_data_register(&bcsr_handle[1], -1, BCSR_NNZ_BLOCKS, BCSR_NROW, (uintptr_t) NULL, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
+		starpu_bcsr_data_register(&bcsr_handle[0], STARPU_MAIN_RAM, BCSR_NNZ_BLOCKS, BCSR_NROWS, (uintptr_t) nzval, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
+		starpu_bcsr_data_register(&bcsr_handle[1], -1, BCSR_NNZ_BLOCKS, BCSR_NROWS, (uintptr_t) NULL, (uint32_t *) NULL, (uint32_t *) NULL, 0, BCSR_R, BCSR_C, sizeof(nzval[0]));
 
 		send_recv_and_check(rank, 1, bcsr_handle[0], 0x73, bcsr_handle[1], 0x8337, error, check_bcsr);
 
@@ -418,7 +447,7 @@ void exchange_bcsr(int rank, int *error)
 	else if (rank == 1)
 	{
 		starpu_data_handle_t bcsr_handle;
-		starpu_bcsr_data_register(&bcsr_handle, -1, BCSR_NNZ_BLOCKS, BCSR_NROW, (uintptr_t) NULL, colind, rowptr, 0, BCSR_R, BCSR_C, sizeof(int));
+		starpu_bcsr_data_register(&bcsr_handle, -1, BCSR_NNZ_BLOCKS, BCSR_NROWS, (uintptr_t) NULL, (uint32_t *) NULL, (uint32_t *) NULL, 0, BCSR_R, BCSR_C, sizeof(int));
 		send_recv_and_check(rank, 0, bcsr_handle, 0x73, NULL, 0x8337, NULL, NULL);
 		starpu_data_unregister(bcsr_handle);
 	}