sqlmesh.utils.jinja
1from __future__ import annotations 2 3import importlib 4import typing as t 5from collections import defaultdict 6 7from jinja2 import Environment, Template, nodes 8from pydantic import validator 9from sqlglot import Dialect, Parser, TokenType 10 11from sqlmesh.utils import AttributeDict 12from sqlmesh.utils.pydantic import PydanticModel 13 14 15def environment(**kwargs: t.Any) -> Environment: 16 extensions = kwargs.pop("extensions", []) 17 extensions.append("jinja2.ext.do") 18 return Environment(extensions=extensions, **kwargs) 19 20 21ENVIRONMENT = environment() 22 23 24class MacroReference(PydanticModel, frozen=True): 25 package: t.Optional[str] 26 name: str 27 28 @property 29 def reference(self) -> str: 30 if self.package is None: 31 return self.name 32 return ".".join((self.package, self.name)) 33 34 def __str__(self) -> str: 35 return self.reference 36 37 38class MacroInfo(PydanticModel): 39 """Class to hold macro and its calls""" 40 41 definition: str 42 depends_on: t.List[MacroReference] 43 44 45class MacroReturnVal(Exception): 46 def __init__(self, val: t.Any): 47 self.value = val 48 49 50def macro_return(macro: t.Callable) -> t.Callable: 51 """Decorator to pass data back to the caller""" 52 53 def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: 54 try: 55 return macro(*args, **kwargs) 56 except MacroReturnVal as ret: 57 return ret.value 58 59 return wrapper 60 61 62class MacroExtractor(Parser): 63 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 64 """Extract a dictionary of macro definitions from a jinja string. 65 66 Args: 67 jinja: The jinja string to extract from. 68 dialect: The dialect of SQL. 69 70 Returns: 71 A dictionary of macro name to macro definition. 72 """ 73 self.reset() 74 self.sql = jinja 75 self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja) 76 self._index = -1 77 self._advance() 78 79 macros: t.Dict[str, MacroInfo] = {} 80 81 while self._curr: 82 if self._curr.token_type == TokenType.BLOCK_START: 83 macro_start = self._curr 84 elif self._tag == "MACRO" and self._next: 85 name = self._next.text 86 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 87 self._advance() 88 89 while self._curr and self._tag != "ENDMACRO": 90 self._advance() 91 92 macro_str = self._find_sql(macro_start, self._next) 93 macros[name] = MacroInfo( 94 definition=macro_str, 95 depends_on=list(extract_macro_references(macro_str)), 96 ) 97 98 self._advance() 99 100 return macros 101 102 def _advance(self, times: int = 1) -> None: 103 super()._advance(times) 104 self._tag = ( 105 self._curr.text.upper() 106 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 107 else "" 108 ) 109 110 111def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 112 if isinstance(node, nodes.Name): 113 return (node.name,) 114 if isinstance(node, nodes.Const): 115 return (f"'{node.value}'",) 116 if isinstance(node, nodes.Getattr): 117 return call_name(node.node) + (node.attr,) 118 if isinstance(node, (nodes.Getitem, nodes.Call)): 119 return call_name(node.node) 120 return () 121 122 123def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str: 124 return ENVIRONMENT.from_string(query).render(methods) 125 126 127def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]: 128 vars_in_scope = vars_in_scope.copy() 129 for child_node in node.iter_child_nodes(): 130 if "target" in child_node.fields: 131 target = getattr(child_node, "target") 132 if isinstance(target, nodes.Name): 133 vars_in_scope.add(target.name) 134 elif isinstance(target, nodes.Tuple): 135 for item in target.items: 136 if isinstance(item, nodes.Name): 137 vars_in_scope.add(item.name) 138 elif isinstance(child_node, nodes.Macro): 139 for arg in child_node.args: 140 vars_in_scope.add(arg.name) 141 elif isinstance(child_node, nodes.Call): 142 name = call_name(child_node) 143 if name[0][0] != "'" and name[0] not in vars_in_scope: 144 yield name 145 yield from find_call_names(child_node, vars_in_scope) 146 147 148def extract_call_names(jinja_str: str) -> t.List[t.Tuple[str, ...]]: 149 return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) 150 151 152def extract_macro_references(jinja_str: str) -> t.Set[MacroReference]: 153 result = set() 154 for call_name in extract_call_names(jinja_str): 155 if len(call_name) == 1: 156 result.add(MacroReference(name=call_name[0])) 157 elif len(call_name) == 2: 158 result.add(MacroReference(package=call_name[0], name=call_name[1])) 159 return result 160 161 162JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict] 163 164 165class JinjaMacroRegistry(PydanticModel): 166 """Registry for Jinja macros. 167 168 Args: 169 packages: The mapping from package name to a collection of macro definitions. 170 root_macros: The collection of top-level macro definitions. 171 global_objs: The global objects. 172 create_builtins_module: The name of a module which defines the `create_builtins` factory 173 function that will be used to construct builtin variables and functions. 174 """ 175 176 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 177 root_macros: t.Dict[str, MacroInfo] = {} 178 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 179 create_builtins_module: t.Optional[str] = None 180 181 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 182 __environment: t.Optional[Environment] = None 183 184 @validator("global_objs", pre=True) 185 def _validate_attribute_dict(cls, value: t.Any) -> t.Any: 186 if isinstance(value, t.Dict): 187 return {k: AttributeDict(v) if isinstance(v, dict) else v for k, v in value.items()} 188 return value 189 190 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 191 """Adds macros to the target package. 192 193 Args: 194 macros: Macros that should be added. 195 package: The name of the package the given macros belong to. If not specified, the provided 196 macros will be added to the root namespace. 197 """ 198 199 if package is not None: 200 package_macros = self.packages.get(package, {}) 201 package_macros.update(macros) 202 self.packages[package] = package_macros 203 else: 204 self.root_macros.update(macros) 205 206 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 207 """Builds a Python callable for a macro with the given reference. 208 209 Args: 210 reference: The macro reference. 211 Returns: 212 The macro as a Python callable or None if not found. 213 """ 214 if reference.package is not None and reference.name not in self.packages.get( 215 reference.package, {} 216 ): 217 return None 218 if reference.package is None and reference.name not in self.root_macros: 219 return None 220 221 global_vars = self._create_builtin_globals(kwargs) 222 return self._make_callable(reference.name, reference.package, {}, global_vars) 223 224 def build_environment(self, **kwargs: t.Any) -> Environment: 225 """Builds a new Jinja environment based on this registry.""" 226 227 global_vars = self._create_builtin_globals(kwargs) 228 229 callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable] = {} 230 231 root_macros = { 232 name: self._make_callable(name, None, callable_cache, global_vars) 233 for name, macro in self.root_macros.items() 234 if not _is_private_macro(name) 235 } 236 237 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 238 for package_name, macros in self.packages.items(): 239 for macro_name, macro in macros.items(): 240 if not _is_private_macro(macro_name): 241 package_macros[package_name][macro_name] = self._make_callable( 242 macro_name, package_name, callable_cache, global_vars 243 ) 244 245 env = environment() 246 env.globals.update( 247 { 248 **root_macros, 249 **package_macros, 250 **global_vars, 251 } 252 ) 253 env.filters.update(self._environment.filters) 254 return env 255 256 def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry: 257 """Trims the registry by keeping only macros with given references and their transitive dependencies. 258 259 Args: 260 dependencies: References to macros that should be kept. 261 262 Returns: 263 A new trimmed registry. 264 """ 265 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 266 for dep in dependencies: 267 dependencies_by_package[dep.package].add(dep.name) 268 269 result = JinjaMacroRegistry( 270 global_objs=self.global_objs.copy(), create_builtins_module=self.create_builtins_module 271 ) 272 for package, names in dependencies_by_package.items(): 273 result = result.merge(self._trim_macros(names, package)) 274 275 return result 276 277 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 278 """Returns a copy of the registry which contains macros from both this and `other` instances. 279 280 Args: 281 other: The other registry instance. 282 283 Returns: 284 A new merged registry. 285 """ 286 287 root_macros = { 288 **self.root_macros, 289 **other.root_macros, 290 } 291 292 packages = {} 293 for package in {*self.packages, *other.packages}: 294 packages[package] = { 295 **self.packages.get(package, {}), 296 **other.packages.get(package, {}), 297 } 298 299 global_objs = { 300 **self.global_objs, 301 **other.global_objs, 302 } 303 304 return JinjaMacroRegistry( 305 packages=packages, 306 root_macros=root_macros, 307 global_objs=global_objs, 308 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 309 ) 310 311 def _make_callable( 312 self, 313 name: str, 314 package: t.Optional[str], 315 callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable], 316 macro_vars: t.Dict[str, t.Any], 317 ) -> t.Callable: 318 cache_key = (package, name) 319 if cache_key in callable_cache: 320 return callable_cache[cache_key] 321 322 macro_vars = macro_vars.copy() 323 macro = self._get_macro(name, package) 324 325 package_macros: t.Dict[str, AttributeDict] = defaultdict(AttributeDict) 326 for dependency in macro.depends_on: 327 if (dependency.package is None and dependency.name == name) or not self._macro_exists( 328 dependency.name, dependency.package or package 329 ): 330 continue 331 332 upstream_callable = self._make_callable( 333 dependency.name, dependency.package or package, callable_cache, macro_vars 334 ) 335 if dependency.package is None: 336 macro_vars[dependency.name] = upstream_callable 337 else: 338 package_macros[dependency.package][dependency.name] = upstream_callable 339 340 macro_vars.update(package_macros) 341 342 template = self._parse_macro(name, package) 343 macro_callable = macro_return( 344 getattr(template.make_module(vars=macro_vars), _non_private_name(name)) 345 ) 346 callable_cache[cache_key] = macro_callable 347 return macro_callable 348 349 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 350 cache_key = (package, name) 351 if cache_key not in self._parser_cache: 352 macro = self._get_macro(name, package) 353 354 definition: t.Union[str, nodes.Template] = macro.definition 355 if _is_private_macro(name): 356 # A workaround to expose private jinja macros. 357 definition = self._to_non_private_macro_def(name, macro.definition) 358 359 self._parser_cache[cache_key] = self._environment.from_string(definition) 360 return self._parser_cache[cache_key] 361 362 @property 363 def _environment(self) -> Environment: 364 if self.__environment is None: 365 self.__environment = environment() 366 self.__environment.filters.update(self._create_builtin_filters()) 367 return self.__environment 368 369 def _trim_macros(self, names: t.Set[str], package: t.Optional[str]) -> JinjaMacroRegistry: 370 macros = self.packages.get(package, {}) if package is not None else self.root_macros 371 trimmed_macros = {} 372 373 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 374 375 for name in names: 376 if name in macros: 377 macro = macros[name] 378 trimmed_macros[name] = macro 379 for dependency in macro.depends_on: 380 dependencies[dependency.package or package].add(dependency.name) 381 382 if package is not None: 383 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 384 else: 385 result = JinjaMacroRegistry(root_macros=trimmed_macros) 386 387 for upstream_package, upstream_names in dependencies.items(): 388 result = result.merge(self._trim_macros(upstream_names, upstream_package)) 389 390 return result 391 392 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 393 return ( 394 name in self.packages.get(package, {}) 395 if package is not None 396 else name in self.root_macros 397 ) 398 399 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 400 return self.packages[package][name] if package is not None else self.root_macros[name] 401 402 def _to_non_private_macro_def(self, name: str, definition: str) -> nodes.Template: 403 template = self._environment.parse(definition) 404 405 for node in template.find_all((nodes.Macro, nodes.Call)): 406 if isinstance(node, nodes.Macro): 407 node.name = _non_private_name(name) 408 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 409 node.node.name = _non_private_name(name) 410 411 return template 412 413 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 414 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 415 engine_adapter = global_vars.pop("engine_adapter", None) 416 global_vars = {**self.global_objs, **global_vars} 417 if self.create_builtins_module is not None: 418 module = importlib.import_module(self.create_builtins_module) 419 if hasattr(module, "create_builtin_globals"): 420 return module.create_builtin_globals(self, global_vars, engine_adapter) 421 return global_vars 422 423 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 424 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 425 if self.create_builtins_module is not None: 426 module = importlib.import_module(self.create_builtins_module) 427 if hasattr(module, "create_builtin_filters"): 428 return module.create_builtin_filters() 429 return {} 430 431 432def _is_private_macro(name: str) -> bool: 433 return name.startswith("_") 434 435 436def _non_private_name(name: str) -> str: 437 return name.lstrip("_")
25class MacroReference(PydanticModel, frozen=True): 26 package: t.Optional[str] 27 name: str 28 29 @property 30 def reference(self) -> str: 31 if self.package is None: 32 return self.name 33 return ".".join((self.package, self.name)) 34 35 def __str__(self) -> str: 36 return self.reference
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs
39class MacroInfo(PydanticModel): 40 """Class to hold macro and its calls""" 41 42 definition: str 43 depends_on: t.List[MacroReference]
Class to hold macro and its calls
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
51def macro_return(macro: t.Callable) -> t.Callable: 52 """Decorator to pass data back to the caller""" 53 54 def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: 55 try: 56 return macro(*args, **kwargs) 57 except MacroReturnVal as ret: 58 return ret.value 59 60 return wrapper
Decorator to pass data back to the caller
63class MacroExtractor(Parser): 64 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 65 """Extract a dictionary of macro definitions from a jinja string. 66 67 Args: 68 jinja: The jinja string to extract from. 69 dialect: The dialect of SQL. 70 71 Returns: 72 A dictionary of macro name to macro definition. 73 """ 74 self.reset() 75 self.sql = jinja 76 self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja) 77 self._index = -1 78 self._advance() 79 80 macros: t.Dict[str, MacroInfo] = {} 81 82 while self._curr: 83 if self._curr.token_type == TokenType.BLOCK_START: 84 macro_start = self._curr 85 elif self._tag == "MACRO" and self._next: 86 name = self._next.text 87 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 88 self._advance() 89 90 while self._curr and self._tag != "ENDMACRO": 91 self._advance() 92 93 macro_str = self._find_sql(macro_start, self._next) 94 macros[name] = MacroInfo( 95 definition=macro_str, 96 depends_on=list(extract_macro_references(macro_str)), 97 ) 98 99 self._advance() 100 101 return macros 102 103 def _advance(self, times: int = 1) -> None: 104 super()._advance(times) 105 self._tag = ( 106 self._curr.text.upper() 107 if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START 108 else "" 109 )
Parser consumes a list of tokens produced by the sqlglot.tokens.Tokenizer
and produces
a parsed syntax tree.
Arguments:
- error_level: the desired error level. Default: ErrorLevel.RAISE
- error_message_context: determines the amount of context to capture from a query string when displaying the error message (in number of characters). Default: 50.
- index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. Default: 0
- alias_post_tablesample: If the table alias comes after tablesample. Default: False
- max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
- null_ordering: Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small"
64 def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: 65 """Extract a dictionary of macro definitions from a jinja string. 66 67 Args: 68 jinja: The jinja string to extract from. 69 dialect: The dialect of SQL. 70 71 Returns: 72 A dictionary of macro name to macro definition. 73 """ 74 self.reset() 75 self.sql = jinja 76 self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja) 77 self._index = -1 78 self._advance() 79 80 macros: t.Dict[str, MacroInfo] = {} 81 82 while self._curr: 83 if self._curr.token_type == TokenType.BLOCK_START: 84 macro_start = self._curr 85 elif self._tag == "MACRO" and self._next: 86 name = self._next.text 87 while self._curr and self._curr.token_type != TokenType.BLOCK_END: 88 self._advance() 89 90 while self._curr and self._tag != "ENDMACRO": 91 self._advance() 92 93 macro_str = self._find_sql(macro_start, self._next) 94 macros[name] = MacroInfo( 95 definition=macro_str, 96 depends_on=list(extract_macro_references(macro_str)), 97 ) 98 99 self._advance() 100 101 return macros
Extract a dictionary of macro definitions from a jinja string.
Arguments:
- jinja: The jinja string to extract from.
- dialect: The dialect of SQL.
Returns:
A dictionary of macro name to macro definition.
Inherited Members
- sqlglot.parser.Parser
- Parser
- reset
- parse
- parse_into
- check_errors
- raise_error
- expression
- validate_expression
112def call_name(node: nodes.Expr) -> t.Tuple[str, ...]: 113 if isinstance(node, nodes.Name): 114 return (node.name,) 115 if isinstance(node, nodes.Const): 116 return (f"'{node.value}'",) 117 if isinstance(node, nodes.Getattr): 118 return call_name(node.node) + (node.attr,) 119 if isinstance(node, (nodes.Getitem, nodes.Call)): 120 return call_name(node.node) 121 return ()
128def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]: 129 vars_in_scope = vars_in_scope.copy() 130 for child_node in node.iter_child_nodes(): 131 if "target" in child_node.fields: 132 target = getattr(child_node, "target") 133 if isinstance(target, nodes.Name): 134 vars_in_scope.add(target.name) 135 elif isinstance(target, nodes.Tuple): 136 for item in target.items: 137 if isinstance(item, nodes.Name): 138 vars_in_scope.add(item.name) 139 elif isinstance(child_node, nodes.Macro): 140 for arg in child_node.args: 141 vars_in_scope.add(arg.name) 142 elif isinstance(child_node, nodes.Call): 143 name = call_name(child_node) 144 if name[0][0] != "'" and name[0] not in vars_in_scope: 145 yield name 146 yield from find_call_names(child_node, vars_in_scope)
153def extract_macro_references(jinja_str: str) -> t.Set[MacroReference]: 154 result = set() 155 for call_name in extract_call_names(jinja_str): 156 if len(call_name) == 1: 157 result.add(MacroReference(name=call_name[0])) 158 elif len(call_name) == 2: 159 result.add(MacroReference(package=call_name[0], name=call_name[1])) 160 return result
166class JinjaMacroRegistry(PydanticModel): 167 """Registry for Jinja macros. 168 169 Args: 170 packages: The mapping from package name to a collection of macro definitions. 171 root_macros: The collection of top-level macro definitions. 172 global_objs: The global objects. 173 create_builtins_module: The name of a module which defines the `create_builtins` factory 174 function that will be used to construct builtin variables and functions. 175 """ 176 177 packages: t.Dict[str, t.Dict[str, MacroInfo]] = {} 178 root_macros: t.Dict[str, MacroInfo] = {} 179 global_objs: t.Dict[str, JinjaGlobalAttribute] = {} 180 create_builtins_module: t.Optional[str] = None 181 182 _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} 183 __environment: t.Optional[Environment] = None 184 185 @validator("global_objs", pre=True) 186 def _validate_attribute_dict(cls, value: t.Any) -> t.Any: 187 if isinstance(value, t.Dict): 188 return {k: AttributeDict(v) if isinstance(v, dict) else v for k, v in value.items()} 189 return value 190 191 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 192 """Adds macros to the target package. 193 194 Args: 195 macros: Macros that should be added. 196 package: The name of the package the given macros belong to. If not specified, the provided 197 macros will be added to the root namespace. 198 """ 199 200 if package is not None: 201 package_macros = self.packages.get(package, {}) 202 package_macros.update(macros) 203 self.packages[package] = package_macros 204 else: 205 self.root_macros.update(macros) 206 207 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 208 """Builds a Python callable for a macro with the given reference. 209 210 Args: 211 reference: The macro reference. 212 Returns: 213 The macro as a Python callable or None if not found. 214 """ 215 if reference.package is not None and reference.name not in self.packages.get( 216 reference.package, {} 217 ): 218 return None 219 if reference.package is None and reference.name not in self.root_macros: 220 return None 221 222 global_vars = self._create_builtin_globals(kwargs) 223 return self._make_callable(reference.name, reference.package, {}, global_vars) 224 225 def build_environment(self, **kwargs: t.Any) -> Environment: 226 """Builds a new Jinja environment based on this registry.""" 227 228 global_vars = self._create_builtin_globals(kwargs) 229 230 callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable] = {} 231 232 root_macros = { 233 name: self._make_callable(name, None, callable_cache, global_vars) 234 for name, macro in self.root_macros.items() 235 if not _is_private_macro(name) 236 } 237 238 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 239 for package_name, macros in self.packages.items(): 240 for macro_name, macro in macros.items(): 241 if not _is_private_macro(macro_name): 242 package_macros[package_name][macro_name] = self._make_callable( 243 macro_name, package_name, callable_cache, global_vars 244 ) 245 246 env = environment() 247 env.globals.update( 248 { 249 **root_macros, 250 **package_macros, 251 **global_vars, 252 } 253 ) 254 env.filters.update(self._environment.filters) 255 return env 256 257 def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry: 258 """Trims the registry by keeping only macros with given references and their transitive dependencies. 259 260 Args: 261 dependencies: References to macros that should be kept. 262 263 Returns: 264 A new trimmed registry. 265 """ 266 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 267 for dep in dependencies: 268 dependencies_by_package[dep.package].add(dep.name) 269 270 result = JinjaMacroRegistry( 271 global_objs=self.global_objs.copy(), create_builtins_module=self.create_builtins_module 272 ) 273 for package, names in dependencies_by_package.items(): 274 result = result.merge(self._trim_macros(names, package)) 275 276 return result 277 278 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 279 """Returns a copy of the registry which contains macros from both this and `other` instances. 280 281 Args: 282 other: The other registry instance. 283 284 Returns: 285 A new merged registry. 286 """ 287 288 root_macros = { 289 **self.root_macros, 290 **other.root_macros, 291 } 292 293 packages = {} 294 for package in {*self.packages, *other.packages}: 295 packages[package] = { 296 **self.packages.get(package, {}), 297 **other.packages.get(package, {}), 298 } 299 300 global_objs = { 301 **self.global_objs, 302 **other.global_objs, 303 } 304 305 return JinjaMacroRegistry( 306 packages=packages, 307 root_macros=root_macros, 308 global_objs=global_objs, 309 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 310 ) 311 312 def _make_callable( 313 self, 314 name: str, 315 package: t.Optional[str], 316 callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable], 317 macro_vars: t.Dict[str, t.Any], 318 ) -> t.Callable: 319 cache_key = (package, name) 320 if cache_key in callable_cache: 321 return callable_cache[cache_key] 322 323 macro_vars = macro_vars.copy() 324 macro = self._get_macro(name, package) 325 326 package_macros: t.Dict[str, AttributeDict] = defaultdict(AttributeDict) 327 for dependency in macro.depends_on: 328 if (dependency.package is None and dependency.name == name) or not self._macro_exists( 329 dependency.name, dependency.package or package 330 ): 331 continue 332 333 upstream_callable = self._make_callable( 334 dependency.name, dependency.package or package, callable_cache, macro_vars 335 ) 336 if dependency.package is None: 337 macro_vars[dependency.name] = upstream_callable 338 else: 339 package_macros[dependency.package][dependency.name] = upstream_callable 340 341 macro_vars.update(package_macros) 342 343 template = self._parse_macro(name, package) 344 macro_callable = macro_return( 345 getattr(template.make_module(vars=macro_vars), _non_private_name(name)) 346 ) 347 callable_cache[cache_key] = macro_callable 348 return macro_callable 349 350 def _parse_macro(self, name: str, package: t.Optional[str]) -> Template: 351 cache_key = (package, name) 352 if cache_key not in self._parser_cache: 353 macro = self._get_macro(name, package) 354 355 definition: t.Union[str, nodes.Template] = macro.definition 356 if _is_private_macro(name): 357 # A workaround to expose private jinja macros. 358 definition = self._to_non_private_macro_def(name, macro.definition) 359 360 self._parser_cache[cache_key] = self._environment.from_string(definition) 361 return self._parser_cache[cache_key] 362 363 @property 364 def _environment(self) -> Environment: 365 if self.__environment is None: 366 self.__environment = environment() 367 self.__environment.filters.update(self._create_builtin_filters()) 368 return self.__environment 369 370 def _trim_macros(self, names: t.Set[str], package: t.Optional[str]) -> JinjaMacroRegistry: 371 macros = self.packages.get(package, {}) if package is not None else self.root_macros 372 trimmed_macros = {} 373 374 dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 375 376 for name in names: 377 if name in macros: 378 macro = macros[name] 379 trimmed_macros[name] = macro 380 for dependency in macro.depends_on: 381 dependencies[dependency.package or package].add(dependency.name) 382 383 if package is not None: 384 result = JinjaMacroRegistry(packages={package: trimmed_macros}) 385 else: 386 result = JinjaMacroRegistry(root_macros=trimmed_macros) 387 388 for upstream_package, upstream_names in dependencies.items(): 389 result = result.merge(self._trim_macros(upstream_names, upstream_package)) 390 391 return result 392 393 def _macro_exists(self, name: str, package: t.Optional[str]) -> bool: 394 return ( 395 name in self.packages.get(package, {}) 396 if package is not None 397 else name in self.root_macros 398 ) 399 400 def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo: 401 return self.packages[package][name] if package is not None else self.root_macros[name] 402 403 def _to_non_private_macro_def(self, name: str, definition: str) -> nodes.Template: 404 template = self._environment.parse(definition) 405 406 for node in template.find_all((nodes.Macro, nodes.Call)): 407 if isinstance(node, nodes.Macro): 408 node.name = _non_private_name(name) 409 elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name): 410 node.node.name = _non_private_name(name) 411 412 return template 413 414 def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 415 """Creates Jinja builtin globals using a factory function defined in the provided module.""" 416 engine_adapter = global_vars.pop("engine_adapter", None) 417 global_vars = {**self.global_objs, **global_vars} 418 if self.create_builtins_module is not None: 419 module = importlib.import_module(self.create_builtins_module) 420 if hasattr(module, "create_builtin_globals"): 421 return module.create_builtin_globals(self, global_vars, engine_adapter) 422 return global_vars 423 424 def _create_builtin_filters(self) -> t.Dict[str, t.Any]: 425 """Creates Jinja builtin filters using a factory function defined in the provided module.""" 426 if self.create_builtins_module is not None: 427 module = importlib.import_module(self.create_builtins_module) 428 if hasattr(module, "create_builtin_filters"): 429 return module.create_builtin_filters() 430 return {}
Registry for Jinja macros.
Arguments:
- packages: The mapping from package name to a collection of macro definitions.
- root_macros: The collection of top-level macro definitions.
- global_objs: The global objects.
- create_builtins_module: The name of a module which defines the
create_builtins
factory function that will be used to construct builtin variables and functions.
191 def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: 192 """Adds macros to the target package. 193 194 Args: 195 macros: Macros that should be added. 196 package: The name of the package the given macros belong to. If not specified, the provided 197 macros will be added to the root namespace. 198 """ 199 200 if package is not None: 201 package_macros = self.packages.get(package, {}) 202 package_macros.update(macros) 203 self.packages[package] = package_macros 204 else: 205 self.root_macros.update(macros)
Adds macros to the target package.
Arguments:
- macros: Macros that should be added.
- package: The name of the package the given macros belong to. If not specified, the provided
- macros will be added to the root namespace.
207 def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: 208 """Builds a Python callable for a macro with the given reference. 209 210 Args: 211 reference: The macro reference. 212 Returns: 213 The macro as a Python callable or None if not found. 214 """ 215 if reference.package is not None and reference.name not in self.packages.get( 216 reference.package, {} 217 ): 218 return None 219 if reference.package is None and reference.name not in self.root_macros: 220 return None 221 222 global_vars = self._create_builtin_globals(kwargs) 223 return self._make_callable(reference.name, reference.package, {}, global_vars)
Builds a Python callable for a macro with the given reference.
Arguments:
- reference: The macro reference.
Returns:
The macro as a Python callable or None if not found.
225 def build_environment(self, **kwargs: t.Any) -> Environment: 226 """Builds a new Jinja environment based on this registry.""" 227 228 global_vars = self._create_builtin_globals(kwargs) 229 230 callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable] = {} 231 232 root_macros = { 233 name: self._make_callable(name, None, callable_cache, global_vars) 234 for name, macro in self.root_macros.items() 235 if not _is_private_macro(name) 236 } 237 238 package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) 239 for package_name, macros in self.packages.items(): 240 for macro_name, macro in macros.items(): 241 if not _is_private_macro(macro_name): 242 package_macros[package_name][macro_name] = self._make_callable( 243 macro_name, package_name, callable_cache, global_vars 244 ) 245 246 env = environment() 247 env.globals.update( 248 { 249 **root_macros, 250 **package_macros, 251 **global_vars, 252 } 253 ) 254 env.filters.update(self._environment.filters) 255 return env
Builds a new Jinja environment based on this registry.
257 def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry: 258 """Trims the registry by keeping only macros with given references and their transitive dependencies. 259 260 Args: 261 dependencies: References to macros that should be kept. 262 263 Returns: 264 A new trimmed registry. 265 """ 266 dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set) 267 for dep in dependencies: 268 dependencies_by_package[dep.package].add(dep.name) 269 270 result = JinjaMacroRegistry( 271 global_objs=self.global_objs.copy(), create_builtins_module=self.create_builtins_module 272 ) 273 for package, names in dependencies_by_package.items(): 274 result = result.merge(self._trim_macros(names, package)) 275 276 return result
Trims the registry by keeping only macros with given references and their transitive dependencies.
Arguments:
- dependencies: References to macros that should be kept.
Returns:
A new trimmed registry.
278 def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: 279 """Returns a copy of the registry which contains macros from both this and `other` instances. 280 281 Args: 282 other: The other registry instance. 283 284 Returns: 285 A new merged registry. 286 """ 287 288 root_macros = { 289 **self.root_macros, 290 **other.root_macros, 291 } 292 293 packages = {} 294 for package in {*self.packages, *other.packages}: 295 packages[package] = { 296 **self.packages.get(package, {}), 297 **other.packages.get(package, {}), 298 } 299 300 global_objs = { 301 **self.global_objs, 302 **other.global_objs, 303 } 304 305 return JinjaMacroRegistry( 306 packages=packages, 307 root_macros=root_macros, 308 global_objs=global_objs, 309 create_builtins_module=self.create_builtins_module or other.create_builtins_module, 310 )
Returns a copy of the registry which contains macros from both this and other
instances.
Arguments:
- other: The other registry instance.
Returns:
A new merged registry.
Inherited Members
- pydantic.main.BaseModel
- BaseModel
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs