瀏覽代碼

Fix starpu_mpi_comm_get_attr calling convention for STARPU_MPI_TAG_UB according to MPI standard

Samuel Thibault 7 年之前
父節點
當前提交
9ab3f077e2
共有 2 個文件被更改,包括 6 次插入4 次删除
  1. 2 1
      mpi/src/mpi/starpu_mpi_mpi.c
  2. 4 3
      mpi/src/nmad/starpu_mpi_nmad.c

+ 2 - 1
mpi/src/mpi/starpu_mpi_mpi.c

@@ -1527,6 +1527,7 @@ void _starpu_mpi_progress_shutdown(void **value)
         STARPU_PTHREAD_COND_DESTROY(&barrier_cond);
 }
 
+static int64_t _starpu_mpi_tag_max = INT64_MAX;
 
 int starpu_mpi_comm_get_attr(MPI_Comm comm, int keyval, void *attribute_val, int *flag)
 {
@@ -1534,7 +1535,7 @@ int starpu_mpi_comm_get_attr(MPI_Comm comm, int keyval, void *attribute_val, int
 	if (keyval == STARPU_MPI_TAG_UB)
 	{
 		*flag = 1;
-		*(int64_t *)attribute_val = INT64_MAX;
+		*(int64_t **)attribute_val = &_starpu_mpi_tag_max;
 	}
 	else
 	{

+ 4 - 3
mpi/src/nmad/starpu_mpi_nmad.c

@@ -671,16 +671,17 @@ void _starpu_mpi_progress_shutdown(void **value)
         STARPU_PTHREAD_COND_DESTROY(&progress_cond);
 }
 
+static int64_t _starpu_mpi_tag_max = INT64_MAX;
 
 int starpu_mpi_comm_get_attr(MPI_Comm comm, int keyval, void *attribute_val, int *flag)
 {
 	(void) comm;
 	if (keyval == STARPU_MPI_TAG_UB)
 	{
-		const int64_t starpu_tag_max = INT64_MAX;
-		const nm_tag_t nm_tag_max = NM_TAG_MAX;
+		if ((uint64_t) _starpu_mpi_tag_max > NM_TAG_MAX)
+			_starpu_mpi_tag_max = NM_TAG_MAX;
 		/* manage case where nmad max tag causes overflow if represented as starpu tag */
-		*(int64_t *)attribute_val = (nm_tag_max > starpu_tag_max) ? starpu_tag_max : nm_tag_max;
+		*(int64_t **)attribute_val = &_starpu_mpi_tag_max;
 		*flag = 1;
 	}
 	else