Ver código fonte

Fix scalability of lws for small tasks, by avoiding stealing from a worker which is not running a task, and relieving the sched mutex while stealing

Samuel Thibault 8 anos atrás
pai
commit
8715eba93c
1 arquivos alterados com 32 adições e 30 exclusões
  1. 32 30
      src/sched_policies/work_stealing_policy.c

+ 32 - 30
src/sched_policies/work_stealing_policy.c

@@ -71,6 +71,7 @@ struct _starpu_work_stealing_data_per_worker
 	struct _starpu_fifo_taskq *queue_array;
 	int *proxlist;
 	starpu_pthread_mutex_t worker_mutex;
+	int busy;
 
 #ifdef USE_LOCALITY_TASKS
 	/* This records the same as queue_array, but hashed by data accessed with locality flag.  */
@@ -130,7 +131,7 @@ static int select_victim_round_robin(unsigned sched_ctx_id)
 		 * estimation */
 		ntasks = ws->per_worker[workerids[worker]].queue_array->ntasks;
 
-		if (ntasks)
+		if (ntasks && ws->per_worker[workerids[worker]].busy)
 			break;
 
 		worker = (worker + 1) % nworkers;
@@ -138,6 +139,7 @@ static int select_victim_round_robin(unsigned sched_ctx_id)
 		{
 			/* We got back to the first worker,
 			 * don't go in infinite loop */
+			ntasks = 0;
 			break;
 		}
 	}
@@ -175,6 +177,7 @@ static unsigned select_worker_round_robin(unsigned sched_ctx_id)
 /* Select a worker according to the locality of the data of the task to be scheduled */
 static unsigned select_worker_locality(struct starpu_task *task, unsigned sched_ctx_id)
 {
+	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned nbuffers = STARPU_TASK_GET_NBUFFERS(task);
 	if (nbuffers == 0)
 		return -1;
@@ -206,7 +209,7 @@ static unsigned select_worker_locality(struct starpu_task *task, unsigned sched_
 		while(workers->has_next(workers, &it))
 		{
 			int workerid = workers->get_next(workers, &it);
-			if (ndata[workerid] > best_ndata)
+			if (ndata[workerid] > best_ndata && ws->per_worker[worker].busy)
 			{
 				best_worker = workerid;
 				best_ndata = ndata[workerid];
@@ -275,9 +278,8 @@ static void locality_pushed_task(struct starpu_task *task, int workerid, unsigne
 }
 
 /* Pick a task from workerid's queue, for execution on target */
-static struct starpu_task *ws_pick_task(int source, int target, unsigned sched_ctx_id)
+static struct starpu_task *ws_pick_task(struct _starpu_work_stealing_data *ws, int source, int target, unsigned sched_ctx_id)
 {
-	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	struct _starpu_work_stealing_data_per_worker *data_source = &ws->per_worker[source];
 	struct _starpu_work_stealing_data_per_worker *data_target = &ws->per_worker[target];
 	unsigned i, j, n = data_target->nlast_locality;
@@ -360,9 +362,8 @@ static void locality_pushed_task(struct starpu_task *task STARPU_ATTRIBUTE_UNUSE
 {
 }
 /* Pick a task from workerid's queue, for execution on target */
-static struct starpu_task *ws_pick_task(int source, int target, unsigned sched_ctx_id)
+static struct starpu_task *ws_pick_task(struct _starpu_work_stealing_data *ws, int source, int target, unsigned sched_ctx_id)
 {
-	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	return _starpu_fifo_pop_task(ws->per_worker[source].queue_array, target);
 }
 /* Called when popping a task from a queue */
@@ -415,6 +416,7 @@ static float overload_metric(unsigned sched_ctx_id, unsigned id)
  */
 static int select_victim_overload(unsigned sched_ctx_id)
 {
+	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned worker;
 	float  worker_ratio;
 	unsigned best_worker = 0;
@@ -435,7 +437,7 @@ static int select_victim_overload(unsigned sched_ctx_id)
                 worker = workers->get_next(workers, &it);
 		worker_ratio = overload_metric(sched_ctx_id, worker);
 
-		if (worker_ratio > best_ratio)
+		if (worker_ratio > best_ratio && ws->per_worker[worker].busy)
 		{
 			best_worker = worker;
 			best_ratio = worker_ratio;
@@ -526,12 +528,14 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	struct starpu_task *task = NULL;
 	unsigned workerid = starpu_worker_get_id_check();
 
+	ws->per_worker[workerid].busy = 0;
+
 #ifdef STARPU_NON_BLOCKING_DRIVERS
 	if (STARPU_RUNNING_ON_VALGRIND || !_starpu_fifo_empty(ws->per_worker[workerid].queue_array))
 #endif
 	{
 		STARPU_PTHREAD_MUTEX_LOCK(&ws->per_worker[workerid].worker_mutex);
-		task = ws_pick_task(workerid, workerid, sched_ctx_id);
+		task = ws_pick_task(ws, workerid, workerid, sched_ctx_id);
 		if (task)
 			locality_popped_task(task, workerid, sched_ctx_id);
 		STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[workerid].worker_mutex);
@@ -540,21 +544,34 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	if (task)
 	{
 		/* there was a local task */
+		ws->per_worker[workerid].busy = 1;
 		return task;
 	}
 
+	starpu_pthread_mutex_t *sched_mutex;
+	starpu_pthread_cond_t *sched_cond;
+	starpu_worker_get_sched_condition(workerid, &sched_mutex, &sched_cond);
+	/* While stealing, relieve mutex used to synchronize with pushers */
+	STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(sched_mutex);
+
 
 	/* we need to steal someone's job */
 	int victim = ws->select_victim(sched_ctx_id, workerid);
 	if (victim == -1)
+	{
+		STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
 		return NULL;
+	}
 
 	if (STARPU_PTHREAD_MUTEX_TRYLOCK(&ws->per_worker[victim].worker_mutex))
+	{
 		/* victim is busy, don't bother it, come back later */
+		STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
 		return NULL;
+	}
 	if (ws->per_worker[victim].queue_array != NULL && ws->per_worker[victim].queue_array->ntasks > 0)
 	{
-		task = ws_pick_task(victim, workerid, sched_ctx_id);
+		task = ws_pick_task(ws, victim, workerid, sched_ctx_id);
 	}
 
 	if (task)
@@ -567,27 +584,10 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[victim].worker_mutex);
 
-	if(!task
-#ifdef STARPU_NON_BLOCKING_DRIVERS
-		&& (STARPU_RUNNING_ON_VALGRIND || !_starpu_fifo_empty(ws->per_worker[workerid].queue_array))
-#endif
-		)
-	{
-		STARPU_PTHREAD_MUTEX_LOCK(&ws->per_worker[workerid].worker_mutex);
-		if (ws->per_worker[workerid].queue_array != NULL && ws->per_worker[workerid].queue_array->ntasks > 0)
-			task = ws_pick_task(workerid, workerid, sched_ctx_id);
-
-		if (task)
-			locality_popped_task(task, workerid, sched_ctx_id);
-		STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[workerid].worker_mutex);
-
-		if (task)
-		{
-			/* there was a local task */
-			return task;
-		}
-	}
+	/* Done with stealing, resynchronize with core */
+	STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
 
+	ws->per_worker[workerid].busy = !!task;
 	return task;
 }
 
@@ -654,6 +654,8 @@ static void ws_add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworke
 		 * this is just an estimation */
 		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].queue_array->ntasks);
 		STARPU_PTHREAD_MUTEX_INIT(&ws->per_worker[workerid].worker_mutex, NULL);
+		ws->per_worker[workerid].busy = 0;
+		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].busy);
 	}
 }
 
@@ -734,7 +736,7 @@ static int lws_select_victim(unsigned sched_ctx_id, int workerid)
 	{
 		neighbor = ws->per_worker[workerid].proxlist[i];
 		int ntasks = ws->per_worker[neighbor].queue_array->ntasks;
-		if (ntasks)
+		if (ntasks && ws->per_worker[workerid].busy)
 			return neighbor;
 	}
 	return -1;