|
@@ -16,7 +16,17 @@
|
|
|
|
|
|
#include <starpu_mpi_datatype.h>
|
|
|
|
|
|
+/*
|
|
|
+ * MPI_* functions usually requires both a pointer to the first element of
|
|
|
+ * a datatype and the datatype itself, so we need to provide both.
|
|
|
+ */
|
|
|
+
|
|
|
typedef int (*handle_to_datatype_func)(starpu_data_handle, MPI_Datatype *);
|
|
|
+typedef void *(*handle_to_ptr_func)(starpu_data_handle);
|
|
|
+
|
|
|
+/*
|
|
|
+ * Vector
|
|
|
+ */
|
|
|
|
|
|
static int handle_to_datatype_vector(starpu_data_handle data_handle, MPI_Datatype *datatype)
|
|
|
{
|
|
@@ -29,6 +39,15 @@ static int handle_to_datatype_vector(starpu_data_handle data_handle, MPI_Datatyp
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
+static void *handle_to_ptr_vector(starpu_data_handle data_handle)
|
|
|
+{
|
|
|
+ return (void *)starpu_get_vector_local_ptr(data_handle);
|
|
|
+}
|
|
|
+
|
|
|
+/*
|
|
|
+ * Generic
|
|
|
+ */
|
|
|
+
|
|
|
static handle_to_datatype_func handle_to_datatype_funcs[STARPU_NINTERFACES_ID] = {
|
|
|
[STARPU_BLAS_INTERFACE_ID] = NULL,
|
|
|
[STARPU_BLOCK_INTERFACE_ID] = NULL,
|
|
@@ -38,6 +57,15 @@ static handle_to_datatype_func handle_to_datatype_funcs[STARPU_NINTERFACES_ID] =
|
|
|
[STARPU_BCSCR_INTERFACE_ID] = NULL
|
|
|
};
|
|
|
|
|
|
+static handle_to_ptr_func handle_to_ptr_funcs[STARPU_NINTERFACES_ID] = {
|
|
|
+ [STARPU_BLAS_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_BLOCK_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_VECTOR_INTERFACE_ID] = handle_to_ptr_vector,
|
|
|
+ [STARPU_CSR_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_CSC_INTERFACE_ID] = NULL,
|
|
|
+ [STARPU_BCSCR_INTERFACE_ID] = NULL
|
|
|
+};
|
|
|
+
|
|
|
int starpu_mpi_handle_to_datatype(starpu_data_handle data_handle, MPI_Datatype *datatype)
|
|
|
{
|
|
|
unsigned id = starpu_get_handle_interface_id(data_handle);
|
|
@@ -48,3 +76,14 @@ int starpu_mpi_handle_to_datatype(starpu_data_handle data_handle, MPI_Datatype *
|
|
|
|
|
|
return func(data_handle, datatype);
|
|
|
}
|
|
|
+
|
|
|
+void *starpu_mpi_handle_to_ptr(starpu_data_handle data_handle)
|
|
|
+{
|
|
|
+ unsigned id = starpu_get_handle_interface_id(data_handle);
|
|
|
+
|
|
|
+ handle_to_ptr_func func = handle_to_ptr_funcs[id];
|
|
|
+
|
|
|
+ STARPU_ASSERT(func);
|
|
|
+
|
|
|
+ return func(data_handle);
|
|
|
+}
|