Coverage for tests\unit\test_dbg.py: 100%
154 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
1import inspect
2import tempfile
3from pathlib import Path
4import importlib
5from typing import Any, Callable, Dict, Optional, List, Tuple
7from muutils.dbg import (
8 dbg,
9 dbg_tensor,
10 tensor_info,
11 tensor_info_dict,
12 _NoExpPassed,
13 _process_path,
14 _CWD,
15 # we do use this as a global in `test_dbg_counter_increments`
16 _COUNTER, # noqa: F401
17)
19import pytest
22DBG_MODULE_NAME: str = "muutils.dbg"
24# ============================================================================
25# Dummy Tensor classes for testing tensor_info* functions
26# ============================================================================
29class DummyTensor:
30 """A dummy tensor whose sum is NaN."""
32 shape: Tuple[int, ...] = (2, 3)
33 dtype: str = "float32"
34 device: str = "cpu"
35 requires_grad: bool = False
37 def sum(self) -> float:
38 return float("nan")
41class DummyTensorNormal:
42 """A dummy tensor with a normal sum."""
44 shape: Tuple[int, ...] = (4, 5)
45 dtype: str = "int32"
46 device: str = "cuda"
47 requires_grad: bool = True
49 def sum(self) -> float:
50 return 20.0
53class DummyTensorPartial:
54 """A dummy tensor with only a shape attribute."""
56 shape: Tuple[int, ...] = (7,)
59# ============================================================================
60# Additional Tests for dbg and tensor_info functionality
61# ============================================================================
64# --- Tests for _process_path (existing ones) ---
65def test_process_path_absolute(monkeypatch: pytest.MonkeyPatch) -> None:
66 monkeypatch.setattr(
67 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "absolute"
68 )
69 test_path: Path = Path("somefile.txt")
70 expected: str = test_path.absolute().as_posix()
71 result: str = _process_path(test_path)
72 assert result == expected
75def test_process_path_relative_inside_common(monkeypatch: pytest.MonkeyPatch) -> None:
76 monkeypatch.setattr(
77 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative"
78 )
79 test_path: Path = _CWD / "file.txt"
80 expected: str = "file.txt"
81 result: str = _process_path(test_path)
82 assert result == expected
85def test_process_path_relative_outside_common(monkeypatch: pytest.MonkeyPatch) -> None:
86 monkeypatch.setattr(
87 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "relative"
88 )
89 with tempfile.TemporaryDirectory() as tmp_dir:
90 test_path: Path = Path(tmp_dir) / "file.txt"
91 expected: str = test_path.absolute().as_posix()
92 result: str = _process_path(test_path)
93 assert result == expected.lstrip("/")
96def test_process_path_invalid_mode(monkeypatch: pytest.MonkeyPatch) -> None:
97 monkeypatch.setattr(
98 importlib.import_module(DBG_MODULE_NAME), "PATH_MODE", "invalid"
99 )
100 with pytest.raises(
101 ValueError, match="PATH_MODE must be either 'relative' or 'absolute"
102 ):
103 _process_path(Path("anything.txt"))
106# --- Tests for dbg ---
107def test_dbg_with_expression(capsys: pytest.CaptureFixture) -> None:
108 result: int = dbg(1 + 2)
109 captured: str = capsys.readouterr().err
110 assert "= 3" in captured
111 # check that the printed string includes some form of "1+2"
112 assert "1+2" in captured.replace(" ", "") or "1 + 2" in captured
113 assert result == 3
116def test_dbg_without_expression(
117 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
118) -> None:
119 monkeypatch.setattr(importlib.import_module(DBG_MODULE_NAME), "_COUNTER", 0)
120 result: Any = dbg()
121 captured: str = capsys.readouterr().err.strip()
122 assert "(dbg 0)" in captured
123 no_exp_passed: Any = _NoExpPassed
124 assert result is no_exp_passed
127def test_dbg_custom_formatter(capsys: pytest.CaptureFixture) -> None:
128 custom_formatter: Callable[[Any], str] = lambda x: "custom" # noqa: E731
129 result: str = dbg("anything", formatter=custom_formatter)
130 captured: str = capsys.readouterr().err
131 assert "custom" in captured
132 assert result == "anything"
135def test_dbg_complex_expression(capsys: pytest.CaptureFixture) -> None:
136 # Test a complex expression (lambda invocation)
137 result: int = dbg((lambda x: x * x)(5))
138 captured: str = capsys.readouterr().err
139 assert (
140 "lambda" in captured
141 ) # expecting the extracted code snippet to include 'lambda'
142 assert "25" in captured # evaluated result is 25
143 assert result == 25
146def test_dbg_multiline_code_context(
147 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
148) -> None:
149 # Create a fake stack with two frames; the first frame does not contain "dbg",
150 # but the second does.
151 class FakeFrame:
152 def __init__(
153 self, code_context: Optional[List[str]], filename: str, lineno: int
154 ) -> None:
155 self.code_context = code_context
156 self.filename = filename
157 self.lineno = lineno
159 def fake_inspect_stack() -> List[Any]:
160 return [
161 FakeFrame(["not line"], "frame1.py", 20),
162 FakeFrame(["dbg(2+2)", "ignored line"], "frame2.py", 30),
163 ]
165 monkeypatch.setattr(inspect, "stack", fake_inspect_stack)
166 result: int = dbg(2 + 2)
167 captured: str = capsys.readouterr().err
168 print(captured)
169 assert "2+2" in captured
170 assert "4" in captured
171 assert result == 4
174def test_dbg_counter_increments(capsys: pytest.CaptureFixture) -> None:
175 global _COUNTER
176 _COUNTER = 0
177 dbg()
178 out1: str = capsys.readouterr().err
179 dbg()
180 out2: str = capsys.readouterr().err
181 assert "(dbg 0)" in out1
182 assert "(dbg 1)" in out2
185def test_dbg_formatter_exception() -> None:
186 def bad_formatter(x: Any) -> str:
187 raise ValueError("formatter error")
189 with pytest.raises(ValueError, match="formatter error"):
190 dbg(123, formatter=bad_formatter)
193def test_dbg_incomplete_expression(
194 monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
195) -> None:
196 # Simulate a frame with an incomplete expression (no closing parenthesis)
197 class FakeFrame:
198 def __init__(
199 self, code_context: Optional[List[str]], filename: str, lineno: int
200 ) -> None:
201 self.code_context = code_context
202 self.filename = filename
203 self.lineno = lineno
205 def fake_inspect_stack() -> List[Any]:
206 return [FakeFrame(["dbg(42"], "fake_incomplete.py", 100)]
208 monkeypatch.setattr(inspect, "stack", fake_inspect_stack)
209 result: int = dbg(42)
210 captured: str = capsys.readouterr().err
211 # The extracted expression should be "42" (since there's no closing parenthesis)
212 assert "42" in captured
213 assert result == 42
216def test_dbg_non_callable_formatter() -> None:
217 with pytest.raises(TypeError):
218 dbg(42, formatter="not callable") # type: ignore
221# --- Tests for tensor_info_dict and tensor_info ---
222def test_tensor_info_dict_with_nan() -> None:
223 tensor: DummyTensor = DummyTensor()
224 info: Dict[str, str] = tensor_info_dict(tensor)
225 expected: Dict[str, str] = {
226 "shape": repr((2, 3)),
227 "sum": repr(float("nan")),
228 "dtype": repr("float32"),
229 "device": repr("cpu"),
230 "requires_grad": repr(False),
231 }
232 assert info == expected
235def test_tensor_info_dict_normal() -> None:
236 tensor: DummyTensorNormal = DummyTensorNormal()
237 info: Dict[str, str] = tensor_info_dict(tensor)
238 expected: Dict[str, str] = {
239 "shape": repr((4, 5)),
240 "dtype": repr("int32"),
241 "device": repr("cuda"),
242 "requires_grad": repr(True),
243 }
244 assert info == expected
247def test_tensor_info_dict_partial() -> None:
248 tensor: DummyTensorPartial = DummyTensorPartial()
249 info: Dict[str, str] = tensor_info_dict(tensor)
250 expected: Dict[str, str] = {"shape": repr((7,))}
251 assert info == expected
254def test_tensor_info() -> None:
255 tensor: DummyTensorNormal = DummyTensorNormal()
256 info_str: str = tensor_info(tensor)
257 expected: str = ", ".join(
258 [
259 f"shape={repr((4, 5))}",
260 f"dtype={repr('int32')}",
261 f"device={repr('cuda')}",
262 f"requires_grad={repr(True)}",
263 ]
264 )
265 assert info_str == expected
268def test_tensor_info_dict_no_attributes() -> None:
269 class DummyEmpty:
270 pass
272 dummy = DummyEmpty()
273 info: Dict[str, str] = tensor_info_dict(dummy)
274 assert info == {}
277def test_tensor_info_no_attributes() -> None:
278 class DummyEmpty:
279 pass
281 dummy = DummyEmpty()
282 info_str: str = tensor_info(dummy)
283 assert info_str == ""
286def test_dbg_tensor(capsys: pytest.CaptureFixture) -> None:
287 tensor: DummyTensorPartial = DummyTensorPartial()
288 result: DummyTensorPartial = dbg_tensor(tensor) # type: ignore
289 captured: str = capsys.readouterr().err
290 assert "shape=(7,)" in captured
291 assert result is tensor