Browse Source

mpi: Add the capability to define specific MPI datatypes for StarPU user-defined interfaces.

Nathalie Furmento 10 years ago
parent
commit
a26be674fb

+ 2 - 0
ChangeLog

@@ -101,6 +101,8 @@ New features:
     which worker to schedule
   * Add data access arbiters, to improve parallelism of concurrent data
     accesses, notably with STARPU_COMMUTE.
+  * Add the capability to define specific MPI datatypes for
+    StarPU user-defined interfaces.
 
 Small features:
   * Tasks can now have a name (via the field const char *name of

+ 8 - 0
doc/doxygen/chapters/api/mpi.doxy

@@ -192,6 +192,14 @@ todo
 \ingroup API_MPI_Support
 todo
 
+\fn int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
+\ingroup API_MPI_Support
+Register functions to create and free a MPI datatype for the given handle.
+
+\fn int starpu_mpi_datatype_unregister(starpu_data_handle_t handle);
+\ingroup API_MPI_Support
+Unregister the MPI datatype functions stored for the interface of the given handle.
+
 @name Communication Cache
 \ingroup API_MPI_Support
 

+ 5 - 0
mpi/include/starpu_mpi.h

@@ -108,6 +108,11 @@ int starpu_mpi_node_selection_set_current_policy(int policy);
 int starpu_mpi_cache_is_enabled();
 int starpu_mpi_cache_set(int enabled);
 
+typedef void (*starpu_mpi_datatype_allocate_func_t)(starpu_data_handle_t, MPI_Datatype *);
+typedef void (*starpu_mpi_datatype_free_func_t)(MPI_Datatype *);
+int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func);
+int starpu_mpi_datatype_unregister(starpu_data_handle_t handle);
+
 #ifdef __cplusplus
 }
 #endif

+ 23 - 22
mpi/src/starpu_mpi.c

@@ -94,7 +94,7 @@ static void _starpu_mpi_request_init(struct _starpu_mpi_req **req)
 	(*req)->datatype = 0;
 	(*req)->ptr = NULL;
 	(*req)->count = -1;
-	(*req)->user_datatype = -1;
+	(*req)->registered_datatype = -1;
 
 	(*req)->node_tag.rank = -1;
 	(*req)->node_tag.data_tag = -1;
@@ -167,8 +167,8 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 		 * before the next submission of the envelope-catching request. */
 		if (req->is_internal_req)
 		{
-			_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->user_datatype);
-			if (req->user_datatype == 0)
+			_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->registered_datatype);
+			if (req->registered_datatype == 1)
 			{
 				req->count = 1;
 				req->ptr = starpu_data_get_local_ptr(req->data_handle);
@@ -180,9 +180,9 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 				STARPU_MPI_ASSERT_MSG(req->ptr, "cannot allocate message of size %ld\n", req->count);
 			}
 
-			_STARPU_MPI_DEBUG(3, "Pushing internal starpu_mpi_irecv request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d user_datatype %d \n",
+			_STARPU_MPI_DEBUG(3, "Pushing internal starpu_mpi_irecv request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d \n",
 					  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr,
-					  _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype);
+					  _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype);
 			_starpu_mpi_req_list_push_front(ready_requests, req);
 
 			/* inform the starpu mpi thread that the request has been pushed in the ready_requests list */
@@ -233,8 +233,8 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 				if (sync_req)
 				{
 					req->sync = 1;
-					_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->user_datatype);
-					if (req->user_datatype == 0)
+					_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->registered_datatype);
+					if (req->registered_datatype == 1)
 					{
 						req->count = 1;
 						req->ptr = starpu_data_get_local_ptr(req->data_handle);
@@ -260,8 +260,9 @@ static void _starpu_mpi_submit_ready_request(void *arg)
 	else
 	{
 		_starpu_mpi_req_list_push_front(ready_requests, req);
-		_STARPU_MPI_DEBUG(3, "Pushing new request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d user_datatype %d \n",
-				  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype);
+		_STARPU_MPI_DEBUG(3, "Pushing new request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d \n",
+				  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr,
+				  _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype);
 	}
 
 	newer_requests = 1;
@@ -321,7 +322,7 @@ static void _starpu_mpi_isend_data_func(struct _starpu_mpi_req *req)
 {
 	_STARPU_MPI_LOG_IN();
 
-	_STARPU_MPI_DEBUG(20, "post MPI isend request %p type %s tag %d src %d data %p datasize %ld ptr %p datatype '%s' count %d user_datatype %d sync %d\n", req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, starpu_data_get_size(req->data_handle), req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype, req->sync);
+	_STARPU_MPI_DEBUG(20, "post MPI isend request %p type %s tag %d src %d data %p datasize %ld ptr %p datatype '%s' count %d registered_datatype %d sync %d\n", req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, starpu_data_get_size(req->data_handle), req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype, req->sync);
 
 	_starpu_mpi_comm_amounts_inc(req->node_tag.comm, req->node_tag.rank, req->datatype, req->count);
 
@@ -355,14 +356,14 @@ static void _starpu_mpi_isend_data_func(struct _starpu_mpi_req *req)
 
 static void _starpu_mpi_isend_size_func(struct _starpu_mpi_req *req)
 {
-	_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->user_datatype);
+	_starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->registered_datatype);
 
 	req->envelope = calloc(1,sizeof(struct _starpu_mpi_envelope));
 	req->envelope->mode = _STARPU_MPI_ENVELOPE_DATA;
 	req->envelope->data_tag = req->node_tag.data_tag;
 	req->envelope->sync = req->sync;
 
-	if (req->user_datatype == 0)
+	if (req->registered_datatype == 1)
 	{
 		int size;
 		req->count = 1;
@@ -505,7 +506,7 @@ static void _starpu_mpi_irecv_data_func(struct _starpu_mpi_req *req)
 {
 	_STARPU_MPI_LOG_IN();
 
-	_STARPU_MPI_DEBUG(20, "post MPI irecv request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d user_datatype %d \n", req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype);
+	_STARPU_MPI_DEBUG(20, "post MPI irecv request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d \n", req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype);
 
 	_STARPU_MPI_TRACE_IRECV_SUBMIT_BEGIN(req->node_tag.rank, req->node_tag.data_tag);
 
@@ -707,8 +708,8 @@ static void _starpu_mpi_test_func(struct _starpu_mpi_req *testing_req)
 	/* Which is the mpi request we are testing for ? */
 	struct _starpu_mpi_req *req = testing_req->other_request;
 
-	_STARPU_MPI_DEBUG(2, "Test request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d user_datatype %d \n",
-			  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype);
+	_STARPU_MPI_DEBUG(2, "Test request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d \n",
+			  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype);
 
 	_STARPU_MPI_TRACE_UTESTING_BEGIN(req->node_tag.rank, req->node_tag.data_tag);
 
@@ -893,9 +894,9 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)
 {
 	_STARPU_MPI_LOG_IN();
 
-	_STARPU_MPI_DEBUG(2, "complete MPI request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d user_datatype %d internal_req %p\n",
+	_STARPU_MPI_DEBUG(2, "complete MPI request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d internal_req %p\n",
 			  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr,
-			  _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype, req->internal_req);
+			  _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype, req->internal_req);
 
 	if (req->internal_req)
 	{
@@ -909,7 +910,7 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)
 	{
 		if (req->request_type == RECV_REQ || req->request_type == SEND_REQ)
 		{
-			if (req->user_datatype == 1)
+			if (req->registered_datatype == 0)
 			{
 				if (req->request_type == SEND_REQ)
 				{
@@ -1118,8 +1119,8 @@ static void _starpu_mpi_handle_ready_request(struct _starpu_mpi_req *req)
 	STARPU_MPI_ASSERT_MSG(req, "Invalid request");
 
 	/* submit the request to MPI */
-	_STARPU_MPI_DEBUG(2, "Handling new request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d user_datatype %d \n",
-			  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->user_datatype);
+	_STARPU_MPI_DEBUG(2, "Handling new request %p type %s tag %d src %d data %p ptr %p datatype '%s' count %d registered_datatype %d \n",
+			  req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.rank, req->data_handle, req->ptr, _starpu_mpi_datatype(req->datatype), (int)req->count, req->registered_datatype);
 	req->func(req);
 
 	_STARPU_MPI_LOG_OUT();
@@ -1401,8 +1402,8 @@ static void *_starpu_mpi_progress_thread_func(void *arg)
 						_STARPU_MPI_DEBUG(2000, "Request sync %d\n", envelope->sync);
 
 						early_request->sync = envelope->sync;
-						_starpu_mpi_handle_allocate_datatype(early_request->data_handle, &early_request->datatype, &early_request->user_datatype);
-						if (early_request->user_datatype == 0)
+						_starpu_mpi_handle_allocate_datatype(early_request->data_handle, &early_request->datatype, &early_request->registered_datatype);
+						if (early_request->registered_datatype == 1)
 						{
 							early_request->count = 1;
 							early_request->ptr = starpu_data_get_local_ptr(early_request->data_handle);

+ 115 - 44
mpi/src/starpu_mpi_datatype.c

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2009-2011, 2015  Université de Bordeaux
- * Copyright (C) 2010, 2011, 2012, 2013, 2014  CNRS
+ * Copyright (C) 2010, 2011, 2012, 2013, 2014, 2015  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -16,9 +16,18 @@
  */
 
 #include <starpu_mpi_datatype.h>
+#include <common/uthash.h>
+#include <datawizard/coherency.h>
 
-typedef void (*handle_to_datatype_func)(starpu_data_handle_t, MPI_Datatype *);
-typedef void (*handle_free_datatype_func)(MPI_Datatype *);
+struct _starpu_mpi_datatype_funcs
+{
+	enum starpu_data_interface_id id;
+	starpu_mpi_datatype_allocate_func_t allocate_datatype_func;
+	starpu_mpi_datatype_free_func_t free_datatype_func;
+	UT_hash_handle hh;
+};
+
+static struct _starpu_mpi_datatype_funcs *_starpu_mpi_datatype_funcs_table = NULL;
 
 /*
  * 	Matrix
@@ -123,7 +132,7 @@ static void handle_to_datatype_void(starpu_data_handle_t data_handle STARPU_ATTR
  *	Generic
  */
 
-static handle_to_datatype_func handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
+static starpu_mpi_datatype_allocate_func_t handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
 {
 	[STARPU_MATRIX_INTERFACE_ID]	= handle_to_datatype_matrix,
 	[STARPU_BLOCK_INTERFACE_ID]	= handle_to_datatype_block,
@@ -135,22 +144,33 @@ static handle_to_datatype_func handle_to_datatype_funcs[STARPU_MAX_INTERFACE_ID]
 	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
 };
 
-void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype, int *user_datatype)
+void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype, int *registered_datatype)
 {
 	enum starpu_data_interface_id id = starpu_data_get_interface_id(data_handle);
 
 	if (id < STARPU_MAX_INTERFACE_ID)
 	{
-		handle_to_datatype_func func = handle_to_datatype_funcs[id];
+		starpu_mpi_datatype_allocate_func_t func = handle_to_datatype_funcs[id];
 		STARPU_ASSERT_MSG(func, "Handle To Datatype Function not defined for StarPU data interface %d", id);
 		func(data_handle, datatype);
-		*user_datatype = 0;
+		*registered_datatype = 1;
 	}
 	else
 	{
-		/* The datatype is not predefined by StarPU */
-		*datatype = MPI_BYTE;
-		*user_datatype = 1;
+		struct _starpu_mpi_datatype_funcs *table;
+		HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
+		if (table)
+		{
+			STARPU_ASSERT_MSG(table->allocate_datatype_func, "Handle To Datatype Function not defined for StarPU data interface %d", id);
+			table->allocate_datatype_func(data_handle, datatype);
+			*registered_datatype = 1;
+		}
+		else
+		{
+			/* The datatype is not predefined by StarPU */
+			*datatype = MPI_BYTE;
+			*registered_datatype = 0;
+		}
 	}
 }
 
@@ -184,7 +204,7 @@ static void _starpu_mpi_handle_free_complex_datatype(MPI_Datatype *datatype)
 	}
 }
 
-static handle_free_datatype_func handle_free_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
+static starpu_mpi_datatype_free_func_t handle_free_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
 {
 	[STARPU_MATRIX_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
 	[STARPU_BLOCK_INTERFACE_ID]	= _starpu_mpi_handle_free_complex_datatype,
@@ -202,45 +222,96 @@ void _starpu_mpi_handle_free_datatype(starpu_data_handle_t data_handle, MPI_Data
 
 	if (id < STARPU_MAX_INTERFACE_ID)
 	{
-		handle_free_datatype_func func = handle_free_datatype_funcs[id];
+		starpu_mpi_datatype_free_func_t func = handle_free_datatype_funcs[id];
 		STARPU_ASSERT_MSG(func, "Handle free datatype function not defined for StarPU data interface %d", id);
 		func(datatype);
 	}
+	else
+	{
+		struct _starpu_mpi_datatype_funcs *table;
+		HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
+		if (table)
+		{
+			STARPU_ASSERT_MSG(table->free_datatype_func, "Free Datatype Function not defined for StarPU data interface %d", id);
+			table->free_datatype_func(datatype);
+		}
+
+	}
 	/* else the datatype is not predefined by StarPU */
 }
 
 char *_starpu_mpi_datatype(MPI_Datatype datatype)
 {
-     if (datatype == MPI_DATATYPE_NULL) return "MPI_DATATYPE_NULL";
-     if (datatype == MPI_CHAR) return "MPI_CHAR";
-     if (datatype == MPI_UNSIGNED_CHAR) return "MPI_UNSIGNED_CHAR";
-     if (datatype == MPI_BYTE) return "MPI_BYTE";
-     if (datatype == MPI_SHORT) return "MPI_SHORT";
-     if (datatype == MPI_UNSIGNED_SHORT) return "MPI_UNSIGNED_SHORT";
-     if (datatype == MPI_INT) return "MPI_INT";
-     if (datatype == MPI_UNSIGNED) return "MPI_UNSIGNED";
-     if (datatype == MPI_LONG) return "MPI_LONG";
-     if (datatype == MPI_UNSIGNED_LONG) return "MPI_UNSIGNED_LONG";
-     if (datatype == MPI_FLOAT) return "MPI_FLOAT";
-     if (datatype == MPI_DOUBLE) return "MPI_DOUBLE";
-     if (datatype == MPI_LONG_DOUBLE) return "MPI_LONG_DOUBLE";
-     if (datatype == MPI_LONG_LONG) return "MPI_LONG_LONG";
-     if (datatype == MPI_LONG_INT) return "MPI_LONG_INT";
-     if (datatype == MPI_SHORT_INT) return "MPI_SHORT_INT";
-     if (datatype == MPI_FLOAT_INT) return "MPI_FLOAT_INT";
-     if (datatype == MPI_DOUBLE_INT) return "MPI_DOUBLE_INT";
-     if (datatype == MPI_2INT) return "MPI_2INT";
-     if (datatype == MPI_2DOUBLE_PRECISION) return "MPI_2DOUBLE_PRECISION";
-     if (datatype == MPI_COMPLEX) return "MPI_COMPLEX";
-     if (datatype == MPI_DOUBLE_COMPLEX) return "MPI_DOUBLE_COMPLEX";
-     if (datatype == MPI_LOGICAL) return "MPI_LOGICAL";
-     if (datatype == MPI_REAL) return "MPI_REAL";
-     if (datatype == MPI_REAL4) return "MPI_REAL4";
-     if (datatype == MPI_REAL8) return "MPI_REAL8";
-     if (datatype == MPI_DOUBLE_PRECISION) return "MPI_DOUBLE_PRECISION";
-     if (datatype == MPI_INTEGER) return "MPI_INTEGER";
-     if (datatype == MPI_INTEGER4) return "MPI_INTEGER4";
-     if (datatype == MPI_PACKED) return "MPI_PACKED";
-     if (datatype == 0) return "Unknown datatype";
-     return "User defined MPI Datatype";
+	if (datatype == MPI_DATATYPE_NULL) return "MPI_DATATYPE_NULL";
+	if (datatype == MPI_CHAR) return "MPI_CHAR";
+	if (datatype == MPI_UNSIGNED_CHAR) return "MPI_UNSIGNED_CHAR";
+	if (datatype == MPI_BYTE) return "MPI_BYTE";
+	if (datatype == MPI_SHORT) return "MPI_SHORT";
+	if (datatype == MPI_UNSIGNED_SHORT) return "MPI_UNSIGNED_SHORT";
+	if (datatype == MPI_INT) return "MPI_INT";
+	if (datatype == MPI_UNSIGNED) return "MPI_UNSIGNED";
+	if (datatype == MPI_LONG) return "MPI_LONG";
+	if (datatype == MPI_UNSIGNED_LONG) return "MPI_UNSIGNED_LONG";
+	if (datatype == MPI_FLOAT) return "MPI_FLOAT";
+	if (datatype == MPI_DOUBLE) return "MPI_DOUBLE";
+	if (datatype == MPI_LONG_DOUBLE) return "MPI_LONG_DOUBLE";
+	if (datatype == MPI_LONG_LONG) return "MPI_LONG_LONG";
+	if (datatype == MPI_LONG_INT) return "MPI_LONG_INT";
+	if (datatype == MPI_SHORT_INT) return "MPI_SHORT_INT";
+	if (datatype == MPI_FLOAT_INT) return "MPI_FLOAT_INT";
+	if (datatype == MPI_DOUBLE_INT) return "MPI_DOUBLE_INT";
+	if (datatype == MPI_2INT) return "MPI_2INT";
+	if (datatype == MPI_2DOUBLE_PRECISION) return "MPI_2DOUBLE_PRECISION";
+	if (datatype == MPI_COMPLEX) return "MPI_COMPLEX";
+	if (datatype == MPI_DOUBLE_COMPLEX) return "MPI_DOUBLE_COMPLEX";
+	if (datatype == MPI_LOGICAL) return "MPI_LOGICAL";
+	if (datatype == MPI_REAL) return "MPI_REAL";
+	if (datatype == MPI_REAL4) return "MPI_REAL4";
+	if (datatype == MPI_REAL8) return "MPI_REAL8";
+	if (datatype == MPI_DOUBLE_PRECISION) return "MPI_DOUBLE_PRECISION";
+	if (datatype == MPI_INTEGER) return "MPI_INTEGER";
+	if (datatype == MPI_INTEGER4) return "MPI_INTEGER4";
+	if (datatype == MPI_PACKED) return "MPI_PACKED";
+	if (datatype == 0) return "Unknown datatype";
+	return "User defined MPI Datatype";
+}
+
+int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func)
+{
+	enum starpu_data_interface_id id = starpu_data_get_interface_id(handle);
+	struct _starpu_mpi_datatype_funcs *table;
+
+	STARPU_ASSERT_MSG(id >= STARPU_MAX_INTERFACE_ID, "Cannot redefine the MPI datatype for a predefined StarPU datatype");
+
+	HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
+	if (table)
+	{
+		table->allocate_datatype_func = allocate_datatype_func;
+		table->free_datatype_func = free_datatype_func;
+	}
+	else
+	{
+		table = malloc(sizeof(struct _starpu_mpi_datatype_funcs));
+		table->id = id;
+		table->allocate_datatype_func = allocate_datatype_func;
+		table->free_datatype_func = free_datatype_func;
+		HASH_ADD_INT(_starpu_mpi_datatype_funcs_table, id, table);
+	}
+	STARPU_ASSERT_MSG(handle->ops->handle_to_pointer, "The data interface must define the operation 'handle_to_pointer'\n");
+	return 0;
+}
+
+int starpu_mpi_datatype_unregister(starpu_data_handle_t handle)
+{
+	enum starpu_data_interface_id id = starpu_data_get_interface_id(handle);
+	struct _starpu_mpi_datatype_funcs *table;
+
+	STARPU_ASSERT_MSG(id >= STARPU_MAX_INTERFACE_ID, "Cannot redefine the MPI datatype for a predefined StarPU datatype");
+
+	HASH_FIND_INT(_starpu_mpi_datatype_funcs_table, &id, table);
+	if (table)
+	{
+		HASH_DEL(_starpu_mpi_datatype_funcs_table, table);
+		free(table);
+	}
 }

+ 2 - 2
mpi/src/starpu_mpi_datatype.h

@@ -1,7 +1,7 @@
 /* StarPU --- Runtime system for heterogeneous multicore architectures.
  *
  * Copyright (C) 2009-2011  Université de Bordeaux
- * Copyright (C) 2010, 2012, 2013  CNRS
+ * Copyright (C) 2010, 2012, 2013, 2015  CNRS
  *
  * StarPU is free software; you can redistribute it and/or modify
  * it under the terms of the GNU Lesser General Public License as published by
@@ -24,7 +24,7 @@
 extern "C" {
 #endif
 
-void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype, int *user_datatype);
+void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype, int *registered_datatype);
 void _starpu_mpi_handle_free_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype);
 char *_starpu_mpi_datatype(MPI_Datatype datatype);
 

+ 1 - 1
mpi/src/starpu_mpi_private.h

@@ -166,7 +166,7 @@ LIST_TYPE(_starpu_mpi_req,
 	MPI_Datatype datatype;
 	void *ptr;
 	starpu_ssize_t count;
-	int user_datatype;
+	int registered_datatype;
 
 	/* who are we talking to ? */
 	struct _starpu_mpi_node_tag node_tag;