Pārlūkot izejas kodu

mpi/burst_gemm: add case where workers alternate between tasks and polling

Philippe SWARTVAGHER 5 gadi atpakaļ
vecāks
revīzija
1c5693560b
3 mainītis faili ar 75 papildinājumiem un 0 dzēšanām
  1. 25 0
      mpi/tests/burst_gemm.c
  2. 48 0
      mpi/tests/gemm_helper.c
  3. 2 0
      mpi/tests/gemm_helper.h

+ 25 - 0
mpi/tests/burst_gemm.c

@@ -146,19 +146,25 @@ int main(int argc, char **argv)
 	FPRINTF(stderr, "** Burst warmup **\n");
 	burst_all(mpi_rank);
 
+
 	starpu_sleep(0.3); // sleep to easily distinguish different bursts in traces
 
+
 	FPRINTF(stderr, "** Burst while there is no task available, but workers are polling **\n");
 	burst_all(mpi_rank);
 
+
 	starpu_sleep(0.3); // sleep to easily distinguish different bursts in traces
 
+
 	FPRINTF(stderr, "** Burst while there is no task available, workers are paused **\n");
 	starpu_pause();
 	burst_all(mpi_rank);
 
+
 	starpu_sleep(0.3); // sleep to easily distinguish different bursts in traces
 
+
 	FPRINTF(stderr, "** Burst while workers are really working **\n");
 	if(gemm_submit_tasks() == -ENODEV)
 		goto enodev;
@@ -172,8 +178,10 @@ int main(int argc, char **argv)
 	starpu_task_wait_for_all();
 	starpu_mpi_barrier(MPI_COMM_WORLD);
 
+
 	starpu_sleep(0.3); // sleep to easily distinguish different parts in traces
 
+
 	FPRINTF(stderr, "** Workers are computing, without communications **\n");
 	starpu_pause();
 	if(gemm_submit_tasks() == -ENODEV)
@@ -184,6 +192,23 @@ int main(int argc, char **argv)
 	starpu_task_wait_for_all();
 	starpu_mpi_barrier(MPI_COMM_WORLD);
 
+
+	starpu_sleep(0.3); // sleep to easily distinguish different parts in traces
+
+
+	FPRINTF(stderr, "** Burst while workers are computing, but polling a moment between each task **\n");
+	starpu_pause();
+	gemm_add_polling_dependencies();
+	if(gemm_submit_tasks_with_tags(/* enable task tags */ 1) == -ENODEV)
+		goto enodev;
+	starpu_resume();
+
+	burst_all(mpi_rank);
+
+	/* Wait for everything and everybody: */
+	starpu_task_wait_for_all();
+	starpu_mpi_barrier(MPI_COMM_WORLD);
+
 enodev:
 	gemm_release();
 	burst_free_data(mpi_rank);

+ 48 - 0
mpi/tests/gemm_helper.c

@@ -240,8 +240,14 @@ int gemm_init_data()
 /* Submit tasks to compute the GEMM */
 int gemm_submit_tasks()
 {
+	return gemm_submit_tasks_with_tags(/* by default, disable task tags */ 0);
+}
+
+int gemm_submit_tasks_with_tags(int with_tags)
+{
 	int ret;
 	unsigned x, y;
+	starpu_tag_t task_tag = 0;
 
 	for (x = 0; x < nslices; x++)
 	for (y = 0; y < nslices; y++)
@@ -253,6 +259,12 @@ int gemm_submit_tasks()
 		task->handles[2] = starpu_data_get_sub_data(C_handle, 2, x, y);
 		task->flops = 2ULL * (matrix_dim/nslices) * (matrix_dim/nslices) * matrix_dim;
 
+		if (with_tags)
+		{
+			task->use_tag = 1;
+			task->tag_id = ++task_tag;
+		}
+
 		ret = starpu_task_submit(task);
 		CHECK_TASK_SUBMIT(ret);
 		starpu_data_wont_use(starpu_data_get_sub_data(C_handle, 2, x, y));
@@ -261,6 +273,42 @@ int gemm_submit_tasks()
 	return 0;
 }
 
+/* Add dependencies between GEMM tasks to see the impact of polling workers which will at the end get a task.
+ * The new dependency graph has the following shape:
+ * - the same number of GEMMs as the number of workers are executed in parallel on all workers ("a column of tasks")
+ * - then a GEMM waits all tasks of the previous column of tasks, and is executed on a worker
+ * - the next column of tasks waits for the previous GEMM
+ * - and so on...
+ *
+ * worker 0 |  1  |  4  |  5  |  8  |  9  |
+ * worker 1 |  2  |     |  6  |     | 10  |  ...
+ * worker 2 |  3  |     |  7  |     | 11  |
+ *
+ * This function has to be called before gemm_submit_tasks_with_tags(1).
+ */
+void gemm_add_polling_dependencies()
+{
+	int nb_tasks = nslices * nslices;
+	unsigned nb_workers = starpu_worker_get_count();
+
+	for (starpu_tag_t synchro_tag = nb_workers+1; synchro_tag <= nb_tasks; synchro_tag += (nb_workers+1))
+	{
+		// this synchro tag depends on tasks of previous column of tasks:
+		for (starpu_tag_t previous_tag = synchro_tag - nb_workers; previous_tag < synchro_tag; previous_tag++)
+		{
+			starpu_tag_declare_deps(synchro_tag, 1, previous_tag);
+		}
+
+		// tasks of the next column of tasks depend on this synchro tag:
+		// this actually allows workers to poll for new tasks, while no task is available
+		for (starpu_tag_t next_tag = synchro_tag+1; next_tag < (synchro_tag + nb_workers + 1) && next_tag <= nb_tasks; next_tag++)
+		{
+			starpu_tag_declare_deps(next_tag, 1, synchro_tag);
+		}
+	}
+
+}
+
 void gemm_release()
 {
 	starpu_data_unpartition(C_handle, STARPU_MAIN_RAM);

+ 2 - 0
mpi/tests/gemm_helper.h

@@ -29,5 +29,7 @@ void gemm_alloc_data();
 int gemm_init_data();
 int gemm_submit_tasks();
 void gemm_release();
+void gemm_add_polling_dependencies();
+int gemm_submit_tasks_with_tags(int with_tags);
 
 #endif /* __MPI_TESTS_GEMM_HELPER__ */