|
@@ -18,6 +18,7 @@
|
|
|
#include <starpu_mpi_datatype.h>
|
|
|
|
|
|
typedef int (*handle_to_datatype_func)(starpu_data_handle_t, MPI_Datatype *);
|
|
|
+typedef int (*handle_free_datatype_func)(MPI_Datatype *);
|
|
|
|
|
|
/*
|
|
|
* Matrix
|
|
@@ -145,3 +146,57 @@ void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_
|
|
|
*user_datatype = 1;
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+static int _starpu_mpi_handle_free_simple_datatype(MPI_Datatype *datatype)
|
|
|
+{
|
|
|
+ MPI_Type_free(datatype);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+static int _starpu_mpi_handle_free_complex_datatype(MPI_Datatype *datatype)
|
|
|
+{
|
|
|
+ int num_ints, num_adds, num_datatypes, combiner, i;
|
|
|
+ int *array_of_ints;
|
|
|
+ MPI_Aint *array_of_adds;
|
|
|
+ MPI_Datatype *array_of_datatypes;
|
|
|
+
|
|
|
+ MPI_Type_get_envelope(*datatype, &num_ints, &num_adds, &num_datatypes, &combiner);
|
|
|
+ if (combiner != MPI_COMBINER_NAMED)
|
|
|
+ {
|
|
|
+ array_of_ints = (int *) malloc(num_ints * sizeof(int));
|
|
|
+ array_of_adds = (MPI_Aint *) malloc(num_ints * sizeof(MPI_Aint));
|
|
|
+ array_of_datatypes = (MPI_Datatype *) malloc(num_ints * sizeof(MPI_Datatype));
|
|
|
+ MPI_Type_get_contents(*datatype, num_ints, num_adds, num_datatypes, array_of_ints, array_of_adds, array_of_datatypes);
|
|
|
+ for(i=0 ; i<num_datatypes ; i++)
|
|
|
+ {
|
|
|
+ _starpu_mpi_handle_free_complex_datatype(&array_of_datatypes[i]);
|
|
|
+ }
|
|
|
+ MPI_Type_free(datatype);
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+static handle_free_datatype_func handle_free_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
|
|
|
+{
|
|
|
+ [STARPU_MATRIX_INTERFACE_ID] = _starpu_mpi_handle_free_simple_datatype,
|
|
|
+ [STARPU_BLOCK_INTERFACE_ID] = _starpu_mpi_handle_free_complex_datatype,
|
|
|
+ [STARPU_VECTOR_INTERFACE_ID] = _starpu_mpi_handle_free_simple_datatype,
|
|
|
+ [STARPU_CSR_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_BCSR_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_VARIABLE_INTERFACE_ID] = _starpu_mpi_handle_free_simple_datatype,
|
|
|
+ [STARPU_VOID_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
|
|
|
+};
|
|
|
+
|
|
|
+void _starpu_mpi_handle_free_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
|
|
|
+{
|
|
|
+ enum starpu_data_interface_id id = starpu_handle_get_interface_id(data_handle);
|
|
|
+
|
|
|
+ if (id < STARPU_MAX_INTERFACE_ID)
|
|
|
+ {
|
|
|
+ handle_free_datatype_func func = handle_free_datatype_funcs[id];
|
|
|
+ STARPU_ASSERT(func);
|
|
|
+ func(datatype);
|
|
|
+ }
|
|
|
+ /* else the datatype is not predefined by StarPU */
|
|
|
+}
|