浏览代码

starpupy: return common Future in parallel Future version

HE Kun 4 年之前
父节点
当前提交
d8f092c9c3
共有 2 个文件被更改,包括 6 次插入11 次删除
  1. 2 1
      starpupy/src/starpu/joblib.py
  2. 4 10
      starpupy/tests/starpu_py_parallel.py

+ 2 - 1
starpupy/src/starpu/joblib.py

@@ -85,7 +85,8 @@ def parallel(*, mode, n_jobs=None):
 	elif mode=="future":
 		def parallel_future(g):
 			L_fut=future_generator(g, n_jobs)
-			return L_fut
+			fut=asyncio.gather(*L_fut)
+			return fut
 		return parallel_future
 
 def delayed(f):

+ 4 - 10
starpupy/tests/starpu_py_parallel.py

@@ -74,18 +74,12 @@ print("parallel Future version:")
 print("************************")
 async def main():
 	print("--input is iterable argument list")
-	L_fut1=starpu.joblib.parallel(mode="future", n_jobs=-3)(starpu.joblib.delayed(sqrt)(i**2)for i in range(10))
-	res1=[]
-	for i in range(len(L_fut1)):
-		L_res1=await L_fut1[i]
-		res1.extend(L_res1)
+	fut1=starpu.joblib.parallel(mode="future", n_jobs=-3)(starpu.joblib.delayed(sqrt)(i**2)for i in range(10))
+	res1=await fut1
 	print(res1)
 
 	print("--input is iterable function list")
-	L_fut2=starpu.joblib.parallel(mode="future", n_jobs=2)(g_func)
-	res2=[]
-	for i in range(len(L_fut2)):
-		L_res2=await L_fut2[i]
-		res2.extend(L_res2)
+	fut2=starpu.joblib.parallel(mode="future", n_jobs=2)(g_func)
+	res2=await fut2
 	print(res2)
 asyncio.run(main())