sqlmesh.schedulers.airflow.client
1import json 2import time 3import typing as t 4import uuid 5from urllib.parse import urlencode, urljoin 6 7import requests 8from requests.models import Response 9 10from sqlmesh.core._typing import NotificationTarget 11from sqlmesh.core.console import Console 12from sqlmesh.core.environment import Environment 13from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotNameVersion 14from sqlmesh.core.user import User 15from sqlmesh.schedulers.airflow import common 16from sqlmesh.utils.errors import ( 17 ApiClientError, 18 ApiServerError, 19 NotFoundError, 20 SQLMeshError, 21) 22from sqlmesh.utils.pydantic import PydanticModel 23 24DAG_RUN_PATH_TEMPLATE = "api/v1/dags/{}/dagRuns" 25 26 27PLANS_PATH = f"{common.SQLMESH_API_BASE_PATH}/plans" 28ENVIRONMENTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/environments" 29SNAPSHOTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/snapshots" 30 31 32class AirflowClient: 33 def __init__( 34 self, 35 session: requests.Session, 36 airflow_url: str, 37 console: t.Optional[Console] = None, 38 ): 39 self._session = session 40 self._airflow_url = airflow_url 41 self._console = console 42 43 def apply_plan( 44 self, 45 new_snapshots: t.Iterable[Snapshot], 46 environment: Environment, 47 request_id: str, 48 no_gaps: bool = False, 49 skip_backfill: bool = False, 50 restatements: t.Optional[t.Iterable[str]] = None, 51 notification_targets: t.Optional[t.List[NotificationTarget]] = None, 52 backfill_concurrent_tasks: int = 1, 53 ddl_concurrent_tasks: int = 1, 54 users: t.Optional[t.List[User]] = None, 55 is_dev: bool = False, 56 ) -> None: 57 request = common.PlanApplicationRequest( 58 new_snapshots=list(new_snapshots), 59 environment=environment, 60 no_gaps=no_gaps, 61 skip_backfill=skip_backfill, 62 request_id=request_id, 63 restatements=set(restatements or []), 64 notification_targets=notification_targets or [], 65 backfill_concurrent_tasks=backfill_concurrent_tasks, 66 ddl_concurrent_tasks=ddl_concurrent_tasks, 67 users=users or [], 68 is_dev=is_dev, 69 ) 70 71 response = self._session.post( 72 urljoin(self._airflow_url, PLANS_PATH), 73 data=request.json(), 74 ) 75 self._raise_for_status(response) 76 77 def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]: 78 params: t.Dict[str, str] = {} 79 if snapshot_ids is not None: 80 params["ids"] = _list_to_json(snapshot_ids) 81 82 return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, **params)).snapshots 83 84 def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]: 85 return set( 86 common.SnapshotIdsResponse.parse_obj( 87 self._get(SNAPSHOTS_PATH, "check_existence", ids=_list_to_json(snapshot_ids)) 88 ).snapshot_ids 89 ) 90 91 def get_snapshots_with_same_version( 92 self, snapshot_name_versions: t.List[SnapshotNameVersion] 93 ) -> t.List[Snapshot]: 94 return common.SnapshotsResponse.parse_obj( 95 self._get(SNAPSHOTS_PATH, versions=_list_to_json(snapshot_name_versions)) 96 ).snapshots 97 98 def get_environment(self, environment: str) -> t.Optional[Environment]: 99 try: 100 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}") 101 return Environment.parse_obj(response) 102 except NotFoundError: 103 return None 104 105 def get_environments(self) -> t.List[Environment]: 106 response = self._get(ENVIRONMENTS_PATH) 107 return common.EnvironmentsResponse.parse_obj(response).environments 108 109 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 110 url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}" 111 return self._get(url)["state"].lower() 112 113 def get_janitor_dag(self) -> t.Dict[str, t.Any]: 114 return self._get_dag(common.JANITOR_DAG_ID) 115 116 def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]: 117 return self._get_dag(common.dag_id_for_name_version(name, version)) 118 119 def get_all_dags(self) -> t.Dict[str, t.Any]: 120 return self._get("api/v1/dags") 121 122 def wait_for_dag_run_completion( 123 self, dag_id: str, dag_run_id: str, poll_interval_secs: int 124 ) -> bool: 125 loading_id = self._console_loading_start() 126 127 while True: 128 state = self.get_dag_run_state(dag_id, dag_run_id) 129 if state in ("failed", "success"): 130 if self._console and loading_id: 131 self._console.loading_stop(loading_id) 132 return state == "success" 133 134 time.sleep(poll_interval_secs) 135 136 def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: 137 loading_id = self._console_loading_start() 138 139 attempt_num = 1 140 141 try: 142 while True: 143 try: 144 first_dag_run_id = self._get_first_dag_run_id(dag_id) 145 if first_dag_run_id is None: 146 raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") 147 return first_dag_run_id 148 except ApiServerError: 149 raise 150 except SQLMeshError: 151 if attempt_num > max_retries: 152 raise 153 154 attempt_num += 1 155 time.sleep(poll_interval_secs) 156 finally: 157 if self._console and loading_id: 158 self._console.loading_stop(loading_id) 159 160 def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: 161 if not self._console: 162 return 163 164 tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) 165 # TODO: Figure out generalized solution for links 166 self._console.log_status_update( 167 f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" 168 ) 169 170 def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str: 171 url_params = urlencode( 172 dict( 173 dag_id=dag_id, 174 run_id=dag_run_id, 175 ) 176 ) 177 return urljoin(self._airflow_url, f"dagrun_details?{url_params}") 178 179 def close(self) -> None: 180 self._session.close() 181 182 def _get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 183 dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1") 184 dag_runs = dag_runs_response["dag_runs"] 185 if not dag_runs: 186 return None 187 return dag_runs[0]["dag_run_id"] 188 189 def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]: 190 return self._get(f"api/v1/dags/{dag_id}") 191 192 def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]: 193 all_params = [*flags, *([urlencode(params)] if params else [])] 194 query_string = "&".join(all_params) 195 if query_string: 196 path = f"{path}?{query_string}" 197 response = self._session.get(urljoin(self._airflow_url, path)) 198 self._raise_for_status(response) 199 return response.json() 200 201 def _console_loading_start(self) -> t.Optional[uuid.UUID]: 202 if self._console: 203 return self._console.loading_start() 204 return None 205 206 def _raise_for_status(self, response: Response) -> None: 207 if response.status_code == 404: 208 raise NotFoundError(response.text) 209 elif 400 <= response.status_code < 500: 210 raise ApiClientError(response.text) 211 elif 500 <= response.status_code < 600: 212 raise ApiServerError(response.text) 213 214 215T = t.TypeVar("T", bound=PydanticModel) 216 217 218def _list_to_json(models: t.List[T]) -> str: 219 return json.dumps([m.dict() for m in models], separators=(",", ":"))
class
AirflowClient:
33class AirflowClient: 34 def __init__( 35 self, 36 session: requests.Session, 37 airflow_url: str, 38 console: t.Optional[Console] = None, 39 ): 40 self._session = session 41 self._airflow_url = airflow_url 42 self._console = console 43 44 def apply_plan( 45 self, 46 new_snapshots: t.Iterable[Snapshot], 47 environment: Environment, 48 request_id: str, 49 no_gaps: bool = False, 50 skip_backfill: bool = False, 51 restatements: t.Optional[t.Iterable[str]] = None, 52 notification_targets: t.Optional[t.List[NotificationTarget]] = None, 53 backfill_concurrent_tasks: int = 1, 54 ddl_concurrent_tasks: int = 1, 55 users: t.Optional[t.List[User]] = None, 56 is_dev: bool = False, 57 ) -> None: 58 request = common.PlanApplicationRequest( 59 new_snapshots=list(new_snapshots), 60 environment=environment, 61 no_gaps=no_gaps, 62 skip_backfill=skip_backfill, 63 request_id=request_id, 64 restatements=set(restatements or []), 65 notification_targets=notification_targets or [], 66 backfill_concurrent_tasks=backfill_concurrent_tasks, 67 ddl_concurrent_tasks=ddl_concurrent_tasks, 68 users=users or [], 69 is_dev=is_dev, 70 ) 71 72 response = self._session.post( 73 urljoin(self._airflow_url, PLANS_PATH), 74 data=request.json(), 75 ) 76 self._raise_for_status(response) 77 78 def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]: 79 params: t.Dict[str, str] = {} 80 if snapshot_ids is not None: 81 params["ids"] = _list_to_json(snapshot_ids) 82 83 return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, **params)).snapshots 84 85 def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]: 86 return set( 87 common.SnapshotIdsResponse.parse_obj( 88 self._get(SNAPSHOTS_PATH, "check_existence", ids=_list_to_json(snapshot_ids)) 89 ).snapshot_ids 90 ) 91 92 def get_snapshots_with_same_version( 93 self, snapshot_name_versions: t.List[SnapshotNameVersion] 94 ) -> t.List[Snapshot]: 95 return common.SnapshotsResponse.parse_obj( 96 self._get(SNAPSHOTS_PATH, versions=_list_to_json(snapshot_name_versions)) 97 ).snapshots 98 99 def get_environment(self, environment: str) -> t.Optional[Environment]: 100 try: 101 response = self._get(f"{ENVIRONMENTS_PATH}/{environment}") 102 return Environment.parse_obj(response) 103 except NotFoundError: 104 return None 105 106 def get_environments(self) -> t.List[Environment]: 107 response = self._get(ENVIRONMENTS_PATH) 108 return common.EnvironmentsResponse.parse_obj(response).environments 109 110 def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: 111 url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}" 112 return self._get(url)["state"].lower() 113 114 def get_janitor_dag(self) -> t.Dict[str, t.Any]: 115 return self._get_dag(common.JANITOR_DAG_ID) 116 117 def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]: 118 return self._get_dag(common.dag_id_for_name_version(name, version)) 119 120 def get_all_dags(self) -> t.Dict[str, t.Any]: 121 return self._get("api/v1/dags") 122 123 def wait_for_dag_run_completion( 124 self, dag_id: str, dag_run_id: str, poll_interval_secs: int 125 ) -> bool: 126 loading_id = self._console_loading_start() 127 128 while True: 129 state = self.get_dag_run_state(dag_id, dag_run_id) 130 if state in ("failed", "success"): 131 if self._console and loading_id: 132 self._console.loading_stop(loading_id) 133 return state == "success" 134 135 time.sleep(poll_interval_secs) 136 137 def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: 138 loading_id = self._console_loading_start() 139 140 attempt_num = 1 141 142 try: 143 while True: 144 try: 145 first_dag_run_id = self._get_first_dag_run_id(dag_id) 146 if first_dag_run_id is None: 147 raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") 148 return first_dag_run_id 149 except ApiServerError: 150 raise 151 except SQLMeshError: 152 if attempt_num > max_retries: 153 raise 154 155 attempt_num += 1 156 time.sleep(poll_interval_secs) 157 finally: 158 if self._console and loading_id: 159 self._console.loading_stop(loading_id) 160 161 def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: 162 if not self._console: 163 return 164 165 tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) 166 # TODO: Figure out generalized solution for links 167 self._console.log_status_update( 168 f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" 169 ) 170 171 def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str: 172 url_params = urlencode( 173 dict( 174 dag_id=dag_id, 175 run_id=dag_run_id, 176 ) 177 ) 178 return urljoin(self._airflow_url, f"dagrun_details?{url_params}") 179 180 def close(self) -> None: 181 self._session.close() 182 183 def _get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: 184 dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1") 185 dag_runs = dag_runs_response["dag_runs"] 186 if not dag_runs: 187 return None 188 return dag_runs[0]["dag_run_id"] 189 190 def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]: 191 return self._get(f"api/v1/dags/{dag_id}") 192 193 def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]: 194 all_params = [*flags, *([urlencode(params)] if params else [])] 195 query_string = "&".join(all_params) 196 if query_string: 197 path = f"{path}?{query_string}" 198 response = self._session.get(urljoin(self._airflow_url, path)) 199 self._raise_for_status(response) 200 return response.json() 201 202 def _console_loading_start(self) -> t.Optional[uuid.UUID]: 203 if self._console: 204 return self._console.loading_start() 205 return None 206 207 def _raise_for_status(self, response: Response) -> None: 208 if response.status_code == 404: 209 raise NotFoundError(response.text) 210 elif 400 <= response.status_code < 500: 211 raise ApiClientError(response.text) 212 elif 500 <= response.status_code < 600: 213 raise ApiServerError(response.text)
AirflowClient( session: requests.sessions.Session, airflow_url: str, console: Optional[sqlmesh.core.console.Console] = None)
def
apply_plan( self, new_snapshots: Iterable[sqlmesh.core.snapshot.definition.Snapshot], environment: sqlmesh.core.environment.Environment, request_id: str, no_gaps: bool = False, skip_backfill: bool = False, restatements: Optional[Iterable[str]] = None, notification_targets: Optional[List[Annotated[Union[sqlmesh.core.notification_target.ConsoleNotificationTarget, sqlmesh.integrations.github.notification_target.GithubNotificationTarget], FieldInfo(default=PydanticUndefined, discriminator='type_', extra={})]]] = None, backfill_concurrent_tasks: int = 1, ddl_concurrent_tasks: int = 1, users: Optional[List[sqlmesh.core.user.User]] = None, is_dev: bool = False) -> None:
44 def apply_plan( 45 self, 46 new_snapshots: t.Iterable[Snapshot], 47 environment: Environment, 48 request_id: str, 49 no_gaps: bool = False, 50 skip_backfill: bool = False, 51 restatements: t.Optional[t.Iterable[str]] = None, 52 notification_targets: t.Optional[t.List[NotificationTarget]] = None, 53 backfill_concurrent_tasks: int = 1, 54 ddl_concurrent_tasks: int = 1, 55 users: t.Optional[t.List[User]] = None, 56 is_dev: bool = False, 57 ) -> None: 58 request = common.PlanApplicationRequest( 59 new_snapshots=list(new_snapshots), 60 environment=environment, 61 no_gaps=no_gaps, 62 skip_backfill=skip_backfill, 63 request_id=request_id, 64 restatements=set(restatements or []), 65 notification_targets=notification_targets or [], 66 backfill_concurrent_tasks=backfill_concurrent_tasks, 67 ddl_concurrent_tasks=ddl_concurrent_tasks, 68 users=users or [], 69 is_dev=is_dev, 70 ) 71 72 response = self._session.post( 73 urljoin(self._airflow_url, PLANS_PATH), 74 data=request.json(), 75 ) 76 self._raise_for_status(response)
def
get_snapshots( self, snapshot_ids: Optional[List[sqlmesh.core.snapshot.definition.SnapshotId]]) -> List[sqlmesh.core.snapshot.definition.Snapshot]:
78 def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]: 79 params: t.Dict[str, str] = {} 80 if snapshot_ids is not None: 81 params["ids"] = _list_to_json(snapshot_ids) 82 83 return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, **params)).snapshots
def
snapshots_exist( self, snapshot_ids: List[sqlmesh.core.snapshot.definition.SnapshotId]) -> Set[sqlmesh.core.snapshot.definition.SnapshotId]:
def
get_snapshots_with_same_version( self, snapshot_name_versions: List[sqlmesh.core.snapshot.definition.SnapshotNameVersion]) -> List[sqlmesh.core.snapshot.definition.Snapshot]:
def
wait_for_dag_run_completion(self, dag_id: str, dag_run_id: str, poll_interval_secs: int) -> bool:
123 def wait_for_dag_run_completion( 124 self, dag_id: str, dag_run_id: str, poll_interval_secs: int 125 ) -> bool: 126 loading_id = self._console_loading_start() 127 128 while True: 129 state = self.get_dag_run_state(dag_id, dag_run_id) 130 if state in ("failed", "success"): 131 if self._console and loading_id: 132 self._console.loading_stop(loading_id) 133 return state == "success" 134 135 time.sleep(poll_interval_secs)
def
wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
137 def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: 138 loading_id = self._console_loading_start() 139 140 attempt_num = 1 141 142 try: 143 while True: 144 try: 145 first_dag_run_id = self._get_first_dag_run_id(dag_id) 146 if first_dag_run_id is None: 147 raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") 148 return first_dag_run_id 149 except ApiServerError: 150 raise 151 except SQLMeshError: 152 if attempt_num > max_retries: 153 raise 154 155 attempt_num += 1 156 time.sleep(poll_interval_secs) 157 finally: 158 if self._console and loading_id: 159 self._console.loading_stop(loading_id)
def
print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
161 def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: 162 if not self._console: 163 return 164 165 tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) 166 # TODO: Figure out generalized solution for links 167 self._console.log_status_update( 168 f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" 169 )