sqlmesh.dbt.builtin
1from __future__ import annotations 2 3import json 4import os 5import typing as t 6from ast import literal_eval 7 8import agate 9import jinja2 10from dbt.adapters.base import BaseRelation 11from dbt.contracts.relation import Policy 12from ruamel.yaml import YAMLError 13 14from sqlmesh.core.engine_adapter import EngineAdapter 15from sqlmesh.dbt.adapter import ParsetimeAdapter, RuntimeAdapter 16from sqlmesh.utils import AttributeDict, yaml 17from sqlmesh.utils.errors import ConfigError, MacroEvalError 18from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal 19 20 21class Exceptions: 22 def raise_compiler_error(self, msg: str) -> None: 23 from dbt.exceptions import CompilationError 24 25 raise CompilationError(msg) 26 27 def warn(self, msg: str) -> str: 28 print(msg) 29 return "" 30 31 32class Api: 33 def __init__(self) -> None: 34 from dbt.adapters.base.column import Column 35 from dbt.adapters.base.relation import BaseRelation 36 37 self.Relation = BaseRelation 38 self.Column = Column 39 40 41class Flags: 42 def __init__(self) -> None: 43 # Temporary placeholder values for now (these are generally passed from the CLI) 44 self.FULL_REFRESH = None 45 self.STORE_FAILURES = None 46 self.WHICH = None 47 48 49class Modules: 50 def __init__(self) -> None: 51 import datetime 52 import itertools 53 import re 54 55 try: 56 import pytz 57 58 self.pytz = pytz 59 except ImportError: 60 pass 61 62 self.datetime = datetime 63 self.re = re 64 self.itertools = itertools 65 66 67class SQLExecution: 68 def __init__(self, adapter: RuntimeAdapter): 69 self.adapter = adapter 70 self._results: t.Dict[str, AttributeDict] = {} 71 72 def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str: 73 from dbt.clients import agate_helper 74 75 if agate_table is None: 76 agate_table = agate_helper.empty_table() 77 78 self._results[name] = AttributeDict( 79 { 80 "response": response, 81 "data": agate_helper.as_matrix(agate_table), 82 "table": agate_table, 83 } 84 ) 85 return "" 86 87 def load_result(self, name: str) -> t.Optional[AttributeDict]: 88 return self._results.get(name) 89 90 def run_query(self, sql: str) -> agate.Table: 91 self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql) 92 resp = self.load_result("run_query_statement") 93 assert resp is not None 94 return resp["table"] 95 96 def statement( 97 self, 98 name: t.Optional[str], 99 fetch_result: bool = False, 100 auto_begin: bool = True, 101 language: str = "sql", 102 caller: t.Optional[jinja2.runtime.Macro | str] = None, 103 ) -> str: 104 """ 105 Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional 106 but we make it optional and at the end because we need to match the signature of the jinja2 macro. 107 108 Name is the name that we store the results to which can be retrieved with `load_result`. If name is not 109 provided then the SQL is executed but the results are not stored. 110 """ 111 if not caller: 112 raise RuntimeError( 113 "Statement relies on a caller to be set that is the target SQL to be run" 114 ) 115 sql = caller if isinstance(caller, str) else caller() 116 if language != "sql": 117 raise NotImplementedError( 118 "SQLMesh's dbt integration only supports SQL statements at this time." 119 ) 120 assert self.adapter is not None 121 res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin) 122 if name: 123 self.store_result(name, res, table) 124 return "" 125 126 127def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]: 128 if name not in os.environ and default is None: 129 raise ConfigError(f"Missing environment variable '{name}'") 130 return os.environ.get(name, default) 131 132 133def is_incremental() -> bool: 134 return False 135 136 137def log(msg: str, info: bool = False) -> str: 138 print(msg) 139 return "" 140 141 142def no_log(msg: str, info: bool = False) -> str: 143 return "" 144 145 146def config(*args: t.Any, **kwargs: t.Any) -> str: 147 return "" 148 149 150def generate_var(variables: t.Dict[str, t.Any]) -> t.Callable: 151 def var(name: str, default: t.Optional[str] = None) -> str: 152 return variables.get(name, default) 153 154 return var 155 156 157def generate_ref(refs: t.Dict[str, t.Any]) -> t.Callable: 158 159 # TODO suport package name 160 def ref(package: str, name: t.Optional[str] = None) -> t.Optional[BaseRelation]: 161 name = name or package 162 relation_info = refs.get(name) 163 if relation_info is None: 164 return relation_info 165 166 return BaseRelation.create(**relation_info, quote_policy=quote_policy()) 167 168 return ref 169 170 171def generate_source(sources: t.Dict[str, t.Any]) -> t.Callable: 172 def source(package: str, name: str) -> t.Optional[BaseRelation]: 173 relation_info = sources.get(f"{package}.{name}") 174 if relation_info is None: 175 return relation_info 176 177 return BaseRelation.create(**relation_info, quote_policy=quote_policy()) 178 179 return source 180 181 182def quote_policy() -> Policy: 183 return Policy(database=False, schema=False, identifier=False) 184 185 186def return_val(val: t.Any) -> None: 187 raise MacroReturnVal(val) 188 189 190def to_set(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 191 try: 192 return set(value) 193 except TypeError: 194 return default 195 196 197def to_json(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 198 try: 199 return json.dumps(value) 200 except TypeError: 201 return default 202 203 204def from_json(value: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 205 try: 206 return json.loads(value) 207 except (TypeError, json.JSONDecodeError): 208 return default 209 210 211def to_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 212 try: 213 return yaml.dumps(value) 214 except (TypeError, YAMLError): 215 return default 216 217 218def from_yaml(value: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 219 try: 220 return dict(yaml.load(value, raise_if_empty=False, render_jinja=False)) 221 except (TypeError, YAMLError): 222 return default 223 224 225def do_zip(*args: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: 226 try: 227 return list(zip(*args)) 228 except TypeError: 229 return default 230 231 232def as_bool(value: str) -> bool: 233 result = _try_literal_eval(value) 234 if isinstance(result, bool): 235 return result 236 raise MacroEvalError(f"Failed to convert '{value}' into boolean.") 237 238 239def as_number(value: str) -> t.Any: 240 result = _try_literal_eval(value) 241 if isinstance(value, (int, float)) and not isinstance(result, bool): 242 return result 243 raise MacroEvalError(f"Failed to convert '{value}' into number.") 244 245 246def _try_literal_eval(value: str) -> t.Any: 247 try: 248 return literal_eval(value) 249 except (ValueError, SyntaxError, MemoryError): 250 return value 251 252 253BUILTIN_GLOBALS = { 254 "api": Api(), 255 "config": config, 256 "env_var": env_var, 257 "exceptions": Exceptions(), 258 "flags": Flags(), 259 "fromjson": from_json, 260 "fromyaml": from_yaml, 261 "is_incremental": is_incremental, 262 "log": no_log, 263 "modules": Modules(), 264 "print": no_log, 265 "return": return_val, 266 "set": to_set, 267 "set_strict": set, 268 "sqlmesh": True, 269 "tojson": to_json, 270 "toyaml": to_yaml, 271 "zip": do_zip, 272 "zip_strict": lambda *args: list(zip(*args)), 273} 274 275BUILTIN_FILTERS = { 276 "as_bool": as_bool, 277 "as_native": _try_literal_eval, 278 "as_number": as_number, 279 "as_text": lambda v: "" if v is None else str(v), 280} 281 282 283def create_builtin_globals( 284 jinja_macros: JinjaMacroRegistry, 285 jinja_globals: t.Dict[str, t.Any], 286 engine_adapter: t.Optional[EngineAdapter], 287) -> t.Dict[str, t.Any]: 288 builtin_globals = BUILTIN_GLOBALS.copy() 289 jinja_globals = jinja_globals.copy() 290 291 this = jinja_globals.pop("this", None) 292 if this is not None: 293 if not isinstance(this, BaseRelation): 294 builtin_globals["this"] = BaseRelation.create(**this, quote_policy=quote_policy()) 295 else: 296 builtin_globals["this"] = this 297 298 sources = jinja_globals.pop("sources", None) 299 if sources is not None: 300 builtin_globals["source"] = generate_source(sources) 301 302 refs = jinja_globals.pop("refs", None) 303 if refs is not None: 304 builtin_globals["ref"] = generate_ref(refs) 305 306 variables = jinja_globals.pop("vars", None) 307 if variables is not None: 308 builtin_globals["var"] = generate_var(variables) 309 310 builtin_globals["builtins"] = AttributeDict( 311 {k: builtin_globals.get(k) for k in ("ref", "source", "config")} 312 ) 313 314 if engine_adapter is not None: 315 adapter = RuntimeAdapter( 316 engine_adapter, jinja_macros, jinja_globals={**builtin_globals, **jinja_globals} 317 ) 318 sql_execution = SQLExecution(adapter) 319 builtin_globals.update( 320 { 321 "execute": True, 322 "adapter": adapter, 323 "load_relation": lambda r: adapter.get_relation(r.database, r.schema, r.identifier), 324 "store_result": sql_execution.store_result, 325 "load_result": sql_execution.load_result, 326 "run_query": sql_execution.run_query, 327 "statement": sql_execution.statement, 328 "log": log, 329 "print": log, 330 } 331 ) 332 else: 333 builtin_globals.update( 334 { 335 "execute": False, 336 "adapter": ParsetimeAdapter( 337 jinja_macros, jinja_globals={**builtin_globals, **jinja_globals} 338 ), 339 "load_relation": lambda *args, **kwargs: None, 340 "store_result": lambda *args, **kwargs: "", 341 "load_result": lambda *args, **kwargs: None, 342 "run_query": lambda *args, **kwargs: None, 343 "statement": lambda *args, **kwargs: "", 344 "log": no_log, 345 "print": no_log, 346 } 347 ) 348 349 return {**builtin_globals, **jinja_globals} 350 351 352def create_builtin_filters() -> t.Dict[str, t.Callable]: 353 return BUILTIN_FILTERS
class
Exceptions:
22class Exceptions: 23 def raise_compiler_error(self, msg: str) -> None: 24 from dbt.exceptions import CompilationError 25 26 raise CompilationError(msg) 27 28 def warn(self, msg: str) -> str: 29 print(msg) 30 return ""
class
Api:
class
Flags:
class
Modules:
class
SQLExecution:
68class SQLExecution: 69 def __init__(self, adapter: RuntimeAdapter): 70 self.adapter = adapter 71 self._results: t.Dict[str, AttributeDict] = {} 72 73 def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str: 74 from dbt.clients import agate_helper 75 76 if agate_table is None: 77 agate_table = agate_helper.empty_table() 78 79 self._results[name] = AttributeDict( 80 { 81 "response": response, 82 "data": agate_helper.as_matrix(agate_table), 83 "table": agate_table, 84 } 85 ) 86 return "" 87 88 def load_result(self, name: str) -> t.Optional[AttributeDict]: 89 return self._results.get(name) 90 91 def run_query(self, sql: str) -> agate.Table: 92 self.statement("run_query_statement", fetch_result=True, auto_begin=False, caller=sql) 93 resp = self.load_result("run_query_statement") 94 assert resp is not None 95 return resp["table"] 96 97 def statement( 98 self, 99 name: t.Optional[str], 100 fetch_result: bool = False, 101 auto_begin: bool = True, 102 language: str = "sql", 103 caller: t.Optional[jinja2.runtime.Macro | str] = None, 104 ) -> str: 105 """ 106 Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional 107 but we make it optional and at the end because we need to match the signature of the jinja2 macro. 108 109 Name is the name that we store the results to which can be retrieved with `load_result`. If name is not 110 provided then the SQL is executed but the results are not stored. 111 """ 112 if not caller: 113 raise RuntimeError( 114 "Statement relies on a caller to be set that is the target SQL to be run" 115 ) 116 sql = caller if isinstance(caller, str) else caller() 117 if language != "sql": 118 raise NotImplementedError( 119 "SQLMesh's dbt integration only supports SQL statements at this time." 120 ) 121 assert self.adapter is not None 122 res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin) 123 if name: 124 self.store_result(name, res, table) 125 return ""
SQLExecution(adapter: sqlmesh.dbt.adapter.RuntimeAdapter)
def
store_result( self, name: str, response: Any, agate_table: Optional[agate.table.Table]) -> str:
73 def store_result(self, name: str, response: t.Any, agate_table: t.Optional[agate.Table]) -> str: 74 from dbt.clients import agate_helper 75 76 if agate_table is None: 77 agate_table = agate_helper.empty_table() 78 79 self._results[name] = AttributeDict( 80 { 81 "response": response, 82 "data": agate_helper.as_matrix(agate_table), 83 "table": agate_table, 84 } 85 ) 86 return ""
def
statement( self, name: Optional[str], fetch_result: bool = False, auto_begin: bool = True, language: str = 'sql', caller: Union[jinja2.runtime.Macro, str, NoneType] = None) -> str:
97 def statement( 98 self, 99 name: t.Optional[str], 100 fetch_result: bool = False, 101 auto_begin: bool = True, 102 language: str = "sql", 103 caller: t.Optional[jinja2.runtime.Macro | str] = None, 104 ) -> str: 105 """ 106 Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional 107 but we make it optional and at the end because we need to match the signature of the jinja2 macro. 108 109 Name is the name that we store the results to which can be retrieved with `load_result`. If name is not 110 provided then the SQL is executed but the results are not stored. 111 """ 112 if not caller: 113 raise RuntimeError( 114 "Statement relies on a caller to be set that is the target SQL to be run" 115 ) 116 sql = caller if isinstance(caller, str) else caller() 117 if language != "sql": 118 raise NotImplementedError( 119 "SQLMesh's dbt integration only supports SQL statements at this time." 120 ) 121 assert self.adapter is not None 122 res, table = self.adapter.execute(sql, fetch=fetch_result, auto_begin=auto_begin) 123 if name: 124 self.store_result(name, res, table) 125 return ""
Executes the SQL that is defined within the context of the caller. Therefore caller really isn't optional but we make it optional and at the end because we need to match the signature of the jinja2 macro.
Name is the name that we store the results to which can be retrieved with load_result
. If name is not
provided then the SQL is executed but the results are not stored.
def
env_var(name: str, default: Optional[str] = None) -> Optional[str]:
def
is_incremental() -> bool:
def
log(msg: str, info: bool = False) -> str:
def
no_log(msg: str, info: bool = False) -> str:
def
config(*args: Any, **kwargs: Any) -> str:
def
generate_var(variables: Dict[str, Any]) -> Callable:
def
generate_ref(refs: Dict[str, Any]) -> Callable:
158def generate_ref(refs: t.Dict[str, t.Any]) -> t.Callable: 159 160 # TODO suport package name 161 def ref(package: str, name: t.Optional[str] = None) -> t.Optional[BaseRelation]: 162 name = name or package 163 relation_info = refs.get(name) 164 if relation_info is None: 165 return relation_info 166 167 return BaseRelation.create(**relation_info, quote_policy=quote_policy()) 168 169 return ref
def
generate_source(sources: Dict[str, Any]) -> Callable:
172def generate_source(sources: t.Dict[str, t.Any]) -> t.Callable: 173 def source(package: str, name: str) -> t.Optional[BaseRelation]: 174 relation_info = sources.get(f"{package}.{name}") 175 if relation_info is None: 176 return relation_info 177 178 return BaseRelation.create(**relation_info, quote_policy=quote_policy()) 179 180 return source
def
quote_policy() -> dbt.contracts.relation.Policy:
def
return_val(val: Any) -> None:
def
to_set(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
to_json(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
from_json(value: str, default: Optional[Any] = None) -> Optional[Any]:
def
to_yaml(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
from_yaml(value: Any, default: Optional[Any] = None) -> Optional[Any]:
def
do_zip(*args: Any, default: Optional[Any] = None) -> Optional[Any]:
def
as_bool(value: str) -> bool:
def
as_number(value: str) -> Any:
def
create_builtin_globals( jinja_macros: sqlmesh.utils.jinja.JinjaMacroRegistry, jinja_globals: Dict[str, Any], engine_adapter: Optional[sqlmesh.core.engine_adapter.base.EngineAdapter]) -> Dict[str, Any]:
284def create_builtin_globals( 285 jinja_macros: JinjaMacroRegistry, 286 jinja_globals: t.Dict[str, t.Any], 287 engine_adapter: t.Optional[EngineAdapter], 288) -> t.Dict[str, t.Any]: 289 builtin_globals = BUILTIN_GLOBALS.copy() 290 jinja_globals = jinja_globals.copy() 291 292 this = jinja_globals.pop("this", None) 293 if this is not None: 294 if not isinstance(this, BaseRelation): 295 builtin_globals["this"] = BaseRelation.create(**this, quote_policy=quote_policy()) 296 else: 297 builtin_globals["this"] = this 298 299 sources = jinja_globals.pop("sources", None) 300 if sources is not None: 301 builtin_globals["source"] = generate_source(sources) 302 303 refs = jinja_globals.pop("refs", None) 304 if refs is not None: 305 builtin_globals["ref"] = generate_ref(refs) 306 307 variables = jinja_globals.pop("vars", None) 308 if variables is not None: 309 builtin_globals["var"] = generate_var(variables) 310 311 builtin_globals["builtins"] = AttributeDict( 312 {k: builtin_globals.get(k) for k in ("ref", "source", "config")} 313 ) 314 315 if engine_adapter is not None: 316 adapter = RuntimeAdapter( 317 engine_adapter, jinja_macros, jinja_globals={**builtin_globals, **jinja_globals} 318 ) 319 sql_execution = SQLExecution(adapter) 320 builtin_globals.update( 321 { 322 "execute": True, 323 "adapter": adapter, 324 "load_relation": lambda r: adapter.get_relation(r.database, r.schema, r.identifier), 325 "store_result": sql_execution.store_result, 326 "load_result": sql_execution.load_result, 327 "run_query": sql_execution.run_query, 328 "statement": sql_execution.statement, 329 "log": log, 330 "print": log, 331 } 332 ) 333 else: 334 builtin_globals.update( 335 { 336 "execute": False, 337 "adapter": ParsetimeAdapter( 338 jinja_macros, jinja_globals={**builtin_globals, **jinja_globals} 339 ), 340 "load_relation": lambda *args, **kwargs: None, 341 "store_result": lambda *args, **kwargs: "", 342 "load_result": lambda *args, **kwargs: None, 343 "run_query": lambda *args, **kwargs: None, 344 "statement": lambda *args, **kwargs: "", 345 "log": no_log, 346 "print": no_log, 347 } 348 ) 349 350 return {**builtin_globals, **jinja_globals}
def
create_builtin_filters() -> Dict[str, Callable]: