Edit on GitHub

sqlmesh.schedulers.airflow.util

  1from __future__ import annotations
  2
  3import contextlib
  4import logging
  5import typing as t
  6from datetime import timedelta
  7
  8from airflow import settings
  9from airflow.api.common.experimental.delete_dag import delete_dag
 10from airflow.exceptions import AirflowException, DagNotFound
 11from airflow.models import BaseOperator, DagRun, DagTag, Variable, XCom
 12from airflow.utils.session import provide_session
 13from airflow.utils.state import DagRunState
 14from sqlalchemy.orm import Session
 15
 16from sqlmesh.core.engine_adapter import create_engine_adapter
 17from sqlmesh.core.state_sync import EngineAdapterStateSync, StateSync
 18from sqlmesh.schedulers.airflow import common
 19from sqlmesh.utils.date import now
 20from sqlmesh.utils.errors import SQLMeshError
 21
 22logger = logging.getLogger(__name__)
 23
 24
 25# Used to omit Optional for session instances supplied by
 26# Airflow at runtime. This makes the type signature cleaner
 27# and prevents mypy from complaining.
 28PROVIDED_SESSION: Session = t.cast(Session, None)
 29
 30
 31@contextlib.contextmanager
 32def scoped_state_sync() -> t.Generator[StateSync, None, None]:
 33    dialect = settings.engine.dialect.name
 34    engine_adapter = create_engine_adapter(
 35        settings.engine.raw_connection, dialect, multithreaded=True
 36    )
 37    try:
 38        yield EngineAdapterStateSync(engine_adapter, "sqlmesh")
 39    finally:
 40        engine_adapter.close()
 41
 42
 43@provide_session
 44def get_snapshot_dag_ids(session: Session = PROVIDED_SESSION) -> t.List[str]:
 45    dag_tags = session.query(DagTag).filter(DagTag.name == common.SNAPSHOT_AIRFLOW_TAG).all()
 46    return [tag.dag_id for tag in dag_tags]
 47
 48
 49@provide_session
 50def get_finished_plan_application_dag_ids(
 51    ttl: t.Optional[timedelta] = None, session: Session = PROVIDED_SESSION
 52) -> t.Set[str]:
 53    dag_ids = (
 54        session.query(DagTag.dag_id)
 55        .join(DagRun, DagTag.dag_id == DagRun.dag_id)
 56        .filter(
 57            DagTag.name == common.PLAN_AIRFLOW_TAG,
 58            DagRun.state.in_((DagRunState.SUCCESS, DagRunState.FAILED)),
 59        )
 60    )
 61    if ttl is not None:
 62        dag_ids = dag_ids.filter(DagRun.last_scheduling_decision <= now() - ttl)
 63    return {dag_id[0] for dag_id in dag_ids.all()}
 64
 65
 66@provide_session
 67def delete_dags(dag_ids: t.Set[str], session: Session = PROVIDED_SESSION) -> None:
 68    for dag_id in dag_ids:
 69        try:
 70            delete_dag(dag_id, session=session)
 71        except DagNotFound:
 72            logger.warning("DAG '%s' was not found", dag_id)
 73        except AirflowException:
 74            logger.warning("Failed to delete DAG '%s'", dag_id, exc_info=True)
 75
 76
 77@provide_session
 78def delete_xcoms(
 79    dag_id: str,
 80    keys: t.Set[str],
 81    task_id: t.Optional[str] = None,
 82    run_id: t.Optional[str] = None,
 83    session: Session = PROVIDED_SESSION,
 84) -> None:
 85    query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.key.in_(keys))
 86    if task_id is not None:
 87        query = query.filter_by(task_id=task_id)
 88    if run_id is not None:
 89        query = query.filter_by(run_id=run_id)
 90    query.delete(synchronize_session=False)
 91
 92
 93@provide_session
 94def delete_variables(
 95    keys: t.Set[str],
 96    session: Session = PROVIDED_SESSION,
 97) -> None:
 98    (session.query(Variable).filter(Variable.key.in_(keys)).delete(synchronize_session=False))
 99
100
101def discover_engine_operator(name: str, sql_only: bool = False) -> t.Type[BaseOperator]:
102    name = name.lower()
103
104    try:
105        if name == "spark":
106            from sqlmesh.schedulers.airflow.operators.spark_submit import (
107                SQLMeshSparkSubmitOperator,
108            )
109
110            return SQLMeshSparkSubmitOperator
111        if name in ("databricks", "databricks-submit", "databricks-sql"):
112            if name == "databricks-submit" or (name == "databricks" and not sql_only):
113                from sqlmesh.schedulers.airflow.operators.databricks import (
114                    SQLMeshDatabricksSubmitOperator,
115                )
116
117                return SQLMeshDatabricksSubmitOperator
118            if name == "databricks-sql" or (name == "databricks" and sql_only):
119                from sqlmesh.schedulers.airflow.operators.databricks import (
120                    SQLMeshDatabricksSQLOperator,
121                )
122
123                return SQLMeshDatabricksSQLOperator
124        if name == "snowflake":
125            from sqlmesh.schedulers.airflow.operators.snowflake import (
126                SQLMeshSnowflakeOperator,
127            )
128
129            return SQLMeshSnowflakeOperator
130        if name == "bigquery":
131            from sqlmesh.schedulers.airflow.operators.bigquery import (
132                SQLMeshBigQueryOperator,
133            )
134
135            return SQLMeshBigQueryOperator
136        if name == "redshift":
137            from sqlmesh.schedulers.airflow.operators.redshift import (
138                SQLMeshRedshiftOperator,
139            )
140
141            return SQLMeshRedshiftOperator
142    except ImportError:
143        raise SQLMeshError(f"Failed to automatically discover an operator for '{name}'.'")
144
145    raise ValueError(f"Unsupported engine name '{name}'.")
@contextlib.contextmanager
def scoped_state_sync() -> Generator[sqlmesh.core.state_sync.base.StateSync, NoneType, NoneType]:
32@contextlib.contextmanager
33def scoped_state_sync() -> t.Generator[StateSync, None, None]:
34    dialect = settings.engine.dialect.name
35    engine_adapter = create_engine_adapter(
36        settings.engine.raw_connection, dialect, multithreaded=True
37    )
38    try:
39        yield EngineAdapterStateSync(engine_adapter, "sqlmesh")
40    finally:
41        engine_adapter.close()
@provide_session
def get_snapshot_dag_ids(session: sqlalchemy.orm.session.Session = None) -> List[str]:
44@provide_session
45def get_snapshot_dag_ids(session: Session = PROVIDED_SESSION) -> t.List[str]:
46    dag_tags = session.query(DagTag).filter(DagTag.name == common.SNAPSHOT_AIRFLOW_TAG).all()
47    return [tag.dag_id for tag in dag_tags]
@provide_session
def get_finished_plan_application_dag_ids( ttl: Optional[datetime.timedelta] = None, session: sqlalchemy.orm.session.Session = None) -> Set[str]:
50@provide_session
51def get_finished_plan_application_dag_ids(
52    ttl: t.Optional[timedelta] = None, session: Session = PROVIDED_SESSION
53) -> t.Set[str]:
54    dag_ids = (
55        session.query(DagTag.dag_id)
56        .join(DagRun, DagTag.dag_id == DagRun.dag_id)
57        .filter(
58            DagTag.name == common.PLAN_AIRFLOW_TAG,
59            DagRun.state.in_((DagRunState.SUCCESS, DagRunState.FAILED)),
60        )
61    )
62    if ttl is not None:
63        dag_ids = dag_ids.filter(DagRun.last_scheduling_decision <= now() - ttl)
64    return {dag_id[0] for dag_id in dag_ids.all()}
@provide_session
def delete_dags( dag_ids: Set[str], session: sqlalchemy.orm.session.Session = None) -> None:
67@provide_session
68def delete_dags(dag_ids: t.Set[str], session: Session = PROVIDED_SESSION) -> None:
69    for dag_id in dag_ids:
70        try:
71            delete_dag(dag_id, session=session)
72        except DagNotFound:
73            logger.warning("DAG '%s' was not found", dag_id)
74        except AirflowException:
75            logger.warning("Failed to delete DAG '%s'", dag_id, exc_info=True)
@provide_session
def delete_xcoms( dag_id: str, keys: Set[str], task_id: Optional[str] = None, run_id: Optional[str] = None, session: sqlalchemy.orm.session.Session = None) -> None:
78@provide_session
79def delete_xcoms(
80    dag_id: str,
81    keys: t.Set[str],
82    task_id: t.Optional[str] = None,
83    run_id: t.Optional[str] = None,
84    session: Session = PROVIDED_SESSION,
85) -> None:
86    query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.key.in_(keys))
87    if task_id is not None:
88        query = query.filter_by(task_id=task_id)
89    if run_id is not None:
90        query = query.filter_by(run_id=run_id)
91    query.delete(synchronize_session=False)
@provide_session
def delete_variables(keys: Set[str], session: sqlalchemy.orm.session.Session = None) -> None:
94@provide_session
95def delete_variables(
96    keys: t.Set[str],
97    session: Session = PROVIDED_SESSION,
98) -> None:
99    (session.query(Variable).filter(Variable.key.in_(keys)).delete(synchronize_session=False))
def discover_engine_operator( name: str, sql_only: bool = False) -> Type[airflow.models.baseoperator.BaseOperator]:
102def discover_engine_operator(name: str, sql_only: bool = False) -> t.Type[BaseOperator]:
103    name = name.lower()
104
105    try:
106        if name == "spark":
107            from sqlmesh.schedulers.airflow.operators.spark_submit import (
108                SQLMeshSparkSubmitOperator,
109            )
110
111            return SQLMeshSparkSubmitOperator
112        if name in ("databricks", "databricks-submit", "databricks-sql"):
113            if name == "databricks-submit" or (name == "databricks" and not sql_only):
114                from sqlmesh.schedulers.airflow.operators.databricks import (
115                    SQLMeshDatabricksSubmitOperator,
116                )
117
118                return SQLMeshDatabricksSubmitOperator
119            if name == "databricks-sql" or (name == "databricks" and sql_only):
120                from sqlmesh.schedulers.airflow.operators.databricks import (
121                    SQLMeshDatabricksSQLOperator,
122                )
123
124                return SQLMeshDatabricksSQLOperator
125        if name == "snowflake":
126            from sqlmesh.schedulers.airflow.operators.snowflake import (
127                SQLMeshSnowflakeOperator,
128            )
129
130            return SQLMeshSnowflakeOperator
131        if name == "bigquery":
132            from sqlmesh.schedulers.airflow.operators.bigquery import (
133                SQLMeshBigQueryOperator,
134            )
135
136            return SQLMeshBigQueryOperator
137        if name == "redshift":
138            from sqlmesh.schedulers.airflow.operators.redshift import (
139                SQLMeshRedshiftOperator,
140            )
141
142            return SQLMeshRedshiftOperator
143    except ImportError:
144        raise SQLMeshError(f"Failed to automatically discover an operator for '{name}'.'")
145
146    raise ValueError(f"Unsupported engine name '{name}'.")