Coverage for muutils\dbg.py: 99%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-14 01:33 -0700

1""" 

2 

3an implementation of the Rust builtin `dbg!` for Python,orignally from 

4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 

5 

6licensed under MIT: 

7 

8Copyright (c) 2019 Tyler Wince 

9 

10Permission is hereby granted, free of charge, to any person obtaining a copy 

11of this software and associated documentation files (the "Software"), to deal 

12in the Software without restriction, including without limitation the rights 

13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

14copies of the Software, and to permit persons to whom the Software is 

15furnished to do so, subject to the following conditions: 

16 

17The above copyright notice and this permission notice shall be included in 

18all copies or substantial portions of the Software. 

19 

20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

26THE SOFTWARE. 

27 

28""" 

29 

30from __future__ import annotations 

31 

32import os 

33import inspect 

34import sys 

35import typing 

36from pathlib import Path 

37import functools 

38 

39# type defs 

40_ExpType = typing.TypeVar("_ExpType") 

41 

42 

43# Sentinel type for no expression passed 

44class _NoExpPassedSentinel: 

45 """Unique sentinel type used to indicate that no expression was passed.""" 

46 

47 pass 

48 

49 

50_NoExpPassed = _NoExpPassedSentinel() 

51 

52# global variables 

53_CWD: Path = Path.cwd().absolute() 

54_COUNTER: int = 0 

55 

56# configuration 

57PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 

58 

59 

60# path processing 

61def _process_path(path: Path) -> str: 

62 path_abs: Path = path.absolute() 

63 if PATH_MODE == "absolute": 

64 fname = path_abs.as_posix() 

65 elif PATH_MODE == "relative": 

66 try: 

67 fname = path_abs.relative_to( 

68 Path(os.path.commonpath([path_abs, _CWD])) 

69 ).as_posix() 

70 except ValueError: 

71 fname = path_abs.as_posix() 

72 else: 

73 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 

74 

75 return fname 

76 

77 

78# actual dbg function 

79@typing.overload 

80def dbg() -> _NoExpPassedSentinel: ... 

81@typing.overload 

82def dbg( 

83 exp: _NoExpPassedSentinel, 

84 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

85) -> _NoExpPassedSentinel: ... 

86@typing.overload 

87def dbg( 

88 exp: _ExpType, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None 

89) -> _ExpType: ... 

90def dbg( 

91 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 

92 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

93) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 

94 """Call dbg with any variable or expression. 

95 

96 Calling dbg will print to stderr the current filename and lineno, 

97 as well as the passed expression and what the expression evaluates to: 

98 

99 from muutils.dbg import dbg 

100 

101 a = 2 

102 b = 5 

103 

104 dbg(a+b) 

105 

106 def square(x: int) -> int: 

107 return x * x 

108 

109 dbg(square(a)) 

110 

111 """ 

112 global _COUNTER 

113 

114 # get the context 

115 fname: str = "unknown" 

116 line_exp: str = "unknown" 

117 for frame in inspect.stack(): 

118 if frame.code_context is None: 

119 continue 

120 line: str = frame.code_context[0] 

121 if "dbg" in line: 

122 start: int = line.find("(") + 1 

123 end: int = line.rfind(")") 

124 if end == -1: 

125 end = len(line) 

126 

127 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 

128 line_exp = line[start:end] 

129 

130 break 

131 

132 # assemble the message 

133 msg: str 

134 if exp is _NoExpPassed: 

135 # if no expression is passed, just show location and counter value 

136 msg = f"[ {fname} ] (dbg {_COUNTER})" 

137 _COUNTER += 1 

138 else: 

139 # if expression passed, format its value and show location, expr, and value 

140 exp_val: str = formatter(exp) if formatter else repr(exp) 

141 msg = f"[ {fname} ] {line_exp} = {exp_val}" 

142 

143 # print the message 

144 print( 

145 msg, 

146 file=sys.stderr, 

147 ) 

148 

149 # return the expression itself 

150 return exp 

151 

152 

153# formatted `dbg_*` functions with their helpers 

154def tensor_info_dict(tensor: typing.Any) -> typing.Dict[str, str]: 

155 output: typing.Dict[str, str] = dict() 

156 # shape 

157 if hasattr(tensor, "shape"): 

158 # output += f"shape={tuple(tensor.shape)}" 

159 output["shape"] = repr(tuple(tensor.shape)) 

160 

161 # print the sum if its a nan or inf 

162 if hasattr(tensor, "sum"): 

163 sum: float = tensor.sum() 

164 if sum != sum: 

165 output["sum"] = repr(sum) 

166 

167 # more info 

168 if hasattr(tensor, "dtype"): 

169 # output += f", dtype={tensor.dtype}" 

170 output["dtype"] = repr(tensor.dtype) 

171 if hasattr(tensor, "device"): 

172 output["device"] = repr(tensor.device) 

173 if hasattr(tensor, "requires_grad"): 

174 output["requires_grad"] = repr(tensor.requires_grad) 

175 

176 # return 

177 return output 

178 

179 

180def tensor_info(tensor: typing.Any) -> str: 

181 info: typing.Dict[str, str] = tensor_info_dict(tensor) 

182 return ", ".join(f"{k}={v}" for k, v in info.items()) 

183 

184 

185dbg_tensor = functools.partial(dbg, formatter=tensor_info)