Edit on GitHub

sqlmesh.schedulers.airflow.plan

  1from __future__ import annotations
  2
  3import typing as t
  4from collections import defaultdict
  5
  6from sqlmesh.core import scheduler
  7from sqlmesh.core.environment import Environment
  8from sqlmesh.core.snapshot import SnapshotTableInfo
  9from sqlmesh.core.state_sync import StateSync
 10from sqlmesh.schedulers.airflow import common
 11from sqlmesh.utils.date import now
 12from sqlmesh.utils.errors import SQLMeshError
 13
 14
 15def create_plan_dag_spec(
 16    request: common.PlanApplicationRequest, state_sync: StateSync
 17) -> common.PlanDagSpec:
 18    new_snapshots = {s.snapshot_id: s for s in request.new_snapshots}
 19    stored_snapshots = state_sync.get_snapshots(None)
 20
 21    duplicated_snapshots = set(stored_snapshots).intersection(new_snapshots)
 22    if duplicated_snapshots:
 23        raise SQLMeshError(
 24            f"Snapshots {duplicated_snapshots} already exist. "
 25            "Make sure your code base is up to date and try re-creating the plan"
 26        )
 27
 28    if request.environment.end_at:
 29        end = request.environment.end_at
 30        unpaused_dt = None
 31    else:
 32        # Unbounded end date means we need to unpause all paused snapshots
 33        # that are part of the target environment.
 34        end = now()
 35        unpaused_dt = end
 36
 37    snapshots_for_intervals = {**new_snapshots, **stored_snapshots}
 38    all_snapshots_by_version = defaultdict(set)
 39    for snapshot in snapshots_for_intervals.values():
 40        all_snapshots_by_version[(snapshot.name, snapshot.version)].add(snapshot.snapshot_id)
 41
 42    if request.is_dev:
 43        # When in development mode exclude snapshots that match the version of each
 44        # paused forward-only snapshot that is a part of the plan.
 45        for s in request.environment.snapshots:
 46            if s.is_forward_only and snapshots_for_intervals[s.snapshot_id].is_paused:
 47                previous_snapshot_ids = all_snapshots_by_version[(s.name, s.version)] - {
 48                    s.snapshot_id
 49                }
 50                for sid in previous_snapshot_ids:
 51                    snapshots_for_intervals.pop(sid)
 52
 53    if request.restatements:
 54        state_sync.remove_interval(
 55            [],
 56            start=request.environment.start_at,
 57            end=end,
 58            all_snapshots=(
 59                snapshot
 60                for snapshot in snapshots_for_intervals.values()
 61                if snapshot.name in request.restatements
 62                and snapshot.snapshot_id not in new_snapshots
 63            ),
 64        )
 65
 66    if not request.skip_backfill:
 67        backfill_batches = scheduler.compute_interval_params(
 68            request.environment.snapshots,
 69            snapshots=snapshots_for_intervals,
 70            start=request.environment.start_at,
 71            end=end,
 72            latest=end,
 73        )
 74    else:
 75        backfill_batches = {}
 76
 77    backfill_intervals_per_snapshot = [
 78        common.BackfillIntervalsPerSnapshot(
 79            snapshot_id=snapshot.snapshot_id,
 80            intervals=intervals,
 81        )
 82        for snapshot, intervals in backfill_batches.items()
 83    ]
 84
 85    return common.PlanDagSpec(
 86        request_id=request.request_id,
 87        environment_name=request.environment.name,
 88        new_snapshots=request.new_snapshots,
 89        backfill_intervals_per_snapshot=backfill_intervals_per_snapshot,
 90        promoted_snapshots=request.environment.snapshots,
 91        demoted_snapshots=_get_demoted_snapshots(request.environment, state_sync),
 92        start=request.environment.start_at,
 93        end=request.environment.end_at,
 94        unpaused_dt=unpaused_dt,
 95        no_gaps=request.no_gaps,
 96        plan_id=request.environment.plan_id,
 97        previous_plan_id=request.environment.previous_plan_id,
 98        notification_targets=request.notification_targets,
 99        backfill_concurrent_tasks=request.backfill_concurrent_tasks,
100        ddl_concurrent_tasks=request.ddl_concurrent_tasks,
101        users=request.users,
102        is_dev=request.is_dev,
103        environment_expiration_ts=request.environment.expiration_ts,
104    )
105
106
107def _get_demoted_snapshots(
108    new_environment: Environment, state_sync: StateSync
109) -> t.List[SnapshotTableInfo]:
110    current_environment = state_sync.get_environment(new_environment.name)
111    if current_environment:
112        preserved_snapshot_names = {s.name for s in new_environment.snapshots}
113        return [s for s in current_environment.snapshots if s.name not in preserved_snapshot_names]
114    return []
 16def create_plan_dag_spec(
 17    request: common.PlanApplicationRequest, state_sync: StateSync
 18) -> common.PlanDagSpec:
 19    new_snapshots = {s.snapshot_id: s for s in request.new_snapshots}
 20    stored_snapshots = state_sync.get_snapshots(None)
 21
 22    duplicated_snapshots = set(stored_snapshots).intersection(new_snapshots)
 23    if duplicated_snapshots:
 24        raise SQLMeshError(
 25            f"Snapshots {duplicated_snapshots} already exist. "
 26            "Make sure your code base is up to date and try re-creating the plan"
 27        )
 28
 29    if request.environment.end_at:
 30        end = request.environment.end_at
 31        unpaused_dt = None
 32    else:
 33        # Unbounded end date means we need to unpause all paused snapshots
 34        # that are part of the target environment.
 35        end = now()
 36        unpaused_dt = end
 37
 38    snapshots_for_intervals = {**new_snapshots, **stored_snapshots}
 39    all_snapshots_by_version = defaultdict(set)
 40    for snapshot in snapshots_for_intervals.values():
 41        all_snapshots_by_version[(snapshot.name, snapshot.version)].add(snapshot.snapshot_id)
 42
 43    if request.is_dev:
 44        # When in development mode exclude snapshots that match the version of each
 45        # paused forward-only snapshot that is a part of the plan.
 46        for s in request.environment.snapshots:
 47            if s.is_forward_only and snapshots_for_intervals[s.snapshot_id].is_paused:
 48                previous_snapshot_ids = all_snapshots_by_version[(s.name, s.version)] - {
 49                    s.snapshot_id
 50                }
 51                for sid in previous_snapshot_ids:
 52                    snapshots_for_intervals.pop(sid)
 53
 54    if request.restatements:
 55        state_sync.remove_interval(
 56            [],
 57            start=request.environment.start_at,
 58            end=end,
 59            all_snapshots=(
 60                snapshot
 61                for snapshot in snapshots_for_intervals.values()
 62                if snapshot.name in request.restatements
 63                and snapshot.snapshot_id not in new_snapshots
 64            ),
 65        )
 66
 67    if not request.skip_backfill:
 68        backfill_batches = scheduler.compute_interval_params(
 69            request.environment.snapshots,
 70            snapshots=snapshots_for_intervals,
 71            start=request.environment.start_at,
 72            end=end,
 73            latest=end,
 74        )
 75    else:
 76        backfill_batches = {}
 77
 78    backfill_intervals_per_snapshot = [
 79        common.BackfillIntervalsPerSnapshot(
 80            snapshot_id=snapshot.snapshot_id,
 81            intervals=intervals,
 82        )
 83        for snapshot, intervals in backfill_batches.items()
 84    ]
 85
 86    return common.PlanDagSpec(
 87        request_id=request.request_id,
 88        environment_name=request.environment.name,
 89        new_snapshots=request.new_snapshots,
 90        backfill_intervals_per_snapshot=backfill_intervals_per_snapshot,
 91        promoted_snapshots=request.environment.snapshots,
 92        demoted_snapshots=_get_demoted_snapshots(request.environment, state_sync),
 93        start=request.environment.start_at,
 94        end=request.environment.end_at,
 95        unpaused_dt=unpaused_dt,
 96        no_gaps=request.no_gaps,
 97        plan_id=request.environment.plan_id,
 98        previous_plan_id=request.environment.previous_plan_id,
 99        notification_targets=request.notification_targets,
100        backfill_concurrent_tasks=request.backfill_concurrent_tasks,
101        ddl_concurrent_tasks=request.ddl_concurrent_tasks,
102        users=request.users,
103        is_dev=request.is_dev,
104        environment_expiration_ts=request.environment.expiration_ts,
105    )