sqlmesh.schedulers.airflow.state_sync
1from __future__ import annotations 2 3import logging 4import typing as t 5 6from sqlmesh.core.console import Console 7from sqlmesh.core.environment import Environment 8from sqlmesh.core.snapshot import ( 9 Snapshot, 10 SnapshotId, 11 SnapshotIdLike, 12 SnapshotNameVersion, 13 SnapshotNameVersionLike, 14) 15from sqlmesh.core.state_sync import StateReader 16from sqlmesh.schedulers.airflow.client import AirflowClient 17 18logger = logging.getLogger(__name__) 19 20 21class HttpStateReader(StateReader): 22 """Reads state of models and snapshot through the Airflow REST API. 23 24 Args: 25 airflow_url: URL pointing to the airflow rest api. 26 username: Username for Airflow. 27 password: Password for Airflow. 28 blocking_updates: Indicates whether calls that cause state updates should be blocking. 29 dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked. 30 Used to block on calls that update the state. 31 console: Used to print out tracking URLs. 32 """ 33 34 def __init__( 35 self, 36 client: AirflowClient, 37 blocking_updates: bool = True, 38 dag_run_poll_interval_secs: int = 2, 39 console: t.Optional[Console] = None, 40 ): 41 self._client = client 42 self.blocking_updates = blocking_updates 43 self.dag_run_poll_interval_secs = dag_run_poll_interval_secs 44 self.console = console 45 46 def get_environment(self, environment: str) -> t.Optional[Environment]: 47 """Fetches the environment if it exists. 48 49 Args: 50 environment: The environment 51 52 Returns: 53 The environment object. 54 """ 55 return self._client.get_environment(environment) 56 57 def get_environments(self) -> t.List[Environment]: 58 """Fetches all environments. 59 60 Returns: 61 A list of all environments. 62 """ 63 return self._client.get_environments() 64 65 def get_snapshots( 66 self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] 67 ) -> t.Dict[SnapshotId, Snapshot]: 68 """Gets multiple snapshots from the rest api. 69 70 Because of the limitations of the Airflow API, this method is inherently inefficient. 71 It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual 72 call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects 73 on the production server. 74 """ 75 snapshots = self._client.get_snapshots( 76 [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None 77 ) 78 return {snapshot.snapshot_id: snapshot for snapshot in snapshots} 79 80 def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: 81 """Checks if multiple snapshots exist in the state sync. 82 83 Args: 84 snapshot_ids: Iterable of snapshot ids to bulk check. 85 86 Returns: 87 A set of existing snapshot IDs. 88 """ 89 if not snapshot_ids: 90 return set() 91 return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids]) 92 93 def get_snapshots_with_same_version( 94 self, snapshots: t.Iterable[SnapshotNameVersionLike] 95 ) -> t.List[Snapshot]: 96 if not snapshots: 97 return [] 98 return self._client.get_snapshots_with_same_version( 99 [SnapshotNameVersion(name=s.name, version=s.version) for s in snapshots] 100 ) 101 102 def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]: 103 """ 104 Get all snapshots by model name. 105 106 Returns: 107 The list of snapshots. 108 """ 109 raise NotImplementedError( 110 "Getting snapshots by model names is not supported by the Airflow HTTP State Sync" 111 )
22class HttpStateReader(StateReader): 23 """Reads state of models and snapshot through the Airflow REST API. 24 25 Args: 26 airflow_url: URL pointing to the airflow rest api. 27 username: Username for Airflow. 28 password: Password for Airflow. 29 blocking_updates: Indicates whether calls that cause state updates should be blocking. 30 dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked. 31 Used to block on calls that update the state. 32 console: Used to print out tracking URLs. 33 """ 34 35 def __init__( 36 self, 37 client: AirflowClient, 38 blocking_updates: bool = True, 39 dag_run_poll_interval_secs: int = 2, 40 console: t.Optional[Console] = None, 41 ): 42 self._client = client 43 self.blocking_updates = blocking_updates 44 self.dag_run_poll_interval_secs = dag_run_poll_interval_secs 45 self.console = console 46 47 def get_environment(self, environment: str) -> t.Optional[Environment]: 48 """Fetches the environment if it exists. 49 50 Args: 51 environment: The environment 52 53 Returns: 54 The environment object. 55 """ 56 return self._client.get_environment(environment) 57 58 def get_environments(self) -> t.List[Environment]: 59 """Fetches all environments. 60 61 Returns: 62 A list of all environments. 63 """ 64 return self._client.get_environments() 65 66 def get_snapshots( 67 self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] 68 ) -> t.Dict[SnapshotId, Snapshot]: 69 """Gets multiple snapshots from the rest api. 70 71 Because of the limitations of the Airflow API, this method is inherently inefficient. 72 It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual 73 call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects 74 on the production server. 75 """ 76 snapshots = self._client.get_snapshots( 77 [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None 78 ) 79 return {snapshot.snapshot_id: snapshot for snapshot in snapshots} 80 81 def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: 82 """Checks if multiple snapshots exist in the state sync. 83 84 Args: 85 snapshot_ids: Iterable of snapshot ids to bulk check. 86 87 Returns: 88 A set of existing snapshot IDs. 89 """ 90 if not snapshot_ids: 91 return set() 92 return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids]) 93 94 def get_snapshots_with_same_version( 95 self, snapshots: t.Iterable[SnapshotNameVersionLike] 96 ) -> t.List[Snapshot]: 97 if not snapshots: 98 return [] 99 return self._client.get_snapshots_with_same_version( 100 [SnapshotNameVersion(name=s.name, version=s.version) for s in snapshots] 101 ) 102 103 def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]: 104 """ 105 Get all snapshots by model name. 106 107 Returns: 108 The list of snapshots. 109 """ 110 raise NotImplementedError( 111 "Getting snapshots by model names is not supported by the Airflow HTTP State Sync" 112 )
Reads state of models and snapshot through the Airflow REST API.
Arguments:
- airflow_url: URL pointing to the airflow rest api.
- username: Username for Airflow.
- password: Password for Airflow.
- blocking_updates: Indicates whether calls that cause state updates should be blocking.
- dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked. Used to block on calls that update the state.
- console: Used to print out tracking URLs.
35 def __init__( 36 self, 37 client: AirflowClient, 38 blocking_updates: bool = True, 39 dag_run_poll_interval_secs: int = 2, 40 console: t.Optional[Console] = None, 41 ): 42 self._client = client 43 self.blocking_updates = blocking_updates 44 self.dag_run_poll_interval_secs = dag_run_poll_interval_secs 45 self.console = console
47 def get_environment(self, environment: str) -> t.Optional[Environment]: 48 """Fetches the environment if it exists. 49 50 Args: 51 environment: The environment 52 53 Returns: 54 The environment object. 55 """ 56 return self._client.get_environment(environment)
Fetches the environment if it exists.
Arguments:
- environment: The environment
Returns:
The environment object.
58 def get_environments(self) -> t.List[Environment]: 59 """Fetches all environments. 60 61 Returns: 62 A list of all environments. 63 """ 64 return self._client.get_environments()
Fetches all environments.
Returns:
A list of all environments.
66 def get_snapshots( 67 self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] 68 ) -> t.Dict[SnapshotId, Snapshot]: 69 """Gets multiple snapshots from the rest api. 70 71 Because of the limitations of the Airflow API, this method is inherently inefficient. 72 It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual 73 call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects 74 on the production server. 75 """ 76 snapshots = self._client.get_snapshots( 77 [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None 78 ) 79 return {snapshot.snapshot_id: snapshot for snapshot in snapshots}
Gets multiple snapshots from the rest api.
Because of the limitations of the Airflow API, this method is inherently inefficient. It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects on the production server.
81 def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: 82 """Checks if multiple snapshots exist in the state sync. 83 84 Args: 85 snapshot_ids: Iterable of snapshot ids to bulk check. 86 87 Returns: 88 A set of existing snapshot IDs. 89 """ 90 if not snapshot_ids: 91 return set() 92 return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids])
Checks if multiple snapshots exist in the state sync.
Arguments:
- snapshot_ids: Iterable of snapshot ids to bulk check.
Returns:
A set of existing snapshot IDs.
94 def get_snapshots_with_same_version( 95 self, snapshots: t.Iterable[SnapshotNameVersionLike] 96 ) -> t.List[Snapshot]: 97 if not snapshots: 98 return [] 99 return self._client.get_snapshots_with_same_version( 100 [SnapshotNameVersion(name=s.name, version=s.version) for s in snapshots] 101 )
Fetches all snapshots that share the same version as the snapshots.
The output includes the snapshots with the specified version.
Arguments:
- snapshots: The collection of target name / version pairs.
Returns:
The list of Snapshot objects.
103 def get_snapshots_by_models(self, *names: str) -> t.List[Snapshot]: 104 """ 105 Get all snapshots by model name. 106 107 Returns: 108 The list of snapshots. 109 """ 110 raise NotImplementedError( 111 "Getting snapshots by model names is not supported by the Airflow HTTP State Sync" 112 )
Get all snapshots by model name.
Returns:
The list of snapshots.