Przeglądaj źródła

mpi/src: functions to free datatype are now defined per type, it allows to make sure complex datatypes are freed properly

Nathalie Furmento 12 lat temu
rodzic
commit
4ea8619303

+ 3 - 1
mpi/src/starpu_mpi.c

@@ -582,7 +582,9 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)
 				free(req->ptr);
 		}
 		else
-			MPI_Type_free(&req->datatype);
+		{
+			_starpu_mpi_handle_free_datatype(req->data_handle, &req->datatype);
+		}
 		starpu_data_release(req->data_handle);
 	}
 

+ 55 - 0
mpi/src/starpu_mpi_datatype.c

@@ -18,6 +18,7 @@
 #include <starpu_mpi_datatype.h>
 
 typedef int (*handle_to_datatype_func)(starpu_data_handle_t, MPI_Datatype *);
+typedef int (*handle_free_datatype_func)(MPI_Datatype *);
 
 /*
  * 	Matrix
@@ -145,3 +146,57 @@ void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_
 		*user_datatype = 1;
 	}
 }
+
+static int _starpu_mpi_handle_free_simple_datatype(MPI_Datatype *datatype)
+{
+	MPI_Type_free(datatype);
+	return 0;
+}
+
+static int _starpu_mpi_handle_free_complex_datatype(MPI_Datatype *datatype)
+{
+	int num_ints, num_adds, num_datatypes, combiner, i;
+	int *array_of_ints;
+	MPI_Aint *array_of_adds;
+	MPI_Datatype *array_of_datatypes;
+
+	MPI_Type_get_envelope(*datatype, &num_ints, &num_adds, &num_datatypes, &combiner);
+	if (combiner != MPI_COMBINER_NAMED)
+	{
+		array_of_ints = (int *) malloc(num_ints * sizeof(int));
+		array_of_adds = (MPI_Aint *) malloc(num_ints * sizeof(MPI_Aint));
+		array_of_datatypes = (MPI_Datatype *) malloc(num_ints * sizeof(MPI_Datatype));
+		MPI_Type_get_contents(*datatype, num_ints, num_adds, num_datatypes, array_of_ints, array_of_adds, array_of_datatypes);
+		for(i=0 ; i<num_datatypes ; i++)
+		{
+			_starpu_mpi_handle_free_complex_datatype(&array_of_datatypes[i]);
+		}
+		MPI_Type_free(datatype);
+	}
+	return 0;
+}
+
+static handle_free_datatype_func handle_free_datatype_funcs[STARPU_MAX_INTERFACE_ID] =
+{
+	[STARPU_MATRIX_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
+	[STARPU_BLOCK_INTERFACE_ID]	= _starpu_mpi_handle_free_complex_datatype,
+	[STARPU_VECTOR_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
+	[STARPU_CSR_INTERFACE_ID]	= NULL,
+	[STARPU_BCSR_INTERFACE_ID]	= NULL,
+	[STARPU_VARIABLE_INTERFACE_ID]	= _starpu_mpi_handle_free_simple_datatype,
+	[STARPU_VOID_INTERFACE_ID]      = NULL,
+	[STARPU_MULTIFORMAT_INTERFACE_ID] = NULL,
+};
+
+void _starpu_mpi_handle_free_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype)
+{
+	enum starpu_data_interface_id id = starpu_handle_get_interface_id(data_handle);
+
+	if (id < STARPU_MAX_INTERFACE_ID)
+	{
+		handle_free_datatype_func func = handle_free_datatype_funcs[id];
+		STARPU_ASSERT(func);
+		func(datatype);
+	}
+	/* else the datatype is not predefined by StarPU */
+}

+ 1 - 0
mpi/src/starpu_mpi_datatype.h

@@ -25,6 +25,7 @@ extern "C" {
 #endif
 
 void _starpu_mpi_handle_allocate_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype, int *user_datatype);
+void _starpu_mpi_handle_free_datatype(starpu_data_handle_t data_handle, MPI_Datatype *datatype);
 
 #ifdef __cplusplus
 }