Coverage for muutils\dbg.py: 95%
75 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
1"""
3an implementation of the Rust builtin `dbg!` for Python,orignally from
4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
6licensed under MIT:
8Copyright (c) 2019 Tyler Wince
10Permission is hereby granted, free of charge, to any person obtaining a copy
11of this software and associated documentation files (the "Software"), to deal
12in the Software without restriction, including without limitation the rights
13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14copies of the Software, and to permit persons to whom the Software is
15furnished to do so, subject to the following conditions:
17The above copyright notice and this permission notice shall be included in
18all copies or substantial portions of the Software.
20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26THE SOFTWARE.
28"""
30from __future__ import annotations
32import os
33import inspect
34import sys
35import typing
36from pathlib import Path
37import functools
39# type defs
40_ExpType = typing.TypeVar("_ExpType")
43# Sentinel type for no expression passed
44class _NoExpPassedSentinel:
45 """Unique sentinel type used to indicate that no expression was passed."""
47 pass
50_NoExpPassed = _NoExpPassedSentinel()
52# global variables
53_CWD: Path = Path.cwd().absolute()
54_COUNTER: int = 0
56# configuration
57PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
60# path processing
61def _process_path(path: Path) -> str:
62 path_abs: Path = path.absolute()
63 if PATH_MODE == "absolute":
64 fname = path_abs.as_posix()
65 elif PATH_MODE == "relative":
66 try:
67 fname = path_abs.relative_to(
68 Path(os.path.commonpath([path_abs, _CWD]))
69 ).as_posix()
70 except ValueError:
71 fname = path_abs.as_posix()
72 else:
73 raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
75 return fname
78# actual dbg function
79@typing.overload
80def dbg() -> _NoExpPassedSentinel:
81 ...
82@typing.overload
83def dbg(exp: _NoExpPassedSentinel, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None) -> _NoExpPassedSentinel:
84 ...
85@typing.overload
86def dbg(exp: _ExpType, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None) -> _ExpType:
87 ...
88def dbg(
89 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
90 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
91) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
92 """Call dbg with any variable or expression.
94 Calling dbg will print to stderr the current filename and lineno,
95 as well as the passed expression and what the expression evaluates to:
97 from muutils.dbg import dbg
99 a = 2
100 b = 5
102 dbg(a+b)
104 def square(x: int) -> int:
105 return x * x
107 dbg(square(a))
109 """
110 global _COUNTER
112 # get the context
113 fname: str = "unknown"
114 line_exp: str = "unknown"
115 for frame in inspect.stack():
116 if frame.code_context is None:
117 continue
118 line: str = frame.code_context[0]
119 if "dbg" in line:
120 start: int = line.find("(") + 1
121 end: int = line.rfind(")")
122 if end == -1:
123 end = len(line)
125 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}"
126 line_exp = line[start:end]
128 break
130 # assemble the message
131 msg: str
132 if exp is _NoExpPassed:
133 # if no expression is passed, just show location and counter value
134 msg = f"[ {fname} ] (dbg {_COUNTER})"
135 _COUNTER += 1
136 else:
137 # if expression passed, format its value and show location, expr, and value
138 exp_val: str = formatter(exp) if formatter else repr(exp)
139 msg = f"[ {fname} ] {line_exp} = {exp_val}"
141 # print the message
142 print(
143 msg,
144 file=sys.stderr,
145 )
147 # return the expression itself
148 return exp
151# formatted `dbg_*` functions with their helpers
152def tensor_info_dict(tensor: typing.Any) -> typing.Dict[str, str]:
153 output: typing.Dict[str, str] = dict()
154 # shape
155 if hasattr(tensor, "shape"):
156 # output += f"shape={tuple(tensor.shape)}"
157 output["shape"] = repr(tuple(tensor.shape))
159 # print the sum if its a nan or inf
160 if hasattr(tensor, "sum"):
161 sum: float = tensor.sum()
162 if sum != sum:
163 output["sum"] = repr(sum)
165 # more info
166 if hasattr(tensor, "dtype"):
167 # output += f", dtype={tensor.dtype}"
168 output["dtype"] = repr(tensor.dtype)
169 if hasattr(tensor, "device"):
170 output["device"] = repr(tensor.device)
171 if hasattr(tensor, "requires_grad"):
172 output["requires_grad"] = repr(tensor.requires_grad)
174 # return
175 return output
178def tensor_info(tensor: typing.Any) -> str:
179 info: typing.Dict[str, str] = tensor_info_dict(tensor)
180 return ", ".join(f"{k}={v}" for k, v in info.items())
183dbg_tensor = functools.partial(dbg, formatter=tensor_info)