Edit on GitHub

sqlmesh.core.dialect

  1from __future__ import annotations
  2
  3import functools
  4import re
  5import typing as t
  6from difflib import unified_diff
  7
  8import pandas as pd
  9from jinja2.meta import find_undeclared_variables
 10from sqlglot import Dialect, Generator, Parser, TokenType, exp
 11
 12from sqlmesh.utils.jinja import ENVIRONMENT
 13
 14
 15class Model(exp.Expression):
 16    arg_types = {"expressions": True}
 17
 18
 19class Audit(exp.Expression):
 20    arg_types = {"expressions": True}
 21
 22
 23class Jinja(exp.Func):
 24    arg_types = {"this": True, "expressions": False}
 25    is_var_len_args = True
 26
 27
 28class ModelKind(exp.Expression):
 29    arg_types = {"this": True, "expressions": False}
 30
 31
 32class MacroVar(exp.Var):
 33    pass
 34
 35
 36class MacroFunc(exp.Func):
 37    @property
 38    def name(self) -> str:
 39        return self.this.name
 40
 41
 42class MacroDef(MacroFunc):
 43    arg_types = {"this": True, "expression": True}
 44
 45
 46class MacroSQL(MacroFunc):
 47    arg_types = {"this": True, "into": False}
 48
 49
 50class MacroStrReplace(MacroFunc):
 51    pass
 52
 53
 54class PythonCode(exp.Expression):
 55    arg_types = {"expressions": True}
 56
 57
 58class DColonCast(exp.Cast):
 59    pass
 60
 61
 62@t.no_type_check
 63def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
 64    if self._curr is None:
 65        return None
 66
 67    parser = PARSERS.get(self._curr.text.upper())
 68
 69    if parser:
 70        # Capture any available description in the form of a comment
 71        comments = self._curr.comments
 72
 73        self._advance()
 74        meta = self._parse_wrapped(lambda: parser(self))
 75
 76        meta.comments = comments
 77        return meta
 78    return self.__parse_statement()
 79
 80
 81@t.no_type_check
 82def _parse_lambda(self: Parser) -> t.Optional[exp.Expression]:
 83    node = self.__parse_lambda()
 84    if isinstance(node, exp.Lambda):
 85        node.set("this", self._parse_alias(node.this))
 86    return node
 87
 88
 89def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]:
 90    index = self._index
 91    field = self._parse_primary() or self._parse_function({}) or self._parse_id_var()
 92
 93    if isinstance(field, exp.Func):
 94        macro_name = field.name.upper()
 95        if macro_name != keyword_macro and macro_name in KEYWORD_MACROS:
 96            self._retreat(index)
 97            return None
 98        if isinstance(field, exp.Anonymous):
 99            name = field.name.upper()
100            if name == "DEF":
101                return self.expression(
102                    MacroDef, this=field.expressions[0], expression=field.expressions[1]
103                )
104            if name == "SQL":
105                into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None
106                return self.expression(MacroSQL, this=field.expressions[0], into=into)
107        return self.expression(MacroFunc, this=field)
108
109    if field is None:
110        return None
111
112    if field.is_string or (isinstance(field, exp.Identifier) and field.quoted):
113        return self.expression(MacroStrReplace, this=exp.Literal.string(field.this))
114    return self.expression(MacroVar, this=field.this)
115
116
117KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY"}
118
119
120def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]:
121    if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or (
122        self._next and self._next.text.upper() != name.upper()
123    ):
124        return None
125
126    self._advance(1)
127    return _parse_macro(self, keyword_macro=name)
128
129
130@t.no_type_check
131def _parse_with(self: Parser) -> t.Optional[exp.Expression]:
132    macro = _parse_matching_macro(self, "WITH")
133    if not macro:
134        return self.__parse_with()
135
136    macro.this.append("expressions", self.__parse_with(True))
137    return macro
138
139
140@t.no_type_check
141def _parse_join(self: Parser) -> t.Optional[exp.Expression]:
142    index = self._index
143    natural, side, kind = self._parse_join_side_and_kind()
144    macro = _parse_matching_macro(self, "JOIN")
145    if not macro:
146        self._retreat(index)
147        return self.__parse_join()
148
149    join = self.__parse_join(True)
150    if natural:
151        join.set("natural", True)
152    if side:
153        join.set("side", side.text)
154    if kind:
155        join.set("kind", kind.text)
156
157    macro.this.append("expressions", join)
158    return macro
159
160
161@t.no_type_check
162def _parse_where(self: Parser) -> t.Optional[exp.Expression]:
163    macro = _parse_matching_macro(self, "WHERE")
164    if not macro:
165        return self.__parse_where()
166
167    macro.this.append("expressions", self.__parse_where(True))
168    return macro
169
170
171@t.no_type_check
172def _parse_group(self: Parser) -> t.Optional[exp.Expression]:
173    macro = _parse_matching_macro(self, "GROUP_BY")
174    if not macro:
175        return self.__parse_group()
176
177    macro.this.append("expressions", self.__parse_group(True))
178    return macro
179
180
181@t.no_type_check
182def _parse_having(self: Parser) -> t.Optional[exp.Expression]:
183    macro = _parse_matching_macro(self, "HAVING")
184    if not macro:
185        return self.__parse_having()
186
187    macro.this.append("expressions", self.__parse_having(True))
188    return macro
189
190
191@t.no_type_check
192def _parse_order(self: Parser, this: exp.Expression = None) -> t.Optional[exp.Expression]:
193    macro = _parse_matching_macro(self, "ORDER_BY")
194    if not macro:
195        return self.__parse_order(this)
196
197    macro.this.append("expressions", self.__parse_order(this, True))
198    return macro
199
200
201def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
202    key = self._parse_id_var(True)
203
204    if not key:
205        return None
206
207    index = self._index
208    if self._match(TokenType.L_PAREN):
209        self._retreat(index)
210        value: t.Optional[exp.Expression] = self.expression(
211            exp.Tuple,
212            expressions=self._parse_wrapped_csv(
213                lambda: self._parse_string() or self._parse_id_var()
214            ),
215        )
216    else:
217        value = self._parse_bracket(self._parse_field(any_token=True))
218
219    return self.expression(exp.Property, this=key.name.lower(), value=value)
220
221
222def _create_parser(parser_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable:
223    def parse(self: Parser) -> t.Optional[exp.Expression]:
224        from sqlmesh.core.model.kind import ModelKindName
225
226        expressions = []
227
228        while True:
229            key_expression = self._parse_id_var(any_token=True)
230
231            if not key_expression:
232                break
233
234            key = key_expression.name.lower()
235
236            value: t.Optional[exp.Expression | str]
237
238            if key in table_keys:
239                value = exp.table_name(self._parse_table())
240            elif key == "columns":
241                value = self._parse_schema()
242            elif key == "kind":
243                id_var = self._parse_id_var(any_token=True)
244                if not id_var:
245                    value = None
246                else:
247                    index = self._index
248                    kind = ModelKindName[id_var.name.upper()]
249
250                    if kind in (
251                        ModelKindName.INCREMENTAL_BY_TIME_RANGE,
252                        ModelKindName.INCREMENTAL_BY_UNIQUE_KEY,
253                        ModelKindName.SEED,
254                    ) and self._match(TokenType.L_PAREN):
255                        self._retreat(index)
256                        props = self._parse_wrapped_csv(functools.partial(_parse_props, self))
257                    else:
258                        props = None
259                    value = self.expression(
260                        ModelKind,
261                        this=kind.value,
262                        expressions=props,
263                    )
264            else:
265                value = self._parse_bracket(self._parse_field(any_token=True))
266
267            expressions.append(self.expression(exp.Property, this=key, value=value))
268
269            if not self._match(TokenType.COMMA):
270                break
271
272        return self.expression(parser_type, expressions=expressions)
273
274    return parse
275
276
277_parse_model = _create_parser(Model, ["name"])
278_parse_audit = _create_parser(Audit, ["model"])
279PARSERS = {"MODEL": _parse_model, "AUDIT": _parse_audit}
280
281
282def _model_sql(self: Generator, expression: exp.Expression) -> str:
283    props = ",\n".join(
284        self.indent(f"{prop.name} {self.sql(prop, 'value')}") for prop in expression.expressions
285    )
286    return "\n".join(["MODEL (", props, ")"])
287
288
289def _model_kind_sql(self: Generator, expression: ModelKind) -> str:
290    props = ",\n".join(
291        self.indent(f"{prop.this} {self.sql(prop, 'value')}") for prop in expression.expressions
292    )
293    if props:
294        return "\n".join([f"{expression.this} (", props, ")"])
295    return expression.name.upper()
296
297
298def _macro_keyword_func_sql(self: Generator, expression: exp.Expression) -> str:
299    name = expression.name
300    keyword = name.replace("_", " ")
301    *args, clause = expression.expressions
302    macro = f"@{name}({self.format_args(*args)})"
303    return self.sql(clause).replace(keyword, macro, 1)
304
305
306def _macro_func_sql(self: Generator, expression: exp.Expression) -> str:
307    expression = expression.this
308    name = expression.name
309    if name in KEYWORD_MACROS:
310        return _macro_keyword_func_sql(self, expression)
311    return f"@{name}({self.format_args(*expression.expressions)})"
312
313
314def _override(klass: t.Type[Parser], func: t.Callable) -> None:
315    name = func.__name__
316    setattr(klass, f"_{name}", getattr(klass, name))
317    setattr(klass, name, func)
318
319
320def format_model_expressions(
321    expressions: t.List[exp.Expression], dialect: t.Optional[str] = None
322) -> str:
323    """Format a model's expressions into a standardized format.
324
325    Args:
326        expressions: The model's expressions, must be at least model def + query.
327        dialect: The dialect to render the expressions as.
328    Returns:
329        A string with the formatted model.
330    """
331    if len(expressions) == 1:
332        return expressions[0].sql(pretty=True, dialect=dialect)
333
334    *statements, query = expressions
335    query = query.copy()
336    selects = []
337
338    for expression in query.expressions:
339        column = None
340        comments = expression.comments
341        expression.comments = None
342
343        if not isinstance(expression, exp.Alias):
344            if expression.name:
345                expression = expression.replace(exp.alias_(expression.copy(), expression.name))
346
347        column = column or expression
348        expression = expression.this
349
350        if isinstance(expression, exp.Cast):
351            this = expression.this
352            if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
353                expression.replace(DColonCast(this=this, to=expression.to))
354        column.comments = comments
355        selects.append(column)
356
357    query.set("expressions", selects)
358
359    return ";\n\n".join(
360        [
361            *(statement.sql(pretty=True, dialect=dialect) for statement in statements),
362            query.sql(pretty=True, dialect=dialect),
363        ]
364    ).strip()
365
366
367def text_diff(
368    a: t.Optional[exp.Expression],
369    b: t.Optional[exp.Expression],
370    dialect: t.Optional[str] = None,
371) -> str:
372    """Find the unified text diff between two expressions."""
373    return "\n".join(
374        unified_diff(
375            a.sql(pretty=True, comments=False, dialect=dialect).split("\n") if a else "",
376            b.sql(pretty=True, comments=False, dialect=dialect).split("\n") if b else "",
377        )
378    )
379
380
381DIALECT_PATTERN = re.compile(
382    r"(model|audit).*?\(.*?dialect[^a-z,]+([a-z]*|,)", re.IGNORECASE | re.DOTALL
383)
384
385
386def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]:
387    """Parse a sql string.
388
389    Supports parsing model definition.
390    If a jinja block is detected, the query is stored as raw string in a Jinja node.
391
392    Args:
393        sql: The sql based definition.
394        default_dialect: The dialect to use if the model does not specify one.
395
396    Returns:
397        A list of the expressions, [Model, *Statements, Query | Jinja]
398    """
399    match = DIALECT_PATTERN.search(sql)
400    dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)()
401
402    tokens = dialect.tokenizer.tokenize(sql)
403    chunks: t.List[t.Tuple[t.List, bool]] = [([], False)]
404    total = len(tokens)
405
406    for i, token in enumerate(tokens):
407        if token.token_type == TokenType.SEMICOLON:
408            if i < total - 1:
409                chunks.append(([], False))
410        else:
411            if token.token_type == TokenType.BLOCK_START or (
412                i < total - 1
413                and token.token_type == TokenType.L_BRACE
414                and tokens[i + 1].token_type == TokenType.L_BRACE
415            ):
416                chunks[-1] = (chunks[-1][0], True)
417            chunks[-1][0].append(token)
418
419    expressions: t.List[exp.Expression] = []
420    sql_lines = None
421
422    for chunk, is_jinja in chunks:
423        if is_jinja:
424            start, *_, end = chunk
425            sql_lines = sql_lines or sql.split("\n")
426            lines = sql_lines[start.line - 1 : end.line]
427            lines[0] = lines[0][start.col - 1 :]
428            lines[-1] = lines[-1][: end.col + len(end.text) - 1]
429            segment = "\n".join(lines)
430            variables = [
431                exp.Literal.string(var)
432                for var in find_undeclared_variables(ENVIRONMENT.parse(segment))
433            ]
434            expressions.append(Jinja(this=exp.Literal.string(segment), expressions=variables))
435        else:
436            for expression in dialect.parser().parse(chunk, sql):
437                if expression:
438                    expressions.append(expression)
439
440    return expressions
441
442
443@t.no_type_check
444def extend_sqlglot() -> None:
445    """Extend SQLGlot with SQLMesh's custom macro aware dialect."""
446    parsers = {Parser}
447    generators = {Generator}
448
449    for dialect in Dialect.classes.values():
450        if hasattr(dialect, "Parser"):
451            parsers.add(dialect.Parser)
452        if hasattr(dialect, "Generator"):
453            generators.add(dialect.Generator)
454
455    for generator in generators:
456        if MacroFunc not in generator.TRANSFORMS:
457            generator.TRANSFORMS.update(
458                {
459                    DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
460                    MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
461                    MacroFunc: _macro_func_sql,
462                    MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}",
463                    MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})",
464                    MacroVar: lambda self, e: f"@{e.name}",
465                    Model: _model_sql,
466                    Jinja: lambda self, e: e.name,
467                    ModelKind: _model_kind_sql,
468                    PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
469                }
470            )
471            generator.WITH_SEPARATED_COMMENTS = (
472                *generator.WITH_SEPARATED_COMMENTS,
473                Model,
474            )
475
476    for parser in parsers:
477        parser.FUNCTIONS.update(
478            {
479                "JINJA": Jinja.from_arg_list,
480            }
481        )
482        parser.PLACEHOLDER_PARSERS.update(
483            {
484                TokenType.PARAMETER: _parse_macro,
485            }
486        )
487
488    _override(Parser, _parse_statement)
489    _override(Parser, _parse_join)
490    _override(Parser, _parse_order)
491    _override(Parser, _parse_where)
492    _override(Parser, _parse_group)
493    _override(Parser, _parse_with)
494    _override(Parser, _parse_having)
495    _override(Parser, _parse_lambda)
496
497
498def select_from_values(
499    values: t.Iterable[t.Tuple[t.Any, ...]],
500    columns_to_types: t.Dict[str, exp.DataType],
501    batch_size: int = 0,
502    alias: str = "t",
503) -> t.Generator[exp.Select, None, None]:
504    """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
505
506    Args:
507        values: List of values to use for the VALUES expression.
508        columns_to_types: Mapping of column names to types to assign to the values.
509        batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
510        alias: The alias to assign to the values expression. If not provided then will default to "t"
511
512    Returns:
513        This method operates as a generator and yields a VALUES expression.
514    """
515    casted_columns = [
516        exp.alias_(exp.cast(column, to=kind), column) for column, kind in columns_to_types.items()
517    ]
518    batch = []
519    for row in values:
520        batch.append(row)
521        if batch_size > 0 and len(batch) > batch_size:
522            values_exp = exp.values(batch, alias=alias, columns=columns_to_types)
523            yield exp.select(*casted_columns).from_(values_exp)
524            batch.clear()
525    if batch:
526        values_exp = exp.values(batch, alias=alias, columns=columns_to_types)
527        yield exp.select(*casted_columns).from_(values_exp)
528
529
530def pandas_to_sql(
531    df: pd.DataFrame,
532    columns_to_types: t.Dict[str, exp.DataType],
533    batch_size: int = 0,
534    alias: str = "t",
535) -> t.Generator[exp.Select, None, None]:
536    """Convert a pandas dataframe into a VALUES sql statement.
537
538    Args:
539        df: A pandas dataframe to convert.
540        columns_to_types: Mapping of column names to types to assign to the values.
541        batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
542        alias: The alias to assign to the values expression. If not provided then will default to "t"
543
544    Returns:
545        This method operates as a generator and yields a VALUES expression.
546    """
547    yield from select_from_values(
548        values=df.itertuples(index=False, name=None),
549        columns_to_types=columns_to_types,
550        batch_size=batch_size,
551        alias=alias,
552    )
class Model(sqlglot.expressions.Expression):
16class Model(exp.Expression):
17    arg_types = {"expressions": True}
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
class Audit(sqlglot.expressions.Expression):
20class Audit(exp.Expression):
21    arg_types = {"expressions": True}
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
class Jinja(sqlglot.expressions.Func):
24class Jinja(exp.Func):
25    arg_types = {"this": True, "expressions": False}
26    is_var_len_args = True
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
sqlglot.expressions.Func
from_arg_list
sql_names
sql_name
default_parser_mappings
sqlglot.expressions.Condition
and_
or_
not_
class ModelKind(sqlglot.expressions.Expression):
29class ModelKind(exp.Expression):
30    arg_types = {"this": True, "expressions": False}
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
class MacroVar(sqlglot.expressions.Var):
33class MacroVar(exp.Var):
34    pass
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
class MacroFunc(sqlglot.expressions.Func):
37class MacroFunc(exp.Func):
38    @property
39    def name(self) -> str:
40        return self.this.name
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
sqlglot.expressions.Func
from_arg_list
sql_names
sql_name
default_parser_mappings
sqlglot.expressions.Condition
and_
or_
not_
class MacroDef(MacroFunc):
43class MacroDef(MacroFunc):
44    arg_types = {"this": True, "expression": True}
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
sqlglot.expressions.Func
from_arg_list
sql_names
sql_name
default_parser_mappings
sqlglot.expressions.Condition
and_
or_
not_
class MacroSQL(MacroFunc):
47class MacroSQL(MacroFunc):
48    arg_types = {"this": True, "into": False}
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
sqlglot.expressions.Func
from_arg_list
sql_names
sql_name
default_parser_mappings
sqlglot.expressions.Condition
and_
or_
not_
class MacroStrReplace(MacroFunc):
51class MacroStrReplace(MacroFunc):
52    pass
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
sqlglot.expressions.Func
from_arg_list
sql_names
sql_name
default_parser_mappings
sqlglot.expressions.Condition
and_
or_
not_
class PythonCode(sqlglot.expressions.Expression):
55class PythonCode(exp.Expression):
56    arg_types = {"expressions": True}
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
output_name
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
class DColonCast(sqlglot.expressions.Cast):
59class DColonCast(exp.Cast):
60    pass
Inherited Members
sqlglot.expressions.Expression
Expression
this
expression
expressions
text
is_string
is_number
is_int
is_star
alias
copy
append
set
depth
iter_expressions
find
find_all
find_ancestor
parent_select
same_parent
root
walk
dfs
bfs
unnest
unalias
unnest_operands
flatten
sql
transform
replace
pop
assert_is
error_messages
dump
load
sqlglot.expressions.Cast
output_name
is_type
sqlglot.expressions.Func
from_arg_list
sql_names
sql_name
default_parser_mappings
sqlglot.expressions.Condition
and_
or_
not_
def format_model_expressions( expressions: List[sqlglot.expressions.Expression], dialect: Optional[str] = None) -> str:
321def format_model_expressions(
322    expressions: t.List[exp.Expression], dialect: t.Optional[str] = None
323) -> str:
324    """Format a model's expressions into a standardized format.
325
326    Args:
327        expressions: The model's expressions, must be at least model def + query.
328        dialect: The dialect to render the expressions as.
329    Returns:
330        A string with the formatted model.
331    """
332    if len(expressions) == 1:
333        return expressions[0].sql(pretty=True, dialect=dialect)
334
335    *statements, query = expressions
336    query = query.copy()
337    selects = []
338
339    for expression in query.expressions:
340        column = None
341        comments = expression.comments
342        expression.comments = None
343
344        if not isinstance(expression, exp.Alias):
345            if expression.name:
346                expression = expression.replace(exp.alias_(expression.copy(), expression.name))
347
348        column = column or expression
349        expression = expression.this
350
351        if isinstance(expression, exp.Cast):
352            this = expression.this
353            if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren):
354                expression.replace(DColonCast(this=this, to=expression.to))
355        column.comments = comments
356        selects.append(column)
357
358    query.set("expressions", selects)
359
360    return ";\n\n".join(
361        [
362            *(statement.sql(pretty=True, dialect=dialect) for statement in statements),
363            query.sql(pretty=True, dialect=dialect),
364        ]
365    ).strip()

Format a model's expressions into a standardized format.

Arguments:
  • expressions: The model's expressions, must be at least model def + query.
  • dialect: The dialect to render the expressions as.
Returns:

A string with the formatted model.

def text_diff( a: Optional[sqlglot.expressions.Expression], b: Optional[sqlglot.expressions.Expression], dialect: Optional[str] = None) -> str:
368def text_diff(
369    a: t.Optional[exp.Expression],
370    b: t.Optional[exp.Expression],
371    dialect: t.Optional[str] = None,
372) -> str:
373    """Find the unified text diff between two expressions."""
374    return "\n".join(
375        unified_diff(
376            a.sql(pretty=True, comments=False, dialect=dialect).split("\n") if a else "",
377            b.sql(pretty=True, comments=False, dialect=dialect).split("\n") if b else "",
378        )
379    )

Find the unified text diff between two expressions.

def parse( sql: str, default_dialect: str | None = None) -> List[sqlglot.expressions.Expression]:
387def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]:
388    """Parse a sql string.
389
390    Supports parsing model definition.
391    If a jinja block is detected, the query is stored as raw string in a Jinja node.
392
393    Args:
394        sql: The sql based definition.
395        default_dialect: The dialect to use if the model does not specify one.
396
397    Returns:
398        A list of the expressions, [Model, *Statements, Query | Jinja]
399    """
400    match = DIALECT_PATTERN.search(sql)
401    dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)()
402
403    tokens = dialect.tokenizer.tokenize(sql)
404    chunks: t.List[t.Tuple[t.List, bool]] = [([], False)]
405    total = len(tokens)
406
407    for i, token in enumerate(tokens):
408        if token.token_type == TokenType.SEMICOLON:
409            if i < total - 1:
410                chunks.append(([], False))
411        else:
412            if token.token_type == TokenType.BLOCK_START or (
413                i < total - 1
414                and token.token_type == TokenType.L_BRACE
415                and tokens[i + 1].token_type == TokenType.L_BRACE
416            ):
417                chunks[-1] = (chunks[-1][0], True)
418            chunks[-1][0].append(token)
419
420    expressions: t.List[exp.Expression] = []
421    sql_lines = None
422
423    for chunk, is_jinja in chunks:
424        if is_jinja:
425            start, *_, end = chunk
426            sql_lines = sql_lines or sql.split("\n")
427            lines = sql_lines[start.line - 1 : end.line]
428            lines[0] = lines[0][start.col - 1 :]
429            lines[-1] = lines[-1][: end.col + len(end.text) - 1]
430            segment = "\n".join(lines)
431            variables = [
432                exp.Literal.string(var)
433                for var in find_undeclared_variables(ENVIRONMENT.parse(segment))
434            ]
435            expressions.append(Jinja(this=exp.Literal.string(segment), expressions=variables))
436        else:
437            for expression in dialect.parser().parse(chunk, sql):
438                if expression:
439                    expressions.append(expression)
440
441    return expressions

Parse a sql string.

Supports parsing model definition. If a jinja block is detected, the query is stored as raw string in a Jinja node.

Arguments:
  • sql: The sql based definition.
  • default_dialect: The dialect to use if the model does not specify one.
Returns:

A list of the expressions, [Model, *Statements, Query | Jinja]

@t.no_type_check
def extend_sqlglot() -> None:
444@t.no_type_check
445def extend_sqlglot() -> None:
446    """Extend SQLGlot with SQLMesh's custom macro aware dialect."""
447    parsers = {Parser}
448    generators = {Generator}
449
450    for dialect in Dialect.classes.values():
451        if hasattr(dialect, "Parser"):
452            parsers.add(dialect.Parser)
453        if hasattr(dialect, "Generator"):
454            generators.add(dialect.Generator)
455
456    for generator in generators:
457        if MacroFunc not in generator.TRANSFORMS:
458            generator.TRANSFORMS.update(
459                {
460                    DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
461                    MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
462                    MacroFunc: _macro_func_sql,
463                    MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}",
464                    MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})",
465                    MacroVar: lambda self, e: f"@{e.name}",
466                    Model: _model_sql,
467                    Jinja: lambda self, e: e.name,
468                    ModelKind: _model_kind_sql,
469                    PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
470                }
471            )
472            generator.WITH_SEPARATED_COMMENTS = (
473                *generator.WITH_SEPARATED_COMMENTS,
474                Model,
475            )
476
477    for parser in parsers:
478        parser.FUNCTIONS.update(
479            {
480                "JINJA": Jinja.from_arg_list,
481            }
482        )
483        parser.PLACEHOLDER_PARSERS.update(
484            {
485                TokenType.PARAMETER: _parse_macro,
486            }
487        )
488
489    _override(Parser, _parse_statement)
490    _override(Parser, _parse_join)
491    _override(Parser, _parse_order)
492    _override(Parser, _parse_where)
493    _override(Parser, _parse_group)
494    _override(Parser, _parse_with)
495    _override(Parser, _parse_having)
496    _override(Parser, _parse_lambda)

Extend SQLGlot with SQLMesh's custom macro aware dialect.

def select_from_values( values: Iterable[Tuple[Any, ...]], columns_to_types: Dict[str, sqlglot.expressions.DataType], batch_size: int = 0, alias: str = 't') -> Generator[sqlglot.expressions.Select, NoneType, NoneType]:
499def select_from_values(
500    values: t.Iterable[t.Tuple[t.Any, ...]],
501    columns_to_types: t.Dict[str, exp.DataType],
502    batch_size: int = 0,
503    alias: str = "t",
504) -> t.Generator[exp.Select, None, None]:
505    """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
506
507    Args:
508        values: List of values to use for the VALUES expression.
509        columns_to_types: Mapping of column names to types to assign to the values.
510        batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
511        alias: The alias to assign to the values expression. If not provided then will default to "t"
512
513    Returns:
514        This method operates as a generator and yields a VALUES expression.
515    """
516    casted_columns = [
517        exp.alias_(exp.cast(column, to=kind), column) for column, kind in columns_to_types.items()
518    ]
519    batch = []
520    for row in values:
521        batch.append(row)
522        if batch_size > 0 and len(batch) > batch_size:
523            values_exp = exp.values(batch, alias=alias, columns=columns_to_types)
524            yield exp.select(*casted_columns).from_(values_exp)
525            batch.clear()
526    if batch:
527        values_exp = exp.values(batch, alias=alias, columns=columns_to_types)
528        yield exp.select(*casted_columns).from_(values_exp)

Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.

Arguments:
  • values: List of values to use for the VALUES expression.
  • columns_to_types: Mapping of column names to types to assign to the values.
  • batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
  • alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:

This method operates as a generator and yields a VALUES expression.

def pandas_to_sql( df: pandas.core.frame.DataFrame, columns_to_types: Dict[str, sqlglot.expressions.DataType], batch_size: int = 0, alias: str = 't') -> Generator[sqlglot.expressions.Select, NoneType, NoneType]:
531def pandas_to_sql(
532    df: pd.DataFrame,
533    columns_to_types: t.Dict[str, exp.DataType],
534    batch_size: int = 0,
535    alias: str = "t",
536) -> t.Generator[exp.Select, None, None]:
537    """Convert a pandas dataframe into a VALUES sql statement.
538
539    Args:
540        df: A pandas dataframe to convert.
541        columns_to_types: Mapping of column names to types to assign to the values.
542        batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
543        alias: The alias to assign to the values expression. If not provided then will default to "t"
544
545    Returns:
546        This method operates as a generator and yields a VALUES expression.
547    """
548    yield from select_from_values(
549        values=df.itertuples(index=False, name=None),
550        columns_to_types=columns_to_types,
551        batch_size=batch_size,
552        alias=alias,
553    )

Convert a pandas dataframe into a VALUES sql statement.

Arguments:
  • df: A pandas dataframe to convert.
  • columns_to_types: Mapping of column names to types to assign to the values.
  • batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
  • alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:

This method operates as a generator and yields a VALUES expression.