Edit on GitHub

sqlmesh.schedulers.airflow.client

  1import json
  2import time
  3import typing as t
  4import uuid
  5from urllib.parse import urlencode, urljoin
  6
  7import requests
  8from requests.models import Response
  9
 10from sqlmesh.core._typing import NotificationTarget
 11from sqlmesh.core.console import Console
 12from sqlmesh.core.environment import Environment
 13from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotNameVersion
 14from sqlmesh.core.user import User
 15from sqlmesh.schedulers.airflow import common
 16from sqlmesh.utils.errors import (
 17    ApiClientError,
 18    ApiServerError,
 19    NotFoundError,
 20    SQLMeshError,
 21)
 22from sqlmesh.utils.pydantic import PydanticModel
 23
 24DAG_RUN_PATH_TEMPLATE = "api/v1/dags/{}/dagRuns"
 25
 26
 27PLANS_PATH = f"{common.SQLMESH_API_BASE_PATH}/plans"
 28ENVIRONMENTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/environments"
 29SNAPSHOTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/snapshots"
 30
 31
 32class AirflowClient:
 33    def __init__(
 34        self,
 35        session: requests.Session,
 36        airflow_url: str,
 37        console: t.Optional[Console] = None,
 38    ):
 39        self._session = session
 40        self._airflow_url = airflow_url
 41        self._console = console
 42
 43    def apply_plan(
 44        self,
 45        new_snapshots: t.Iterable[Snapshot],
 46        environment: Environment,
 47        request_id: str,
 48        no_gaps: bool = False,
 49        skip_backfill: bool = False,
 50        restatements: t.Optional[t.Iterable[str]] = None,
 51        notification_targets: t.Optional[t.List[NotificationTarget]] = None,
 52        backfill_concurrent_tasks: int = 1,
 53        ddl_concurrent_tasks: int = 1,
 54        users: t.Optional[t.List[User]] = None,
 55        is_dev: bool = False,
 56    ) -> None:
 57        request = common.PlanApplicationRequest(
 58            new_snapshots=list(new_snapshots),
 59            environment=environment,
 60            no_gaps=no_gaps,
 61            skip_backfill=skip_backfill,
 62            request_id=request_id,
 63            restatements=set(restatements or []),
 64            notification_targets=notification_targets or [],
 65            backfill_concurrent_tasks=backfill_concurrent_tasks,
 66            ddl_concurrent_tasks=ddl_concurrent_tasks,
 67            users=users or [],
 68            is_dev=is_dev,
 69        )
 70
 71        response = self._session.post(
 72            urljoin(self._airflow_url, PLANS_PATH),
 73            data=request.json(),
 74        )
 75        self._raise_for_status(response)
 76
 77    def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]:
 78        params: t.Dict[str, str] = {}
 79        if snapshot_ids is not None:
 80            params["ids"] = _list_to_json(snapshot_ids)
 81
 82        return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, **params)).snapshots
 83
 84    def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]:
 85        return set(
 86            common.SnapshotIdsResponse.parse_obj(
 87                self._get(SNAPSHOTS_PATH, "check_existence", ids=_list_to_json(snapshot_ids))
 88            ).snapshot_ids
 89        )
 90
 91    def get_snapshots_with_same_version(
 92        self, snapshot_name_versions: t.List[SnapshotNameVersion]
 93    ) -> t.List[Snapshot]:
 94        return common.SnapshotsResponse.parse_obj(
 95            self._get(SNAPSHOTS_PATH, versions=_list_to_json(snapshot_name_versions))
 96        ).snapshots
 97
 98    def get_environment(self, environment: str) -> t.Optional[Environment]:
 99        try:
100            response = self._get(f"{ENVIRONMENTS_PATH}/{environment}")
101            return Environment.parse_obj(response)
102        except NotFoundError:
103            return None
104
105    def get_environments(self) -> t.List[Environment]:
106        response = self._get(ENVIRONMENTS_PATH)
107        return common.EnvironmentsResponse.parse_obj(response).environments
108
109    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
110        url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}"
111        return self._get(url)["state"].lower()
112
113    def get_janitor_dag(self) -> t.Dict[str, t.Any]:
114        return self._get_dag(common.JANITOR_DAG_ID)
115
116    def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]:
117        return self._get_dag(common.dag_id_for_name_version(name, version))
118
119    def get_all_dags(self) -> t.Dict[str, t.Any]:
120        return self._get("api/v1/dags")
121
122    def wait_for_dag_run_completion(
123        self, dag_id: str, dag_run_id: str, poll_interval_secs: int
124    ) -> bool:
125        loading_id = self._console_loading_start()
126
127        while True:
128            state = self.get_dag_run_state(dag_id, dag_run_id)
129            if state in ("failed", "success"):
130                if self._console and loading_id:
131                    self._console.loading_stop(loading_id)
132                return state == "success"
133
134            time.sleep(poll_interval_secs)
135
136    def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
137        loading_id = self._console_loading_start()
138
139        attempt_num = 1
140
141        try:
142            while True:
143                try:
144                    first_dag_run_id = self._get_first_dag_run_id(dag_id)
145                    if first_dag_run_id is None:
146                        raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'")
147                    return first_dag_run_id
148                except ApiServerError:
149                    raise
150                except SQLMeshError:
151                    if attempt_num > max_retries:
152                        raise
153
154                attempt_num += 1
155                time.sleep(poll_interval_secs)
156        finally:
157            if self._console and loading_id:
158                self._console.loading_stop(loading_id)
159
160    def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
161        if not self._console:
162            return
163
164        tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id)
165        # TODO: Figure out generalized solution for links
166        self._console.log_status_update(
167            f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]"
168        )
169
170    def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
171        url_params = urlencode(
172            dict(
173                dag_id=dag_id,
174                run_id=dag_run_id,
175            )
176        )
177        return urljoin(self._airflow_url, f"dagrun_details?{url_params}")
178
179    def close(self) -> None:
180        self._session.close()
181
182    def _get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
183        dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1")
184        dag_runs = dag_runs_response["dag_runs"]
185        if not dag_runs:
186            return None
187        return dag_runs[0]["dag_run_id"]
188
189    def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]:
190        return self._get(f"api/v1/dags/{dag_id}")
191
192    def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]:
193        all_params = [*flags, *([urlencode(params)] if params else [])]
194        query_string = "&".join(all_params)
195        if query_string:
196            path = f"{path}?{query_string}"
197        response = self._session.get(urljoin(self._airflow_url, path))
198        self._raise_for_status(response)
199        return response.json()
200
201    def _console_loading_start(self) -> t.Optional[uuid.UUID]:
202        if self._console:
203            return self._console.loading_start()
204        return None
205
206    def _raise_for_status(self, response: Response) -> None:
207        if response.status_code == 404:
208            raise NotFoundError(response.text)
209        elif 400 <= response.status_code < 500:
210            raise ApiClientError(response.text)
211        elif 500 <= response.status_code < 600:
212            raise ApiServerError(response.text)
213
214
215T = t.TypeVar("T", bound=PydanticModel)
216
217
218def _list_to_json(models: t.List[T]) -> str:
219    return json.dumps([m.dict() for m in models], separators=(",", ":"))
class AirflowClient:
 33class AirflowClient:
 34    def __init__(
 35        self,
 36        session: requests.Session,
 37        airflow_url: str,
 38        console: t.Optional[Console] = None,
 39    ):
 40        self._session = session
 41        self._airflow_url = airflow_url
 42        self._console = console
 43
 44    def apply_plan(
 45        self,
 46        new_snapshots: t.Iterable[Snapshot],
 47        environment: Environment,
 48        request_id: str,
 49        no_gaps: bool = False,
 50        skip_backfill: bool = False,
 51        restatements: t.Optional[t.Iterable[str]] = None,
 52        notification_targets: t.Optional[t.List[NotificationTarget]] = None,
 53        backfill_concurrent_tasks: int = 1,
 54        ddl_concurrent_tasks: int = 1,
 55        users: t.Optional[t.List[User]] = None,
 56        is_dev: bool = False,
 57    ) -> None:
 58        request = common.PlanApplicationRequest(
 59            new_snapshots=list(new_snapshots),
 60            environment=environment,
 61            no_gaps=no_gaps,
 62            skip_backfill=skip_backfill,
 63            request_id=request_id,
 64            restatements=set(restatements or []),
 65            notification_targets=notification_targets or [],
 66            backfill_concurrent_tasks=backfill_concurrent_tasks,
 67            ddl_concurrent_tasks=ddl_concurrent_tasks,
 68            users=users or [],
 69            is_dev=is_dev,
 70        )
 71
 72        response = self._session.post(
 73            urljoin(self._airflow_url, PLANS_PATH),
 74            data=request.json(),
 75        )
 76        self._raise_for_status(response)
 77
 78    def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]:
 79        params: t.Dict[str, str] = {}
 80        if snapshot_ids is not None:
 81            params["ids"] = _list_to_json(snapshot_ids)
 82
 83        return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, **params)).snapshots
 84
 85    def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]:
 86        return set(
 87            common.SnapshotIdsResponse.parse_obj(
 88                self._get(SNAPSHOTS_PATH, "check_existence", ids=_list_to_json(snapshot_ids))
 89            ).snapshot_ids
 90        )
 91
 92    def get_snapshots_with_same_version(
 93        self, snapshot_name_versions: t.List[SnapshotNameVersion]
 94    ) -> t.List[Snapshot]:
 95        return common.SnapshotsResponse.parse_obj(
 96            self._get(SNAPSHOTS_PATH, versions=_list_to_json(snapshot_name_versions))
 97        ).snapshots
 98
 99    def get_environment(self, environment: str) -> t.Optional[Environment]:
100        try:
101            response = self._get(f"{ENVIRONMENTS_PATH}/{environment}")
102            return Environment.parse_obj(response)
103        except NotFoundError:
104            return None
105
106    def get_environments(self) -> t.List[Environment]:
107        response = self._get(ENVIRONMENTS_PATH)
108        return common.EnvironmentsResponse.parse_obj(response).environments
109
110    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
111        url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}"
112        return self._get(url)["state"].lower()
113
114    def get_janitor_dag(self) -> t.Dict[str, t.Any]:
115        return self._get_dag(common.JANITOR_DAG_ID)
116
117    def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]:
118        return self._get_dag(common.dag_id_for_name_version(name, version))
119
120    def get_all_dags(self) -> t.Dict[str, t.Any]:
121        return self._get("api/v1/dags")
122
123    def wait_for_dag_run_completion(
124        self, dag_id: str, dag_run_id: str, poll_interval_secs: int
125    ) -> bool:
126        loading_id = self._console_loading_start()
127
128        while True:
129            state = self.get_dag_run_state(dag_id, dag_run_id)
130            if state in ("failed", "success"):
131                if self._console and loading_id:
132                    self._console.loading_stop(loading_id)
133                return state == "success"
134
135            time.sleep(poll_interval_secs)
136
137    def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
138        loading_id = self._console_loading_start()
139
140        attempt_num = 1
141
142        try:
143            while True:
144                try:
145                    first_dag_run_id = self._get_first_dag_run_id(dag_id)
146                    if first_dag_run_id is None:
147                        raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'")
148                    return first_dag_run_id
149                except ApiServerError:
150                    raise
151                except SQLMeshError:
152                    if attempt_num > max_retries:
153                        raise
154
155                attempt_num += 1
156                time.sleep(poll_interval_secs)
157        finally:
158            if self._console and loading_id:
159                self._console.loading_stop(loading_id)
160
161    def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
162        if not self._console:
163            return
164
165        tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id)
166        # TODO: Figure out generalized solution for links
167        self._console.log_status_update(
168            f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]"
169        )
170
171    def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
172        url_params = urlencode(
173            dict(
174                dag_id=dag_id,
175                run_id=dag_run_id,
176            )
177        )
178        return urljoin(self._airflow_url, f"dagrun_details?{url_params}")
179
180    def close(self) -> None:
181        self._session.close()
182
183    def _get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
184        dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1")
185        dag_runs = dag_runs_response["dag_runs"]
186        if not dag_runs:
187            return None
188        return dag_runs[0]["dag_run_id"]
189
190    def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]:
191        return self._get(f"api/v1/dags/{dag_id}")
192
193    def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]:
194        all_params = [*flags, *([urlencode(params)] if params else [])]
195        query_string = "&".join(all_params)
196        if query_string:
197            path = f"{path}?{query_string}"
198        response = self._session.get(urljoin(self._airflow_url, path))
199        self._raise_for_status(response)
200        return response.json()
201
202    def _console_loading_start(self) -> t.Optional[uuid.UUID]:
203        if self._console:
204            return self._console.loading_start()
205        return None
206
207    def _raise_for_status(self, response: Response) -> None:
208        if response.status_code == 404:
209            raise NotFoundError(response.text)
210        elif 400 <= response.status_code < 500:
211            raise ApiClientError(response.text)
212        elif 500 <= response.status_code < 600:
213            raise ApiServerError(response.text)
AirflowClient( session: requests.sessions.Session, airflow_url: str, console: Optional[sqlmesh.core.console.Console] = None)
34    def __init__(
35        self,
36        session: requests.Session,
37        airflow_url: str,
38        console: t.Optional[Console] = None,
39    ):
40        self._session = session
41        self._airflow_url = airflow_url
42        self._console = console
def apply_plan( self, new_snapshots: Iterable[sqlmesh.core.snapshot.definition.Snapshot], environment: sqlmesh.core.environment.Environment, request_id: str, no_gaps: bool = False, skip_backfill: bool = False, restatements: Optional[Iterable[str]] = None, notification_targets: Optional[List[Annotated[Union[sqlmesh.core.notification_target.ConsoleNotificationTarget, sqlmesh.integrations.github.notification_target.GithubNotificationTarget], FieldInfo(default=PydanticUndefined, discriminator='type_', extra={})]]] = None, backfill_concurrent_tasks: int = 1, ddl_concurrent_tasks: int = 1, users: Optional[List[sqlmesh.core.user.User]] = None, is_dev: bool = False) -> None:
44    def apply_plan(
45        self,
46        new_snapshots: t.Iterable[Snapshot],
47        environment: Environment,
48        request_id: str,
49        no_gaps: bool = False,
50        skip_backfill: bool = False,
51        restatements: t.Optional[t.Iterable[str]] = None,
52        notification_targets: t.Optional[t.List[NotificationTarget]] = None,
53        backfill_concurrent_tasks: int = 1,
54        ddl_concurrent_tasks: int = 1,
55        users: t.Optional[t.List[User]] = None,
56        is_dev: bool = False,
57    ) -> None:
58        request = common.PlanApplicationRequest(
59            new_snapshots=list(new_snapshots),
60            environment=environment,
61            no_gaps=no_gaps,
62            skip_backfill=skip_backfill,
63            request_id=request_id,
64            restatements=set(restatements or []),
65            notification_targets=notification_targets or [],
66            backfill_concurrent_tasks=backfill_concurrent_tasks,
67            ddl_concurrent_tasks=ddl_concurrent_tasks,
68            users=users or [],
69            is_dev=is_dev,
70        )
71
72        response = self._session.post(
73            urljoin(self._airflow_url, PLANS_PATH),
74            data=request.json(),
75        )
76        self._raise_for_status(response)
def get_snapshots( self, snapshot_ids: Optional[List[sqlmesh.core.snapshot.definition.SnapshotId]]) -> List[sqlmesh.core.snapshot.definition.Snapshot]:
78    def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]:
79        params: t.Dict[str, str] = {}
80        if snapshot_ids is not None:
81            params["ids"] = _list_to_json(snapshot_ids)
82
83        return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH, **params)).snapshots
def snapshots_exist( self, snapshot_ids: List[sqlmesh.core.snapshot.definition.SnapshotId]) -> Set[sqlmesh.core.snapshot.definition.SnapshotId]:
85    def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]:
86        return set(
87            common.SnapshotIdsResponse.parse_obj(
88                self._get(SNAPSHOTS_PATH, "check_existence", ids=_list_to_json(snapshot_ids))
89            ).snapshot_ids
90        )
def get_snapshots_with_same_version( self, snapshot_name_versions: List[sqlmesh.core.snapshot.definition.SnapshotNameVersion]) -> List[sqlmesh.core.snapshot.definition.Snapshot]:
92    def get_snapshots_with_same_version(
93        self, snapshot_name_versions: t.List[SnapshotNameVersion]
94    ) -> t.List[Snapshot]:
95        return common.SnapshotsResponse.parse_obj(
96            self._get(SNAPSHOTS_PATH, versions=_list_to_json(snapshot_name_versions))
97        ).snapshots
def get_environment(self, environment: str) -> Optional[sqlmesh.core.environment.Environment]:
 99    def get_environment(self, environment: str) -> t.Optional[Environment]:
100        try:
101            response = self._get(f"{ENVIRONMENTS_PATH}/{environment}")
102            return Environment.parse_obj(response)
103        except NotFoundError:
104            return None
def get_environments(self) -> List[sqlmesh.core.environment.Environment]:
106    def get_environments(self) -> t.List[Environment]:
107        response = self._get(ENVIRONMENTS_PATH)
108        return common.EnvironmentsResponse.parse_obj(response).environments
def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
110    def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str:
111        url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}"
112        return self._get(url)["state"].lower()
def get_janitor_dag(self) -> Dict[str, Any]:
114    def get_janitor_dag(self) -> t.Dict[str, t.Any]:
115        return self._get_dag(common.JANITOR_DAG_ID)
def get_snapshot_dag(self, name: str, version: str) -> Dict[str, Any]:
117    def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]:
118        return self._get_dag(common.dag_id_for_name_version(name, version))
def get_all_dags(self) -> Dict[str, Any]:
120    def get_all_dags(self) -> t.Dict[str, t.Any]:
121        return self._get("api/v1/dags")
def wait_for_dag_run_completion(self, dag_id: str, dag_run_id: str, poll_interval_secs: int) -> bool:
123    def wait_for_dag_run_completion(
124        self, dag_id: str, dag_run_id: str, poll_interval_secs: int
125    ) -> bool:
126        loading_id = self._console_loading_start()
127
128        while True:
129            state = self.get_dag_run_state(dag_id, dag_run_id)
130            if state in ("failed", "success"):
131                if self._console and loading_id:
132                    self._console.loading_stop(loading_id)
133                return state == "success"
134
135            time.sleep(poll_interval_secs)
def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
137    def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str:
138        loading_id = self._console_loading_start()
139
140        attempt_num = 1
141
142        try:
143            while True:
144                try:
145                    first_dag_run_id = self._get_first_dag_run_id(dag_id)
146                    if first_dag_run_id is None:
147                        raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'")
148                    return first_dag_run_id
149                except ApiServerError:
150                    raise
151                except SQLMeshError:
152                    if attempt_num > max_retries:
153                        raise
154
155                attempt_num += 1
156                time.sleep(poll_interval_secs)
157        finally:
158            if self._console and loading_id:
159                self._console.loading_stop(loading_id)
def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
161    def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None:
162        if not self._console:
163            return
164
165        tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id)
166        # TODO: Figure out generalized solution for links
167        self._console.log_status_update(
168            f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]"
169        )
def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
171    def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str:
172        url_params = urlencode(
173            dict(
174                dag_id=dag_id,
175                run_id=dag_run_id,
176            )
177        )
178        return urljoin(self._airflow_url, f"dagrun_details?{url_params}")
def close(self) -> None:
180    def close(self) -> None:
181        self._session.close()