Parcourir la source

starpupy: put aside get_active_backend

HE Kun il y a 4 ans
Parent
commit
944c6a2f69
2 fichiers modifiés avec 33 ajouts et 25 suppressions
  1. 11 11
      starpupy/examples/starpu_py_parallel.py
  2. 22 14
      starpupy/src/joblib.py

+ 11 - 11
starpupy/examples/starpu_py_parallel.py

@@ -102,20 +102,20 @@ def log10_arr(t):
 ########################################################
 
 #################scikit test###################
-DEFAULT_JOBLIB_BACKEND = starpu.joblib.get_active_backend()[0].__class__
-class MyBackend(DEFAULT_JOBLIB_BACKEND):  # type: ignore
-        def __init__(self, *args, **kwargs):
-                self.count = 0
-                super().__init__(*args, **kwargs)
+# DEFAULT_JOBLIB_BACKEND = starpu.joblib.get_active_backend()[0].__class__
+# class MyBackend(DEFAULT_JOBLIB_BACKEND):  # type: ignore
+#         def __init__(self, *args, **kwargs):
+#                 self.count = 0
+#                 super().__init__(*args, **kwargs)
 
-        def start_call(self):
-                self.count += 1
-                return super().start_call()
+#         def start_call(self):
+#                 self.count += 1
+#                 return super().start_call()
 
-starpu.joblib.register_parallel_backend('testing', MyBackend)
+# starpu.joblib.register_parallel_backend('testing', MyBackend)
 
-with starpu.joblib.parallel_backend("testing") as (ba, n_jobs):
-	print("backend and n_jobs is", ba, n_jobs)
+# with starpu.joblib.parallel_backend("testing") as (ba, n_jobs):
+# 	print("backend and n_jobs is", ba, n_jobs)
 ###############################################
 
 N=100

+ 22 - 14
starpupy/src/joblib.py

@@ -18,6 +18,7 @@ import sys
 import types
 import joblib as jl
 from joblib import logger
+from joblib._parallel_backends import ParallelBackendBase
 from starpu import starpupy
 import starpu
 import asyncio
@@ -27,7 +28,9 @@ import numpy as np
 import inspect
 import threading
 
-BACKENDS={}
+BACKENDS={
+	#'loky': LokyBackend,
+}
 _backend = threading.local()
 
 # get the number of CPUs controlled by StarPU
@@ -178,19 +181,19 @@ class Parallel(object):
 	         n_jobs=None, backend=None, verbose=0, timeout=None, pre_dispatch='2 * n_jobs',\
 	         batch_size='auto', temp_folder=None, max_nbytes='1M',\
 	         mmap_mode='r', prefer=None, require=None):
-		active_backend, context_n_jobs = get_active_backend(prefer=prefer, require=require, verbose=verbose)
-		nesting_level = active_backend.nesting_level
+		#active_backend= get_active_backend()
+		# nesting_level = active_backend.nesting_level
 
-		if backend is None:
-			backend = active_backend
+		# if backend is None:
+		# 	backend = active_backend
 
-		else:
-			try:
-				backend_factory = BACKENDS[backend]
-			except KeyError as e:
-				raise ValueError("Invalid backend: %s, expected one of %r"
-                                 % (backend, sorted(BACKENDS.keys()))) from e
-			backend = backend_factory(nesting_level=nesting_level)
+		# else:
+		# 	try:
+		# 		backend_factory = BACKENDS[backend]
+		# 	except KeyError as e:
+		# 		raise ValueError("Invalid backend: %s, expected one of %r"
+  #                                % (backend, sorted(BACKENDS.keys()))) from e
+		# 	backend = backend_factory(nesting_level=nesting_level)
 
 		if n_jobs is None:
 			n_jobs = 1
@@ -275,8 +278,13 @@ def register_compressor(compressor_name, compressor, force=False):
 def effective_n_jobs(n_jobs=-1):
 	return cpu_count()
 
-def get_active_backend(prefer=None, require=None, verbose=0):
-	return jl.parallel.get_active_backend(prefer, require, verbose)
+def get_active_backend():
+	backend_and_jobs = getattr(_backend, 'backend_and_jobs', None)
+	if backend_and_jobs is not None:
+		backend,n_jobs=backend_and_jobs
+		return backend
+	backend = BACKENDS[loky](nesting_level=0)
+	return backend
 
 class parallel_backend(object):
 	def __init__(self, backend, n_jobs=-1, inner_max_num_threads=None,