Browse Source

ws&lws: factorize per-worker data

Samuel Thibault 9 years ago
parent
commit
48b959ab37
1 changed files with 30 additions and 29 deletions
  1. 30 29
      src/sched_policies/work_stealing_policy.c

+ 30 - 29
src/sched_policies/work_stealing_policy.c

@@ -32,15 +32,19 @@
 /* Experimental (dead) code which needs to be tested, fixed... */
 /* #define USE_OVERLOAD */
 
+struct _starpu_work_stealing_data_per_worker
+{
+	struct _starpu_fifo_taskq *queue_array;
+	int *proxlist;
+};
+
 /* Experimental code for improving data cache locality */
 //#define USE_LOCALITY
 
 struct _starpu_work_stealing_data
 {
 	unsigned (*select_victim)(unsigned, int);
-
-	struct _starpu_fifo_taskq **queue_array;
-	int **proxlist;
+	struct _starpu_work_stealing_data_per_worker *per_worker;
 	/* keep track of the work performed from the beginning of the algorithm to make
 	 * better decisions about which queue to select when stealing or deferring work
 	 */
@@ -82,7 +86,7 @@ static unsigned select_victim_round_robin(unsigned sched_ctx_id)
 		/* Here helgrind would shout that this is unprotected, but we
 		 * are fine with getting outdated values, this is just an
 		 * estimation */
-		ntasks = ws->queue_array[workerids[worker]]->ntasks;
+		ntasks = ws->per_worker[workerids[worker]].queue_array->ntasks;
 
 		if (ntasks)
 			break;
@@ -202,8 +206,8 @@ static float overload_metric(unsigned sched_ctx_id, unsigned id)
 	float execution_ratio = 0.0f;
 	float current_ratio = 0.0f;
 
-	int nprocessed = _starpu_get_deque_nprocessed(ws->queue_array[id]);
-	unsigned njobs = _starpu_get_deque_njobs(ws->queue_array[id]);
+	int nprocessed = _starpu_get_deque_nprocessed(ws->per_worker[id].queue_array);
+	unsigned njobs = _starpu_get_deque_njobs(ws->per_worker[id].queue_array);
 
 	/* Did we get enough information ? */
 	if (performed_total > 0 && nprocessed > 0)
@@ -343,7 +347,7 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 
 	STARPU_ASSERT(workerid != -1);
 
-	task = _starpu_fifo_pop_task(ws->queue_array[workerid], workerid);
+	task = _starpu_fifo_pop_task(ws->per_worker[workerid].queue_array, workerid);
 	if (task)
 	{
 		/* there was a local task */
@@ -365,8 +369,8 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	starpu_worker_get_sched_condition(victim, &victim_sched_mutex, &victim_sched_cond);
 	STARPU_PTHREAD_MUTEX_LOCK_SCHED(victim_sched_mutex);
 
-	if (ws->queue_array[victim] != NULL && ws->queue_array[victim]->ntasks > 0)
-		task = _starpu_fifo_pop_task(ws->queue_array[victim], workerid);
+	if (ws->per_worker[victim].queue_array != NULL && ws->per_worker[victim].queue_array->ntasks > 0)
+		task = _starpu_fifo_pop_task(ws->per_worker[victim].queue_array, workerid);
 	if (task)
 	{
 		_STARPU_TRACE_WORK_STEALING(workerid, victim);
@@ -378,8 +382,8 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	STARPU_PTHREAD_MUTEX_LOCK_SCHED(worker_sched_mutex);
 	if(!task)
 	{
-		if (ws->queue_array[workerid] != NULL && ws->queue_array[workerid]->ntasks > 0)
-			task = _starpu_fifo_pop_task(ws->queue_array[workerid], workerid);
+		if (ws->per_worker[workerid].queue_array != NULL && ws->per_worker[workerid].queue_array->ntasks > 0)
+			task = _starpu_fifo_pop_task(ws->per_worker[workerid].queue_array, workerid);
 		if (task)
 		{
 			/* there was a local task */
@@ -416,7 +420,7 @@ pick_worker:
 	STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
 
 	/* Maybe the worker we selected was removed before we picked the mutex */
-	if (ws->queue_array[workerid] == NULL)
+	if (ws->per_worker[workerid].queue_array == NULL)
 		goto pick_worker;
 
 	record_worker_locality(task, workerid);
@@ -430,7 +434,7 @@ pick_worker:
 	}
 #endif
 
-	_starpu_fifo_push_task(ws->queue_array[workerid], task);
+	_starpu_fifo_push_task(ws->per_worker[workerid].queue_array, task);
 
 	starpu_push_task_end(task);
 
@@ -459,11 +463,11 @@ static void ws_add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworke
 	{
 		workerid = workerids[i];
 		starpu_sched_ctx_worker_shares_tasks_lists(workerid, sched_ctx_id);
-		ws->queue_array[workerid] = _starpu_create_fifo();
+		ws->per_worker[workerid].queue_array = _starpu_create_fifo();
 
 		/* Tell helgrid that we are fine with getting outdated values,
 		 * this is just an estimation */
-		STARPU_HG_DISABLE_CHECKING(ws->queue_array[workerid]->ntasks);
+		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].queue_array->ntasks);
 	}
 }
 
@@ -482,13 +486,13 @@ static void ws_remove_workers(unsigned sched_ctx_id, int *workerids, unsigned nw
 		workerid = workerids[i];
 		starpu_worker_get_sched_condition(workerid, &sched_mutex, &sched_cond);
 		STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
-		if (ws->queue_array[workerid] != NULL)
+		if (ws->per_worker[workerid].queue_array != NULL)
 		{
-			_starpu_destroy_fifo(ws->queue_array[workerid]);
-			ws->queue_array[workerid] = NULL;
+			_starpu_destroy_fifo(ws->per_worker[workerid].queue_array);
+			ws->per_worker[workerid].queue_array = NULL;
 		}
-		if (ws->proxlist != NULL)
-			free(ws->proxlist[workerid]);
+		free(ws->per_worker[workerid].proxlist);
+		ws->per_worker[workerid].proxlist = NULL;
 		STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(sched_mutex);
 	}
 }
@@ -500,19 +504,17 @@ static void initialize_ws_policy(unsigned sched_ctx_id)
 
 	ws->last_pop_worker = 0;
 	ws->last_push_worker = 0;
-	ws->proxlist = NULL;
 	ws->select_victim = select_victim;
 
 	unsigned nw = starpu_worker_get_count();
-	ws->queue_array = (struct _starpu_fifo_taskq**)malloc(nw*sizeof(struct _starpu_fifo_taskq*));
+	ws->per_worker = calloc(nw, sizeof(struct _starpu_work_stealing_data_per_worker));
 }
 
 static void deinit_ws_policy(unsigned sched_ctx_id)
 {
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 
-	free(ws->queue_array);
-	free(ws->proxlist);
+	free(ws->per_worker);
 	free(ws);
 }
 
@@ -545,14 +547,14 @@ static unsigned lws_select_victim(unsigned sched_ctx_id, int workerid)
 
 	for (i = 0; i < nworkers; i++)
 	{
-		neighbor = ws->proxlist[workerid][i];
+		neighbor = ws->per_worker[workerid].proxlist[i];
 		/* if a worker was removed, then nothing tells us that the proxlist is correct */
 		if (!starpu_sched_ctx_contains_worker(neighbor, sched_ctx_id))
 		{
 			i--;
 			continue;
 		}
-		int ntasks = ws->queue_array[neighbor]->ntasks;
+		int ntasks = ws->per_worker[neighbor].queue_array->ntasks;
 		if (ntasks)
 			return neighbor;
 	}
@@ -574,13 +576,12 @@ static void lws_add_workers(unsigned sched_ctx_id, int *workerids,
 	 * build this once and then use it for popping tasks rather
 	 * than traversing the hwloc tree every time a task must be
 	 * stolen */
-	ws->proxlist = (int**)malloc(starpu_worker_get_count()*sizeof(int*));
 	struct starpu_worker_collection *workers = starpu_sched_ctx_get_worker_collection(sched_ctx_id);
 	struct starpu_tree *tree = (struct starpu_tree*)workers->workerids;
 	for (i = 0; i < nworkers; i++)
 	{
 		workerid = workerids[i];
-		ws->proxlist[workerid] = (int*)malloc(nworkers*sizeof(int));
+		ws->per_worker[workerid].proxlist = (int*)malloc(nworkers*sizeof(int));
 		int bindid;
 
 		struct starpu_tree *neighbour = NULL;
@@ -601,7 +602,7 @@ static void lws_add_workers(unsigned sched_ctx_id, int *workerids,
 			{
 				if(!it.visited[neigh_workerids[w]] && workers->present[neigh_workerids[w]])
 				{
-					ws->proxlist[workerid][cnt++] = neigh_workerids[w];
+					ws->per_worker[workerid].proxlist[cnt++] = neigh_workerids[w];
 					it.visited[neigh_workerids[w]] = 1;
 				}
 			}