Forráskód Böngészése

Allow context without scheduelr (parallel workers) to be tracked by their parent's scheduler (dmda)

Terry Cojean 10 éve
szülő
commit
e07df2c78f

+ 2 - 0
include/starpu_sched_ctx.h

@@ -73,6 +73,8 @@ unsigned starpu_sched_ctx_contains_type_of_worker(enum starpu_worker_archtype ar
 
 unsigned starpu_sched_ctx_worker_get_id(unsigned sched_ctx_id);
 
+unsigned starpu_sched_ctx_get_ctx_for_task(struct starpu_task *task);
+
 unsigned starpu_sched_ctx_overlapping_ctxs_on_worker(int workerid);
 
 int starpu_sched_get_min_priority(void);

+ 10 - 0
src/core/sched_ctx.c

@@ -1675,6 +1675,16 @@ unsigned starpu_sched_ctx_worker_get_id(unsigned sched_ctx_id)
 	return -1;
 }
 
+unsigned starpu_sched_ctx_get_ctx_for_task(struct starpu_task *task)
+{
+	struct _starpu_sched_ctx *sched_ctx = _starpu_get_sched_ctx_struct(task->sched_ctx);
+	unsigned ret_sched_ctx = task->sched_ctx;
+	if (task->possibly_parallel && !sched_ctx->sched_policy
+	    && sched_ctx->nesting_sched_ctx != STARPU_NMAX_SCHED_CTXS)
+		 ret_sched_ctx = sched_ctx->nesting_sched_ctx;
+	return ret_sched_ctx;
+}
+
 unsigned starpu_sched_ctx_overlapping_ctxs_on_worker(int workerid)
 {
 	struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);

+ 4 - 3
src/core/sched_policy.c

@@ -933,7 +933,8 @@ struct starpu_task *_starpu_pop_every_task(struct _starpu_sched_ctx *sched_ctx)
 
 void _starpu_sched_pre_exec_hook(struct starpu_task *task)
 {
-	struct _starpu_sched_ctx *sched_ctx = _starpu_get_sched_ctx_struct(task->sched_ctx);
+	unsigned sched_ctx_id = starpu_sched_ctx_get_ctx_for_task(task);
+	struct _starpu_sched_ctx *sched_ctx = _starpu_get_sched_ctx_struct(sched_ctx_id);
 	if (sched_ctx->sched_policy && sched_ctx->sched_policy->pre_exec_hook)
 	{
 		_STARPU_TRACE_WORKER_SCHEDULING_PUSH;
@@ -944,8 +945,8 @@ void _starpu_sched_pre_exec_hook(struct starpu_task *task)
 
 void _starpu_sched_post_exec_hook(struct starpu_task *task)
 {
-	struct _starpu_sched_ctx *sched_ctx = _starpu_get_sched_ctx_struct(task->sched_ctx);
-
+	unsigned sched_ctx_id = starpu_sched_ctx_get_ctx_for_task(task);
+	struct _starpu_sched_ctx *sched_ctx = _starpu_get_sched_ctx_struct(sched_ctx_id);
 	if (sched_ctx->sched_policy && sched_ctx->sched_policy->post_exec_hook)
 	{
 		_STARPU_TRACE_WORKER_SCHEDULING_PUSH;

+ 7 - 3
src/drivers/driver_common/driver_common.c

@@ -49,7 +49,7 @@ void _starpu_driver_start_job(struct _starpu_worker *worker, struct _starpu_job
 	/* If the job is executed on a combined worker there is no need for the
 	 * scheduler to process it : it doesn't contain any valuable data
 	 * as it's not linked to an actual worker */
-	if (j->task_size == 1)
+	if (j->task_size == 1 && rank == 0)
 		_starpu_sched_pre_exec_hook(task);
 
 	_starpu_set_worker_status(worker, STATUS_EXECUTING);
@@ -85,6 +85,7 @@ void _starpu_driver_start_job(struct _starpu_worker *worker, struct _starpu_job
 		{
 			struct starpu_worker_collection *workers = sched_ctx->workers;
 			struct starpu_sched_ctx_iterator it;
+			int new_rank = 0;
 
 			if (workers->init_iterator)
 				workers->init_iterator(workers, &it);
@@ -93,8 +94,9 @@ void _starpu_driver_start_job(struct _starpu_worker *worker, struct _starpu_job
 				int _workerid = workers->get_next(workers, &it);
 				if (_workerid != workerid)
 				{
+					new_rank++;
 					struct _starpu_worker *_worker = _starpu_get_worker_struct(_workerid);
-					_starpu_driver_start_job(_worker, j, &_worker->perf_arch, codelet_start, rank, profiling);
+					_starpu_driver_start_job(_worker, j, &_worker->perf_arch, codelet_start, new_rank, profiling);
 				}
 			}
 		}
@@ -151,6 +153,7 @@ void _starpu_driver_end_job(struct _starpu_worker *worker, struct _starpu_job *j
 	{
 		struct starpu_worker_collection *workers = sched_ctx->workers;
 		struct starpu_sched_ctx_iterator it;
+		int new_rank = 0;
 
 		if (workers->init_iterator)
 			workers->init_iterator(workers, &it);
@@ -159,8 +162,9 @@ void _starpu_driver_end_job(struct _starpu_worker *worker, struct _starpu_job *j
 			int _workerid = workers->get_next(workers, &it);
 			if (_workerid != workerid)
 			{
+				new_rank++;
 				struct _starpu_worker *_worker = _starpu_get_worker_struct(_workerid);
-				_starpu_driver_end_job(_worker, j, &_worker->perf_arch, codelet_end, rank, profiling);
+				_starpu_driver_end_job(_worker, j, &_worker->perf_arch, codelet_end, new_rank, profiling);
 			}
 		}
 	}

+ 3 - 3
src/sched_policies/deque_modeling_policy_data_aware.c

@@ -1062,7 +1062,7 @@ static void deinitialize_dmda_policy(unsigned sched_ctx_id)
  * value of the expected start, end, length, etc... */
 static void dmda_pre_exec_hook(struct starpu_task *task)
 {
-	unsigned sched_ctx_id = task->sched_ctx;
+	unsigned sched_ctx_id = starpu_sched_ctx_get_ctx_for_task(task);
 	int workerid = starpu_worker_get_id();
 	struct _starpu_dmda_data *dt = (struct _starpu_dmda_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	struct _starpu_fifo_taskq *fifo = dt->queue_array[workerid];
@@ -1174,8 +1174,8 @@ static void dmda_push_task_notify(struct starpu_task *task, int workerid, int pe
 
 static void dmda_post_exec_hook(struct starpu_task * task)
 {
-
-	struct _starpu_dmda_data *dt = (struct _starpu_dmda_data*)starpu_sched_ctx_get_policy_data(task->sched_ctx);
+	unsigned sched_ctx_id = starpu_sched_ctx_get_ctx_for_task(task);
+	struct _starpu_dmda_data *dt = (struct _starpu_dmda_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	int workerid = starpu_worker_get_id();
 	struct _starpu_fifo_taskq *fifo = dt->queue_array[workerid];
 	starpu_pthread_mutex_t *sched_mutex;