sqlmesh.schedulers.airflow.dag_generator
1from __future__ import annotations 2 3import logging 4import os 5import typing as t 6 7from airflow import DAG 8from airflow.models import BaseOperator, baseoperator 9from airflow.operators.empty import EmptyOperator 10from airflow.operators.python import PythonOperator 11 12from sqlmesh.core._typing import NotificationTarget 13from sqlmesh.core.environment import Environment 14from sqlmesh.core.plan import PlanStatus 15from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo 16from sqlmesh.integrations.github.notification_operator_provider import ( 17 GithubNotificationOperatorProvider, 18) 19from sqlmesh.integrations.github.notification_target import GithubNotificationTarget 20from sqlmesh.schedulers.airflow import common, util 21from sqlmesh.schedulers.airflow.operators import targets 22from sqlmesh.schedulers.airflow.operators.hwm_sensor import HighWaterMarkSensor 23from sqlmesh.schedulers.airflow.operators.notification import ( 24 BaseNotificationOperatorProvider, 25) 26from sqlmesh.utils.date import TimeLike, now, to_datetime 27from sqlmesh.utils.errors import SQLMeshError 28 29logger = logging.getLogger(__name__) 30 31 32TASK_ID_DATE_FORMAT = "%Y-%m-%d_%H-%M-%S" 33 34NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER: t.Dict[ 35 t.Type[NotificationTarget], BaseNotificationOperatorProvider 36] = { 37 GithubNotificationTarget: GithubNotificationOperatorProvider(), 38} 39 40DAG_DEFAULT_ARGS = { 41 # `AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY` support added in 2.4.0 42 # We can't use `AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY` because cloud composer doesn't allow you to set config 43 # from an environment variable 44 "retry_delay": int( 45 os.getenv( 46 "SQLMESH_AIRFLOW_DEFAULT_TASK_RETRY_DELAY", 47 os.getenv("AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY", "300"), 48 ) 49 ), 50} 51 52 53class SnapshotDagGenerator: 54 def __init__( 55 self, 56 engine_operator: t.Type[BaseOperator], 57 engine_operator_args: t.Optional[t.Dict[str, t.Any]], 58 ddl_engine_operator: t.Type[BaseOperator], 59 ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]], 60 snapshots: t.Dict[SnapshotId, Snapshot], 61 ): 62 self._engine_operator = engine_operator 63 self._engine_operator_args = engine_operator_args or {} 64 self._ddl_engine_operator = ddl_engine_operator 65 self._ddl_engine_operator_args = ddl_engine_operator_args or {} 66 self._snapshots = snapshots 67 68 def generate_cadence_dags(self) -> t.List[DAG]: 69 return [ 70 self._create_cadence_dag_for_snapshot(s) 71 for s in self._snapshots.values() 72 if s.unpaused_ts and not s.is_embedded_kind and not s.is_seed_kind 73 ] 74 75 def generate_plan_application_dag(self, spec: common.PlanDagSpec) -> DAG: 76 return self._create_plan_application_dag(spec) 77 78 def _create_cadence_dag_for_snapshot(self, snapshot: Snapshot) -> DAG: 79 dag_id = common.dag_id_for_snapshot_info(snapshot.table_info) 80 logger.info( 81 "Generating the cadence DAG '%s' for snapshot %s", 82 dag_id, 83 snapshot.snapshot_id, 84 ) 85 86 if not snapshot.unpaused_ts: 87 raise SQLMeshError( 88 f"Can't create a cadence DAG for the paused snapshot {snapshot.snapshot_id}" 89 ) 90 91 with DAG( 92 dag_id=dag_id, 93 schedule_interval=snapshot.model.cron, 94 start_date=to_datetime(snapshot.unpaused_ts), 95 max_active_runs=1, 96 catchup=True, 97 is_paused_upon_creation=False, 98 tags=[ 99 common.SQLMESH_AIRFLOW_TAG, 100 common.SNAPSHOT_AIRFLOW_TAG, 101 snapshot.name, 102 ], 103 default_args={ 104 **DAG_DEFAULT_ARGS, 105 "email": snapshot.model.owner, 106 "email_on_failure": True, 107 }, 108 ) as dag: 109 hwm_sensor_tasks = self._create_hwm_sensors(snapshot=snapshot) 110 111 evaluator_task = self._create_snapshot_evaluator_operator( 112 snapshots=self._snapshots, 113 snapshot=snapshot, 114 task_id="snapshot_evaluator", 115 ) 116 117 hwm_sensor_tasks >> evaluator_task 118 119 return dag 120 121 def _create_plan_application_dag(self, plan_dag_spec: common.PlanDagSpec) -> DAG: 122 dag_id = common.plan_application_dag_id( 123 plan_dag_spec.environment_name, plan_dag_spec.request_id 124 ) 125 logger.info( 126 "Generating the plan application DAG '%s' for environment '%s'", 127 dag_id, 128 plan_dag_spec.environment_name, 129 ) 130 131 all_snapshots = { 132 **{s.snapshot_id: s for s in plan_dag_spec.new_snapshots}, 133 **self._snapshots, 134 } 135 136 with DAG( 137 dag_id=dag_id, 138 schedule_interval="@once", 139 start_date=now(), 140 max_active_tasks=plan_dag_spec.backfill_concurrent_tasks, 141 catchup=False, 142 is_paused_upon_creation=False, 143 default_args=DAG_DEFAULT_ARGS, 144 tags=[ 145 common.SQLMESH_AIRFLOW_TAG, 146 common.PLAN_AIRFLOW_TAG, 147 plan_dag_spec.environment_name, 148 ], 149 ) as dag: 150 start_task = EmptyOperator(task_id="plan_application_start") 151 end_task = EmptyOperator(task_id="plan_application_end") 152 153 (create_start_task, create_end_task) = self._create_creation_tasks( 154 plan_dag_spec.new_snapshots, plan_dag_spec.ddl_concurrent_tasks 155 ) 156 157 (backfill_start_task, backfill_end_task) = self._create_backfill_tasks( 158 plan_dag_spec.backfill_intervals_per_snapshot, 159 all_snapshots, 160 plan_dag_spec.is_dev, 161 ) 162 163 ( 164 promote_start_task, 165 promote_end_task, 166 ) = self._create_promotion_demotion_tasks(plan_dag_spec) 167 168 start_task >> create_start_task 169 create_end_task >> backfill_start_task 170 backfill_end_task >> promote_start_task 171 172 self._add_notification_target_tasks( 173 plan_dag_spec, start_task, end_task, promote_end_task 174 ) 175 return dag 176 177 def _add_notification_target_tasks( 178 self, 179 request: common.PlanDagSpec, 180 start_task: BaseOperator, 181 end_task: BaseOperator, 182 promote_end_task: BaseOperator, 183 ) -> None: 184 has_success_or_failed_notification = False 185 for notification_target in request.notification_targets: 186 notification_operator_provider = NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER.get( 187 type(notification_target) 188 ) 189 if not notification_operator_provider: 190 continue 191 plan_start_notification_task = notification_operator_provider.operator( 192 notification_target, PlanStatus.STARTED, request 193 ) 194 plan_success_notification_task = notification_operator_provider.operator( 195 notification_target, PlanStatus.FINISHED, request 196 ) 197 plan_failed_notification_task = notification_operator_provider.operator( 198 notification_target, PlanStatus.FAILED, request 199 ) 200 if plan_start_notification_task: 201 start_task >> plan_start_notification_task 202 if plan_success_notification_task: 203 has_success_or_failed_notification = True 204 promote_end_task >> plan_success_notification_task 205 plan_success_notification_task >> end_task 206 if plan_failed_notification_task: 207 has_success_or_failed_notification = True 208 promote_end_task >> plan_failed_notification_task 209 plan_failed_notification_task >> end_task 210 if not has_success_or_failed_notification: 211 promote_end_task >> end_task 212 213 def _create_creation_tasks( 214 self, new_snapshots: t.List[Snapshot], ddl_concurrent_tasks: int 215 ) -> t.Tuple[BaseOperator, BaseOperator]: 216 start_task = EmptyOperator(task_id="snapshot_creation_start") 217 end_task = EmptyOperator(task_id="snapshot_creation_end") 218 219 if not new_snapshots: 220 start_task >> end_task 221 return (start_task, end_task) 222 223 creation_task = self._create_snapshot_create_tables_operator( 224 new_snapshots, ddl_concurrent_tasks, "snapshot_creation__create_tables" 225 ) 226 227 update_state_task = PythonOperator( 228 task_id="snapshot_creation__update_state", 229 python_callable=creation_update_state_task, 230 op_kwargs={"new_snapshots": new_snapshots}, 231 ) 232 233 start_task >> creation_task 234 creation_task >> update_state_task 235 update_state_task >> end_task 236 237 return (start_task, end_task) 238 239 def _create_promotion_demotion_tasks( 240 self, request: common.PlanDagSpec 241 ) -> t.Tuple[BaseOperator, BaseOperator]: 242 start_task = EmptyOperator(task_id="snapshot_promotion_start") 243 end_task = EmptyOperator(task_id="snapshot_promotion_end") 244 245 update_state_task = PythonOperator( 246 task_id="snapshot_promotion__update_state", 247 python_callable=promotion_update_state_task, 248 op_kwargs={ 249 "snapshots": request.promoted_snapshots, 250 "environment_name": request.environment_name, 251 "start": request.start, 252 "end": request.end, 253 "unpaused_dt": request.unpaused_dt, 254 "no_gaps": request.no_gaps, 255 "plan_id": request.plan_id, 256 "previous_plan_id": request.previous_plan_id, 257 "environment_expiration_ts": request.environment_expiration_ts, 258 }, 259 ) 260 261 start_task >> update_state_task 262 263 if request.promoted_snapshots: 264 create_views_task = self._create_snapshot_promotion_operator( 265 request.promoted_snapshots, 266 request.environment_name, 267 request.ddl_concurrent_tasks, 268 request.is_dev, 269 "snapshot_promotion__create_views", 270 ) 271 create_views_task >> end_task 272 273 if not request.is_dev and request.unpaused_dt: 274 migrate_tables_task = self._create_snapshot_migrate_tables_operator( 275 request.promoted_snapshots, 276 request.ddl_concurrent_tasks, 277 "snapshot_promotion__migrate_tables", 278 ) 279 update_state_task >> migrate_tables_task 280 migrate_tables_task >> create_views_task 281 else: 282 update_state_task >> create_views_task 283 284 if request.demoted_snapshots: 285 delete_views_task = self._create_snapshot_demotion_operator( 286 request.demoted_snapshots, 287 request.environment_name, 288 request.ddl_concurrent_tasks, 289 "snapshot_promotion__delete_views", 290 ) 291 update_state_task >> delete_views_task 292 delete_views_task >> end_task 293 294 if not request.promoted_snapshots and not request.demoted_snapshots: 295 update_state_task >> end_task 296 297 return (start_task, end_task) 298 299 def _create_backfill_tasks( 300 self, 301 backfill_intervals: t.List[common.BackfillIntervalsPerSnapshot], 302 snapshots: t.Dict[SnapshotId, Snapshot], 303 is_dev: bool, 304 ) -> t.Tuple[BaseOperator, BaseOperator]: 305 snapshot_to_tasks = {} 306 for intervals_per_snapshot in backfill_intervals: 307 sid = intervals_per_snapshot.snapshot_id 308 309 if not intervals_per_snapshot.intervals: 310 logger.info(f"Skipping backfill for snapshot %s", sid) 311 continue 312 313 snapshot = snapshots[sid] 314 315 task_id_prefix = f"snapshot_evaluator__{snapshot.name}__{snapshot.identifier}" 316 tasks = [ 317 self._create_snapshot_evaluator_operator( 318 snapshots=snapshots, 319 snapshot=snapshot, 320 task_id=f"{task_id_prefix}__{start.strftime(TASK_ID_DATE_FORMAT)}__{end.strftime(TASK_ID_DATE_FORMAT)}", 321 start=start, 322 end=end, 323 is_dev=is_dev, 324 ) 325 for (start, end) in intervals_per_snapshot.intervals 326 ] 327 snapshot_start_task = EmptyOperator( 328 task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__start" 329 ) 330 snapshot_end_task = EmptyOperator( 331 task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__end" 332 ) 333 if snapshot.is_incremental_by_unique_key_kind: 334 baseoperator.chain(snapshot_start_task, *tasks, snapshot_end_task) 335 else: 336 snapshot_start_task >> tasks >> snapshot_end_task 337 snapshot_to_tasks[snapshot.snapshot_id] = ( 338 snapshot_start_task, 339 snapshot_end_task, 340 ) 341 342 backfill_start_task = EmptyOperator(task_id="snapshot_backfill_start") 343 backfill_end_task = EmptyOperator(task_id="snapshot_backfill_end") 344 345 if not snapshot_to_tasks: 346 backfill_start_task >> backfill_end_task 347 return (backfill_start_task, backfill_end_task) 348 349 entry_tasks = [] 350 parent_ids_to_backfill = set() 351 for sid, (start_task, _) in snapshot_to_tasks.items(): 352 has_parents_to_backfill = False 353 for p_sid in snapshots[sid].parents: 354 if p_sid in snapshot_to_tasks: 355 snapshot_to_tasks[p_sid][1] >> start_task 356 parent_ids_to_backfill.add(p_sid) 357 has_parents_to_backfill = True 358 359 if not has_parents_to_backfill: 360 entry_tasks.append(start_task) 361 362 backfill_start_task >> entry_tasks 363 364 exit_tasks = [ 365 end_task 366 for sid, (_, end_task) in snapshot_to_tasks.items() 367 if sid not in parent_ids_to_backfill 368 ] 369 for task in exit_tasks: 370 task >> backfill_end_task 371 372 return (backfill_start_task, backfill_end_task) 373 374 def _create_snapshot_promotion_operator( 375 self, 376 snapshots: t.List[SnapshotTableInfo], 377 environment: str, 378 ddl_concurrent_tasks: int, 379 is_dev: bool, 380 task_id: str, 381 ) -> BaseOperator: 382 return self._ddl_engine_operator( 383 **self._ddl_engine_operator_args, 384 target=targets.SnapshotPromotionTarget( 385 snapshots=snapshots, 386 environment=environment, 387 ddl_concurrent_tasks=ddl_concurrent_tasks, 388 is_dev=is_dev, 389 ), 390 task_id=task_id, 391 ) 392 393 def _create_snapshot_demotion_operator( 394 self, 395 snapshots: t.List[SnapshotTableInfo], 396 environment: str, 397 ddl_concurrent_tasks: int, 398 task_id: str, 399 ) -> BaseOperator: 400 return self._ddl_engine_operator( 401 **self._ddl_engine_operator_args, 402 target=targets.SnapshotDemotionTarget( 403 snapshots=snapshots, 404 environment=environment, 405 ddl_concurrent_tasks=ddl_concurrent_tasks, 406 ), 407 task_id=task_id, 408 ) 409 410 def _create_snapshot_create_tables_operator( 411 self, 412 new_snapshots: t.List[Snapshot], 413 ddl_concurrent_tasks: int, 414 task_id: str, 415 ) -> BaseOperator: 416 return self._ddl_engine_operator( 417 **self._ddl_engine_operator_args, 418 target=targets.SnapshotCreateTablesTarget( 419 new_snapshots=new_snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks 420 ), 421 task_id=task_id, 422 ) 423 424 def _create_snapshot_migrate_tables_operator( 425 self, 426 snapshots: t.List[SnapshotTableInfo], 427 ddl_concurrent_tasks: int, 428 task_id: str, 429 ) -> BaseOperator: 430 return self._ddl_engine_operator( 431 **self._ddl_engine_operator_args, 432 target=targets.SnapshotMigrateTablesTarget( 433 snapshots=snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks 434 ), 435 task_id=task_id, 436 ) 437 438 def _create_snapshot_evaluator_operator( 439 self, 440 snapshots: t.Dict[SnapshotId, Snapshot], 441 snapshot: Snapshot, 442 task_id: str, 443 start: t.Optional[TimeLike] = None, 444 end: t.Optional[TimeLike] = None, 445 is_dev: bool = False, 446 ) -> BaseOperator: 447 parent_snapshots = {sid.name: snapshots[sid] for sid in snapshot.parents} 448 449 return self._engine_operator( 450 **self._engine_operator_args, 451 target=targets.SnapshotEvaluationTarget( 452 snapshot=snapshot, 453 parent_snapshots=parent_snapshots, 454 start=start, 455 end=end, 456 is_dev=is_dev, 457 ), 458 task_id=task_id, 459 ) 460 461 def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor]: 462 output = [] 463 for upstream_snapshot_id in snapshot.parents: 464 upstream_snapshot = self._snapshots[upstream_snapshot_id] 465 if not upstream_snapshot.is_embedded_kind and not upstream_snapshot.is_seed_kind: 466 output.append( 467 HighWaterMarkSensor( 468 target_snapshot_info=upstream_snapshot.table_info, 469 this_snapshot=snapshot, 470 task_id=f"{upstream_snapshot.name}_{upstream_snapshot.version}_high_water_mark_sensor", 471 ) 472 ) 473 return output 474 475 476def creation_update_state_task(new_snapshots: t.Iterable[Snapshot]) -> None: 477 with util.scoped_state_sync() as state_sync: 478 state_sync.push_snapshots(new_snapshots) 479 480 481def promotion_update_state_task( 482 snapshots: t.List[SnapshotTableInfo], 483 environment_name: str, 484 start: TimeLike, 485 end: t.Optional[TimeLike], 486 unpaused_dt: t.Optional[TimeLike], 487 no_gaps: bool, 488 plan_id: str, 489 previous_plan_id: t.Optional[str], 490 environment_expiration_ts: t.Optional[int], 491) -> None: 492 environment = Environment( 493 name=environment_name, 494 snapshots=snapshots, 495 start_at=start, 496 end_at=end, 497 plan_id=plan_id, 498 previous_plan_id=previous_plan_id, 499 expiration_ts=environment_expiration_ts, 500 ) 501 with util.scoped_state_sync() as state_sync: 502 state_sync.promote(environment, no_gaps=no_gaps) 503 if snapshots and not end and unpaused_dt: 504 state_sync.unpause_snapshots(snapshots, unpaused_dt)
class
SnapshotDagGenerator:
54class SnapshotDagGenerator: 55 def __init__( 56 self, 57 engine_operator: t.Type[BaseOperator], 58 engine_operator_args: t.Optional[t.Dict[str, t.Any]], 59 ddl_engine_operator: t.Type[BaseOperator], 60 ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]], 61 snapshots: t.Dict[SnapshotId, Snapshot], 62 ): 63 self._engine_operator = engine_operator 64 self._engine_operator_args = engine_operator_args or {} 65 self._ddl_engine_operator = ddl_engine_operator 66 self._ddl_engine_operator_args = ddl_engine_operator_args or {} 67 self._snapshots = snapshots 68 69 def generate_cadence_dags(self) -> t.List[DAG]: 70 return [ 71 self._create_cadence_dag_for_snapshot(s) 72 for s in self._snapshots.values() 73 if s.unpaused_ts and not s.is_embedded_kind and not s.is_seed_kind 74 ] 75 76 def generate_plan_application_dag(self, spec: common.PlanDagSpec) -> DAG: 77 return self._create_plan_application_dag(spec) 78 79 def _create_cadence_dag_for_snapshot(self, snapshot: Snapshot) -> DAG: 80 dag_id = common.dag_id_for_snapshot_info(snapshot.table_info) 81 logger.info( 82 "Generating the cadence DAG '%s' for snapshot %s", 83 dag_id, 84 snapshot.snapshot_id, 85 ) 86 87 if not snapshot.unpaused_ts: 88 raise SQLMeshError( 89 f"Can't create a cadence DAG for the paused snapshot {snapshot.snapshot_id}" 90 ) 91 92 with DAG( 93 dag_id=dag_id, 94 schedule_interval=snapshot.model.cron, 95 start_date=to_datetime(snapshot.unpaused_ts), 96 max_active_runs=1, 97 catchup=True, 98 is_paused_upon_creation=False, 99 tags=[ 100 common.SQLMESH_AIRFLOW_TAG, 101 common.SNAPSHOT_AIRFLOW_TAG, 102 snapshot.name, 103 ], 104 default_args={ 105 **DAG_DEFAULT_ARGS, 106 "email": snapshot.model.owner, 107 "email_on_failure": True, 108 }, 109 ) as dag: 110 hwm_sensor_tasks = self._create_hwm_sensors(snapshot=snapshot) 111 112 evaluator_task = self._create_snapshot_evaluator_operator( 113 snapshots=self._snapshots, 114 snapshot=snapshot, 115 task_id="snapshot_evaluator", 116 ) 117 118 hwm_sensor_tasks >> evaluator_task 119 120 return dag 121 122 def _create_plan_application_dag(self, plan_dag_spec: common.PlanDagSpec) -> DAG: 123 dag_id = common.plan_application_dag_id( 124 plan_dag_spec.environment_name, plan_dag_spec.request_id 125 ) 126 logger.info( 127 "Generating the plan application DAG '%s' for environment '%s'", 128 dag_id, 129 plan_dag_spec.environment_name, 130 ) 131 132 all_snapshots = { 133 **{s.snapshot_id: s for s in plan_dag_spec.new_snapshots}, 134 **self._snapshots, 135 } 136 137 with DAG( 138 dag_id=dag_id, 139 schedule_interval="@once", 140 start_date=now(), 141 max_active_tasks=plan_dag_spec.backfill_concurrent_tasks, 142 catchup=False, 143 is_paused_upon_creation=False, 144 default_args=DAG_DEFAULT_ARGS, 145 tags=[ 146 common.SQLMESH_AIRFLOW_TAG, 147 common.PLAN_AIRFLOW_TAG, 148 plan_dag_spec.environment_name, 149 ], 150 ) as dag: 151 start_task = EmptyOperator(task_id="plan_application_start") 152 end_task = EmptyOperator(task_id="plan_application_end") 153 154 (create_start_task, create_end_task) = self._create_creation_tasks( 155 plan_dag_spec.new_snapshots, plan_dag_spec.ddl_concurrent_tasks 156 ) 157 158 (backfill_start_task, backfill_end_task) = self._create_backfill_tasks( 159 plan_dag_spec.backfill_intervals_per_snapshot, 160 all_snapshots, 161 plan_dag_spec.is_dev, 162 ) 163 164 ( 165 promote_start_task, 166 promote_end_task, 167 ) = self._create_promotion_demotion_tasks(plan_dag_spec) 168 169 start_task >> create_start_task 170 create_end_task >> backfill_start_task 171 backfill_end_task >> promote_start_task 172 173 self._add_notification_target_tasks( 174 plan_dag_spec, start_task, end_task, promote_end_task 175 ) 176 return dag 177 178 def _add_notification_target_tasks( 179 self, 180 request: common.PlanDagSpec, 181 start_task: BaseOperator, 182 end_task: BaseOperator, 183 promote_end_task: BaseOperator, 184 ) -> None: 185 has_success_or_failed_notification = False 186 for notification_target in request.notification_targets: 187 notification_operator_provider = NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER.get( 188 type(notification_target) 189 ) 190 if not notification_operator_provider: 191 continue 192 plan_start_notification_task = notification_operator_provider.operator( 193 notification_target, PlanStatus.STARTED, request 194 ) 195 plan_success_notification_task = notification_operator_provider.operator( 196 notification_target, PlanStatus.FINISHED, request 197 ) 198 plan_failed_notification_task = notification_operator_provider.operator( 199 notification_target, PlanStatus.FAILED, request 200 ) 201 if plan_start_notification_task: 202 start_task >> plan_start_notification_task 203 if plan_success_notification_task: 204 has_success_or_failed_notification = True 205 promote_end_task >> plan_success_notification_task 206 plan_success_notification_task >> end_task 207 if plan_failed_notification_task: 208 has_success_or_failed_notification = True 209 promote_end_task >> plan_failed_notification_task 210 plan_failed_notification_task >> end_task 211 if not has_success_or_failed_notification: 212 promote_end_task >> end_task 213 214 def _create_creation_tasks( 215 self, new_snapshots: t.List[Snapshot], ddl_concurrent_tasks: int 216 ) -> t.Tuple[BaseOperator, BaseOperator]: 217 start_task = EmptyOperator(task_id="snapshot_creation_start") 218 end_task = EmptyOperator(task_id="snapshot_creation_end") 219 220 if not new_snapshots: 221 start_task >> end_task 222 return (start_task, end_task) 223 224 creation_task = self._create_snapshot_create_tables_operator( 225 new_snapshots, ddl_concurrent_tasks, "snapshot_creation__create_tables" 226 ) 227 228 update_state_task = PythonOperator( 229 task_id="snapshot_creation__update_state", 230 python_callable=creation_update_state_task, 231 op_kwargs={"new_snapshots": new_snapshots}, 232 ) 233 234 start_task >> creation_task 235 creation_task >> update_state_task 236 update_state_task >> end_task 237 238 return (start_task, end_task) 239 240 def _create_promotion_demotion_tasks( 241 self, request: common.PlanDagSpec 242 ) -> t.Tuple[BaseOperator, BaseOperator]: 243 start_task = EmptyOperator(task_id="snapshot_promotion_start") 244 end_task = EmptyOperator(task_id="snapshot_promotion_end") 245 246 update_state_task = PythonOperator( 247 task_id="snapshot_promotion__update_state", 248 python_callable=promotion_update_state_task, 249 op_kwargs={ 250 "snapshots": request.promoted_snapshots, 251 "environment_name": request.environment_name, 252 "start": request.start, 253 "end": request.end, 254 "unpaused_dt": request.unpaused_dt, 255 "no_gaps": request.no_gaps, 256 "plan_id": request.plan_id, 257 "previous_plan_id": request.previous_plan_id, 258 "environment_expiration_ts": request.environment_expiration_ts, 259 }, 260 ) 261 262 start_task >> update_state_task 263 264 if request.promoted_snapshots: 265 create_views_task = self._create_snapshot_promotion_operator( 266 request.promoted_snapshots, 267 request.environment_name, 268 request.ddl_concurrent_tasks, 269 request.is_dev, 270 "snapshot_promotion__create_views", 271 ) 272 create_views_task >> end_task 273 274 if not request.is_dev and request.unpaused_dt: 275 migrate_tables_task = self._create_snapshot_migrate_tables_operator( 276 request.promoted_snapshots, 277 request.ddl_concurrent_tasks, 278 "snapshot_promotion__migrate_tables", 279 ) 280 update_state_task >> migrate_tables_task 281 migrate_tables_task >> create_views_task 282 else: 283 update_state_task >> create_views_task 284 285 if request.demoted_snapshots: 286 delete_views_task = self._create_snapshot_demotion_operator( 287 request.demoted_snapshots, 288 request.environment_name, 289 request.ddl_concurrent_tasks, 290 "snapshot_promotion__delete_views", 291 ) 292 update_state_task >> delete_views_task 293 delete_views_task >> end_task 294 295 if not request.promoted_snapshots and not request.demoted_snapshots: 296 update_state_task >> end_task 297 298 return (start_task, end_task) 299 300 def _create_backfill_tasks( 301 self, 302 backfill_intervals: t.List[common.BackfillIntervalsPerSnapshot], 303 snapshots: t.Dict[SnapshotId, Snapshot], 304 is_dev: bool, 305 ) -> t.Tuple[BaseOperator, BaseOperator]: 306 snapshot_to_tasks = {} 307 for intervals_per_snapshot in backfill_intervals: 308 sid = intervals_per_snapshot.snapshot_id 309 310 if not intervals_per_snapshot.intervals: 311 logger.info(f"Skipping backfill for snapshot %s", sid) 312 continue 313 314 snapshot = snapshots[sid] 315 316 task_id_prefix = f"snapshot_evaluator__{snapshot.name}__{snapshot.identifier}" 317 tasks = [ 318 self._create_snapshot_evaluator_operator( 319 snapshots=snapshots, 320 snapshot=snapshot, 321 task_id=f"{task_id_prefix}__{start.strftime(TASK_ID_DATE_FORMAT)}__{end.strftime(TASK_ID_DATE_FORMAT)}", 322 start=start, 323 end=end, 324 is_dev=is_dev, 325 ) 326 for (start, end) in intervals_per_snapshot.intervals 327 ] 328 snapshot_start_task = EmptyOperator( 329 task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__start" 330 ) 331 snapshot_end_task = EmptyOperator( 332 task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__end" 333 ) 334 if snapshot.is_incremental_by_unique_key_kind: 335 baseoperator.chain(snapshot_start_task, *tasks, snapshot_end_task) 336 else: 337 snapshot_start_task >> tasks >> snapshot_end_task 338 snapshot_to_tasks[snapshot.snapshot_id] = ( 339 snapshot_start_task, 340 snapshot_end_task, 341 ) 342 343 backfill_start_task = EmptyOperator(task_id="snapshot_backfill_start") 344 backfill_end_task = EmptyOperator(task_id="snapshot_backfill_end") 345 346 if not snapshot_to_tasks: 347 backfill_start_task >> backfill_end_task 348 return (backfill_start_task, backfill_end_task) 349 350 entry_tasks = [] 351 parent_ids_to_backfill = set() 352 for sid, (start_task, _) in snapshot_to_tasks.items(): 353 has_parents_to_backfill = False 354 for p_sid in snapshots[sid].parents: 355 if p_sid in snapshot_to_tasks: 356 snapshot_to_tasks[p_sid][1] >> start_task 357 parent_ids_to_backfill.add(p_sid) 358 has_parents_to_backfill = True 359 360 if not has_parents_to_backfill: 361 entry_tasks.append(start_task) 362 363 backfill_start_task >> entry_tasks 364 365 exit_tasks = [ 366 end_task 367 for sid, (_, end_task) in snapshot_to_tasks.items() 368 if sid not in parent_ids_to_backfill 369 ] 370 for task in exit_tasks: 371 task >> backfill_end_task 372 373 return (backfill_start_task, backfill_end_task) 374 375 def _create_snapshot_promotion_operator( 376 self, 377 snapshots: t.List[SnapshotTableInfo], 378 environment: str, 379 ddl_concurrent_tasks: int, 380 is_dev: bool, 381 task_id: str, 382 ) -> BaseOperator: 383 return self._ddl_engine_operator( 384 **self._ddl_engine_operator_args, 385 target=targets.SnapshotPromotionTarget( 386 snapshots=snapshots, 387 environment=environment, 388 ddl_concurrent_tasks=ddl_concurrent_tasks, 389 is_dev=is_dev, 390 ), 391 task_id=task_id, 392 ) 393 394 def _create_snapshot_demotion_operator( 395 self, 396 snapshots: t.List[SnapshotTableInfo], 397 environment: str, 398 ddl_concurrent_tasks: int, 399 task_id: str, 400 ) -> BaseOperator: 401 return self._ddl_engine_operator( 402 **self._ddl_engine_operator_args, 403 target=targets.SnapshotDemotionTarget( 404 snapshots=snapshots, 405 environment=environment, 406 ddl_concurrent_tasks=ddl_concurrent_tasks, 407 ), 408 task_id=task_id, 409 ) 410 411 def _create_snapshot_create_tables_operator( 412 self, 413 new_snapshots: t.List[Snapshot], 414 ddl_concurrent_tasks: int, 415 task_id: str, 416 ) -> BaseOperator: 417 return self._ddl_engine_operator( 418 **self._ddl_engine_operator_args, 419 target=targets.SnapshotCreateTablesTarget( 420 new_snapshots=new_snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks 421 ), 422 task_id=task_id, 423 ) 424 425 def _create_snapshot_migrate_tables_operator( 426 self, 427 snapshots: t.List[SnapshotTableInfo], 428 ddl_concurrent_tasks: int, 429 task_id: str, 430 ) -> BaseOperator: 431 return self._ddl_engine_operator( 432 **self._ddl_engine_operator_args, 433 target=targets.SnapshotMigrateTablesTarget( 434 snapshots=snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks 435 ), 436 task_id=task_id, 437 ) 438 439 def _create_snapshot_evaluator_operator( 440 self, 441 snapshots: t.Dict[SnapshotId, Snapshot], 442 snapshot: Snapshot, 443 task_id: str, 444 start: t.Optional[TimeLike] = None, 445 end: t.Optional[TimeLike] = None, 446 is_dev: bool = False, 447 ) -> BaseOperator: 448 parent_snapshots = {sid.name: snapshots[sid] for sid in snapshot.parents} 449 450 return self._engine_operator( 451 **self._engine_operator_args, 452 target=targets.SnapshotEvaluationTarget( 453 snapshot=snapshot, 454 parent_snapshots=parent_snapshots, 455 start=start, 456 end=end, 457 is_dev=is_dev, 458 ), 459 task_id=task_id, 460 ) 461 462 def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor]: 463 output = [] 464 for upstream_snapshot_id in snapshot.parents: 465 upstream_snapshot = self._snapshots[upstream_snapshot_id] 466 if not upstream_snapshot.is_embedded_kind and not upstream_snapshot.is_seed_kind: 467 output.append( 468 HighWaterMarkSensor( 469 target_snapshot_info=upstream_snapshot.table_info, 470 this_snapshot=snapshot, 471 task_id=f"{upstream_snapshot.name}_{upstream_snapshot.version}_high_water_mark_sensor", 472 ) 473 ) 474 return output
SnapshotDagGenerator( engine_operator: Type[airflow.models.baseoperator.BaseOperator], engine_operator_args: Optional[Dict[str, Any]], ddl_engine_operator: Type[airflow.models.baseoperator.BaseOperator], ddl_engine_operator_args: Optional[Dict[str, Any]], snapshots: Dict[sqlmesh.core.snapshot.definition.SnapshotId, sqlmesh.core.snapshot.definition.Snapshot])
55 def __init__( 56 self, 57 engine_operator: t.Type[BaseOperator], 58 engine_operator_args: t.Optional[t.Dict[str, t.Any]], 59 ddl_engine_operator: t.Type[BaseOperator], 60 ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]], 61 snapshots: t.Dict[SnapshotId, Snapshot], 62 ): 63 self._engine_operator = engine_operator 64 self._engine_operator_args = engine_operator_args or {} 65 self._ddl_engine_operator = ddl_engine_operator 66 self._ddl_engine_operator_args = ddl_engine_operator_args or {} 67 self._snapshots = snapshots
def
generate_plan_application_dag( self, spec: sqlmesh.schedulers.airflow.common.PlanDagSpec) -> airflow.models.dag.DAG:
def
creation_update_state_task( new_snapshots: Iterable[sqlmesh.core.snapshot.definition.Snapshot]) -> None:
def
promotion_update_state_task( snapshots: List[sqlmesh.core.snapshot.definition.SnapshotTableInfo], environment_name: str, start: Union[datetime.date, datetime.datetime, str, int, float], end: Union[datetime.date, datetime.datetime, str, int, float, NoneType], unpaused_dt: Union[datetime.date, datetime.datetime, str, int, float, NoneType], no_gaps: bool, plan_id: str, previous_plan_id: Optional[str], environment_expiration_ts: Optional[int]) -> None:
482def promotion_update_state_task( 483 snapshots: t.List[SnapshotTableInfo], 484 environment_name: str, 485 start: TimeLike, 486 end: t.Optional[TimeLike], 487 unpaused_dt: t.Optional[TimeLike], 488 no_gaps: bool, 489 plan_id: str, 490 previous_plan_id: t.Optional[str], 491 environment_expiration_ts: t.Optional[int], 492) -> None: 493 environment = Environment( 494 name=environment_name, 495 snapshots=snapshots, 496 start_at=start, 497 end_at=end, 498 plan_id=plan_id, 499 previous_plan_id=previous_plan_id, 500 expiration_ts=environment_expiration_ts, 501 ) 502 with util.scoped_state_sync() as state_sync: 503 state_sync.promote(environment, no_gaps=no_gaps) 504 if snapshots and not end and unpaused_dt: 505 state_sync.unpause_snapshots(snapshots, unpaused_dt)