|
@@ -1,6 +1,6 @@
|
|
|
/* StarPU --- Runtime system for heterogeneous multicore architectures.
|
|
|
*
|
|
|
- * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016 CNRS
|
|
|
+ * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016, 2017 CNRS
|
|
|
*
|
|
|
* StarPU is free software; you can redistribute it and/or modify
|
|
|
* it under the terms of the GNU Lesser General Public License as published by
|
|
@@ -39,25 +39,23 @@ void _callback_collective(void *arg)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg)
|
|
|
+static
|
|
|
+int _callback_set(int rank, starpu_data_handle_t *data_handles, int count, int root, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg, void (**callback_func)(void *), struct _callback_arg **callback_arg)
|
|
|
{
|
|
|
- int rank;
|
|
|
- int x;
|
|
|
- struct _callback_arg *callback_arg = NULL;
|
|
|
- void (*callback_func)(void *) = NULL;
|
|
|
void (*callback)(void *);
|
|
|
|
|
|
- starpu_mpi_comm_rank(comm, &rank);
|
|
|
-
|
|
|
callback = (rank == root) ? scallback : rcallback;
|
|
|
- if (callback)
|
|
|
+ if (*callback)
|
|
|
{
|
|
|
- callback_func = _callback_collective;
|
|
|
- _STARPU_MPI_MALLOC(callback_arg, sizeof(struct _callback_arg));
|
|
|
- callback_arg->count = 0;
|
|
|
- callback_arg->nb = 0;
|
|
|
- callback_arg->callback = (rank == root) ? scallback : rcallback;
|
|
|
- callback_arg->arg = (rank == root) ? sarg : rarg;
|
|
|
+ int x;
|
|
|
+
|
|
|
+ *callback_func = _callback_collective;
|
|
|
+
|
|
|
+ _STARPU_MPI_MALLOC(*callback_arg, sizeof(struct _callback_arg));
|
|
|
+ (*callback_arg)->count = 0;
|
|
|
+ (*callback_arg)->nb = 0;
|
|
|
+ (*callback_arg)->callback = (rank == root) ? scallback : rcallback;
|
|
|
+ (*callback_arg)->arg = (rank == root) ? sarg : rarg;
|
|
|
|
|
|
for(x = 0; x < count ; x++)
|
|
|
{
|
|
@@ -68,22 +66,38 @@ int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, i
|
|
|
STARPU_ASSERT_MSG(data_tag >= 0, "Invalid tag for data handle");
|
|
|
if ((rank == root) && (owner != root))
|
|
|
{
|
|
|
- callback_arg->count ++;
|
|
|
+ (*callback_arg)->count ++;
|
|
|
}
|
|
|
if ((rank != root) && (owner == rank))
|
|
|
{
|
|
|
- callback_arg->count ++;
|
|
|
+ (*callback_arg)->count ++;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if (!callback_arg->count)
|
|
|
+ if (!(*callback_arg)->count)
|
|
|
{
|
|
|
- free(callback_arg);
|
|
|
- return 0;
|
|
|
+ free(*callback_arg);
|
|
|
+ return 1;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg)
|
|
|
+{
|
|
|
+ int rank;
|
|
|
+ int x;
|
|
|
+ struct _callback_arg *callback_arg = NULL;
|
|
|
+ void (*callback_func)(void *) = NULL;
|
|
|
+
|
|
|
+ starpu_mpi_comm_rank(comm, &rank);
|
|
|
+
|
|
|
+ x = _callback_set(rank, data_handles, count, root, scallback, sarg, rcallback, rarg, &callback_func, &callback_arg);
|
|
|
+ if (x == 1)
|
|
|
+ return 0;
|
|
|
+
|
|
|
for(x = 0; x < count ; x++)
|
|
|
{
|
|
|
if (data_handles[x])
|
|
@@ -112,45 +126,12 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
|
|
|
int x;
|
|
|
struct _callback_arg *callback_arg = NULL;
|
|
|
void (*callback_func)(void *) = NULL;
|
|
|
- void (*callback)(void *);
|
|
|
|
|
|
starpu_mpi_comm_rank(comm, &rank);
|
|
|
|
|
|
- callback = (rank == root) ? scallback : rcallback;
|
|
|
- if (callback)
|
|
|
- {
|
|
|
- callback_func = _callback_collective;
|
|
|
-
|
|
|
- _STARPU_MPI_MALLOC(callback_arg, sizeof(struct _callback_arg));
|
|
|
- callback_arg->count = 0;
|
|
|
- callback_arg->nb = 0;
|
|
|
- callback_arg->callback = callback;
|
|
|
- callback_arg->arg = (rank == root) ? sarg : rarg;
|
|
|
-
|
|
|
- for(x = 0; x < count ; x++)
|
|
|
- {
|
|
|
- if (data_handles[x])
|
|
|
- {
|
|
|
- int owner = starpu_mpi_data_get_rank(data_handles[x]);
|
|
|
- int data_tag = starpu_mpi_data_get_tag(data_handles[x]);
|
|
|
- STARPU_ASSERT_MSG(data_tag >= 0, "Invalid tag for data handle");
|
|
|
- if ((rank == root) && (owner != root))
|
|
|
- {
|
|
|
- callback_arg->count ++;
|
|
|
- }
|
|
|
- if ((rank != root) && (owner == rank))
|
|
|
- {
|
|
|
- callback_arg->count ++;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if (!callback_arg->count)
|
|
|
- {
|
|
|
- free(callback_arg);
|
|
|
- return 0;
|
|
|
- }
|
|
|
- }
|
|
|
+ x = _callback_set(rank, data_handles, count, root, scallback, sarg, rcallback, rarg, &callback_func, &callback_arg);
|
|
|
+ if (x == 1)
|
|
|
+ return 0;
|
|
|
|
|
|
for(x = 0; x < count ; x++)
|
|
|
{
|