Coverage for muutils\dbg.py: 99%
72 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-14 01:33 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-14 01:33 -0700
1"""
3an implementation of the Rust builtin `dbg!` for Python,orignally from
4https://github.com/tylerwince/pydbg/blob/master/pydbg.py
6licensed under MIT:
8Copyright (c) 2019 Tyler Wince
10Permission is hereby granted, free of charge, to any person obtaining a copy
11of this software and associated documentation files (the "Software"), to deal
12in the Software without restriction, including without limitation the rights
13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14copies of the Software, and to permit persons to whom the Software is
15furnished to do so, subject to the following conditions:
17The above copyright notice and this permission notice shall be included in
18all copies or substantial portions of the Software.
20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26THE SOFTWARE.
28"""
30from __future__ import annotations
32import os
33import inspect
34import sys
35import typing
36from pathlib import Path
37import functools
39# type defs
40_ExpType = typing.TypeVar("_ExpType")
43# Sentinel type for no expression passed
44class _NoExpPassedSentinel:
45 """Unique sentinel type used to indicate that no expression was passed."""
47 pass
50_NoExpPassed = _NoExpPassedSentinel()
52# global variables
53_CWD: Path = Path.cwd().absolute()
54_COUNTER: int = 0
56# configuration
57PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
60# path processing
61def _process_path(path: Path) -> str:
62 path_abs: Path = path.absolute()
63 if PATH_MODE == "absolute":
64 fname = path_abs.as_posix()
65 elif PATH_MODE == "relative":
66 try:
67 fname = path_abs.relative_to(
68 Path(os.path.commonpath([path_abs, _CWD]))
69 ).as_posix()
70 except ValueError:
71 fname = path_abs.as_posix()
72 else:
73 raise ValueError("PATH_MODE must be either 'relative' or 'absolute")
75 return fname
78# actual dbg function
79@typing.overload
80def dbg() -> _NoExpPassedSentinel: ...
81@typing.overload
82def dbg(
83 exp: _NoExpPassedSentinel,
84 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
85) -> _NoExpPassedSentinel: ...
86@typing.overload
87def dbg(
88 exp: _ExpType, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None
89) -> _ExpType: ...
90def dbg(
91 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
92 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
93) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
94 """Call dbg with any variable or expression.
96 Calling dbg will print to stderr the current filename and lineno,
97 as well as the passed expression and what the expression evaluates to:
99 from muutils.dbg import dbg
101 a = 2
102 b = 5
104 dbg(a+b)
106 def square(x: int) -> int:
107 return x * x
109 dbg(square(a))
111 """
112 global _COUNTER
114 # get the context
115 fname: str = "unknown"
116 line_exp: str = "unknown"
117 for frame in inspect.stack():
118 if frame.code_context is None:
119 continue
120 line: str = frame.code_context[0]
121 if "dbg" in line:
122 start: int = line.find("(") + 1
123 end: int = line.rfind(")")
124 if end == -1:
125 end = len(line)
127 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}"
128 line_exp = line[start:end]
130 break
132 # assemble the message
133 msg: str
134 if exp is _NoExpPassed:
135 # if no expression is passed, just show location and counter value
136 msg = f"[ {fname} ] (dbg {_COUNTER})"
137 _COUNTER += 1
138 else:
139 # if expression passed, format its value and show location, expr, and value
140 exp_val: str = formatter(exp) if formatter else repr(exp)
141 msg = f"[ {fname} ] {line_exp} = {exp_val}"
143 # print the message
144 print(
145 msg,
146 file=sys.stderr,
147 )
149 # return the expression itself
150 return exp
153# formatted `dbg_*` functions with their helpers
154def tensor_info_dict(tensor: typing.Any) -> typing.Dict[str, str]:
155 output: typing.Dict[str, str] = dict()
156 # shape
157 if hasattr(tensor, "shape"):
158 # output += f"shape={tuple(tensor.shape)}"
159 output["shape"] = repr(tuple(tensor.shape))
161 # print the sum if its a nan or inf
162 if hasattr(tensor, "sum"):
163 sum: float = tensor.sum()
164 if sum != sum:
165 output["sum"] = repr(sum)
167 # more info
168 if hasattr(tensor, "dtype"):
169 # output += f", dtype={tensor.dtype}"
170 output["dtype"] = repr(tensor.dtype)
171 if hasattr(tensor, "device"):
172 output["device"] = repr(tensor.device)
173 if hasattr(tensor, "requires_grad"):
174 output["requires_grad"] = repr(tensor.requires_grad)
176 # return
177 return output
180def tensor_info(tensor: typing.Any) -> str:
181 info: typing.Dict[str, str] = tensor_info_dict(tensor)
182 return ", ".join(f"{k}={v}" for k, v in info.items())
185dbg_tensor = functools.partial(dbg, formatter=tensor_info)