|
@@ -1,6 +1,6 @@
|
|
/* StarPU --- Runtime system for heterogeneous multicore architectures.
|
|
/* StarPU --- Runtime system for heterogeneous multicore architectures.
|
|
*
|
|
*
|
|
- * Copyright (C) 2009-2019 Université de Bordeaux
|
|
|
|
|
|
+ * Copyright (C) 2009-2020 Université de Bordeaux
|
|
* Copyright (C) 2011,2012,2017 Inria
|
|
* Copyright (C) 2011,2012,2017 Inria
|
|
* Copyright (C) 2010-2017,2019 CNRS
|
|
* Copyright (C) 2010-2017,2019 CNRS
|
|
*
|
|
*
|
|
@@ -247,22 +247,46 @@ static int pack_tensor_handle(starpu_data_handle_t handle, unsigned node, void *
|
|
*ptr = (void *)starpu_malloc_on_node_flags(node, *count, 0);
|
|
*ptr = (void *)starpu_malloc_on_node_flags(node, *count, 0);
|
|
|
|
|
|
char *cur = *ptr;
|
|
char *cur = *ptr;
|
|
- char *block_t = block;
|
|
|
|
- for(t=0 ; t<tensor_interface->nt ; t++)
|
|
|
|
|
|
+ if (tensor_interface->nx * tensor_interface->ny * tensor_interface->nz == tensor_interface->ldt &&
|
|
|
|
+ tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
|
|
|
|
+ tensor_interface->nx == tensor_interface->ldy)
|
|
|
|
+ memcpy(cur, block, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->nt * tensor_interface->elemsize);
|
|
|
|
+ else
|
|
{
|
|
{
|
|
- char *block_z = block_t;
|
|
|
|
- for(z=0 ; z<tensor_interface->nz ; z++)
|
|
|
|
- {
|
|
|
|
- char *block_y = block_z;
|
|
|
|
- for(y=0 ; y<tensor_interface->ny ; y++)
|
|
|
|
|
|
+ char *block_t = block;
|
|
|
|
+ for(t=0 ; t<tensor_interface->nt ; t++)
|
|
{
|
|
{
|
|
- memcpy(cur, block_y, tensor_interface->nx*tensor_interface->elemsize);
|
|
|
|
- cur += tensor_interface->nx*tensor_interface->elemsize;
|
|
|
|
- block_y += tensor_interface->ldy * tensor_interface->elemsize;
|
|
|
|
|
|
+ if (tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
|
|
|
|
+ tensor_interface->nx == tensor_interface->ldy)
|
|
|
|
+ {
|
|
|
|
+ memcpy(cur, block_t, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->elemsize);
|
|
|
|
+ cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->nz*tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ else
|
|
|
|
+ {
|
|
|
|
+ char *block_z = block_t;
|
|
|
|
+ for(z=0 ; z<tensor_interface->nz ; z++)
|
|
|
|
+ {
|
|
|
|
+ if (tensor_interface->nx == tensor_interface->ldy)
|
|
|
|
+ {
|
|
|
|
+ memcpy(cur, block_z, tensor_interface->nx * tensor_interface->ny * tensor_interface->elemsize);
|
|
|
|
+ cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ else
|
|
|
|
+ {
|
|
|
|
+ char *block_y = block_z;
|
|
|
|
+ for(y=0 ; y<tensor_interface->ny ; y++)
|
|
|
|
+ {
|
|
|
|
+ memcpy(cur, block_y, tensor_interface->nx*tensor_interface->elemsize);
|
|
|
|
+ cur += tensor_interface->nx*tensor_interface->elemsize;
|
|
|
|
+ block_y += tensor_interface->ldy * tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ block_z += tensor_interface->ldz * tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ block_t += tensor_interface->ldt * tensor_interface->elemsize;
|
|
}
|
|
}
|
|
- block_z += tensor_interface->ldz * tensor_interface->elemsize;
|
|
|
|
- }
|
|
|
|
- block_t += tensor_interface->ldt * tensor_interface->elemsize;
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -281,22 +305,47 @@ static int unpack_tensor_handle(starpu_data_handle_t handle, unsigned node, void
|
|
uint32_t t, z, y;
|
|
uint32_t t, z, y;
|
|
char *cur = ptr;
|
|
char *cur = ptr;
|
|
char *block = (void *)tensor_interface->ptr;
|
|
char *block = (void *)tensor_interface->ptr;
|
|
- char *block_t = block;
|
|
|
|
- for(t=0 ; t<tensor_interface->nt ; t++)
|
|
|
|
|
|
+
|
|
|
|
+ if (tensor_interface->nx * tensor_interface->ny * tensor_interface->nz == tensor_interface->ldt &&
|
|
|
|
+ tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
|
|
|
|
+ tensor_interface->nx == tensor_interface->ldy)
|
|
|
|
+ memcpy(block, cur, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->nt * tensor_interface->elemsize);
|
|
|
|
+ else
|
|
{
|
|
{
|
|
- char *block_z = block_t;
|
|
|
|
- for(z=0 ; z<tensor_interface->nz ; z++)
|
|
|
|
- {
|
|
|
|
- char *block_y = block_z;
|
|
|
|
- for(y=0 ; y<tensor_interface->ny ; y++)
|
|
|
|
|
|
+ char *block_t = block;
|
|
|
|
+ for(t=0 ; t<tensor_interface->nt ; t++)
|
|
{
|
|
{
|
|
- memcpy(block_y, cur, tensor_interface->nx*tensor_interface->elemsize);
|
|
|
|
- cur += tensor_interface->nx*tensor_interface->elemsize;
|
|
|
|
- block_y += tensor_interface->ldy * tensor_interface->elemsize;
|
|
|
|
|
|
+ if (tensor_interface->nx * tensor_interface->ny == tensor_interface->ldz &&
|
|
|
|
+ tensor_interface->nx == tensor_interface->ldy)
|
|
|
|
+ {
|
|
|
|
+ memcpy(block_t, cur, tensor_interface->nx * tensor_interface->ny * tensor_interface->nz * tensor_interface->elemsize);
|
|
|
|
+ cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->nz*tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ else
|
|
|
|
+ {
|
|
|
|
+ char *block_z = block_t;
|
|
|
|
+ for(z=0 ; z<tensor_interface->nz ; z++)
|
|
|
|
+ {
|
|
|
|
+ if (tensor_interface->nx == tensor_interface->ldy)
|
|
|
|
+ {
|
|
|
|
+ memcpy(block_z, cur, tensor_interface->nx * tensor_interface->ny * tensor_interface->elemsize);
|
|
|
|
+ cur += tensor_interface->nx*tensor_interface->ny*tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ else
|
|
|
|
+ {
|
|
|
|
+ char *block_y = block_z;
|
|
|
|
+ for(y=0 ; y<tensor_interface->ny ; y++)
|
|
|
|
+ {
|
|
|
|
+ memcpy(block_y, cur, tensor_interface->nx*tensor_interface->elemsize);
|
|
|
|
+ cur += tensor_interface->nx*tensor_interface->elemsize;
|
|
|
|
+ block_y += tensor_interface->ldy * tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ block_z += tensor_interface->ldz * tensor_interface->elemsize;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ block_t += tensor_interface->ldt * tensor_interface->elemsize;
|
|
}
|
|
}
|
|
- block_z += tensor_interface->ldz * tensor_interface->elemsize;
|
|
|
|
- }
|
|
|
|
- block_t += tensor_interface->ldt * tensor_interface->elemsize;
|
|
|
|
}
|
|
}
|
|
|
|
|
|
starpu_free_on_node_flags(node, (uintptr_t)ptr, count, 0);
|
|
starpu_free_on_node_flags(node, (uintptr_t)ptr, count, 0);
|