소스 검색

Add support for limiting gpu memory in simgrid

Samuel Thibault 12 년 전
부모
커밋
8e4534513a
1개의 변경된 파일20개의 추가작업 그리고 19개의 파일을 삭제
  1. 20 19
      src/drivers/cuda/driver_cuda.c

+ 20 - 19
src/drivers/cuda/driver_cuda.c

@@ -38,6 +38,8 @@
 /* the number of CUDA devices */
 /* the number of CUDA devices */
 static int ncudagpus;
 static int ncudagpus;
 
 
+static size_t global_mem[STARPU_NMAXWORKERS];
+
 #ifdef STARPU_USE_CUDA
 #ifdef STARPU_USE_CUDA
 static cudaStream_t streams[STARPU_NMAXWORKERS];
 static cudaStream_t streams[STARPU_NMAXWORKERS];
 static cudaStream_t out_transfer_streams[STARPU_NMAXWORKERS];
 static cudaStream_t out_transfer_streams[STARPU_NMAXWORKERS];
@@ -64,16 +66,17 @@ _starpu_cuda_discover_devices (struct _starpu_machine_config *config)
 #endif
 #endif
 }
 }
 
 
-#ifdef STARPU_USE_CUDA
 /* In case we want to cap the amount of memory available on the GPUs by the
 /* In case we want to cap the amount of memory available on the GPUs by the
  * mean of the STARPU_LIMIT_CUDA_MEM, we decrease the value of
  * mean of the STARPU_LIMIT_CUDA_MEM, we decrease the value of
- * props[devid].totalGlobalMem which is the value returned by
+ * global_mem[devid] which is the value returned by
  * _starpu_cuda_get_global_mem_size() to indicate how much memory can
  * _starpu_cuda_get_global_mem_size() to indicate how much memory can
  * be allocated on the device
  * be allocated on the device
  */
  */
 static void _starpu_cuda_limit_gpu_mem_if_needed(unsigned devid)
 static void _starpu_cuda_limit_gpu_mem_if_needed(unsigned devid)
 {
 {
-	int limit;
+	ssize_t limit;
+	size_t STARPU_ATTRIBUTE_UNUSED totalGlobalMem = 0;
+	size_t STARPU_ATTRIBUTE_UNUSED to_waste = 0;
 	char name[30];
 	char name[30];
 
 
 	limit = starpu_get_env_number("STARPU_LIMIT_CUDA_MEM");
 	limit = starpu_get_env_number("STARPU_LIMIT_CUDA_MEM");
@@ -87,19 +90,24 @@ static void _starpu_cuda_limit_gpu_mem_if_needed(unsigned devid)
 		return;
 		return;
 	}
 	}
 
 
+	global_mem[devid] = limit * 1024*1024;
+
+#ifdef STARPU_USE_CUDA
 	/* Find the size of the memory on the device */
 	/* Find the size of the memory on the device */
-	size_t totalGlobalMem = props[devid].totalGlobalMem;
+	totalGlobalMem = props[devid].totalGlobalMem;
 
 
 	/* How much memory to waste ? */
 	/* How much memory to waste ? */
-	size_t to_waste = totalGlobalMem - (size_t)limit*1024*1024;
+	to_waste = totalGlobalMem - global_mem[devid];
 
 
 	props[devid].totalGlobalMem -= to_waste;
 	props[devid].totalGlobalMem -= to_waste;
+#endif
 
 
 	_STARPU_DEBUG("CUDA device %u: Wasting %ld MB / Limit %ld MB / Total %ld MB / Remains %ld MB\n",
 	_STARPU_DEBUG("CUDA device %u: Wasting %ld MB / Limit %ld MB / Total %ld MB / Remains %ld MB\n",
-			devid, (size_t)to_waste/(1024*1024), (size_t)limit, (size_t)totalGlobalMem/(1024*1024),
-			(size_t)(totalGlobalMem - to_waste)/(1024*1024));
+			devid, (long) to_waste/(1024*1024), (long) limit, (long) totalGlobalMem/(1024*1024),
+			(long) (totalGlobalMem - to_waste)/(1024*1024));
 }
 }
 
 
+#ifdef STARPU_USE_CUDA
 cudaStream_t starpu_cuda_get_local_in_transfer_stream(void)
 cudaStream_t starpu_cuda_get_local_in_transfer_stream(void)
 {
 {
 	int worker = starpu_worker_get_id();
 	int worker = starpu_worker_get_id();
@@ -271,11 +279,7 @@ static void deinit_context(int workerid)
 
 
 static size_t _starpu_cuda_get_global_mem_size(unsigned devid)
 static size_t _starpu_cuda_get_global_mem_size(unsigned devid)
 {
 {
-#ifdef STARPU_USE_CUDA
-	return (size_t)props[devid].totalGlobalMem;
-#else
-	return 0;
-#endif
+	return global_mem[devid];
 }
 }
 
 
 
 
@@ -391,9 +395,7 @@ int _starpu_cuda_driver_init(struct starpu_driver *d)
 	init_context(devid);
 	init_context(devid);
 #endif
 #endif
 
 
-#ifdef STARPU_USE_CUDA
 	_starpu_cuda_limit_gpu_mem_if_needed(devid);
 	_starpu_cuda_limit_gpu_mem_if_needed(devid);
-#endif
 	_starpu_memory_manager_set_global_memory_size(args->memory_node, _starpu_cuda_get_global_mem_size(devid));
 	_starpu_memory_manager_set_global_memory_size(args->memory_node, _starpu_cuda_get_global_mem_size(devid));
 
 
 	/* one more time to avoid hacks from third party lib :) */
 	/* one more time to avoid hacks from third party lib :) */
@@ -401,17 +403,17 @@ int _starpu_cuda_driver_init(struct starpu_driver *d)
 
 
 	args->status = STATUS_UNKNOWN;
 	args->status = STATUS_UNKNOWN;
 
 
+	float size = (float) global_mem[devid] / (1<<30);
 #ifdef STARPU_SIMGRID
 #ifdef STARPU_SIMGRID
 	const char *devname = "Simgrid";
 	const char *devname = "Simgrid";
-	snprintf(args->name, sizeof(args->name), "CUDA %u (%s TODO GiB)", devid, devname);
 #else
 #else
 	/* get the device's name */
 	/* get the device's name */
 	char devname[128];
 	char devname[128];
 	strncpy(devname, props[devid].name, 128);
 	strncpy(devname, props[devid].name, 128);
-	float size = (float) props[devid].totalGlobalMem / (1<<30);
+#endif
 
 
-#ifdef STARPU_HAVE_BUSID
-#ifdef STARPU_HAVE_DOMAINID
+#if defined(STARPU_HAVE_BUSID) && !defined(STARPU_SIMGRID)
+#if defined(STARPU_HAVE_DOMAINID) && !defined(STARPU_SIMGRID)
 	if (props[devid].pciDomainID)
 	if (props[devid].pciDomainID)
 		snprintf(args->name, sizeof(args->name), "CUDA %u (%s %.1f GiB %04x:%02x:%02x.0)", devid, devname, size, props[devid].pciDomainID, props[devid].pciBusID, props[devid].pciDeviceID);
 		snprintf(args->name, sizeof(args->name), "CUDA %u (%s %.1f GiB %04x:%02x:%02x.0)", devid, devname, size, props[devid].pciDomainID, props[devid].pciBusID, props[devid].pciDeviceID);
 	else
 	else
@@ -420,7 +422,6 @@ int _starpu_cuda_driver_init(struct starpu_driver *d)
 #else
 #else
 	snprintf(args->name, sizeof(args->name), "CUDA %u (%s %.1f GiB)", devid, devname, size);
 	snprintf(args->name, sizeof(args->name), "CUDA %u (%s %.1f GiB)", devid, devname, size);
 #endif
 #endif
-#endif
 	snprintf(args->short_name, sizeof(args->short_name), "CUDA %u", devid);
 	snprintf(args->short_name, sizeof(args->short_name), "CUDA %u", devid);
 	_STARPU_DEBUG("cuda (%s) dev id %u thread is ready to run on CPU %d !\n", devname, devid, args->bindid);
 	_STARPU_DEBUG("cuda (%s) dev id %u thread is ready to run on CPU %d !\n", devname, devid, args->bindid);