Browse Source

Improve my solution to fix work stealing with contexts

Terry Cojean 9 years ago
parent
commit
b9cda4509e
1 changed files with 21 additions and 22 deletions
  1. 21 22
      src/sched_policies/work_stealing_policy.c

+ 21 - 22
src/sched_policies/work_stealing_policy.c

@@ -67,7 +67,7 @@ static unsigned select_victim_round_robin(unsigned sched_ctx_id)
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned worker = ws->last_pop_worker;
 	unsigned nworkers = starpu_sched_ctx_get_nworkers(sched_ctx_id);
-	int *workerids;
+	int *workerids = NULL;
 	starpu_sched_ctx_get_workers_list(sched_ctx_id, &workerids);
 
 	/* If the worker's queue is empty, let's try
@@ -95,7 +95,10 @@ static unsigned select_victim_round_robin(unsigned sched_ctx_id)
 
 	ws->last_pop_worker = (worker + 1) % nworkers;
 
-	return workerids[worker];
+	worker = workerids[worker];
+	free(workerids);
+
+	return worker;
 }
 
 /**
@@ -107,12 +110,15 @@ static unsigned select_worker_round_robin(unsigned sched_ctx_id)
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data*)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	unsigned worker = ws->last_push_worker;
 	unsigned nworkers = starpu_sched_ctx_get_nworkers(sched_ctx_id);
-	int *workerids;
+	int *workerids = NULL;
 	starpu_sched_ctx_get_workers_list(sched_ctx_id, &workerids);
 
 	ws->last_push_worker = (ws->last_push_worker + 1) % nworkers;
 
-	return workerids[worker];
+	worker = workerids[worker];
+	free(workerids);
+
+	return worker;
 }
 
 #ifdef USE_OVERLOAD
@@ -375,8 +381,7 @@ static void ws_add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworke
 	{
 		workerid = workerids[i];
 		starpu_sched_ctx_worker_shares_tasks_lists(workerid, sched_ctx_id);
-		if (ws->queue_array[workerid] == NULL)
-			ws->queue_array[workerid] = _starpu_create_fifo();
+		ws->queue_array[workerid] = _starpu_create_fifo();
 
 		/* Tell helgrid that we are fine with getting outdated values,
 		 * this is just an estimation */
@@ -394,8 +399,12 @@ static void ws_remove_workers(unsigned sched_ctx_id, int *workerids, unsigned nw
 	for (i = 0; i < nworkers; i++)
 	{
 		workerid = workerids[i];
-		_starpu_destroy_fifo(ws->queue_array[workerid]);
-		if (ws->proxlist)
+		if (ws->queue_array[workerid] != NULL)
+		{
+			_starpu_destroy_fifo(ws->queue_array[workerid]);
+			ws->queue_array[workerid] = NULL;
+		}
+		if (ws->proxlist != NULL)
 			free(ws->proxlist[workerid]);
 	}
 }
@@ -407,11 +416,11 @@ static void initialize_ws_policy(unsigned sched_ctx_id)
 
 	ws->last_pop_worker = 0;
 	ws->last_push_worker = 0;
+	ws->proxlist = NULL;
 	ws->select_victim = select_victim;
 
 	unsigned nw = starpu_worker_get_count();
-	if (ws->queue_array == NULL)
-		ws->queue_array = (struct _starpu_fifo_taskq**)malloc(nw*sizeof(struct _starpu_fifo_taskq*));
+	ws->queue_array = (struct _starpu_fifo_taskq**)malloc(nw*sizeof(struct _starpu_fifo_taskq*));
 }
 
 static void deinit_ws_policy(unsigned sched_ctx_id)
@@ -420,7 +429,7 @@ static void deinit_ws_policy(unsigned sched_ctx_id)
 
 	free(ws->queue_array);
 	free(ws->proxlist);
-        free(ws);
+	free(ws);
 }
 
 struct starpu_sched_policy _starpu_sched_ws_policy =
@@ -475,6 +484,7 @@ static void lws_add_workers(unsigned sched_ctx_id, int *workerids,
 	 * build this once and then use it for popping tasks rather
 	 * than traversing the hwloc tree every time a task must be
 	 * stolen */
+	ws->proxlist = (int**)malloc(starpu_worker_get_count()*sizeof(int*));
 	struct starpu_worker_collection *workers = starpu_sched_ctx_get_worker_collection(sched_ctx_id);
 	struct starpu_tree *tree = (struct starpu_tree*)workers->workerids;
 	for (i = 0; i < nworkers; i++)
@@ -522,17 +532,6 @@ static void initialize_lws_policy(unsigned sched_ctx_id)
 #ifdef STARPU_HAVE_HWLOC
 	struct _starpu_work_stealing_data *ws = (struct _starpu_work_stealing_data *)starpu_sched_ctx_get_policy_data(sched_ctx_id);
 	ws->select_victim = lws_select_victim;
-
-	int nworkers = starpu_worker_get_count(), i;
-
-	if (ws->proxlist == NULL)
-	{
-		ws->proxlist = (int**)malloc(nworkers*sizeof(int*));
-		for (i = 0; i < nworkers; i++) {
-			if (ws->proxlist[i] == NULL)
-				ws->proxlist[i] = (int*)malloc(nworkers*sizeof(int));
-		}
-	}
 #endif
 }