Coverage for muutils\parallel.py: 93%
86 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
1import multiprocessing
2import functools
3from typing import (
4 Any,
5 Callable,
6 Iterable,
7 Literal,
8 Optional,
9 TypeVar,
10 Dict,
11 List,
12 Union,
13 Protocol,
14)
16# for no tqdm fallback
17from muutils.spinner import SpinnerContext
18from muutils.validate_type import get_fn_allowed_kwargs
21InputType = TypeVar("InputType")
22OutputType = TypeVar("OutputType")
23# typevars for our iterable and map
26class ProgressBarFunction(Protocol):
27 "a protocol for a progress bar function"
29 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
32ProgressBarOption = Literal["tqdm", "spinner", "none", None]
35DEFAULT_PBAR_FN: Callable
36# default progress bar function
39# fallback to spinner option
40def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
41 mapped_kwargs: dict = {
42 k: v
43 for k, v in kwargs.items()
44 if k in get_fn_allowed_kwargs(SpinnerContext.__init__)
45 }
46 if "desc" in kwargs and "message" not in mapped_kwargs:
47 mapped_kwargs["message"] = kwargs.get("desc")
49 if "message" not in mapped_kwargs and "total" in kwargs:
50 mapped_kwargs["message"] = f"Processing {kwargs.get('total')} items"
52 with SpinnerContext(**mapped_kwargs):
53 output = list(x)
55 return output
58def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
59 "fallback to no progress bar"
60 return x
63# set the default progress bar function
64try:
65 # use tqdm if it's available
66 import tqdm # type: ignore[import-untyped]
68 @functools.wraps(tqdm.tqdm)
69 def tqdm_wrap(x: Iterable, **kwargs) -> Iterable:
70 mapped_kwargs: dict = {
71 k: v for k, v in kwargs.items() if k in get_fn_allowed_kwargs(tqdm.tqdm)
72 }
73 if "message" in kwargs and "desc" not in mapped_kwargs:
74 mapped_kwargs["desc"] = mapped_kwargs.get("desc")
75 return tqdm.tqdm(x, **mapped_kwargs)
77 DEFAULT_PBAR_FN = tqdm_wrap
79except ImportError:
80 # use progress bar as fallback
81 DEFAULT_PBAR_FN = spinner_fn_wrap
84def set_up_progress_bar_fn(
85 pbar: Union[ProgressBarFunction, ProgressBarOption],
86 pbar_kwargs: Optional[Dict[str, Any]] = None,
87 **extra_kwargs,
88) -> ProgressBarFunction:
89 pbar_fn: ProgressBarFunction
91 if pbar_kwargs is None:
92 pbar_kwargs = dict()
94 pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
96 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
97 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
98 pbar_fn = no_progress_fn_wrap # type: ignore[assignment]
100 # if `pbar` is a different string, figure out which progress bar to use
101 elif isinstance(pbar, str):
102 if pbar == "tqdm":
103 pbar_fn = functools.partial(tqdm.tqdm, **pbar_kwargs)
104 elif pbar == "spinner":
105 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
106 else:
107 raise ValueError(
108 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
109 )
110 else:
111 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
112 pbar_fn = functools.partial(pbar, **pbar_kwargs)
114 return pbar_fn
117def run_maybe_parallel(
118 func: Callable[[InputType], OutputType],
119 iterable: Iterable[InputType],
120 parallel: Union[bool, int],
121 pbar_kwargs: Optional[Dict[str, Any]] = None,
122 chunksize: Optional[int] = None,
123 keep_ordered: bool = True,
124 use_multiprocess: bool = False,
125 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
126) -> List[OutputType]:
127 """a function to make it easier to sometimes parallelize an operation
129 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
130 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
131 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
133 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
135 # Parameters:
136 - `func : Callable[[InputType], OutputType]`
137 function passed to either `map` or `Pool.imap`
138 - `iterable : Iterable[InputType]`
139 iterable passed to either `map` or `Pool.imap`
140 - `parallel : bool | int`
141 _description_
142 - `pbar_kwargs : Dict[str, Any]`
143 _description_
145 # Returns:
146 - `List[OutputType]`
147 _description_
149 # Raises:
150 - `ValueError` : _description_
151 """
153 # number of inputs in iterable
154 n_inputs: int = len(iterable) # type: ignore[arg-type]
155 if n_inputs == 0:
156 # Return immediately if there is no input
157 return list()
159 # which progress bar to use
160 pbar_fn: ProgressBarFunction = set_up_progress_bar_fn(
161 pbar=pbar,
162 pbar_kwargs=pbar_kwargs,
163 # extra kwargs
164 total=n_inputs,
165 )
167 # number of processes
168 num_processes: int
169 if isinstance(parallel, bool):
170 num_processes = multiprocessing.cpu_count() if parallel else 1
171 elif isinstance(parallel, int):
172 if parallel < 2:
173 raise ValueError(
174 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
175 )
176 num_processes = parallel
177 else:
178 raise ValueError(
179 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
180 )
182 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
183 num_processes = min(num_processes, n_inputs)
184 mp = multiprocessing
185 if num_processes == 1:
186 parallel = False
188 if use_multiprocess:
189 if not parallel:
190 raise ValueError("`use_multiprocess=True` requires `parallel=True`")
192 try:
193 import multiprocess # type: ignore[import-untyped]
194 except ImportError as e:
195 raise ImportError(
196 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
197 ) from e
199 mp = multiprocess
201 # set up the map function -- maybe its parallel, maybe it's just `map`
202 do_map: Callable[
203 [Callable[[InputType], OutputType], Iterable[InputType]],
204 Iterable[OutputType],
205 ]
206 if parallel:
207 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
208 pool = mp.Pool(num_processes)
210 # use `imap` if we want to keep the order, otherwise use `imap_unordered`
211 if keep_ordered:
212 do_map = pool.imap
213 else:
214 do_map = pool.imap_unordered
216 # figure out a smart chunksize if one is not given
217 chunksize_int: int
218 if chunksize is None:
219 chunksize_int = max(1, n_inputs // num_processes)
220 else:
221 chunksize_int = chunksize
223 # set the chunksize
224 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore
226 else:
227 do_map = map
229 # run the map function with a progress bar
230 output: List[OutputType] = list(
231 pbar_fn(
232 do_map(
233 func,
234 iterable,
235 )
236 )
237 )
239 # close the pool if we used one
240 if parallel:
241 pool.close()
242 pool.join()
244 # return the output as a list
245 return output