sqlmesh.schedulers.airflow.util
1from __future__ import annotations 2 3import contextlib 4import logging 5import typing as t 6from datetime import timedelta 7 8from airflow import settings 9from airflow.api.common.experimental.delete_dag import delete_dag 10from airflow.exceptions import AirflowException, DagNotFound 11from airflow.models import BaseOperator, DagRun, DagTag, Variable, XCom 12from airflow.utils.session import provide_session 13from airflow.utils.state import DagRunState 14from sqlalchemy.orm import Session 15 16from sqlmesh.core.engine_adapter import create_engine_adapter 17from sqlmesh.core.state_sync import EngineAdapterStateSync, StateSync 18from sqlmesh.schedulers.airflow import common 19from sqlmesh.utils.date import now 20from sqlmesh.utils.errors import SQLMeshError 21 22logger = logging.getLogger(__name__) 23 24 25# Used to omit Optional for session instances supplied by 26# Airflow at runtime. This makes the type signature cleaner 27# and prevents mypy from complaining. 28PROVIDED_SESSION: Session = t.cast(Session, None) 29 30 31@contextlib.contextmanager 32def scoped_state_sync() -> t.Generator[StateSync, None, None]: 33 dialect = settings.engine.dialect.name 34 engine_adapter = create_engine_adapter( 35 settings.engine.raw_connection, dialect, multithreaded=True 36 ) 37 try: 38 yield EngineAdapterStateSync(engine_adapter, "sqlmesh") 39 finally: 40 engine_adapter.close() 41 42 43@provide_session 44def get_snapshot_dag_ids(session: Session = PROVIDED_SESSION) -> t.List[str]: 45 dag_tags = session.query(DagTag).filter(DagTag.name == common.SNAPSHOT_AIRFLOW_TAG).all() 46 return [tag.dag_id for tag in dag_tags] 47 48 49@provide_session 50def get_finished_plan_application_dag_ids( 51 ttl: t.Optional[timedelta] = None, session: Session = PROVIDED_SESSION 52) -> t.Set[str]: 53 dag_ids = ( 54 session.query(DagTag.dag_id) 55 .join(DagRun, DagTag.dag_id == DagRun.dag_id) 56 .filter( 57 DagTag.name == common.PLAN_AIRFLOW_TAG, 58 DagRun.state.in_((DagRunState.SUCCESS, DagRunState.FAILED)), 59 ) 60 ) 61 if ttl is not None: 62 dag_ids = dag_ids.filter(DagRun.last_scheduling_decision <= now() - ttl) 63 return {dag_id[0] for dag_id in dag_ids.all()} 64 65 66@provide_session 67def delete_dags(dag_ids: t.Set[str], session: Session = PROVIDED_SESSION) -> None: 68 for dag_id in dag_ids: 69 try: 70 delete_dag(dag_id, session=session) 71 except DagNotFound: 72 logger.warning("DAG '%s' was not found", dag_id) 73 except AirflowException: 74 logger.warning("Failed to delete DAG '%s'", dag_id, exc_info=True) 75 76 77@provide_session 78def delete_xcoms( 79 dag_id: str, 80 keys: t.Set[str], 81 task_id: t.Optional[str] = None, 82 run_id: t.Optional[str] = None, 83 session: Session = PROVIDED_SESSION, 84) -> None: 85 query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.key.in_(keys)) 86 if task_id is not None: 87 query = query.filter_by(task_id=task_id) 88 if run_id is not None: 89 query = query.filter_by(run_id=run_id) 90 query.delete(synchronize_session=False) 91 92 93@provide_session 94def delete_variables( 95 keys: t.Set[str], 96 session: Session = PROVIDED_SESSION, 97) -> None: 98 (session.query(Variable).filter(Variable.key.in_(keys)).delete(synchronize_session=False)) 99 100 101def discover_engine_operator(name: str, sql_only: bool = False) -> t.Type[BaseOperator]: 102 name = name.lower() 103 104 try: 105 if name == "spark": 106 from sqlmesh.schedulers.airflow.operators.spark_submit import ( 107 SQLMeshSparkSubmitOperator, 108 ) 109 110 return SQLMeshSparkSubmitOperator 111 if name in ("databricks", "databricks-submit", "databricks-sql"): 112 if name == "databricks-submit" or (name == "databricks" and not sql_only): 113 from sqlmesh.schedulers.airflow.operators.databricks import ( 114 SQLMeshDatabricksSubmitOperator, 115 ) 116 117 return SQLMeshDatabricksSubmitOperator 118 if name == "databricks-sql" or (name == "databricks" and sql_only): 119 from sqlmesh.schedulers.airflow.operators.databricks import ( 120 SQLMeshDatabricksSQLOperator, 121 ) 122 123 return SQLMeshDatabricksSQLOperator 124 if name == "snowflake": 125 from sqlmesh.schedulers.airflow.operators.snowflake import ( 126 SQLMeshSnowflakeOperator, 127 ) 128 129 return SQLMeshSnowflakeOperator 130 if name == "bigquery": 131 from sqlmesh.schedulers.airflow.operators.bigquery import ( 132 SQLMeshBigQueryOperator, 133 ) 134 135 return SQLMeshBigQueryOperator 136 if name == "redshift": 137 from sqlmesh.schedulers.airflow.operators.redshift import ( 138 SQLMeshRedshiftOperator, 139 ) 140 141 return SQLMeshRedshiftOperator 142 except ImportError: 143 raise SQLMeshError(f"Failed to automatically discover an operator for '{name}'.'") 144 145 raise ValueError(f"Unsupported engine name '{name}'.")
@contextlib.contextmanager
def
scoped_state_sync() -> Generator[sqlmesh.core.state_sync.base.StateSync, NoneType, NoneType]:
32@contextlib.contextmanager 33def scoped_state_sync() -> t.Generator[StateSync, None, None]: 34 dialect = settings.engine.dialect.name 35 engine_adapter = create_engine_adapter( 36 settings.engine.raw_connection, dialect, multithreaded=True 37 ) 38 try: 39 yield EngineAdapterStateSync(engine_adapter, "sqlmesh") 40 finally: 41 engine_adapter.close()
@provide_session
def
get_snapshot_dag_ids(session: sqlalchemy.orm.session.Session = None) -> List[str]:
@provide_session
def
get_finished_plan_application_dag_ids( ttl: Optional[datetime.timedelta] = None, session: sqlalchemy.orm.session.Session = None) -> Set[str]:
50@provide_session 51def get_finished_plan_application_dag_ids( 52 ttl: t.Optional[timedelta] = None, session: Session = PROVIDED_SESSION 53) -> t.Set[str]: 54 dag_ids = ( 55 session.query(DagTag.dag_id) 56 .join(DagRun, DagTag.dag_id == DagRun.dag_id) 57 .filter( 58 DagTag.name == common.PLAN_AIRFLOW_TAG, 59 DagRun.state.in_((DagRunState.SUCCESS, DagRunState.FAILED)), 60 ) 61 ) 62 if ttl is not None: 63 dag_ids = dag_ids.filter(DagRun.last_scheduling_decision <= now() - ttl) 64 return {dag_id[0] for dag_id in dag_ids.all()}
@provide_session
def
delete_dags( dag_ids: Set[str], session: sqlalchemy.orm.session.Session = None) -> None:
67@provide_session 68def delete_dags(dag_ids: t.Set[str], session: Session = PROVIDED_SESSION) -> None: 69 for dag_id in dag_ids: 70 try: 71 delete_dag(dag_id, session=session) 72 except DagNotFound: 73 logger.warning("DAG '%s' was not found", dag_id) 74 except AirflowException: 75 logger.warning("Failed to delete DAG '%s'", dag_id, exc_info=True)
@provide_session
def
delete_xcoms( dag_id: str, keys: Set[str], task_id: Optional[str] = None, run_id: Optional[str] = None, session: sqlalchemy.orm.session.Session = None) -> None:
78@provide_session 79def delete_xcoms( 80 dag_id: str, 81 keys: t.Set[str], 82 task_id: t.Optional[str] = None, 83 run_id: t.Optional[str] = None, 84 session: Session = PROVIDED_SESSION, 85) -> None: 86 query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.key.in_(keys)) 87 if task_id is not None: 88 query = query.filter_by(task_id=task_id) 89 if run_id is not None: 90 query = query.filter_by(run_id=run_id) 91 query.delete(synchronize_session=False)
@provide_session
def
delete_variables(keys: Set[str], session: sqlalchemy.orm.session.Session = None) -> None:
def
discover_engine_operator( name: str, sql_only: bool = False) -> Type[airflow.models.baseoperator.BaseOperator]:
102def discover_engine_operator(name: str, sql_only: bool = False) -> t.Type[BaseOperator]: 103 name = name.lower() 104 105 try: 106 if name == "spark": 107 from sqlmesh.schedulers.airflow.operators.spark_submit import ( 108 SQLMeshSparkSubmitOperator, 109 ) 110 111 return SQLMeshSparkSubmitOperator 112 if name in ("databricks", "databricks-submit", "databricks-sql"): 113 if name == "databricks-submit" or (name == "databricks" and not sql_only): 114 from sqlmesh.schedulers.airflow.operators.databricks import ( 115 SQLMeshDatabricksSubmitOperator, 116 ) 117 118 return SQLMeshDatabricksSubmitOperator 119 if name == "databricks-sql" or (name == "databricks" and sql_only): 120 from sqlmesh.schedulers.airflow.operators.databricks import ( 121 SQLMeshDatabricksSQLOperator, 122 ) 123 124 return SQLMeshDatabricksSQLOperator 125 if name == "snowflake": 126 from sqlmesh.schedulers.airflow.operators.snowflake import ( 127 SQLMeshSnowflakeOperator, 128 ) 129 130 return SQLMeshSnowflakeOperator 131 if name == "bigquery": 132 from sqlmesh.schedulers.airflow.operators.bigquery import ( 133 SQLMeshBigQueryOperator, 134 ) 135 136 return SQLMeshBigQueryOperator 137 if name == "redshift": 138 from sqlmesh.schedulers.airflow.operators.redshift import ( 139 SQLMeshRedshiftOperator, 140 ) 141 142 return SQLMeshRedshiftOperator 143 except ImportError: 144 raise SQLMeshError(f"Failed to automatically discover an operator for '{name}'.'") 145 146 raise ValueError(f"Unsupported engine name '{name}'.")