Kaynağa Gözat

Move mpi_data initialization to separate _starpu_mpi_data_get

Samuel Thibault 7 yıl önce
ebeveyn
işleme
ec1e3d1133
2 değiştirilmiş dosya ile 16 ekleme ve 8 silme
  1. 13 8
      mpi/src/starpu_mpi.c
  2. 3 0
      mpi/src/starpu_mpi_private.h

+ 13 - 8
mpi/src/starpu_mpi.c

@@ -251,12 +251,11 @@ void _starpu_mpi_data_clear(starpu_data_handle_t data_handle)
 	free(data_handle->mpi_data);
 }
 
-void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, starpu_mpi_tag_t data_tag, int rank, MPI_Comm comm)
-{
-	struct _starpu_mpi_data *mpi_data;
-	if (data_handle->mpi_data)
+struct _starpu_mpi_data *_starpu_mpi_data_get(starpu_data_handle_t data_handle) {
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
+	if (mpi_data)
 	{
-		mpi_data = data_handle->mpi_data;
+		STARPU_ASSERT(mpi_data->magic == 42);
 	}
 	else
 	{
@@ -266,15 +265,21 @@ void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, starpu_mpi_
 		mpi_data->node_tag.rank = -1;
 		mpi_data->node_tag.comm = MPI_COMM_WORLD;
 		data_handle->mpi_data = mpi_data;
-#if defined(STARPU_USE_MPI_MPI)
-		_starpu_mpi_tag_data_register(data_handle, data_tag);
-#endif
 		_starpu_mpi_cache_data_init(data_handle);
 		_starpu_data_set_unregister_hook(data_handle, _starpu_mpi_data_clear);
 	}
+	return mpi_data;
+}
+
+void starpu_mpi_data_register_comm(starpu_data_handle_t data_handle, starpu_mpi_tag_t data_tag, int rank, MPI_Comm comm)
+{
+	struct _starpu_mpi_data *mpi_data = _starpu_mpi_data_get(data_handle);
 
 	if (data_tag != -1)
 	{
+#if defined(STARPU_USE_MPI_MPI)
+		_starpu_mpi_tag_data_register(data_handle, data_tag);
+#endif
 		mpi_data->node_tag.data_tag = data_tag;
 	}
 	if (rank != -1)

+ 3 - 0
mpi/src/starpu_mpi_private.h

@@ -195,6 +195,7 @@ struct _starpu_mpi_node_tag
 	starpu_mpi_tag_t data_tag;
 };
 
+/* Initialized in starpu_mpi_data_register_comm */
 struct _starpu_mpi_data
 {
 	int magic;
@@ -203,6 +204,8 @@ struct _starpu_mpi_data
 	int cache_received;
 };
 
+struct _starpu_mpi_data *_starpu_mpi_data_get(starpu_data_handle_t data_handle);
+
 struct _starpu_mpi_req;
 LIST_TYPE(_starpu_mpi_req,
 	/* description of the data at StarPU level */