Kaynağa Gözat

Optimize priority support of ws/lws by using an rbtree

Samuel Thibault 8 yıl önce
ebeveyn
işleme
13ac05aa06
1 değiştirilmiş dosya ile 32 ekleme ve 25 silme
  1. 32 25
      src/sched_policies/work_stealing_policy.c

+ 32 - 25
src/sched_policies/work_stealing_policy.c

@@ -21,7 +21,7 @@
 #include <float.h>
 
 #include <core/workers.h>
-#include <sched_policies/fifo_queues.h>
+#include <sched_policies/prio_deque.h>
 #include <core/debug.h>
 #include <starpu_scheduler.h>
 #include <core/sched_policy.h>
@@ -72,12 +72,13 @@ struct locality_entry
 
 struct _starpu_work_stealing_data_per_worker
 {
-	struct _starpu_fifo_taskq *queue_array;
+	struct _starpu_prio_deque queue;
+	int running;
 	int *proxlist;
 	int busy;
 
 #ifdef USE_LOCALITY_TASKS
-	/* This records the same as queue_array, but hashed by data accessed with locality flag.  */
+	/* This records the same as queue, but hashed by data accessed with locality flag.  */
 	/* FIXME: we record only one task per data, assuming that the access is
 	 * RW, and thus only one task is ready to write to it. Do we really need to handle the R case too? */
 	struct locality_entry *queued_tasks_per_data;
@@ -131,7 +132,7 @@ static int select_victim_round_robin(struct _starpu_work_stealing_data *ws, unsi
 		/* Here helgrind would shout that this is unprotected, but we
 		 * are fine with getting outdated values, this is just an
 		 * estimation */
-		ntasks = ws->per_worker[workerids[worker]].queue_array->ntasks;
+		ntasks = ws->per_worker[workerids[worker]].queue.ntasks;
 
 		if (ntasks && (ws->per_worker[workerids[worker]].busy
 					   || starpu_worker_is_blocked_in_parallel(workerids[worker])))
@@ -171,7 +172,7 @@ static unsigned select_worker_round_robin(struct _starpu_work_stealing_data *ws,
 	worker = ws->last_push_worker;
 	do
 		worker = (worker + 1) % nworkers;
-	while (ws->per_worker[worker].queue_array && !starpu_worker_can_execute_task_first_impl(workerids[worker], task, NULL));
+	while (!ws->per_worker[worker].running || !starpu_worker_can_execute_task_first_impl(workerids[worker], task, NULL));
 
 	ws->last_push_worker = worker;
 
@@ -214,7 +215,7 @@ static unsigned select_worker_locality(struct _starpu_work_stealing_data *ws, st
 		while(workers->has_next(workers, &it))
 		{
 			int workerid = workers->get_next(workers, &it);
-			if (ndata[workerid] > best_ndata && ws->per_worker[worker].queue_array && ws->per_worker[workerid].busy)
+			if (ndata[workerid] > best_ndata && ws->per_worker[worker].running && ws->per_worker[workerid].busy)
 			{
 				best_worker = workerid;
 				best_ndata = ndata[workerid];
@@ -326,12 +327,13 @@ static struct starpu_task *ws_pick_task(struct _starpu_work_stealing_data *ws, i
 	if (best_n > 0)
 	{
 		/* found an interesting task, try to pick it! */
-		if (_starpu_fifo_pop_this_task(data_source->queue_array, target, best_task))
+		if (_starpu_prio_deque_pop_this_task(data_source->queue, target, best_task))
 			return best_task;
 	}
 
 	/* Didn't find an interesting task, or couldn't run it :( */
-	return _starpu_fifo_pop_task(data_source->queue_array, target);
+	int skipped;
+	return _starpu_prio_deque_pop_task_for_worker(data_source->queue, target, &skipped);
 }
 
 /* Called when popping a task from a queue */
@@ -366,7 +368,8 @@ static void locality_pushed_task(struct _starpu_work_stealing_data *ws STARPU_AT
 /* Pick a task from workerid's queue, for execution on target */
 static struct starpu_task *ws_pick_task(struct _starpu_work_stealing_data *ws, int source, int target)
 {
-	return _starpu_fifo_pop_task(ws->per_worker[source].queue_array, target);
+	int skipped;
+	return _starpu_prio_deque_pop_task_for_worker(&ws->per_worker[source].queue, target, &skipped);
 }
 /* Called when popping a task from a queue */
 static void locality_popped_task(struct _starpu_work_stealing_data *ws STARPU_ATTRIBUTE_UNUSED, struct starpu_task *task STARPU_ATTRIBUTE_UNUSED, int workerid STARPU_ATTRIBUTE_UNUSED, unsigned sched_ctx_id STARPU_ATTRIBUTE_UNUSED)
@@ -389,8 +392,8 @@ static float overload_metric(struct _starpu_work_stealing_data *ws, unsigned sch
 	float execution_ratio = 0.0f;
 	float current_ratio = 0.0f;
 
-	int nprocessed = _starpu_get_deque_nprocessed(ws->per_worker[id].queue_array);
-	unsigned njobs = _starpu_get_deque_njobs(ws->per_worker[id].queue_array);
+	int nprocessed = _starpu_get_deque_nprocessed(ws->per_worker[id].queue);
+	unsigned njobs = _starpu_get_deque_njobs(ws->per_worker[id].queue);
 
 	/* Did we get enough information ? */
 	if (ws->performed_total > 0 && nprocessed > 0)
@@ -434,7 +437,7 @@ static int select_victim_overload(struct _starpu_work_stealing_data *ws, unsigne
                 unsigned worker = workers->get_next(workers, &it);
 		float worker_ratio = overload_metric(ws, sched_ctx_id, worker);
 
-		if (worker_ratio > best_ratio && ws->per_worker[worker].queue_array && ws->per_worker[worker].busy)
+		if (worker_ratio > best_ratio && ws->per_worker[worker].running && ws->per_worker[worker].busy)
 		{
 			best_worker = worker;
 			best_ratio = worker_ratio;
@@ -471,7 +474,7 @@ static unsigned select_worker_overload(struct _starpu_work_stealing_data *ws, st
 		unsigned worker = workers->get_next(workers, &it);
 		float worker_ratio = overload_metric(ws, sched_ctx_id, worker);
 
-		if (worker_ratio < best_ratio && ws->per_worker[worker].queue_array && starpu_worker_can_execute_task_first_impl(worker, task, NULL))
+		if (worker_ratio < best_ratio && ws->per_worker[worker].running && starpu_worker_can_execute_task_first_impl(worker, task, NULL))
 		{
 			best_worker = worker;
 			best_ratio = worker_ratio;
@@ -525,7 +528,7 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	ws->per_worker[workerid].busy = 0;
 
 #ifdef STARPU_NON_BLOCKING_DRIVERS
-	if (STARPU_RUNNING_ON_VALGRIND || !_starpu_fifo_empty(ws->per_worker[workerid].queue_array))
+	if (STARPU_RUNNING_ON_VALGRIND || !_starpu_prio_deque_is_empty(&ws->per_worker[workerid].queue))
 #endif
 	{
 		task = ws_pick_task(ws, workerid, workerid);
@@ -566,7 +569,7 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 		/* victim is busy, don't bother it, come back later */
 		return NULL;
 	}
-	if (ws->per_worker[victim].queue_array != NULL && ws->per_worker[victim].queue_array->ntasks > 0)
+	if (ws->per_worker[victim].running && ws->per_worker[victim].queue.ntasks > 0)
 	{
 		task = ws_pick_task(ws, victim, workerid);
 	}
@@ -644,8 +647,8 @@ int ws_push_task(struct starpu_task *task)
 	STARPU_AYU_ADDTOTASKQUEUE(starpu_task_get_job_id(task), workerid);
 	_STARPU_TASK_BREAK_ON(task, sched);
 	record_data_locality(task, workerid);
-	STARPU_ASSERT_MSG(ws->per_worker[workerid].queue_array, "workerid=%d, ws=%p\n", workerid, ws);
-	_starpu_fifo_push_task(ws->per_worker[workerid].queue_array, task);
+	STARPU_ASSERT_MSG(ws->per_worker[workerid].running, "workerid=%d, ws=%p\n", workerid, ws);
+	_starpu_prio_deque_push_task(&ws->per_worker[workerid].queue, task);
 	locality_pushed_task(ws, task, workerid, sched_ctx_id);
 
 	starpu_push_task_end(task);
@@ -673,11 +676,12 @@ static void ws_add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworke
 	{
 		int workerid = workerids[i];
 		starpu_sched_ctx_worker_shares_tasks_lists(workerid, sched_ctx_id);
-		ws->per_worker[workerid].queue_array = _starpu_create_fifo();
+		_starpu_prio_deque_init(&ws->per_worker[workerid].queue);
+		ws->per_worker[workerid].running = 1;
 
 		/* Tell helgrind that we are fine with getting outdated values,
 		 * this is just an estimation */
-		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].queue_array->ntasks);
+		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].queue.ntasks);
 		ws->per_worker[workerid].busy = 0;
 		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].busy);
 	}
@@ -692,11 +696,8 @@ static void ws_remove_workers(unsigned sched_ctx_id, int *workerids, unsigned nw
 	{
 		int workerid = workerids[i];
 
-		if (ws->per_worker[workerid].queue_array != NULL)
-		{
-			_starpu_destroy_fifo(ws->per_worker[workerid].queue_array);
-			ws->per_worker[workerid].queue_array = NULL;
-		}
+		_starpu_prio_deque_destroy(&ws->per_worker[workerid].queue);
+		ws->per_worker[workerid].running = 0;
 		free(ws->per_worker[workerid].proxlist);
 		ws->per_worker[workerid].proxlist = NULL;
 	}
@@ -716,6 +717,12 @@ static void initialize_ws_policy(unsigned sched_ctx_id)
 
 	unsigned nw = starpu_worker_get_count();
 	_STARPU_CALLOC(ws->per_worker, nw, sizeof(struct _starpu_work_stealing_data_per_worker));
+
+	/* The application may use any integer */
+	if (starpu_sched_ctx_min_priority_is_set(sched_ctx_id) == 0)
+		starpu_sched_ctx_set_min_priority(sched_ctx_id, INT_MIN);
+	if (starpu_sched_ctx_max_priority_is_set(sched_ctx_id) == 0)
+		starpu_sched_ctx_set_max_priority(sched_ctx_id, INT_MAX);
 }
 
 static void deinit_ws_policy(unsigned sched_ctx_id)
@@ -754,7 +761,7 @@ static int lws_select_victim(struct _starpu_work_stealing_data *ws, unsigned sch
 	for (i = 0; i < nworkers; i++)
 	{
 		int neighbor = ws->per_worker[workerid].proxlist[i];
-		int ntasks = ws->per_worker[neighbor].queue_array->ntasks;
+		int ntasks = ws->per_worker[neighbor].queue.ntasks;
 		if (ntasks && (ws->per_worker[neighbor].busy
 					   || starpu_worker_is_blocked_in_parallel(neighbor)))
 			return neighbor;