瀏覽代碼

ws, modular-ws: make last_pop_worker per worker

Sharing it means significant memory trafic only to keep it coherent when
all workers are idle. And sharing it would not be that useful.
Samuel Thibault 5 年之前
父節點
當前提交
aa8972ad0b
共有 2 個文件被更改,包括 40 次插入32 次删除
  1. 30 25
      src/sched_policies/component_work_stealing.c
  2. 10 7
      src/sched_policies/work_stealing_policy.c

+ 30 - 25
src/sched_policies/component_work_stealing.c

@@ -30,14 +30,20 @@
 #warning TODO: locality work-stealing
 #endif
 
+struct _starpu_component_work_stealing_data_per_worker
+{
+	struct _starpu_prio_deque fifo;
+	unsigned last_pop_child;
+};
+
 struct _starpu_component_work_stealing_data
 {
 /* keep track of the work performed from the beginning of the algorithm to make
  * better decisions about which queue to child when stealing or deferring work
  */
-	unsigned performed_total, last_pop_child, last_push_child;
+	struct _starpu_component_work_stealing_data_per_worker *per_worker;
+	unsigned performed_total, last_push_child;
 
-	struct _starpu_prio_deque * fifos;
 	starpu_pthread_mutex_t ** mutexes;
 	unsigned size;
 };
@@ -50,16 +56,14 @@ struct _starpu_component_work_stealing_data
 static struct starpu_task *  steal_task_round_robin(struct starpu_sched_component *component, int workerid)
 {
 	struct _starpu_component_work_stealing_data *wsd = component->data;
-	STARPU_HG_DISABLE_CHECKING(wsd->last_pop_child);
-	unsigned i = wsd->last_pop_child;
-	wsd->last_pop_child = (i + 1) % component->nchildren;
-	STARPU_HG_ENABLE_CHECKING(wsd->last_pop_child);
+	unsigned i = wsd->per_worker[workerid].last_pop_child;
+	wsd->per_worker[workerid].last_pop_child = (i + 1) % component->nchildren;
 	/* If the worker's queue have no suitable tasks, let's try
 	 * the next ones */
 	struct starpu_task * task = NULL;
 	while (1)
 	{
-		struct _starpu_prio_deque * fifo = &wsd->fifos[i];
+		struct _starpu_prio_deque * fifo = &wsd->per_worker[i].fifo;
 
 		STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
 		task = _starpu_prio_deque_deque_task_for_worker(fifo, workerid, NULL);
@@ -75,7 +79,7 @@ static struct starpu_task *  steal_task_round_robin(struct starpu_sched_componen
 			break;
 		}
 
-		if (i == wsd->last_pop_child)
+		if (i == wsd->per_worker[workerid].last_pop_child)
 		{
 			/* We got back to the first worker,
 			 * don't go in infinite loop */
@@ -141,17 +145,17 @@ static struct starpu_task * pull_task(struct starpu_sched_component * component,
 	struct _starpu_component_work_stealing_data * wsd = component->data;
 	const double now = starpu_timing_now();
 	STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
-	struct starpu_task * task = _starpu_prio_deque_pop_task(&wsd->fifos[i]);
+	struct starpu_task * task = _starpu_prio_deque_pop_task(&wsd->per_worker[i].fifo);
 	if(task)
 	{
 		if(!isnan(task->predicted))
 		{
-			wsd->fifos[i].exp_len -= task->predicted;
-			wsd->fifos[i].exp_start = now + task->predicted;
+			wsd->per_worker[i].fifo.exp_len -= task->predicted;
+			wsd->per_worker[i].fifo.exp_start = now + task->predicted;
 		}
 	}
 	else
-		wsd->fifos[i].exp_len = 0.0;
+		wsd->per_worker[i].fifo.exp_len = 0.0;
 
 	STARPU_COMPONENT_MUTEX_UNLOCK(wsd->mutexes[i]);
 	if(task)
@@ -163,7 +167,7 @@ static struct starpu_task * pull_task(struct starpu_sched_component * component,
 	if(task)
 	{
 		STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
-		wsd->fifos[i].nprocessed++;
+		wsd->per_worker[i].fifo.nprocessed++;
 		STARPU_COMPONENT_MUTEX_UNLOCK(wsd->mutexes[i]);
 
 		return task;
@@ -196,9 +200,9 @@ double _ws_estimated_end(struct starpu_sched_component * component)
 	for(i = 0; i < component->nchildren; i++)
 	{
 		STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
-		sum_len += wsd->fifos[i].exp_len;
-		wsd->fifos[i].exp_start = STARPU_MAX(now, wsd->fifos[i].exp_start);
-		sum_start += wsd->fifos[i].exp_start;
+		sum_len += wsd->per_worker[i].fifo.exp_len;
+		wsd->per_worker[i].fifo.exp_start = STARPU_MAX(now, wsd->per_worker[i].fifo.exp_start);
+		sum_start += wsd->per_worker[i].fifo.exp_start;
 		STARPU_COMPONENT_MUTEX_UNLOCK(wsd->mutexes[i]);
 
 	}
@@ -216,7 +220,7 @@ double _ws_estimated_load(struct starpu_sched_component * component)
 	for(i = 0; i < component->nchildren; i++)
 	{
 		STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
-		ntasks += wsd->fifos[i].ntasks;
+		ntasks += wsd->per_worker[i].fifo.ntasks;
 		STARPU_COMPONENT_MUTEX_UNLOCK(wsd->mutexes[i]);
 	}
 	double speedup = 0.0;
@@ -265,7 +269,7 @@ static int push_task(struct starpu_sched_component * component, struct starpu_ta
 
 	STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
 	starpu_sched_task_break(task);
-	ret = _starpu_prio_deque_push_front_task(&wsd->fifos[i], task);
+	ret = _starpu_prio_deque_push_front_task(&wsd->per_worker[i].fifo, task);
 	STARPU_COMPONENT_MUTEX_UNLOCK(wsd->mutexes[i]);
 
 	wsd->last_push_child = i;
@@ -308,9 +312,9 @@ int starpu_sched_tree_work_stealing_push_task(struct starpu_task *task)
 
 			struct _starpu_component_work_stealing_data * wsd = component->data;
 			STARPU_COMPONENT_MUTEX_LOCK(wsd->mutexes[i]);
-			int ret = _starpu_prio_deque_push_front_task(&wsd->fifos[i] , task);
+			int ret = _starpu_prio_deque_push_front_task(&wsd->per_worker[i].fifo , task);
 			if(ret == 0 && !isnan(task->predicted))
-				wsd->fifos[i].exp_len += task->predicted;
+				wsd->per_worker[i].fifo.exp_len += task->predicted;
 			STARPU_COMPONENT_MUTEX_UNLOCK(wsd->mutexes[i]);
 
 			component->can_pull(component);
@@ -329,12 +333,13 @@ void _ws_add_child(struct starpu_sched_component * component, struct starpu_sche
 	if(wsd->size < component->nchildren)
 	{
 		STARPU_ASSERT(wsd->size == component->nchildren - 1);
-		_STARPU_REALLOC(wsd->fifos, component->nchildren * sizeof(*wsd->fifos));
+		_STARPU_REALLOC(wsd->per_worker, component->nchildren * sizeof(*wsd->per_worker));
 		_STARPU_REALLOC(wsd->mutexes, component->nchildren * sizeof(*wsd->mutexes));
 		wsd->size = component->nchildren;
 	}
 
-	_starpu_prio_deque_init(&wsd->fifos[component->nchildren - 1]);
+	wsd->per_worker[component->nchildren - 1].last_pop_child = 0;
+	_starpu_prio_deque_init(&wsd->per_worker[component->nchildren - 1].fifo);
 
 	starpu_pthread_mutex_t *mutex;
 	_STARPU_MALLOC(mutex, sizeof(*mutex));
@@ -356,8 +361,8 @@ void _ws_remove_child(struct starpu_sched_component * component, struct starpu_s
 			break;
 	}
 	STARPU_ASSERT(i_component != component->nchildren);
-	struct _starpu_prio_deque tmp_fifo = wsd->fifos[i_component];
-	wsd->fifos[i_component] = wsd->fifos[component->nchildren - 1];
+	struct _starpu_prio_deque tmp_fifo = wsd->per_worker[i_component].fifo;
+	wsd->per_worker[i_component].fifo = wsd->per_worker[component->nchildren - 1].fifo;
 
 
 	component->children[i_component] = component->children[component->nchildren - 1];
@@ -372,7 +377,7 @@ void _ws_remove_child(struct starpu_sched_component * component, struct starpu_s
 void _work_stealing_component_deinit_data(struct starpu_sched_component * component)
 {
 	struct _starpu_component_work_stealing_data * wsd = component->data;
-	free(wsd->fifos);
+	free(wsd->per_worker);
 	free(wsd->mutexes);
 	free(wsd);
 }

+ 10 - 7
src/sched_policies/work_stealing_policy.c

@@ -82,6 +82,11 @@ struct _starpu_work_stealing_data_per_worker
 	int *proxlist;
 	int busy;	/* Whether this worker is working on a task */
 
+	/* keep track of the work performed from the beginning of the algorithm to make
+	 * better decisions about which queue to select when deferring work
+	 */
+	unsigned last_pop_worker;
+
 #ifdef USE_LOCALITY_TASKS
 	/* This records the same as queue, but hashed by data accessed with locality flag.  */
 	/* FIXME: we record only one task per data, assuming that the access is
@@ -99,9 +104,8 @@ struct _starpu_work_stealing_data
 	int (*select_victim)(struct _starpu_work_stealing_data *, unsigned, int);
 	struct _starpu_work_stealing_data_per_worker *per_worker;
 	/* keep track of the work performed from the beginning of the algorithm to make
-	 * better decisions about which queue to select when stealing or deferring work
+	 * better decisions about which queue to select when deferring work
 	 */
-	unsigned last_pop_worker;
 	unsigned last_push_worker;
 };
 
@@ -124,7 +128,8 @@ static int calibration_value = 0;
  */
 static int select_victim_round_robin(struct _starpu_work_stealing_data *ws, unsigned sched_ctx_id)
 {
-	unsigned worker = ws->last_pop_worker;
+	unsigned workerid = starpu_worker_get_id_check();
+	unsigned worker = ws->per_worker[workerid].last_pop_worker;
 	unsigned nworkers;
 	int *workerids = NULL;
 	nworkers = starpu_sched_ctx_get_workers_list_raw(sched_ctx_id, &workerids);
@@ -147,7 +152,7 @@ static int select_victim_round_robin(struct _starpu_work_stealing_data *ws, unsi
 		}
 
 		worker = (worker + 1) % nworkers;
-		if (worker == ws->last_pop_worker)
+		if (worker == ws->per_worker[workerid].last_pop_worker)
 		{
 			/* We got back to the first worker,
 			 * don't go in infinite loop */
@@ -156,7 +161,7 @@ static int select_victim_round_robin(struct _starpu_work_stealing_data *ws, unsi
 		}
 	}
 
-	ws->last_pop_worker = (worker + 1) % nworkers;
+	ws->per_worker[workerid].last_pop_worker = (worker + 1) % nworkers;
 
 	worker = workerids[worker];
 
@@ -750,9 +755,7 @@ static void initialize_ws_policy(unsigned sched_ctx_id)
 	_STARPU_MALLOC(ws, sizeof(struct _starpu_work_stealing_data));
 	starpu_sched_ctx_set_policy_data(sched_ctx_id, (void*)ws);
 
-	ws->last_pop_worker = 0;
 	ws->last_push_worker = 0;
-	STARPU_HG_DISABLE_CHECKING(ws->last_pop_worker);
 	STARPU_HG_DISABLE_CHECKING(ws->last_push_worker);
 	ws->select_victim = select_victim;