muutils.parallel
1import multiprocessing 2import functools 3from typing import ( 4 Any, 5 Callable, 6 Iterable, 7 Literal, 8 Optional, 9 TypeVar, 10 Dict, 11 List, 12 Union, 13 Protocol, 14) 15 16# for no tqdm fallback 17from muutils.spinner import SpinnerContext 18from muutils.validate_type import get_fn_allowed_kwargs 19 20 21InputType = TypeVar("InputType") 22OutputType = TypeVar("OutputType") 23# typevars for our iterable and map 24 25 26class ProgressBarFunction(Protocol): 27 "a protocol for a progress bar function" 28 29 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 30 31 32ProgressBarOption = Literal["tqdm", "spinner", "none", None] 33 34 35DEFAULT_PBAR_FN: Callable 36# default progress bar function 37 38 39# fallback to spinner option 40def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 41 mapped_kwargs: dict = { 42 k: v 43 for k, v in kwargs.items() 44 if k in get_fn_allowed_kwargs(SpinnerContext.__init__) 45 } 46 if "desc" in kwargs and "message" not in mapped_kwargs: 47 mapped_kwargs["message"] = kwargs.get("desc") 48 49 if "message" not in mapped_kwargs and "total" in kwargs: 50 mapped_kwargs["message"] = f"Processing {kwargs.get('total')} items" 51 52 with SpinnerContext(**mapped_kwargs): 53 output = list(x) 54 55 return output 56 57 58def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 59 "fallback to no progress bar" 60 return x 61 62 63# set the default progress bar function 64try: 65 # use tqdm if it's available 66 import tqdm # type: ignore[import-untyped] 67 68 @functools.wraps(tqdm.tqdm) 69 def tqdm_wrap(x: Iterable, **kwargs) -> Iterable: 70 mapped_kwargs: dict = { 71 k: v for k, v in kwargs.items() if k in get_fn_allowed_kwargs(tqdm.tqdm) 72 } 73 if "message" in kwargs and "desc" not in mapped_kwargs: 74 mapped_kwargs["desc"] = mapped_kwargs.get("desc") 75 return tqdm.tqdm(x, **mapped_kwargs) 76 77 DEFAULT_PBAR_FN = tqdm_wrap 78 79except ImportError: 80 # use progress bar as fallback 81 DEFAULT_PBAR_FN = spinner_fn_wrap 82 83 84def set_up_progress_bar_fn( 85 pbar: Union[ProgressBarFunction, ProgressBarOption], 86 pbar_kwargs: Optional[Dict[str, Any]] = None, 87 **extra_kwargs, 88) -> ProgressBarFunction: 89 pbar_fn: ProgressBarFunction 90 91 if pbar_kwargs is None: 92 pbar_kwargs = dict() 93 94 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 95 96 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 97 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 98 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 99 100 # if `pbar` is a different string, figure out which progress bar to use 101 elif isinstance(pbar, str): 102 if pbar == "tqdm": 103 pbar_fn = functools.partial(tqdm.tqdm, **pbar_kwargs) 104 elif pbar == "spinner": 105 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 106 else: 107 raise ValueError( 108 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 109 ) 110 else: 111 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 112 pbar_fn = functools.partial(pbar, **pbar_kwargs) 113 114 return pbar_fn 115 116 117def run_maybe_parallel( 118 func: Callable[[InputType], OutputType], 119 iterable: Iterable[InputType], 120 parallel: Union[bool, int], 121 pbar_kwargs: Optional[Dict[str, Any]] = None, 122 chunksize: Optional[int] = None, 123 keep_ordered: bool = True, 124 use_multiprocess: bool = False, 125 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 126) -> List[OutputType]: 127 """a function to make it easier to sometimes parallelize an operation 128 129 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 130 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 131 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 132 133 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 134 135 # Parameters: 136 - `func : Callable[[InputType], OutputType]` 137 function passed to either `map` or `Pool.imap` 138 - `iterable : Iterable[InputType]` 139 iterable passed to either `map` or `Pool.imap` 140 - `parallel : bool | int` 141 _description_ 142 - `pbar_kwargs : Dict[str, Any]` 143 _description_ 144 145 # Returns: 146 - `List[OutputType]` 147 _description_ 148 149 # Raises: 150 - `ValueError` : _description_ 151 """ 152 153 # number of inputs in iterable 154 n_inputs: int = len(iterable) # type: ignore[arg-type] 155 if n_inputs == 0: 156 # Return immediately if there is no input 157 return list() 158 159 # which progress bar to use 160 pbar_fn: ProgressBarFunction = set_up_progress_bar_fn( 161 pbar=pbar, 162 pbar_kwargs=pbar_kwargs, 163 # extra kwargs 164 total=n_inputs, 165 ) 166 167 # number of processes 168 num_processes: int 169 if isinstance(parallel, bool): 170 num_processes = multiprocessing.cpu_count() if parallel else 1 171 elif isinstance(parallel, int): 172 if parallel < 2: 173 raise ValueError( 174 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 175 ) 176 num_processes = parallel 177 else: 178 raise ValueError( 179 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 180 ) 181 182 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 183 num_processes = min(num_processes, n_inputs) 184 mp = multiprocessing 185 if num_processes == 1: 186 parallel = False 187 188 if use_multiprocess: 189 if not parallel: 190 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 191 192 try: 193 import multiprocess # type: ignore[import-untyped] 194 except ImportError as e: 195 raise ImportError( 196 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 197 ) from e 198 199 mp = multiprocess 200 201 # set up the map function -- maybe its parallel, maybe it's just `map` 202 do_map: Callable[ 203 [Callable[[InputType], OutputType], Iterable[InputType]], 204 Iterable[OutputType], 205 ] 206 if parallel: 207 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 208 pool = mp.Pool(num_processes) 209 210 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 211 if keep_ordered: 212 do_map = pool.imap 213 else: 214 do_map = pool.imap_unordered 215 216 # figure out a smart chunksize if one is not given 217 chunksize_int: int 218 if chunksize is None: 219 chunksize_int = max(1, n_inputs // num_processes) 220 else: 221 chunksize_int = chunksize 222 223 # set the chunksize 224 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 225 226 else: 227 do_map = map 228 229 # run the map function with a progress bar 230 output: List[OutputType] = list( 231 pbar_fn( 232 do_map( 233 func, 234 iterable, 235 ) 236 ) 237 ) 238 239 # close the pool if we used one 240 if parallel: 241 pool.close() 242 pool.join() 243 244 # return the output as a list 245 return output
27class ProgressBarFunction(Protocol): 28 "a protocol for a progress bar function" 29 30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
a protocol for a progress bar function
1710def _no_init_or_replace_init(self, *args, **kwargs): 1711 cls = type(self) 1712 1713 if cls._is_protocol: 1714 raise TypeError('Protocols cannot be instantiated') 1715 1716 # Already using a custom `__init__`. No need to calculate correct 1717 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1718 if cls.__init__ is not _no_init_or_replace_init: 1719 return 1720 1721 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1722 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1723 # searches for a proper new `__init__` in the MRO. The new `__init__` 1724 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1725 # instantiation of the protocol subclass will thus use the new 1726 # `__init__` and no longer call `_no_init_or_replace_init`. 1727 for base in cls.__mro__: 1728 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1729 if init is not _no_init_or_replace_init: 1730 cls.__init__ = init 1731 break 1732 else: 1733 # should not happen 1734 cls.__init__ = object.__init__ 1735 1736 cls.__init__(self, *args, **kwargs)
246class tqdm(Comparable): 247 """ 248 Decorate an iterable object, returning an iterator which acts exactly 249 like the original iterable, but prints a dynamically updating 250 progressbar every time a value is requested. 251 252 Parameters 253 ---------- 254 iterable : iterable, optional 255 Iterable to decorate with a progressbar. 256 Leave blank to manually manage the updates. 257 desc : str, optional 258 Prefix for the progressbar. 259 total : int or float, optional 260 The number of expected iterations. If unspecified, 261 len(iterable) is used if possible. If float("inf") or as a last 262 resort, only basic progress statistics are displayed 263 (no ETA, no progressbar). 264 If `gui` is True and this parameter needs subsequent updating, 265 specify an initial arbitrary large positive number, 266 e.g. 9e9. 267 leave : bool, optional 268 If [default: True], keeps all traces of the progressbar 269 upon termination of iteration. 270 If `None`, will leave only if `position` is `0`. 271 file : `io.TextIOWrapper` or `io.StringIO`, optional 272 Specifies where to output the progress messages 273 (default: sys.stderr). Uses `file.write(str)` and `file.flush()` 274 methods. For encoding, see `write_bytes`. 275 ncols : int, optional 276 The width of the entire output message. If specified, 277 dynamically resizes the progressbar to stay within this bound. 278 If unspecified, attempts to use environment width. The 279 fallback is a meter width of 10 and no limit for the counter and 280 statistics. If 0, will not print any meter (only stats). 281 mininterval : float, optional 282 Minimum progress display update interval [default: 0.1] seconds. 283 maxinterval : float, optional 284 Maximum progress display update interval [default: 10] seconds. 285 Automatically adjusts `miniters` to correspond to `mininterval` 286 after long display update lag. Only works if `dynamic_miniters` 287 or monitor thread is enabled. 288 miniters : int or float, optional 289 Minimum progress display update interval, in iterations. 290 If 0 and `dynamic_miniters`, will automatically adjust to equal 291 `mininterval` (more CPU efficient, good for tight loops). 292 If > 0, will skip display of specified number of iterations. 293 Tweak this and `mininterval` to get very efficient loops. 294 If your progress is erratic with both fast and slow iterations 295 (network, skipping items, etc) you should set miniters=1. 296 ascii : bool or str, optional 297 If unspecified or False, use unicode (smooth blocks) to fill 298 the meter. The fallback is to use ASCII characters " 123456789#". 299 disable : bool, optional 300 Whether to disable the entire progressbar wrapper 301 [default: False]. If set to None, disable on non-TTY. 302 unit : str, optional 303 String that will be used to define the unit of each iteration 304 [default: it]. 305 unit_scale : bool or int or float, optional 306 If 1 or True, the number of iterations will be reduced/scaled 307 automatically and a metric prefix following the 308 International System of Units standard will be added 309 (kilo, mega, etc.) [default: False]. If any other non-zero 310 number, will scale `total` and `n`. 311 dynamic_ncols : bool, optional 312 If set, constantly alters `ncols` and `nrows` to the 313 environment (allowing for window resizes) [default: False]. 314 smoothing : float, optional 315 Exponential moving average smoothing factor for speed estimates 316 (ignored in GUI mode). Ranges from 0 (average speed) to 1 317 (current/instantaneous speed) [default: 0.3]. 318 bar_format : str, optional 319 Specify a custom bar string formatting. May impact performance. 320 [default: '{l_bar}{bar}{r_bar}'], where 321 l_bar='{desc}: {percentage:3.0f}%|' and 322 r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' 323 '{rate_fmt}{postfix}]' 324 Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt, 325 percentage, elapsed, elapsed_s, ncols, nrows, desc, unit, 326 rate, rate_fmt, rate_noinv, rate_noinv_fmt, 327 rate_inv, rate_inv_fmt, postfix, unit_divisor, 328 remaining, remaining_s, eta. 329 Note that a trailing ": " is automatically removed after {desc} 330 if the latter is empty. 331 initial : int or float, optional 332 The initial counter value. Useful when restarting a progress 333 bar [default: 0]. If using float, consider specifying `{n:.3f}` 334 or similar in `bar_format`, or specifying `unit_scale`. 335 position : int, optional 336 Specify the line offset to print this bar (starting from 0) 337 Automatic if unspecified. 338 Useful to manage multiple bars at once (eg, from threads). 339 postfix : dict or *, optional 340 Specify additional stats to display at the end of the bar. 341 Calls `set_postfix(**postfix)` if possible (dict). 342 unit_divisor : float, optional 343 [default: 1000], ignored unless `unit_scale` is True. 344 write_bytes : bool, optional 345 Whether to write bytes. If (default: False) will write unicode. 346 lock_args : tuple, optional 347 Passed to `refresh` for intermediate output 348 (initialisation, iterating, and updating). 349 nrows : int, optional 350 The screen height. If specified, hides nested bars outside this 351 bound. If unspecified, attempts to use environment height. 352 The fallback is 20. 353 colour : str, optional 354 Bar colour (e.g. 'green', '#00ff00'). 355 delay : float, optional 356 Don't display until [default: 0] seconds have elapsed. 357 gui : bool, optional 358 WARNING: internal parameter - do not use. 359 Use tqdm.gui.tqdm(...) instead. If set, will attempt to use 360 matplotlib animations for a graphical output [default: False]. 361 362 Returns 363 ------- 364 out : decorated iterator. 365 """ 366 367 monitor_interval = 10 # set to 0 to disable the thread 368 monitor = None 369 _instances = WeakSet() 370 371 @staticmethod 372 def format_sizeof(num, suffix='', divisor=1000): 373 """ 374 Formats a number (greater than unity) with SI Order of Magnitude 375 prefixes. 376 377 Parameters 378 ---------- 379 num : float 380 Number ( >= 1) to format. 381 suffix : str, optional 382 Post-postfix [default: '']. 383 divisor : float, optional 384 Divisor between prefixes [default: 1000]. 385 386 Returns 387 ------- 388 out : str 389 Number with Order of Magnitude SI unit postfix. 390 """ 391 for unit in ['', 'k', 'M', 'G', 'T', 'P', 'E', 'Z']: 392 if abs(num) < 999.5: 393 if abs(num) < 99.95: 394 if abs(num) < 9.995: 395 return f'{num:1.2f}{unit}{suffix}' 396 return f'{num:2.1f}{unit}{suffix}' 397 return f'{num:3.0f}{unit}{suffix}' 398 num /= divisor 399 return f'{num:3.1f}Y{suffix}' 400 401 @staticmethod 402 def format_interval(t): 403 """ 404 Formats a number of seconds as a clock time, [H:]MM:SS 405 406 Parameters 407 ---------- 408 t : int 409 Number of seconds. 410 411 Returns 412 ------- 413 out : str 414 [H:]MM:SS 415 """ 416 mins, s = divmod(int(t), 60) 417 h, m = divmod(mins, 60) 418 return f'{h:d}:{m:02d}:{s:02d}' if h else f'{m:02d}:{s:02d}' 419 420 @staticmethod 421 def format_num(n): 422 """ 423 Intelligent scientific notation (.3g). 424 425 Parameters 426 ---------- 427 n : int or float or Numeric 428 A Number. 429 430 Returns 431 ------- 432 out : str 433 Formatted number. 434 """ 435 f = f'{n:.3g}'.replace('e+0', 'e+').replace('e-0', 'e-') 436 n = str(n) 437 return f if len(f) < len(n) else n 438 439 @staticmethod 440 def status_printer(file): 441 """ 442 Manage the printing and in-place updating of a line of characters. 443 Note that if the string is longer than a line, then in-place 444 updating may not work (it will print a new line at each refresh). 445 """ 446 fp = file 447 fp_flush = getattr(fp, 'flush', lambda: None) # pragma: no cover 448 if fp in (sys.stderr, sys.stdout): 449 getattr(sys.stderr, 'flush', lambda: None)() 450 getattr(sys.stdout, 'flush', lambda: None)() 451 452 def fp_write(s): 453 fp.write(str(s)) 454 fp_flush() 455 456 last_len = [0] 457 458 def print_status(s): 459 len_s = disp_len(s) 460 fp_write('\r' + s + (' ' * max(last_len[0] - len_s, 0))) 461 last_len[0] = len_s 462 463 return print_status 464 465 @staticmethod 466 def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it', 467 unit_scale=False, rate=None, bar_format=None, postfix=None, 468 unit_divisor=1000, initial=0, colour=None, **extra_kwargs): 469 """ 470 Return a string-based progress bar given some parameters 471 472 Parameters 473 ---------- 474 n : int or float 475 Number of finished iterations. 476 total : int or float 477 The expected total number of iterations. If meaningless (None), 478 only basic progress statistics are displayed (no ETA). 479 elapsed : float 480 Number of seconds passed since start. 481 ncols : int, optional 482 The width of the entire output message. If specified, 483 dynamically resizes `{bar}` to stay within this bound 484 [default: None]. If `0`, will not print any bar (only stats). 485 The fallback is `{bar:10}`. 486 prefix : str, optional 487 Prefix message (included in total width) [default: '']. 488 Use as {desc} in bar_format string. 489 ascii : bool, optional or str, optional 490 If not set, use unicode (smooth blocks) to fill the meter 491 [default: False]. The fallback is to use ASCII characters 492 " 123456789#". 493 unit : str, optional 494 The iteration unit [default: 'it']. 495 unit_scale : bool or int or float, optional 496 If 1 or True, the number of iterations will be printed with an 497 appropriate SI metric prefix (k = 10^3, M = 10^6, etc.) 498 [default: False]. If any other non-zero number, will scale 499 `total` and `n`. 500 rate : float, optional 501 Manual override for iteration rate. 502 If [default: None], uses n/elapsed. 503 bar_format : str, optional 504 Specify a custom bar string formatting. May impact performance. 505 [default: '{l_bar}{bar}{r_bar}'], where 506 l_bar='{desc}: {percentage:3.0f}%|' and 507 r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' 508 '{rate_fmt}{postfix}]' 509 Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt, 510 percentage, elapsed, elapsed_s, ncols, nrows, desc, unit, 511 rate, rate_fmt, rate_noinv, rate_noinv_fmt, 512 rate_inv, rate_inv_fmt, postfix, unit_divisor, 513 remaining, remaining_s, eta. 514 Note that a trailing ": " is automatically removed after {desc} 515 if the latter is empty. 516 postfix : *, optional 517 Similar to `prefix`, but placed at the end 518 (e.g. for additional stats). 519 Note: postfix is usually a string (not a dict) for this method, 520 and will if possible be set to postfix = ', ' + postfix. 521 However other types are supported (#382). 522 unit_divisor : float, optional 523 [default: 1000], ignored unless `unit_scale` is True. 524 initial : int or float, optional 525 The initial counter value [default: 0]. 526 colour : str, optional 527 Bar colour (e.g. 'green', '#00ff00'). 528 529 Returns 530 ------- 531 out : Formatted meter and stats, ready to display. 532 """ 533 534 # sanity check: total 535 if total and n >= (total + 0.5): # allow float imprecision (#849) 536 total = None 537 538 # apply custom scale if necessary 539 if unit_scale and unit_scale not in (True, 1): 540 if total: 541 total *= unit_scale 542 n *= unit_scale 543 if rate: 544 rate *= unit_scale # by default rate = self.avg_dn / self.avg_dt 545 unit_scale = False 546 547 elapsed_str = tqdm.format_interval(elapsed) 548 549 # if unspecified, attempt to use rate = average speed 550 # (we allow manual override since predicting time is an arcane art) 551 if rate is None and elapsed: 552 rate = (n - initial) / elapsed 553 inv_rate = 1 / rate if rate else None 554 format_sizeof = tqdm.format_sizeof 555 rate_noinv_fmt = ((format_sizeof(rate) if unit_scale else f'{rate:5.2f}') 556 if rate else '?') + unit + '/s' 557 rate_inv_fmt = ( 558 (format_sizeof(inv_rate) if unit_scale else f'{inv_rate:5.2f}') 559 if inv_rate else '?') + 's/' + unit 560 rate_fmt = rate_inv_fmt if inv_rate and inv_rate > 1 else rate_noinv_fmt 561 562 if unit_scale: 563 n_fmt = format_sizeof(n, divisor=unit_divisor) 564 total_fmt = format_sizeof(total, divisor=unit_divisor) if total is not None else '?' 565 else: 566 n_fmt = str(n) 567 total_fmt = str(total) if total is not None else '?' 568 569 try: 570 postfix = ', ' + postfix if postfix else '' 571 except TypeError: 572 pass 573 574 remaining = (total - n) / rate if rate and total else 0 575 remaining_str = tqdm.format_interval(remaining) if rate else '?' 576 try: 577 eta_dt = (datetime.now() + timedelta(seconds=remaining) 578 if rate and total else datetime.fromtimestamp(0, timezone.utc)) 579 except OverflowError: 580 eta_dt = datetime.max 581 582 # format the stats displayed to the left and right sides of the bar 583 if prefix: 584 # old prefix setup work around 585 bool_prefix_colon_already = (prefix[-2:] == ": ") 586 l_bar = prefix if bool_prefix_colon_already else prefix + ": " 587 else: 588 l_bar = '' 589 590 r_bar = f'| {n_fmt}/{total_fmt} [{elapsed_str}<{remaining_str}, {rate_fmt}{postfix}]' 591 592 # Custom bar formatting 593 # Populate a dict with all available progress indicators 594 format_dict = { 595 # slight extension of self.format_dict 596 'n': n, 'n_fmt': n_fmt, 'total': total, 'total_fmt': total_fmt, 597 'elapsed': elapsed_str, 'elapsed_s': elapsed, 598 'ncols': ncols, 'desc': prefix or '', 'unit': unit, 599 'rate': inv_rate if inv_rate and inv_rate > 1 else rate, 600 'rate_fmt': rate_fmt, 'rate_noinv': rate, 601 'rate_noinv_fmt': rate_noinv_fmt, 'rate_inv': inv_rate, 602 'rate_inv_fmt': rate_inv_fmt, 603 'postfix': postfix, 'unit_divisor': unit_divisor, 604 'colour': colour, 605 # plus more useful definitions 606 'remaining': remaining_str, 'remaining_s': remaining, 607 'l_bar': l_bar, 'r_bar': r_bar, 'eta': eta_dt, 608 **extra_kwargs} 609 610 # total is known: we can predict some stats 611 if total: 612 # fractional and percentage progress 613 frac = n / total 614 percentage = frac * 100 615 616 l_bar += f'{percentage:3.0f}%|' 617 618 if ncols == 0: 619 return l_bar[:-1] + r_bar[1:] 620 621 format_dict.update(l_bar=l_bar) 622 if bar_format: 623 format_dict.update(percentage=percentage) 624 625 # auto-remove colon for empty `{desc}` 626 if not prefix: 627 bar_format = bar_format.replace("{desc}: ", '') 628 else: 629 bar_format = "{l_bar}{bar}{r_bar}" 630 631 full_bar = FormatReplace() 632 nobar = bar_format.format(bar=full_bar, **format_dict) 633 if not full_bar.format_called: 634 return nobar # no `{bar}`; nothing else to do 635 636 # Formatting progress bar space available for bar's display 637 full_bar = Bar(frac, 638 max(1, ncols - disp_len(nobar)) if ncols else 10, 639 charset=Bar.ASCII if ascii is True else ascii or Bar.UTF, 640 colour=colour) 641 if not _is_ascii(full_bar.charset) and _is_ascii(bar_format): 642 bar_format = str(bar_format) 643 res = bar_format.format(bar=full_bar, **format_dict) 644 return disp_trim(res, ncols) if ncols else res 645 646 elif bar_format: 647 # user-specified bar_format but no total 648 l_bar += '|' 649 format_dict.update(l_bar=l_bar, percentage=0) 650 full_bar = FormatReplace() 651 nobar = bar_format.format(bar=full_bar, **format_dict) 652 if not full_bar.format_called: 653 return nobar 654 full_bar = Bar(0, 655 max(1, ncols - disp_len(nobar)) if ncols else 10, 656 charset=Bar.BLANK, colour=colour) 657 res = bar_format.format(bar=full_bar, **format_dict) 658 return disp_trim(res, ncols) if ncols else res 659 else: 660 # no total: no progressbar, ETA, just progress stats 661 return (f'{(prefix + ": ") if prefix else ""}' 662 f'{n_fmt}{unit} [{elapsed_str}, {rate_fmt}{postfix}]') 663 664 def __new__(cls, *_, **__): 665 instance = object.__new__(cls) 666 with cls.get_lock(): # also constructs lock if non-existent 667 cls._instances.add(instance) 668 # create monitoring thread 669 if cls.monitor_interval and (cls.monitor is None 670 or not cls.monitor.report()): 671 try: 672 cls.monitor = TMonitor(cls, cls.monitor_interval) 673 except Exception as e: # pragma: nocover 674 warn("tqdm:disabling monitor support" 675 " (monitor_interval = 0) due to:\n" + str(e), 676 TqdmMonitorWarning, stacklevel=2) 677 cls.monitor_interval = 0 678 return instance 679 680 @classmethod 681 def _get_free_pos(cls, instance=None): 682 """Skips specified instance.""" 683 positions = {abs(inst.pos) for inst in cls._instances 684 if inst is not instance and hasattr(inst, "pos")} 685 return min(set(range(len(positions) + 1)).difference(positions)) 686 687 @classmethod 688 def _decr_instances(cls, instance): 689 """ 690 Remove from list and reposition another unfixed bar 691 to fill the new gap. 692 693 This means that by default (where all nested bars are unfixed), 694 order is not maintained but screen flicker/blank space is minimised. 695 (tqdm<=4.44.1 moved ALL subsequent unfixed bars up.) 696 """ 697 with cls._lock: 698 try: 699 cls._instances.remove(instance) 700 except KeyError: 701 # if not instance.gui: # pragma: no cover 702 # raise 703 pass # py2: maybe magically removed already 704 # else: 705 if not instance.gui: 706 last = (instance.nrows or 20) - 1 707 # find unfixed (`pos >= 0`) overflow (`pos >= nrows - 1`) 708 instances = list(filter( 709 lambda i: hasattr(i, "pos") and last <= i.pos, 710 cls._instances)) 711 # set first found to current `pos` 712 if instances: 713 inst = min(instances, key=lambda i: i.pos) 714 inst.clear(nolock=True) 715 inst.pos = abs(instance.pos) 716 717 @classmethod 718 def write(cls, s, file=None, end="\n", nolock=False): 719 """Print a message via tqdm (without overlap with bars).""" 720 fp = file if file is not None else sys.stdout 721 with cls.external_write_mode(file=file, nolock=nolock): 722 # Write the message 723 fp.write(s) 724 fp.write(end) 725 726 @classmethod 727 @contextmanager 728 def external_write_mode(cls, file=None, nolock=False): 729 """ 730 Disable tqdm within context and refresh tqdm when exits. 731 Useful when writing to standard output stream 732 """ 733 fp = file if file is not None else sys.stdout 734 735 try: 736 if not nolock: 737 cls.get_lock().acquire() 738 # Clear all bars 739 inst_cleared = [] 740 for inst in getattr(cls, '_instances', []): 741 # Clear instance if in the target output file 742 # or if write output + tqdm output are both either 743 # sys.stdout or sys.stderr (because both are mixed in terminal) 744 if hasattr(inst, "start_t") and (inst.fp == fp or all( 745 f in (sys.stdout, sys.stderr) for f in (fp, inst.fp))): 746 inst.clear(nolock=True) 747 inst_cleared.append(inst) 748 yield 749 # Force refresh display of bars we cleared 750 for inst in inst_cleared: 751 inst.refresh(nolock=True) 752 finally: 753 if not nolock: 754 cls._lock.release() 755 756 @classmethod 757 def set_lock(cls, lock): 758 """Set the global lock.""" 759 cls._lock = lock 760 761 @classmethod 762 def get_lock(cls): 763 """Get the global lock. Construct it if it does not exist.""" 764 if not hasattr(cls, '_lock'): 765 cls._lock = TqdmDefaultWriteLock() 766 return cls._lock 767 768 @classmethod 769 def pandas(cls, **tqdm_kwargs): 770 """ 771 Registers the current `tqdm` class with 772 pandas.core. 773 ( frame.DataFrame 774 | series.Series 775 | groupby.(generic.)DataFrameGroupBy 776 | groupby.(generic.)SeriesGroupBy 777 ).progress_apply 778 779 A new instance will be created every time `progress_apply` is called, 780 and each instance will automatically `close()` upon completion. 781 782 Parameters 783 ---------- 784 tqdm_kwargs : arguments for the tqdm instance 785 786 Examples 787 -------- 788 >>> import pandas as pd 789 >>> import numpy as np 790 >>> from tqdm import tqdm 791 >>> from tqdm.gui import tqdm as tqdm_gui 792 >>> 793 >>> df = pd.DataFrame(np.random.randint(0, 100, (100000, 6))) 794 >>> tqdm.pandas(ncols=50) # can use tqdm_gui, optional kwargs, etc 795 >>> # Now you can use `progress_apply` instead of `apply` 796 >>> df.groupby(0).progress_apply(lambda x: x**2) 797 798 References 799 ---------- 800 <https://stackoverflow.com/questions/18603270/\ 801 progress-indicator-during-pandas-operations-python> 802 """ 803 from warnings import catch_warnings, simplefilter 804 805 from pandas.core.frame import DataFrame 806 from pandas.core.series import Series 807 try: 808 with catch_warnings(): 809 simplefilter("ignore", category=FutureWarning) 810 from pandas import Panel 811 except ImportError: # pandas>=1.2.0 812 Panel = None 813 Rolling, Expanding = None, None 814 try: # pandas>=1.0.0 815 from pandas.core.window.rolling import _Rolling_and_Expanding 816 except ImportError: 817 try: # pandas>=0.18.0 818 from pandas.core.window import _Rolling_and_Expanding 819 except ImportError: # pandas>=1.2.0 820 try: # pandas>=1.2.0 821 from pandas.core.window.expanding import Expanding 822 from pandas.core.window.rolling import Rolling 823 _Rolling_and_Expanding = Rolling, Expanding 824 except ImportError: # pragma: no cover 825 _Rolling_and_Expanding = None 826 try: # pandas>=0.25.0 827 from pandas.core.groupby.generic import SeriesGroupBy # , NDFrameGroupBy 828 from pandas.core.groupby.generic import DataFrameGroupBy 829 except ImportError: # pragma: no cover 830 try: # pandas>=0.23.0 831 from pandas.core.groupby.groupby import DataFrameGroupBy, SeriesGroupBy 832 except ImportError: 833 from pandas.core.groupby import DataFrameGroupBy, SeriesGroupBy 834 try: # pandas>=0.23.0 835 from pandas.core.groupby.groupby import GroupBy 836 except ImportError: # pragma: no cover 837 from pandas.core.groupby import GroupBy 838 839 try: # pandas>=0.23.0 840 from pandas.core.groupby.groupby import PanelGroupBy 841 except ImportError: 842 try: 843 from pandas.core.groupby import PanelGroupBy 844 except ImportError: # pandas>=0.25.0 845 PanelGroupBy = None 846 847 tqdm_kwargs = tqdm_kwargs.copy() 848 deprecated_t = [tqdm_kwargs.pop('deprecated_t', None)] 849 850 def inner_generator(df_function='apply'): 851 def inner(df, func, *args, **kwargs): 852 """ 853 Parameters 854 ---------- 855 df : (DataFrame|Series)[GroupBy] 856 Data (may be grouped). 857 func : function 858 To be applied on the (grouped) data. 859 **kwargs : optional 860 Transmitted to `df.apply()`. 861 """ 862 863 # Precompute total iterations 864 total = tqdm_kwargs.pop("total", getattr(df, 'ngroups', None)) 865 if total is None: # not grouped 866 if df_function == 'applymap': 867 total = df.size 868 elif isinstance(df, Series): 869 total = len(df) 870 elif (_Rolling_and_Expanding is None or 871 not isinstance(df, _Rolling_and_Expanding)): 872 # DataFrame or Panel 873 axis = kwargs.get('axis', 0) 874 if axis == 'index': 875 axis = 0 876 elif axis == 'columns': 877 axis = 1 878 # when axis=0, total is shape[axis1] 879 total = df.size // df.shape[axis] 880 881 # Init bar 882 if deprecated_t[0] is not None: 883 t = deprecated_t[0] 884 deprecated_t[0] = None 885 else: 886 t = cls(total=total, **tqdm_kwargs) 887 888 if len(args) > 0: 889 # *args intentionally not supported (see #244, #299) 890 TqdmDeprecationWarning( 891 "Except func, normal arguments are intentionally" + 892 " not supported by" + 893 " `(DataFrame|Series|GroupBy).progress_apply`." + 894 " Use keyword arguments instead.", 895 fp_write=getattr(t.fp, 'write', sys.stderr.write)) 896 897 try: # pandas>=1.3.0 898 from pandas.core.common import is_builtin_func 899 except ImportError: 900 is_builtin_func = df._is_builtin_func 901 try: 902 func = is_builtin_func(func) 903 except TypeError: 904 pass 905 906 # Define bar updating wrapper 907 def wrapper(*args, **kwargs): 908 # update tbar correctly 909 # it seems `pandas apply` calls `func` twice 910 # on the first column/row to decide whether it can 911 # take a fast or slow code path; so stop when t.total==t.n 912 t.update(n=1 if not t.total or t.n < t.total else 0) 913 return func(*args, **kwargs) 914 915 # Apply the provided function (in **kwargs) 916 # on the df using our wrapper (which provides bar updating) 917 try: 918 return getattr(df, df_function)(wrapper, **kwargs) 919 finally: 920 t.close() 921 922 return inner 923 924 # Monkeypatch pandas to provide easy methods 925 # Enable custom tqdm progress in pandas! 926 Series.progress_apply = inner_generator() 927 SeriesGroupBy.progress_apply = inner_generator() 928 Series.progress_map = inner_generator('map') 929 SeriesGroupBy.progress_map = inner_generator('map') 930 931 DataFrame.progress_apply = inner_generator() 932 DataFrameGroupBy.progress_apply = inner_generator() 933 DataFrame.progress_applymap = inner_generator('applymap') 934 DataFrame.progress_map = inner_generator('map') 935 DataFrameGroupBy.progress_map = inner_generator('map') 936 937 if Panel is not None: 938 Panel.progress_apply = inner_generator() 939 if PanelGroupBy is not None: 940 PanelGroupBy.progress_apply = inner_generator() 941 942 GroupBy.progress_apply = inner_generator() 943 GroupBy.progress_aggregate = inner_generator('aggregate') 944 GroupBy.progress_transform = inner_generator('transform') 945 946 if Rolling is not None and Expanding is not None: 947 Rolling.progress_apply = inner_generator() 948 Expanding.progress_apply = inner_generator() 949 elif _Rolling_and_Expanding is not None: 950 _Rolling_and_Expanding.progress_apply = inner_generator() 951 952 # override defaults via env vars 953 @envwrap("TQDM_", is_method=True, types={'total': float, 'ncols': int, 'miniters': float, 954 'position': int, 'nrows': int}) 955 def __init__(self, iterable=None, desc=None, total=None, leave=True, file=None, 956 ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None, 957 ascii=None, disable=False, unit='it', unit_scale=False, 958 dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, 959 position=None, postfix=None, unit_divisor=1000, write_bytes=False, 960 lock_args=None, nrows=None, colour=None, delay=0.0, gui=False, 961 **kwargs): 962 """see tqdm.tqdm for arguments""" 963 if file is None: 964 file = sys.stderr 965 966 if write_bytes: 967 # Despite coercing unicode into bytes, py2 sys.std* streams 968 # should have bytes written to them. 969 file = SimpleTextIOWrapper( 970 file, encoding=getattr(file, 'encoding', None) or 'utf-8') 971 972 file = DisableOnWriteError(file, tqdm_instance=self) 973 974 if disable is None and hasattr(file, "isatty") and not file.isatty(): 975 disable = True 976 977 if total is None and iterable is not None: 978 try: 979 total = len(iterable) 980 except (TypeError, AttributeError): 981 total = None 982 if total == float("inf"): 983 # Infinite iterations, behave same as unknown 984 total = None 985 986 if disable: 987 self.iterable = iterable 988 self.disable = disable 989 with self._lock: 990 self.pos = self._get_free_pos(self) 991 self._instances.remove(self) 992 self.n = initial 993 self.total = total 994 self.leave = leave 995 return 996 997 if kwargs: 998 self.disable = True 999 with self._lock: 1000 self.pos = self._get_free_pos(self) 1001 self._instances.remove(self) 1002 raise ( 1003 TqdmDeprecationWarning( 1004 "`nested` is deprecated and automated.\n" 1005 "Use `position` instead for manual control.\n", 1006 fp_write=getattr(file, 'write', sys.stderr.write)) 1007 if "nested" in kwargs else 1008 TqdmKeyError("Unknown argument(s): " + str(kwargs))) 1009 1010 # Preprocess the arguments 1011 if ( 1012 (ncols is None or nrows is None) and (file in (sys.stderr, sys.stdout)) 1013 ) or dynamic_ncols: # pragma: no cover 1014 if dynamic_ncols: 1015 dynamic_ncols = _screen_shape_wrapper() 1016 if dynamic_ncols: 1017 ncols, nrows = dynamic_ncols(file) 1018 else: 1019 _dynamic_ncols = _screen_shape_wrapper() 1020 if _dynamic_ncols: 1021 _ncols, _nrows = _dynamic_ncols(file) 1022 if ncols is None: 1023 ncols = _ncols 1024 if nrows is None: 1025 nrows = _nrows 1026 1027 if miniters is None: 1028 miniters = 0 1029 dynamic_miniters = True 1030 else: 1031 dynamic_miniters = False 1032 1033 if mininterval is None: 1034 mininterval = 0 1035 1036 if maxinterval is None: 1037 maxinterval = 0 1038 1039 if ascii is None: 1040 ascii = not _supports_unicode(file) 1041 1042 if bar_format and ascii is not True and not _is_ascii(ascii): 1043 # Convert bar format into unicode since terminal uses unicode 1044 bar_format = str(bar_format) 1045 1046 if smoothing is None: 1047 smoothing = 0 1048 1049 # Store the arguments 1050 self.iterable = iterable 1051 self.desc = desc or '' 1052 self.total = total 1053 self.leave = leave 1054 self.fp = file 1055 self.ncols = ncols 1056 self.nrows = nrows 1057 self.mininterval = mininterval 1058 self.maxinterval = maxinterval 1059 self.miniters = miniters 1060 self.dynamic_miniters = dynamic_miniters 1061 self.ascii = ascii 1062 self.disable = disable 1063 self.unit = unit 1064 self.unit_scale = unit_scale 1065 self.unit_divisor = unit_divisor 1066 self.initial = initial 1067 self.lock_args = lock_args 1068 self.delay = delay 1069 self.gui = gui 1070 self.dynamic_ncols = dynamic_ncols 1071 self.smoothing = smoothing 1072 self._ema_dn = EMA(smoothing) 1073 self._ema_dt = EMA(smoothing) 1074 self._ema_miniters = EMA(smoothing) 1075 self.bar_format = bar_format 1076 self.postfix = None 1077 self.colour = colour 1078 self._time = time 1079 if postfix: 1080 try: 1081 self.set_postfix(refresh=False, **postfix) 1082 except TypeError: 1083 self.postfix = postfix 1084 1085 # Init the iterations counters 1086 self.last_print_n = initial 1087 self.n = initial 1088 1089 # if nested, at initial sp() call we replace '\r' by '\n' to 1090 # not overwrite the outer progress bar 1091 with self._lock: 1092 # mark fixed positions as negative 1093 self.pos = self._get_free_pos(self) if position is None else -position 1094 1095 if not gui: 1096 # Initialize the screen printer 1097 self.sp = self.status_printer(self.fp) 1098 if delay <= 0: 1099 self.refresh(lock_args=self.lock_args) 1100 1101 # Init the time counter 1102 self.last_print_t = self._time() 1103 # NB: Avoid race conditions by setting start_t at the very end of init 1104 self.start_t = self.last_print_t 1105 1106 def __bool__(self): 1107 if self.total is not None: 1108 return self.total > 0 1109 if self.iterable is None: 1110 raise TypeError('bool() undefined when iterable == total == None') 1111 return bool(self.iterable) 1112 1113 def __len__(self): 1114 return ( 1115 self.total if self.iterable is None 1116 else self.iterable.shape[0] if hasattr(self.iterable, "shape") 1117 else len(self.iterable) if hasattr(self.iterable, "__len__") 1118 else self.iterable.__length_hint__() if hasattr(self.iterable, "__length_hint__") 1119 else getattr(self, "total", None)) 1120 1121 def __reversed__(self): 1122 try: 1123 orig = self.iterable 1124 except AttributeError: 1125 raise TypeError("'tqdm' object is not reversible") 1126 else: 1127 self.iterable = reversed(self.iterable) 1128 return self.__iter__() 1129 finally: 1130 self.iterable = orig 1131 1132 def __contains__(self, item): 1133 contains = getattr(self.iterable, '__contains__', None) 1134 return contains(item) if contains is not None else item in self.__iter__() 1135 1136 def __enter__(self): 1137 return self 1138 1139 def __exit__(self, exc_type, exc_value, traceback): 1140 try: 1141 self.close() 1142 except AttributeError: 1143 # maybe eager thread cleanup upon external error 1144 if (exc_type, exc_value, traceback) == (None, None, None): 1145 raise 1146 warn("AttributeError ignored", TqdmWarning, stacklevel=2) 1147 1148 def __del__(self): 1149 self.close() 1150 1151 def __str__(self): 1152 return self.format_meter(**self.format_dict) 1153 1154 @property 1155 def _comparable(self): 1156 return abs(getattr(self, "pos", 1 << 31)) 1157 1158 def __hash__(self): 1159 return id(self) 1160 1161 def __iter__(self): 1162 """Backward-compatibility to use: for x in tqdm(iterable)""" 1163 1164 # Inlining instance variables as locals (speed optimisation) 1165 iterable = self.iterable 1166 1167 # If the bar is disabled, then just walk the iterable 1168 # (note: keep this check outside the loop for performance) 1169 if self.disable: 1170 for obj in iterable: 1171 yield obj 1172 return 1173 1174 mininterval = self.mininterval 1175 last_print_t = self.last_print_t 1176 last_print_n = self.last_print_n 1177 min_start_t = self.start_t + self.delay 1178 n = self.n 1179 time = self._time 1180 1181 try: 1182 for obj in iterable: 1183 yield obj 1184 # Update and possibly print the progressbar. 1185 # Note: does not call self.update(1) for speed optimisation. 1186 n += 1 1187 1188 if n - last_print_n >= self.miniters: 1189 cur_t = time() 1190 dt = cur_t - last_print_t 1191 if dt >= mininterval and cur_t >= min_start_t: 1192 self.update(n - last_print_n) 1193 last_print_n = self.last_print_n 1194 last_print_t = self.last_print_t 1195 finally: 1196 self.n = n 1197 self.close() 1198 1199 def update(self, n=1): 1200 """ 1201 Manually update the progress bar, useful for streams 1202 such as reading files. 1203 E.g.: 1204 >>> t = tqdm(total=filesize) # Initialise 1205 >>> for current_buffer in stream: 1206 ... ... 1207 ... t.update(len(current_buffer)) 1208 >>> t.close() 1209 The last line is highly recommended, but possibly not necessary if 1210 `t.update()` will be called in such a way that `filesize` will be 1211 exactly reached and printed. 1212 1213 Parameters 1214 ---------- 1215 n : int or float, optional 1216 Increment to add to the internal counter of iterations 1217 [default: 1]. If using float, consider specifying `{n:.3f}` 1218 or similar in `bar_format`, or specifying `unit_scale`. 1219 1220 Returns 1221 ------- 1222 out : bool or None 1223 True if a `display()` was triggered. 1224 """ 1225 if self.disable: 1226 return 1227 1228 if n < 0: 1229 self.last_print_n += n # for auto-refresh logic to work 1230 self.n += n 1231 1232 # check counter first to reduce calls to time() 1233 if self.n - self.last_print_n >= self.miniters: 1234 cur_t = self._time() 1235 dt = cur_t - self.last_print_t 1236 if dt >= self.mininterval and cur_t >= self.start_t + self.delay: 1237 cur_t = self._time() 1238 dn = self.n - self.last_print_n # >= n 1239 if self.smoothing and dt and dn: 1240 # EMA (not just overall average) 1241 self._ema_dn(dn) 1242 self._ema_dt(dt) 1243 self.refresh(lock_args=self.lock_args) 1244 if self.dynamic_miniters: 1245 # If no `miniters` was specified, adjust automatically to the 1246 # maximum iteration rate seen so far between two prints. 1247 # e.g.: After running `tqdm.update(5)`, subsequent 1248 # calls to `tqdm.update()` will only cause an update after 1249 # at least 5 more iterations. 1250 if self.maxinterval and dt >= self.maxinterval: 1251 self.miniters = dn * (self.mininterval or self.maxinterval) / dt 1252 elif self.smoothing: 1253 # EMA miniters update 1254 self.miniters = self._ema_miniters( 1255 dn * (self.mininterval / dt if self.mininterval and dt 1256 else 1)) 1257 else: 1258 # max iters between two prints 1259 self.miniters = max(self.miniters, dn) 1260 1261 # Store old values for next call 1262 self.last_print_n = self.n 1263 self.last_print_t = cur_t 1264 return True 1265 1266 def close(self): 1267 """Cleanup and (if leave=False) close the progressbar.""" 1268 if self.disable: 1269 return 1270 1271 # Prevent multiple closures 1272 self.disable = True 1273 1274 # decrement instance pos and remove from internal set 1275 pos = abs(self.pos) 1276 self._decr_instances(self) 1277 1278 if self.last_print_t < self.start_t + self.delay: 1279 # haven't ever displayed; nothing to clear 1280 return 1281 1282 # GUI mode 1283 if getattr(self, 'sp', None) is None: 1284 return 1285 1286 # annoyingly, _supports_unicode isn't good enough 1287 def fp_write(s): 1288 self.fp.write(str(s)) 1289 1290 try: 1291 fp_write('') 1292 except ValueError as e: 1293 if 'closed' in str(e): 1294 return 1295 raise # pragma: no cover 1296 1297 leave = pos == 0 if self.leave is None else self.leave 1298 1299 with self._lock: 1300 if leave: 1301 # stats for overall rate (no weighted average) 1302 self._ema_dt = lambda: None 1303 self.display(pos=0) 1304 fp_write('\n') 1305 else: 1306 # clear previous display 1307 if self.display(msg='', pos=pos) and not pos: 1308 fp_write('\r') 1309 1310 def clear(self, nolock=False): 1311 """Clear current bar display.""" 1312 if self.disable: 1313 return 1314 1315 if not nolock: 1316 self._lock.acquire() 1317 pos = abs(self.pos) 1318 if pos < (self.nrows or 20): 1319 self.moveto(pos) 1320 self.sp('') 1321 self.fp.write('\r') # place cursor back at the beginning of line 1322 self.moveto(-pos) 1323 if not nolock: 1324 self._lock.release() 1325 1326 def refresh(self, nolock=False, lock_args=None): 1327 """ 1328 Force refresh the display of this bar. 1329 1330 Parameters 1331 ---------- 1332 nolock : bool, optional 1333 If `True`, does not lock. 1334 If [default: `False`]: calls `acquire()` on internal lock. 1335 lock_args : tuple, optional 1336 Passed to internal lock's `acquire()`. 1337 If specified, will only `display()` if `acquire()` returns `True`. 1338 """ 1339 if self.disable: 1340 return 1341 1342 if not nolock: 1343 if lock_args: 1344 if not self._lock.acquire(*lock_args): 1345 return False 1346 else: 1347 self._lock.acquire() 1348 self.display() 1349 if not nolock: 1350 self._lock.release() 1351 return True 1352 1353 def unpause(self): 1354 """Restart tqdm timer from last print time.""" 1355 if self.disable: 1356 return 1357 cur_t = self._time() 1358 self.start_t += cur_t - self.last_print_t 1359 self.last_print_t = cur_t 1360 1361 def reset(self, total=None): 1362 """ 1363 Resets to 0 iterations for repeated use. 1364 1365 Consider combining with `leave=True`. 1366 1367 Parameters 1368 ---------- 1369 total : int or float, optional. Total to use for the new bar. 1370 """ 1371 self.n = 0 1372 if total is not None: 1373 self.total = total 1374 if self.disable: 1375 return 1376 self.last_print_n = 0 1377 self.last_print_t = self.start_t = self._time() 1378 self._ema_dn = EMA(self.smoothing) 1379 self._ema_dt = EMA(self.smoothing) 1380 self._ema_miniters = EMA(self.smoothing) 1381 self.refresh() 1382 1383 def set_description(self, desc=None, refresh=True): 1384 """ 1385 Set/modify description of the progress bar. 1386 1387 Parameters 1388 ---------- 1389 desc : str, optional 1390 refresh : bool, optional 1391 Forces refresh [default: True]. 1392 """ 1393 self.desc = desc + ': ' if desc else '' 1394 if refresh: 1395 self.refresh() 1396 1397 def set_description_str(self, desc=None, refresh=True): 1398 """Set/modify description without ': ' appended.""" 1399 self.desc = desc or '' 1400 if refresh: 1401 self.refresh() 1402 1403 def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): 1404 """ 1405 Set/modify postfix (additional stats) 1406 with automatic formatting based on datatype. 1407 1408 Parameters 1409 ---------- 1410 ordered_dict : dict or OrderedDict, optional 1411 refresh : bool, optional 1412 Forces refresh [default: True]. 1413 kwargs : dict, optional 1414 """ 1415 # Sort in alphabetical order to be more deterministic 1416 postfix = OrderedDict([] if ordered_dict is None else ordered_dict) 1417 for key in sorted(kwargs.keys()): 1418 postfix[key] = kwargs[key] 1419 # Preprocess stats according to datatype 1420 for key in postfix.keys(): 1421 # Number: limit the length of the string 1422 if isinstance(postfix[key], Number): 1423 postfix[key] = self.format_num(postfix[key]) 1424 # Else for any other type, try to get the string conversion 1425 elif not isinstance(postfix[key], str): 1426 postfix[key] = str(postfix[key]) 1427 # Else if it's a string, don't need to preprocess anything 1428 # Stitch together to get the final postfix 1429 self.postfix = ', '.join(key + '=' + postfix[key].strip() 1430 for key in postfix.keys()) 1431 if refresh: 1432 self.refresh() 1433 1434 def set_postfix_str(self, s='', refresh=True): 1435 """ 1436 Postfix without dictionary expansion, similar to prefix handling. 1437 """ 1438 self.postfix = str(s) 1439 if refresh: 1440 self.refresh() 1441 1442 def moveto(self, n): 1443 # TODO: private method 1444 self.fp.write('\n' * n + _term_move_up() * -n) 1445 getattr(self.fp, 'flush', lambda: None)() 1446 1447 @property 1448 def format_dict(self): 1449 """Public API for read-only member access.""" 1450 if self.disable and not hasattr(self, 'unit'): 1451 return defaultdict(lambda: None, { 1452 'n': self.n, 'total': self.total, 'elapsed': 0, 'unit': 'it'}) 1453 if self.dynamic_ncols: 1454 self.ncols, self.nrows = self.dynamic_ncols(self.fp) 1455 return { 1456 'n': self.n, 'total': self.total, 1457 'elapsed': self._time() - self.start_t if hasattr(self, 'start_t') else 0, 1458 'ncols': self.ncols, 'nrows': self.nrows, 'prefix': self.desc, 1459 'ascii': self.ascii, 'unit': self.unit, 'unit_scale': self.unit_scale, 1460 'rate': self._ema_dn() / self._ema_dt() if self._ema_dt() else None, 1461 'bar_format': self.bar_format, 'postfix': self.postfix, 1462 'unit_divisor': self.unit_divisor, 'initial': self.initial, 1463 'colour': self.colour} 1464 1465 def display(self, msg=None, pos=None): 1466 """ 1467 Use `self.sp` to display `msg` in the specified `pos`. 1468 1469 Consider overloading this function when inheriting to use e.g.: 1470 `self.some_frontend(**self.format_dict)` instead of `self.sp`. 1471 1472 Parameters 1473 ---------- 1474 msg : str, optional. What to display (default: `repr(self)`). 1475 pos : int, optional. Position to `moveto` 1476 (default: `abs(self.pos)`). 1477 """ 1478 if pos is None: 1479 pos = abs(self.pos) 1480 1481 nrows = self.nrows or 20 1482 if pos >= nrows - 1: 1483 if pos >= nrows: 1484 return False 1485 if msg or msg is None: # override at `nrows - 1` 1486 msg = " ... (more hidden) ..." 1487 1488 if not hasattr(self, "sp"): 1489 raise TqdmDeprecationWarning( 1490 "Please use `tqdm.gui.tqdm(...)`" 1491 " instead of `tqdm(..., gui=True)`\n", 1492 fp_write=getattr(self.fp, 'write', sys.stderr.write)) 1493 1494 if pos: 1495 self.moveto(pos) 1496 self.sp(self.__str__() if msg is None else msg) 1497 if pos: 1498 self.moveto(-pos) 1499 return True 1500 1501 @classmethod 1502 @contextmanager 1503 def wrapattr(cls, stream, method, total=None, bytes=True, **tqdm_kwargs): 1504 """ 1505 stream : file-like object. 1506 method : str, "read" or "write". The result of `read()` and 1507 the first argument of `write()` should have a `len()`. 1508 1509 >>> with tqdm.wrapattr(file_obj, "read", total=file_obj.size) as fobj: 1510 ... while True: 1511 ... chunk = fobj.read(chunk_size) 1512 ... if not chunk: 1513 ... break 1514 """ 1515 with cls(total=total, **tqdm_kwargs) as t: 1516 if bytes: 1517 t.unit = "B" 1518 t.unit_scale = True 1519 t.unit_divisor = 1024 1520 yield CallbackIOWrapper(t.update, stream, method)
Decorate an iterable object, returning an iterator which acts exactly like the original iterable, but prints a dynamically updating progressbar every time a value is requested.
Parameters
iterable : iterable, optional
Iterable to decorate with a progressbar.
Leave blank to manually manage the updates.
desc : str, optional
Prefix for the progressbar.
total : int or float, optional
The number of expected iterations. If unspecified,
len(iterable) is used if possible. If float("inf") or as a last
resort, only basic progress statistics are displayed
(no ETA, no progressbar).
If gui
is True and this parameter needs subsequent updating,
specify an initial arbitrary large positive number,
e.g. 9e9.
leave : bool, optional
If [default: True], keeps all traces of the progressbar
upon termination of iteration.
If None
, will leave only if position
is 0
.
file : io.TextIOWrapper
or io.StringIO
, optional
Specifies where to output the progress messages
(default: sys.stderr). Uses file.write(str)
and file.flush()
methods. For encoding, see write_bytes
.
ncols : int, optional
The width of the entire output message. If specified,
dynamically resizes the progressbar to stay within this bound.
If unspecified, attempts to use environment width. The
fallback is a meter width of 10 and no limit for the counter and
statistics. If 0, will not print any meter (only stats).
mininterval : float, optional
Minimum progress display update interval [default: 0.1] seconds.
maxinterval : float, optional
Maximum progress display update interval [default: 10] seconds.
Automatically adjusts miniters
to correspond to mininterval
after long display update lag. Only works if dynamic_miniters
or monitor thread is enabled.
miniters : int or float, optional
Minimum progress display update interval, in iterations.
If 0 and dynamic_miniters
, will automatically adjust to equal
mininterval
(more CPU efficient, good for tight loops).
If > 0, will skip display of specified number of iterations.
Tweak this and mininterval
to get very efficient loops.
If your progress is erratic with both fast and slow iterations
(network, skipping items, etc) you should set miniters=1.
ascii : bool or str, optional
If unspecified or False, use unicode (smooth blocks) to fill
the meter. The fallback is to use ASCII characters " 123456789#".
disable : bool, optional
Whether to disable the entire progressbar wrapper
[default: False]. If set to None, disable on non-TTY.
unit : str, optional
String that will be used to define the unit of each iteration
[default: it].
unit_scale : bool or int or float, optional
If 1 or True, the number of iterations will be reduced/scaled
automatically and a metric prefix following the
International System of Units standard will be added
(kilo, mega, etc.) [default: False]. If any other non-zero
number, will scale total
and n
.
dynamic_ncols : bool, optional
If set, constantly alters ncols
and nrows
to the
environment (allowing for window resizes) [default: False].
smoothing : float, optional
Exponential moving average smoothing factor for speed estimates
(ignored in GUI mode). Ranges from 0 (average speed) to 1
(current/instantaneous speed) [default: 0.3].
bar_format : str, optional
Specify a custom bar string formatting. May impact performance.
[default: '{l_bar}{bar}{r_bar}'], where
l_bar='{desc}: {percentage:3.0f}%|' and
r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
'{rate_fmt}{postfix}]'
Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
percentage, elapsed, elapsed_s, ncols, nrows, desc, unit,
rate, rate_fmt, rate_noinv, rate_noinv_fmt,
rate_inv, rate_inv_fmt, postfix, unit_divisor,
remaining, remaining_s, eta.
Note that a trailing ": " is automatically removed after {desc}
if the latter is empty.
initial : int or float, optional
The initial counter value. Useful when restarting a progress
bar [default: 0]. If using float, consider specifying {n:.3f}
or similar in bar_format
, or specifying unit_scale
.
position : int, optional
Specify the line offset to print this bar (starting from 0)
Automatic if unspecified.
Useful to manage multiple bars at once (eg, from threads).
postfix : dict or *, optional
Specify additional stats to display at the end of the bar.
Calls set_postfix(**postfix)
if possible (dict).
unit_divisor : float, optional
[default: 1000], ignored unless unit_scale
is True.
write_bytes : bool, optional
Whether to write bytes. If (default: False) will write unicode.
lock_args : tuple, optional
Passed to refresh
for intermediate output
(initialisation, iterating, and updating).
nrows : int, optional
The screen height. If specified, hides nested bars outside this
bound. If unspecified, attempts to use environment height.
The fallback is 20.
colour : str, optional
Bar colour (e.g. 'green', '#00ff00').
delay : float, optional
Don't display until [default: 0] seconds have elapsed.
gui : bool, optional
WARNING: internal parameter - do not use.
Use tqdm.gui.tqdm(...) instead. If set, will attempt to use
matplotlib animations for a graphical output [default: False].
Returns
out : decorated iterator.
41def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 42 mapped_kwargs: dict = { 43 k: v 44 for k, v in kwargs.items() 45 if k in get_fn_allowed_kwargs(SpinnerContext.__init__) 46 } 47 if "desc" in kwargs and "message" not in mapped_kwargs: 48 mapped_kwargs["message"] = kwargs.get("desc") 49 50 if "message" not in mapped_kwargs and "total" in kwargs: 51 mapped_kwargs["message"] = f"Processing {kwargs.get('total')} items" 52 53 with SpinnerContext(**mapped_kwargs): 54 output = list(x) 55 56 return output
59def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 60 "fallback to no progress bar" 61 return x
fallback to no progress bar
85def set_up_progress_bar_fn( 86 pbar: Union[ProgressBarFunction, ProgressBarOption], 87 pbar_kwargs: Optional[Dict[str, Any]] = None, 88 **extra_kwargs, 89) -> ProgressBarFunction: 90 pbar_fn: ProgressBarFunction 91 92 if pbar_kwargs is None: 93 pbar_kwargs = dict() 94 95 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 96 97 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 98 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 99 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 100 101 # if `pbar` is a different string, figure out which progress bar to use 102 elif isinstance(pbar, str): 103 if pbar == "tqdm": 104 pbar_fn = functools.partial(tqdm.tqdm, **pbar_kwargs) 105 elif pbar == "spinner": 106 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 107 else: 108 raise ValueError( 109 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 110 ) 111 else: 112 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 113 pbar_fn = functools.partial(pbar, **pbar_kwargs) 114 115 return pbar_fn
118def run_maybe_parallel( 119 func: Callable[[InputType], OutputType], 120 iterable: Iterable[InputType], 121 parallel: Union[bool, int], 122 pbar_kwargs: Optional[Dict[str, Any]] = None, 123 chunksize: Optional[int] = None, 124 keep_ordered: bool = True, 125 use_multiprocess: bool = False, 126 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 127) -> List[OutputType]: 128 """a function to make it easier to sometimes parallelize an operation 129 130 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 131 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 132 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 133 134 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 135 136 # Parameters: 137 - `func : Callable[[InputType], OutputType]` 138 function passed to either `map` or `Pool.imap` 139 - `iterable : Iterable[InputType]` 140 iterable passed to either `map` or `Pool.imap` 141 - `parallel : bool | int` 142 _description_ 143 - `pbar_kwargs : Dict[str, Any]` 144 _description_ 145 146 # Returns: 147 - `List[OutputType]` 148 _description_ 149 150 # Raises: 151 - `ValueError` : _description_ 152 """ 153 154 # number of inputs in iterable 155 n_inputs: int = len(iterable) # type: ignore[arg-type] 156 if n_inputs == 0: 157 # Return immediately if there is no input 158 return list() 159 160 # which progress bar to use 161 pbar_fn: ProgressBarFunction = set_up_progress_bar_fn( 162 pbar=pbar, 163 pbar_kwargs=pbar_kwargs, 164 # extra kwargs 165 total=n_inputs, 166 ) 167 168 # number of processes 169 num_processes: int 170 if isinstance(parallel, bool): 171 num_processes = multiprocessing.cpu_count() if parallel else 1 172 elif isinstance(parallel, int): 173 if parallel < 2: 174 raise ValueError( 175 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 176 ) 177 num_processes = parallel 178 else: 179 raise ValueError( 180 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 181 ) 182 183 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 184 num_processes = min(num_processes, n_inputs) 185 mp = multiprocessing 186 if num_processes == 1: 187 parallel = False 188 189 if use_multiprocess: 190 if not parallel: 191 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 192 193 try: 194 import multiprocess # type: ignore[import-untyped] 195 except ImportError as e: 196 raise ImportError( 197 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 198 ) from e 199 200 mp = multiprocess 201 202 # set up the map function -- maybe its parallel, maybe it's just `map` 203 do_map: Callable[ 204 [Callable[[InputType], OutputType], Iterable[InputType]], 205 Iterable[OutputType], 206 ] 207 if parallel: 208 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 209 pool = mp.Pool(num_processes) 210 211 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 212 if keep_ordered: 213 do_map = pool.imap 214 else: 215 do_map = pool.imap_unordered 216 217 # figure out a smart chunksize if one is not given 218 chunksize_int: int 219 if chunksize is None: 220 chunksize_int = max(1, n_inputs // num_processes) 221 else: 222 chunksize_int = chunksize 223 224 # set the chunksize 225 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 226 227 else: 228 do_map = map 229 230 # run the map function with a progress bar 231 output: List[OutputType] = list( 232 pbar_fn( 233 do_map( 234 func, 235 iterable, 236 ) 237 ) 238 ) 239 240 # close the pool if we used one 241 if parallel: 242 pool.close() 243 pool.join() 244 245 # return the output as a list 246 return output
a function to make it easier to sometimes parallelize an operation
- if
parallel
isFalse
, then the function will run in serial, runningmap(func, iterable)
- if
parallel
isTrue
, then the function will run in parallel, running in parallel with the maximum number of processes - if
parallel
is anint
, it must be greater than 1, and the function will run in parallel with the number of processes specified byparallel
the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())
Parameters:
func : Callable[[InputType], OutputType]
function passed to eithermap
orPool.imap
iterable : Iterable[InputType]
iterable passed to eithermap
orPool.imap
parallel : bool | int
_description_pbar_kwargs : Dict[str, Any]
_description_
Returns:
List[OutputType]
_description_
Raises:
ValueError
: _description_