Edit on GitHub

sqlmesh.schedulers.airflow.dag_generator

  1from __future__ import annotations
  2
  3import logging
  4import os
  5import typing as t
  6
  7from airflow import DAG
  8from airflow.models import BaseOperator, baseoperator
  9from airflow.operators.empty import EmptyOperator
 10from airflow.operators.python import PythonOperator
 11
 12from sqlmesh.core._typing import NotificationTarget
 13from sqlmesh.core.environment import Environment
 14from sqlmesh.core.plan import PlanStatus
 15from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo
 16from sqlmesh.integrations.github.notification_operator_provider import (
 17    GithubNotificationOperatorProvider,
 18)
 19from sqlmesh.integrations.github.notification_target import GithubNotificationTarget
 20from sqlmesh.schedulers.airflow import common, util
 21from sqlmesh.schedulers.airflow.operators import targets
 22from sqlmesh.schedulers.airflow.operators.hwm_sensor import HighWaterMarkSensor
 23from sqlmesh.schedulers.airflow.operators.notification import (
 24    BaseNotificationOperatorProvider,
 25)
 26from sqlmesh.utils.date import TimeLike, now, to_datetime
 27from sqlmesh.utils.errors import SQLMeshError
 28
 29logger = logging.getLogger(__name__)
 30
 31
 32TASK_ID_DATE_FORMAT = "%Y-%m-%d_%H-%M-%S"
 33
 34NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER: t.Dict[
 35    t.Type[NotificationTarget], BaseNotificationOperatorProvider
 36] = {
 37    GithubNotificationTarget: GithubNotificationOperatorProvider(),
 38}
 39
 40DAG_DEFAULT_ARGS = {
 41    # `AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY` support added in 2.4.0
 42    # We can't use `AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY` because cloud composer doesn't allow you to set config
 43    # from an environment variable
 44    "retry_delay": int(
 45        os.getenv(
 46            "SQLMESH_AIRFLOW_DEFAULT_TASK_RETRY_DELAY",
 47            os.getenv("AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY", "300"),
 48        )
 49    ),
 50}
 51
 52
 53class SnapshotDagGenerator:
 54    def __init__(
 55        self,
 56        engine_operator: t.Type[BaseOperator],
 57        engine_operator_args: t.Optional[t.Dict[str, t.Any]],
 58        ddl_engine_operator: t.Type[BaseOperator],
 59        ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]],
 60        snapshots: t.Dict[SnapshotId, Snapshot],
 61    ):
 62        self._engine_operator = engine_operator
 63        self._engine_operator_args = engine_operator_args or {}
 64        self._ddl_engine_operator = ddl_engine_operator
 65        self._ddl_engine_operator_args = ddl_engine_operator_args or {}
 66        self._snapshots = snapshots
 67
 68    def generate_cadence_dags(self) -> t.List[DAG]:
 69        return [
 70            self._create_cadence_dag_for_snapshot(s)
 71            for s in self._snapshots.values()
 72            if s.unpaused_ts and not s.is_embedded_kind and not s.is_seed_kind
 73        ]
 74
 75    def generate_plan_application_dag(self, spec: common.PlanDagSpec) -> DAG:
 76        return self._create_plan_application_dag(spec)
 77
 78    def _create_cadence_dag_for_snapshot(self, snapshot: Snapshot) -> DAG:
 79        dag_id = common.dag_id_for_snapshot_info(snapshot.table_info)
 80        logger.info(
 81            "Generating the cadence DAG '%s' for snapshot %s",
 82            dag_id,
 83            snapshot.snapshot_id,
 84        )
 85
 86        if not snapshot.unpaused_ts:
 87            raise SQLMeshError(
 88                f"Can't create a cadence DAG for the paused snapshot {snapshot.snapshot_id}"
 89            )
 90
 91        with DAG(
 92            dag_id=dag_id,
 93            schedule_interval=snapshot.model.cron,
 94            start_date=to_datetime(snapshot.unpaused_ts),
 95            max_active_runs=1,
 96            catchup=True,
 97            is_paused_upon_creation=False,
 98            tags=[
 99                common.SQLMESH_AIRFLOW_TAG,
100                common.SNAPSHOT_AIRFLOW_TAG,
101                snapshot.name,
102            ],
103            default_args={
104                **DAG_DEFAULT_ARGS,
105                "email": snapshot.model.owner,
106                "email_on_failure": True,
107            },
108        ) as dag:
109            hwm_sensor_tasks = self._create_hwm_sensors(snapshot=snapshot)
110
111            evaluator_task = self._create_snapshot_evaluator_operator(
112                snapshots=self._snapshots,
113                snapshot=snapshot,
114                task_id="snapshot_evaluator",
115            )
116
117            hwm_sensor_tasks >> evaluator_task
118
119            return dag
120
121    def _create_plan_application_dag(self, plan_dag_spec: common.PlanDagSpec) -> DAG:
122        dag_id = common.plan_application_dag_id(
123            plan_dag_spec.environment_name, plan_dag_spec.request_id
124        )
125        logger.info(
126            "Generating the plan application DAG '%s' for environment '%s'",
127            dag_id,
128            plan_dag_spec.environment_name,
129        )
130
131        all_snapshots = {
132            **{s.snapshot_id: s for s in plan_dag_spec.new_snapshots},
133            **self._snapshots,
134        }
135
136        with DAG(
137            dag_id=dag_id,
138            schedule_interval="@once",
139            start_date=now(),
140            max_active_tasks=plan_dag_spec.backfill_concurrent_tasks,
141            catchup=False,
142            is_paused_upon_creation=False,
143            default_args=DAG_DEFAULT_ARGS,
144            tags=[
145                common.SQLMESH_AIRFLOW_TAG,
146                common.PLAN_AIRFLOW_TAG,
147                plan_dag_spec.environment_name,
148            ],
149        ) as dag:
150            start_task = EmptyOperator(task_id="plan_application_start")
151            end_task = EmptyOperator(task_id="plan_application_end")
152
153            (create_start_task, create_end_task) = self._create_creation_tasks(
154                plan_dag_spec.new_snapshots, plan_dag_spec.ddl_concurrent_tasks
155            )
156
157            (backfill_start_task, backfill_end_task) = self._create_backfill_tasks(
158                plan_dag_spec.backfill_intervals_per_snapshot,
159                all_snapshots,
160                plan_dag_spec.is_dev,
161            )
162
163            (
164                promote_start_task,
165                promote_end_task,
166            ) = self._create_promotion_demotion_tasks(plan_dag_spec)
167
168            start_task >> create_start_task
169            create_end_task >> backfill_start_task
170            backfill_end_task >> promote_start_task
171
172            self._add_notification_target_tasks(
173                plan_dag_spec, start_task, end_task, promote_end_task
174            )
175            return dag
176
177    def _add_notification_target_tasks(
178        self,
179        request: common.PlanDagSpec,
180        start_task: BaseOperator,
181        end_task: BaseOperator,
182        promote_end_task: BaseOperator,
183    ) -> None:
184        has_success_or_failed_notification = False
185        for notification_target in request.notification_targets:
186            notification_operator_provider = NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER.get(
187                type(notification_target)
188            )
189            if not notification_operator_provider:
190                continue
191            plan_start_notification_task = notification_operator_provider.operator(
192                notification_target, PlanStatus.STARTED, request
193            )
194            plan_success_notification_task = notification_operator_provider.operator(
195                notification_target, PlanStatus.FINISHED, request
196            )
197            plan_failed_notification_task = notification_operator_provider.operator(
198                notification_target, PlanStatus.FAILED, request
199            )
200            if plan_start_notification_task:
201                start_task >> plan_start_notification_task
202            if plan_success_notification_task:
203                has_success_or_failed_notification = True
204                promote_end_task >> plan_success_notification_task
205                plan_success_notification_task >> end_task
206            if plan_failed_notification_task:
207                has_success_or_failed_notification = True
208                promote_end_task >> plan_failed_notification_task
209                plan_failed_notification_task >> end_task
210        if not has_success_or_failed_notification:
211            promote_end_task >> end_task
212
213    def _create_creation_tasks(
214        self, new_snapshots: t.List[Snapshot], ddl_concurrent_tasks: int
215    ) -> t.Tuple[BaseOperator, BaseOperator]:
216        start_task = EmptyOperator(task_id="snapshot_creation_start")
217        end_task = EmptyOperator(task_id="snapshot_creation_end")
218
219        if not new_snapshots:
220            start_task >> end_task
221            return (start_task, end_task)
222
223        creation_task = self._create_snapshot_create_tables_operator(
224            new_snapshots, ddl_concurrent_tasks, "snapshot_creation__create_tables"
225        )
226
227        update_state_task = PythonOperator(
228            task_id="snapshot_creation__update_state",
229            python_callable=creation_update_state_task,
230            op_kwargs={"new_snapshots": new_snapshots},
231        )
232
233        start_task >> creation_task
234        creation_task >> update_state_task
235        update_state_task >> end_task
236
237        return (start_task, end_task)
238
239    def _create_promotion_demotion_tasks(
240        self, request: common.PlanDagSpec
241    ) -> t.Tuple[BaseOperator, BaseOperator]:
242        start_task = EmptyOperator(task_id="snapshot_promotion_start")
243        end_task = EmptyOperator(task_id="snapshot_promotion_end")
244
245        update_state_task = PythonOperator(
246            task_id="snapshot_promotion__update_state",
247            python_callable=promotion_update_state_task,
248            op_kwargs={
249                "snapshots": request.promoted_snapshots,
250                "environment_name": request.environment_name,
251                "start": request.start,
252                "end": request.end,
253                "unpaused_dt": request.unpaused_dt,
254                "no_gaps": request.no_gaps,
255                "plan_id": request.plan_id,
256                "previous_plan_id": request.previous_plan_id,
257                "environment_expiration_ts": request.environment_expiration_ts,
258            },
259        )
260
261        start_task >> update_state_task
262
263        if request.promoted_snapshots:
264            create_views_task = self._create_snapshot_promotion_operator(
265                request.promoted_snapshots,
266                request.environment_name,
267                request.ddl_concurrent_tasks,
268                request.is_dev,
269                "snapshot_promotion__create_views",
270            )
271            create_views_task >> end_task
272
273            if not request.is_dev and request.unpaused_dt:
274                migrate_tables_task = self._create_snapshot_migrate_tables_operator(
275                    request.promoted_snapshots,
276                    request.ddl_concurrent_tasks,
277                    "snapshot_promotion__migrate_tables",
278                )
279                update_state_task >> migrate_tables_task
280                migrate_tables_task >> create_views_task
281            else:
282                update_state_task >> create_views_task
283
284        if request.demoted_snapshots:
285            delete_views_task = self._create_snapshot_demotion_operator(
286                request.demoted_snapshots,
287                request.environment_name,
288                request.ddl_concurrent_tasks,
289                "snapshot_promotion__delete_views",
290            )
291            update_state_task >> delete_views_task
292            delete_views_task >> end_task
293
294        if not request.promoted_snapshots and not request.demoted_snapshots:
295            update_state_task >> end_task
296
297        return (start_task, end_task)
298
299    def _create_backfill_tasks(
300        self,
301        backfill_intervals: t.List[common.BackfillIntervalsPerSnapshot],
302        snapshots: t.Dict[SnapshotId, Snapshot],
303        is_dev: bool,
304    ) -> t.Tuple[BaseOperator, BaseOperator]:
305        snapshot_to_tasks = {}
306        for intervals_per_snapshot in backfill_intervals:
307            sid = intervals_per_snapshot.snapshot_id
308
309            if not intervals_per_snapshot.intervals:
310                logger.info(f"Skipping backfill for snapshot %s", sid)
311                continue
312
313            snapshot = snapshots[sid]
314
315            task_id_prefix = f"snapshot_evaluator__{snapshot.name}__{snapshot.identifier}"
316            tasks = [
317                self._create_snapshot_evaluator_operator(
318                    snapshots=snapshots,
319                    snapshot=snapshot,
320                    task_id=f"{task_id_prefix}__{start.strftime(TASK_ID_DATE_FORMAT)}__{end.strftime(TASK_ID_DATE_FORMAT)}",
321                    start=start,
322                    end=end,
323                    is_dev=is_dev,
324                )
325                for (start, end) in intervals_per_snapshot.intervals
326            ]
327            snapshot_start_task = EmptyOperator(
328                task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__start"
329            )
330            snapshot_end_task = EmptyOperator(
331                task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__end"
332            )
333            if snapshot.is_incremental_by_unique_key_kind:
334                baseoperator.chain(snapshot_start_task, *tasks, snapshot_end_task)
335            else:
336                snapshot_start_task >> tasks >> snapshot_end_task
337            snapshot_to_tasks[snapshot.snapshot_id] = (
338                snapshot_start_task,
339                snapshot_end_task,
340            )
341
342        backfill_start_task = EmptyOperator(task_id="snapshot_backfill_start")
343        backfill_end_task = EmptyOperator(task_id="snapshot_backfill_end")
344
345        if not snapshot_to_tasks:
346            backfill_start_task >> backfill_end_task
347            return (backfill_start_task, backfill_end_task)
348
349        entry_tasks = []
350        parent_ids_to_backfill = set()
351        for sid, (start_task, _) in snapshot_to_tasks.items():
352            has_parents_to_backfill = False
353            for p_sid in snapshots[sid].parents:
354                if p_sid in snapshot_to_tasks:
355                    snapshot_to_tasks[p_sid][1] >> start_task
356                    parent_ids_to_backfill.add(p_sid)
357                    has_parents_to_backfill = True
358
359            if not has_parents_to_backfill:
360                entry_tasks.append(start_task)
361
362        backfill_start_task >> entry_tasks
363
364        exit_tasks = [
365            end_task
366            for sid, (_, end_task) in snapshot_to_tasks.items()
367            if sid not in parent_ids_to_backfill
368        ]
369        for task in exit_tasks:
370            task >> backfill_end_task
371
372        return (backfill_start_task, backfill_end_task)
373
374    def _create_snapshot_promotion_operator(
375        self,
376        snapshots: t.List[SnapshotTableInfo],
377        environment: str,
378        ddl_concurrent_tasks: int,
379        is_dev: bool,
380        task_id: str,
381    ) -> BaseOperator:
382        return self._ddl_engine_operator(
383            **self._ddl_engine_operator_args,
384            target=targets.SnapshotPromotionTarget(
385                snapshots=snapshots,
386                environment=environment,
387                ddl_concurrent_tasks=ddl_concurrent_tasks,
388                is_dev=is_dev,
389            ),
390            task_id=task_id,
391        )
392
393    def _create_snapshot_demotion_operator(
394        self,
395        snapshots: t.List[SnapshotTableInfo],
396        environment: str,
397        ddl_concurrent_tasks: int,
398        task_id: str,
399    ) -> BaseOperator:
400        return self._ddl_engine_operator(
401            **self._ddl_engine_operator_args,
402            target=targets.SnapshotDemotionTarget(
403                snapshots=snapshots,
404                environment=environment,
405                ddl_concurrent_tasks=ddl_concurrent_tasks,
406            ),
407            task_id=task_id,
408        )
409
410    def _create_snapshot_create_tables_operator(
411        self,
412        new_snapshots: t.List[Snapshot],
413        ddl_concurrent_tasks: int,
414        task_id: str,
415    ) -> BaseOperator:
416        return self._ddl_engine_operator(
417            **self._ddl_engine_operator_args,
418            target=targets.SnapshotCreateTablesTarget(
419                new_snapshots=new_snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks
420            ),
421            task_id=task_id,
422        )
423
424    def _create_snapshot_migrate_tables_operator(
425        self,
426        snapshots: t.List[SnapshotTableInfo],
427        ddl_concurrent_tasks: int,
428        task_id: str,
429    ) -> BaseOperator:
430        return self._ddl_engine_operator(
431            **self._ddl_engine_operator_args,
432            target=targets.SnapshotMigrateTablesTarget(
433                snapshots=snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks
434            ),
435            task_id=task_id,
436        )
437
438    def _create_snapshot_evaluator_operator(
439        self,
440        snapshots: t.Dict[SnapshotId, Snapshot],
441        snapshot: Snapshot,
442        task_id: str,
443        start: t.Optional[TimeLike] = None,
444        end: t.Optional[TimeLike] = None,
445        is_dev: bool = False,
446    ) -> BaseOperator:
447        parent_snapshots = {sid.name: snapshots[sid] for sid in snapshot.parents}
448
449        return self._engine_operator(
450            **self._engine_operator_args,
451            target=targets.SnapshotEvaluationTarget(
452                snapshot=snapshot,
453                parent_snapshots=parent_snapshots,
454                start=start,
455                end=end,
456                is_dev=is_dev,
457            ),
458            task_id=task_id,
459        )
460
461    def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor]:
462        output = []
463        for upstream_snapshot_id in snapshot.parents:
464            upstream_snapshot = self._snapshots[upstream_snapshot_id]
465            if not upstream_snapshot.is_embedded_kind and not upstream_snapshot.is_seed_kind:
466                output.append(
467                    HighWaterMarkSensor(
468                        target_snapshot_info=upstream_snapshot.table_info,
469                        this_snapshot=snapshot,
470                        task_id=f"{upstream_snapshot.name}_{upstream_snapshot.version}_high_water_mark_sensor",
471                    )
472                )
473        return output
474
475
476def creation_update_state_task(new_snapshots: t.Iterable[Snapshot]) -> None:
477    with util.scoped_state_sync() as state_sync:
478        state_sync.push_snapshots(new_snapshots)
479
480
481def promotion_update_state_task(
482    snapshots: t.List[SnapshotTableInfo],
483    environment_name: str,
484    start: TimeLike,
485    end: t.Optional[TimeLike],
486    unpaused_dt: t.Optional[TimeLike],
487    no_gaps: bool,
488    plan_id: str,
489    previous_plan_id: t.Optional[str],
490    environment_expiration_ts: t.Optional[int],
491) -> None:
492    environment = Environment(
493        name=environment_name,
494        snapshots=snapshots,
495        start_at=start,
496        end_at=end,
497        plan_id=plan_id,
498        previous_plan_id=previous_plan_id,
499        expiration_ts=environment_expiration_ts,
500    )
501    with util.scoped_state_sync() as state_sync:
502        state_sync.promote(environment, no_gaps=no_gaps)
503        if snapshots and not end and unpaused_dt:
504            state_sync.unpause_snapshots(snapshots, unpaused_dt)
class SnapshotDagGenerator:
 54class SnapshotDagGenerator:
 55    def __init__(
 56        self,
 57        engine_operator: t.Type[BaseOperator],
 58        engine_operator_args: t.Optional[t.Dict[str, t.Any]],
 59        ddl_engine_operator: t.Type[BaseOperator],
 60        ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]],
 61        snapshots: t.Dict[SnapshotId, Snapshot],
 62    ):
 63        self._engine_operator = engine_operator
 64        self._engine_operator_args = engine_operator_args or {}
 65        self._ddl_engine_operator = ddl_engine_operator
 66        self._ddl_engine_operator_args = ddl_engine_operator_args or {}
 67        self._snapshots = snapshots
 68
 69    def generate_cadence_dags(self) -> t.List[DAG]:
 70        return [
 71            self._create_cadence_dag_for_snapshot(s)
 72            for s in self._snapshots.values()
 73            if s.unpaused_ts and not s.is_embedded_kind and not s.is_seed_kind
 74        ]
 75
 76    def generate_plan_application_dag(self, spec: common.PlanDagSpec) -> DAG:
 77        return self._create_plan_application_dag(spec)
 78
 79    def _create_cadence_dag_for_snapshot(self, snapshot: Snapshot) -> DAG:
 80        dag_id = common.dag_id_for_snapshot_info(snapshot.table_info)
 81        logger.info(
 82            "Generating the cadence DAG '%s' for snapshot %s",
 83            dag_id,
 84            snapshot.snapshot_id,
 85        )
 86
 87        if not snapshot.unpaused_ts:
 88            raise SQLMeshError(
 89                f"Can't create a cadence DAG for the paused snapshot {snapshot.snapshot_id}"
 90            )
 91
 92        with DAG(
 93            dag_id=dag_id,
 94            schedule_interval=snapshot.model.cron,
 95            start_date=to_datetime(snapshot.unpaused_ts),
 96            max_active_runs=1,
 97            catchup=True,
 98            is_paused_upon_creation=False,
 99            tags=[
100                common.SQLMESH_AIRFLOW_TAG,
101                common.SNAPSHOT_AIRFLOW_TAG,
102                snapshot.name,
103            ],
104            default_args={
105                **DAG_DEFAULT_ARGS,
106                "email": snapshot.model.owner,
107                "email_on_failure": True,
108            },
109        ) as dag:
110            hwm_sensor_tasks = self._create_hwm_sensors(snapshot=snapshot)
111
112            evaluator_task = self._create_snapshot_evaluator_operator(
113                snapshots=self._snapshots,
114                snapshot=snapshot,
115                task_id="snapshot_evaluator",
116            )
117
118            hwm_sensor_tasks >> evaluator_task
119
120            return dag
121
122    def _create_plan_application_dag(self, plan_dag_spec: common.PlanDagSpec) -> DAG:
123        dag_id = common.plan_application_dag_id(
124            plan_dag_spec.environment_name, plan_dag_spec.request_id
125        )
126        logger.info(
127            "Generating the plan application DAG '%s' for environment '%s'",
128            dag_id,
129            plan_dag_spec.environment_name,
130        )
131
132        all_snapshots = {
133            **{s.snapshot_id: s for s in plan_dag_spec.new_snapshots},
134            **self._snapshots,
135        }
136
137        with DAG(
138            dag_id=dag_id,
139            schedule_interval="@once",
140            start_date=now(),
141            max_active_tasks=plan_dag_spec.backfill_concurrent_tasks,
142            catchup=False,
143            is_paused_upon_creation=False,
144            default_args=DAG_DEFAULT_ARGS,
145            tags=[
146                common.SQLMESH_AIRFLOW_TAG,
147                common.PLAN_AIRFLOW_TAG,
148                plan_dag_spec.environment_name,
149            ],
150        ) as dag:
151            start_task = EmptyOperator(task_id="plan_application_start")
152            end_task = EmptyOperator(task_id="plan_application_end")
153
154            (create_start_task, create_end_task) = self._create_creation_tasks(
155                plan_dag_spec.new_snapshots, plan_dag_spec.ddl_concurrent_tasks
156            )
157
158            (backfill_start_task, backfill_end_task) = self._create_backfill_tasks(
159                plan_dag_spec.backfill_intervals_per_snapshot,
160                all_snapshots,
161                plan_dag_spec.is_dev,
162            )
163
164            (
165                promote_start_task,
166                promote_end_task,
167            ) = self._create_promotion_demotion_tasks(plan_dag_spec)
168
169            start_task >> create_start_task
170            create_end_task >> backfill_start_task
171            backfill_end_task >> promote_start_task
172
173            self._add_notification_target_tasks(
174                plan_dag_spec, start_task, end_task, promote_end_task
175            )
176            return dag
177
178    def _add_notification_target_tasks(
179        self,
180        request: common.PlanDagSpec,
181        start_task: BaseOperator,
182        end_task: BaseOperator,
183        promote_end_task: BaseOperator,
184    ) -> None:
185        has_success_or_failed_notification = False
186        for notification_target in request.notification_targets:
187            notification_operator_provider = NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER.get(
188                type(notification_target)
189            )
190            if not notification_operator_provider:
191                continue
192            plan_start_notification_task = notification_operator_provider.operator(
193                notification_target, PlanStatus.STARTED, request
194            )
195            plan_success_notification_task = notification_operator_provider.operator(
196                notification_target, PlanStatus.FINISHED, request
197            )
198            plan_failed_notification_task = notification_operator_provider.operator(
199                notification_target, PlanStatus.FAILED, request
200            )
201            if plan_start_notification_task:
202                start_task >> plan_start_notification_task
203            if plan_success_notification_task:
204                has_success_or_failed_notification = True
205                promote_end_task >> plan_success_notification_task
206                plan_success_notification_task >> end_task
207            if plan_failed_notification_task:
208                has_success_or_failed_notification = True
209                promote_end_task >> plan_failed_notification_task
210                plan_failed_notification_task >> end_task
211        if not has_success_or_failed_notification:
212            promote_end_task >> end_task
213
214    def _create_creation_tasks(
215        self, new_snapshots: t.List[Snapshot], ddl_concurrent_tasks: int
216    ) -> t.Tuple[BaseOperator, BaseOperator]:
217        start_task = EmptyOperator(task_id="snapshot_creation_start")
218        end_task = EmptyOperator(task_id="snapshot_creation_end")
219
220        if not new_snapshots:
221            start_task >> end_task
222            return (start_task, end_task)
223
224        creation_task = self._create_snapshot_create_tables_operator(
225            new_snapshots, ddl_concurrent_tasks, "snapshot_creation__create_tables"
226        )
227
228        update_state_task = PythonOperator(
229            task_id="snapshot_creation__update_state",
230            python_callable=creation_update_state_task,
231            op_kwargs={"new_snapshots": new_snapshots},
232        )
233
234        start_task >> creation_task
235        creation_task >> update_state_task
236        update_state_task >> end_task
237
238        return (start_task, end_task)
239
240    def _create_promotion_demotion_tasks(
241        self, request: common.PlanDagSpec
242    ) -> t.Tuple[BaseOperator, BaseOperator]:
243        start_task = EmptyOperator(task_id="snapshot_promotion_start")
244        end_task = EmptyOperator(task_id="snapshot_promotion_end")
245
246        update_state_task = PythonOperator(
247            task_id="snapshot_promotion__update_state",
248            python_callable=promotion_update_state_task,
249            op_kwargs={
250                "snapshots": request.promoted_snapshots,
251                "environment_name": request.environment_name,
252                "start": request.start,
253                "end": request.end,
254                "unpaused_dt": request.unpaused_dt,
255                "no_gaps": request.no_gaps,
256                "plan_id": request.plan_id,
257                "previous_plan_id": request.previous_plan_id,
258                "environment_expiration_ts": request.environment_expiration_ts,
259            },
260        )
261
262        start_task >> update_state_task
263
264        if request.promoted_snapshots:
265            create_views_task = self._create_snapshot_promotion_operator(
266                request.promoted_snapshots,
267                request.environment_name,
268                request.ddl_concurrent_tasks,
269                request.is_dev,
270                "snapshot_promotion__create_views",
271            )
272            create_views_task >> end_task
273
274            if not request.is_dev and request.unpaused_dt:
275                migrate_tables_task = self._create_snapshot_migrate_tables_operator(
276                    request.promoted_snapshots,
277                    request.ddl_concurrent_tasks,
278                    "snapshot_promotion__migrate_tables",
279                )
280                update_state_task >> migrate_tables_task
281                migrate_tables_task >> create_views_task
282            else:
283                update_state_task >> create_views_task
284
285        if request.demoted_snapshots:
286            delete_views_task = self._create_snapshot_demotion_operator(
287                request.demoted_snapshots,
288                request.environment_name,
289                request.ddl_concurrent_tasks,
290                "snapshot_promotion__delete_views",
291            )
292            update_state_task >> delete_views_task
293            delete_views_task >> end_task
294
295        if not request.promoted_snapshots and not request.demoted_snapshots:
296            update_state_task >> end_task
297
298        return (start_task, end_task)
299
300    def _create_backfill_tasks(
301        self,
302        backfill_intervals: t.List[common.BackfillIntervalsPerSnapshot],
303        snapshots: t.Dict[SnapshotId, Snapshot],
304        is_dev: bool,
305    ) -> t.Tuple[BaseOperator, BaseOperator]:
306        snapshot_to_tasks = {}
307        for intervals_per_snapshot in backfill_intervals:
308            sid = intervals_per_snapshot.snapshot_id
309
310            if not intervals_per_snapshot.intervals:
311                logger.info(f"Skipping backfill for snapshot %s", sid)
312                continue
313
314            snapshot = snapshots[sid]
315
316            task_id_prefix = f"snapshot_evaluator__{snapshot.name}__{snapshot.identifier}"
317            tasks = [
318                self._create_snapshot_evaluator_operator(
319                    snapshots=snapshots,
320                    snapshot=snapshot,
321                    task_id=f"{task_id_prefix}__{start.strftime(TASK_ID_DATE_FORMAT)}__{end.strftime(TASK_ID_DATE_FORMAT)}",
322                    start=start,
323                    end=end,
324                    is_dev=is_dev,
325                )
326                for (start, end) in intervals_per_snapshot.intervals
327            ]
328            snapshot_start_task = EmptyOperator(
329                task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__start"
330            )
331            snapshot_end_task = EmptyOperator(
332                task_id=f"snapshot_backfill__{snapshot.name}__{snapshot.identifier}__end"
333            )
334            if snapshot.is_incremental_by_unique_key_kind:
335                baseoperator.chain(snapshot_start_task, *tasks, snapshot_end_task)
336            else:
337                snapshot_start_task >> tasks >> snapshot_end_task
338            snapshot_to_tasks[snapshot.snapshot_id] = (
339                snapshot_start_task,
340                snapshot_end_task,
341            )
342
343        backfill_start_task = EmptyOperator(task_id="snapshot_backfill_start")
344        backfill_end_task = EmptyOperator(task_id="snapshot_backfill_end")
345
346        if not snapshot_to_tasks:
347            backfill_start_task >> backfill_end_task
348            return (backfill_start_task, backfill_end_task)
349
350        entry_tasks = []
351        parent_ids_to_backfill = set()
352        for sid, (start_task, _) in snapshot_to_tasks.items():
353            has_parents_to_backfill = False
354            for p_sid in snapshots[sid].parents:
355                if p_sid in snapshot_to_tasks:
356                    snapshot_to_tasks[p_sid][1] >> start_task
357                    parent_ids_to_backfill.add(p_sid)
358                    has_parents_to_backfill = True
359
360            if not has_parents_to_backfill:
361                entry_tasks.append(start_task)
362
363        backfill_start_task >> entry_tasks
364
365        exit_tasks = [
366            end_task
367            for sid, (_, end_task) in snapshot_to_tasks.items()
368            if sid not in parent_ids_to_backfill
369        ]
370        for task in exit_tasks:
371            task >> backfill_end_task
372
373        return (backfill_start_task, backfill_end_task)
374
375    def _create_snapshot_promotion_operator(
376        self,
377        snapshots: t.List[SnapshotTableInfo],
378        environment: str,
379        ddl_concurrent_tasks: int,
380        is_dev: bool,
381        task_id: str,
382    ) -> BaseOperator:
383        return self._ddl_engine_operator(
384            **self._ddl_engine_operator_args,
385            target=targets.SnapshotPromotionTarget(
386                snapshots=snapshots,
387                environment=environment,
388                ddl_concurrent_tasks=ddl_concurrent_tasks,
389                is_dev=is_dev,
390            ),
391            task_id=task_id,
392        )
393
394    def _create_snapshot_demotion_operator(
395        self,
396        snapshots: t.List[SnapshotTableInfo],
397        environment: str,
398        ddl_concurrent_tasks: int,
399        task_id: str,
400    ) -> BaseOperator:
401        return self._ddl_engine_operator(
402            **self._ddl_engine_operator_args,
403            target=targets.SnapshotDemotionTarget(
404                snapshots=snapshots,
405                environment=environment,
406                ddl_concurrent_tasks=ddl_concurrent_tasks,
407            ),
408            task_id=task_id,
409        )
410
411    def _create_snapshot_create_tables_operator(
412        self,
413        new_snapshots: t.List[Snapshot],
414        ddl_concurrent_tasks: int,
415        task_id: str,
416    ) -> BaseOperator:
417        return self._ddl_engine_operator(
418            **self._ddl_engine_operator_args,
419            target=targets.SnapshotCreateTablesTarget(
420                new_snapshots=new_snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks
421            ),
422            task_id=task_id,
423        )
424
425    def _create_snapshot_migrate_tables_operator(
426        self,
427        snapshots: t.List[SnapshotTableInfo],
428        ddl_concurrent_tasks: int,
429        task_id: str,
430    ) -> BaseOperator:
431        return self._ddl_engine_operator(
432            **self._ddl_engine_operator_args,
433            target=targets.SnapshotMigrateTablesTarget(
434                snapshots=snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks
435            ),
436            task_id=task_id,
437        )
438
439    def _create_snapshot_evaluator_operator(
440        self,
441        snapshots: t.Dict[SnapshotId, Snapshot],
442        snapshot: Snapshot,
443        task_id: str,
444        start: t.Optional[TimeLike] = None,
445        end: t.Optional[TimeLike] = None,
446        is_dev: bool = False,
447    ) -> BaseOperator:
448        parent_snapshots = {sid.name: snapshots[sid] for sid in snapshot.parents}
449
450        return self._engine_operator(
451            **self._engine_operator_args,
452            target=targets.SnapshotEvaluationTarget(
453                snapshot=snapshot,
454                parent_snapshots=parent_snapshots,
455                start=start,
456                end=end,
457                is_dev=is_dev,
458            ),
459            task_id=task_id,
460        )
461
462    def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor]:
463        output = []
464        for upstream_snapshot_id in snapshot.parents:
465            upstream_snapshot = self._snapshots[upstream_snapshot_id]
466            if not upstream_snapshot.is_embedded_kind and not upstream_snapshot.is_seed_kind:
467                output.append(
468                    HighWaterMarkSensor(
469                        target_snapshot_info=upstream_snapshot.table_info,
470                        this_snapshot=snapshot,
471                        task_id=f"{upstream_snapshot.name}_{upstream_snapshot.version}_high_water_mark_sensor",
472                    )
473                )
474        return output
SnapshotDagGenerator( engine_operator: Type[airflow.models.baseoperator.BaseOperator], engine_operator_args: Optional[Dict[str, Any]], ddl_engine_operator: Type[airflow.models.baseoperator.BaseOperator], ddl_engine_operator_args: Optional[Dict[str, Any]], snapshots: Dict[sqlmesh.core.snapshot.definition.SnapshotId, sqlmesh.core.snapshot.definition.Snapshot])
55    def __init__(
56        self,
57        engine_operator: t.Type[BaseOperator],
58        engine_operator_args: t.Optional[t.Dict[str, t.Any]],
59        ddl_engine_operator: t.Type[BaseOperator],
60        ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]],
61        snapshots: t.Dict[SnapshotId, Snapshot],
62    ):
63        self._engine_operator = engine_operator
64        self._engine_operator_args = engine_operator_args or {}
65        self._ddl_engine_operator = ddl_engine_operator
66        self._ddl_engine_operator_args = ddl_engine_operator_args or {}
67        self._snapshots = snapshots
def generate_cadence_dags(self) -> List[airflow.models.dag.DAG]:
69    def generate_cadence_dags(self) -> t.List[DAG]:
70        return [
71            self._create_cadence_dag_for_snapshot(s)
72            for s in self._snapshots.values()
73            if s.unpaused_ts and not s.is_embedded_kind and not s.is_seed_kind
74        ]
def generate_plan_application_dag( self, spec: sqlmesh.schedulers.airflow.common.PlanDagSpec) -> airflow.models.dag.DAG:
76    def generate_plan_application_dag(self, spec: common.PlanDagSpec) -> DAG:
77        return self._create_plan_application_dag(spec)
def creation_update_state_task( new_snapshots: Iterable[sqlmesh.core.snapshot.definition.Snapshot]) -> None:
477def creation_update_state_task(new_snapshots: t.Iterable[Snapshot]) -> None:
478    with util.scoped_state_sync() as state_sync:
479        state_sync.push_snapshots(new_snapshots)
def promotion_update_state_task( snapshots: List[sqlmesh.core.snapshot.definition.SnapshotTableInfo], environment_name: str, start: Union[datetime.date, datetime.datetime, str, int, float], end: Union[datetime.date, datetime.datetime, str, int, float, NoneType], unpaused_dt: Union[datetime.date, datetime.datetime, str, int, float, NoneType], no_gaps: bool, plan_id: str, previous_plan_id: Optional[str], environment_expiration_ts: Optional[int]) -> None:
482def promotion_update_state_task(
483    snapshots: t.List[SnapshotTableInfo],
484    environment_name: str,
485    start: TimeLike,
486    end: t.Optional[TimeLike],
487    unpaused_dt: t.Optional[TimeLike],
488    no_gaps: bool,
489    plan_id: str,
490    previous_plan_id: t.Optional[str],
491    environment_expiration_ts: t.Optional[int],
492) -> None:
493    environment = Environment(
494        name=environment_name,
495        snapshots=snapshots,
496        start_at=start,
497        end_at=end,
498        plan_id=plan_id,
499        previous_plan_id=previous_plan_id,
500        expiration_ts=environment_expiration_ts,
501    )
502    with util.scoped_state_sync() as state_sync:
503        state_sync.promote(environment, no_gaps=no_gaps)
504        if snapshots and not end and unpaused_dt:
505            state_sync.unpause_snapshots(snapshots, unpaused_dt)