Przeglądaj źródła

mpi/src/starpu_mpi_collective.c: factorize duplicated code

Nathalie Furmento 8 lat temu
rodzic
commit
3711794767
1 zmienionych plików z 37 dodań i 56 usunięć
  1. 37 56
      mpi/src/starpu_mpi_collective.c

+ 37 - 56
mpi/src/starpu_mpi_collective.c

@@ -1,6 +1,6 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
- * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016  CNRS
+ * Copyright (C) 2011, 2012, 2013, 2014, 2015, 2016, 2017  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -39,25 +39,23 @@ void _callback_collective(void *arg)
 	}
 }
 
-int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg)
+static
+int _callback_set(int rank, starpu_data_handle_t *data_handles, int count, int root, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg, void (**callback_func)(void *), struct _callback_arg **callback_arg)
 {
-	int rank;
-	int x;
-	struct _callback_arg *callback_arg = NULL;
-	void (*callback_func)(void *) = NULL;
 	void (*callback)(void *);
 
-	starpu_mpi_comm_rank(comm, &rank);
-
 	callback = (rank == root) ? scallback : rcallback;
-	if (callback)
+	if (*callback)
 	{
-		callback_func = _callback_collective;
-		_STARPU_MPI_MALLOC(callback_arg, sizeof(struct _callback_arg));
-		callback_arg->count = 0;
-		callback_arg->nb = 0;
-		callback_arg->callback = (rank == root) ? scallback : rcallback;
-		callback_arg->arg = (rank == root) ? sarg : rarg;
+		int x;
+
+		*callback_func = _callback_collective;
+
+		_STARPU_MPI_MALLOC(*callback_arg, sizeof(struct _callback_arg));
+		(*callback_arg)->count = 0;
+		(*callback_arg)->nb = 0;
+		(*callback_arg)->callback = (rank == root) ? scallback : rcallback;
+		(*callback_arg)->arg = (rank == root) ? sarg : rarg;
 
 		for(x = 0; x < count ; x++)
 		{
@@ -68,22 +66,38 @@ int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, i
 				STARPU_ASSERT_MSG(data_tag >= 0, "Invalid tag for data handle");
 				if ((rank == root) && (owner != root))
 				{
-					callback_arg->count ++;
+					(*callback_arg)->count ++;
 				}
 				if ((rank != root) && (owner == rank))
 				{
-					callback_arg->count ++;
+					(*callback_arg)->count ++;
 				}
 			}
 		}
 
-		if (!callback_arg->count)
+		if (!(*callback_arg)->count)
 		{
-			free(callback_arg);
-			return 0;
+			free(*callback_arg);
+			return 1;
 		}
 	}
 
+	return 0;
+}
+
+int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, int root, MPI_Comm comm, void (*scallback)(void *), void *sarg, void (*rcallback)(void *), void *rarg)
+{
+	int rank;
+	int x;
+	struct _callback_arg *callback_arg = NULL;
+	void (*callback_func)(void *) = NULL;
+
+	starpu_mpi_comm_rank(comm, &rank);
+
+	x = _callback_set(rank, data_handles, count, root, scallback, sarg, rcallback, rarg, &callback_func, &callback_arg);
+	if (x == 1)
+		return 0;
+
 	for(x = 0; x < count ; x++)
 	{
 		if (data_handles[x])
@@ -112,45 +126,12 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
 	int x;
 	struct _callback_arg *callback_arg = NULL;
 	void (*callback_func)(void *) = NULL;
-	void (*callback)(void *);
 
 	starpu_mpi_comm_rank(comm, &rank);
 
-	callback = (rank == root) ? scallback : rcallback;
-	if (callback)
-	{
-		callback_func = _callback_collective;
-
-		_STARPU_MPI_MALLOC(callback_arg, sizeof(struct _callback_arg));
-		callback_arg->count = 0;
-		callback_arg->nb = 0;
-		callback_arg->callback = callback;
-		callback_arg->arg = (rank == root) ? sarg : rarg;
-
-		for(x = 0; x < count ; x++)
-		{
-			if (data_handles[x])
-			{
-				int owner = starpu_mpi_data_get_rank(data_handles[x]);
-				int data_tag = starpu_mpi_data_get_tag(data_handles[x]);
-				STARPU_ASSERT_MSG(data_tag >= 0, "Invalid tag for data handle");
-				if ((rank == root) && (owner != root))
-				{
-					callback_arg->count ++;
-				}
-				if ((rank != root) && (owner == rank))
-				{
-					callback_arg->count ++;
-				}
-			}
-		}
-
-		if (!callback_arg->count)
-		{
-			free(callback_arg);
-			return 0;
-		}
-	}
+	x = _callback_set(rank, data_handles, count, root, scallback, sarg, rcallback, rarg, &callback_func, &callback_arg);
+	if (x == 1)
+		return 0;
 
 	for(x = 0; x < count ; x++)
 	{