浏览代码

handle termination properly

Cédric Augonnet 15 年之前
父节点
当前提交
eee6350711
共有 2 个文件被更改,包括 49 次插入6 次删除
  1. 2 2
      mpi/examples/mpi_lu/plu_example.c
  2. 47 4
      mpi/examples/mpi_lu/pxlu.c

+ 2 - 2
mpi/examples/mpi_lu/plu_example.c

@@ -239,8 +239,8 @@ int main(int argc, char **argv)
 
 	barrier_ret = MPI_Barrier(MPI_COMM_WORLD);
 	STARPU_ASSERT(barrier_ret == MPI_SUCCESS);
-	fprintf(stderr, "Rank %d PID %d\n", rank, getpid());
-	sleep(10);
+//	fprintf(stderr, "Rank %d PID %d\n", rank, getpid());
+//	sleep(10);
 
 	STARPU_PLU(plu_main)(nblocks, rank, world_size);
 

+ 47 - 4
mpi/examples/mpi_lu/pxlu.c

@@ -611,7 +611,52 @@ static void create_task_22(unsigned k, unsigned i, unsigned j)
 	}
 }
 
+static void wait_tag_and_fetch_handle(starpu_tag_t tag, starpu_data_handle handle)
+{
+	
+	fprintf(stderr, "Rank %d : waiting tag %lx\n", rank, tag);
+	starpu_tag_wait(tag);
+	fprintf(stderr, "Rank %d : tag %lx is done\n", rank, tag);
+
+	starpu_sync_data_with_mem(handle, STARPU_R);
+
+//	starpu_delete_data(handle);
+}
+
+static void wait_termination(void)
+{
+	unsigned k, i, j;
+	for (k = 0; k < nblocks; k++)
+	{
+		/* Wait task 11k if needed */
+		if (get_block_rank(k, k) == rank)
+		{
+			starpu_data_handle diag_block = STARPU_PLU(get_block_handle)(k, k);
+			wait_tag_and_fetch_handle(TAG11_SAVE(k), diag_block);
+		}
+		
 
+		for (i = k + 1; i < nblocks; i++)
+		{
+			/* Wait task 21ki is needed */
+			if (get_block_rank(i, k) == rank)
+			{
+				starpu_data_handle block21 = STARPU_PLU(get_block_handle)(k, i);
+				wait_tag_and_fetch_handle(TAG21_SAVE(k, i), block21);
+			}
+		}
+
+		for (j = k + 1; j < nblocks; j++)
+		{
+			/* Wait task 12kj is needed */
+			if (get_block_rank(k, j) == rank)
+			{
+				starpu_data_handle block12 = STARPU_PLU(get_block_handle)(j, k);
+				wait_tag_and_fetch_handle(TAG12_SAVE(k, j), block12);
+			}
+		}
+	}	
+}
 
 /*
  *	code to bootstrap the factorization 
@@ -656,10 +701,8 @@ void STARPU_PLU(plu_main)(unsigned _nblocks, int _rank, int _world_size)
 	fprintf(stderr, "Rank %d GO\n", rank);
 	starpu_tag_notify_from_apps(STARPU_TAG_INIT);
 
-	/* stall the application until the end of computations : note that it
-	 * may be liberated explicitely by MPI */
-	starpu_tag_wait(TAG11(nblocks-1));
-
+	wait_termination();
+	
 	gettimeofday(&end, NULL);
 
 	double timing = (double)((end.tv_sec - start.tv_sec)*1000000 + (end.tv_usec - start.tv_usec));