Browse Source

port r19150 from 1.2: make ws and lws correctly call can_execute on push, to avoid pushing a task to a worker which can not process it

Samuel Thibault 8 years ago
parent
commit
def2eaea6e
1 changed files with 24 additions and 18 deletions
  1. 24 18
      src/sched_policies/work_stealing_policy.c

+ 24 - 18
src/sched_policies/work_stealing_policy.c

@@ -157,18 +157,21 @@ static int select_victim_round_robin(struct _starpu_work_stealing_data *ws, unsi
  * Return a worker to whom add a task.
  * Selecting a worker is done in a round-robin fashion.
  */
-static unsigned select_worker_round_robin(struct _starpu_work_stealing_data *ws, unsigned sched_ctx_id)
+static unsigned select_worker_round_robin(struct _starpu_work_stealing_data *ws, struct starpu_task *task, unsigned sched_ctx_id)
 {
-	unsigned worker = ws->last_push_worker;
+	unsigned worker;
 	unsigned nworkers;
 	int *workerids;
 	nworkers = starpu_sched_ctx_get_workers_list_raw(sched_ctx_id, &workerids);
 
-	ws->last_push_worker = (ws->last_push_worker + 1) % nworkers;
+	worker = ws->last_push_worker;
+	do
+		worker = (worker + 1) % nworkers;
+	while (!starpu_worker_can_execute_task_first_impl(workerids[worker], task, NULL));
 
-	worker = workerids[worker];
+	ws->last_push_worker = worker;
 
-	return worker;
+	return workerids[worker];
 }
 
 #ifdef USE_LOCALITY
@@ -206,7 +209,7 @@ static unsigned select_worker_locality(struct _starpu_work_stealing_data *ws, st
 		while(workers->has_next(workers, &it))
 		{
 			int workerid = workers->get_next(workers, &it);
-			if (ndata[workerid] > best_ndata && ws->per_worker[worker].busy)
+			if (ndata[workerid] > best_ndata && ws->per_worker[workerid].busy)
 			{
 				best_worker = workerid;
 				best_ndata = ndata[workerid];
@@ -385,10 +388,10 @@ static float overload_metric(struct _starpu_work_stealing_data *ws, unsigned sch
 	unsigned njobs = _starpu_get_deque_njobs(ws->per_worker[id].queue_array);
 
 	/* Did we get enough information ? */
-	if (performed_total > 0 && nprocessed > 0)
+	if (ws->performed_total > 0 && nprocessed > 0)
 	{
 		/* How fast or slow is the worker compared to the other workers */
-		execution_ratio = (float) nprocessed / performed_total;
+		execution_ratio = (float) nprocessed / ws->performed_total;
 		/* How replete is its queue */
 		current_ratio = (float) njobs / nprocessed;
 	}
@@ -416,7 +419,7 @@ static int select_victim_overload(struct _starpu_work_stealing_data *ws, unsigne
 
 	/* Don't try to play smart until we get
 	 * enough informations. */
-	if (performed_total < calibration_value)
+	if (ws->performed_total < calibration_value)
 		return select_victim_round_robin(ws, sched_ctx_id);
 
 	struct starpu_worker_collection *workers = starpu_sched_ctx_get_worker_collection(sched_ctx_id);
@@ -446,8 +449,9 @@ static int select_victim_overload(struct _starpu_work_stealing_data *ws, unsigne
  * by the tasks are taken into account to select the most suitable
  * worker to add a task to.
  */
-static unsigned select_worker_overload(struct _starpu_work_stealing_data *ws, unsigned sched_ctx_id)
+static unsigned select_worker_overload(struct _starpu_work_stealing_data *ws, struct starpu_task *task, unsigned sched_ctx_id)
 {
+	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned worker;
 	float  worker_ratio;
 	unsigned best_worker = 0;
@@ -455,8 +459,8 @@ static unsigned select_worker_overload(struct _starpu_work_stealing_data *ws, un
 
 	/* Don't try to play smart until we get
 	 * enough informations. */
-	if (performed_total < calibration_value)
-		return select_worker_round_robin(sched_ctx_id);
+	if (ws->performed_total < calibration_value)
+		return select_worker_round_robin(task, sched_ctx_id);
 
 	struct starpu_worker_collection *workers = starpu_sched_ctx_get_worker_collection(sched_ctx_id);
 
@@ -469,7 +473,8 @@ static unsigned select_worker_overload(struct _starpu_work_stealing_data *ws, un
 
 		worker_ratio = overload_metric(ws, sched_ctx_id, worker);
 
-		if (worker_ratio < best_ratio)
+		if (worker_ratio < best_ratio &&
+			starpu_worker_can_execute_task_first_impl(worker, task, NULL))
 		{
 			best_worker = worker;
 			best_ratio = worker_ratio;
@@ -502,12 +507,12 @@ static inline int select_victim(struct _starpu_work_stealing_data *ws, unsigned
  * This is a phony function used to call the right
  * function depending on the value of USE_OVERLOAD.
  */
-static inline unsigned select_worker(struct _starpu_work_stealing_data *ws, unsigned sched_ctx_id)
+static inline unsigned select_worker(struct _starpu_work_stealing_data *ws, struct starpu_task *task, unsigned sched_ctx_id)
 {
 #ifdef USE_OVERLOAD
-	return select_worker_overload(ws, sched_ctx_id);
+	return select_worker_overload(ws, task, sched_ctx_id);
 #else
-	return select_worker_round_robin(ws, sched_ctx_id);
+	return select_worker_round_robin(ws, task, sched_ctx_id);
 #endif /* USE_OVERLOAD */
 }
 
@@ -599,8 +604,9 @@ int ws_push_task(struct starpu_task *task)
 	/* If the current thread is not a worker but
 	 * the main thread (-1) or the current worker is not in the target
 	 * context, we find the better one to put task on its queue */
-	if (workerid == -1 || !starpu_sched_ctx_contains_worker(workerid, sched_ctx_id))
-		workerid = select_worker(ws, sched_ctx_id);
+	if (workerid == -1 || !starpu_sched_ctx_contains_worker(workerid, sched_ctx_id) ||
+			!starpu_worker_can_execute_task_first_impl(workerid, task, NULL))
+		workerid = select_worker(ws, task, sched_ctx_id);
 
 	starpu_pthread_mutex_t *sched_mutex;
 	starpu_pthread_cond_t *sched_cond;