sqlmesh.schedulers.airflow.plan
1from __future__ import annotations 2 3import typing as t 4from collections import defaultdict 5 6from sqlmesh.core import scheduler 7from sqlmesh.core.environment import Environment 8from sqlmesh.core.snapshot import SnapshotTableInfo 9from sqlmesh.core.state_sync import StateSync 10from sqlmesh.schedulers.airflow import common 11from sqlmesh.utils.date import now 12from sqlmesh.utils.errors import SQLMeshError 13 14 15def create_plan_dag_spec( 16 request: common.PlanApplicationRequest, state_sync: StateSync 17) -> common.PlanDagSpec: 18 new_snapshots = {s.snapshot_id: s for s in request.new_snapshots} 19 stored_snapshots = state_sync.get_snapshots(None) 20 21 duplicated_snapshots = set(stored_snapshots).intersection(new_snapshots) 22 if duplicated_snapshots: 23 raise SQLMeshError( 24 f"Snapshots {duplicated_snapshots} already exist. " 25 "Make sure your code base is up to date and try re-creating the plan" 26 ) 27 28 if request.environment.end_at: 29 end = request.environment.end_at 30 unpaused_dt = None 31 else: 32 # Unbounded end date means we need to unpause all paused snapshots 33 # that are part of the target environment. 34 end = now() 35 unpaused_dt = end 36 37 snapshots_for_intervals = {**new_snapshots, **stored_snapshots} 38 all_snapshots_by_version = defaultdict(set) 39 for snapshot in snapshots_for_intervals.values(): 40 all_snapshots_by_version[(snapshot.name, snapshot.version)].add(snapshot.snapshot_id) 41 42 if request.is_dev: 43 # When in development mode exclude snapshots that match the version of each 44 # paused forward-only snapshot that is a part of the plan. 45 for s in request.environment.snapshots: 46 if s.is_forward_only and snapshots_for_intervals[s.snapshot_id].is_paused: 47 previous_snapshot_ids = all_snapshots_by_version[(s.name, s.version)] - { 48 s.snapshot_id 49 } 50 for sid in previous_snapshot_ids: 51 snapshots_for_intervals.pop(sid) 52 53 if request.restatements: 54 state_sync.remove_interval( 55 [], 56 start=request.environment.start_at, 57 end=end, 58 all_snapshots=( 59 snapshot 60 for snapshot in snapshots_for_intervals.values() 61 if snapshot.name in request.restatements 62 and snapshot.snapshot_id not in new_snapshots 63 ), 64 ) 65 66 if not request.skip_backfill: 67 backfill_batches = scheduler.compute_interval_params( 68 request.environment.snapshots, 69 snapshots=snapshots_for_intervals, 70 start=request.environment.start_at, 71 end=end, 72 latest=end, 73 ) 74 else: 75 backfill_batches = {} 76 77 backfill_intervals_per_snapshot = [ 78 common.BackfillIntervalsPerSnapshot( 79 snapshot_id=snapshot.snapshot_id, 80 intervals=intervals, 81 ) 82 for snapshot, intervals in backfill_batches.items() 83 ] 84 85 return common.PlanDagSpec( 86 request_id=request.request_id, 87 environment_name=request.environment.name, 88 new_snapshots=request.new_snapshots, 89 backfill_intervals_per_snapshot=backfill_intervals_per_snapshot, 90 promoted_snapshots=request.environment.snapshots, 91 demoted_snapshots=_get_demoted_snapshots(request.environment, state_sync), 92 start=request.environment.start_at, 93 end=request.environment.end_at, 94 unpaused_dt=unpaused_dt, 95 no_gaps=request.no_gaps, 96 plan_id=request.environment.plan_id, 97 previous_plan_id=request.environment.previous_plan_id, 98 notification_targets=request.notification_targets, 99 backfill_concurrent_tasks=request.backfill_concurrent_tasks, 100 ddl_concurrent_tasks=request.ddl_concurrent_tasks, 101 users=request.users, 102 is_dev=request.is_dev, 103 environment_expiration_ts=request.environment.expiration_ts, 104 ) 105 106 107def _get_demoted_snapshots( 108 new_environment: Environment, state_sync: StateSync 109) -> t.List[SnapshotTableInfo]: 110 current_environment = state_sync.get_environment(new_environment.name) 111 if current_environment: 112 preserved_snapshot_names = {s.name for s in new_environment.snapshots} 113 return [s for s in current_environment.snapshots if s.name not in preserved_snapshot_names] 114 return []
def
create_plan_dag_spec( request: sqlmesh.schedulers.airflow.common.PlanApplicationRequest, state_sync: sqlmesh.core.state_sync.base.StateSync) -> sqlmesh.schedulers.airflow.common.PlanDagSpec:
16def create_plan_dag_spec( 17 request: common.PlanApplicationRequest, state_sync: StateSync 18) -> common.PlanDagSpec: 19 new_snapshots = {s.snapshot_id: s for s in request.new_snapshots} 20 stored_snapshots = state_sync.get_snapshots(None) 21 22 duplicated_snapshots = set(stored_snapshots).intersection(new_snapshots) 23 if duplicated_snapshots: 24 raise SQLMeshError( 25 f"Snapshots {duplicated_snapshots} already exist. " 26 "Make sure your code base is up to date and try re-creating the plan" 27 ) 28 29 if request.environment.end_at: 30 end = request.environment.end_at 31 unpaused_dt = None 32 else: 33 # Unbounded end date means we need to unpause all paused snapshots 34 # that are part of the target environment. 35 end = now() 36 unpaused_dt = end 37 38 snapshots_for_intervals = {**new_snapshots, **stored_snapshots} 39 all_snapshots_by_version = defaultdict(set) 40 for snapshot in snapshots_for_intervals.values(): 41 all_snapshots_by_version[(snapshot.name, snapshot.version)].add(snapshot.snapshot_id) 42 43 if request.is_dev: 44 # When in development mode exclude snapshots that match the version of each 45 # paused forward-only snapshot that is a part of the plan. 46 for s in request.environment.snapshots: 47 if s.is_forward_only and snapshots_for_intervals[s.snapshot_id].is_paused: 48 previous_snapshot_ids = all_snapshots_by_version[(s.name, s.version)] - { 49 s.snapshot_id 50 } 51 for sid in previous_snapshot_ids: 52 snapshots_for_intervals.pop(sid) 53 54 if request.restatements: 55 state_sync.remove_interval( 56 [], 57 start=request.environment.start_at, 58 end=end, 59 all_snapshots=( 60 snapshot 61 for snapshot in snapshots_for_intervals.values() 62 if snapshot.name in request.restatements 63 and snapshot.snapshot_id not in new_snapshots 64 ), 65 ) 66 67 if not request.skip_backfill: 68 backfill_batches = scheduler.compute_interval_params( 69 request.environment.snapshots, 70 snapshots=snapshots_for_intervals, 71 start=request.environment.start_at, 72 end=end, 73 latest=end, 74 ) 75 else: 76 backfill_batches = {} 77 78 backfill_intervals_per_snapshot = [ 79 common.BackfillIntervalsPerSnapshot( 80 snapshot_id=snapshot.snapshot_id, 81 intervals=intervals, 82 ) 83 for snapshot, intervals in backfill_batches.items() 84 ] 85 86 return common.PlanDagSpec( 87 request_id=request.request_id, 88 environment_name=request.environment.name, 89 new_snapshots=request.new_snapshots, 90 backfill_intervals_per_snapshot=backfill_intervals_per_snapshot, 91 promoted_snapshots=request.environment.snapshots, 92 demoted_snapshots=_get_demoted_snapshots(request.environment, state_sync), 93 start=request.environment.start_at, 94 end=request.environment.end_at, 95 unpaused_dt=unpaused_dt, 96 no_gaps=request.no_gaps, 97 plan_id=request.environment.plan_id, 98 previous_plan_id=request.environment.previous_plan_id, 99 notification_targets=request.notification_targets, 100 backfill_concurrent_tasks=request.backfill_concurrent_tasks, 101 ddl_concurrent_tasks=request.ddl_concurrent_tasks, 102 users=request.users, 103 is_dev=request.is_dev, 104 environment_expiration_ts=request.environment.expiration_ts, 105 )