sqlmesh.utils.connection_pool
1import abc 2import logging 3import typing as t 4from threading import Lock, get_ident 5 6logger = logging.getLogger(__name__) 7 8 9class ConnectionPool(abc.ABC): 10 @abc.abstractmethod 11 def get_cursor(self) -> t.Any: 12 """Returns cached cursor instance. 13 14 Automatically creates a new instance if one is not available. 15 16 Returns: 17 A cursor instance. 18 """ 19 20 @abc.abstractmethod 21 def get(self) -> t.Any: 22 """Returns cached connection instance. 23 24 Automatically opens a new connection if one is not available. 25 26 Returns: 27 A connection instance. 28 """ 29 30 @abc.abstractmethod 31 def begin(self) -> None: 32 """Starts a new transaction.""" 33 34 @abc.abstractmethod 35 def commit(self) -> None: 36 """Commits the current transaction.""" 37 38 @abc.abstractmethod 39 def rollback(self) -> None: 40 """Rolls back the current transaction.""" 41 42 @property 43 @abc.abstractmethod 44 def is_transaction_active(self) -> bool: 45 """Returns True if there is an active transaction and False otherwise.""" 46 47 @abc.abstractmethod 48 def close_cursor(self) -> None: 49 """Closes the current cursor instance if exists.""" 50 51 @abc.abstractmethod 52 def close(self) -> None: 53 """Closes the current connection instance if exists. 54 55 Note: if there is a cursor instance available it will be closed as well. 56 """ 57 58 @abc.abstractmethod 59 def close_all(self, exclude_calling_thread: bool = False) -> None: 60 """Closes all cached cursors and connections. 61 62 Args: 63 exclude_calling_thread: If set to True excludes cursors and connections associated 64 with the calling thread. 65 """ 66 67 68class _TransactionManagementMixin(ConnectionPool): 69 def _do_begin(self) -> None: 70 cursor = self.get_cursor() 71 if hasattr(cursor, "begin"): 72 cursor.begin() 73 else: 74 conn = self.get() 75 if hasattr(conn, "begin"): 76 conn.begin() 77 78 def _do_commit(self) -> None: 79 cursor = self.get_cursor() 80 if hasattr(cursor, "commit"): 81 cursor.commit() 82 else: 83 self.get().commit() 84 85 def _do_rollback(self) -> None: 86 cursor = self.get_cursor() 87 if hasattr(cursor, "rollback"): 88 cursor.rollback() 89 else: 90 self.get().rollback() 91 92 93class ThreadLocalConnectionPool(_TransactionManagementMixin): 94 def __init__(self, connection_factory: t.Callable[[], t.Any]): 95 self._connection_factory = connection_factory 96 self._thread_connections: t.Dict[t.Hashable, t.Any] = {} 97 self._thread_cursors: t.Dict[t.Hashable, t.Any] = {} 98 self._thread_transactions: t.Set[t.Hashable] = set() 99 self._thread_connections_lock = Lock() 100 self._thread_cursors_lock = Lock() 101 self._thread_transactions_lock = Lock() 102 103 def get_cursor(self) -> t.Any: 104 thread_id = get_ident() 105 with self._thread_cursors_lock: 106 if thread_id not in self._thread_cursors: 107 self._thread_cursors[thread_id] = self.get().cursor() 108 return self._thread_cursors[thread_id] 109 110 def get(self) -> t.Any: 111 thread_id = get_ident() 112 with self._thread_connections_lock: 113 if thread_id not in self._thread_connections: 114 self._thread_connections[thread_id] = self._connection_factory() 115 return self._thread_connections[thread_id] 116 117 def begin(self) -> None: 118 self._do_begin() 119 with self._thread_transactions_lock: 120 self._thread_transactions.add(get_ident()) 121 122 def commit(self) -> None: 123 self._do_commit() 124 self._discard_transaction(get_ident()) 125 126 def rollback(self) -> None: 127 self._do_rollback() 128 self._discard_transaction(get_ident()) 129 130 @property 131 def is_transaction_active(self) -> bool: 132 with self._thread_transactions_lock: 133 return get_ident() in self._thread_transactions 134 135 def close_cursor(self) -> None: 136 thread_id = get_ident() 137 with self._thread_cursors_lock: 138 if thread_id in self._thread_cursors: 139 _try_close(self._thread_cursors[thread_id], "cursor") 140 self._thread_cursors.pop(thread_id) 141 142 def close(self) -> None: 143 thread_id = get_ident() 144 with self._thread_cursors_lock, self._thread_connections_lock: 145 if thread_id in self._thread_connections: 146 _try_close(self._thread_connections[thread_id], "connection") 147 self._thread_connections.pop(thread_id) 148 self._thread_cursors.pop(thread_id, None) 149 self._discard_transaction(thread_id) 150 151 def close_all(self, exclude_calling_thread: bool = False) -> None: 152 calling_thread_id = get_ident() 153 with self._thread_cursors_lock, self._thread_connections_lock: 154 for thread_id, connection in self._thread_connections.copy().items(): 155 if not exclude_calling_thread or thread_id != calling_thread_id: 156 # NOTE: the access to the connection instance itself is not thread-safe here. 157 _try_close(connection, "connection") 158 self._thread_connections.pop(thread_id) 159 self._thread_cursors.pop(thread_id, None) 160 self._discard_transaction(thread_id) 161 162 def _discard_transaction(self, thread_id: t.Hashable) -> None: 163 with self._thread_transactions_lock: 164 self._thread_transactions.discard(thread_id) 165 166 167class SingletonConnectionPool(_TransactionManagementMixin): 168 def __init__(self, connection_factory: t.Callable[[], t.Any]): 169 self._connection_factory = connection_factory 170 self._connection: t.Optional[t.Any] = None 171 self._cursor: t.Optional[t.Any] = None 172 self._is_transaction_active: bool = False 173 174 def get_cursor(self) -> t.Any: 175 if not self._cursor: 176 self._cursor = self.get().cursor() 177 return self._cursor 178 179 def get(self) -> t.Any: 180 if not self._connection: 181 self._connection = self._connection_factory() 182 return self._connection 183 184 def begin(self) -> None: 185 self._do_begin() 186 self._is_transaction_active = True 187 188 def commit(self) -> None: 189 self._do_commit() 190 self._is_transaction_active = False 191 192 def rollback(self) -> None: 193 self._do_rollback() 194 self._is_transaction_active = False 195 196 @property 197 def is_transaction_active(self) -> bool: 198 return self._is_transaction_active 199 200 def close_cursor(self) -> None: 201 _try_close(self._cursor, "cursor") 202 self._cursor = None 203 204 def close(self) -> None: 205 _try_close(self._connection, "connection") 206 self._connection = None 207 self._cursor = None 208 self._is_transaction_active = False 209 210 def close_all(self, exclude_calling_thread: bool = False) -> None: 211 if not exclude_calling_thread: 212 self.close() 213 214 215def create_connection_pool( 216 connection_factory: t.Callable[[], t.Any], multithreaded: bool 217) -> ConnectionPool: 218 return ( 219 ThreadLocalConnectionPool(connection_factory) 220 if multithreaded 221 else SingletonConnectionPool(connection_factory) 222 ) 223 224 225def _try_close(closeable: t.Any, kind: str) -> None: 226 if closeable is None: 227 return 228 try: 229 closeable.close() 230 except Exception: 231 logger.exception("Failed to close %s", kind)
10class ConnectionPool(abc.ABC): 11 @abc.abstractmethod 12 def get_cursor(self) -> t.Any: 13 """Returns cached cursor instance. 14 15 Automatically creates a new instance if one is not available. 16 17 Returns: 18 A cursor instance. 19 """ 20 21 @abc.abstractmethod 22 def get(self) -> t.Any: 23 """Returns cached connection instance. 24 25 Automatically opens a new connection if one is not available. 26 27 Returns: 28 A connection instance. 29 """ 30 31 @abc.abstractmethod 32 def begin(self) -> None: 33 """Starts a new transaction.""" 34 35 @abc.abstractmethod 36 def commit(self) -> None: 37 """Commits the current transaction.""" 38 39 @abc.abstractmethod 40 def rollback(self) -> None: 41 """Rolls back the current transaction.""" 42 43 @property 44 @abc.abstractmethod 45 def is_transaction_active(self) -> bool: 46 """Returns True if there is an active transaction and False otherwise.""" 47 48 @abc.abstractmethod 49 def close_cursor(self) -> None: 50 """Closes the current cursor instance if exists.""" 51 52 @abc.abstractmethod 53 def close(self) -> None: 54 """Closes the current connection instance if exists. 55 56 Note: if there is a cursor instance available it will be closed as well. 57 """ 58 59 @abc.abstractmethod 60 def close_all(self, exclude_calling_thread: bool = False) -> None: 61 """Closes all cached cursors and connections. 62 63 Args: 64 exclude_calling_thread: If set to True excludes cursors and connections associated 65 with the calling thread. 66 """
Helper class that provides a standard way to create an ABC using inheritance.
11 @abc.abstractmethod 12 def get_cursor(self) -> t.Any: 13 """Returns cached cursor instance. 14 15 Automatically creates a new instance if one is not available. 16 17 Returns: 18 A cursor instance. 19 """
Returns cached cursor instance.
Automatically creates a new instance if one is not available.
Returns:
A cursor instance.
21 @abc.abstractmethod 22 def get(self) -> t.Any: 23 """Returns cached connection instance. 24 25 Automatically opens a new connection if one is not available. 26 27 Returns: 28 A connection instance. 29 """
Returns cached connection instance.
Automatically opens a new connection if one is not available.
Returns:
A connection instance.
48 @abc.abstractmethod 49 def close_cursor(self) -> None: 50 """Closes the current cursor instance if exists."""
Closes the current cursor instance if exists.
52 @abc.abstractmethod 53 def close(self) -> None: 54 """Closes the current connection instance if exists. 55 56 Note: if there is a cursor instance available it will be closed as well. 57 """
Closes the current connection instance if exists.
Note: if there is a cursor instance available it will be closed as well.
59 @abc.abstractmethod 60 def close_all(self, exclude_calling_thread: bool = False) -> None: 61 """Closes all cached cursors and connections. 62 63 Args: 64 exclude_calling_thread: If set to True excludes cursors and connections associated 65 with the calling thread. 66 """
Closes all cached cursors and connections.
Arguments:
- exclude_calling_thread: If set to True excludes cursors and connections associated with the calling thread.
94class ThreadLocalConnectionPool(_TransactionManagementMixin): 95 def __init__(self, connection_factory: t.Callable[[], t.Any]): 96 self._connection_factory = connection_factory 97 self._thread_connections: t.Dict[t.Hashable, t.Any] = {} 98 self._thread_cursors: t.Dict[t.Hashable, t.Any] = {} 99 self._thread_transactions: t.Set[t.Hashable] = set() 100 self._thread_connections_lock = Lock() 101 self._thread_cursors_lock = Lock() 102 self._thread_transactions_lock = Lock() 103 104 def get_cursor(self) -> t.Any: 105 thread_id = get_ident() 106 with self._thread_cursors_lock: 107 if thread_id not in self._thread_cursors: 108 self._thread_cursors[thread_id] = self.get().cursor() 109 return self._thread_cursors[thread_id] 110 111 def get(self) -> t.Any: 112 thread_id = get_ident() 113 with self._thread_connections_lock: 114 if thread_id not in self._thread_connections: 115 self._thread_connections[thread_id] = self._connection_factory() 116 return self._thread_connections[thread_id] 117 118 def begin(self) -> None: 119 self._do_begin() 120 with self._thread_transactions_lock: 121 self._thread_transactions.add(get_ident()) 122 123 def commit(self) -> None: 124 self._do_commit() 125 self._discard_transaction(get_ident()) 126 127 def rollback(self) -> None: 128 self._do_rollback() 129 self._discard_transaction(get_ident()) 130 131 @property 132 def is_transaction_active(self) -> bool: 133 with self._thread_transactions_lock: 134 return get_ident() in self._thread_transactions 135 136 def close_cursor(self) -> None: 137 thread_id = get_ident() 138 with self._thread_cursors_lock: 139 if thread_id in self._thread_cursors: 140 _try_close(self._thread_cursors[thread_id], "cursor") 141 self._thread_cursors.pop(thread_id) 142 143 def close(self) -> None: 144 thread_id = get_ident() 145 with self._thread_cursors_lock, self._thread_connections_lock: 146 if thread_id in self._thread_connections: 147 _try_close(self._thread_connections[thread_id], "connection") 148 self._thread_connections.pop(thread_id) 149 self._thread_cursors.pop(thread_id, None) 150 self._discard_transaction(thread_id) 151 152 def close_all(self, exclude_calling_thread: bool = False) -> None: 153 calling_thread_id = get_ident() 154 with self._thread_cursors_lock, self._thread_connections_lock: 155 for thread_id, connection in self._thread_connections.copy().items(): 156 if not exclude_calling_thread or thread_id != calling_thread_id: 157 # NOTE: the access to the connection instance itself is not thread-safe here. 158 _try_close(connection, "connection") 159 self._thread_connections.pop(thread_id) 160 self._thread_cursors.pop(thread_id, None) 161 self._discard_transaction(thread_id) 162 163 def _discard_transaction(self, thread_id: t.Hashable) -> None: 164 with self._thread_transactions_lock: 165 self._thread_transactions.discard(thread_id)
Helper class that provides a standard way to create an ABC using inheritance.
95 def __init__(self, connection_factory: t.Callable[[], t.Any]): 96 self._connection_factory = connection_factory 97 self._thread_connections: t.Dict[t.Hashable, t.Any] = {} 98 self._thread_cursors: t.Dict[t.Hashable, t.Any] = {} 99 self._thread_transactions: t.Set[t.Hashable] = set() 100 self._thread_connections_lock = Lock() 101 self._thread_cursors_lock = Lock() 102 self._thread_transactions_lock = Lock()
104 def get_cursor(self) -> t.Any: 105 thread_id = get_ident() 106 with self._thread_cursors_lock: 107 if thread_id not in self._thread_cursors: 108 self._thread_cursors[thread_id] = self.get().cursor() 109 return self._thread_cursors[thread_id]
Returns cached cursor instance.
Automatically creates a new instance if one is not available.
Returns:
A cursor instance.
111 def get(self) -> t.Any: 112 thread_id = get_ident() 113 with self._thread_connections_lock: 114 if thread_id not in self._thread_connections: 115 self._thread_connections[thread_id] = self._connection_factory() 116 return self._thread_connections[thread_id]
Returns cached connection instance.
Automatically opens a new connection if one is not available.
Returns:
A connection instance.
118 def begin(self) -> None: 119 self._do_begin() 120 with self._thread_transactions_lock: 121 self._thread_transactions.add(get_ident())
Starts a new transaction.
136 def close_cursor(self) -> None: 137 thread_id = get_ident() 138 with self._thread_cursors_lock: 139 if thread_id in self._thread_cursors: 140 _try_close(self._thread_cursors[thread_id], "cursor") 141 self._thread_cursors.pop(thread_id)
Closes the current cursor instance if exists.
143 def close(self) -> None: 144 thread_id = get_ident() 145 with self._thread_cursors_lock, self._thread_connections_lock: 146 if thread_id in self._thread_connections: 147 _try_close(self._thread_connections[thread_id], "connection") 148 self._thread_connections.pop(thread_id) 149 self._thread_cursors.pop(thread_id, None) 150 self._discard_transaction(thread_id)
Closes the current connection instance if exists.
Note: if there is a cursor instance available it will be closed as well.
152 def close_all(self, exclude_calling_thread: bool = False) -> None: 153 calling_thread_id = get_ident() 154 with self._thread_cursors_lock, self._thread_connections_lock: 155 for thread_id, connection in self._thread_connections.copy().items(): 156 if not exclude_calling_thread or thread_id != calling_thread_id: 157 # NOTE: the access to the connection instance itself is not thread-safe here. 158 _try_close(connection, "connection") 159 self._thread_connections.pop(thread_id) 160 self._thread_cursors.pop(thread_id, None) 161 self._discard_transaction(thread_id)
Closes all cached cursors and connections.
Arguments:
- exclude_calling_thread: If set to True excludes cursors and connections associated with the calling thread.
168class SingletonConnectionPool(_TransactionManagementMixin): 169 def __init__(self, connection_factory: t.Callable[[], t.Any]): 170 self._connection_factory = connection_factory 171 self._connection: t.Optional[t.Any] = None 172 self._cursor: t.Optional[t.Any] = None 173 self._is_transaction_active: bool = False 174 175 def get_cursor(self) -> t.Any: 176 if not self._cursor: 177 self._cursor = self.get().cursor() 178 return self._cursor 179 180 def get(self) -> t.Any: 181 if not self._connection: 182 self._connection = self._connection_factory() 183 return self._connection 184 185 def begin(self) -> None: 186 self._do_begin() 187 self._is_transaction_active = True 188 189 def commit(self) -> None: 190 self._do_commit() 191 self._is_transaction_active = False 192 193 def rollback(self) -> None: 194 self._do_rollback() 195 self._is_transaction_active = False 196 197 @property 198 def is_transaction_active(self) -> bool: 199 return self._is_transaction_active 200 201 def close_cursor(self) -> None: 202 _try_close(self._cursor, "cursor") 203 self._cursor = None 204 205 def close(self) -> None: 206 _try_close(self._connection, "connection") 207 self._connection = None 208 self._cursor = None 209 self._is_transaction_active = False 210 211 def close_all(self, exclude_calling_thread: bool = False) -> None: 212 if not exclude_calling_thread: 213 self.close()
Helper class that provides a standard way to create an ABC using inheritance.
175 def get_cursor(self) -> t.Any: 176 if not self._cursor: 177 self._cursor = self.get().cursor() 178 return self._cursor
Returns cached cursor instance.
Automatically creates a new instance if one is not available.
Returns:
A cursor instance.
180 def get(self) -> t.Any: 181 if not self._connection: 182 self._connection = self._connection_factory() 183 return self._connection
Returns cached connection instance.
Automatically opens a new connection if one is not available.
Returns:
A connection instance.
205 def close(self) -> None: 206 _try_close(self._connection, "connection") 207 self._connection = None 208 self._cursor = None 209 self._is_transaction_active = False
Closes the current connection instance if exists.
Note: if there is a cursor instance available it will be closed as well.
211 def close_all(self, exclude_calling_thread: bool = False) -> None: 212 if not exclude_calling_thread: 213 self.close()
Closes all cached cursors and connections.
Arguments:
- exclude_calling_thread: If set to True excludes cursors and connections associated with the calling thread.