浏览代码

mpi/src/starpu_mpi_collective.c: deal with NULL callback functions

Nathalie Furmento 12 年之前
父节点
当前提交
ed0d832366
共有 1 个文件被更改,包括 48 次插入26 次删除
  1. 48 26
      mpi/src/starpu_mpi_collective.c

+ 48 - 26
mpi/src/starpu_mpi_collective.c

@@ -42,32 +42,43 @@ int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, i
 	int rank;
 	int x;
 	struct _callback_arg *callback_arg;
+	void (*callback_func)(void *);
 
 	MPI_Comm_rank(comm, &rank);
 
 #ifdef STARPU_DEVEL
 #warning callback_arg needs to be free-ed
 #endif
+	callback_func = _callback_collective;
 	callback_arg = malloc(sizeof(struct _callback_arg));
 	callback_arg->count = 0;
 	callback_arg->nb = 0;
 	callback_arg->callback = (rank == root) ? scallback : rcallback;
 	callback_arg->arg = (rank == root) ? sarg : rarg;
+	if (callback_arg->callback == NULL)
+	{
+		free(callback_arg);
+		callback_arg = NULL;
+		callback_func = NULL;
+	}
 
-	for(x = 0; x < count ;  x++)
+	if (callback_arg)
 	{
-		if (data_handles[x])
+		for(x = 0; x < count ;  x++)
 		{
-			int owner = starpu_data_get_rank(data_handles[x]);
-			int mpi_tag = starpu_data_get_tag(data_handles[x]);
-			STARPU_ASSERT(mpi_tag >= 0);
-			if ((rank == root) && (owner != root))
-			{
-				callback_arg->count ++;
-			}
-			if ((rank != root) && (owner == rank))
+			if (data_handles[x])
 			{
-				callback_arg->count ++;
+				int owner = starpu_data_get_rank(data_handles[x]);
+				int mpi_tag = starpu_data_get_tag(data_handles[x]);
+				STARPU_ASSERT(mpi_tag >= 0);
+				if ((rank == root) && (owner != root))
+				{
+					callback_arg->count ++;
+				}
+				if ((rank != root) && (owner == rank))
+				{
+					callback_arg->count ++;
+				}
 			}
 		}
 	}
@@ -82,12 +93,12 @@ int starpu_mpi_scatter_detached(starpu_data_handle_t *data_handles, int count, i
 			if ((rank == root) && (owner != root))
 			{
 				//fprintf(stderr, "[%d] Sending data[%d] to %d\n", rank, x, owner);
-				starpu_mpi_isend_detached(data_handles[x], owner, mpi_tag, comm, _callback_collective, callback_arg);
+				starpu_mpi_isend_detached(data_handles[x], owner, mpi_tag, comm, callback_func, callback_arg);
 			}
 			if ((rank != root) && (owner == rank))
 			{
 				//fprintf(stderr, "[%d] Receiving data[%d] from %d\n", rank, x, root);
-				starpu_mpi_irecv_detached(data_handles[x], root, mpi_tag, comm, _callback_collective, callback_arg);
+				starpu_mpi_irecv_detached(data_handles[x], root, mpi_tag, comm, callback_func, callback_arg);
 			}
 		}
 	}
@@ -99,32 +110,43 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
 	int rank;
 	int x;
 	struct _callback_arg *callback_arg;
+	void (*callback_func)(void *);
 
 	MPI_Comm_rank(comm, &rank);
 
 #ifdef STARPU_DEVEL
 #warning callback_arg needs to be free-ed
 #endif
+	callback_func = _callback_collective;
 	callback_arg = malloc(sizeof(struct _callback_arg));
 	callback_arg->count = 0;
 	callback_arg->nb = 0;
 	callback_arg->callback = (rank == root) ? scallback : rcallback;
 	callback_arg->arg = (rank == root) ? sarg : rarg;
+	if (callback_arg->callback == NULL)
+	{
+		free(callback_arg);
+		callback_arg = NULL;
+		callback_func = NULL;
+	}
 
-	for(x = 0; x < count ;  x++)
+	if (callback_arg)
 	{
-		if (data_handles[x])
+		for(x = 0; x < count ;  x++)
 		{
-			int owner = starpu_data_get_rank(data_handles[x]);
-			int mpi_tag = starpu_data_get_tag(data_handles[x]);
-			STARPU_ASSERT(mpi_tag >= 0);
-			if ((rank == root) && (owner != root))
-			{
-				callback_arg->count ++;
-			}
-			if ((rank != root) && (owner == rank))
+			if (data_handles[x])
 			{
-				callback_arg->count ++;
+				int owner = starpu_data_get_rank(data_handles[x]);
+				int mpi_tag = starpu_data_get_tag(data_handles[x]);
+				STARPU_ASSERT(mpi_tag >= 0);
+				if ((rank == root) && (owner != root))
+				{
+					callback_arg->count ++;
+				}
+				if ((rank != root) && (owner == rank))
+				{
+					callback_arg->count ++;
+				}
 			}
 		}
 	}
@@ -139,12 +161,12 @@ int starpu_mpi_gather_detached(starpu_data_handle_t *data_handles, int count, in
 			if ((rank == root) && (owner != root))
 			{
 				//fprintf(stderr, "[%d] Receiving data[%d] from %d\n", rank, x, owner);
-				starpu_mpi_irecv_detached(data_handles[x], owner, mpi_tag, comm, _callback_collective, callback_arg);
+				starpu_mpi_irecv_detached(data_handles[x], owner, mpi_tag, comm, callback_func, callback_arg);
 			}
 			if ((rank != root) && (owner == rank))
 			{
 				//fprintf(stderr, "[%d] Sending data[%d] to %d\n", rank, x, root);
-				starpu_mpi_isend_detached(data_handles[x], root, mpi_tag, comm, _callback_collective, callback_arg);
+				starpu_mpi_isend_detached(data_handles[x], root, mpi_tag, comm, callback_func, callback_arg);
 			}
 		}
 	}