Edit on GitHub

sqlmesh.utils.concurrency

  1import typing as t
  2from concurrent.futures import Executor, Future, ThreadPoolExecutor
  3from threading import Lock
  4
  5from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike
  6from sqlmesh.utils.dag import DAG
  7from sqlmesh.utils.errors import ConfigError, SQLMeshError
  8
  9H = t.TypeVar("H", bound=t.Hashable)
 10S = t.TypeVar("S", bound=SnapshotInfoLike)
 11
 12
 13class NodeExecutionFailedError(t.Generic[H], SQLMeshError):
 14    def __init__(self, node: H):
 15        self.node = node
 16        super().__init__(f"Execution failed for node {node}")
 17
 18
 19class ConcurrentDAGExecutor(t.Generic[H]):
 20    """Concurrently traverses the given DAG in topological order while applying a function to each node.
 21
 22    If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes.
 23
 24    Args:
 25        dag: The target DAG.
 26        fn: The function that will be applied concurrently to each snapshot.
 27        tasks_num: The number of concurrent tasks.
 28        raise_on_error: If set to True raises an exception on a first encountered error,
 29            otherwises returns a tuple which contains a list of failed nodes and a list of
 30            skipped nodes.
 31    """
 32
 33    def __init__(
 34        self,
 35        dag: DAG[H],
 36        fn: t.Callable[[H], None],
 37        tasks_num: int,
 38        raise_on_error: bool,
 39    ):
 40        self.dag = dag
 41        self.fn = fn
 42        self.tasks_num = tasks_num
 43        self.raise_on_error = raise_on_error
 44
 45        self._init_state()
 46
 47    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
 48        """Runs the executor.
 49
 50        Raises:
 51            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
 52
 53        Returns:
 54            A pair which contains a list of node errors and a list of skipped nodes.
 55        """
 56        if self._finished_future.done():
 57            self._init_state()
 58
 59        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
 60            with self._unprocessed_nodes_lock:
 61                self._submit_next_nodes(pool)
 62            self._finished_future.result()
 63        return self._node_errors, self._skipped_nodes
 64
 65    def _process_node(self, node: H, executor: Executor) -> None:
 66        try:
 67            self.fn(node)
 68
 69            with self._unprocessed_nodes_lock:
 70                self._unprocessed_nodes_num -= 1
 71                self._submit_next_nodes(executor, node)
 72        except Exception as ex:
 73            error = NodeExecutionFailedError(node)
 74            error.__cause__ = ex
 75
 76            if self.raise_on_error:
 77                self._finished_future.set_exception(error)
 78                return
 79
 80            with self._unprocessed_nodes_lock:
 81                self._unprocessed_nodes_num -= 1
 82                self._node_errors.append(error)
 83                self._skip_next_nodes(node)
 84
 85    def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None:
 86        if not self._unprocessed_nodes_num:
 87            self._finished_future.set_result(None)
 88            return
 89
 90        submitted_nodes = []
 91        for next_node, deps in self._unprocessed_nodes.items():
 92            if processed_node:
 93                deps.discard(processed_node)
 94            if not deps:
 95                submitted_nodes.append(next_node)
 96
 97        for submitted_node in submitted_nodes:
 98            self._unprocessed_nodes.pop(submitted_node)
 99            executor.submit(self._process_node, submitted_node, executor)
100
101    def _skip_next_nodes(self, parent: H) -> None:
102        if not self._unprocessed_nodes_num:
103            self._finished_future.set_result(None)
104            return
105
106        skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps]
107
108        self._skipped_nodes.extend(skipped_nodes)
109
110        for skipped_node in skipped_nodes:
111            self._unprocessed_nodes_num -= 1
112            self._unprocessed_nodes.pop(skipped_node)
113            self._skip_next_nodes(skipped_node)
114
115    def _init_state(self) -> None:
116        self._unprocessed_nodes = self.dag.graph
117        self._unprocessed_nodes_num = len(self._unprocessed_nodes)
118        self._unprocessed_nodes_lock = Lock()
119        self._finished_future = Future()  # type: ignore
120
121        self._node_errors: t.List[NodeExecutionFailedError[H]] = []
122        self._skipped_nodes: t.List[H] = []
123
124
125def concurrent_apply_to_snapshots(
126    snapshots: t.Iterable[S],
127    fn: t.Callable[[S], None],
128    tasks_num: int,
129    reverse_order: bool = False,
130    raise_on_error: bool = True,
131) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]:
132    """Applies a function to the given collection of snapshots concurrently while
133    preserving the topological order between snapshots.
134
135    Args:
136        snapshots: Target snapshots.
137        fn: The function that will be applied concurrently to each snapshot.
138        tasks_num: The number of concurrent tasks.
139        reverse_order: Whether the order should be reversed. Default: False.
140        raise_on_error: If set to True raises an exception on a first encountered error,
141            otherwises returns a tuple which contains a list of failed nodes and a list of
142            skipped nodes.
143
144    Raises:
145        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
146
147    Returns:
148        A pair which contains a list of errors and a list of skipped snapshot IDs.
149    """
150    snapshots_by_id = {s.snapshot_id: s for s in snapshots}
151
152    dag: DAG[SnapshotId] = DAG[SnapshotId]()
153    for snapshot in snapshots:
154        dag.add(
155            snapshot.snapshot_id,
156            [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id],
157        )
158
159    return concurrent_apply_to_dag(
160        dag if not reverse_order else dag.reversed,
161        lambda s_id: fn(snapshots_by_id[s_id]),
162        tasks_num,
163        raise_on_error=raise_on_error,
164    )
165
166
167def concurrent_apply_to_dag(
168    dag: DAG[H],
169    fn: t.Callable[[H], None],
170    tasks_num: int,
171    raise_on_error: bool = True,
172) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
173    """Applies a function to the given DAG concurrently while preserving the topological
174    order between snapshots.
175
176    Args:
177        dag: The target DAG.
178        fn: The function that will be applied concurrently to each snapshot.
179        tasks_num: The number of concurrent tasks.
180        raise_on_error: If set to True raises an exception on a first encountered error,
181            otherwises returns a tuple which contains a list of failed nodes and a list of
182            skipped nodes.
183
184    Raises:
185        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
186
187    Returns:
188        A pair which contains a list of node errors and a list of skipped nodes.
189    """
190    if tasks_num <= 0:
191        raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
192
193    if tasks_num == 1:
194        return sequential_apply_to_dag(dag, fn, raise_on_error)
195
196    return ConcurrentDAGExecutor(
197        dag,
198        fn,
199        tasks_num,
200        raise_on_error,
201    ).run()
202
203
204def sequential_apply_to_dag(
205    dag: DAG[H],
206    fn: t.Callable[[H], None],
207    raise_on_error: bool = True,
208) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
209    dependencies = dag.graph
210
211    node_errors: t.List[NodeExecutionFailedError[H]] = []
212    skipped_nodes: t.List[H] = []
213
214    failed_or_skipped_nodes: t.Set[H] = set()
215
216    for node in dag.sorted():
217        if not failed_or_skipped_nodes.isdisjoint(dependencies[node]):
218            skipped_nodes.append(node)
219            failed_or_skipped_nodes.add(node)
220            continue
221
222        try:
223            fn(node)
224        except Exception as ex:
225            if raise_on_error:
226                raise NodeExecutionFailedError(node) from ex
227
228            error = NodeExecutionFailedError(node)
229            error.__cause__ = ex
230
231            node_errors.append(error)
232            failed_or_skipped_nodes.add(node)
233
234    return node_errors, skipped_nodes
class NodeExecutionFailedError(typing.Generic[~H], sqlmesh.utils.errors.SQLMeshError):
14class NodeExecutionFailedError(t.Generic[H], SQLMeshError):
15    def __init__(self, node: H):
16        self.node = node
17        super().__init__(f"Execution failed for node {node}")

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

NodeExecutionFailedError(node: ~H)
15    def __init__(self, node: H):
16        self.node = node
17        super().__init__(f"Execution failed for node {node}")
Inherited Members
builtins.BaseException
with_traceback
class ConcurrentDAGExecutor(typing.Generic[~H]):
 20class ConcurrentDAGExecutor(t.Generic[H]):
 21    """Concurrently traverses the given DAG in topological order while applying a function to each node.
 22
 23    If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes.
 24
 25    Args:
 26        dag: The target DAG.
 27        fn: The function that will be applied concurrently to each snapshot.
 28        tasks_num: The number of concurrent tasks.
 29        raise_on_error: If set to True raises an exception on a first encountered error,
 30            otherwises returns a tuple which contains a list of failed nodes and a list of
 31            skipped nodes.
 32    """
 33
 34    def __init__(
 35        self,
 36        dag: DAG[H],
 37        fn: t.Callable[[H], None],
 38        tasks_num: int,
 39        raise_on_error: bool,
 40    ):
 41        self.dag = dag
 42        self.fn = fn
 43        self.tasks_num = tasks_num
 44        self.raise_on_error = raise_on_error
 45
 46        self._init_state()
 47
 48    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
 49        """Runs the executor.
 50
 51        Raises:
 52            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
 53
 54        Returns:
 55            A pair which contains a list of node errors and a list of skipped nodes.
 56        """
 57        if self._finished_future.done():
 58            self._init_state()
 59
 60        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
 61            with self._unprocessed_nodes_lock:
 62                self._submit_next_nodes(pool)
 63            self._finished_future.result()
 64        return self._node_errors, self._skipped_nodes
 65
 66    def _process_node(self, node: H, executor: Executor) -> None:
 67        try:
 68            self.fn(node)
 69
 70            with self._unprocessed_nodes_lock:
 71                self._unprocessed_nodes_num -= 1
 72                self._submit_next_nodes(executor, node)
 73        except Exception as ex:
 74            error = NodeExecutionFailedError(node)
 75            error.__cause__ = ex
 76
 77            if self.raise_on_error:
 78                self._finished_future.set_exception(error)
 79                return
 80
 81            with self._unprocessed_nodes_lock:
 82                self._unprocessed_nodes_num -= 1
 83                self._node_errors.append(error)
 84                self._skip_next_nodes(node)
 85
 86    def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None:
 87        if not self._unprocessed_nodes_num:
 88            self._finished_future.set_result(None)
 89            return
 90
 91        submitted_nodes = []
 92        for next_node, deps in self._unprocessed_nodes.items():
 93            if processed_node:
 94                deps.discard(processed_node)
 95            if not deps:
 96                submitted_nodes.append(next_node)
 97
 98        for submitted_node in submitted_nodes:
 99            self._unprocessed_nodes.pop(submitted_node)
100            executor.submit(self._process_node, submitted_node, executor)
101
102    def _skip_next_nodes(self, parent: H) -> None:
103        if not self._unprocessed_nodes_num:
104            self._finished_future.set_result(None)
105            return
106
107        skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps]
108
109        self._skipped_nodes.extend(skipped_nodes)
110
111        for skipped_node in skipped_nodes:
112            self._unprocessed_nodes_num -= 1
113            self._unprocessed_nodes.pop(skipped_node)
114            self._skip_next_nodes(skipped_node)
115
116    def _init_state(self) -> None:
117        self._unprocessed_nodes = self.dag.graph
118        self._unprocessed_nodes_num = len(self._unprocessed_nodes)
119        self._unprocessed_nodes_lock = Lock()
120        self._finished_future = Future()  # type: ignore
121
122        self._node_errors: t.List[NodeExecutionFailedError[H]] = []
123        self._skipped_nodes: t.List[H] = []

Concurrently traverses the given DAG in topological order while applying a function to each node.

If raise_on_error is set to False maintains a state of execution errors as well as of skipped nodes.

Arguments:
  • dag: The target DAG.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
ConcurrentDAGExecutor( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], tasks_num: int, raise_on_error: bool)
34    def __init__(
35        self,
36        dag: DAG[H],
37        fn: t.Callable[[H], None],
38        tasks_num: int,
39        raise_on_error: bool,
40    ):
41        self.dag = dag
42        self.fn = fn
43        self.tasks_num = tasks_num
44        self.raise_on_error = raise_on_error
45
46        self._init_state()
def run( self) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[~H]], List[~H]]:
48    def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
49        """Runs the executor.
50
51        Raises:
52            NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot.
53
54        Returns:
55            A pair which contains a list of node errors and a list of skipped nodes.
56        """
57        if self._finished_future.done():
58            self._init_state()
59
60        with ThreadPoolExecutor(max_workers=self.tasks_num) as pool:
61            with self._unprocessed_nodes_lock:
62                self._submit_next_nodes(pool)
63            self._finished_future.result()
64        return self._node_errors, self._skipped_nodes

Runs the executor.

Raises:
  • NodeExecutionFailedError if raise_on_error was set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of node errors and a list of skipped nodes.

def concurrent_apply_to_snapshots( snapshots: Iterable[~S], fn: Callable[[~S], NoneType], tasks_num: int, reverse_order: bool = False, raise_on_error: bool = True) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[sqlmesh.core.snapshot.definition.SnapshotId]], List[sqlmesh.core.snapshot.definition.SnapshotId]]:
126def concurrent_apply_to_snapshots(
127    snapshots: t.Iterable[S],
128    fn: t.Callable[[S], None],
129    tasks_num: int,
130    reverse_order: bool = False,
131    raise_on_error: bool = True,
132) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]:
133    """Applies a function to the given collection of snapshots concurrently while
134    preserving the topological order between snapshots.
135
136    Args:
137        snapshots: Target snapshots.
138        fn: The function that will be applied concurrently to each snapshot.
139        tasks_num: The number of concurrent tasks.
140        reverse_order: Whether the order should be reversed. Default: False.
141        raise_on_error: If set to True raises an exception on a first encountered error,
142            otherwises returns a tuple which contains a list of failed nodes and a list of
143            skipped nodes.
144
145    Raises:
146        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
147
148    Returns:
149        A pair which contains a list of errors and a list of skipped snapshot IDs.
150    """
151    snapshots_by_id = {s.snapshot_id: s for s in snapshots}
152
153    dag: DAG[SnapshotId] = DAG[SnapshotId]()
154    for snapshot in snapshots:
155        dag.add(
156            snapshot.snapshot_id,
157            [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id],
158        )
159
160    return concurrent_apply_to_dag(
161        dag if not reverse_order else dag.reversed,
162        lambda s_id: fn(snapshots_by_id[s_id]),
163        tasks_num,
164        raise_on_error=raise_on_error,
165    )

Applies a function to the given collection of snapshots concurrently while preserving the topological order between snapshots.

Arguments:
  • snapshots: Target snapshots.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • reverse_order: Whether the order should be reversed. Default: False.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
  • NodeExecutionFailedError if raise_on_error is set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of errors and a list of skipped snapshot IDs.

def concurrent_apply_to_dag( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], tasks_num: int, raise_on_error: bool = True) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[~H]], List[~H]]:
168def concurrent_apply_to_dag(
169    dag: DAG[H],
170    fn: t.Callable[[H], None],
171    tasks_num: int,
172    raise_on_error: bool = True,
173) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
174    """Applies a function to the given DAG concurrently while preserving the topological
175    order between snapshots.
176
177    Args:
178        dag: The target DAG.
179        fn: The function that will be applied concurrently to each snapshot.
180        tasks_num: The number of concurrent tasks.
181        raise_on_error: If set to True raises an exception on a first encountered error,
182            otherwises returns a tuple which contains a list of failed nodes and a list of
183            skipped nodes.
184
185    Raises:
186        NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot.
187
188    Returns:
189        A pair which contains a list of node errors and a list of skipped nodes.
190    """
191    if tasks_num <= 0:
192        raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
193
194    if tasks_num == 1:
195        return sequential_apply_to_dag(dag, fn, raise_on_error)
196
197    return ConcurrentDAGExecutor(
198        dag,
199        fn,
200        tasks_num,
201        raise_on_error,
202    ).run()

Applies a function to the given DAG concurrently while preserving the topological order between snapshots.

Arguments:
  • dag: The target DAG.
  • fn: The function that will be applied concurrently to each snapshot.
  • tasks_num: The number of concurrent tasks.
  • raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
  • NodeExecutionFailedError if raise_on_error is set to True and execution fails for any snapshot.
Returns:

A pair which contains a list of node errors and a list of skipped nodes.

def sequential_apply_to_dag( dag: sqlmesh.utils.dag.DAG[~H], fn: Callable[[~H], NoneType], raise_on_error: bool = True) -> Tuple[List[sqlmesh.utils.concurrency.NodeExecutionFailedError[~H]], List[~H]]:
205def sequential_apply_to_dag(
206    dag: DAG[H],
207    fn: t.Callable[[H], None],
208    raise_on_error: bool = True,
209) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]:
210    dependencies = dag.graph
211
212    node_errors: t.List[NodeExecutionFailedError[H]] = []
213    skipped_nodes: t.List[H] = []
214
215    failed_or_skipped_nodes: t.Set[H] = set()
216
217    for node in dag.sorted():
218        if not failed_or_skipped_nodes.isdisjoint(dependencies[node]):
219            skipped_nodes.append(node)
220            failed_or_skipped_nodes.add(node)
221            continue
222
223        try:
224            fn(node)
225        except Exception as ex:
226            if raise_on_error:
227                raise NodeExecutionFailedError(node) from ex
228
229            error = NodeExecutionFailedError(node)
230            error.__cause__ = ex
231
232            node_errors.append(error)
233            failed_or_skipped_nodes.add(node)
234
235    return node_errors, skipped_nodes