Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/_pytest/assertion/rewrite.py : 14%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Rewrite assertion AST to produce nice error messages"""
2import ast
3import errno
4import functools
5import importlib.abc
6import importlib.machinery
7import importlib.util
8import io
9import itertools
10import marshal
11import os
12import struct
13import sys
14import tokenize
15import types
16from typing import Callable
17from typing import Dict
18from typing import IO
19from typing import List
20from typing import Optional
21from typing import Sequence
22from typing import Set
23from typing import Tuple
24from typing import Union
26import py
28from _pytest._io.saferepr import saferepr
29from _pytest._version import version
30from _pytest.assertion import util
31from _pytest.assertion.util import ( # noqa: F401
32 format_explanation as _format_explanation,
33)
34from _pytest.compat import fspath
35from _pytest.compat import TYPE_CHECKING
36from _pytest.config import Config
37from _pytest.main import Session
38from _pytest.pathlib import fnmatch_ex
39from _pytest.pathlib import Path
40from _pytest.pathlib import PurePath
41from _pytest.store import StoreKey
43if TYPE_CHECKING:
44 from _pytest.assertion import AssertionState # noqa: F401
47assertstate_key = StoreKey["AssertionState"]()
50# pytest caches rewritten pycs in pycache dirs
51PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
52PYC_EXT = ".py" + (__debug__ and "c" or "o")
53PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
56class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
57 """PEP302/PEP451 import hook which rewrites asserts."""
59 def __init__(self, config: Config) -> None:
60 self.config = config
61 try:
62 self.fnpats = config.getini("python_files")
63 except ValueError:
64 self.fnpats = ["test_*.py", "*_test.py"]
65 self.session = None # type: Optional[Session]
66 self._rewritten_names = set() # type: Set[str]
67 self._must_rewrite = set() # type: Set[str]
68 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
69 # which might result in infinite recursion (#3506)
70 self._writing_pyc = False
71 self._basenames_to_check_rewrite = {"conftest"}
72 self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
73 self._session_paths_checked = False
75 def set_session(self, session: Optional[Session]) -> None:
76 self.session = session
77 self._session_paths_checked = False
79 # Indirection so we can mock calls to find_spec originated from the hook during testing
80 _find_spec = importlib.machinery.PathFinder.find_spec
82 def find_spec(
83 self,
84 name: str,
85 path: Optional[Sequence[Union[str, bytes]]] = None,
86 target: Optional[types.ModuleType] = None,
87 ) -> Optional[importlib.machinery.ModuleSpec]:
88 if self._writing_pyc:
89 return None
90 state = self.config._store[assertstate_key]
91 if self._early_rewrite_bailout(name, state):
92 return None
93 state.trace("find_module called for: %s" % name)
95 # Type ignored because mypy is confused about the `self` binding here.
96 spec = self._find_spec(name, path) # type: ignore
97 if (
98 # the import machinery could not find a file to import
99 spec is None
100 # this is a namespace package (without `__init__.py`)
101 # there's nothing to rewrite there
102 # python3.5 - python3.6: `namespace`
103 # python3.7+: `None`
104 or spec.origin == "namespace"
105 or spec.origin is None
106 # we can only rewrite source files
107 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
108 # if the file doesn't exist, we can't rewrite it
109 or not os.path.exists(spec.origin)
110 ):
111 return None
112 else:
113 fn = spec.origin
115 if not self._should_rewrite(name, fn, state):
116 return None
118 return importlib.util.spec_from_file_location(
119 name,
120 fn,
121 loader=self,
122 submodule_search_locations=spec.submodule_search_locations,
123 )
125 def create_module(
126 self, spec: importlib.machinery.ModuleSpec
127 ) -> Optional[types.ModuleType]:
128 return None # default behaviour is fine
130 def exec_module(self, module: types.ModuleType) -> None:
131 assert module.__spec__ is not None
132 assert module.__spec__.origin is not None
133 fn = Path(module.__spec__.origin)
134 state = self.config._store[assertstate_key]
136 self._rewritten_names.add(module.__name__)
138 # The requested module looks like a test file, so rewrite it. This is
139 # the most magical part of the process: load the source, rewrite the
140 # asserts, and load the rewritten source. We also cache the rewritten
141 # module code in a special pyc. We must be aware of the possibility of
142 # concurrent pytest processes rewriting and loading pycs. To avoid
143 # tricky race conditions, we maintain the following invariant: The
144 # cached pyc is always a complete, valid pyc. Operations on it must be
145 # atomic. POSIX's atomic rename comes in handy.
146 write = not sys.dont_write_bytecode
147 cache_dir = get_cache_dir(fn)
148 if write:
149 ok = try_makedirs(cache_dir)
150 if not ok:
151 write = False
152 state.trace("read only directory: {}".format(cache_dir))
154 cache_name = fn.name[:-3] + PYC_TAIL
155 pyc = cache_dir / cache_name
156 # Notice that even if we're in a read-only directory, I'm going
157 # to check for a cached pyc. This may not be optimal...
158 co = _read_pyc(fn, pyc, state.trace)
159 if co is None:
160 state.trace("rewriting {!r}".format(fn))
161 source_stat, co = _rewrite_test(fn, self.config)
162 if write:
163 self._writing_pyc = True
164 try:
165 _write_pyc(state, co, source_stat, pyc)
166 finally:
167 self._writing_pyc = False
168 else:
169 state.trace("found cached rewritten pyc for {}".format(fn))
170 exec(co, module.__dict__)
172 def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
173 """This is a fast way to get out of rewriting modules.
175 Profiling has shown that the call to PathFinder.find_spec (inside of
176 the find_spec from this class) is a major slowdown, so, this method
177 tries to filter what we're sure won't be rewritten before getting to
178 it.
179 """
180 if self.session is not None and not self._session_paths_checked:
181 self._session_paths_checked = True
182 for initial_path in self.session._initialpaths:
183 # Make something as c:/projects/my_project/path.py ->
184 # ['c:', 'projects', 'my_project', 'path.py']
185 parts = str(initial_path).split(os.path.sep)
186 # add 'path' to basenames to be checked.
187 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
189 # Note: conftest already by default in _basenames_to_check_rewrite.
190 parts = name.split(".")
191 if parts[-1] in self._basenames_to_check_rewrite:
192 return False
194 # For matching the name it must be as if it was a filename.
195 path = PurePath(os.path.sep.join(parts) + ".py")
197 for pat in self.fnpats:
198 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
199 # on the name alone because we need to match against the full path
200 if os.path.dirname(pat):
201 return False
202 if fnmatch_ex(pat, path):
203 return False
205 if self._is_marked_for_rewrite(name, state):
206 return False
208 state.trace("early skip of rewriting module: {}".format(name))
209 return True
211 def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
212 # always rewrite conftest files
213 if os.path.basename(fn) == "conftest.py":
214 state.trace("rewriting conftest file: {!r}".format(fn))
215 return True
217 if self.session is not None:
218 if self.session.isinitpath(py.path.local(fn)):
219 state.trace(
220 "matched test file (was specified on cmdline): {!r}".format(fn)
221 )
222 return True
224 # modules not passed explicitly on the command line are only
225 # rewritten if they match the naming convention for test files
226 fn_path = PurePath(fn)
227 for pat in self.fnpats:
228 if fnmatch_ex(pat, fn_path):
229 state.trace("matched test file {!r}".format(fn))
230 return True
232 return self._is_marked_for_rewrite(name, state)
234 def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
235 try:
236 return self._marked_for_rewrite_cache[name]
237 except KeyError:
238 for marked in self._must_rewrite:
239 if name == marked or name.startswith(marked + "."):
240 state.trace(
241 "matched marked file {!r} (from {!r})".format(name, marked)
242 )
243 self._marked_for_rewrite_cache[name] = True
244 return True
246 self._marked_for_rewrite_cache[name] = False
247 return False
249 def mark_rewrite(self, *names: str) -> None:
250 """Mark import names as needing to be rewritten.
252 The named module or package as well as any nested modules will
253 be rewritten on import.
254 """
255 already_imported = (
256 set(names).intersection(sys.modules).difference(self._rewritten_names)
257 )
258 for name in already_imported:
259 mod = sys.modules[name]
260 if not AssertionRewriter.is_rewrite_disabled(
261 mod.__doc__ or ""
262 ) and not isinstance(mod.__loader__, type(self)):
263 self._warn_already_imported(name)
264 self._must_rewrite.update(names)
265 self._marked_for_rewrite_cache.clear()
267 def _warn_already_imported(self, name: str) -> None:
268 from _pytest.warning_types import PytestAssertRewriteWarning
269 from _pytest.warnings import _issue_warning_captured
271 _issue_warning_captured(
272 PytestAssertRewriteWarning(
273 "Module already imported so cannot be rewritten: %s" % name
274 ),
275 self.config.hook,
276 stacklevel=5,
277 )
279 def get_data(self, pathname: Union[str, bytes]) -> bytes:
280 """Optional PEP302 get_data API."""
281 with open(pathname, "rb") as f:
282 return f.read()
285def _write_pyc_fp(
286 fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
287) -> None:
288 # Technically, we don't have to have the same pyc format as
289 # (C)Python, since these "pycs" should never be seen by builtin
290 # import. However, there's little reason deviate.
291 fp.write(importlib.util.MAGIC_NUMBER)
292 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
293 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
294 size = source_stat.st_size & 0xFFFFFFFF
295 # "<LL" stands for 2 unsigned longs, little-ending
296 fp.write(struct.pack("<LL", mtime, size))
297 fp.write(marshal.dumps(co))
300if sys.platform == "win32":
301 from atomicwrites import atomic_write
303 def _write_pyc(
304 state: "AssertionState",
305 co: types.CodeType,
306 source_stat: os.stat_result,
307 pyc: Path,
308 ) -> bool:
309 try:
310 with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
311 _write_pyc_fp(fp, source_stat, co)
312 except OSError as e:
313 state.trace("error writing pyc file at {}: {}".format(pyc, e))
314 # we ignore any failure to write the cache file
315 # there are many reasons, permission-denied, pycache dir being a
316 # file etc.
317 return False
318 return True
321else:
323 def _write_pyc(
324 state: "AssertionState",
325 co: types.CodeType,
326 source_stat: os.stat_result,
327 pyc: Path,
328 ) -> bool:
329 proc_pyc = "{}.{}".format(pyc, os.getpid())
330 try:
331 fp = open(proc_pyc, "wb")
332 except OSError as e:
333 state.trace(
334 "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno)
335 )
336 return False
338 try:
339 _write_pyc_fp(fp, source_stat, co)
340 os.rename(proc_pyc, fspath(pyc))
341 except OSError as e:
342 state.trace("error writing pyc file at {}: {}".format(pyc, e))
343 # we ignore any failure to write the cache file
344 # there are many reasons, permission-denied, pycache dir being a
345 # file etc.
346 return False
347 finally:
348 fp.close()
349 return True
352def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
353 """read and rewrite *fn* and return the code object."""
354 fn_ = fspath(fn)
355 stat = os.stat(fn_)
356 with open(fn_, "rb") as f:
357 source = f.read()
358 tree = ast.parse(source, filename=fn_)
359 rewrite_asserts(tree, source, fn_, config)
360 co = compile(tree, fn_, "exec", dont_inherit=True)
361 return stat, co
364def _read_pyc(
365 source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
366) -> Optional[types.CodeType]:
367 """Possibly read a pytest pyc containing rewritten code.
369 Return rewritten code if successful or None if not.
370 """
371 try:
372 fp = open(fspath(pyc), "rb")
373 except OSError:
374 return None
375 with fp:
376 try:
377 stat_result = os.stat(fspath(source))
378 mtime = int(stat_result.st_mtime)
379 size = stat_result.st_size
380 data = fp.read(12)
381 except OSError as e:
382 trace("_read_pyc({}): OSError {}".format(source, e))
383 return None
384 # Check for invalid or out of date pyc file.
385 if (
386 len(data) != 12
387 or data[:4] != importlib.util.MAGIC_NUMBER
388 or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
389 ):
390 trace("_read_pyc(%s): invalid or out of date pyc" % source)
391 return None
392 try:
393 co = marshal.load(fp)
394 except Exception as e:
395 trace("_read_pyc({}): marshal.load error {}".format(source, e))
396 return None
397 if not isinstance(co, types.CodeType):
398 trace("_read_pyc(%s): not a code object" % source)
399 return None
400 return co
403def rewrite_asserts(
404 mod: ast.Module,
405 source: bytes,
406 module_path: Optional[str] = None,
407 config: Optional[Config] = None,
408) -> None:
409 """Rewrite the assert statements in mod."""
410 AssertionRewriter(module_path, config, source).run(mod)
413def _saferepr(obj: object) -> str:
414 """Get a safe repr of an object for assertion error messages.
416 The assertion formatting (util.format_explanation()) requires
417 newlines to be escaped since they are a special character for it.
418 Normally assertion.util.format_explanation() does this but for a
419 custom repr it is possible to contain one of the special escape
420 sequences, especially '\n{' and '\n}' are likely to be present in
421 JSON reprs.
423 """
424 return saferepr(obj).replace("\n", "\\n")
427def _format_assertmsg(obj: object) -> str:
428 """Format the custom assertion message given.
430 For strings this simply replaces newlines with '\n~' so that
431 util.format_explanation() will preserve them instead of escaping
432 newlines. For other objects saferepr() is used first.
434 """
435 # reprlib appears to have a bug which means that if a string
436 # contains a newline it gets escaped, however if an object has a
437 # .__repr__() which contains newlines it does not get escaped.
438 # However in either case we want to preserve the newline.
439 replaces = [("\n", "\n~"), ("%", "%%")]
440 if not isinstance(obj, str):
441 obj = saferepr(obj)
442 replaces.append(("\\n", "\n~"))
444 for r1, r2 in replaces:
445 obj = obj.replace(r1, r2)
447 return obj
450def _should_repr_global_name(obj: object) -> bool:
451 if callable(obj):
452 return False
454 try:
455 return not hasattr(obj, "__name__")
456 except Exception:
457 return True
460def _format_boolop(explanations, is_or: bool):
461 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
462 if isinstance(explanation, str):
463 return explanation.replace("%", "%%")
464 else:
465 return explanation.replace(b"%", b"%%")
468def _call_reprcompare(
469 ops: Sequence[str],
470 results: Sequence[bool],
471 expls: Sequence[str],
472 each_obj: Sequence[object],
473) -> str:
474 for i, res, expl in zip(range(len(ops)), results, expls):
475 try:
476 done = not res
477 except Exception:
478 done = True
479 if done:
480 break
481 if util._reprcompare is not None:
482 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
483 if custom is not None:
484 return custom
485 return expl
488def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
489 if util._assertion_pass is not None:
490 util._assertion_pass(lineno, orig, expl)
493def _check_if_assertion_pass_impl() -> bool:
494 """Checks if any plugins implement the pytest_assertion_pass hook
495 in order not to generate explanation unecessarily (might be expensive)"""
496 return True if util._assertion_pass else False
499UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
501BINOP_MAP = {
502 ast.BitOr: "|",
503 ast.BitXor: "^",
504 ast.BitAnd: "&",
505 ast.LShift: "<<",
506 ast.RShift: ">>",
507 ast.Add: "+",
508 ast.Sub: "-",
509 ast.Mult: "*",
510 ast.Div: "/",
511 ast.FloorDiv: "//",
512 ast.Mod: "%%", # escaped for string formatting
513 ast.Eq: "==",
514 ast.NotEq: "!=",
515 ast.Lt: "<",
516 ast.LtE: "<=",
517 ast.Gt: ">",
518 ast.GtE: ">=",
519 ast.Pow: "**",
520 ast.Is: "is",
521 ast.IsNot: "is not",
522 ast.In: "in",
523 ast.NotIn: "not in",
524 ast.MatMult: "@",
525}
528def set_location(node, lineno, col_offset):
529 """Set node location information recursively."""
531 def _fix(node, lineno, col_offset):
532 if "lineno" in node._attributes:
533 node.lineno = lineno
534 if "col_offset" in node._attributes:
535 node.col_offset = col_offset
536 for child in ast.iter_child_nodes(node):
537 _fix(child, lineno, col_offset)
539 _fix(node, lineno, col_offset)
540 return node
543def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
544 """Returns a mapping from {lineno: "assertion test expression"}"""
545 ret = {} # type: Dict[int, str]
547 depth = 0
548 lines = [] # type: List[str]
549 assert_lineno = None # type: Optional[int]
550 seen_lines = set() # type: Set[int]
552 def _write_and_reset() -> None:
553 nonlocal depth, lines, assert_lineno, seen_lines
554 assert assert_lineno is not None
555 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
556 depth = 0
557 lines = []
558 assert_lineno = None
559 seen_lines = set()
561 tokens = tokenize.tokenize(io.BytesIO(src).readline)
562 for tp, source, (lineno, offset), _, line in tokens:
563 if tp == tokenize.NAME and source == "assert":
564 assert_lineno = lineno
565 elif assert_lineno is not None:
566 # keep track of depth for the assert-message `,` lookup
567 if tp == tokenize.OP and source in "([{":
568 depth += 1
569 elif tp == tokenize.OP and source in ")]}":
570 depth -= 1
572 if not lines:
573 lines.append(line[offset:])
574 seen_lines.add(lineno)
575 # a non-nested comma separates the expression from the message
576 elif depth == 0 and tp == tokenize.OP and source == ",":
577 # one line assert with message
578 if lineno in seen_lines and len(lines) == 1:
579 offset_in_trimmed = offset + len(lines[-1]) - len(line)
580 lines[-1] = lines[-1][:offset_in_trimmed]
581 # multi-line assert with message
582 elif lineno in seen_lines:
583 lines[-1] = lines[-1][:offset]
584 # multi line assert with escapd newline before message
585 else:
586 lines.append(line[:offset])
587 _write_and_reset()
588 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
589 _write_and_reset()
590 elif lines and lineno not in seen_lines:
591 lines.append(line)
592 seen_lines.add(lineno)
594 return ret
597class AssertionRewriter(ast.NodeVisitor):
598 """Assertion rewriting implementation.
600 The main entrypoint is to call .run() with an ast.Module instance,
601 this will then find all the assert statements and rewrite them to
602 provide intermediate values and a detailed assertion error. See
603 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
604 for an overview of how this works.
606 The entry point here is .run() which will iterate over all the
607 statements in an ast.Module and for each ast.Assert statement it
608 finds call .visit() with it. Then .visit_Assert() takes over and
609 is responsible for creating new ast statements to replace the
610 original assert statement: it rewrites the test of an assertion
611 to provide intermediate values and replace it with an if statement
612 which raises an assertion error with a detailed explanation in
613 case the expression is false and calls pytest_assertion_pass hook
614 if expression is true.
616 For this .visit_Assert() uses the visitor pattern to visit all the
617 AST nodes of the ast.Assert.test field, each visit call returning
618 an AST node and the corresponding explanation string. During this
619 state is kept in several instance attributes:
621 :statements: All the AST statements which will replace the assert
622 statement.
624 :variables: This is populated by .variable() with each variable
625 used by the statements so that they can all be set to None at
626 the end of the statements.
628 :variable_counter: Counter to create new unique variables needed
629 by statements. Variables are created using .variable() and
630 have the form of "@py_assert0".
632 :expl_stmts: The AST statements which will be executed to get
633 data from the assertion. This is the code which will construct
634 the detailed assertion message that is used in the AssertionError
635 or for the pytest_assertion_pass hook.
637 :explanation_specifiers: A dict filled by .explanation_param()
638 with %-formatting placeholders and their corresponding
639 expressions to use in the building of an assertion message.
640 This is used by .pop_format_context() to build a message.
642 :stack: A stack of the explanation_specifiers dicts maintained by
643 .push_format_context() and .pop_format_context() which allows
644 to build another %-formatted string while already building one.
646 This state is reset on every new assert statement visited and used
647 by the other visitors.
649 """
651 def __init__(
652 self, module_path: Optional[str], config: Optional[Config], source: bytes
653 ) -> None:
654 super().__init__()
655 self.module_path = module_path
656 self.config = config
657 if config is not None:
658 self.enable_assertion_pass_hook = config.getini(
659 "enable_assertion_pass_hook"
660 )
661 else:
662 self.enable_assertion_pass_hook = False
663 self.source = source
665 @functools.lru_cache(maxsize=1)
666 def _assert_expr_to_lineno(self) -> Dict[int, str]:
667 return _get_assertion_exprs(self.source)
669 def run(self, mod: ast.Module) -> None:
670 """Find all assert statements in *mod* and rewrite them."""
671 if not mod.body:
672 # Nothing to do.
673 return
674 # Insert some special imports at the top of the module but after any
675 # docstrings and __future__ imports.
676 aliases = [
677 ast.alias("builtins", "@py_builtins"),
678 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
679 ]
680 doc = getattr(mod, "docstring", None)
681 expect_docstring = doc is None
682 if doc is not None and self.is_rewrite_disabled(doc):
683 return
684 pos = 0
685 lineno = 1
686 for item in mod.body:
687 if (
688 expect_docstring
689 and isinstance(item, ast.Expr)
690 and isinstance(item.value, ast.Str)
691 ):
692 doc = item.value.s
693 if self.is_rewrite_disabled(doc):
694 return
695 expect_docstring = False
696 elif (
697 not isinstance(item, ast.ImportFrom)
698 or item.level > 0
699 or item.module != "__future__"
700 ):
701 lineno = item.lineno
702 break
703 pos += 1
704 else:
705 lineno = item.lineno
706 imports = [
707 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
708 ]
709 mod.body[pos:pos] = imports
710 # Collect asserts.
711 nodes = [mod] # type: List[ast.AST]
712 while nodes:
713 node = nodes.pop()
714 for name, field in ast.iter_fields(node):
715 if isinstance(field, list):
716 new = [] # type: List
717 for i, child in enumerate(field):
718 if isinstance(child, ast.Assert):
719 # Transform assert.
720 new.extend(self.visit(child))
721 else:
722 new.append(child)
723 if isinstance(child, ast.AST):
724 nodes.append(child)
725 setattr(node, name, new)
726 elif (
727 isinstance(field, ast.AST)
728 # Don't recurse into expressions as they can't contain
729 # asserts.
730 and not isinstance(field, ast.expr)
731 ):
732 nodes.append(field)
734 @staticmethod
735 def is_rewrite_disabled(docstring: str) -> bool:
736 return "PYTEST_DONT_REWRITE" in docstring
738 def variable(self) -> str:
739 """Get a new variable."""
740 # Use a character invalid in python identifiers to avoid clashing.
741 name = "@py_assert" + str(next(self.variable_counter))
742 self.variables.append(name)
743 return name
745 def assign(self, expr: ast.expr) -> ast.Name:
746 """Give *expr* a name."""
747 name = self.variable()
748 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
749 return ast.Name(name, ast.Load())
751 def display(self, expr: ast.expr) -> ast.expr:
752 """Call saferepr on the expression."""
753 return self.helper("_saferepr", expr)
755 def helper(self, name: str, *args: ast.expr) -> ast.expr:
756 """Call a helper in this module."""
757 py_name = ast.Name("@pytest_ar", ast.Load())
758 attr = ast.Attribute(py_name, name, ast.Load())
759 return ast.Call(attr, list(args), [])
761 def builtin(self, name: str) -> ast.Attribute:
762 """Return the builtin called *name*."""
763 builtin_name = ast.Name("@py_builtins", ast.Load())
764 return ast.Attribute(builtin_name, name, ast.Load())
766 def explanation_param(self, expr: ast.expr) -> str:
767 """Return a new named %-formatting placeholder for expr.
769 This creates a %-formatting placeholder for expr in the
770 current formatting context, e.g. ``%(py0)s``. The placeholder
771 and expr are placed in the current format context so that it
772 can be used on the next call to .pop_format_context().
774 """
775 specifier = "py" + str(next(self.variable_counter))
776 self.explanation_specifiers[specifier] = expr
777 return "%(" + specifier + ")s"
779 def push_format_context(self) -> None:
780 """Create a new formatting context.
782 The format context is used for when an explanation wants to
783 have a variable value formatted in the assertion message. In
784 this case the value required can be added using
785 .explanation_param(). Finally .pop_format_context() is used
786 to format a string of %-formatted values as added by
787 .explanation_param().
789 """
790 self.explanation_specifiers = {} # type: Dict[str, ast.expr]
791 self.stack.append(self.explanation_specifiers)
793 def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
794 """Format the %-formatted string with current format context.
796 The expl_expr should be an str ast.expr instance constructed from
797 the %-placeholders created by .explanation_param(). This will
798 add the required code to format said string to .expl_stmts and
799 return the ast.Name instance of the formatted string.
801 """
802 current = self.stack.pop()
803 if self.stack:
804 self.explanation_specifiers = self.stack[-1]
805 keys = [ast.Str(key) for key in current.keys()]
806 format_dict = ast.Dict(keys, list(current.values()))
807 form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
808 name = "@py_format" + str(next(self.variable_counter))
809 if self.enable_assertion_pass_hook:
810 self.format_variables.append(name)
811 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
812 return ast.Name(name, ast.Load())
814 def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
815 """Handle expressions we don't have custom code for."""
816 assert isinstance(node, ast.expr)
817 res = self.assign(node)
818 return res, self.explanation_param(self.display(res))
820 def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
821 """Return the AST statements to replace the ast.Assert instance.
823 This rewrites the test of an assertion to provide
824 intermediate values and replace it with an if statement which
825 raises an assertion error with a detailed explanation in case
826 the expression is false.
828 """
829 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
830 from _pytest.warning_types import PytestAssertRewriteWarning
831 import warnings
833 # TODO: This assert should not be needed.
834 assert self.module_path is not None
835 warnings.warn_explicit(
836 PytestAssertRewriteWarning(
837 "assertion is always true, perhaps remove parentheses?"
838 ),
839 category=None,
840 filename=fspath(self.module_path),
841 lineno=assert_.lineno,
842 )
844 self.statements = [] # type: List[ast.stmt]
845 self.variables = [] # type: List[str]
846 self.variable_counter = itertools.count()
848 if self.enable_assertion_pass_hook:
849 self.format_variables = [] # type: List[str]
851 self.stack = [] # type: List[Dict[str, ast.expr]]
852 self.expl_stmts = [] # type: List[ast.stmt]
853 self.push_format_context()
854 # Rewrite assert into a bunch of statements.
855 top_condition, explanation = self.visit(assert_.test)
857 negation = ast.UnaryOp(ast.Not(), top_condition)
859 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
860 msg = self.pop_format_context(ast.Str(explanation))
862 # Failed
863 if assert_.msg:
864 assertmsg = self.helper("_format_assertmsg", assert_.msg)
865 gluestr = "\n>assert "
866 else:
867 assertmsg = ast.Str("")
868 gluestr = "assert "
869 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
870 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
871 err_name = ast.Name("AssertionError", ast.Load())
872 fmt = self.helper("_format_explanation", err_msg)
873 exc = ast.Call(err_name, [fmt], [])
874 raise_ = ast.Raise(exc, None)
875 statements_fail = []
876 statements_fail.extend(self.expl_stmts)
877 statements_fail.append(raise_)
879 # Passed
880 fmt_pass = self.helper("_format_explanation", msg)
881 orig = self._assert_expr_to_lineno()[assert_.lineno]
882 hook_call_pass = ast.Expr(
883 self.helper(
884 "_call_assertion_pass",
885 ast.Num(assert_.lineno),
886 ast.Str(orig),
887 fmt_pass,
888 )
889 )
890 # If any hooks implement assert_pass hook
891 hook_impl_test = ast.If(
892 self.helper("_check_if_assertion_pass_impl"),
893 self.expl_stmts + [hook_call_pass],
894 [],
895 )
896 statements_pass = [hook_impl_test]
898 # Test for assertion condition
899 main_test = ast.If(negation, statements_fail, statements_pass)
900 self.statements.append(main_test)
901 if self.format_variables:
902 variables = [
903 ast.Name(name, ast.Store()) for name in self.format_variables
904 ]
905 clear_format = ast.Assign(variables, ast.NameConstant(None))
906 self.statements.append(clear_format)
908 else: # Original assertion rewriting
909 # Create failure message.
910 body = self.expl_stmts
911 self.statements.append(ast.If(negation, body, []))
912 if assert_.msg:
913 assertmsg = self.helper("_format_assertmsg", assert_.msg)
914 explanation = "\n>assert " + explanation
915 else:
916 assertmsg = ast.Str("")
917 explanation = "assert " + explanation
918 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
919 msg = self.pop_format_context(template)
920 fmt = self.helper("_format_explanation", msg)
921 err_name = ast.Name("AssertionError", ast.Load())
922 exc = ast.Call(err_name, [fmt], [])
923 raise_ = ast.Raise(exc, None)
925 body.append(raise_)
927 # Clear temporary variables by setting them to None.
928 if self.variables:
929 variables = [ast.Name(name, ast.Store()) for name in self.variables]
930 clear = ast.Assign(variables, ast.NameConstant(None))
931 self.statements.append(clear)
932 # Fix line numbers.
933 for stmt in self.statements:
934 set_location(stmt, assert_.lineno, assert_.col_offset)
935 return self.statements
937 def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
938 # Display the repr of the name if it's a local variable or
939 # _should_repr_global_name() thinks it's acceptable.
940 locs = ast.Call(self.builtin("locals"), [], [])
941 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
942 dorepr = self.helper("_should_repr_global_name", name)
943 test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
944 expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
945 return name, self.explanation_param(expr)
947 def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
948 res_var = self.variable()
949 expl_list = self.assign(ast.List([], ast.Load()))
950 app = ast.Attribute(expl_list, "append", ast.Load())
951 is_or = int(isinstance(boolop.op, ast.Or))
952 body = save = self.statements
953 fail_save = self.expl_stmts
954 levels = len(boolop.values) - 1
955 self.push_format_context()
956 # Process each operand, short-circuiting if needed.
957 for i, v in enumerate(boolop.values):
958 if i:
959 fail_inner = [] # type: List[ast.stmt]
960 # cond is set in a prior loop iteration below
961 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
962 self.expl_stmts = fail_inner
963 self.push_format_context()
964 res, expl = self.visit(v)
965 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
966 expl_format = self.pop_format_context(ast.Str(expl))
967 call = ast.Call(app, [expl_format], [])
968 self.expl_stmts.append(ast.Expr(call))
969 if i < levels:
970 cond = res # type: ast.expr
971 if is_or:
972 cond = ast.UnaryOp(ast.Not(), cond)
973 inner = [] # type: List[ast.stmt]
974 self.statements.append(ast.If(cond, inner, []))
975 self.statements = body = inner
976 self.statements = save
977 self.expl_stmts = fail_save
978 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
979 expl = self.pop_format_context(expl_template)
980 return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
982 def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
983 pattern = UNARY_MAP[unary.op.__class__]
984 operand_res, operand_expl = self.visit(unary.operand)
985 res = self.assign(ast.UnaryOp(unary.op, operand_res))
986 return res, pattern % (operand_expl,)
988 def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
989 symbol = BINOP_MAP[binop.op.__class__]
990 left_expr, left_expl = self.visit(binop.left)
991 right_expr, right_expl = self.visit(binop.right)
992 explanation = "({} {} {})".format(left_expl, symbol, right_expl)
993 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
994 return res, explanation
996 def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
997 """
998 visit `ast.Call` nodes
999 """
1000 new_func, func_expl = self.visit(call.func)
1001 arg_expls = []
1002 new_args = []
1003 new_kwargs = []
1004 for arg in call.args:
1005 res, expl = self.visit(arg)
1006 arg_expls.append(expl)
1007 new_args.append(res)
1008 for keyword in call.keywords:
1009 res, expl = self.visit(keyword.value)
1010 new_kwargs.append(ast.keyword(keyword.arg, res))
1011 if keyword.arg:
1012 arg_expls.append(keyword.arg + "=" + expl)
1013 else: # **args have `arg` keywords with an .arg of None
1014 arg_expls.append("**" + expl)
1016 expl = "{}({})".format(func_expl, ", ".join(arg_expls))
1017 new_call = ast.Call(new_func, new_args, new_kwargs)
1018 res = self.assign(new_call)
1019 res_expl = self.explanation_param(self.display(res))
1020 outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
1021 return res, outer_expl
1023 def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
1024 # From Python 3.5, a Starred node can appear in a function call
1025 res, expl = self.visit(starred.value)
1026 new_starred = ast.Starred(res, starred.ctx)
1027 return new_starred, "*" + expl
1029 def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1030 if not isinstance(attr.ctx, ast.Load):
1031 return self.generic_visit(attr)
1032 value, value_expl = self.visit(attr.value)
1033 res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
1034 res_expl = self.explanation_param(self.display(res))
1035 pat = "%s\n{%s = %s.%s\n}"
1036 expl = pat % (res_expl, res_expl, value_expl, attr.attr)
1037 return res, expl
1039 def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1040 self.push_format_context()
1041 left_res, left_expl = self.visit(comp.left)
1042 if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
1043 left_expl = "({})".format(left_expl)
1044 res_variables = [self.variable() for i in range(len(comp.ops))]
1045 load_names = [ast.Name(v, ast.Load()) for v in res_variables]
1046 store_names = [ast.Name(v, ast.Store()) for v in res_variables]
1047 it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
1048 expls = []
1049 syms = []
1050 results = [left_res]
1051 for i, op, next_operand in it:
1052 next_res, next_expl = self.visit(next_operand)
1053 if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
1054 next_expl = "({})".format(next_expl)
1055 results.append(next_res)
1056 sym = BINOP_MAP[op.__class__]
1057 syms.append(ast.Str(sym))
1058 expl = "{} {} {}".format(left_expl, sym, next_expl)
1059 expls.append(ast.Str(expl))
1060 res_expr = ast.Compare(left_res, [op], [next_res])
1061 self.statements.append(ast.Assign([store_names[i]], res_expr))
1062 left_res, left_expl = next_res, next_expl
1063 # Use pytest.assertion.util._reprcompare if that's available.
1064 expl_call = self.helper(
1065 "_call_reprcompare",
1066 ast.Tuple(syms, ast.Load()),
1067 ast.Tuple(load_names, ast.Load()),
1068 ast.Tuple(expls, ast.Load()),
1069 ast.Tuple(results, ast.Load()),
1070 )
1071 if len(comp.ops) > 1:
1072 res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
1073 else:
1074 res = load_names[0]
1075 return res, self.explanation_param(self.pop_format_context(expl_call))
1078def try_makedirs(cache_dir: Path) -> bool:
1079 """Attempts to create the given directory and sub-directories exist, returns True if
1080 successful or it already exists"""
1081 try:
1082 os.makedirs(fspath(cache_dir), exist_ok=True)
1083 except (FileNotFoundError, NotADirectoryError, FileExistsError):
1084 # One of the path components was not a directory:
1085 # - we're in a zip file
1086 # - it is a file
1087 return False
1088 except PermissionError:
1089 return False
1090 except OSError as e:
1091 # as of now, EROFS doesn't have an equivalent OSError-subclass
1092 if e.errno == errno.EROFS:
1093 return False
1094 raise
1095 return True
1098def get_cache_dir(file_path: Path) -> Path:
1099 """Returns the cache directory to write .pyc files for the given .py file path"""
1100 if sys.version_info >= (3, 8) and sys.pycache_prefix:
1101 # given:
1102 # prefix = '/tmp/pycs'
1103 # path = '/home/user/proj/test_app.py'
1104 # we want:
1105 # '/tmp/pycs/home/user/proj'
1106 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
1107 else:
1108 # classic pycache directory
1109 return file_path.parent / "__pycache__"