sqlmesh.utils.concurrency
1import typing as t 2from concurrent.futures import Executor, Future, ThreadPoolExecutor 3from threading import Lock 4 5from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike 6from sqlmesh.utils.dag import DAG 7from sqlmesh.utils.errors import ConfigError, SQLMeshError 8 9H = t.TypeVar("H", bound=t.Hashable) 10S = t.TypeVar("S", bound=SnapshotInfoLike) 11 12 13class NodeExecutionFailedError(t.Generic[H], SQLMeshError): 14 def __init__(self, node: H): 15 self.node = node 16 super().__init__(f"Execution failed for node {node}") 17 18 19class ConcurrentDAGExecutor(t.Generic[H]): 20 """Concurrently traverses the given DAG in topological order while applying a function to each node. 21 22 If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes. 23 24 Args: 25 dag: The target DAG. 26 fn: The function that will be applied concurrently to each snapshot. 27 tasks_num: The number of concurrent tasks. 28 raise_on_error: If set to True raises an exception on a first encountered error, 29 otherwises returns a tuple which contains a list of failed nodes and a list of 30 skipped nodes. 31 """ 32 33 def __init__( 34 self, 35 dag: DAG[H], 36 fn: t.Callable[[H], None], 37 tasks_num: int, 38 raise_on_error: bool, 39 ): 40 self.dag = dag 41 self.fn = fn 42 self.tasks_num = tasks_num 43 self.raise_on_error = raise_on_error 44 45 self._init_state() 46 47 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 48 """Runs the executor. 49 50 Raises: 51 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 52 53 Returns: 54 A pair which contains a list of node errors and a list of skipped nodes. 55 """ 56 if self._finished_future.done(): 57 self._init_state() 58 59 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 60 with self._unprocessed_nodes_lock: 61 self._submit_next_nodes(pool) 62 self._finished_future.result() 63 return self._node_errors, self._skipped_nodes 64 65 def _process_node(self, node: H, executor: Executor) -> None: 66 try: 67 self.fn(node) 68 69 with self._unprocessed_nodes_lock: 70 self._unprocessed_nodes_num -= 1 71 self._submit_next_nodes(executor, node) 72 except Exception as ex: 73 error = NodeExecutionFailedError(node) 74 error.__cause__ = ex 75 76 if self.raise_on_error: 77 self._finished_future.set_exception(error) 78 return 79 80 with self._unprocessed_nodes_lock: 81 self._unprocessed_nodes_num -= 1 82 self._node_errors.append(error) 83 self._skip_next_nodes(node) 84 85 def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None: 86 if not self._unprocessed_nodes_num: 87 self._finished_future.set_result(None) 88 return 89 90 submitted_nodes = [] 91 for next_node, deps in self._unprocessed_nodes.items(): 92 if processed_node: 93 deps.discard(processed_node) 94 if not deps: 95 submitted_nodes.append(next_node) 96 97 for submitted_node in submitted_nodes: 98 self._unprocessed_nodes.pop(submitted_node) 99 executor.submit(self._process_node, submitted_node, executor) 100 101 def _skip_next_nodes(self, parent: H) -> None: 102 if not self._unprocessed_nodes_num: 103 self._finished_future.set_result(None) 104 return 105 106 skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps] 107 108 self._skipped_nodes.extend(skipped_nodes) 109 110 for skipped_node in skipped_nodes: 111 self._unprocessed_nodes_num -= 1 112 self._unprocessed_nodes.pop(skipped_node) 113 self._skip_next_nodes(skipped_node) 114 115 def _init_state(self) -> None: 116 self._unprocessed_nodes = self.dag.graph 117 self._unprocessed_nodes_num = len(self._unprocessed_nodes) 118 self._unprocessed_nodes_lock = Lock() 119 self._finished_future = Future() # type: ignore 120 121 self._node_errors: t.List[NodeExecutionFailedError[H]] = [] 122 self._skipped_nodes: t.List[H] = [] 123 124 125def concurrent_apply_to_snapshots( 126 snapshots: t.Iterable[S], 127 fn: t.Callable[[S], None], 128 tasks_num: int, 129 reverse_order: bool = False, 130 raise_on_error: bool = True, 131) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]: 132 """Applies a function to the given collection of snapshots concurrently while 133 preserving the topological order between snapshots. 134 135 Args: 136 snapshots: Target snapshots. 137 fn: The function that will be applied concurrently to each snapshot. 138 tasks_num: The number of concurrent tasks. 139 reverse_order: Whether the order should be reversed. Default: False. 140 raise_on_error: If set to True raises an exception on a first encountered error, 141 otherwises returns a tuple which contains a list of failed nodes and a list of 142 skipped nodes. 143 144 Raises: 145 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 146 147 Returns: 148 A pair which contains a list of errors and a list of skipped snapshot IDs. 149 """ 150 snapshots_by_id = {s.snapshot_id: s for s in snapshots} 151 152 dag: DAG[SnapshotId] = DAG[SnapshotId]() 153 for snapshot in snapshots: 154 dag.add( 155 snapshot.snapshot_id, 156 [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], 157 ) 158 159 return concurrent_apply_to_dag( 160 dag if not reverse_order else dag.reversed, 161 lambda s_id: fn(snapshots_by_id[s_id]), 162 tasks_num, 163 raise_on_error=raise_on_error, 164 ) 165 166 167def concurrent_apply_to_dag( 168 dag: DAG[H], 169 fn: t.Callable[[H], None], 170 tasks_num: int, 171 raise_on_error: bool = True, 172) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 173 """Applies a function to the given DAG concurrently while preserving the topological 174 order between snapshots. 175 176 Args: 177 dag: The target DAG. 178 fn: The function that will be applied concurrently to each snapshot. 179 tasks_num: The number of concurrent tasks. 180 raise_on_error: If set to True raises an exception on a first encountered error, 181 otherwises returns a tuple which contains a list of failed nodes and a list of 182 skipped nodes. 183 184 Raises: 185 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 186 187 Returns: 188 A pair which contains a list of node errors and a list of skipped nodes. 189 """ 190 if tasks_num <= 0: 191 raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") 192 193 if tasks_num == 1: 194 return sequential_apply_to_dag(dag, fn, raise_on_error) 195 196 return ConcurrentDAGExecutor( 197 dag, 198 fn, 199 tasks_num, 200 raise_on_error, 201 ).run() 202 203 204def sequential_apply_to_dag( 205 dag: DAG[H], 206 fn: t.Callable[[H], None], 207 raise_on_error: bool = True, 208) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 209 dependencies = dag.graph 210 211 node_errors: t.List[NodeExecutionFailedError[H]] = [] 212 skipped_nodes: t.List[H] = [] 213 214 failed_or_skipped_nodes: t.Set[H] = set() 215 216 for node in dag.sorted(): 217 if not failed_or_skipped_nodes.isdisjoint(dependencies[node]): 218 skipped_nodes.append(node) 219 failed_or_skipped_nodes.add(node) 220 continue 221 222 try: 223 fn(node) 224 except Exception as ex: 225 if raise_on_error: 226 raise NodeExecutionFailedError(node) from ex 227 228 error = NodeExecutionFailedError(node) 229 error.__cause__ = ex 230 231 node_errors.append(error) 232 failed_or_skipped_nodes.add(node) 233 234 return node_errors, skipped_nodes
14class NodeExecutionFailedError(t.Generic[H], SQLMeshError): 15 def __init__(self, node: H): 16 self.node = node 17 super().__init__(f"Execution failed for node {node}")
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Inherited Members
- builtins.BaseException
- with_traceback
20class ConcurrentDAGExecutor(t.Generic[H]): 21 """Concurrently traverses the given DAG in topological order while applying a function to each node. 22 23 If `raise_on_error` is set to False maintains a state of execution errors as well as of skipped nodes. 24 25 Args: 26 dag: The target DAG. 27 fn: The function that will be applied concurrently to each snapshot. 28 tasks_num: The number of concurrent tasks. 29 raise_on_error: If set to True raises an exception on a first encountered error, 30 otherwises returns a tuple which contains a list of failed nodes and a list of 31 skipped nodes. 32 """ 33 34 def __init__( 35 self, 36 dag: DAG[H], 37 fn: t.Callable[[H], None], 38 tasks_num: int, 39 raise_on_error: bool, 40 ): 41 self.dag = dag 42 self.fn = fn 43 self.tasks_num = tasks_num 44 self.raise_on_error = raise_on_error 45 46 self._init_state() 47 48 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 49 """Runs the executor. 50 51 Raises: 52 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 53 54 Returns: 55 A pair which contains a list of node errors and a list of skipped nodes. 56 """ 57 if self._finished_future.done(): 58 self._init_state() 59 60 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 61 with self._unprocessed_nodes_lock: 62 self._submit_next_nodes(pool) 63 self._finished_future.result() 64 return self._node_errors, self._skipped_nodes 65 66 def _process_node(self, node: H, executor: Executor) -> None: 67 try: 68 self.fn(node) 69 70 with self._unprocessed_nodes_lock: 71 self._unprocessed_nodes_num -= 1 72 self._submit_next_nodes(executor, node) 73 except Exception as ex: 74 error = NodeExecutionFailedError(node) 75 error.__cause__ = ex 76 77 if self.raise_on_error: 78 self._finished_future.set_exception(error) 79 return 80 81 with self._unprocessed_nodes_lock: 82 self._unprocessed_nodes_num -= 1 83 self._node_errors.append(error) 84 self._skip_next_nodes(node) 85 86 def _submit_next_nodes(self, executor: Executor, processed_node: t.Optional[H] = None) -> None: 87 if not self._unprocessed_nodes_num: 88 self._finished_future.set_result(None) 89 return 90 91 submitted_nodes = [] 92 for next_node, deps in self._unprocessed_nodes.items(): 93 if processed_node: 94 deps.discard(processed_node) 95 if not deps: 96 submitted_nodes.append(next_node) 97 98 for submitted_node in submitted_nodes: 99 self._unprocessed_nodes.pop(submitted_node) 100 executor.submit(self._process_node, submitted_node, executor) 101 102 def _skip_next_nodes(self, parent: H) -> None: 103 if not self._unprocessed_nodes_num: 104 self._finished_future.set_result(None) 105 return 106 107 skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps] 108 109 self._skipped_nodes.extend(skipped_nodes) 110 111 for skipped_node in skipped_nodes: 112 self._unprocessed_nodes_num -= 1 113 self._unprocessed_nodes.pop(skipped_node) 114 self._skip_next_nodes(skipped_node) 115 116 def _init_state(self) -> None: 117 self._unprocessed_nodes = self.dag.graph 118 self._unprocessed_nodes_num = len(self._unprocessed_nodes) 119 self._unprocessed_nodes_lock = Lock() 120 self._finished_future = Future() # type: ignore 121 122 self._node_errors: t.List[NodeExecutionFailedError[H]] = [] 123 self._skipped_nodes: t.List[H] = []
Concurrently traverses the given DAG in topological order while applying a function to each node.
If raise_on_error
is set to False maintains a state of execution errors as well as of skipped nodes.
Arguments:
- dag: The target DAG.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
48 def run(self) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 49 """Runs the executor. 50 51 Raises: 52 NodeExecutionFailedError if `raise_on_error` was set to True and execution fails for any snapshot. 53 54 Returns: 55 A pair which contains a list of node errors and a list of skipped nodes. 56 """ 57 if self._finished_future.done(): 58 self._init_state() 59 60 with ThreadPoolExecutor(max_workers=self.tasks_num) as pool: 61 with self._unprocessed_nodes_lock: 62 self._submit_next_nodes(pool) 63 self._finished_future.result() 64 return self._node_errors, self._skipped_nodes
Runs the executor.
Raises:
- NodeExecutionFailedError if
raise_on_error
was set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of node errors and a list of skipped nodes.
126def concurrent_apply_to_snapshots( 127 snapshots: t.Iterable[S], 128 fn: t.Callable[[S], None], 129 tasks_num: int, 130 reverse_order: bool = False, 131 raise_on_error: bool = True, 132) -> t.Tuple[t.List[NodeExecutionFailedError[SnapshotId]], t.List[SnapshotId]]: 133 """Applies a function to the given collection of snapshots concurrently while 134 preserving the topological order between snapshots. 135 136 Args: 137 snapshots: Target snapshots. 138 fn: The function that will be applied concurrently to each snapshot. 139 tasks_num: The number of concurrent tasks. 140 reverse_order: Whether the order should be reversed. Default: False. 141 raise_on_error: If set to True raises an exception on a first encountered error, 142 otherwises returns a tuple which contains a list of failed nodes and a list of 143 skipped nodes. 144 145 Raises: 146 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 147 148 Returns: 149 A pair which contains a list of errors and a list of skipped snapshot IDs. 150 """ 151 snapshots_by_id = {s.snapshot_id: s for s in snapshots} 152 153 dag: DAG[SnapshotId] = DAG[SnapshotId]() 154 for snapshot in snapshots: 155 dag.add( 156 snapshot.snapshot_id, 157 [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], 158 ) 159 160 return concurrent_apply_to_dag( 161 dag if not reverse_order else dag.reversed, 162 lambda s_id: fn(snapshots_by_id[s_id]), 163 tasks_num, 164 raise_on_error=raise_on_error, 165 )
Applies a function to the given collection of snapshots concurrently while preserving the topological order between snapshots.
Arguments:
- snapshots: Target snapshots.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- reverse_order: Whether the order should be reversed. Default: False.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
- NodeExecutionFailedError if
raise_on_error
is set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of errors and a list of skipped snapshot IDs.
168def concurrent_apply_to_dag( 169 dag: DAG[H], 170 fn: t.Callable[[H], None], 171 tasks_num: int, 172 raise_on_error: bool = True, 173) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 174 """Applies a function to the given DAG concurrently while preserving the topological 175 order between snapshots. 176 177 Args: 178 dag: The target DAG. 179 fn: The function that will be applied concurrently to each snapshot. 180 tasks_num: The number of concurrent tasks. 181 raise_on_error: If set to True raises an exception on a first encountered error, 182 otherwises returns a tuple which contains a list of failed nodes and a list of 183 skipped nodes. 184 185 Raises: 186 NodeExecutionFailedError if `raise_on_error` is set to True and execution fails for any snapshot. 187 188 Returns: 189 A pair which contains a list of node errors and a list of skipped nodes. 190 """ 191 if tasks_num <= 0: 192 raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") 193 194 if tasks_num == 1: 195 return sequential_apply_to_dag(dag, fn, raise_on_error) 196 197 return ConcurrentDAGExecutor( 198 dag, 199 fn, 200 tasks_num, 201 raise_on_error, 202 ).run()
Applies a function to the given DAG concurrently while preserving the topological order between snapshots.
Arguments:
- dag: The target DAG.
- fn: The function that will be applied concurrently to each snapshot.
- tasks_num: The number of concurrent tasks.
- raise_on_error: If set to True raises an exception on a first encountered error, otherwises returns a tuple which contains a list of failed nodes and a list of skipped nodes.
Raises:
- NodeExecutionFailedError if
raise_on_error
is set to True and execution fails for any snapshot.
Returns:
A pair which contains a list of node errors and a list of skipped nodes.
205def sequential_apply_to_dag( 206 dag: DAG[H], 207 fn: t.Callable[[H], None], 208 raise_on_error: bool = True, 209) -> t.Tuple[t.List[NodeExecutionFailedError[H]], t.List[H]]: 210 dependencies = dag.graph 211 212 node_errors: t.List[NodeExecutionFailedError[H]] = [] 213 skipped_nodes: t.List[H] = [] 214 215 failed_or_skipped_nodes: t.Set[H] = set() 216 217 for node in dag.sorted(): 218 if not failed_or_skipped_nodes.isdisjoint(dependencies[node]): 219 skipped_nodes.append(node) 220 failed_or_skipped_nodes.add(node) 221 continue 222 223 try: 224 fn(node) 225 except Exception as ex: 226 if raise_on_error: 227 raise NodeExecutionFailedError(node) from ex 228 229 error = NodeExecutionFailedError(node) 230 error.__cause__ = ex 231 232 node_errors.append(error) 233 failed_or_skipped_nodes.add(node) 234 235 return node_errors, skipped_nodes