Edit on GitHub

sqlmesh.schedulers.airflow.api

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from functools import wraps
  6
  7from airflow.api_connexion import security
  8from airflow.models import Variable
  9from airflow.www.app import csrf
 10from flask import Blueprint, Response, jsonify, make_response, request
 11
 12from sqlmesh.core import constants as c
 13from sqlmesh.core.snapshot import SnapshotId, SnapshotNameVersion
 14from sqlmesh.schedulers.airflow import common, util
 15from sqlmesh.schedulers.airflow.plan import create_plan_dag_spec
 16from sqlmesh.utils.pydantic import PydanticModel
 17
 18sqlmesh_api_v1 = Blueprint(
 19    c.SQLMESH,
 20    __name__,
 21    url_prefix=f"/{common.SQLMESH_API_BASE_PATH}",
 22)
 23
 24
 25def check_authentication(func: t.Callable) -> t.Callable:
 26    @wraps(func)
 27    def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
 28        security.check_authentication()
 29        return func(*args, **kwargs)
 30
 31    return wrapper
 32
 33
 34@sqlmesh_api_v1.post("/plans")
 35@csrf.exempt
 36@check_authentication
 37def apply_plan() -> Response:
 38    try:
 39        plan = common.PlanApplicationRequest.parse_obj(request.json or {})
 40        with util.scoped_state_sync() as state_sync:
 41            spec = create_plan_dag_spec(plan, state_sync)
 42    except Exception as ex:
 43        return _error(str(ex))
 44
 45    Variable.set(common.plan_dag_spec_key(spec.request_id), spec.json())
 46
 47    return make_response(jsonify(request_id=spec.request_id), 201)
 48
 49
 50@sqlmesh_api_v1.get("/environments/<name>")
 51@csrf.exempt
 52@check_authentication
 53def get_environment(name: str) -> Response:
 54    with util.scoped_state_sync() as state_sync:
 55        environment = state_sync.get_environment(name)
 56    if environment is None:
 57        return _error(f"Environment '{name}' was not found", 404)
 58    return _success(environment)
 59
 60
 61@sqlmesh_api_v1.get("/environments")
 62@csrf.exempt
 63@check_authentication
 64def get_environments() -> Response:
 65    with util.scoped_state_sync() as state_sync:
 66        environments = state_sync.get_environments()
 67    return _success(common.EnvironmentsResponse(environments=environments))
 68
 69
 70@sqlmesh_api_v1.get("/snapshots")
 71@csrf.exempt
 72@check_authentication
 73def get_snapshots() -> Response:
 74    with util.scoped_state_sync() as state_sync:
 75        snapshot_name_versions = _snapshot_name_versions_from_request()
 76        if snapshot_name_versions is not None:
 77            snapshots = state_sync.get_snapshots_with_same_version(snapshot_name_versions)
 78        else:
 79            snapshot_ids = _snapshot_ids_from_request()
 80
 81            if "check_existence" in request.args:
 82                existing_snapshot_ids = (
 83                    state_sync.snapshots_exist(snapshot_ids) if snapshot_ids is not None else set()
 84                )
 85                return _success(common.SnapshotIdsResponse(snapshot_ids=existing_snapshot_ids))
 86
 87            snapshots = list(state_sync.get_snapshots(snapshot_ids).values())
 88
 89        return _success(common.SnapshotsResponse(snapshots=snapshots))
 90
 91
 92T = t.TypeVar("T", bound=PydanticModel)
 93
 94
 95def _success(data: T, status_code: int = 200) -> Response:
 96    response = make_response(data.json(), status_code)
 97    response.mimetype = "application/json"
 98    return response
 99
100
101def _error(message: str, status_code: int = 400) -> Response:
102    return make_response(jsonify(message=message), status_code)
103
104
105def _snapshot_ids_from_request() -> t.Optional[t.List[SnapshotId]]:
106    if "ids" not in request.args:
107        return None
108
109    raw_ids = json.loads(request.args["ids"])
110    return [SnapshotId.parse_obj(i) for i in raw_ids]
111
112
113def _snapshot_name_versions_from_request() -> t.Optional[t.List[SnapshotNameVersion]]:
114    if "versions" not in request.args:
115        return None
116
117    raw_versions = json.loads(request.args["versions"])
118    return [SnapshotNameVersion.parse_obj(v) for v in raw_versions]
def check_authentication(func: Callable) -> Callable:
26def check_authentication(func: t.Callable) -> t.Callable:
27    @wraps(func)
28    def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
29        security.check_authentication()
30        return func(*args, **kwargs)
31
32    return wrapper
@sqlmesh_api_v1.post('/plans')
@csrf.exempt
@check_authentication
def apply_plan() -> flask.wrappers.Response:
35@sqlmesh_api_v1.post("/plans")
36@csrf.exempt
37@check_authentication
38def apply_plan() -> Response:
39    try:
40        plan = common.PlanApplicationRequest.parse_obj(request.json or {})
41        with util.scoped_state_sync() as state_sync:
42            spec = create_plan_dag_spec(plan, state_sync)
43    except Exception as ex:
44        return _error(str(ex))
45
46    Variable.set(common.plan_dag_spec_key(spec.request_id), spec.json())
47
48    return make_response(jsonify(request_id=spec.request_id), 201)
@sqlmesh_api_v1.get('/environments/<name>')
@csrf.exempt
@check_authentication
def get_environment(name: str) -> flask.wrappers.Response:
51@sqlmesh_api_v1.get("/environments/<name>")
52@csrf.exempt
53@check_authentication
54def get_environment(name: str) -> Response:
55    with util.scoped_state_sync() as state_sync:
56        environment = state_sync.get_environment(name)
57    if environment is None:
58        return _error(f"Environment '{name}' was not found", 404)
59    return _success(environment)
@sqlmesh_api_v1.get('/environments')
@csrf.exempt
@check_authentication
def get_environments() -> flask.wrappers.Response:
62@sqlmesh_api_v1.get("/environments")
63@csrf.exempt
64@check_authentication
65def get_environments() -> Response:
66    with util.scoped_state_sync() as state_sync:
67        environments = state_sync.get_environments()
68    return _success(common.EnvironmentsResponse(environments=environments))
@sqlmesh_api_v1.get('/snapshots')
@csrf.exempt
@check_authentication
def get_snapshots() -> flask.wrappers.Response:
71@sqlmesh_api_v1.get("/snapshots")
72@csrf.exempt
73@check_authentication
74def get_snapshots() -> Response:
75    with util.scoped_state_sync() as state_sync:
76        snapshot_name_versions = _snapshot_name_versions_from_request()
77        if snapshot_name_versions is not None:
78            snapshots = state_sync.get_snapshots_with_same_version(snapshot_name_versions)
79        else:
80            snapshot_ids = _snapshot_ids_from_request()
81
82            if "check_existence" in request.args:
83                existing_snapshot_ids = (
84                    state_sync.snapshots_exist(snapshot_ids) if snapshot_ids is not None else set()
85                )
86                return _success(common.SnapshotIdsResponse(snapshot_ids=existing_snapshot_ids))
87
88            snapshots = list(state_sync.get_snapshots(snapshot_ids).values())
89
90        return _success(common.SnapshotsResponse(snapshots=snapshots))