Преглед на файлове

reorganize CUDA initialization to make it more orthogonal

Samuel Thibault преди 10 години
родител
ревизия
4224463759
променени са 1 файла, в които са добавени 38 реда и са изтрити 31 реда
  1. 38 31
      src/drivers/cuda/driver_cuda.c

+ 38 - 31
src/drivers/cuda/driver_cuda.c

@@ -233,23 +233,12 @@ done:
 #endif
 }
 
-static void init_context(struct _starpu_worker_set *worker_set, unsigned devid STARPU_ATTRIBUTE_UNUSED)
+#ifndef STARPU_SIMGRID
+static void init_device_context(unsigned devid)
 {
 	int workerid;
-	unsigned i, j;
+	unsigned i;
 
-#ifdef STARPU_SIMGRID
-	for (i = 0; i < worker_set->nworkers; i++)
-	{
-		workerid = worker_set->workers[i].workerid;
-		for (j = 0; j < STARPU_MAX_PIPELINE; j++)
-		{
-			task_finished[workerid][j] = 0;
-			STARPU_PTHREAD_MUTEX_INIT(&task_mutex[workerid][j], NULL);
-			STARPU_PTHREAD_COND_INIT(&task_cond[workerid][j], NULL);
-		}
-	}
-#else /* !SIMGRID */
 	cudaError_t cures;
 
 	/* TODO: cudaSetDeviceFlag(cudaDeviceMapHost) */
@@ -301,20 +290,6 @@ static void init_context(struct _starpu_worker_set *worker_set, unsigned devid S
 	}
 #endif
 
-	for (i = 0; i < worker_set->nworkers; i++)
-	{
-		workerid = worker_set->workers[i].workerid;
-
-		for (j = 0; j < STARPU_MAX_PIPELINE; j++)
-			cures = cudaEventCreateWithFlags(&task_events[workerid][j], cudaEventDisableTiming);
-		if (STARPU_UNLIKELY(cures))
-			STARPU_CUDA_REPORT_ERROR(cures);
-
-		cures = cudaStreamCreate(&streams[workerid]);
-		if (STARPU_UNLIKELY(cures))
-			STARPU_CUDA_REPORT_ERROR(cures);
-	}
-
 	cures = cudaStreamCreate(&in_transfer_streams[devid]);
 	if (STARPU_UNLIKELY(cures))
 		STARPU_CUDA_REPORT_ERROR(cures);
@@ -332,7 +307,34 @@ static void init_context(struct _starpu_worker_set *worker_set, unsigned devid S
 		if (STARPU_UNLIKELY(cures))
 			STARPU_CUDA_REPORT_ERROR(cures);
 	}
-#endif /* !SIMGRID */
+}
+#endif /* !STARPU_SIMGRID */
+
+static void init_worker_context(unsigned workerid)
+{
+	int j;
+#ifdef STARPU_SIMGRID
+	for (j = 0; j < STARPU_MAX_PIPELINE; j++)
+	{
+		task_finished[workerid][j] = 0;
+		STARPU_PTHREAD_MUTEX_INIT(&task_mutex[workerid][j], NULL);
+		STARPU_PTHREAD_COND_INIT(&task_cond[workerid][j], NULL);
+	}
+#else /* !STARPU_SIMGRID */
+	cudaError_t cures;
+
+	for (j = 0; j < STARPU_MAX_PIPELINE; j++)
+	{
+		cures = cudaEventCreateWithFlags(&task_events[workerid][j], cudaEventDisableTiming);
+		if (STARPU_UNLIKELY(cures))
+			STARPU_CUDA_REPORT_ERROR(cures);
+	}
+
+	cures = cudaStreamCreate(&streams[workerid]);
+	if (STARPU_UNLIKELY(cures))
+		STARPU_CUDA_REPORT_ERROR(cures);
+
+#endif /* !STARPU_SIMGRID */
 }
 
 static void deinit_context(struct _starpu_worker_set *worker_set)
@@ -579,7 +581,9 @@ int _starpu_cuda_driver_init(struct _starpu_worker_set *worker_set)
 	}
 #endif
 
-	init_context(worker_set, devid);
+#ifndef STARPU_SIMGRID
+	init_device_context(devid);
+#endif
 
 #ifdef STARPU_SIMGRID
 	STARPU_ASSERT_MSG (worker_set->nworkers = 1, "Simgrid mode does not support concurrent kernel execution yet\n");
@@ -608,6 +612,7 @@ int _starpu_cuda_driver_init(struct _starpu_worker_set *worker_set)
 	for (i = 0; i < worker_set->nworkers; i++)
 	{
 		struct _starpu_worker *worker = &worker_set->workers[i];
+		unsigned workerid = worker_set->workers[i].workerid;
 #if defined(STARPU_HAVE_BUSID) && !defined(STARPU_SIMGRID)
 #if defined(STARPU_HAVE_DOMAINID) && !defined(STARPU_SIMGRID)
 		if (props[devid].pciDomainID)
@@ -646,7 +651,9 @@ int _starpu_cuda_driver_init(struct _starpu_worker_set *worker_set)
 			worker->pipeline_length = 0;
 		}
 #endif
-		_STARPU_TRACE_WORKER_INIT_END(worker_set->workers[i].workerid);
+		init_worker_context(workerid);
+
+		_STARPU_TRACE_WORKER_INIT_END(workerid);
 	}
 
 	/* tell the main thread that this one is ready */