소스 검색

mpi/tests/datatype.c: test csr data communication

Nathalie Furmento 8 년 전
부모
커밋
5bf801b600
1개의 변경된 파일100개의 추가작업 그리고 0개의 파일을 삭제
  1. 100 0
      mpi/tests/datatypes.c

+ 100 - 0
mpi/tests/datatypes.c

@@ -471,6 +471,105 @@ void exchange_bcsr(int rank, int *error)
 	}
 }
 
+/*
+ * CSR
+ */
+void check_csr(starpu_data_handle_t handle_s, starpu_data_handle_t handle_r, int *error)
+{
+	STARPU_ASSERT(starpu_csr_get_elemsize(handle_s) == starpu_csr_get_elemsize(handle_r));
+	STARPU_ASSERT(starpu_csr_get_nnz(handle_s) == starpu_csr_get_nnz(handle_r));
+	STARPU_ASSERT(starpu_csr_get_nrow(handle_s) == starpu_csr_get_nrow(handle_r));
+	STARPU_ASSERT(starpu_csr_get_firstentry(handle_s) == starpu_csr_get_firstentry(handle_r));
+
+	starpu_data_acquire(handle_s, STARPU_R);
+	starpu_data_acquire(handle_r, STARPU_R);
+
+	uint32_t *colind_s = starpu_csr_get_local_colind(handle_s);
+	uint32_t *colind_r = starpu_csr_get_local_colind(handle_r);
+	uint32_t *rowptr_s = starpu_csr_get_local_rowptr(handle_s);
+	uint32_t *rowptr_r = starpu_csr_get_local_rowptr(handle_r);
+
+	int *csr_s = (int *)starpu_csr_get_local_nzval(handle_s);
+	int *csr_r = (int *)starpu_csr_get_local_nzval(handle_r);
+
+	int nnz = starpu_csr_get_nnz(handle_s);
+	int nrows = starpu_csr_get_nrow(handle_s);
+
+	int x;
+
+	for(x=0 ; x<nnz ; x++)
+	{
+		if (colind_s[x] == colind_r[x])
+		{
+			FPRINTF_MPI(stderr, "Success with colind[%d] value: %u == %u\n", x, colind_s[x], colind_r[x]);
+		}
+		else
+		{
+			*error = 1;
+			FPRINTF_MPI(stderr, "Error with colind[%d] value: %u != %u\n", x, colind_s[x], colind_r[x]);
+		}
+	}
+
+	for(x=0 ; x<nrows+1 ; x++)
+	{
+		if (rowptr_s[x] == rowptr_r[x])
+		{
+			FPRINTF_MPI(stderr, "Success with rowptr[%d] value: %u == %u\n", x, rowptr_s[x], rowptr_r[x]);
+		}
+		else
+		{
+			*error = 1;
+			FPRINTF_MPI(stderr, "Error with rowptr[%d] value: %u != %u\n", x, rowptr_s[x], rowptr_r[x]);
+		}
+	}
+
+	for(x=0 ; x<nnz ; x++)
+	{
+		if (csr_s[x] == csr_r[x])
+		{
+			FPRINTF_MPI(stderr, "Success with csr[%d] value: %d == %d\n", x, csr_s[x], csr_r[x]);
+		}
+		else
+		{
+			*error = 1;
+			FPRINTF_MPI(stderr, "Error with csr[%d] value: %d != %d\n", x, csr_s[x], csr_r[x]);
+		}
+	}
+
+	starpu_data_release(handle_s);
+	starpu_data_release(handle_r);
+}
+
+void exchange_csr(int rank, int *error)
+{
+	// the values are completely wrong, we just want to test that the communication is done correctly
+#define CSR_NROWS 2
+#define CSR_NNZ   5
+
+	if (rank == 0)
+	{
+		starpu_data_handle_t csr_handle[2];
+		uint32_t colind[CSR_NNZ] = {0, 1, 2, 3, 4};
+		uint32_t rowptr[CSR_NROWS+1] = {0, 1, CSR_NNZ};
+		int nzval[CSR_NNZ] = { 11, 22, 33, 44, 55 };
+
+		starpu_csr_data_register(&csr_handle[0], STARPU_MAIN_RAM, CSR_NNZ, CSR_NROWS, (uintptr_t) nzval, colind, rowptr, 0, sizeof(nzval[0]));
+		starpu_csr_data_register(&csr_handle[1], -1, CSR_NNZ, CSR_NROWS, (uintptr_t) NULL, (uint32_t *) NULL, (uint32_t *) NULL, 0, sizeof(nzval[0]));
+
+		send_recv_and_check(rank, 1, csr_handle[0], 0x84, csr_handle[1], 0x8765, error, check_csr);
+
+		starpu_data_unregister(csr_handle[0]);
+		starpu_data_unregister(csr_handle[1]);
+	}
+	else if (rank == 1)
+	{
+		starpu_data_handle_t csr_handle;
+		starpu_csr_data_register(&csr_handle, -1, CSR_NNZ, CSR_NROWS, (uintptr_t) NULL, (uint32_t *) NULL, (uint32_t *) NULL, 0, sizeof(int));
+		send_recv_and_check(rank, 0, csr_handle, 0x84, NULL, 0x8765, NULL, NULL);
+		starpu_data_unregister(csr_handle);
+	}
+}
+
 int main(int argc, char **argv)
 {
 	int ret, rank, size;
@@ -505,6 +604,7 @@ int main(int argc, char **argv)
 	exchange_matrix(rank, &error);
 	exchange_block(rank, &error);
 	exchange_bcsr(rank, &error);
+	exchange_csr(rank, &error);
 
 	starpu_mpi_shutdown();
 	starpu_shutdown();