joblib.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # StarPU --- Runtime system for heterogeneous multicore architectures.
  2. #
  3. # Copyright (C) 2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
  4. #
  5. # StarPU is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 2.1 of the License, or (at
  8. # your option) any later version.
  9. #
  10. # StarPU is distributed in the hope that it will be useful, but
  11. # WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. #
  14. # See the GNU Lesser General Public License in COPYING.LGPL for more details.
  15. #
  16. import sys
  17. try:
  18. sys.path.remove('/usr/local/lib/python3.8/site-packages/starpu')
  19. except:
  20. pass
  21. import types
  22. import joblib as jl
  23. from joblib import logger
  24. from starpu import starpupy
  25. import starpu
  26. import asyncio
  27. import math
  28. import functools
  29. import numpy as np
  30. import inspect
  31. import threading
  32. BACKENDS={}
  33. _backend = threading.local()
  34. # get the number of CPUs controlled by StarPU
  35. def cpu_count():
  36. n_cpus=starpupy.cpu_worker_get_count()
  37. return n_cpus
  38. # split a list ls into n_block numbers of sub-lists
  39. def partition(ls, n_block):
  40. if len(ls)>=n_block:
  41. # there are n1 sub-lists which contain q1 elements, and (n_block-n1) sublists which contain q2 elements (n1 can be 0)
  42. q1=math.ceil(len(ls)/n_block)
  43. q2=math.floor(len(ls)/n_block)
  44. n1=len(ls)%n_block
  45. #n2=n_block-n1
  46. # generate n1 sub-lists in L1, and (n_block-n1) sub-lists in L2
  47. L1=[ls[i:i+q1] for i in range(0, n1*q1, q1)]
  48. L2=[ls[i:i+q2] for i in range(n1*q1, len(ls), q2)]
  49. L=L1+L2
  50. else:
  51. # if the block number is larger than the length of list, each element in the list is a sub-list
  52. L=[ls[i:i+1] for i in range (len(ls))]
  53. return L
  54. def future_generator(iterable, n_jobs, dict_task):
  55. # iterable is generated by delayed function, after converting to a list, the format is [function, (arg1, arg2, ... ,)]
  56. #print("iterable type is ", type(iterable))
  57. #print("iterable is", iterable)
  58. # get the number of block
  59. if n_jobs<-cpu_count()-1 or n_jobs>cpu_count():
  60. raise SystemExit('Error: n_jobs is out of range')
  61. #print("Error: n_jobs is out of range, number of CPUs is", cpu_count())
  62. elif n_jobs<0:
  63. n_block=cpu_count()+1+n_jobs
  64. else:
  65. n_block=n_jobs
  66. # if arguments is tuple format
  67. if type(iterable) is tuple:
  68. # the function is always the first element
  69. f=iterable[0]
  70. # get the name of formal arguments of f
  71. formal_args=inspect.getargspec(f).args
  72. # get the arguments list
  73. args=[]
  74. # argument is arbitrary in iterable[1]
  75. args=list(iterable[1])
  76. # argument is keyword argument in iterable[2]
  77. for i in range(len(formal_args)):
  78. for j in iterable[2].keys():
  79. if j==formal_args[i]:
  80. args.append(iterable[2][j])
  81. # check whether all arrays have the same size
  82. l_arr=[]
  83. # list of Future result
  84. L_fut=[]
  85. # split the vector
  86. args_split=[]
  87. for i in range(len(args)):
  88. args_split.append([])
  89. # if the array is an numpy array
  90. if type(args[i]) is np.ndarray:
  91. # split numpy array
  92. args_split[i]=np.array_split(args[i],n_block)
  93. # get the length of numpy array
  94. l_arr.append(args[i].size)
  95. # if the array is a generator
  96. elif isinstance(args[i],types.GeneratorType):
  97. # split generator
  98. args_split[i]=partition(list(args[i]),n_block)
  99. # get the length of generator
  100. l_arr.append(sum(len(args_split[i][j]) for j in range(len(args_split[i]))))
  101. if len(set(l_arr))>1:
  102. raise SystemExit('Error: all arrays should have the same size')
  103. #print("args list is", args_split)
  104. for i in range(n_block):
  105. # generate the argument list
  106. L_args=[]
  107. for j in range(len(args)):
  108. if type(args[j]) is np.ndarray or isinstance(args[j],types.GeneratorType):
  109. L_args.append(args_split[j][i])
  110. else:
  111. L_args.append(args[j])
  112. #print("L_args is", L_args)
  113. fut=starpu.task_submit(name=dict_task['name'], synchronous=dict_task['synchronous'], priority=dict_task['priority'],\
  114. color=dict_task['color'], flops=dict_task['flops'], perfmodel=dict_task['perfmodel'])\
  115. (f, *L_args)
  116. L_fut.append(fut)
  117. return L_fut
  118. # if iterable is a generator or a list of function
  119. else:
  120. L=list(iterable)
  121. #print(L)
  122. # generate a list of function according to iterable
  123. def lf(ls):
  124. L_func=[]
  125. for i in range(len(ls)):
  126. # the first element is the function
  127. f=ls[i][0]
  128. # the second element is the args list of a type tuple
  129. L_args=list(ls[i][1])
  130. # generate a list of function
  131. L_func.append(f(*L_args))
  132. return L_func
  133. # generate the split function list
  134. L_split=partition(L,n_block)
  135. # operation in each split list
  136. L_fut=[]
  137. for i in range(len(L_split)):
  138. fut=starpu.task_submit(name=dict_task['name'], synchronous=dict_task['synchronous'], priority=dict_task['priority'],\
  139. color=dict_task['color'], flops=dict_task['flops'], perfmodel=dict_task['perfmodel'])\
  140. (lf, L_split[i])
  141. L_fut.append(fut)
  142. return L_fut
  143. class Parallel(object):
  144. def __init__(self, mode="normal", perfmodel=None, end_msg=None,\
  145. name=None, synchronous=0, priority=0, color=None, flops=None,\
  146. n_jobs=None, backend=None, verbose=0, timeout=None, pre_dispatch='2 * n_jobs',\
  147. batch_size='auto', temp_folder=None, max_nbytes='1M',\
  148. mmap_mode='r', prefer=None, require=None):
  149. active_backend, context_n_jobs = get_active_backend(prefer=prefer, require=require, verbose=verbose)
  150. nesting_level = active_backend.nesting_level
  151. if backend is None:
  152. backend = active_backend
  153. else:
  154. try:
  155. backend_factory = BACKENDS[backend]
  156. except KeyError as e:
  157. raise ValueError("Invalid backend: %s, expected one of %r"
  158. % (backend, sorted(BACKENDS.keys()))) from e
  159. backend = backend_factory(nesting_level=nesting_level)
  160. if n_jobs is None:
  161. n_jobs = 1
  162. self.mode=mode
  163. self.perfmodel=perfmodel
  164. self.end_msg=end_msg
  165. self.name=name
  166. self.synchronous=synchronous
  167. self.priority=priority
  168. self.color=color
  169. self.flops=flops
  170. self.n_jobs=n_jobs
  171. self._backend=backend
  172. def print_progress(self):
  173. #pass
  174. print("", starpupy.task_nsubmitted())
  175. def __call__(self,iterable):
  176. #generate the dictionary of task_submit
  177. dict_task={'name': self.name, 'synchronous': self.synchronous, 'priority': self.priority, 'color': self.color, 'flops': self.flops, 'perfmodel': self.perfmodel}
  178. if hasattr(self._backend, 'start_call'):
  179. self._backend.start_call()
  180. # the mode normal, user can call the function directly without using async
  181. if self.mode=="normal":
  182. async def asy_main():
  183. L_fut=future_generator(iterable, self.n_jobs, dict_task)
  184. res=[]
  185. for i in range(len(L_fut)):
  186. L_res=await L_fut[i]
  187. res.extend(L_res)
  188. #print(res)
  189. #print("type of result is", type(res))
  190. return res
  191. #asyncio.run(asy_main())
  192. #retVal=asy_main
  193. loop = asyncio.get_event_loop()
  194. results = loop.run_until_complete(asy_main())
  195. retVal = results
  196. # the mode future, user needs to use asyncio module and await the Future result in main function
  197. elif self.mode=="future":
  198. L_fut=future_generator(iterable, self.n_jobs, dict_task)
  199. fut=asyncio.gather(*L_fut)
  200. if self.end_msg!=None:
  201. fut.add_done_callback(functools.partial(print, self.end_msg))
  202. retVal=fut
  203. if hasattr(self._backend, 'stop_call'):
  204. self._backend.stop_call()
  205. return retVal
  206. def delayed(function):
  207. def delayed_function(*args, **kwargs):
  208. return function, args, kwargs
  209. return delayed_function
  210. ######################################################################
  211. __version__ = jl.__version__
  212. class Memory(jl.Memory):
  213. def __init__(self,location=None, backend='local', cachedir=None,
  214. mmap_mode=None, compress=False, verbose=1, bytes_limit=None,
  215. backend_options=None):
  216. super(Memory, self).__init__(location=None, backend='local', cachedir=None,
  217. mmap_mode=None, compress=False, verbose=1, bytes_limit=None,
  218. backend_options=None)
  219. def dump(value, filename, compress=0, protocol=None, cache_size=None):
  220. return jl.dump(value, filename, compress, protocol, cache_size)
  221. def load(filename, mmap_mode=None):
  222. return jl.load(filename, mmap_mode)
  223. def hash(obj, hash_name='md5', coerce_mmap=False):
  224. return jl.hash(obj, hash_name, coerce_mmap)
  225. def register_compressor(compressor_name, compressor, force=False):
  226. return jl.register_compressor(compressor_name, compressor, force)
  227. def effective_n_jobs(n_jobs=-1):
  228. return cpu_count()
  229. def get_active_backend(prefer=None, require=None, verbose=0):
  230. return jl.parallel.get_active_backend(prefer, require, verbose)
  231. class parallel_backend(object):
  232. def __init__(self, backend, n_jobs=-1, inner_max_num_threads=None,
  233. **backend_params):
  234. if isinstance(backend, str):
  235. backend = BACKENDS[backend](**backend_params)
  236. current_backend_and_jobs = getattr(_backend, 'backend_and_jobs', None)
  237. if backend.nesting_level is None:
  238. if current_backend_and_jobs is None:
  239. nesting_level = 0
  240. else:
  241. nesting_level = current_backend_and_jobs[0].nesting_level
  242. backend.nesting_level = nesting_level
  243. # Save the backends info and set the active backend
  244. self.old_backend_and_jobs = current_backend_and_jobs
  245. self.new_backend_and_jobs = (backend, n_jobs)
  246. _backend.backend_and_jobs = (backend, n_jobs)
  247. def __enter__(self):
  248. return self.new_backend_and_jobs
  249. def __exit__(self, type, value, traceback):
  250. self.unregister()
  251. def unregister(self):
  252. if self.old_backend_and_jobs is None:
  253. if getattr(_backend, 'backend_and_jobs', None) is not None:
  254. del _backend.backend_and_jobs
  255. else:
  256. _backend.backend_and_jobs = self.old_backend_and_jobs
  257. def register_parallel_backend(name, factory):
  258. BACKENDS[name] = factory