Edit on GitHub

sqlmesh.schedulers.airflow.state_sync

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5
  6from sqlmesh.core.console import Console
  7from sqlmesh.core.environment import Environment
  8from sqlmesh.core.snapshot import (
  9    Snapshot,
 10    SnapshotId,
 11    SnapshotIdLike,
 12    SnapshotNameVersion,
 13    SnapshotNameVersionLike,
 14)
 15from sqlmesh.core.state_sync import StateReader
 16from sqlmesh.schedulers.airflow.client import AirflowClient
 17
 18logger = logging.getLogger(__name__)
 19
 20
 21class HttpStateReader(StateReader):
 22    """Reads state of models and snapshot through the Airflow REST API.
 23
 24    Args:
 25        airflow_url: URL pointing to the airflow rest api.
 26        username: Username for Airflow.
 27        password: Password for Airflow.
 28        blocking_updates: Indicates whether calls that cause state updates should be blocking.
 29        dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked.
 30            Used to block on calls that update the state.
 31        console: Used to print out tracking URLs.
 32    """
 33
 34    def __init__(
 35        self,
 36        client: AirflowClient,
 37        blocking_updates: bool = True,
 38        dag_run_poll_interval_secs: int = 2,
 39        console: t.Optional[Console] = None,
 40    ):
 41        self._client = client
 42        self.blocking_updates = blocking_updates
 43        self.dag_run_poll_interval_secs = dag_run_poll_interval_secs
 44        self.console = console
 45
 46    def get_environment(self, environment: str) -> t.Optional[Environment]:
 47        """Fetches the environment if it exists.
 48
 49        Args:
 50            environment: The environment
 51
 52        Returns:
 53            The environment object.
 54        """
 55        return self._client.get_environment(environment)
 56
 57    def get_environments(self) -> t.List[Environment]:
 58        """Fetches all environments.
 59
 60        Returns:
 61            A list of all environments.
 62        """
 63        return self._client.get_environments()
 64
 65    def get_snapshots(
 66        self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
 67    ) -> t.Dict[SnapshotId, Snapshot]:
 68        """Gets multiple snapshots from the rest api.
 69
 70        Because of the limitations of the Airflow API, this method is inherently inefficient.
 71        It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual
 72        call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects
 73        on the production server.
 74        """
 75        snapshots = self._client.get_snapshots(
 76            [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None
 77        )
 78        return {snapshot.snapshot_id: snapshot for snapshot in snapshots}
 79
 80    def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
 81        """Checks if multiple snapshots exist in the state sync.
 82
 83        Args:
 84            snapshot_ids: Iterable of snapshot ids to bulk check.
 85
 86        Returns:
 87            A set of existing snapshot IDs.
 88        """
 89        if not snapshot_ids:
 90            return set()
 91        return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids])
 92
 93    def get_snapshots_with_same_version(
 94        self, snapshots: t.Iterable[SnapshotNameVersionLike]
 95    ) -> t.List[Snapshot]:
 96        if not snapshots:
 97            return []
 98        return self._client.get_snapshots_with_same_version(
 99            [SnapshotNameVersion(name=s.name, version=s.version) for s in snapshots]
100        )
101
102    def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]:
103        """
104        Get all snapshots by model name.
105
106        Returns:
107            The list of snapshots.
108        """
109        raise NotImplementedError(
110            "Getting snapshots by model names is not supported by the Airflow HTTP State Sync"
111        )
class HttpStateReader(sqlmesh.core.state_sync.base.StateReader):
 22class HttpStateReader(StateReader):
 23    """Reads state of models and snapshot through the Airflow REST API.
 24
 25    Args:
 26        airflow_url: URL pointing to the airflow rest api.
 27        username: Username for Airflow.
 28        password: Password for Airflow.
 29        blocking_updates: Indicates whether calls that cause state updates should be blocking.
 30        dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked.
 31            Used to block on calls that update the state.
 32        console: Used to print out tracking URLs.
 33    """
 34
 35    def __init__(
 36        self,
 37        client: AirflowClient,
 38        blocking_updates: bool = True,
 39        dag_run_poll_interval_secs: int = 2,
 40        console: t.Optional[Console] = None,
 41    ):
 42        self._client = client
 43        self.blocking_updates = blocking_updates
 44        self.dag_run_poll_interval_secs = dag_run_poll_interval_secs
 45        self.console = console
 46
 47    def get_environment(self, environment: str) -> t.Optional[Environment]:
 48        """Fetches the environment if it exists.
 49
 50        Args:
 51            environment: The environment
 52
 53        Returns:
 54            The environment object.
 55        """
 56        return self._client.get_environment(environment)
 57
 58    def get_environments(self) -> t.List[Environment]:
 59        """Fetches all environments.
 60
 61        Returns:
 62            A list of all environments.
 63        """
 64        return self._client.get_environments()
 65
 66    def get_snapshots(
 67        self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
 68    ) -> t.Dict[SnapshotId, Snapshot]:
 69        """Gets multiple snapshots from the rest api.
 70
 71        Because of the limitations of the Airflow API, this method is inherently inefficient.
 72        It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual
 73        call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects
 74        on the production server.
 75        """
 76        snapshots = self._client.get_snapshots(
 77            [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None
 78        )
 79        return {snapshot.snapshot_id: snapshot for snapshot in snapshots}
 80
 81    def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
 82        """Checks if multiple snapshots exist in the state sync.
 83
 84        Args:
 85            snapshot_ids: Iterable of snapshot ids to bulk check.
 86
 87        Returns:
 88            A set of existing snapshot IDs.
 89        """
 90        if not snapshot_ids:
 91            return set()
 92        return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids])
 93
 94    def get_snapshots_with_same_version(
 95        self, snapshots: t.Iterable[SnapshotNameVersionLike]
 96    ) -> t.List[Snapshot]:
 97        if not snapshots:
 98            return []
 99        return self._client.get_snapshots_with_same_version(
100            [SnapshotNameVersion(name=s.name, version=s.version) for s in snapshots]
101        )
102
103    def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]:
104        """
105        Get all snapshots by model name.
106
107        Returns:
108            The list of snapshots.
109        """
110        raise NotImplementedError(
111            "Getting snapshots by model names is not supported by the Airflow HTTP State Sync"
112        )

Reads state of models and snapshot through the Airflow REST API.

Arguments:
  • airflow_url: URL pointing to the airflow rest api.
  • username: Username for Airflow.
  • password: Password for Airflow.
  • blocking_updates: Indicates whether calls that cause state updates should be blocking.
  • dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked. Used to block on calls that update the state.
  • console: Used to print out tracking URLs.
HttpStateReader( client: sqlmesh.schedulers.airflow.client.AirflowClient, blocking_updates: bool = True, dag_run_poll_interval_secs: int = 2, console: Optional[sqlmesh.core.console.Console] = None)
35    def __init__(
36        self,
37        client: AirflowClient,
38        blocking_updates: bool = True,
39        dag_run_poll_interval_secs: int = 2,
40        console: t.Optional[Console] = None,
41    ):
42        self._client = client
43        self.blocking_updates = blocking_updates
44        self.dag_run_poll_interval_secs = dag_run_poll_interval_secs
45        self.console = console
def get_environment(self, environment: str) -> Optional[sqlmesh.core.environment.Environment]:
47    def get_environment(self, environment: str) -> t.Optional[Environment]:
48        """Fetches the environment if it exists.
49
50        Args:
51            environment: The environment
52
53        Returns:
54            The environment object.
55        """
56        return self._client.get_environment(environment)

Fetches the environment if it exists.

Arguments:
  • environment: The environment
Returns:

The environment object.

def get_environments(self) -> List[sqlmesh.core.environment.Environment]:
58    def get_environments(self) -> t.List[Environment]:
59        """Fetches all environments.
60
61        Returns:
62            A list of all environments.
63        """
64        return self._client.get_environments()

Fetches all environments.

Returns:

A list of all environments.

66    def get_snapshots(
67        self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
68    ) -> t.Dict[SnapshotId, Snapshot]:
69        """Gets multiple snapshots from the rest api.
70
71        Because of the limitations of the Airflow API, this method is inherently inefficient.
72        It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual
73        call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects
74        on the production server.
75        """
76        snapshots = self._client.get_snapshots(
77            [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None
78        )
79        return {snapshot.snapshot_id: snapshot for snapshot in snapshots}

Gets multiple snapshots from the rest api.

Because of the limitations of the Airflow API, this method is inherently inefficient. It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects on the production server.

81    def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
82        """Checks if multiple snapshots exist in the state sync.
83
84        Args:
85            snapshot_ids: Iterable of snapshot ids to bulk check.
86
87        Returns:
88            A set of existing snapshot IDs.
89        """
90        if not snapshot_ids:
91            return set()
92        return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids])

Checks if multiple snapshots exist in the state sync.

Arguments:
  • snapshot_ids: Iterable of snapshot ids to bulk check.
Returns:

A set of existing snapshot IDs.

 94    def get_snapshots_with_same_version(
 95        self, snapshots: t.Iterable[SnapshotNameVersionLike]
 96    ) -> t.List[Snapshot]:
 97        if not snapshots:
 98            return []
 99        return self._client.get_snapshots_with_same_version(
100            [SnapshotNameVersion(name=s.name, version=s.version) for s in snapshots]
101        )

Fetches all snapshots that share the same version as the snapshots.

The output includes the snapshots with the specified version.

Arguments:
  • snapshots: The collection of target name / version pairs.
Returns:

The list of Snapshot objects.

def get_snapshots_by_models(self, *names: str) -> List[sqlmesh.core.snapshot.definition.Snapshot]:
103    def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]:
104        """
105        Get all snapshots by model name.
106
107        Returns:
108            The list of snapshots.
109        """
110        raise NotImplementedError(
111            "Getting snapshots by model names is not supported by the Airflow HTTP State Sync"
112        )

Get all snapshots by model name.

Returns:

The list of snapshots.