Edit on GitHub

sqlmesh.utils.jinja

  1from __future__ import annotations
  2
  3import importlib
  4import typing as t
  5from collections import defaultdict
  6
  7from jinja2 import Environment, Template, nodes
  8from pydantic import validator
  9from sqlglot import Dialect, Parser, TokenType
 10
 11from sqlmesh.utils import AttributeDict
 12from sqlmesh.utils.pydantic import PydanticModel
 13
 14
 15def environment(**kwargs: t.Any) -> Environment:
 16    extensions = kwargs.pop("extensions", [])
 17    extensions.append("jinja2.ext.do")
 18    return Environment(extensions=extensions, **kwargs)
 19
 20
 21ENVIRONMENT = environment()
 22
 23
 24class MacroReference(PydanticModel, frozen=True):
 25    package: t.Optional[str]
 26    name: str
 27
 28    @property
 29    def reference(self) -> str:
 30        if self.package is None:
 31            return self.name
 32        return ".".join((self.package, self.name))
 33
 34    def __str__(self) -> str:
 35        return self.reference
 36
 37
 38class MacroInfo(PydanticModel):
 39    """Class to hold macro and its calls"""
 40
 41    definition: str
 42    depends_on: t.List[MacroReference]
 43
 44
 45class MacroReturnVal(Exception):
 46    def __init__(self, val: t.Any):
 47        self.value = val
 48
 49
 50def macro_return(macro: t.Callable) -> t.Callable:
 51    """Decorator to pass data back to the caller"""
 52
 53    def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
 54        try:
 55            return macro(*args, **kwargs)
 56        except MacroReturnVal as ret:
 57            return ret.value
 58
 59    return wrapper
 60
 61
 62class MacroExtractor(Parser):
 63    def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
 64        """Extract a dictionary of macro definitions from a jinja string.
 65
 66        Args:
 67            jinja: The jinja string to extract from.
 68            dialect: The dialect of SQL.
 69
 70        Returns:
 71            A dictionary of macro name to macro definition.
 72        """
 73        self.reset()
 74        self.sql = jinja
 75        self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja)
 76        self._index = -1
 77        self._advance()
 78
 79        macros: t.Dict[str, MacroInfo] = {}
 80
 81        while self._curr:
 82            if self._curr.token_type == TokenType.BLOCK_START:
 83                macro_start = self._curr
 84            elif self._tag == "MACRO" and self._next:
 85                name = self._next.text
 86                while self._curr and self._curr.token_type != TokenType.BLOCK_END:
 87                    self._advance()
 88
 89                while self._curr and self._tag != "ENDMACRO":
 90                    self._advance()
 91
 92                macro_str = self._find_sql(macro_start, self._next)
 93                macros[name] = MacroInfo(
 94                    definition=macro_str,
 95                    depends_on=list(extract_macro_references(macro_str)),
 96                )
 97
 98            self._advance()
 99
100        return macros
101
102    def _advance(self, times: int = 1) -> None:
103        super()._advance(times)
104        self._tag = (
105            self._curr.text.upper()
106            if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START
107            else ""
108        )
109
110
111def call_name(node: nodes.Expr) -> t.Tuple[str, ...]:
112    if isinstance(node, nodes.Name):
113        return (node.name,)
114    if isinstance(node, nodes.Const):
115        return (f"'{node.value}'",)
116    if isinstance(node, nodes.Getattr):
117        return call_name(node.node) + (node.attr,)
118    if isinstance(node, (nodes.Getitem, nodes.Call)):
119        return call_name(node.node)
120    return ()
121
122
123def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str:
124    return ENVIRONMENT.from_string(query).render(methods)
125
126
127def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]:
128    vars_in_scope = vars_in_scope.copy()
129    for child_node in node.iter_child_nodes():
130        if "target" in child_node.fields:
131            target = getattr(child_node, "target")
132            if isinstance(target, nodes.Name):
133                vars_in_scope.add(target.name)
134            elif isinstance(target, nodes.Tuple):
135                for item in target.items:
136                    if isinstance(item, nodes.Name):
137                        vars_in_scope.add(item.name)
138        elif isinstance(child_node, nodes.Macro):
139            for arg in child_node.args:
140                vars_in_scope.add(arg.name)
141        elif isinstance(child_node, nodes.Call):
142            name = call_name(child_node)
143            if name[0][0] != "'" and name[0] not in vars_in_scope:
144                yield name
145        yield from find_call_names(child_node, vars_in_scope)
146
147
148def extract_call_names(jinja_str: str) -> t.List[t.Tuple[str, ...]]:
149    return list(find_call_names(ENVIRONMENT.parse(jinja_str), set()))
150
151
152def extract_macro_references(jinja_str: str) -> t.Set[MacroReference]:
153    result = set()
154    for call_name in extract_call_names(jinja_str):
155        if len(call_name) == 1:
156            result.add(MacroReference(name=call_name[0]))
157        elif len(call_name) == 2:
158            result.add(MacroReference(package=call_name[0], name=call_name[1]))
159    return result
160
161
162JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict]
163
164
165class JinjaMacroRegistry(PydanticModel):
166    """Registry for Jinja macros.
167
168    Args:
169        packages: The mapping from package name to a collection of macro definitions.
170        root_macros: The collection of top-level macro definitions.
171        global_objs: The global objects.
172        create_builtins_module: The name of a module which defines the `create_builtins` factory
173            function that will be used to construct builtin variables and functions.
174    """
175
176    packages: t.Dict[str, t.Dict[str, MacroInfo]] = {}
177    root_macros: t.Dict[str, MacroInfo] = {}
178    global_objs: t.Dict[str, JinjaGlobalAttribute] = {}
179    create_builtins_module: t.Optional[str] = None
180
181    _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {}
182    __environment: t.Optional[Environment] = None
183
184    @validator("global_objs", pre=True)
185    def _validate_attribute_dict(cls, value: t.Any) -> t.Any:
186        if isinstance(value, t.Dict):
187            return {k: AttributeDict(v) if isinstance(v, dict) else v for k, v in value.items()}
188        return value
189
190    def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None:
191        """Adds macros to the target package.
192
193        Args:
194            macros: Macros that should be added.
195            package: The name of the package the given macros belong to. If not specified, the provided
196            macros will be added to the root namespace.
197        """
198
199        if package is not None:
200            package_macros = self.packages.get(package, {})
201            package_macros.update(macros)
202            self.packages[package] = package_macros
203        else:
204            self.root_macros.update(macros)
205
206    def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
207        """Builds a Python callable for a macro with the given reference.
208
209        Args:
210            reference: The macro reference.
211        Returns:
212            The macro as a Python callable or None if not found.
213        """
214        if reference.package is not None and reference.name not in self.packages.get(
215            reference.package, {}
216        ):
217            return None
218        if reference.package is None and reference.name not in self.root_macros:
219            return None
220
221        global_vars = self._create_builtin_globals(kwargs)
222        return self._make_callable(reference.name, reference.package, {}, global_vars)
223
224    def build_environment(self, **kwargs: t.Any) -> Environment:
225        """Builds a new Jinja environment based on this registry."""
226
227        global_vars = self._create_builtin_globals(kwargs)
228
229        callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable] = {}
230
231        root_macros = {
232            name: self._make_callable(name, None, callable_cache, global_vars)
233            for name, macro in self.root_macros.items()
234            if not _is_private_macro(name)
235        }
236
237        package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
238        for package_name, macros in self.packages.items():
239            for macro_name, macro in macros.items():
240                if not _is_private_macro(macro_name):
241                    package_macros[package_name][macro_name] = self._make_callable(
242                        macro_name, package_name, callable_cache, global_vars
243                    )
244
245        env = environment()
246        env.globals.update(
247            {
248                **root_macros,
249                **package_macros,
250                **global_vars,
251            }
252        )
253        env.filters.update(self._environment.filters)
254        return env
255
256    def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry:
257        """Trims the registry by keeping only macros with given references and their transitive dependencies.
258
259        Args:
260            dependencies: References to macros that should be kept.
261
262        Returns:
263            A new trimmed registry.
264        """
265        dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
266        for dep in dependencies:
267            dependencies_by_package[dep.package].add(dep.name)
268
269        result = JinjaMacroRegistry(
270            global_objs=self.global_objs.copy(), create_builtins_module=self.create_builtins_module
271        )
272        for package, names in dependencies_by_package.items():
273            result = result.merge(self._trim_macros(names, package))
274
275        return result
276
277    def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
278        """Returns a copy of the registry which contains macros from both this and `other` instances.
279
280        Args:
281            other: The other registry instance.
282
283        Returns:
284            A new merged registry.
285        """
286
287        root_macros = {
288            **self.root_macros,
289            **other.root_macros,
290        }
291
292        packages = {}
293        for package in {*self.packages, *other.packages}:
294            packages[package] = {
295                **self.packages.get(package, {}),
296                **other.packages.get(package, {}),
297            }
298
299        global_objs = {
300            **self.global_objs,
301            **other.global_objs,
302        }
303
304        return JinjaMacroRegistry(
305            packages=packages,
306            root_macros=root_macros,
307            global_objs=global_objs,
308            create_builtins_module=self.create_builtins_module or other.create_builtins_module,
309        )
310
311    def _make_callable(
312        self,
313        name: str,
314        package: t.Optional[str],
315        callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable],
316        macro_vars: t.Dict[str, t.Any],
317    ) -> t.Callable:
318        cache_key = (package, name)
319        if cache_key in callable_cache:
320            return callable_cache[cache_key]
321
322        macro_vars = macro_vars.copy()
323        macro = self._get_macro(name, package)
324
325        package_macros: t.Dict[str, AttributeDict] = defaultdict(AttributeDict)
326        for dependency in macro.depends_on:
327            if (dependency.package is None and dependency.name == name) or not self._macro_exists(
328                dependency.name, dependency.package or package
329            ):
330                continue
331
332            upstream_callable = self._make_callable(
333                dependency.name, dependency.package or package, callable_cache, macro_vars
334            )
335            if dependency.package is None:
336                macro_vars[dependency.name] = upstream_callable
337            else:
338                package_macros[dependency.package][dependency.name] = upstream_callable
339
340        macro_vars.update(package_macros)
341
342        template = self._parse_macro(name, package)
343        macro_callable = macro_return(
344            getattr(template.make_module(vars=macro_vars), _non_private_name(name))
345        )
346        callable_cache[cache_key] = macro_callable
347        return macro_callable
348
349    def _parse_macro(self, name: str, package: t.Optional[str]) -> Template:
350        cache_key = (package, name)
351        if cache_key not in self._parser_cache:
352            macro = self._get_macro(name, package)
353
354            definition: t.Union[str, nodes.Template] = macro.definition
355            if _is_private_macro(name):
356                # A workaround to expose private jinja macros.
357                definition = self._to_non_private_macro_def(name, macro.definition)
358
359            self._parser_cache[cache_key] = self._environment.from_string(definition)
360        return self._parser_cache[cache_key]
361
362    @property
363    def _environment(self) -> Environment:
364        if self.__environment is None:
365            self.__environment = environment()
366            self.__environment.filters.update(self._create_builtin_filters())
367        return self.__environment
368
369    def _trim_macros(self, names: t.Set[str], package: t.Optional[str]) -> JinjaMacroRegistry:
370        macros = self.packages.get(package, {}) if package is not None else self.root_macros
371        trimmed_macros = {}
372
373        dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
374
375        for name in names:
376            if name in macros:
377                macro = macros[name]
378                trimmed_macros[name] = macro
379                for dependency in macro.depends_on:
380                    dependencies[dependency.package or package].add(dependency.name)
381
382        if package is not None:
383            result = JinjaMacroRegistry(packages={package: trimmed_macros})
384        else:
385            result = JinjaMacroRegistry(root_macros=trimmed_macros)
386
387        for upstream_package, upstream_names in dependencies.items():
388            result = result.merge(self._trim_macros(upstream_names, upstream_package))
389
390        return result
391
392    def _macro_exists(self, name: str, package: t.Optional[str]) -> bool:
393        return (
394            name in self.packages.get(package, {})
395            if package is not None
396            else name in self.root_macros
397        )
398
399    def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo:
400        return self.packages[package][name] if package is not None else self.root_macros[name]
401
402    def _to_non_private_macro_def(self, name: str, definition: str) -> nodes.Template:
403        template = self._environment.parse(definition)
404
405        for node in template.find_all((nodes.Macro, nodes.Call)):
406            if isinstance(node, nodes.Macro):
407                node.name = _non_private_name(name)
408            elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name):
409                node.node.name = _non_private_name(name)
410
411        return template
412
413    def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
414        """Creates Jinja builtin globals using a factory function defined in the provided module."""
415        engine_adapter = global_vars.pop("engine_adapter", None)
416        global_vars = {**self.global_objs, **global_vars}
417        if self.create_builtins_module is not None:
418            module = importlib.import_module(self.create_builtins_module)
419            if hasattr(module, "create_builtin_globals"):
420                return module.create_builtin_globals(self, global_vars, engine_adapter)
421        return global_vars
422
423    def _create_builtin_filters(self) -> t.Dict[str, t.Any]:
424        """Creates Jinja builtin filters using a factory function defined in the provided module."""
425        if self.create_builtins_module is not None:
426            module = importlib.import_module(self.create_builtins_module)
427            if hasattr(module, "create_builtin_filters"):
428                return module.create_builtin_filters()
429        return {}
430
431
432def _is_private_macro(name: str) -> bool:
433    return name.startswith("_")
434
435
436def _non_private_name(name: str) -> str:
437    return name.lstrip("_")
def environment(**kwargs: Any) -> jinja2.environment.Environment:
16def environment(**kwargs: t.Any) -> Environment:
17    extensions = kwargs.pop("extensions", [])
18    extensions.append("jinja2.ext.do")
19    return Environment(extensions=extensions, **kwargs)
class MacroReference(sqlmesh.utils.pydantic.PydanticModel):
25class MacroReference(PydanticModel, frozen=True):
26    package: t.Optional[str]
27    name: str
28
29    @property
30    def reference(self) -> str:
31        if self.package is None:
32            return self.name
33        return ".".join((self.package, self.name))
34
35    def __str__(self) -> str:
36        return self.reference
Inherited Members
pydantic.main.BaseModel
BaseModel
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
Config
dict
json
missing_required_fields
extra_fields
all_fields
required_fields
class MacroInfo(sqlmesh.utils.pydantic.PydanticModel):
39class MacroInfo(PydanticModel):
40    """Class to hold macro and its calls"""
41
42    definition: str
43    depends_on: t.List[MacroReference]

Class to hold macro and its calls

Inherited Members
pydantic.main.BaseModel
BaseModel
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
Config
dict
json
missing_required_fields
extra_fields
all_fields
required_fields
class MacroReturnVal(builtins.Exception):
46class MacroReturnVal(Exception):
47    def __init__(self, val: t.Any):
48        self.value = val

Common base class for all non-exit exceptions.

MacroReturnVal(val: Any)
47    def __init__(self, val: t.Any):
48        self.value = val
Inherited Members
builtins.BaseException
with_traceback
def macro_return(macro: Callable) -> Callable:
51def macro_return(macro: t.Callable) -> t.Callable:
52    """Decorator to pass data back to the caller"""
53
54    def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
55        try:
56            return macro(*args, **kwargs)
57        except MacroReturnVal as ret:
58            return ret.value
59
60    return wrapper

Decorator to pass data back to the caller

class MacroExtractor(sqlglot.parser.Parser):
 63class MacroExtractor(Parser):
 64    def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
 65        """Extract a dictionary of macro definitions from a jinja string.
 66
 67        Args:
 68            jinja: The jinja string to extract from.
 69            dialect: The dialect of SQL.
 70
 71        Returns:
 72            A dictionary of macro name to macro definition.
 73        """
 74        self.reset()
 75        self.sql = jinja
 76        self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja)
 77        self._index = -1
 78        self._advance()
 79
 80        macros: t.Dict[str, MacroInfo] = {}
 81
 82        while self._curr:
 83            if self._curr.token_type == TokenType.BLOCK_START:
 84                macro_start = self._curr
 85            elif self._tag == "MACRO" and self._next:
 86                name = self._next.text
 87                while self._curr and self._curr.token_type != TokenType.BLOCK_END:
 88                    self._advance()
 89
 90                while self._curr and self._tag != "ENDMACRO":
 91                    self._advance()
 92
 93                macro_str = self._find_sql(macro_start, self._next)
 94                macros[name] = MacroInfo(
 95                    definition=macro_str,
 96                    depends_on=list(extract_macro_references(macro_str)),
 97                )
 98
 99            self._advance()
100
101        return macros
102
103    def _advance(self, times: int = 1) -> None:
104        super()._advance(times)
105        self._tag = (
106            self._curr.text.upper()
107            if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START
108            else ""
109        )

Parser consumes a list of tokens produced by the sqlglot.tokens.Tokenizer and produces a parsed syntax tree.

Arguments:
  • error_level: the desired error level. Default: ErrorLevel.RAISE
  • error_message_context: determines the amount of context to capture from a query string when displaying the error message (in number of characters). Default: 50.
  • index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. Default: 0
  • alias_post_tablesample: If the table alias comes after tablesample. Default: False
  • max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
  • null_ordering: Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small"
def extract( self, jinja: str, dialect: str = '') -> Dict[str, sqlmesh.utils.jinja.MacroInfo]:
 64    def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
 65        """Extract a dictionary of macro definitions from a jinja string.
 66
 67        Args:
 68            jinja: The jinja string to extract from.
 69            dialect: The dialect of SQL.
 70
 71        Returns:
 72            A dictionary of macro name to macro definition.
 73        """
 74        self.reset()
 75        self.sql = jinja
 76        self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja)
 77        self._index = -1
 78        self._advance()
 79
 80        macros: t.Dict[str, MacroInfo] = {}
 81
 82        while self._curr:
 83            if self._curr.token_type == TokenType.BLOCK_START:
 84                macro_start = self._curr
 85            elif self._tag == "MACRO" and self._next:
 86                name = self._next.text
 87                while self._curr and self._curr.token_type != TokenType.BLOCK_END:
 88                    self._advance()
 89
 90                while self._curr and self._tag != "ENDMACRO":
 91                    self._advance()
 92
 93                macro_str = self._find_sql(macro_start, self._next)
 94                macros[name] = MacroInfo(
 95                    definition=macro_str,
 96                    depends_on=list(extract_macro_references(macro_str)),
 97                )
 98
 99            self._advance()
100
101        return macros

Extract a dictionary of macro definitions from a jinja string.

Arguments:
  • jinja: The jinja string to extract from.
  • dialect: The dialect of SQL.
Returns:

A dictionary of macro name to macro definition.

Inherited Members
sqlglot.parser.Parser
Parser
reset
parse
parse_into
check_errors
raise_error
expression
validate_expression
def call_name(node: jinja2.nodes.Expr) -> Tuple[str, ...]:
112def call_name(node: nodes.Expr) -> t.Tuple[str, ...]:
113    if isinstance(node, nodes.Name):
114        return (node.name,)
115    if isinstance(node, nodes.Const):
116        return (f"'{node.value}'",)
117    if isinstance(node, nodes.Getattr):
118        return call_name(node.node) + (node.attr,)
119    if isinstance(node, (nodes.Getitem, nodes.Call)):
120        return call_name(node.node)
121    return ()
def render_jinja(query: str, methods: Optional[Dict[str, Any]] = None) -> str:
124def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str:
125    return ENVIRONMENT.from_string(query).render(methods)
def find_call_names( node: jinja2.nodes.Node, vars_in_scope: Set[str]) -> Iterator[Tuple[str, ...]]:
128def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]:
129    vars_in_scope = vars_in_scope.copy()
130    for child_node in node.iter_child_nodes():
131        if "target" in child_node.fields:
132            target = getattr(child_node, "target")
133            if isinstance(target, nodes.Name):
134                vars_in_scope.add(target.name)
135            elif isinstance(target, nodes.Tuple):
136                for item in target.items:
137                    if isinstance(item, nodes.Name):
138                        vars_in_scope.add(item.name)
139        elif isinstance(child_node, nodes.Macro):
140            for arg in child_node.args:
141                vars_in_scope.add(arg.name)
142        elif isinstance(child_node, nodes.Call):
143            name = call_name(child_node)
144            if name[0][0] != "'" and name[0] not in vars_in_scope:
145                yield name
146        yield from find_call_names(child_node, vars_in_scope)
def extract_call_names(jinja_str: str) -> List[Tuple[str, ...]]:
149def extract_call_names(jinja_str: str) -> t.List[t.Tuple[str, ...]]:
150    return list(find_call_names(ENVIRONMENT.parse(jinja_str), set()))
def extract_macro_references(jinja_str: str) -> Set[sqlmesh.utils.jinja.MacroReference]:
153def extract_macro_references(jinja_str: str) -> t.Set[MacroReference]:
154    result = set()
155    for call_name in extract_call_names(jinja_str):
156        if len(call_name) == 1:
157            result.add(MacroReference(name=call_name[0]))
158        elif len(call_name) == 2:
159            result.add(MacroReference(package=call_name[0], name=call_name[1]))
160    return result
class JinjaMacroRegistry(sqlmesh.utils.pydantic.PydanticModel):
166class JinjaMacroRegistry(PydanticModel):
167    """Registry for Jinja macros.
168
169    Args:
170        packages: The mapping from package name to a collection of macro definitions.
171        root_macros: The collection of top-level macro definitions.
172        global_objs: The global objects.
173        create_builtins_module: The name of a module which defines the `create_builtins` factory
174            function that will be used to construct builtin variables and functions.
175    """
176
177    packages: t.Dict[str, t.Dict[str, MacroInfo]] = {}
178    root_macros: t.Dict[str, MacroInfo] = {}
179    global_objs: t.Dict[str, JinjaGlobalAttribute] = {}
180    create_builtins_module: t.Optional[str] = None
181
182    _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {}
183    __environment: t.Optional[Environment] = None
184
185    @validator("global_objs", pre=True)
186    def _validate_attribute_dict(cls, value: t.Any) -> t.Any:
187        if isinstance(value, t.Dict):
188            return {k: AttributeDict(v) if isinstance(v, dict) else v for k, v in value.items()}
189        return value
190
191    def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None:
192        """Adds macros to the target package.
193
194        Args:
195            macros: Macros that should be added.
196            package: The name of the package the given macros belong to. If not specified, the provided
197            macros will be added to the root namespace.
198        """
199
200        if package is not None:
201            package_macros = self.packages.get(package, {})
202            package_macros.update(macros)
203            self.packages[package] = package_macros
204        else:
205            self.root_macros.update(macros)
206
207    def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
208        """Builds a Python callable for a macro with the given reference.
209
210        Args:
211            reference: The macro reference.
212        Returns:
213            The macro as a Python callable or None if not found.
214        """
215        if reference.package is not None and reference.name not in self.packages.get(
216            reference.package, {}
217        ):
218            return None
219        if reference.package is None and reference.name not in self.root_macros:
220            return None
221
222        global_vars = self._create_builtin_globals(kwargs)
223        return self._make_callable(reference.name, reference.package, {}, global_vars)
224
225    def build_environment(self, **kwargs: t.Any) -> Environment:
226        """Builds a new Jinja environment based on this registry."""
227
228        global_vars = self._create_builtin_globals(kwargs)
229
230        callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable] = {}
231
232        root_macros = {
233            name: self._make_callable(name, None, callable_cache, global_vars)
234            for name, macro in self.root_macros.items()
235            if not _is_private_macro(name)
236        }
237
238        package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
239        for package_name, macros in self.packages.items():
240            for macro_name, macro in macros.items():
241                if not _is_private_macro(macro_name):
242                    package_macros[package_name][macro_name] = self._make_callable(
243                        macro_name, package_name, callable_cache, global_vars
244                    )
245
246        env = environment()
247        env.globals.update(
248            {
249                **root_macros,
250                **package_macros,
251                **global_vars,
252            }
253        )
254        env.filters.update(self._environment.filters)
255        return env
256
257    def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry:
258        """Trims the registry by keeping only macros with given references and their transitive dependencies.
259
260        Args:
261            dependencies: References to macros that should be kept.
262
263        Returns:
264            A new trimmed registry.
265        """
266        dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
267        for dep in dependencies:
268            dependencies_by_package[dep.package].add(dep.name)
269
270        result = JinjaMacroRegistry(
271            global_objs=self.global_objs.copy(), create_builtins_module=self.create_builtins_module
272        )
273        for package, names in dependencies_by_package.items():
274            result = result.merge(self._trim_macros(names, package))
275
276        return result
277
278    def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
279        """Returns a copy of the registry which contains macros from both this and `other` instances.
280
281        Args:
282            other: The other registry instance.
283
284        Returns:
285            A new merged registry.
286        """
287
288        root_macros = {
289            **self.root_macros,
290            **other.root_macros,
291        }
292
293        packages = {}
294        for package in {*self.packages, *other.packages}:
295            packages[package] = {
296                **self.packages.get(package, {}),
297                **other.packages.get(package, {}),
298            }
299
300        global_objs = {
301            **self.global_objs,
302            **other.global_objs,
303        }
304
305        return JinjaMacroRegistry(
306            packages=packages,
307            root_macros=root_macros,
308            global_objs=global_objs,
309            create_builtins_module=self.create_builtins_module or other.create_builtins_module,
310        )
311
312    def _make_callable(
313        self,
314        name: str,
315        package: t.Optional[str],
316        callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable],
317        macro_vars: t.Dict[str, t.Any],
318    ) -> t.Callable:
319        cache_key = (package, name)
320        if cache_key in callable_cache:
321            return callable_cache[cache_key]
322
323        macro_vars = macro_vars.copy()
324        macro = self._get_macro(name, package)
325
326        package_macros: t.Dict[str, AttributeDict] = defaultdict(AttributeDict)
327        for dependency in macro.depends_on:
328            if (dependency.package is None and dependency.name == name) or not self._macro_exists(
329                dependency.name, dependency.package or package
330            ):
331                continue
332
333            upstream_callable = self._make_callable(
334                dependency.name, dependency.package or package, callable_cache, macro_vars
335            )
336            if dependency.package is None:
337                macro_vars[dependency.name] = upstream_callable
338            else:
339                package_macros[dependency.package][dependency.name] = upstream_callable
340
341        macro_vars.update(package_macros)
342
343        template = self._parse_macro(name, package)
344        macro_callable = macro_return(
345            getattr(template.make_module(vars=macro_vars), _non_private_name(name))
346        )
347        callable_cache[cache_key] = macro_callable
348        return macro_callable
349
350    def _parse_macro(self, name: str, package: t.Optional[str]) -> Template:
351        cache_key = (package, name)
352        if cache_key not in self._parser_cache:
353            macro = self._get_macro(name, package)
354
355            definition: t.Union[str, nodes.Template] = macro.definition
356            if _is_private_macro(name):
357                # A workaround to expose private jinja macros.
358                definition = self._to_non_private_macro_def(name, macro.definition)
359
360            self._parser_cache[cache_key] = self._environment.from_string(definition)
361        return self._parser_cache[cache_key]
362
363    @property
364    def _environment(self) -> Environment:
365        if self.__environment is None:
366            self.__environment = environment()
367            self.__environment.filters.update(self._create_builtin_filters())
368        return self.__environment
369
370    def _trim_macros(self, names: t.Set[str], package: t.Optional[str]) -> JinjaMacroRegistry:
371        macros = self.packages.get(package, {}) if package is not None else self.root_macros
372        trimmed_macros = {}
373
374        dependencies: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
375
376        for name in names:
377            if name in macros:
378                macro = macros[name]
379                trimmed_macros[name] = macro
380                for dependency in macro.depends_on:
381                    dependencies[dependency.package or package].add(dependency.name)
382
383        if package is not None:
384            result = JinjaMacroRegistry(packages={package: trimmed_macros})
385        else:
386            result = JinjaMacroRegistry(root_macros=trimmed_macros)
387
388        for upstream_package, upstream_names in dependencies.items():
389            result = result.merge(self._trim_macros(upstream_names, upstream_package))
390
391        return result
392
393    def _macro_exists(self, name: str, package: t.Optional[str]) -> bool:
394        return (
395            name in self.packages.get(package, {})
396            if package is not None
397            else name in self.root_macros
398        )
399
400    def _get_macro(self, name: str, package: t.Optional[str]) -> MacroInfo:
401        return self.packages[package][name] if package is not None else self.root_macros[name]
402
403    def _to_non_private_macro_def(self, name: str, definition: str) -> nodes.Template:
404        template = self._environment.parse(definition)
405
406        for node in template.find_all((nodes.Macro, nodes.Call)):
407            if isinstance(node, nodes.Macro):
408                node.name = _non_private_name(name)
409            elif isinstance(node, nodes.Call) and isinstance(node.node, nodes.Name):
410                node.node.name = _non_private_name(name)
411
412        return template
413
414    def _create_builtin_globals(self, global_vars: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
415        """Creates Jinja builtin globals using a factory function defined in the provided module."""
416        engine_adapter = global_vars.pop("engine_adapter", None)
417        global_vars = {**self.global_objs, **global_vars}
418        if self.create_builtins_module is not None:
419            module = importlib.import_module(self.create_builtins_module)
420            if hasattr(module, "create_builtin_globals"):
421                return module.create_builtin_globals(self, global_vars, engine_adapter)
422        return global_vars
423
424    def _create_builtin_filters(self) -> t.Dict[str, t.Any]:
425        """Creates Jinja builtin filters using a factory function defined in the provided module."""
426        if self.create_builtins_module is not None:
427            module = importlib.import_module(self.create_builtins_module)
428            if hasattr(module, "create_builtin_filters"):
429                return module.create_builtin_filters()
430        return {}

Registry for Jinja macros.

Arguments:
  • packages: The mapping from package name to a collection of macro definitions.
  • root_macros: The collection of top-level macro definitions.
  • global_objs: The global objects.
  • create_builtins_module: The name of a module which defines the create_builtins factory function that will be used to construct builtin variables and functions.
def add_macros( self, macros: Dict[str, sqlmesh.utils.jinja.MacroInfo], package: Optional[str] = None) -> None:
191    def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None:
192        """Adds macros to the target package.
193
194        Args:
195            macros: Macros that should be added.
196            package: The name of the package the given macros belong to. If not specified, the provided
197            macros will be added to the root namespace.
198        """
199
200        if package is not None:
201            package_macros = self.packages.get(package, {})
202            package_macros.update(macros)
203            self.packages[package] = package_macros
204        else:
205            self.root_macros.update(macros)

Adds macros to the target package.

Arguments:
  • macros: Macros that should be added.
  • package: The name of the package the given macros belong to. If not specified, the provided
  • macros will be added to the root namespace.
def build_macro( self, reference: sqlmesh.utils.jinja.MacroReference, **kwargs: Any) -> Optional[Callable]:
207    def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
208        """Builds a Python callable for a macro with the given reference.
209
210        Args:
211            reference: The macro reference.
212        Returns:
213            The macro as a Python callable or None if not found.
214        """
215        if reference.package is not None and reference.name not in self.packages.get(
216            reference.package, {}
217        ):
218            return None
219        if reference.package is None and reference.name not in self.root_macros:
220            return None
221
222        global_vars = self._create_builtin_globals(kwargs)
223        return self._make_callable(reference.name, reference.package, {}, global_vars)

Builds a Python callable for a macro with the given reference.

Arguments:
  • reference: The macro reference.
Returns:

The macro as a Python callable or None if not found.

def build_environment(self, **kwargs: Any) -> jinja2.environment.Environment:
225    def build_environment(self, **kwargs: t.Any) -> Environment:
226        """Builds a new Jinja environment based on this registry."""
227
228        global_vars = self._create_builtin_globals(kwargs)
229
230        callable_cache: t.Dict[t.Tuple[t.Optional[str], str], t.Callable] = {}
231
232        root_macros = {
233            name: self._make_callable(name, None, callable_cache, global_vars)
234            for name, macro in self.root_macros.items()
235            if not _is_private_macro(name)
236        }
237
238        package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
239        for package_name, macros in self.packages.items():
240            for macro_name, macro in macros.items():
241                if not _is_private_macro(macro_name):
242                    package_macros[package_name][macro_name] = self._make_callable(
243                        macro_name, package_name, callable_cache, global_vars
244                    )
245
246        env = environment()
247        env.globals.update(
248            {
249                **root_macros,
250                **package_macros,
251                **global_vars,
252            }
253        )
254        env.filters.update(self._environment.filters)
255        return env

Builds a new Jinja environment based on this registry.

def trim( self, dependencies: Iterable[sqlmesh.utils.jinja.MacroReference]) -> sqlmesh.utils.jinja.JinjaMacroRegistry:
257    def trim(self, dependencies: t.Iterable[MacroReference]) -> JinjaMacroRegistry:
258        """Trims the registry by keeping only macros with given references and their transitive dependencies.
259
260        Args:
261            dependencies: References to macros that should be kept.
262
263        Returns:
264            A new trimmed registry.
265        """
266        dependencies_by_package: t.Dict[t.Optional[str], t.Set[str]] = defaultdict(set)
267        for dep in dependencies:
268            dependencies_by_package[dep.package].add(dep.name)
269
270        result = JinjaMacroRegistry(
271            global_objs=self.global_objs.copy(), create_builtins_module=self.create_builtins_module
272        )
273        for package, names in dependencies_by_package.items():
274            result = result.merge(self._trim_macros(names, package))
275
276        return result

Trims the registry by keeping only macros with given references and their transitive dependencies.

Arguments:
  • dependencies: References to macros that should be kept.
Returns:

A new trimmed registry.

278    def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry:
279        """Returns a copy of the registry which contains macros from both this and `other` instances.
280
281        Args:
282            other: The other registry instance.
283
284        Returns:
285            A new merged registry.
286        """
287
288        root_macros = {
289            **self.root_macros,
290            **other.root_macros,
291        }
292
293        packages = {}
294        for package in {*self.packages, *other.packages}:
295            packages[package] = {
296                **self.packages.get(package, {}),
297                **other.packages.get(package, {}),
298            }
299
300        global_objs = {
301            **self.global_objs,
302            **other.global_objs,
303        }
304
305        return JinjaMacroRegistry(
306            packages=packages,
307            root_macros=root_macros,
308            global_objs=global_objs,
309            create_builtins_module=self.create_builtins_module or other.create_builtins_module,
310        )

Returns a copy of the registry which contains macros from both this and other instances.

Arguments:
  • other: The other registry instance.
Returns:

A new merged registry.

Inherited Members
pydantic.main.BaseModel
BaseModel
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
sqlmesh.utils.pydantic.PydanticModel
Config
dict
json
missing_required_fields
extra_fields
all_fields
required_fields