Parcourir la source

Factorize reading matrix/block/tensor dimensions

Samuel Thibault il y a 5 ans
Parent
commit
0bde71ed6c

+ 38 - 24
src/datawizard/interfaces/block_interface.c

@@ -227,7 +227,14 @@ static int pack_block_handle(starpu_data_handle_t handle, unsigned node, void **
 	struct starpu_block_interface *block_interface = (struct starpu_block_interface *)
 		starpu_data_get_interface_on_node(handle, node);
 
-	*count = block_interface->nx*block_interface->ny*block_interface->nz*block_interface->elemsize;
+	uint32_t ldy = block_interface->ldy;
+	uint32_t ldz = block_interface->ldz;
+	uint32_t nx = block_interface->nx;
+	uint32_t ny = block_interface->ny;
+	uint32_t nz = block_interface->nz;
+	size_t elemsize = block_interface->elemsize;
+
+	*count = nx*ny*nz*elemsize;
 
 	if (ptr != NULL)
 	{
@@ -238,29 +245,29 @@ static int pack_block_handle(starpu_data_handle_t handle, unsigned node, void **
 
 		char *cur = *ptr;
 
-		if (block_interface->nx * block_interface->ny == block_interface->ldz && block_interface->nx == block_interface->ldy)
-			memcpy(cur, block, block_interface->nx * block_interface->ny * block_interface->nz * block_interface->elemsize);
+		if (nx * ny == ldz && nx == ldy)
+			memcpy(cur, block, nx * ny * nz * elemsize);
 		else
 		{
 			char *block_z = block;
-			for(z=0 ; z<block_interface->nz ; z++)
+			for(z=0 ; z<nz ; z++)
 			{
-				if (block_interface->nx == block_interface->ldy)
+				if (nx == ldy)
 				{
-					memcpy(cur, block_z, block_interface->nx * block_interface->ny * block_interface->elemsize);
-					cur += block_interface->nx*block_interface->ny*block_interface->elemsize;
+					memcpy(cur, block_z, nx * ny * elemsize);
+					cur += nx*ny*elemsize;
 				}
 				else
 				{
 					char *block_y = block_z;
-					for(y=0 ; y<block_interface->ny ; y++)
+					for(y=0 ; y<ny ; y++)
 					{
-						memcpy(cur, block_y, block_interface->nx*block_interface->elemsize);
-						cur += block_interface->nx*block_interface->elemsize;
-						block_y += block_interface->ldy * block_interface->elemsize;
+						memcpy(cur, block_y, nx*elemsize);
+						cur += nx*elemsize;
+						block_y += ldy * elemsize;
 					}
 				}
-				block_z += block_interface->ldz * block_interface->elemsize;
+				block_z += ldz * elemsize;
 			}
 		}
 	}
@@ -275,35 +282,42 @@ static int unpack_block_handle(starpu_data_handle_t handle, unsigned node, void
 	struct starpu_block_interface *block_interface = (struct starpu_block_interface *)
 		starpu_data_get_interface_on_node(handle, node);
 
-	STARPU_ASSERT(count == block_interface->elemsize * block_interface->nx * block_interface->ny * block_interface->nz);
+	uint32_t ldy = block_interface->ldy;
+	uint32_t ldz = block_interface->ldz;
+	uint32_t nx = block_interface->nx;
+	uint32_t ny = block_interface->ny;
+	uint32_t nz = block_interface->nz;
+	size_t elemsize = block_interface->elemsize;
+
+	STARPU_ASSERT(count == elemsize * nx * ny * nz);
 
 	uint32_t z, y;
 	char *cur = ptr;
 	char *block = (void *)block_interface->ptr;
 
-	if (block_interface->nx * block_interface->ny == block_interface->ldz && block_interface->nx == block_interface->ldy)
-		memcpy(block, cur, block_interface->nx * block_interface->ny * block_interface->nz * block_interface->elemsize);
+	if (nx * ny == ldz && nx == ldy)
+		memcpy(block, cur, nx * ny * nz * elemsize);
 	else
 	{
 		char *block_z = block;
-		for(z=0 ; z<block_interface->nz ; z++)
+		for(z=0 ; z<nz ; z++)
 		{
-			if (block_interface->nx == block_interface->ldy)
+			if (nx == ldy)
 			{
-				memcpy(block_z, cur, block_interface->nx * block_interface->ny * block_interface->elemsize);
-				cur += block_interface->nx*block_interface->ny*block_interface->elemsize;
+				memcpy(block_z, cur, nx * ny * elemsize);
+				cur += nx*ny*elemsize;
 			}
 			else
 			{
 				char *block_y = block_z;
-				for(y=0 ; y<block_interface->ny ; y++)
+				for(y=0 ; y<ny ; y++)
 				{
-					memcpy(block_y, cur, block_interface->nx*block_interface->elemsize);
-					cur += block_interface->nx*block_interface->elemsize;
-					block_y += block_interface->ldy * block_interface->elemsize;
+					memcpy(block_y, cur, nx*elemsize);
+					cur += nx*elemsize;
+					block_y += ldy * elemsize;
 				}
 			}
-			block_z += block_interface->ldz * block_interface->elemsize;
+			block_z += ldz * elemsize;
 		}
 	}
 

+ 24 - 14
src/datawizard/interfaces/matrix_interface.c

@@ -266,7 +266,12 @@ static int pack_matrix_handle(starpu_data_handle_t handle, unsigned node, void *
 	struct starpu_matrix_interface *matrix_interface = (struct starpu_matrix_interface *)
 		starpu_data_get_interface_on_node(handle, node);
 
-	*count = matrix_interface->nx*matrix_interface->ny*matrix_interface->elemsize;
+	uint32_t ld = matrix_interface->ld;
+	uint32_t nx = matrix_interface->nx;
+	uint32_t ny = matrix_interface->ny;
+	size_t elemsize = matrix_interface->elemsize;
+
+	*count = nx*ny*elemsize;
 
 	if (ptr != NULL)
 	{
@@ -275,16 +280,16 @@ static int pack_matrix_handle(starpu_data_handle_t handle, unsigned node, void *
 		*ptr = (void *)starpu_malloc_on_node_flags(node, *count, 0);
 		char *cur = *ptr;
 
-		if (matrix_interface->ld == matrix_interface->nx)
-			memcpy(cur, matrix, matrix_interface->nx*matrix_interface->ny*matrix_interface->elemsize);
+		if (ld == nx)
+			memcpy(cur, matrix, nx*ny*elemsize);
 		else
 		{
 			uint32_t y;
-			for(y=0 ; y<matrix_interface->ny ; y++)
+			for(y=0 ; y<ny ; y++)
 			{
-				memcpy(cur, matrix, matrix_interface->nx*matrix_interface->elemsize);
-				cur += matrix_interface->nx*matrix_interface->elemsize;
-				matrix += matrix_interface->ld * matrix_interface->elemsize;
+				memcpy(cur, matrix, nx*elemsize);
+				cur += nx*elemsize;
+				matrix += ld * elemsize;
 			}
 		}
 	}
@@ -299,21 +304,26 @@ static int unpack_matrix_handle(starpu_data_handle_t handle, unsigned node, void
 	struct starpu_matrix_interface *matrix_interface = (struct starpu_matrix_interface *)
 		starpu_data_get_interface_on_node(handle, node);
 
-	STARPU_ASSERT(count == matrix_interface->elemsize * matrix_interface->nx * matrix_interface->ny);
+	uint32_t ld = matrix_interface->ld;
+	uint32_t nx = matrix_interface->nx;
+	uint32_t ny = matrix_interface->ny;
+	size_t elemsize = matrix_interface->elemsize;
+
+	STARPU_ASSERT(count == elemsize * nx * ny);
 
 	char *matrix = (void *)matrix_interface->ptr;
 
-	if (matrix_interface->ld == matrix_interface->nx)
-		memcpy(matrix, ptr, matrix_interface->nx*matrix_interface->ny*matrix_interface->elemsize);
+	if (ld == nx)
+		memcpy(matrix, ptr, nx*ny*elemsize);
 	else
 	{
 		uint32_t y;
 		char *cur = ptr;
-		for(y=0 ; y<matrix_interface->ny ; y++)
+		for(y=0 ; y<ny ; y++)
 		{
-			memcpy(matrix, cur, matrix_interface->nx*matrix_interface->elemsize);
-			cur += matrix_interface->nx*matrix_interface->elemsize;
-			matrix += matrix_interface->ld * matrix_interface->elemsize;
+			memcpy(matrix, cur, nx*elemsize);
+			cur += nx*elemsize;
+			matrix += ld * elemsize;
 		}
 	}
 

+ 58 - 40
src/datawizard/interfaces/tensor_interface.c

@@ -237,7 +237,16 @@ static int pack_tensor_handle(starpu_data_handle_t handle, unsigned node, void *
 	struct starpu_tensor_interface *tensor_interface = (struct starpu_tensor_interface *)
 		starpu_data_get_interface_on_node(handle, node);
 
-	*count = tensor_interface->nx*tensor_interface->ny*tensor_interface->nz*tensor_interface->nt*tensor_interface->elemsize;
+	uint32_t ldy = tensor_interface->ldy;
+	uint32_t ldz = tensor_interface->ldz;
+	uint32_t ldt = tensor_interface->ldt;
+	uint32_t nx = tensor_interface->nx;
+	uint32_t ny = tensor_interface->ny;
+	uint32_t nz = tensor_interface->nz;
+	uint32_t nt = tensor_interface->nt;
+	size_t elemsize = tensor_interface->elemsize;
+
+	*count = nx*ny*nz*nt*elemsize;
 
 	if (ptr != NULL)
 	{
@@ -247,45 +256,45 @@ static int pack_tensor_handle(starpu_data_handle_t handle, unsigned node, void *
 		*ptr = (void *)starpu_malloc_on_node_flags(node, *count, 0);
 
 		char *cur = *ptr;
-		if (tensor_interface->nx * tensor_interface->ny * tensor_interface->nz == tensor_interface->ldt &&
-		    tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
-		    tensor_interface->nx == tensor_interface->ldy)
-			memcpy(cur, block, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->nt * tensor_interface->elemsize);
+		if (nx * ny * nz == ldt &&
+		    nx * ny == ldz &&
+		    nx == ldy)
+			memcpy(cur, block, nx * ny * nz * nt * elemsize);
 		else
 		{
 			char *block_t = block;
-			for(t=0 ; t<tensor_interface->nt ; t++)
+			for(t=0 ; t<nt ; t++)
 			{
-				if (tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
-				    tensor_interface->nx == tensor_interface->ldy)
+				if (nx * ny == ldz &&
+				    nx == ldy)
 				{
-					memcpy(cur, block_t, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->elemsize);
-					cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->nz*tensor_interface->elemsize;
+					memcpy(cur, block_t, nx * ny * nz * elemsize);
+					cur += nx*ny*nz*elemsize;
 				}
 				else
 				{
 					char *block_z = block_t;
-					for(z=0 ; z<tensor_interface->nz ; z++)
+					for(z=0 ; z<nz ; z++)
 					{
-						if (tensor_interface->nx == tensor_interface->ldy)
+						if (nx == ldy)
 						{
-							memcpy(cur, block_z, tensor_interface->nx * tensor_interface->ny * tensor_interface->elemsize);
-							cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->elemsize;
+							memcpy(cur, block_z, nx * ny * elemsize);
+							cur += nx*ny*elemsize;
 						}
 						else
 						{
 							char *block_y = block_z;
-							for(y=0 ; y<tensor_interface->ny ; y++)
+							for(y=0 ; y<ny ; y++)
 							{
-								memcpy(cur, block_y, tensor_interface->nx*tensor_interface->elemsize);
-								cur += tensor_interface->nx*tensor_interface->elemsize;
-								block_y += tensor_interface->ldy * tensor_interface->elemsize;
+								memcpy(cur, block_y, nx*elemsize);
+								cur += nx*elemsize;
+								block_y += ldy * elemsize;
 							}
 						}
-						block_z += tensor_interface->ldz * tensor_interface->elemsize;
+						block_z += ldz * elemsize;
 					}
 				}
-				block_t += tensor_interface->ldt * tensor_interface->elemsize;
+				block_t += ldt * elemsize;
 			}
 		}
 	}
@@ -300,51 +309,60 @@ static int unpack_tensor_handle(starpu_data_handle_t handle, unsigned node, void
 	struct starpu_tensor_interface *tensor_interface = (struct starpu_tensor_interface *)
 		starpu_data_get_interface_on_node(handle, node);
 
-	STARPU_ASSERT(count == tensor_interface->elemsize * tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->nt);
+	uint32_t ldy = tensor_interface->ldy;
+	uint32_t ldz = tensor_interface->ldz;
+	uint32_t ldt = tensor_interface->ldt;
+	uint32_t nx = tensor_interface->nx;
+	uint32_t ny = tensor_interface->ny;
+	uint32_t nz = tensor_interface->nz;
+	uint32_t nt = tensor_interface->nt;
+	size_t elemsize = tensor_interface->elemsize;
+
+	STARPU_ASSERT(count == elemsize * nx * ny * nz * nt);
 
 	uint32_t t, z, y;
 	char *cur = ptr;
 	char *block = (void *)tensor_interface->ptr;
 
-	if (tensor_interface->nx * tensor_interface->ny * tensor_interface->nz == tensor_interface->ldt &&
-	    tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
-	    tensor_interface->nx == tensor_interface->ldy)
-		memcpy(block, cur, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->nt * tensor_interface->elemsize);
+	if (nx * ny * nz == ldt &&
+	    nx * ny == ldz &&
+	    nx == ldy)
+		memcpy(block, cur, nx * ny * nz * nt * elemsize);
 	else
 	{
 		char *block_t = block;
-		for(t=0 ; t<tensor_interface->nt ; t++)
+		for(t=0 ; t<nt ; t++)
 		{
-			if (tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
-			    tensor_interface->nx == tensor_interface->ldy)
+			if (nx * ny == ldz &&
+			    nx == ldy)
 			{
-				memcpy(block_t, cur, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->elemsize);
-				cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->nz*tensor_interface->elemsize;
+				memcpy(block_t, cur, nx * ny * nz * elemsize);
+				cur += nx*ny*nz*elemsize;
 			}
 			else
 			{
 				char *block_z = block_t;
-				for(z=0 ; z<tensor_interface->nz ; z++)
+				for(z=0 ; z<nz ; z++)
 				{
-					if (tensor_interface->nx == tensor_interface->ldy)
+					if (nx == ldy)
 					{
-						memcpy(block_z, cur, tensor_interface->nx * tensor_interface->ny * tensor_interface->elemsize);
-						cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->elemsize;
+						memcpy(block_z, cur, nx * ny * elemsize);
+						cur += nx*ny*elemsize;
 					}
 					else
 					{
 						char *block_y = block_z;
-						for(y=0 ; y<tensor_interface->ny ; y++)
+						for(y=0 ; y<ny ; y++)
 						{
-							memcpy(block_y, cur, tensor_interface->nx*tensor_interface->elemsize);
-							cur += tensor_interface->nx*tensor_interface->elemsize;
-							block_y += tensor_interface->ldy * tensor_interface->elemsize;
+							memcpy(block_y, cur, nx*elemsize);
+							cur += nx*elemsize;
+							block_y += ldy * elemsize;
 						}
 					}
-					block_z += tensor_interface->ldz * tensor_interface->elemsize;
+					block_z += ldz * elemsize;
 				}
 			}
-			block_t += tensor_interface->ldt * tensor_interface->elemsize;
+			block_t += ldt * elemsize;
 		}
 	}