Browse Source

Fix work stealing policy with contexts

Terry Cojean 9 years ago
parent
commit
7f9c179b81
1 changed files with 29 additions and 13 deletions
  1. 29 13
      src/sched_policies/work_stealing_policy.c

+ 29 - 13
src/sched_policies/work_stealing_policy.c

@@ -67,6 +67,8 @@ static unsigned select_victim_round_robin(unsigned sched_ctx_id)
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned worker = ws->last_pop_worker;
 	unsigned nworkers = starpu_sched_ctx_get_nworkers(sched_ctx_id);
+	int *workerids;
+	starpu_sched_ctx_get_workers_list(sched_ctx_id, &workerids);
 
 	/* If the worker's queue is empty, let's try
 	 * the next ones */
@@ -77,7 +79,7 @@ static unsigned select_victim_round_robin(unsigned sched_ctx_id)
 		/* Here helgrind would shout that this is unprotected, but we
 		 * are fine with getting outdated values, this is just an
 		 * estimation */
-		ntasks = ws->queue_array[worker]->ntasks;
+		ntasks = ws->queue_array[workerids[worker]]->ntasks;
 
 		if (ntasks)
 			break;
@@ -93,7 +95,7 @@ static unsigned select_victim_round_robin(unsigned sched_ctx_id)
 
 	ws->last_pop_worker = (worker + 1) % nworkers;
 
-	return worker;
+	return workerids[worker];
 }
 
 /**
@@ -105,10 +107,12 @@ static unsigned select_worker_round_robin(unsigned sched_ctx_id)
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned worker = ws->last_push_worker;
 	unsigned nworkers = starpu_sched_ctx_get_nworkers(sched_ctx_id);
+	int *workerids;
+	starpu_sched_ctx_get_workers_list(sched_ctx_id, &workerids);
 
 	ws->last_push_worker = (ws->last_push_worker + 1) % nworkers;
 
-	return worker;
+	return workerids[worker];
 }
 
 #ifdef USE_OVERLOAD
@@ -278,20 +282,20 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	starpu_pthread_cond_t *worker_sched_cond;
 	starpu_worker_get_sched_condition(workerid, &worker_sched_mutex, &worker_sched_cond);
 
-	/* Note: Releasing this mutex before taking the victim mutex, to avoid interlock*/
-	STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(worker_sched_mutex);
-       
-
 	/* we need to steal someone's job */
 	unsigned victim = ws->select_victim(sched_ctx_id, workerid);
 
+	/* Note: Releasing this mutex before taking the victim mutex, to avoid interlock*/
+	STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(worker_sched_mutex);
+
 	starpu_pthread_mutex_t *victim_sched_mutex;
 	starpu_pthread_cond_t *victim_sched_cond;
 
 	starpu_worker_get_sched_condition(victim, &victim_sched_mutex, &victim_sched_cond);
 	STARPU_PTHREAD_MUTEX_LOCK_SCHED(victim_sched_mutex);
 
-	task = _starpu_fifo_pop_task(ws->queue_array[victim], workerid);
+	if (ws->queue_array[victim] != NULL && ws->queue_array[victim]->ntasks > 0)
+		task = _starpu_fifo_pop_task(ws->queue_array[victim], workerid);
 	if (task)
 	{
 		_STARPU_TRACE_WORK_STEALING(workerid, victim);
@@ -302,7 +306,8 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	STARPU_PTHREAD_MUTEX_LOCK_SCHED(worker_sched_mutex);
 	if(!task)
 	{
-		task = _starpu_fifo_pop_task(ws->queue_array[workerid], workerid);
+		if (ws->queue_array[workerid] != NULL && ws->queue_array[workerid]->ntasks > 0)
+			task = _starpu_fifo_pop_task(ws->queue_array[workerid], workerid);
 		if (task)
 		{
 			/* there was a local task */
@@ -370,7 +375,8 @@ static void ws_add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworke
 	{
 		workerid = workerids[i];
 		starpu_sched_ctx_worker_shares_tasks_lists(workerid, sched_ctx_id);
-		ws->queue_array[workerid] = _starpu_create_fifo();
+		if (ws->queue_array[workerid] == NULL)
+			ws->queue_array[workerid] = _starpu_create_fifo();
 
 		/* Tell helgrid that we are fine with getting outdated values,
 		 * this is just an estimation */
@@ -401,11 +407,11 @@ static void initialize_ws_policy(unsigned sched_ctx_id)
 
 	ws->last_pop_worker = 0;
 	ws->last_push_worker = 0;
-	ws->proxlist = NULL;
 	ws->select_victim = select_victim;
 
 	unsigned nw = starpu_worker_get_count();
-	ws->queue_array = (struct _starpu_fifo_taskq**)malloc(nw*sizeof(struct _starpu_fifo_taskq*));
+	if (ws->queue_array == NULL)
+		ws->queue_array = (struct _starpu_fifo_taskq**)malloc(nw*sizeof(struct _starpu_fifo_taskq*));
 }
 
 static void deinit_ws_policy(unsigned sched_ctx_id)
@@ -469,7 +475,6 @@ static void lws_add_workers(unsigned sched_ctx_id, int *workerids,
 	 * build this once and then use it for popping tasks rather
 	 * than traversing the hwloc tree every time a task must be
 	 * stolen */
-	ws->proxlist = (int**)malloc(nworkers*sizeof(int*));
 	struct starpu_worker_collection *workers = starpu_sched_ctx_get_worker_collection(sched_ctx_id);
 	struct starpu_tree *tree = (struct starpu_tree*)workers->workerids;
 	for (i = 0; i < nworkers; i++)
@@ -517,6 +522,17 @@ static void initialize_lws_policy(unsigned sched_ctx_id)
 #ifdef STARPU_HAVE_HWLOC
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data *)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	ws->select_victim = lws_select_victim;
+
+	int nworkers = starpu_worker_get_count(), i;
+
+	if (ws->proxlist == NULL)
+	{
+		ws->proxlist = (int**)malloc(nworkers*sizeof(int*));
+		for (i = 0; i < nworkers; i++) {
+			if (ws->proxlist[i] == NULL)
+				ws->proxlist[i] = (int*)malloc(nworkers*sizeof(int));
+		}
+	}
 #endif
 }