Переглянути джерело

Abstract combined worker component into parallel worker component

Samuel Thibault 6 роки тому
батько
коміт
7d6e6e0c79

+ 4 - 0
doc/doxygen/chapters/api/modularized_scheduler.doxy

@@ -223,6 +223,10 @@ The actual scheduler
 \ingroup API_Modularized_Scheduler
 	 return the struct starpu_sched_component corresponding to \p workerid. Undefined if \p workerid is not a valid workerid
 
+\fn struct starpu_sched_component *starpu_sched_component_parallel_worker_create(struct starpu_sched_tree *tree, unsigned nworkers, unsigned *workers)
+\ingroup API_Modularized_Scheduler
+	 Create a combined worker that pushes tasks in parallel to workers \p workers (size \p nworkers).
+
 \fn int starpu_sched_component_worker_get_workerid(struct starpu_sched_component *worker_component)
 \ingroup API_Modularized_Scheduler
 	 return the workerid of \p worker_component, undefined if starpu_sched_component_is_worker(worker_component) == 0

+ 1 - 0
include/starpu_sched_component.h

@@ -123,6 +123,7 @@ void starpu_sched_component_connect(struct starpu_sched_component *parent, struc
 
 struct starpu_sched_component *starpu_sched_component_worker_get(unsigned sched_ctx, int workerid);
 struct starpu_sched_component *starpu_sched_component_worker_new(unsigned sched_ctx, int workerid);
+struct starpu_sched_component *starpu_sched_component_parallel_worker_create(struct starpu_sched_tree *tree, unsigned nworkers, unsigned *workers);
 int starpu_sched_component_worker_get_workerid(struct starpu_sched_component *worker_component);
 int starpu_sched_component_is_worker(struct starpu_sched_component *component);
 int starpu_sched_component_is_simple_worker(struct starpu_sched_component *component);

+ 2 - 2
src/sched_policies/component_sched.c

@@ -508,7 +508,7 @@ struct starpu_sched_tree * starpu_sched_tree_get(unsigned sched_ctx_id)
 void starpu_sched_component_add_child(struct starpu_sched_component* component, struct starpu_sched_component * child)
 {
 	STARPU_ASSERT(component && child);
-	STARPU_ASSERT(!starpu_sched_component_is_worker(component));
+	STARPU_ASSERT(!starpu_sched_component_is_simple_worker(component));
 	unsigned i;
 	for(i = 0; i < component->nchildren; i++)
 	{
@@ -524,7 +524,7 @@ void starpu_sched_component_add_child(struct starpu_sched_component* component,
 static void starpu_sched_component_remove_child(struct starpu_sched_component * component, struct starpu_sched_component * child)
 {
 	STARPU_ASSERT(component && child);
-	STARPU_ASSERT(!starpu_sched_component_is_worker(component));
+	STARPU_ASSERT(!starpu_sched_component_is_simple_worker(component));
 	unsigned pos;
 	for(pos = 0; pos < component->nchildren; pos++)
 		if(component->children[pos] == child)

+ 39 - 15
src/sched_policies/component_worker.c

@@ -677,6 +677,11 @@ static int combined_worker_push_task(struct starpu_sched_component * component,
 	return 0;
 }
 
+static struct starpu_task *combined_worker_pull_task(struct starpu_sched_component * from STARPU_ATTRIBUTE_UNUSED, struct starpu_sched_component * to STARPU_ATTRIBUTE_UNUSED)
+{
+	return NULL;
+}
+
 static double combined_worker_estimated_end(struct starpu_sched_component * component)
 {
 	STARPU_ASSERT(starpu_sched_component_is_combined_worker(component));
@@ -705,35 +710,53 @@ static double combined_worker_estimated_load(struct starpu_sched_component * com
 	return load;
 }
 
-static struct starpu_sched_component  * starpu_sched_component_combined_worker_create(struct starpu_sched_tree *tree, int workerid)
+struct starpu_sched_component *starpu_sched_component_parallel_worker_create(struct starpu_sched_tree *tree, unsigned nworkers, unsigned *workers)
 {
-	STARPU_ASSERT(workerid >= 0 && workerid <  STARPU_NMAXWORKERS);
-
-	if(_worker_components[tree->sched_ctx_id][workerid])
-		return _worker_components[tree->sched_ctx_id][workerid];
-
-	struct _starpu_combined_worker * combined_worker = _starpu_get_combined_worker_struct(workerid);
-	if(combined_worker == NULL)
-		return NULL;
 	struct starpu_sched_component * component = starpu_sched_component_create(tree, "combined_worker");
+
 	struct _starpu_worker_component_data *data;
 	_STARPU_MALLOC(data, sizeof(*data));
 	memset(data, 0, sizeof(*data));
-	data->parallel_worker.worker_size = combined_worker->worker_size;
-	memcpy(data->parallel_worker.workerids, combined_worker->combined_workerid, combined_worker->worker_size * sizeof(unsigned));
+	STARPU_ASSERT(nworkers <= STARPU_NMAXWORKERS);
+	STARPU_ASSERT(nworkers <= starpu_worker_get_count());
+	data->parallel_worker.worker_size = nworkers;
+	memcpy(data->parallel_worker.workerids, workers, nworkers * sizeof(unsigned));
 
 	component->data = data;
 	component->push_task = combined_worker_push_task;
-	component->pull_task = NULL;
+	component->pull_task = combined_worker_pull_task;
 	component->estimated_end = combined_worker_estimated_end;
 	component->estimated_load = combined_worker_estimated_load;
 	component->can_pull = combined_worker_can_pull;
 	component->deinit_data = _worker_component_deinit_data;
+	
+	unsigned i;
+	for (i = 0; i < nworkers; i++)
+		starpu_sched_component_connect(component, starpu_sched_component_worker_get(tree->sched_ctx_id, workers[i]));
+
+	return component;
+}
+
+static struct starpu_sched_component  * starpu_sched_component_combined_worker_create(struct starpu_sched_tree *tree, int workerid)
+{
+	STARPU_ASSERT(workerid >= 0 && workerid <  STARPU_NMAXWORKERS);
+
+	if(_worker_components[tree->sched_ctx_id][workerid])
+		return _worker_components[tree->sched_ctx_id][workerid];
+
+	struct _starpu_combined_worker * combined_worker = _starpu_get_combined_worker_struct(workerid);
+	if(combined_worker == NULL)
+		return NULL;
+
+	struct starpu_sched_component *component = starpu_sched_component_parallel_worker_create(tree, combined_worker->worker_size, (unsigned *) combined_worker->combined_workerid);
+
 	starpu_bitmap_set(component->workers, workerid);
 	starpu_bitmap_or(component->workers_in_ctx, component->workers);
+
 	_worker_components[tree->sched_ctx_id][workerid] = component;
 
 #ifdef STARPU_HAVE_HWLOC
+	struct _starpu_worker_component_data * data = component->data;
 	struct _starpu_machine_config *config = _starpu_get_machine_config();
 	struct _starpu_machine_topology *topology = &config->topology;
 	hwloc_obj_t obj = hwloc_get_obj_by_depth(topology->hwtopology, config->cpu_depth, data->parallel_worker.workerids[0]);
@@ -744,7 +767,6 @@ static struct starpu_sched_component  * starpu_sched_component_combined_worker_c
 }
 
 
-
 /******************************************************************************
  *			Worker Components' Public Helper Functions (Part 2)			      *
  *****************************************************************************/
@@ -824,7 +846,8 @@ struct starpu_sched_component * starpu_sched_component_worker_get(unsigned sched
 {
 	STARPU_ASSERT(workerid >= 0 && workerid < STARPU_NMAXWORKERS);
 	/* we may need to take a mutex here */
-	STARPU_ASSERT(_worker_components[sched_ctx][workerid]);
+	if (!_worker_components[sched_ctx][workerid])
+		return starpu_sched_component_worker_new(sched_ctx, workerid);
 	return _worker_components[sched_ctx][workerid];
 }
 
@@ -832,7 +855,8 @@ struct starpu_sched_component * starpu_sched_component_worker_new(unsigned sched
 {
 	STARPU_ASSERT(workerid >= 0 && workerid < STARPU_NMAXWORKERS);
 	/* we may need to take a mutex here */
-	STARPU_ASSERT(!_worker_components[sched_ctx][workerid]);
+	if (_worker_components[sched_ctx][workerid])
+		return _worker_components[sched_ctx][workerid];
 	struct starpu_sched_component * component;
 	if(workerid < (int) starpu_worker_get_count())
 		component = starpu_sched_component_worker_create(starpu_sched_tree_get(sched_ctx), workerid);