소스 검색

Allow bcsr/csr data to become completely zero

While this would look of little practical interest, it's better to
handle the case, and it can actually happen when partitioning some
matrix with huge parts being empty.
Samuel Thibault 5 년 전
부모
커밋
299c803d23
2개의 변경된 파일100개의 추가작업 그리고 52개의 파일을 삭제
  1. 50 26
      src/datawizard/interfaces/bcsr_interface.c
  2. 50 26
      src/datawizard/interfaces/csr_interface.c

+ 50 - 26
src/datawizard/interfaces/bcsr_interface.c

@@ -133,10 +133,13 @@ void starpu_bcsr_data_register(starpu_data_handle_t *handleptr, int home_node,
 #ifndef STARPU_SIMGRID
 	if (home_node >= 0 && starpu_node_get_kind(home_node) == STARPU_CPU_RAM)
 	{
-		STARPU_ASSERT_ACCESSIBLE(nzval);
-		STARPU_ASSERT_ACCESSIBLE(nzval + nnz*elemsize*r*c - 1);
-		STARPU_ASSERT_ACCESSIBLE(colind);
-		STARPU_ASSERT_ACCESSIBLE((uintptr_t) colind + nnz*sizeof(uint32_t) - 1);
+		if (nnz)
+		{
+			STARPU_ASSERT_ACCESSIBLE(nzval);
+			STARPU_ASSERT_ACCESSIBLE(nzval + nnz*elemsize*r*c - 1);
+			STARPU_ASSERT_ACCESSIBLE(colind);
+			STARPU_ASSERT_ACCESSIBLE((uintptr_t) colind + nnz*sizeof(uint32_t) - 1);
+		}
 		STARPU_ASSERT_ACCESSIBLE(rowptr);
 		STARPU_ASSERT_ACCESSIBLE((uintptr_t) rowptr + (nrow+1)*sizeof(uint32_t) - 1);
 	}
@@ -327,12 +330,19 @@ static starpu_ssize_t allocate_bcsr_buffer_on_node(void *data_interface_, unsign
 
 	STARPU_ASSERT_MSG(r && c, "partitioning bcsr with several memory nodes is not supported yet");
 
-	addr_nzval = starpu_malloc_on_node(dst_node, nnz*r*c*elemsize);
-	if (!addr_nzval)
-		goto fail_nzval;
-	addr_colind = starpu_malloc_on_node(dst_node, nnz*sizeof(uint32_t));
-	if (!addr_colind)
-		goto fail_colind;
+	if (nnz)
+	{
+		addr_nzval = starpu_malloc_on_node(dst_node, nnz*r*c*elemsize);
+		if (!addr_nzval)
+			goto fail_nzval;
+		addr_colind = starpu_malloc_on_node(dst_node, nnz*sizeof(uint32_t));
+		if (!addr_colind)
+			goto fail_colind;
+	}
+	else
+	{
+		addr_nzval = addr_colind = NULL;
+	}
 	addr_rowptr = starpu_malloc_on_node(dst_node, (nrow+1)*sizeof(uint32_t));
 	if (!addr_rowptr)
 		goto fail_rowptr;
@@ -349,9 +359,11 @@ static starpu_ssize_t allocate_bcsr_buffer_on_node(void *data_interface_, unsign
 	return allocated_memory;
 
 fail_rowptr:
-	starpu_free_on_node(dst_node, addr_colind, nnz*sizeof(uint32_t));
+	if (nnz)
+		starpu_free_on_node(dst_node, addr_colind, nnz*sizeof(uint32_t));
 fail_colind:
-	starpu_free_on_node(dst_node, addr_nzval, nnz*r*c*elemsize);
+	if (nnz)
+		starpu_free_on_node(dst_node, addr_nzval, nnz*r*c*elemsize);
 fail_nzval:
 	/* allocation failed */
 	return -ENOMEM;
@@ -366,8 +378,11 @@ static void free_bcsr_buffer_on_node(void *data_interface, unsigned node)
 	uint32_t r = bcsr_interface->r;
 	uint32_t c = bcsr_interface->c;
 
-	starpu_free_on_node(node, bcsr_interface->nzval, nnz*r*c*elemsize);
-	starpu_free_on_node(node, (uintptr_t) bcsr_interface->colind, nnz*sizeof(uint32_t));
+	if (nnz)
+	{
+		starpu_free_on_node(node, bcsr_interface->nzval, nnz*r*c*elemsize);
+		starpu_free_on_node(node, (uintptr_t) bcsr_interface->colind, nnz*sizeof(uint32_t));
+	}
 	starpu_free_on_node(node, (uintptr_t) bcsr_interface->rowptr, (nrow+1)*sizeof(uint32_t));
 }
 
@@ -385,11 +400,14 @@ static int copy_any_to_any(void *src_interface, unsigned src_node, void *dst_int
 
 	int ret = 0;
 
-	if (starpu_interface_copy(src_bcsr->nzval, 0, src_node, dst_bcsr->nzval, 0, dst_node, nnz*elemsize*r*c, async_data))
-		ret = -EAGAIN;
+	if (nnz)
+	{
+		if (starpu_interface_copy(src_bcsr->nzval, 0, src_node, dst_bcsr->nzval, 0, dst_node, nnz*elemsize*r*c, async_data))
+			ret = -EAGAIN;
 
-	if (starpu_interface_copy((uintptr_t)src_bcsr->colind, 0, src_node, (uintptr_t)dst_bcsr->colind, 0, dst_node, nnz*sizeof(uint32_t), async_data))
-		ret = -EAGAIN;
+		if (starpu_interface_copy((uintptr_t)src_bcsr->colind, 0, src_node, (uintptr_t)dst_bcsr->colind, 0, dst_node, nnz*sizeof(uint32_t), async_data))
+			ret = -EAGAIN;
+	}
 
 	if (starpu_interface_copy((uintptr_t)src_bcsr->rowptr, 0, src_node, (uintptr_t)dst_bcsr->rowptr, 0, dst_node, (nrow+1)*sizeof(uint32_t), async_data))
 		ret = -EAGAIN;
@@ -427,10 +445,13 @@ static int pack_data(starpu_data_handle_t handle, unsigned node, void **ptr, sta
 	{
 		*ptr = (void *)starpu_malloc_on_node_flags(node, *count, 0);
 		char *tmp = *ptr;
-		memcpy(tmp, (void*)bcsr->colind, bcsr->nnz * sizeof(bcsr->colind[0]));
-		tmp += bcsr->nnz * sizeof(bcsr->colind[0]);
-		memcpy(tmp, (void*)bcsr->rowptr, (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]));
-		tmp += (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]);
+		if (nnz)
+		{
+			memcpy(tmp, (void*)bcsr->colind, bcsr->nnz * sizeof(bcsr->colind[0]));
+			tmp += bcsr->nnz * sizeof(bcsr->colind[0]);
+			memcpy(tmp, (void*)bcsr->rowptr, (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]));
+			tmp += (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]);
+		}
 		memcpy(tmp, (void*)bcsr->nzval, bcsr->r * bcsr->c * bcsr->nnz * bcsr->elemsize);
 	}
 
@@ -446,10 +467,13 @@ static int unpack_data(starpu_data_handle_t handle, unsigned node, void *ptr, si
 	STARPU_ASSERT(count == (bcsr->nnz * sizeof(bcsr->colind[0]))+((bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]))+(bcsr->r * bcsr->c * bcsr->nnz * bcsr->elemsize));
 
 	char *tmp = ptr;
-	memcpy((void*)bcsr->colind, tmp, bcsr->nnz * sizeof(bcsr->colind[0]));
-	tmp += bcsr->nnz * sizeof(bcsr->colind[0]);
-	memcpy((void*)bcsr->rowptr, tmp, (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]));
-	tmp += (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]);
+	if (nnz)
+	{
+		memcpy((void*)bcsr->colind, tmp, bcsr->nnz * sizeof(bcsr->colind[0]));
+		tmp += bcsr->nnz * sizeof(bcsr->colind[0]);
+		memcpy((void*)bcsr->rowptr, tmp, (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]));
+		tmp += (bcsr->nrow + 1) * sizeof(bcsr->rowptr[0]);
+	}
 	memcpy((void*)bcsr->nzval, tmp, bcsr->r * bcsr->c * bcsr->nnz * bcsr->elemsize);
 
 	starpu_free_on_node_flags(node, (uintptr_t)ptr, count, 0);

+ 50 - 26
src/datawizard/interfaces/csr_interface.c

@@ -115,10 +115,13 @@ void starpu_csr_data_register(starpu_data_handle_t *handleptr, int home_node,
 #ifndef STARPU_SIMGRID
 	if (home_node >= 0 && starpu_node_get_kind(home_node) == STARPU_CPU_RAM)
 	{
-		STARPU_ASSERT_ACCESSIBLE(nzval);
-		STARPU_ASSERT_ACCESSIBLE(nzval + nnz*elemsize - 1);
-		STARPU_ASSERT_ACCESSIBLE(colind);
-		STARPU_ASSERT_ACCESSIBLE((uintptr_t) colind + nnz*sizeof(uint32_t) - 1);
+		if (nnz)
+		{
+			STARPU_ASSERT_ACCESSIBLE(nzval);
+			STARPU_ASSERT_ACCESSIBLE(nzval + nnz*elemsize - 1);
+			STARPU_ASSERT_ACCESSIBLE(colind);
+			STARPU_ASSERT_ACCESSIBLE((uintptr_t) colind + nnz*sizeof(uint32_t) - 1);
+		}
 		STARPU_ASSERT_ACCESSIBLE(rowptr);
 		STARPU_ASSERT_ACCESSIBLE((uintptr_t) rowptr + (nrow+1)*sizeof(uint32_t) - 1);
 	}
@@ -272,12 +275,19 @@ static starpu_ssize_t allocate_csr_buffer_on_node(void *data_interface_, unsigne
 	uint32_t nrow = csr_interface->nrow;
 	size_t elemsize = csr_interface->elemsize;
 
-	addr_nzval = starpu_malloc_on_node(dst_node, nnz*elemsize);
-	if (!addr_nzval)
-		goto fail_nzval;
-	addr_colind = (uint32_t*) starpu_malloc_on_node(dst_node, nnz*sizeof(uint32_t));
-	if (!addr_colind)
-		goto fail_colind;
+	if (nnz)
+	{
+		addr_nzval = starpu_malloc_on_node(dst_node, nnz*elemsize);
+		if (!addr_nzval)
+			goto fail_nzval;
+		addr_colind = (uint32_t*) starpu_malloc_on_node(dst_node, nnz*sizeof(uint32_t));
+		if (!addr_colind)
+			goto fail_colind;
+	}
+	else
+	{
+		addr_nzval = addr_colind = NULL;
+	}
 	addr_rowptr = (uint32_t*) starpu_malloc_on_node(dst_node, (nrow+1)*sizeof(uint32_t));
 	if (!addr_rowptr)
 		goto fail_rowptr;
@@ -294,9 +304,11 @@ static starpu_ssize_t allocate_csr_buffer_on_node(void *data_interface_, unsigne
 	return allocated_memory;
 
 fail_rowptr:
-	starpu_free_on_node(dst_node, (uintptr_t) addr_colind, nnz*sizeof(uint32_t));
+	if (nnz)
+		starpu_free_on_node(dst_node, (uintptr_t) addr_colind, nnz*sizeof(uint32_t));
 fail_colind:
-	starpu_free_on_node(dst_node, addr_nzval, nnz*elemsize);
+	if (nnz)
+		starpu_free_on_node(dst_node, addr_nzval, nnz*elemsize);
 fail_nzval:
 	/* allocation failed */
 	return -ENOMEM;
@@ -309,8 +321,11 @@ static void free_csr_buffer_on_node(void *data_interface, unsigned node)
 	uint32_t nrow = csr_interface->nrow;
 	size_t elemsize = csr_interface->elemsize;
 
-	starpu_free_on_node(node, csr_interface->nzval, nnz*elemsize);
-	starpu_free_on_node(node, (uintptr_t) csr_interface->colind, nnz*sizeof(uint32_t));
+	if (nnz)
+	{
+		starpu_free_on_node(node, csr_interface->nzval, nnz*elemsize);
+		starpu_free_on_node(node, (uintptr_t) csr_interface->colind, nnz*sizeof(uint32_t));
+	}
 	starpu_free_on_node(node, (uintptr_t) csr_interface->rowptr, (nrow+1)*sizeof(uint32_t));
 }
 
@@ -325,11 +340,14 @@ static int copy_any_to_any(void *src_interface, unsigned src_node, void *dst_int
 	size_t elemsize = src_csr->elemsize;
 	int ret = 0;
 
-	if (starpu_interface_copy(src_csr->nzval, 0, src_node, dst_csr->nzval, 0, dst_node, nnz*elemsize, async_data))
-		ret = -EAGAIN;
+	if (nnz)
+	{
+		if (starpu_interface_copy(src_csr->nzval, 0, src_node, dst_csr->nzval, 0, dst_node, nnz*elemsize, async_data))
+			ret = -EAGAIN;
 
-	if (starpu_interface_copy((uintptr_t)src_csr->colind, 0, src_node, (uintptr_t)dst_csr->colind, 0, dst_node, nnz*sizeof(uint32_t), async_data))
-		ret = -EAGAIN;
+		if (starpu_interface_copy((uintptr_t)src_csr->colind, 0, src_node, (uintptr_t)dst_csr->colind, 0, dst_node, nnz*sizeof(uint32_t), async_data))
+			ret = -EAGAIN;
+	}
 
 	if (starpu_interface_copy((uintptr_t)src_csr->rowptr, 0, src_node, (uintptr_t)dst_csr->rowptr, 0, dst_node, (nrow+1)*sizeof(uint32_t), async_data))
 		ret = -EAGAIN;
@@ -365,10 +383,13 @@ static int pack_data(starpu_data_handle_t handle, unsigned node, void **ptr, sta
 	{
 		*ptr = (void *)starpu_malloc_on_node_flags(node, *count, 0);
 		char *tmp = *ptr;
-		memcpy(tmp, (void*)csr->colind, csr->nnz * sizeof(csr->colind[0]));
-		tmp += csr->nnz * sizeof(csr->colind[0]);
-		memcpy(tmp, (void*)csr->rowptr, (csr->nrow + 1) * sizeof(csr->rowptr[0]));
-		tmp += (csr->nrow + 1) * sizeof(csr->rowptr[0]);
+		if (nnz)
+		{
+			memcpy(tmp, (void*)csr->colind, csr->nnz * sizeof(csr->colind[0]));
+			tmp += csr->nnz * sizeof(csr->colind[0]);
+			memcpy(tmp, (void*)csr->rowptr, (csr->nrow + 1) * sizeof(csr->rowptr[0]));
+			tmp += (csr->nrow + 1) * sizeof(csr->rowptr[0]);
+		}
 		memcpy(tmp, (void*)csr->nzval, csr->nnz * csr->elemsize);
 	}
 
@@ -384,10 +405,13 @@ static int unpack_data(starpu_data_handle_t handle, unsigned node, void *ptr, si
 	STARPU_ASSERT(count == (csr->nnz * sizeof(csr->colind[0]))+((csr->nrow + 1) * sizeof(csr->rowptr[0]))+(csr->nnz * csr->elemsize));
 
 	char *tmp = ptr;
-	memcpy((void*)csr->colind, tmp, csr->nnz * sizeof(csr->colind[0]));
-	tmp += csr->nnz * sizeof(csr->colind[0]);
-	memcpy((void*)csr->rowptr, tmp, (csr->nrow + 1) * sizeof(csr->rowptr[0]));
-	tmp += (csr->nrow + 1) * sizeof(csr->rowptr[0]);
+	if (nnz)
+	{
+		memcpy((void*)csr->colind, tmp, csr->nnz * sizeof(csr->colind[0]));
+		tmp += csr->nnz * sizeof(csr->colind[0]);
+		memcpy((void*)csr->rowptr, tmp, (csr->nrow + 1) * sizeof(csr->rowptr[0]));
+		tmp += (csr->nrow + 1) * sizeof(csr->rowptr[0]);
+	}
 	memcpy((void*)csr->nzval, tmp, csr->nnz * csr->elemsize);
 
 	starpu_free_on_node_flags(node, (uintptr_t)ptr, count, 0);