|
@@ -61,6 +61,7 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
|
|
|
int arg_type;
|
|
|
va_list varg_list;
|
|
|
int me, do_execute, xrank, nb_nodes;
|
|
|
+ size_t *size_on_nodes;
|
|
|
size_t arg_buffer_size = 0;
|
|
|
char *arg_buffer;
|
|
|
int dest=0, inconsistent_execute;
|
|
@@ -70,6 +71,8 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
|
|
|
MPI_Comm_rank(comm, &me);
|
|
|
MPI_Comm_size(comm, &nb_nodes);
|
|
|
|
|
|
+ size_on_nodes = (size_t *)calloc(1, nb_nodes * sizeof(size_t));
|
|
|
+
|
|
|
_starpu_mpi_tables_init();
|
|
|
|
|
|
/* Get the number of buffers and the size of the arguments */
|
|
@@ -97,9 +100,19 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
|
|
|
STARPU_ASSERT(xrank <= nb_nodes);
|
|
|
do_execute = 1;
|
|
|
}
|
|
|
- if (arg_type==STARPU_R || arg_type==STARPU_W || arg_type==STARPU_REDUX || arg_type==STARPU_RW || arg_type == STARPU_SCRATCH) {
|
|
|
+ else if (arg_type == STARPU_REDUX) {
|
|
|
+ starpu_data_handle_t data = va_arg(varg_list, starpu_data_handle_t);
|
|
|
+ if (data) {
|
|
|
+ int rank = starpu_data_get_rank(data);
|
|
|
+ struct starpu_data_interface_ops *ops;
|
|
|
+ ops = data->ops;
|
|
|
+ size_on_nodes[rank] += ops->get_size(data);
|
|
|
+ do_execute = 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (arg_type==STARPU_R || arg_type==STARPU_W || arg_type==STARPU_RW || arg_type == STARPU_SCRATCH) {
|
|
|
starpu_data_handle_t data = va_arg(varg_list, starpu_data_handle_t);
|
|
|
- if (arg_type & STARPU_W || arg_type & STARPU_REDUX) {
|
|
|
+ if (arg_type & STARPU_W) {
|
|
|
if (!data) {
|
|
|
/* We don't have anything allocated for this.
|
|
|
* The application knows we won't do anything
|
|
@@ -109,6 +122,7 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
|
|
|
* safeguard. */
|
|
|
_STARPU_MPI_DEBUG("oh oh\n");
|
|
|
_STARPU_MPI_LOG_OUT();
|
|
|
+ free(size_on_nodes);
|
|
|
return -EINVAL;
|
|
|
}
|
|
|
int mpi_rank = starpu_data_get_rank(data);
|
|
@@ -165,6 +179,7 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
|
|
|
if (inconsistent_execute == 1) {
|
|
|
if (xrank == -1) {
|
|
|
_STARPU_MPI_DEBUG("Different tasks are owning W data. Needs to specify which one is to execute the codelet, using STARPU_EXECUTE_ON_NODE or STARPU_EXECUTE_ON_DATA\n");
|
|
|
+ free(size_on_nodes);
|
|
|
return -EINVAL;
|
|
|
}
|
|
|
else {
|
|
@@ -177,6 +192,22 @@ int starpu_mpi_insert_task(MPI_Comm comm, struct starpu_codelet *codelet, ...)
|
|
|
do_execute = (me == xrank);
|
|
|
dest = xrank;
|
|
|
}
|
|
|
+ else {
|
|
|
+ int i;
|
|
|
+ size_t max_size = size_on_nodes[0];
|
|
|
+ xrank = 0;
|
|
|
+ for(i=1 ; i<nb_nodes ; i++) {
|
|
|
+ if (size_on_nodes[i] > max_size)
|
|
|
+ {
|
|
|
+ max_size = size_on_nodes[i];
|
|
|
+ xrank = i;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ free(size_on_nodes);
|
|
|
+ _STARPU_MPI_DEBUG("Node %d is having the most REDUX data\n", xrank);
|
|
|
+ do_execute = (me == xrank);
|
|
|
+ dest = xrank;
|
|
|
+ }
|
|
|
|
|
|
_STARPU_MPI_DEBUG("Executing %d - Sending to node %d\n", do_execute, dest);
|
|
|
|