sqlmesh.core.dialect
1from __future__ import annotations 2 3import functools 4import re 5import typing as t 6from difflib import unified_diff 7 8import pandas as pd 9from jinja2.meta import find_undeclared_variables 10from sqlglot import Dialect, Generator, Parser, TokenType, exp 11 12from sqlmesh.utils.jinja import ENVIRONMENT 13 14 15class Model(exp.Expression): 16 arg_types = {"expressions": True} 17 18 19class Audit(exp.Expression): 20 arg_types = {"expressions": True} 21 22 23class Jinja(exp.Func): 24 arg_types = {"this": True, "expressions": False} 25 is_var_len_args = True 26 27 28class ModelKind(exp.Expression): 29 arg_types = {"this": True, "expressions": False} 30 31 32class MacroVar(exp.Var): 33 pass 34 35 36class MacroFunc(exp.Func): 37 @property 38 def name(self) -> str: 39 return self.this.name 40 41 42class MacroDef(MacroFunc): 43 arg_types = {"this": True, "expression": True} 44 45 46class MacroSQL(MacroFunc): 47 arg_types = {"this": True, "into": False} 48 49 50class MacroStrReplace(MacroFunc): 51 pass 52 53 54class PythonCode(exp.Expression): 55 arg_types = {"expressions": True} 56 57 58class DColonCast(exp.Cast): 59 pass 60 61 62@t.no_type_check 63def _parse_statement(self: Parser) -> t.Optional[exp.Expression]: 64 if self._curr is None: 65 return None 66 67 parser = PARSERS.get(self._curr.text.upper()) 68 69 if parser: 70 # Capture any available description in the form of a comment 71 comments = self._curr.comments 72 73 self._advance() 74 meta = self._parse_wrapped(lambda: parser(self)) 75 76 meta.comments = comments 77 return meta 78 return self.__parse_statement() 79 80 81@t.no_type_check 82def _parse_lambda(self: Parser) -> t.Optional[exp.Expression]: 83 node = self.__parse_lambda() 84 if isinstance(node, exp.Lambda): 85 node.set("this", self._parse_alias(node.this)) 86 return node 87 88 89def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]: 90 index = self._index 91 field = self._parse_primary() or self._parse_function({}) or self._parse_id_var() 92 93 if isinstance(field, exp.Func): 94 macro_name = field.name.upper() 95 if macro_name != keyword_macro and macro_name in KEYWORD_MACROS: 96 self._retreat(index) 97 return None 98 if isinstance(field, exp.Anonymous): 99 name = field.name.upper() 100 if name == "DEF": 101 return self.expression( 102 MacroDef, this=field.expressions[0], expression=field.expressions[1] 103 ) 104 if name == "SQL": 105 into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None 106 return self.expression(MacroSQL, this=field.expressions[0], into=into) 107 return self.expression(MacroFunc, this=field) 108 109 if field is None: 110 return None 111 112 if field.is_string or (isinstance(field, exp.Identifier) and field.quoted): 113 return self.expression(MacroStrReplace, this=exp.Literal.string(field.this)) 114 return self.expression(MacroVar, this=field.this) 115 116 117KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY"} 118 119 120def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]: 121 if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or ( 122 self._next and self._next.text.upper() != name.upper() 123 ): 124 return None 125 126 self._advance(1) 127 return _parse_macro(self, keyword_macro=name) 128 129 130@t.no_type_check 131def _parse_with(self: Parser) -> t.Optional[exp.Expression]: 132 macro = _parse_matching_macro(self, "WITH") 133 if not macro: 134 return self.__parse_with() 135 136 macro.this.append("expressions", self.__parse_with(True)) 137 return macro 138 139 140@t.no_type_check 141def _parse_join(self: Parser) -> t.Optional[exp.Expression]: 142 index = self._index 143 natural, side, kind = self._parse_join_side_and_kind() 144 macro = _parse_matching_macro(self, "JOIN") 145 if not macro: 146 self._retreat(index) 147 return self.__parse_join() 148 149 join = self.__parse_join(True) 150 if natural: 151 join.set("natural", True) 152 if side: 153 join.set("side", side.text) 154 if kind: 155 join.set("kind", kind.text) 156 157 macro.this.append("expressions", join) 158 return macro 159 160 161@t.no_type_check 162def _parse_where(self: Parser) -> t.Optional[exp.Expression]: 163 macro = _parse_matching_macro(self, "WHERE") 164 if not macro: 165 return self.__parse_where() 166 167 macro.this.append("expressions", self.__parse_where(True)) 168 return macro 169 170 171@t.no_type_check 172def _parse_group(self: Parser) -> t.Optional[exp.Expression]: 173 macro = _parse_matching_macro(self, "GROUP_BY") 174 if not macro: 175 return self.__parse_group() 176 177 macro.this.append("expressions", self.__parse_group(True)) 178 return macro 179 180 181@t.no_type_check 182def _parse_having(self: Parser) -> t.Optional[exp.Expression]: 183 macro = _parse_matching_macro(self, "HAVING") 184 if not macro: 185 return self.__parse_having() 186 187 macro.this.append("expressions", self.__parse_having(True)) 188 return macro 189 190 191@t.no_type_check 192def _parse_order(self: Parser, this: exp.Expression = None) -> t.Optional[exp.Expression]: 193 macro = _parse_matching_macro(self, "ORDER_BY") 194 if not macro: 195 return self.__parse_order(this) 196 197 macro.this.append("expressions", self.__parse_order(this, True)) 198 return macro 199 200 201def _parse_props(self: Parser) -> t.Optional[exp.Expression]: 202 key = self._parse_id_var(True) 203 204 if not key: 205 return None 206 207 index = self._index 208 if self._match(TokenType.L_PAREN): 209 self._retreat(index) 210 value: t.Optional[exp.Expression] = self.expression( 211 exp.Tuple, 212 expressions=self._parse_wrapped_csv( 213 lambda: self._parse_string() or self._parse_id_var() 214 ), 215 ) 216 else: 217 value = self._parse_bracket(self._parse_field(any_token=True)) 218 219 return self.expression(exp.Property, this=key.name.lower(), value=value) 220 221 222def _create_parser(parser_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable: 223 def parse(self: Parser) -> t.Optional[exp.Expression]: 224 from sqlmesh.core.model.kind import ModelKindName 225 226 expressions = [] 227 228 while True: 229 key_expression = self._parse_id_var(any_token=True) 230 231 if not key_expression: 232 break 233 234 key = key_expression.name.lower() 235 236 value: t.Optional[exp.Expression | str] 237 238 if key in table_keys: 239 value = exp.table_name(self._parse_table()) 240 elif key == "columns": 241 value = self._parse_schema() 242 elif key == "kind": 243 id_var = self._parse_id_var(any_token=True) 244 if not id_var: 245 value = None 246 else: 247 index = self._index 248 kind = ModelKindName[id_var.name.upper()] 249 250 if kind in ( 251 ModelKindName.INCREMENTAL_BY_TIME_RANGE, 252 ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, 253 ModelKindName.SEED, 254 ) and self._match(TokenType.L_PAREN): 255 self._retreat(index) 256 props = self._parse_wrapped_csv(functools.partial(_parse_props, self)) 257 else: 258 props = None 259 value = self.expression( 260 ModelKind, 261 this=kind.value, 262 expressions=props, 263 ) 264 else: 265 value = self._parse_bracket(self._parse_field(any_token=True)) 266 267 expressions.append(self.expression(exp.Property, this=key, value=value)) 268 269 if not self._match(TokenType.COMMA): 270 break 271 272 return self.expression(parser_type, expressions=expressions) 273 274 return parse 275 276 277_parse_model = _create_parser(Model, ["name"]) 278_parse_audit = _create_parser(Audit, ["model"]) 279PARSERS = {"MODEL": _parse_model, "AUDIT": _parse_audit} 280 281 282def _model_sql(self: Generator, expression: exp.Expression) -> str: 283 props = ",\n".join( 284 self.indent(f"{prop.name} {self.sql(prop, 'value')}") for prop in expression.expressions 285 ) 286 return "\n".join(["MODEL (", props, ")"]) 287 288 289def _model_kind_sql(self: Generator, expression: ModelKind) -> str: 290 props = ",\n".join( 291 self.indent(f"{prop.this} {self.sql(prop, 'value')}") for prop in expression.expressions 292 ) 293 if props: 294 return "\n".join([f"{expression.this} (", props, ")"]) 295 return expression.name.upper() 296 297 298def _macro_keyword_func_sql(self: Generator, expression: exp.Expression) -> str: 299 name = expression.name 300 keyword = name.replace("_", " ") 301 *args, clause = expression.expressions 302 macro = f"@{name}({self.format_args(*args)})" 303 return self.sql(clause).replace(keyword, macro, 1) 304 305 306def _macro_func_sql(self: Generator, expression: exp.Expression) -> str: 307 expression = expression.this 308 name = expression.name 309 if name in KEYWORD_MACROS: 310 return _macro_keyword_func_sql(self, expression) 311 return f"@{name}({self.format_args(*expression.expressions)})" 312 313 314def _override(klass: t.Type[Parser], func: t.Callable) -> None: 315 name = func.__name__ 316 setattr(klass, f"_{name}", getattr(klass, name)) 317 setattr(klass, name, func) 318 319 320def format_model_expressions( 321 expressions: t.List[exp.Expression], dialect: t.Optional[str] = None 322) -> str: 323 """Format a model's expressions into a standardized format. 324 325 Args: 326 expressions: The model's expressions, must be at least model def + query. 327 dialect: The dialect to render the expressions as. 328 Returns: 329 A string with the formatted model. 330 """ 331 if len(expressions) == 1: 332 return expressions[0].sql(pretty=True, dialect=dialect) 333 334 *statements, query = expressions 335 query = query.copy() 336 selects = [] 337 338 for expression in query.expressions: 339 column = None 340 comments = expression.comments 341 expression.comments = None 342 343 if not isinstance(expression, exp.Alias): 344 if expression.name: 345 expression = expression.replace(exp.alias_(expression.copy(), expression.name)) 346 347 column = column or expression 348 expression = expression.this 349 350 if isinstance(expression, exp.Cast): 351 this = expression.this 352 if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren): 353 expression.replace(DColonCast(this=this, to=expression.to)) 354 column.comments = comments 355 selects.append(column) 356 357 query.set("expressions", selects) 358 359 return ";\n\n".join( 360 [ 361 *(statement.sql(pretty=True, dialect=dialect) for statement in statements), 362 query.sql(pretty=True, dialect=dialect), 363 ] 364 ).strip() 365 366 367def text_diff( 368 a: t.Optional[exp.Expression], 369 b: t.Optional[exp.Expression], 370 dialect: t.Optional[str] = None, 371) -> str: 372 """Find the unified text diff between two expressions.""" 373 return "\n".join( 374 unified_diff( 375 a.sql(pretty=True, comments=False, dialect=dialect).split("\n") if a else "", 376 b.sql(pretty=True, comments=False, dialect=dialect).split("\n") if b else "", 377 ) 378 ) 379 380 381DIALECT_PATTERN = re.compile( 382 r"(model|audit).*?\(.*?dialect[^a-z,]+([a-z]*|,)", re.IGNORECASE | re.DOTALL 383) 384 385 386def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]: 387 """Parse a sql string. 388 389 Supports parsing model definition. 390 If a jinja block is detected, the query is stored as raw string in a Jinja node. 391 392 Args: 393 sql: The sql based definition. 394 default_dialect: The dialect to use if the model does not specify one. 395 396 Returns: 397 A list of the expressions, [Model, *Statements, Query | Jinja] 398 """ 399 match = DIALECT_PATTERN.search(sql) 400 dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)() 401 402 tokens = dialect.tokenizer.tokenize(sql) 403 chunks: t.List[t.Tuple[t.List, bool]] = [([], False)] 404 total = len(tokens) 405 406 for i, token in enumerate(tokens): 407 if token.token_type == TokenType.SEMICOLON: 408 if i < total - 1: 409 chunks.append(([], False)) 410 else: 411 if token.token_type == TokenType.BLOCK_START or ( 412 i < total - 1 413 and token.token_type == TokenType.L_BRACE 414 and tokens[i + 1].token_type == TokenType.L_BRACE 415 ): 416 chunks[-1] = (chunks[-1][0], True) 417 chunks[-1][0].append(token) 418 419 expressions: t.List[exp.Expression] = [] 420 sql_lines = None 421 422 for chunk, is_jinja in chunks: 423 if is_jinja: 424 start, *_, end = chunk 425 sql_lines = sql_lines or sql.split("\n") 426 lines = sql_lines[start.line - 1 : end.line] 427 lines[0] = lines[0][start.col - 1 :] 428 lines[-1] = lines[-1][: end.col + len(end.text) - 1] 429 segment = "\n".join(lines) 430 variables = [ 431 exp.Literal.string(var) 432 for var in find_undeclared_variables(ENVIRONMENT.parse(segment)) 433 ] 434 expressions.append(Jinja(this=exp.Literal.string(segment), expressions=variables)) 435 else: 436 for expression in dialect.parser().parse(chunk, sql): 437 if expression: 438 expressions.append(expression) 439 440 return expressions 441 442 443@t.no_type_check 444def extend_sqlglot() -> None: 445 """Extend SQLGlot with SQLMesh's custom macro aware dialect.""" 446 parsers = {Parser} 447 generators = {Generator} 448 449 for dialect in Dialect.classes.values(): 450 if hasattr(dialect, "Parser"): 451 parsers.add(dialect.Parser) 452 if hasattr(dialect, "Generator"): 453 generators.add(dialect.Generator) 454 455 for generator in generators: 456 if MacroFunc not in generator.TRANSFORMS: 457 generator.TRANSFORMS.update( 458 { 459 DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}", 460 MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})", 461 MacroFunc: _macro_func_sql, 462 MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}", 463 MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})", 464 MacroVar: lambda self, e: f"@{e.name}", 465 Model: _model_sql, 466 Jinja: lambda self, e: e.name, 467 ModelKind: _model_kind_sql, 468 PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False), 469 } 470 ) 471 generator.WITH_SEPARATED_COMMENTS = ( 472 *generator.WITH_SEPARATED_COMMENTS, 473 Model, 474 ) 475 476 for parser in parsers: 477 parser.FUNCTIONS.update( 478 { 479 "JINJA": Jinja.from_arg_list, 480 } 481 ) 482 parser.PLACEHOLDER_PARSERS.update( 483 { 484 TokenType.PARAMETER: _parse_macro, 485 } 486 ) 487 488 _override(Parser, _parse_statement) 489 _override(Parser, _parse_join) 490 _override(Parser, _parse_order) 491 _override(Parser, _parse_where) 492 _override(Parser, _parse_group) 493 _override(Parser, _parse_with) 494 _override(Parser, _parse_having) 495 _override(Parser, _parse_lambda) 496 497 498def select_from_values( 499 values: t.Iterable[t.Tuple[t.Any, ...]], 500 columns_to_types: t.Dict[str, exp.DataType], 501 batch_size: int = 0, 502 alias: str = "t", 503) -> t.Generator[exp.Select, None, None]: 504 """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types. 505 506 Args: 507 values: List of values to use for the VALUES expression. 508 columns_to_types: Mapping of column names to types to assign to the values. 509 batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur. 510 alias: The alias to assign to the values expression. If not provided then will default to "t" 511 512 Returns: 513 This method operates as a generator and yields a VALUES expression. 514 """ 515 casted_columns = [ 516 exp.alias_(exp.cast(column, to=kind), column) for column, kind in columns_to_types.items() 517 ] 518 batch = [] 519 for row in values: 520 batch.append(row) 521 if batch_size > 0 and len(batch) > batch_size: 522 values_exp = exp.values(batch, alias=alias, columns=columns_to_types) 523 yield exp.select(*casted_columns).from_(values_exp) 524 batch.clear() 525 if batch: 526 values_exp = exp.values(batch, alias=alias, columns=columns_to_types) 527 yield exp.select(*casted_columns).from_(values_exp) 528 529 530def pandas_to_sql( 531 df: pd.DataFrame, 532 columns_to_types: t.Dict[str, exp.DataType], 533 batch_size: int = 0, 534 alias: str = "t", 535) -> t.Generator[exp.Select, None, None]: 536 """Convert a pandas dataframe into a VALUES sql statement. 537 538 Args: 539 df: A pandas dataframe to convert. 540 columns_to_types: Mapping of column names to types to assign to the values. 541 batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur. 542 alias: The alias to assign to the values expression. If not provided then will default to "t" 543 544 Returns: 545 This method operates as a generator and yields a VALUES expression. 546 """ 547 yield from select_from_values( 548 values=df.itertuples(index=False, name=None), 549 columns_to_types=columns_to_types, 550 batch_size=batch_size, 551 alias=alias, 552 )
class
Model(sqlglot.expressions.Expression):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
class
Audit(sqlglot.expressions.Expression):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
class
Jinja(sqlglot.expressions.Func):
24class Jinja(exp.Func): 25 arg_types = {"this": True, "expressions": False} 26 is_var_len_args = True
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
- sqlglot.expressions.Func
- from_arg_list
- sql_names
- sql_name
- default_parser_mappings
- sqlglot.expressions.Condition
- and_
- or_
- not_
class
ModelKind(sqlglot.expressions.Expression):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
class
MacroVar(sqlglot.expressions.Var):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
class
MacroFunc(sqlglot.expressions.Func):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
- sqlglot.expressions.Func
- from_arg_list
- sql_names
- sql_name
- default_parser_mappings
- sqlglot.expressions.Condition
- and_
- or_
- not_
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
- sqlglot.expressions.Func
- from_arg_list
- sql_names
- sql_name
- default_parser_mappings
- sqlglot.expressions.Condition
- and_
- or_
- not_
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
- sqlglot.expressions.Func
- from_arg_list
- sql_names
- sql_name
- default_parser_mappings
- sqlglot.expressions.Condition
- and_
- or_
- not_
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
- sqlglot.expressions.Func
- from_arg_list
- sql_names
- sql_name
- default_parser_mappings
- sqlglot.expressions.Condition
- and_
- or_
- not_
class
PythonCode(sqlglot.expressions.Expression):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- output_name
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
class
DColonCast(sqlglot.expressions.Cast):
Inherited Members
- sqlglot.expressions.Expression
- Expression
- this
- expression
- expressions
- text
- is_string
- is_number
- is_int
- is_star
- alias
- copy
- append
- set
- depth
- iter_expressions
- find
- find_all
- find_ancestor
- parent_select
- same_parent
- root
- walk
- dfs
- bfs
- unnest
- unalias
- unnest_operands
- flatten
- sql
- transform
- replace
- pop
- assert_is
- error_messages
- dump
- load
- sqlglot.expressions.Cast
- output_name
- is_type
- sqlglot.expressions.Func
- from_arg_list
- sql_names
- sql_name
- default_parser_mappings
- sqlglot.expressions.Condition
- and_
- or_
- not_
def
format_model_expressions( expressions: List[sqlglot.expressions.Expression], dialect: Optional[str] = None) -> str:
321def format_model_expressions( 322 expressions: t.List[exp.Expression], dialect: t.Optional[str] = None 323) -> str: 324 """Format a model's expressions into a standardized format. 325 326 Args: 327 expressions: The model's expressions, must be at least model def + query. 328 dialect: The dialect to render the expressions as. 329 Returns: 330 A string with the formatted model. 331 """ 332 if len(expressions) == 1: 333 return expressions[0].sql(pretty=True, dialect=dialect) 334 335 *statements, query = expressions 336 query = query.copy() 337 selects = [] 338 339 for expression in query.expressions: 340 column = None 341 comments = expression.comments 342 expression.comments = None 343 344 if not isinstance(expression, exp.Alias): 345 if expression.name: 346 expression = expression.replace(exp.alias_(expression.copy(), expression.name)) 347 348 column = column or expression 349 expression = expression.this 350 351 if isinstance(expression, exp.Cast): 352 this = expression.this 353 if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren): 354 expression.replace(DColonCast(this=this, to=expression.to)) 355 column.comments = comments 356 selects.append(column) 357 358 query.set("expressions", selects) 359 360 return ";\n\n".join( 361 [ 362 *(statement.sql(pretty=True, dialect=dialect) for statement in statements), 363 query.sql(pretty=True, dialect=dialect), 364 ] 365 ).strip()
Format a model's expressions into a standardized format.
Arguments:
- expressions: The model's expressions, must be at least model def + query.
- dialect: The dialect to render the expressions as.
Returns:
A string with the formatted model.
def
text_diff( a: Optional[sqlglot.expressions.Expression], b: Optional[sqlglot.expressions.Expression], dialect: Optional[str] = None) -> str:
368def text_diff( 369 a: t.Optional[exp.Expression], 370 b: t.Optional[exp.Expression], 371 dialect: t.Optional[str] = None, 372) -> str: 373 """Find the unified text diff between two expressions.""" 374 return "\n".join( 375 unified_diff( 376 a.sql(pretty=True, comments=False, dialect=dialect).split("\n") if a else "", 377 b.sql(pretty=True, comments=False, dialect=dialect).split("\n") if b else "", 378 ) 379 )
Find the unified text diff between two expressions.
def
parse( sql: str, default_dialect: str | None = None) -> List[sqlglot.expressions.Expression]:
387def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]: 388 """Parse a sql string. 389 390 Supports parsing model definition. 391 If a jinja block is detected, the query is stored as raw string in a Jinja node. 392 393 Args: 394 sql: The sql based definition. 395 default_dialect: The dialect to use if the model does not specify one. 396 397 Returns: 398 A list of the expressions, [Model, *Statements, Query | Jinja] 399 """ 400 match = DIALECT_PATTERN.search(sql) 401 dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)() 402 403 tokens = dialect.tokenizer.tokenize(sql) 404 chunks: t.List[t.Tuple[t.List, bool]] = [([], False)] 405 total = len(tokens) 406 407 for i, token in enumerate(tokens): 408 if token.token_type == TokenType.SEMICOLON: 409 if i < total - 1: 410 chunks.append(([], False)) 411 else: 412 if token.token_type == TokenType.BLOCK_START or ( 413 i < total - 1 414 and token.token_type == TokenType.L_BRACE 415 and tokens[i + 1].token_type == TokenType.L_BRACE 416 ): 417 chunks[-1] = (chunks[-1][0], True) 418 chunks[-1][0].append(token) 419 420 expressions: t.List[exp.Expression] = [] 421 sql_lines = None 422 423 for chunk, is_jinja in chunks: 424 if is_jinja: 425 start, *_, end = chunk 426 sql_lines = sql_lines or sql.split("\n") 427 lines = sql_lines[start.line - 1 : end.line] 428 lines[0] = lines[0][start.col - 1 :] 429 lines[-1] = lines[-1][: end.col + len(end.text) - 1] 430 segment = "\n".join(lines) 431 variables = [ 432 exp.Literal.string(var) 433 for var in find_undeclared_variables(ENVIRONMENT.parse(segment)) 434 ] 435 expressions.append(Jinja(this=exp.Literal.string(segment), expressions=variables)) 436 else: 437 for expression in dialect.parser().parse(chunk, sql): 438 if expression: 439 expressions.append(expression) 440 441 return expressions
Parse a sql string.
Supports parsing model definition. If a jinja block is detected, the query is stored as raw string in a Jinja node.
Arguments:
- sql: The sql based definition.
- default_dialect: The dialect to use if the model does not specify one.
Returns:
A list of the expressions, [Model, *Statements, Query | Jinja]
@t.no_type_check
def
extend_sqlglot() -> None:
444@t.no_type_check 445def extend_sqlglot() -> None: 446 """Extend SQLGlot with SQLMesh's custom macro aware dialect.""" 447 parsers = {Parser} 448 generators = {Generator} 449 450 for dialect in Dialect.classes.values(): 451 if hasattr(dialect, "Parser"): 452 parsers.add(dialect.Parser) 453 if hasattr(dialect, "Generator"): 454 generators.add(dialect.Generator) 455 456 for generator in generators: 457 if MacroFunc not in generator.TRANSFORMS: 458 generator.TRANSFORMS.update( 459 { 460 DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}", 461 MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})", 462 MacroFunc: _macro_func_sql, 463 MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}", 464 MacroSQL: lambda self, e: f"@SQL({self.sql(e.this)})", 465 MacroVar: lambda self, e: f"@{e.name}", 466 Model: _model_sql, 467 Jinja: lambda self, e: e.name, 468 ModelKind: _model_kind_sql, 469 PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False), 470 } 471 ) 472 generator.WITH_SEPARATED_COMMENTS = ( 473 *generator.WITH_SEPARATED_COMMENTS, 474 Model, 475 ) 476 477 for parser in parsers: 478 parser.FUNCTIONS.update( 479 { 480 "JINJA": Jinja.from_arg_list, 481 } 482 ) 483 parser.PLACEHOLDER_PARSERS.update( 484 { 485 TokenType.PARAMETER: _parse_macro, 486 } 487 ) 488 489 _override(Parser, _parse_statement) 490 _override(Parser, _parse_join) 491 _override(Parser, _parse_order) 492 _override(Parser, _parse_where) 493 _override(Parser, _parse_group) 494 _override(Parser, _parse_with) 495 _override(Parser, _parse_having) 496 _override(Parser, _parse_lambda)
Extend SQLGlot with SQLMesh's custom macro aware dialect.
def
select_from_values( values: Iterable[Tuple[Any, ...]], columns_to_types: Dict[str, sqlglot.expressions.DataType], batch_size: int = 0, alias: str = 't') -> Generator[sqlglot.expressions.Select, NoneType, NoneType]:
499def select_from_values( 500 values: t.Iterable[t.Tuple[t.Any, ...]], 501 columns_to_types: t.Dict[str, exp.DataType], 502 batch_size: int = 0, 503 alias: str = "t", 504) -> t.Generator[exp.Select, None, None]: 505 """Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types. 506 507 Args: 508 values: List of values to use for the VALUES expression. 509 columns_to_types: Mapping of column names to types to assign to the values. 510 batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur. 511 alias: The alias to assign to the values expression. If not provided then will default to "t" 512 513 Returns: 514 This method operates as a generator and yields a VALUES expression. 515 """ 516 casted_columns = [ 517 exp.alias_(exp.cast(column, to=kind), column) for column, kind in columns_to_types.items() 518 ] 519 batch = [] 520 for row in values: 521 batch.append(row) 522 if batch_size > 0 and len(batch) > batch_size: 523 values_exp = exp.values(batch, alias=alias, columns=columns_to_types) 524 yield exp.select(*casted_columns).from_(values_exp) 525 batch.clear() 526 if batch: 527 values_exp = exp.values(batch, alias=alias, columns=columns_to_types) 528 yield exp.select(*casted_columns).from_(values_exp)
Generate a VALUES expression that has a select wrapped around it to cast the values to their correct types.
Arguments:
- values: List of values to use for the VALUES expression.
- columns_to_types: Mapping of column names to types to assign to the values.
- batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
- alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:
This method operates as a generator and yields a VALUES expression.
def
pandas_to_sql( df: pandas.core.frame.DataFrame, columns_to_types: Dict[str, sqlglot.expressions.DataType], batch_size: int = 0, alias: str = 't') -> Generator[sqlglot.expressions.Select, NoneType, NoneType]:
531def pandas_to_sql( 532 df: pd.DataFrame, 533 columns_to_types: t.Dict[str, exp.DataType], 534 batch_size: int = 0, 535 alias: str = "t", 536) -> t.Generator[exp.Select, None, None]: 537 """Convert a pandas dataframe into a VALUES sql statement. 538 539 Args: 540 df: A pandas dataframe to convert. 541 columns_to_types: Mapping of column names to types to assign to the values. 542 batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur. 543 alias: The alias to assign to the values expression. If not provided then will default to "t" 544 545 Returns: 546 This method operates as a generator and yields a VALUES expression. 547 """ 548 yield from select_from_values( 549 values=df.itertuples(index=False, name=None), 550 columns_to_types=columns_to_types, 551 batch_size=batch_size, 552 alias=alias, 553 )
Convert a pandas dataframe into a VALUES sql statement.
Arguments:
- df: A pandas dataframe to convert.
- columns_to_types: Mapping of column names to types to assign to the values.
- batch_size: The maximum number of tuples per batch, if <= 0 then no batching will occur.
- alias: The alias to assign to the values expression. If not provided then will default to "t"
Returns:
This method operates as a generator and yields a VALUES expression.