Coverage for tests\unit\test_dbg.py: 100%

154 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-14 01:33 -0700

1import inspect 

2import tempfile 

3from pathlib import Path 

4import importlib 

5from typing import Any, Callable, Dict, Optional, List, Tuple 

6 

7from muutils.dbg import ( 

8 dbg, 

9 dbg_tensor, 

10 tensor_info, 

11 tensor_info_dict, 

12 _NoExpPassed, 

13 _process_path, 

14 _CWD, 

15 # we do use this as a global in `test_dbg_counter_increments` 

16 _COUNTER, # noqa: F401 

17) 

18 

19import pytest 

20 

21 

22DBG_MODULE_NAME: str = "muutils.dbg" 

23 

24# ============================================================================ 

25# Dummy Tensor classes for testing tensor_info* functions 

26# ============================================================================ 

27 

28 

29class DummyTensor: 

30 """A dummy tensor whose sum is NaN.""" 

31 

32 shape: Tuple[int, ...] = (2, 3) 

33 dtype: str = "float32" 

34 device: str = "cpu" 

35 requires_grad: bool = False 

36 

37 def sum(self) -> float: 

38 return float("nan") 

39 

40 

41class DummyTensorNormal: 

42 """A dummy tensor with a normal sum.""" 

43 

44 shape: Tuple[int, ...] = (4, 5) 

45 dtype: str = "int32" 

46 device: str = "cuda" 

47 requires_grad: bool = True 

48 

49 def sum(self) -> float: 

50 return 20.0 

51 

52 

53class DummyTensorPartial: 

54 """A dummy tensor with only a shape attribute.""" 

55 

56 shape: Tuple[int, ...] = (7,) 

57 

58 

59# ============================================================================ 

60# Additional Tests for dbg and tensor_info functionality 

61# ============================================================================ 

62 

63 

64# --- Tests for _process_path (existing ones) --- 

65def test_process_path_absolute(monkeypatch: pytest.MonkeyPatch) -> None: 

66 monkeypatch.setattr( 

67 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "absolute" 

68 ) 

69 test_path: Path = Path("somefile.txt") 

70 expected: str = test_path.absolute().as_posix() 

71 result: str = _process_path(test_path) 

72 assert result == expected 

73 

74 

75def test_process_path_relative_inside_common(monkeypatch: pytest.MonkeyPatch) -> None: 

76 monkeypatch.setattr( 

77 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative" 

78 ) 

79 test_path: Path = _CWD / "file.txt" 

80 expected: str = "file.txt" 

81 result: str = _process_path(test_path) 

82 assert result == expected 

83 

84 

85def test_process_path_relative_outside_common(monkeypatch: pytest.MonkeyPatch) -> None: 

86 monkeypatch.setattr( 

87 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative" 

88 ) 

89 with tempfile.TemporaryDirectory() as tmp_dir: 

90 test_path: Path = Path(tmp_dir) / "file.txt" 

91 expected: str = test_path.absolute().as_posix() 

92 result: str = _process_path(test_path) 

93 assert result == expected.lstrip("/") 

94 

95 

96def test_process_path_invalid_mode(monkeypatch: pytest.MonkeyPatch) -> None: 

97 monkeypatch.setattr( 

98 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "invalid" 

99 ) 

100 with pytest.raises( 

101 ValueError, match="PATH_MODE must be either 'relative' or 'absolute" 

102 ): 

103 _process_path(Path("anything.txt")) 

104 

105 

106# --- Tests for dbg --- 

107def test_dbg_with_expression(capsys: pytest.CaptureFixture) -> None: 

108 result: int = dbg(1 + 2) 

109 captured: str = capsys.readouterr().err 

110 assert "= 3" in captured 

111 # check that the printed string includes some form of "1+2" 

112 assert "1+2" in captured.replace(" ", "") or "1 + 2" in captured 

113 assert result == 3 

114 

115 

116def test_dbg_without_expression( 

117 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture 

118) -> None: 

119 monkeypatch.setattr(importlib.import_module(DBG_MODULE_NAME), "_COUNTER", 0) 

120 result: Any = dbg() 

121 captured: str = capsys.readouterr().err.strip() 

122 assert "(dbg 0)" in captured 

123 no_exp_passed: Any = _NoExpPassed 

124 assert result is no_exp_passed 

125 

126 

127def test_dbg_custom_formatter(capsys: pytest.CaptureFixture) -> None: 

128 custom_formatter: Callable[[Any], str] = lambda x: "custom" # noqa: E731 

129 result: str = dbg("anything", formatter=custom_formatter) 

130 captured: str = capsys.readouterr().err 

131 assert "custom" in captured 

132 assert result == "anything" 

133 

134 

135def test_dbg_complex_expression(capsys: pytest.CaptureFixture) -> None: 

136 # Test a complex expression (lambda invocation) 

137 result: int = dbg((lambda x: x * x)(5)) 

138 captured: str = capsys.readouterr().err 

139 assert ( 

140 "lambda" in captured 

141 ) # expecting the extracted code snippet to include 'lambda' 

142 assert "25" in captured # evaluated result is 25 

143 assert result == 25 

144 

145 

146def test_dbg_multiline_code_context( 

147 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture 

148) -> None: 

149 # Create a fake stack with two frames; the first frame does not contain "dbg", 

150 # but the second does. 

151 class FakeFrame: 

152 def __init__( 

153 self, code_context: Optional[List[str]], filename: str, lineno: int 

154 ) -> None: 

155 self.code_context = code_context 

156 self.filename = filename 

157 self.lineno = lineno 

158 

159 def fake_inspect_stack() -> List[Any]: 

160 return [ 

161 FakeFrame(["not line"], "frame1.py", 20), 

162 FakeFrame(["dbg(2+2)", "ignored line"], "frame2.py", 30), 

163 ] 

164 

165 monkeypatch.setattr(inspect, "stack", fake_inspect_stack) 

166 result: int = dbg(2 + 2) 

167 captured: str = capsys.readouterr().err 

168 print(captured) 

169 assert "2+2" in captured 

170 assert "4" in captured 

171 assert result == 4 

172 

173 

174def test_dbg_counter_increments(capsys: pytest.CaptureFixture) -> None: 

175 global _COUNTER 

176 _COUNTER = 0 

177 dbg() 

178 out1: str = capsys.readouterr().err 

179 dbg() 

180 out2: str = capsys.readouterr().err 

181 assert "(dbg 0)" in out1 

182 assert "(dbg 1)" in out2 

183 

184 

185def test_dbg_formatter_exception() -> None: 

186 def bad_formatter(x: Any) -> str: 

187 raise ValueError("formatter error") 

188 

189 with pytest.raises(ValueError, match="formatter error"): 

190 dbg(123, formatter=bad_formatter) 

191 

192 

193def test_dbg_incomplete_expression( 

194 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture 

195) -> None: 

196 # Simulate a frame with an incomplete expression (no closing parenthesis) 

197 class FakeFrame: 

198 def __init__( 

199 self, code_context: Optional[List[str]], filename: str, lineno: int 

200 ) -> None: 

201 self.code_context = code_context 

202 self.filename = filename 

203 self.lineno = lineno 

204 

205 def fake_inspect_stack() -> List[Any]: 

206 return [FakeFrame(["dbg(42"], "fake_incomplete.py", 100)] 

207 

208 monkeypatch.setattr(inspect, "stack", fake_inspect_stack) 

209 result: int = dbg(42) 

210 captured: str = capsys.readouterr().err 

211 # The extracted expression should be "42" (since there's no closing parenthesis) 

212 assert "42" in captured 

213 assert result == 42 

214 

215 

216def test_dbg_non_callable_formatter() -> None: 

217 with pytest.raises(TypeError): 

218 dbg(42, formatter="not callable") # type: ignore 

219 

220 

221# --- Tests for tensor_info_dict and tensor_info --- 

222def test_tensor_info_dict_with_nan() -> None: 

223 tensor: DummyTensor = DummyTensor() 

224 info: Dict[str, str] = tensor_info_dict(tensor) 

225 expected: Dict[str, str] = { 

226 "shape": repr((2, 3)), 

227 "sum": repr(float("nan")), 

228 "dtype": repr("float32"), 

229 "device": repr("cpu"), 

230 "requires_grad": repr(False), 

231 } 

232 assert info == expected 

233 

234 

235def test_tensor_info_dict_normal() -> None: 

236 tensor: DummyTensorNormal = DummyTensorNormal() 

237 info: Dict[str, str] = tensor_info_dict(tensor) 

238 expected: Dict[str, str] = { 

239 "shape": repr((4, 5)), 

240 "dtype": repr("int32"), 

241 "device": repr("cuda"), 

242 "requires_grad": repr(True), 

243 } 

244 assert info == expected 

245 

246 

247def test_tensor_info_dict_partial() -> None: 

248 tensor: DummyTensorPartial = DummyTensorPartial() 

249 info: Dict[str, str] = tensor_info_dict(tensor) 

250 expected: Dict[str, str] = {"shape": repr((7,))} 

251 assert info == expected 

252 

253 

254def test_tensor_info() -> None: 

255 tensor: DummyTensorNormal = DummyTensorNormal() 

256 info_str: str = tensor_info(tensor) 

257 expected: str = ", ".join( 

258 [ 

259 f"shape={repr((4, 5))}", 

260 f"dtype={repr('int32')}", 

261 f"device={repr('cuda')}", 

262 f"requires_grad={repr(True)}", 

263 ] 

264 ) 

265 assert info_str == expected 

266 

267 

268def test_tensor_info_dict_no_attributes() -> None: 

269 class DummyEmpty: 

270 pass 

271 

272 dummy = DummyEmpty() 

273 info: Dict[str, str] = tensor_info_dict(dummy) 

274 assert info == {} 

275 

276 

277def test_tensor_info_no_attributes() -> None: 

278 class DummyEmpty: 

279 pass 

280 

281 dummy = DummyEmpty() 

282 info_str: str = tensor_info(dummy) 

283 assert info_str == "" 

284 

285 

286def test_dbg_tensor(capsys: pytest.CaptureFixture) -> None: 

287 tensor: DummyTensorPartial = DummyTensorPartial() 

288 result: DummyTensorPartial = dbg_tensor(tensor) # type: ignore 

289 captured: str = capsys.readouterr().err 

290 assert "shape=(7,)" in captured 

291 assert result is tensor