Coverage for muutils\dbg.py: 95%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-05 19:24 -0700

1""" 

2 

3an implementation of the Rust builtin `dbg!` for Python,orignally from 

4https://github.com/tylerwince/pydbg/blob/master/pydbg.py 

5 

6licensed under MIT: 

7 

8Copyright (c) 2019 Tyler Wince 

9 

10Permission is hereby granted, free of charge, to any person obtaining a copy 

11of this software and associated documentation files (the "Software"), to deal 

12in the Software without restriction, including without limitation the rights 

13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

14copies of the Software, and to permit persons to whom the Software is 

15furnished to do so, subject to the following conditions: 

16 

17The above copyright notice and this permission notice shall be included in 

18all copies or substantial portions of the Software. 

19 

20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

26THE SOFTWARE. 

27 

28""" 

29 

30from __future__ import annotations 

31 

32import os 

33import inspect 

34import sys 

35import typing 

36from pathlib import Path 

37import functools 

38 

39# type defs 

40_ExpType = typing.TypeVar("_ExpType") 

41 

42 

43# Sentinel type for no expression passed 

44class _NoExpPassedSentinel: 

45 """Unique sentinel type used to indicate that no expression was passed.""" 

46 

47 pass 

48 

49 

50_NoExpPassed = _NoExpPassedSentinel() 

51 

52# global variables 

53_CWD: Path = Path.cwd().absolute() 

54_COUNTER: int = 0 

55 

56# configuration 

57PATH_MODE: typing.Literal["relative", "absolute"] = "relative" 

58 

59 

60# path processing 

61def _process_path(path: Path) -> str: 

62 path_abs: Path = path.absolute() 

63 if PATH_MODE == "absolute": 

64 fname = path_abs.as_posix() 

65 elif PATH_MODE == "relative": 

66 try: 

67 fname = path_abs.relative_to( 

68 Path(os.path.commonpath([path_abs, _CWD])) 

69 ).as_posix() 

70 except ValueError: 

71 fname = path_abs.as_posix() 

72 else: 

73 raise ValueError("PATH_MODE must be either 'relative' or 'absolute") 

74 

75 return fname 

76 

77 

78# actual dbg function 

79@typing.overload 

80def dbg() -> _NoExpPassedSentinel: 

81 ... 

82@typing.overload 

83def dbg(exp: _NoExpPassedSentinel, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None) -> _NoExpPassedSentinel: 

84 ... 

85@typing.overload 

86def dbg(exp: _ExpType, formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None) -> _ExpType: 

87 ... 

88def dbg( 

89 exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed, 

90 formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None, 

91) -> typing.Union[_ExpType, _NoExpPassedSentinel]: 

92 """Call dbg with any variable or expression. 

93 

94 Calling dbg will print to stderr the current filename and lineno, 

95 as well as the passed expression and what the expression evaluates to: 

96 

97 from muutils.dbg import dbg 

98 

99 a = 2 

100 b = 5 

101 

102 dbg(a+b) 

103 

104 def square(x: int) -> int: 

105 return x * x 

106 

107 dbg(square(a)) 

108 

109 """ 

110 global _COUNTER 

111 

112 # get the context 

113 fname: str = "unknown" 

114 line_exp: str = "unknown" 

115 for frame in inspect.stack(): 

116 if frame.code_context is None: 

117 continue 

118 line: str = frame.code_context[0] 

119 if "dbg" in line: 

120 start: int = line.find("(") + 1 

121 end: int = line.rfind(")") 

122 if end == -1: 

123 end = len(line) 

124 

125 fname = f"{_process_path(Path(frame.filename))}:{frame.lineno}" 

126 line_exp = line[start:end] 

127 

128 break 

129 

130 # assemble the message 

131 msg: str 

132 if exp is _NoExpPassed: 

133 # if no expression is passed, just show location and counter value 

134 msg = f"[ {fname} ] (dbg {_COUNTER})" 

135 _COUNTER += 1 

136 else: 

137 # if expression passed, format its value and show location, expr, and value 

138 exp_val: str = formatter(exp) if formatter else repr(exp) 

139 msg = f"[ {fname} ] {line_exp} = {exp_val}" 

140 

141 # print the message 

142 print( 

143 msg, 

144 file=sys.stderr, 

145 ) 

146 

147 # return the expression itself 

148 return exp 

149 

150 

151# formatted `dbg_*` functions with their helpers 

152def tensor_info_dict(tensor: typing.Any) -> typing.Dict[str, str]: 

153 output: typing.Dict[str, str] = dict() 

154 # shape 

155 if hasattr(tensor, "shape"): 

156 # output += f"shape={tuple(tensor.shape)}" 

157 output["shape"] = repr(tuple(tensor.shape)) 

158 

159 # print the sum if its a nan or inf 

160 if hasattr(tensor, "sum"): 

161 sum: float = tensor.sum() 

162 if sum != sum: 

163 output["sum"] = repr(sum) 

164 

165 # more info 

166 if hasattr(tensor, "dtype"): 

167 # output += f", dtype={tensor.dtype}" 

168 output["dtype"] = repr(tensor.dtype) 

169 if hasattr(tensor, "device"): 

170 output["device"] = repr(tensor.device) 

171 if hasattr(tensor, "requires_grad"): 

172 output["requires_grad"] = repr(tensor.requires_grad) 

173 

174 # return 

175 return output 

176 

177 

178def tensor_info(tensor: typing.Any) -> str: 

179 info: typing.Dict[str, str] = tensor_info_dict(tensor) 

180 return ", ".join(f"{k}={v}" for k, v in info.items()) 

181 

182 

183dbg_tensor = functools.partial(dbg, formatter=tensor_info)