Sfoglia il codice sorgente

update ws and lws with new synchro scheme

Olivier Aumage 8 anni fa
parent
commit
30cf4c32fe

+ 3 - 5
src/core/sched_ctx.c

@@ -2197,8 +2197,6 @@ void starpu_sched_ctx_move_task_to_ctx_locked(struct starpu_task *task, unsigned
 {
 	/* TODO: make something cleaner which differentiates between calls
 	   from push or pop (have mutex or not) and from another worker or not */
-	int workerid = starpu_worker_get_id();
-
 	task->sched_ctx = sched_ctx;
 
 	struct _starpu_job *j = _starpu_get_job_associated_to_task(task);
@@ -2302,7 +2300,7 @@ void starpu_sched_ctx_list_task_counters_decrement_all_ctx_locked(struct starpu_
 			struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);
 			if (worker->nsched_ctxs > 1)
 			{
-				_starpu_worker_lock_for_observation(workerid);
+				_starpu_worker_lock_for_observation_no_relax(workerid);
 				starpu_sched_ctx_list_task_counters_decrement(sched_ctx_id, workerid);
 				_starpu_worker_unlock_for_observation(workerid);
 			}
@@ -2324,7 +2322,7 @@ void starpu_sched_ctx_list_task_counters_decrement_all(struct starpu_task *task,
 			struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);
 			if (worker->nsched_ctxs > 1)
 			{
-				_starpu_worker_lock_for_observation(workerid);
+				_starpu_worker_lock_for_observation_no_relax(workerid);
 				starpu_sched_ctx_list_task_counters_decrement(sched_ctx_id, workerid);
 				_starpu_worker_unlock_for_observation(workerid);
 			}
@@ -2347,7 +2345,7 @@ void starpu_sched_ctx_list_task_counters_reset_all(struct starpu_task *task, uns
 			struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);
 			if (worker->nsched_ctxs > 1)
 			{
-				_starpu_worker_lock_for_observation(workerid);
+				_starpu_worker_lock_for_observation_no_relax(workerid);
 				starpu_sched_ctx_list_task_counters_reset(sched_ctx_id, workerid);
 				_starpu_worker_unlock_for_observation(workerid);
 			}

+ 2 - 2
src/core/sched_policy.c

@@ -65,8 +65,6 @@ static struct starpu_sched_policy *predefined_policies[] =
 	&_starpu_sched_modular_heft_policy,
 	&_starpu_sched_modular_heft_prio_policy,
 	&_starpu_sched_modular_heft2_policy,
-	&_starpu_sched_lws_policy,
-	&_starpu_sched_ws_policy,
 	&_starpu_sched_dm_policy,
 	&_starpu_sched_dmda_policy,
 	&_starpu_sched_dmda_ready_policy,
@@ -80,6 +78,8 @@ static struct starpu_sched_policy *predefined_policies[] =
 	&_starpu_sched_prio_policy,
 	&_starpu_sched_random_policy,
 	&_starpu_sched_peager_policy,
+	&_starpu_sched_ws_policy,
+	&_starpu_sched_lws_policy,
 #warning TODO: update sched policies with new synchro scheme
 #endif
 	NULL

+ 2 - 2
src/core/workers.c

@@ -1692,7 +1692,7 @@ unsigned starpu_worker_get_count(void)
 
 unsigned starpu_worker_is_blocked_in_parallel(int workerid)
 {
-	_starpu_worker_lock_for_observation(workerid);
+	_starpu_worker_lock_for_observation_no_relax(workerid);
 	unsigned ret = _starpu_config.workers[workerid].state_blocked_in_parallel;
 	_starpu_worker_unlock_for_observation(workerid);
 	return ret;
@@ -1700,7 +1700,7 @@ unsigned starpu_worker_is_blocked_in_parallel(int workerid)
 
 unsigned starpu_worker_is_slave_somewhere(int workerid)
 {
-	_starpu_worker_lock_for_observation(workerid);
+	_starpu_worker_lock_for_observation_no_relax(workerid);
 	unsigned ret = _starpu_config.workers[workerid].is_slave_somewhere;
 	_starpu_worker_unlock_for_observation(workerid);
 	return ret;

+ 43 - 1
src/core/workers.h

@@ -834,7 +834,33 @@ static inline void _starpu_worker_leave_changing_ctx_op(struct _starpu_worker *
  *
  * notes:
  * - if the observed worker is not in state_safe_for_observation, the function block until the state is reached */
-static inline void _starpu_worker_lock_for_observation(int workerid)
+static inline void _starpu_worker_lock_for_observation_relax(int workerid)
+{
+	struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);
+	int cur_workerid = starpu_worker_get_id();
+	STARPU_ASSERT(worker != NULL);
+	STARPU_PTHREAD_MUTEX_LOCK_SCHED(&worker->sched_mutex);
+	if (workerid != cur_workerid)
+	{
+		struct _starpu_worker *cur_worker = cur_workerid<starpu_worker_get_count()?_starpu_get_worker_struct(cur_workerid):NULL;
+		int relax_own_observation_state = (cur_worker != NULL) && (cur_worker->state_safe_for_observation == 0);
+		if (relax_own_observation_state && !worker->state_safe_for_observation)
+		{
+			cur_worker->state_safe_for_observation = 1;
+			STARPU_PTHREAD_COND_BROADCAST(&cur_worker->sched_cond);
+		}
+		while (!worker->state_safe_for_observation)
+		{
+			STARPU_PTHREAD_COND_WAIT(&worker->sched_cond, &worker->sched_mutex);
+		}
+		if (relax_own_observation_state)
+		{
+			cur_worker->state_safe_for_observation = 0;
+		}
+	}
+}
+
+static inline void _starpu_worker_lock_for_observation_no_relax(int workerid)
 {
 	struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);
 	STARPU_ASSERT(worker != NULL);
@@ -848,6 +874,22 @@ static inline void _starpu_worker_lock_for_observation(int workerid)
 	}
 }
 
+static inline int _starpu_worker_trylock_for_observation(int workerid)
+{
+	struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);
+	int cur_workerid = starpu_worker_get_id();
+	STARPU_ASSERT(worker != NULL);
+	int ret = STARPU_PTHREAD_MUTEX_TRYLOCK_SCHED(&worker->sched_mutex);
+	if (ret)
+		return ret;
+	if (workerid != cur_workerid) {
+		ret = !worker->state_safe_for_observation;
+		if (ret)
+			STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(&worker->sched_mutex);
+	}
+	return ret;
+}
+
 static inline void _starpu_worker_unlock_for_observation(int workerid)
 {
 	struct _starpu_worker *worker = _starpu_get_worker_struct(workerid);

+ 17 - 28
src/sched_policies/work_stealing_policy.c

@@ -74,7 +74,6 @@ struct _starpu_work_stealing_data_per_worker
 {
 	struct _starpu_fifo_taskq *queue_array;
 	int *proxlist;
-	starpu_pthread_mutex_t worker_mutex;
 	int busy;
 
 #ifdef USE_LOCALITY_TASKS
@@ -529,15 +528,15 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 	if (STARPU_RUNNING_ON_VALGRIND || !_starpu_fifo_empty(ws->per_worker[workerid].queue_array))
 #endif
 	{
-		STARPU_PTHREAD_MUTEX_LOCK(&ws->per_worker[workerid].worker_mutex);
 		task = ws_pick_task(ws, workerid, workerid);
 		if (task)
 			locality_popped_task(ws, task, workerid, sched_ctx_id);
-		STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[workerid].worker_mutex);
 	}
 
 	if (task)
 	{
+		_starpu_worker_enter_section_safe_for_observation();
+		_starpu_sched_ctx_lock_write(sched_ctx_id);
 		/* there was a local task */
 		ws->per_worker[workerid].busy = 1;
 		starpu_sched_ctx_list_task_counters_decrement(sched_ctx_id, workerid);
@@ -546,30 +545,27 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 		{
 			starpu_sched_ctx_move_task_to_ctx(task, child_sched_ctx, 1, 1);
 			starpu_sched_ctx_revert_task_counters(sched_ctx_id, task->flops);
-			return NULL;
+			task = NULL;
 		}
+		_starpu_sched_ctx_unlock_write(sched_ctx_id);
+		_starpu_worker_leave_section_safe_for_observation();
 		return task;
 	}
 
-	starpu_pthread_mutex_t *sched_mutex;
-	starpu_pthread_cond_t *sched_cond;
-	starpu_worker_get_sched_condition(workerid, &sched_mutex, &sched_cond);
 	/* While stealing, relieve mutex used to synchronize with pushers */
-	STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(sched_mutex);
-
+	_starpu_worker_enter_section_safe_for_observation();
 
 	/* we need to steal someone's job */
 	int victim = ws->select_victim(ws, sched_ctx_id, workerid);
+	_starpu_worker_leave_section_safe_for_observation();
 	if (victim == -1)
 	{
-		STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
 		return NULL;
 	}
 
-	if (STARPU_PTHREAD_MUTEX_TRYLOCK(&ws->per_worker[victim].worker_mutex))
+	if (_starpu_worker_trylock_for_observation(victim))
 	{
 		/* victim is busy, don't bother it, come back later */
-		STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
 		return NULL;
 	}
 	if (ws->per_worker[victim].queue_array != NULL && ws->per_worker[victim].queue_array->ntasks > 0)
@@ -586,34 +582,35 @@ static struct starpu_task *ws_pop_task(unsigned sched_ctx_id)
 		record_worker_locality(ws, task, workerid, sched_ctx_id);
 		locality_popped_task(ws, task, victim, sched_ctx_id);
 	}
-	STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[victim].worker_mutex);
-
-	/* Done with stealing, resynchronize with core */
-	STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
+	_starpu_worker_unlock_for_observation(victim);
 
 #ifndef STARPU_NON_BLOCKING_DRIVERS
         /* While stealing, perhaps somebody actually give us a task, don't miss
          * the opportunity to take it before going to sleep. */
 	if (!task)
 	{
-		STARPU_PTHREAD_MUTEX_LOCK(&ws->per_worker[workerid].worker_mutex);
 		task = ws_pick_task(ws, workerid, workerid);
 		if (task)
 			locality_popped_task(ws, task, workerid, sched_ctx_id);
-		STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[workerid].worker_mutex);
 	}
 #endif
 
+	_starpu_worker_enter_section_safe_for_observation();
 	if (task)
 	{
+		_starpu_sched_ctx_lock_write(sched_ctx_id);
 		unsigned child_sched_ctx = starpu_sched_ctx_worker_is_master_for_child_ctx(workerid, sched_ctx_id);
 		if(child_sched_ctx != STARPU_NMAX_SCHED_CTXS)
 		{
 			starpu_sched_ctx_move_task_to_ctx(task, child_sched_ctx, 1, 1);
 			starpu_sched_ctx_revert_task_counters(sched_ctx_id, task->flops);
+			_starpu_sched_ctx_unlock_write(sched_ctx_id);
+			_starpu_worker_leave_section_safe_for_observation();
 			return NULL;
 		}
+		_starpu_sched_ctx_unlock_write(sched_ctx_id);
 	}
+	_starpu_worker_leave_section_safe_for_observation();
 	ws->per_worker[workerid].busy = !!task;
 	return task;
 }
@@ -639,21 +636,15 @@ int ws_push_task(struct starpu_task *task)
 	if (workerid == -1 || !starpu_sched_ctx_contains_worker(workerid, sched_ctx_id) ||
 			!starpu_worker_can_execute_task_first_impl(workerid, task, NULL))
 		workerid = select_worker(ws, task, sched_ctx_id);
-
-	starpu_pthread_mutex_t *sched_mutex;
-	starpu_pthread_cond_t *sched_cond;
-	starpu_worker_get_sched_condition(workerid, &sched_mutex, &sched_cond);
-	STARPU_PTHREAD_MUTEX_LOCK_SCHED(sched_mutex);
+	_starpu_worker_lock_for_observation_relax(workerid);
 	STARPU_AYU_ADDTOTASKQUEUE(starpu_task_get_job_id(task), workerid);
-	STARPU_PTHREAD_MUTEX_LOCK(&ws->per_worker[workerid].worker_mutex);
 	_STARPU_TASK_BREAK_ON(task, sched);
 	record_data_locality(task, workerid);
 	_starpu_fifo_push_task(ws->per_worker[workerid].queue_array, task);
 	locality_pushed_task(ws, task, workerid, sched_ctx_id);
-	STARPU_PTHREAD_MUTEX_UNLOCK(&ws->per_worker[workerid].worker_mutex);
 
 	starpu_push_task_end(task);
-	STARPU_PTHREAD_MUTEX_UNLOCK_SCHED(sched_mutex);
+	_starpu_worker_unlock_for_observation(workerid);
 	starpu_sched_ctx_list_task_counters_increment(sched_ctx_id, workerid);
 
 #if !defined(STARPU_NON_BLOCKING_DRIVERS) || defined(STARPU_SIMGRID)
@@ -682,7 +673,6 @@ static void ws_add_workers(unsigned sched_ctx_id, int *workerids,unsigned nworke
 		/* Tell helgrind that we are fine with getting outdated values,
 		 * this is just an estimation */
 		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].queue_array->ntasks);
-		STARPU_PTHREAD_MUTEX_INIT(&ws->per_worker[workerid].worker_mutex, NULL);
 		ws->per_worker[workerid].busy = 0;
 		STARPU_HG_DISABLE_CHECKING(ws->per_worker[workerid].busy);
 	}
@@ -704,7 +694,6 @@ static void ws_remove_workers(unsigned sched_ctx_id, int *workerids, unsigned nw
 		}
 		free(ws->per_worker[workerid].proxlist);
 		ws->per_worker[workerid].proxlist = NULL;
-		STARPU_PTHREAD_MUTEX_DESTROY(&ws->per_worker[workerid].worker_mutex);
 	}
 }