Browse Source

mpi: perform reduction only over contributing nodes

- added a long to describe a redux_map of contributing nodes
- added a test program to assess the behaviour

Fixes #3
Antoine JEGO 4 years ago
parent
commit
7a7abdc1c7

+ 27 - 1
mpi/examples/Makefile.am

@@ -375,7 +375,32 @@ starpu_mpi_EXAMPLES +=				\
 endif
 endif
 endif
- 						
+
+########################################
+# Native Fortran MPI STARPU_REDUX test #
+########################################
+
+if STARPU_HAVE_MPIFORT
+if !STARPU_SANITIZE
+examplebin_PROGRAMS +=		\
+	native_fortran/nf_redux_test
+
+native_fortran_nf_redux_test_SOURCES	=			\
+	native_fortran/fstarpu_mpi_mod.f90	\
+	native_fortran/fstarpu_mod.f90		\
+	native_fortran/nf_redux_test.f90	
+
+native_fortran_nf_redux_test_LDADD =					\
+	-lm
+
+if !STARPU_SIMGRID
+starpu_mpi_EXAMPLES +=				\
+	native_fortran/nf_redux_test
+endif
+endif
+endif
+
+
 ###################
 # complex example #
 ###################
@@ -467,6 +492,7 @@ native_fortran/nf_mm_cl.o: fstarpu_mod.mod
 native_fortran/nf_mm.o: nf_mm_cl.mod fstarpu_mpi_mod.mod fstarpu_mod.mod
 native_fortran/nf_mm_task_build.o: nf_mm_cl.mod fstarpu_mpi_mod.mod fstarpu_mod.mod
 native_fortran/nf_basic_ring.o: fstarpu_mpi_mod.mod fstarpu_mod.mod
+native_fortran/nf_redux_test.o: fstarpu_mpi_mod.mod fstarpu_mod.mod
 native_fortran/nf_mpi_redux.o: fstarpu_mpi_mod.mod fstarpu_mod.mod
 endif
 endif

+ 226 - 0
mpi/examples/native_fortran/nf_redux_test.f90

@@ -0,0 +1,226 @@
+program main
+  use iso_c_binding
+  use fstarpu_mod
+  use fstarpu_mpi_mod
+
+  implicit none
+  
+  integer, target                         :: ret, np, i, j
+  type(c_ptr)                             :: task_cl, task_rw_cl, task_red_cl, task_ini_cl
+  character(kind=c_char,len=*), parameter :: name=C_CHAR_"task"//C_NULL_CHAR
+  character(kind=c_char,len=*), parameter :: namered=C_CHAR_"task_red"//C_NULL_CHAR
+  character(kind=c_char,len=*), parameter :: nameini=C_CHAR_"task_ini"//C_NULL_CHAR
+  real(kind(1.d0)), target                :: a1, a2, b1, b2
+  integer(kind=8)                          :: tag, err
+  type(c_ptr)                             :: a1hdl, a2hdl, b1hdl, b2hdl
+  integer, target                         :: comm, comm_world, comm_w_rank, comm_size
+  integer(c_int), target                  :: w_node
+  
+  call fstarpu_fxt_autostart_profiling(0)
+  ret = fstarpu_init(c_null_ptr)
+  ret = fstarpu_mpi_init(1)
+
+  comm_world = fstarpu_mpi_world_comm()
+  comm_w_rank  = fstarpu_mpi_world_rank()
+  comm_size  = fstarpu_mpi_world_size()
+  if (comm_size.ne.4) then
+    write(*,'(" ")')
+    write(*,'("This application is meant to run with 4 MPI")')
+    stop 1
+  end if
+  err   = fstarpu_mpi_barrier(comm_world)
+
+  if(comm_w_rank.eq.0) then
+    write(*,'(" ")')
+    a1 = 1.0
+    write(*,*) "init_a1", a1
+    b1 = 0.5
+    write(*,*) "init b1", b1
+  end if
+  if(comm_w_rank.eq.1) then
+    write(*,'(" ")')
+    a2 = 2.0
+    write(*,*) "init_a2", a2
+    b2 = 0.8
+    write(*,*) "init b2", b2
+  end if
+
+  ! allocate and fill codelet structs
+  task_cl = fstarpu_codelet_allocate()
+  call fstarpu_codelet_set_name(task_cl, name)
+  call fstarpu_codelet_add_cpu_func(task_cl, C_FUNLOC(cl_cpu_task))
+  call fstarpu_codelet_add_buffer(task_cl, FSTARPU_REDUX)
+  call fstarpu_codelet_add_buffer(task_cl, FSTARPU_R)
+
+  ! allocate and reduction codelets
+  task_red_cl = fstarpu_codelet_allocate()
+  call fstarpu_codelet_set_name(task_red_cl, namered)
+  call fstarpu_codelet_add_cpu_func(task_red_cl,C_FUNLOC(cl_cpu_task_red))
+  call fstarpu_codelet_add_buffer(task_red_cl, FSTARPU_RW)
+  call fstarpu_codelet_add_buffer(task_red_cl, FSTARPU_R)
+
+  task_ini_cl = fstarpu_codelet_allocate()
+  call fstarpu_codelet_set_name(task_ini_cl, nameini)
+  call fstarpu_codelet_add_cpu_func(task_ini_cl,C_FUNLOC(cl_cpu_task_ini))
+  call fstarpu_codelet_add_buffer(task_ini_cl, FSTARPU_W)
+
+  err = fstarpu_mpi_barrier(comm_world)
+
+  tag = 0
+  if(comm_w_rank.eq.0) then
+        call fstarpu_variable_data_register(a1hdl, 0, c_loc(a1),c_sizeof(a1))
+        call fstarpu_variable_data_register(b1hdl, 0, c_loc(b1),c_sizeof(b1))
+  else
+        call fstarpu_variable_data_register(a1hdl, -1, c_null_ptr,c_sizeof(a1))
+        call fstarpu_variable_data_register(b1hdl, -1, c_null_ptr,c_sizeof(b1))
+  end if
+  call fstarpu_mpi_data_register(a1hdl,tag,0)
+  call fstarpu_mpi_data_register(b1hdl, tag+1,0)
+
+  tag = tag + 2
+  if(comm_w_rank.eq.1) then
+        call fstarpu_variable_data_register(a2hdl, 0, c_loc(a2),c_sizeof(a2))
+        call fstarpu_variable_data_register(b2hdl, 0, c_loc(b2),c_sizeof(b2))
+  else
+        call fstarpu_variable_data_register(a2hdl, -1, c_null_ptr,c_sizeof(a2))
+        call fstarpu_variable_data_register(b2hdl, -1, c_null_ptr,c_sizeof(b2))
+  end if
+  call fstarpu_mpi_data_register(a2hdl,tag,1)
+  call fstarpu_mpi_data_register(b2hdl, tag+1, 1)
+  tag = tag + 2
+
+  call fstarpu_data_set_reduction_methods(a1hdl, task_red_cl,task_ini_cl)
+  call fstarpu_data_set_reduction_methods(a2hdl, task_red_cl,task_ini_cl)
+
+  err = fstarpu_mpi_barrier(comm_world)
+  
+  
+  call fstarpu_fxt_start_profiling()
+
+  w_node = 3
+  comm = comm_world
+  call fstarpu_mpi_task_insert( (/ c_loc(comm),   &
+             task_cl,                                         &
+             FSTARPU_REDUX, a1hdl,                            &
+             FSTARPU_R, b1hdl,                                &
+             FSTARPU_EXECUTE_ON_NODE, c_loc(w_node),          &
+             C_NULL_PTR /))
+  w_node = 2
+  comm = comm_world
+  call fstarpu_mpi_task_insert( (/ c_loc(comm),   &
+             task_cl,                                         &
+             FSTARPU_REDUX, a2hdl,                            &
+             FSTARPU_R, b2hdl,                                &
+             FSTARPU_EXECUTE_ON_NODE, c_loc(w_node),          &
+             C_NULL_PTR /))
+  
+  call fstarpu_mpi_redux_data(comm_world, a1hdl)
+  call fstarpu_mpi_redux_data(comm_world, a2hdl)
+  ! write(*,*) "waiting all tasks ..."
+  err = fstarpu_mpi_wait_for_all(comm_world)
+  
+  if(comm_w_rank.eq.0) then
+     write(*,*) 'computed result ---> ',a1, "expected =",4.5
+  end if
+  if(comm_w_rank.eq.1) then
+     write(*,*) 'computed result ---> ',a2, "expected=",5.8
+  end if
+  call fstarpu_data_unregister(a1hdl)
+  call fstarpu_data_unregister(a2hdl)
+  call fstarpu_data_unregister(b1hdl)
+  call fstarpu_data_unregister(b2hdl)
+  
+  call fstarpu_fxt_stop_profiling()
+  call fstarpu_codelet_free(task_cl)
+  call fstarpu_codelet_free(task_red_cl)
+  call fstarpu_codelet_free(task_ini_cl)
+
+  
+  err = fstarpu_mpi_shutdown()
+  call fstarpu_shutdown()
+  
+  stop
+
+contains
+
+  recursive subroutine cl_cpu_task (buffers, cl_args) bind(C)
+    use iso_c_binding       ! C interfacing module
+    use fstarpu_mod         ! StarPU interfacing module
+    implicit none
+    
+    type(c_ptr), value, intent(in) :: buffers, cl_args ! cl_args is unused
+    integer(c_int) :: ret, worker_id
+    integer        :: comm_rank
+    integer, target :: i
+    real(kind(1.d0)), pointer :: a, b
+    real(kind(1.d0))          :: old_a
+
+    worker_id = fstarpu_worker_get_id()
+    comm_rank  = fstarpu_mpi_world_rank()
+
+    call c_f_pointer(fstarpu_variable_get_ptr(buffers, 0), a)
+    call c_f_pointer(fstarpu_variable_get_ptr(buffers, 1), b)
+    call sleep(1.d0)
+    old_a = a
+    a = 3.0 + b
+    write(*,*) "task   (c_w_rank:",comm_rank,") from ",old_a,"to",a
+    
+    return
+  end subroutine cl_cpu_task
+
+
+  recursive subroutine cl_cpu_task_red (buffers, cl_args) bind(C)
+    use iso_c_binding       ! C interfacing module
+    use fstarpu_mod         ! StarPU interfacing module
+    implicit none
+    
+    type(c_ptr), value, intent(in) :: buffers, cl_args ! cl_args is unused
+    integer(c_int) :: ret
+    integer, target                         :: comm_rank
+    real(kind(1.d0)), pointer :: as, ad
+    real(kind(1.d0))           :: old_ad
+    
+    comm_rank  = fstarpu_mpi_world_rank()
+    call c_f_pointer(fstarpu_variable_get_ptr(buffers, 0), ad)
+    call c_f_pointer(fstarpu_variable_get_ptr(buffers, 1), as)
+    old_ad = ad
+    ad = ad + as
+    call sleep(1.d0)
+    write(*,*) "red_cl (c_w_rank:",comm_rank,")",as, old_ad, ' ---> ',ad
+    
+    return
+  end subroutine cl_cpu_task_red
+
+  recursive subroutine cl_cpu_task_ini (buffers, cl_args) bind(C)
+    use iso_c_binding       ! C interfacing module
+    use fstarpu_mod         ! StarPU interfacing module
+    implicit none
+    
+    type(c_ptr), value, intent(in) :: buffers, cl_args 
+        ! cl_args is unused
+    integer(c_int) :: ret
+    integer, target                         :: comm_rank
+    real(kind(1.d0)), pointer :: a
+
+    comm_rank  = fstarpu_mpi_world_rank()
+    call c_f_pointer(fstarpu_variable_get_ptr(buffers, 0), a)
+    call sleep(0.5d0)
+    a = 0.0
+    write(*,*) "ini_cl (c_w_rank:",comm_rank,")"
+    return
+  end subroutine cl_cpu_task_ini
+
+
+  subroutine sleep(t)
+    implicit none
+    integer :: t_start, t_end, t_rate
+    real(kind(1.d0))     :: ta, t
+    call system_clock(t_start)
+    do
+       call system_clock(t_end, t_rate)
+       ta = real(t_end-t_start)/real(t_rate)
+       if(ta.gt.t) return
+    end do
+  end subroutine sleep
+
+end program main

+ 4 - 0
mpi/include/starpu_mpi.h

@@ -561,6 +561,10 @@ int starpu_mpi_data_get_rank(starpu_data_handle_t handle);
    Return the tag of the given data.
 */
 starpu_mpi_tag_t starpu_mpi_data_get_tag(starpu_data_handle_t handle);
+/**
+   Return the redux map of the given data.
+*/
+char* starpu_mpi_data_get_redux_map(starpu_data_handle_t handle);
 
 /**
    Symbol kept for backward compatibility. Call function starpu_mpi_data_get_tag()

+ 6 - 0
mpi/src/starpu_mpi.c

@@ -448,6 +448,12 @@ starpu_mpi_tag_t starpu_mpi_data_get_tag(starpu_data_handle_t data)
 	return ((struct _starpu_mpi_data *)(data->mpi_data))->node_tag.data_tag;
 }
 
+char* starpu_mpi_data_get_redux_map(starpu_data_handle_t data)
+{
+	STARPU_ASSERT_MSG(data->mpi_data, "starpu_mpi_data_register MUST be called for data %p\n", data);
+	return ((struct _starpu_mpi_data *)(data->mpi_data))->redux_map;
+}
+
 void starpu_mpi_get_data_on_node_detached(MPI_Comm comm, starpu_data_handle_t data_handle, int node, void (*callback)(void*), void *arg)
 {
 	int me, rank;

+ 10 - 0
mpi/src/starpu_mpi_private.h

@@ -203,6 +203,12 @@ struct _starpu_mpi_coop_sends
 	long pre_sync_jobid;
 };
 
+/** cf. redux_map field : this is the value
+ * put in this field whenever a node contributes
+ * to the reduction of the data.
+ * Only the owning node keeps track of all the contributing nodes. */
+#define REDUX_CONTRIB ((char*) -1)
+
 /** Initialized in starpu_mpi_data_register_comm */
 struct _starpu_mpi_data
 {
@@ -211,6 +217,10 @@ struct _starpu_mpi_data
 	char *cache_sent;
 	int cache_received;
 
+	/** Array used to store the contributing nodes to this data
+	  * when it is accessed in REDUX mode. */
+	char* redux_map;
+
 	/** Rendez-vous data for opportunistic cooperative sends,
 	  * Needed to synchronize between submit thread and workers */
 	struct _starpu_spinlock coop_lock;

+ 46 - 13
mpi/src/starpu_mpi_task_insert.c

@@ -617,6 +617,18 @@ int _starpu_mpi_task_postbuild_v(MPI_Comm comm, int xrank, int do_execute, struc
 
 	for(i=0 ; i<nb_data ; i++)
 	{
+		struct _starpu_mpi_data *mpi_data = (struct _starpu_mpi_data *) descrs[i].handle->mpi_data;
+		if (descrs[i].mode & STARPU_REDUX || descrs[i].mode & STARPU_MPI_REDUX)
+		{
+			if (me == starpu_mpi_data_get_rank(descrs[i].handle))
+			{ 
+				if (mpi_data->redux_map == NULL)
+					_STARPU_CALLOC(mpi_data->redux_map, 0, STARPU_MAXNODES * sizeof(mpi_data->redux_map[0]));
+				mpi_data->redux_map [xrank] = 1;
+			}
+			else if (me == xrank)
+				mpi_data->redux_map = REDUX_CONTRIB;
+		} 
 		_starpu_mpi_exchange_data_after_execution(descrs[i].handle, descrs[i].mode, me, xrank, do_execute, prio, comm);
 		_starpu_mpi_clear_data_after_execution(descrs[i].handle, descrs[i].mode, me, do_execute);
 	}
@@ -813,6 +825,11 @@ void _starpu_mpi_redux_fill_post_sync_jobid(const void * const redux_data_args,
 
 /* TODO: this should rather be implicitly called by starpu_mpi_task_insert when
  * a data previously accessed in REDUX mode gets accessed in R mode. */
+/* FIXME: In order to prevent simultaneous receive submissions
+ * on the same handle, we need to wait that all the starpu_mpi
+ * tasks are done before submitting next tasks. The current
+ * version of the implementation does not support multiple
+ * simultaneous receive requests on the same handle.*/
 void starpu_mpi_redux_data_prio(MPI_Comm comm, starpu_data_handle_t data_handle, int prio)
 {
 	int me, rank, nb_nodes;
@@ -820,6 +837,7 @@ void starpu_mpi_redux_data_prio(MPI_Comm comm, starpu_data_handle_t data_handle,
 
 	rank = starpu_mpi_data_get_rank(data_handle);
 	data_tag = starpu_mpi_data_get_tag(data_handle);
+	struct _starpu_mpi_data *mpi_data = data_handle->mpi_data;
 	if (rank == -1)
 	{
 		_STARPU_ERROR("StarPU needs to be told the MPI rank of this data, using starpu_mpi_data_register\n");
@@ -832,12 +850,16 @@ void starpu_mpi_redux_data_prio(MPI_Comm comm, starpu_data_handle_t data_handle,
 	starpu_mpi_comm_rank(comm, &me);
 	starpu_mpi_comm_size(comm, &nb_nodes);
 
-	_STARPU_MPI_DEBUG(1, "Doing reduction for data %p on node %d with %d nodes ...\n", data_handle, rank, nb_nodes);
-
+	_STARPU_MPI_DEBUG(50, "Doing reduction for data %p on node %d with %d nodes ...\n", data_handle, rank, nb_nodes);
 	// need to count how many nodes have the data in redux mode
 	if (me == rank)
 	{
-		int i;
+		int i,j;
+		_STARPU_MPI_DEBUG(50, "Who is in the map ?\n");
+		for (j = 0; j<nb_nodes; j++)
+		{
+			_STARPU_MPI_DEBUG(50, "%d is in the map ? %d\n", j, mpi_data->redux_map[j]);
+		}
 
 		// taskC depends on all taskBs created
 		// Creating synchronization task and use its jobid for tracing
@@ -848,8 +870,9 @@ void starpu_mpi_redux_data_prio(MPI_Comm comm, starpu_data_handle_t data_handle,
 
 		for(i=0 ; i<nb_nodes ; i++)
 		{
-			if (i != rank)
+			if (i != rank && mpi_data->redux_map[i])
 			{
+				_STARPU_MPI_DEBUG(5, "%d takes part in the reduction of %p \n", i, data_handle);
 				/* We need to make sure all is
 				 * executed after data_handle finished
 				 * its last read access, we hence do
@@ -893,24 +916,34 @@ void starpu_mpi_redux_data_prio(MPI_Comm comm, starpu_data_handle_t data_handle,
 						   STARPU_CALLBACK_WITH_ARG_NFREE, _starpu_mpi_redux_data_recv_callback, args,
 						   0);
 			}
+			else
+			{
+				_STARPU_MPI_DEBUG(5, "%d is not in the map or is me\n", i);
+			}
 		}
 
 		int ret = starpu_task_submit(taskC);
 		STARPU_ASSERT(ret == 0);
 	}
-	else
+	else if (mpi_data->redux_map)
 	{
-		_STARPU_MPI_DEBUG(1, "Sending redux handle to %d ...\n", rank);
+		STARPU_ASSERT(mpi_data->redux_map == REDUX_CONTRIB);
+		_STARPU_MPI_DEBUG(5, "Sending redux handle to %d ...\n", rank);
 		starpu_mpi_isend_detached_prio(data_handle, rank, data_tag, prio, comm, NULL, NULL);
 		starpu_data_invalidate_submit(data_handle);
 	}
-	/* FIXME: In order to prevent simultaneous receive submissions
-	 * on the same handle, we need to wait that all the starpu_mpi
-	 * tasks are done before submitting next tasks. The current
-	 * version of the implementation does not support multiple
-	 * simultaneous receive requests on the same handle.*/
-	starpu_task_wait_for_all();
-
+	else
+	{
+		_STARPU_MPI_DEBUG(5, "I am not in the map of %d, I am %d ...\n", rank, me);
+	}
+	if (mpi_data->redux_map != NULL) 
+	{ 
+		_STARPU_MPI_DEBUG(100, "waiting for redux tasks with %d\n", rank);
+		starpu_task_wait_for_all();
+	}
+	if (me == rank)
+		free(mpi_data->redux_map);
+	mpi_data->redux_map = NULL;
 }
 void starpu_mpi_redux_data(MPI_Comm comm, starpu_data_handle_t data_handle)
 {