Edit on GitHub

sqlmesh.schedulers.airflow.operators.notification

 1import abc
 2import typing as t
 3
 4from airflow.models import BaseOperator
 5
 6from sqlmesh.core.notification_target import BaseNotificationTarget
 7from sqlmesh.core.plan import PlanStatus
 8from sqlmesh.schedulers.airflow import common
 9
10NT = t.TypeVar("NT", bound=BaseNotificationTarget)
11
12
13class BaseNotificationOperatorProvider(abc.ABC, t.Generic[NT]):
14    @abc.abstractmethod
15    def operator(
16        self,
17        target: NT,
18        plan_status: PlanStatus,
19        plan_dag_spec: common.PlanDagSpec,
20        **dag_kwargs: t.Any,
21    ) -> t.Optional[BaseOperator]:
22        pass
23
24    def get_trigger_rule(self, plan_status: PlanStatus) -> str:
25        if plan_status.is_failed:
26            return "one_failed"
27        return "all_success"
28
29    def get_task_id(self, target: NT, plan_status: PlanStatus) -> str:
30        return f"plan_{plan_status.value}_{target.type_}_notification"
class BaseNotificationOperatorProvider(abc.ABC, typing.Generic[~NT]):
14class BaseNotificationOperatorProvider(abc.ABC, t.Generic[NT]):
15    @abc.abstractmethod
16    def operator(
17        self,
18        target: NT,
19        plan_status: PlanStatus,
20        plan_dag_spec: common.PlanDagSpec,
21        **dag_kwargs: t.Any,
22    ) -> t.Optional[BaseOperator]:
23        pass
24
25    def get_trigger_rule(self, plan_status: PlanStatus) -> str:
26        if plan_status.is_failed:
27            return "one_failed"
28        return "all_success"
29
30    def get_task_id(self, target: NT, plan_status: PlanStatus) -> str:
31        return f"plan_{plan_status.value}_{target.type_}_notification"

Helper class that provides a standard way to create an ABC using inheritance.

@abc.abstractmethod
def operator( self, target: ~NT, plan_status: sqlmesh.core.plan.definition.PlanStatus, plan_dag_spec: sqlmesh.schedulers.airflow.common.PlanDagSpec, **dag_kwargs: Any) -> Optional[airflow.models.baseoperator.BaseOperator]:
15    @abc.abstractmethod
16    def operator(
17        self,
18        target: NT,
19        plan_status: PlanStatus,
20        plan_dag_spec: common.PlanDagSpec,
21        **dag_kwargs: t.Any,
22    ) -> t.Optional[BaseOperator]:
23        pass
def get_trigger_rule(self, plan_status: sqlmesh.core.plan.definition.PlanStatus) -> str:
25    def get_trigger_rule(self, plan_status: PlanStatus) -> str:
26        if plan_status.is_failed:
27            return "one_failed"
28        return "all_success"
def get_task_id( self, target: ~NT, plan_status: sqlmesh.core.plan.definition.PlanStatus) -> str:
30    def get_task_id(self, target: NT, plan_status: PlanStatus) -> str:
31        return f"plan_{plan_status.value}_{target.type_}_notification"