浏览代码

dm: Add missing push_task_notify

Samuel Thibault 4 年之前
父节点
当前提交
08428a46b1
共有 1 个文件被更改,包括 39 次插入25 次删除
  1. 39 25
      src/sched_policies/deque_modeling_policy_data_aware.c

+ 39 - 25
src/sched_policies/deque_modeling_policy_data_aware.c

@@ -1092,7 +1092,7 @@ static void dmda_pre_exec_hook(struct starpu_task *task, unsigned sched_ctx_id)
 	starpu_worker_unlock_self();
 }
 
-static void dmda_push_task_notify(struct starpu_task *task, int workerid, int perf_workerid, unsigned sched_ctx_id)
+static void _dm_push_task_notify(struct starpu_task *task, int workerid, int perf_workerid, unsigned sched_ctx_id, int da)
 {
 	struct _starpu_dmda_data *dt = (struct _starpu_dmda_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	struct _starpu_fifo_taskq *fifo = &dt->queue_array[workerid];
@@ -1101,7 +1101,6 @@ static void dmda_push_task_notify(struct starpu_task *task, int workerid, int pe
 	double predicted = starpu_task_worker_expected_length(task, perf_workerid, STARPU_NMAX_SCHED_CTXS,
 						       starpu_task_get_implementation(task));
 
-	double predicted_transfer = starpu_task_expected_data_transfer_time_for(task, workerid);
 	double now = starpu_timing_now();
 
 	/* Update the predictions */
@@ -1110,32 +1109,36 @@ static void dmda_push_task_notify(struct starpu_task *task, int workerid, int pe
 	fifo->exp_start = isnan(fifo->exp_start) ? now + fifo->pipeline_len : STARPU_MAX(fifo->exp_start, now);
 	fifo->exp_end = fifo->exp_start + fifo->exp_len;
 
-	/* If there is no prediction available, we consider the task has a null length */
-	if (!isnan(predicted_transfer))
+	if (da)
 	{
-		if (now + predicted_transfer < fifo->exp_end)
-		{
-			/* We may hope that the transfer will be finished by
-			 * the start of the task. */
-			predicted_transfer = 0;
-		}
-		else
-		{
-			/* The transfer will not be finished by then, take the
-			 * remainder into account */
-			predicted_transfer = (now + predicted_transfer) - fifo->exp_end;
-		}
-		task->predicted_transfer = predicted_transfer;
-		fifo->exp_end += predicted_transfer;
-		fifo->exp_len += predicted_transfer;
-		if(dt->num_priorities != -1)
+		double predicted_transfer = starpu_task_expected_data_transfer_time_for(task, workerid);
+		/* If there is no prediction available, we consider the task has a null length */
+		if (!isnan(predicted_transfer))
 		{
-			int i;
-			int task_prio = _starpu_normalize_prio(task->priority, dt->num_priorities, task->sched_ctx);
-			for(i = 0; i <= task_prio; i++)
-				fifo->exp_len_per_priority[i] += predicted_transfer;
-		}
+			if (now + predicted_transfer < fifo->exp_end)
+			{
+				/* We may hope that the transfer will be finished by
+				 * the start of the task. */
+				predicted_transfer = 0;
+			}
+			else
+			{
+				/* The transfer will not be finished by then, take the
+				 * remainder into account */
+				predicted_transfer = (now + predicted_transfer) - fifo->exp_end;
+			}
+			task->predicted_transfer = predicted_transfer;
+			fifo->exp_end += predicted_transfer;
+			fifo->exp_len += predicted_transfer;
+			if(dt->num_priorities != -1)
+			{
+				int i;
+				int task_prio = _starpu_normalize_prio(task->priority, dt->num_priorities, task->sched_ctx);
+				for(i = 0; i <= task_prio; i++)
+					fifo->exp_len_per_priority[i] += predicted_transfer;
+			}
 
+		}
 	}
 
 	/* If there is no prediction available, we consider the task has a null length */
@@ -1166,6 +1169,16 @@ static void dmda_push_task_notify(struct starpu_task *task, int workerid, int pe
 	starpu_worker_unlock(workerid);
 }
 
+static void dm_push_task_notify(struct starpu_task *task, int workerid, int perf_workerid, unsigned sched_ctx_id)
+{
+	_dm_push_task_notify(task, workerid, perf_workerid, sched_ctx_id, 0);
+}
+
+static void dmda_push_task_notify(struct starpu_task *task, int workerid, int perf_workerid, unsigned sched_ctx_id)
+{
+	_dm_push_task_notify(task, workerid, perf_workerid, sched_ctx_id, 1);
+}
+
 static void dmda_post_exec_hook(struct starpu_task * task, unsigned sched_ctx_id)
 {
 	struct _starpu_dmda_data *dt = (struct _starpu_dmda_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
@@ -1184,6 +1197,7 @@ struct starpu_sched_policy _starpu_sched_dm_policy =
 	.remove_workers = dmda_remove_workers,
 	.push_task = dm_push_task,
 	.simulate_push_task = NULL,
+	.push_task_notify = dm_push_task_notify,
 	.pop_task = dmda_pop_task,
 	.pre_exec_hook = dmda_pre_exec_hook,
 	.post_exec_hook = dmda_post_exec_hook,