Edit on GitHub

sqlmesh.utils.connection_pool

  1import abc
  2import logging
  3import typing as t
  4from threading import Lock, get_ident
  5
  6logger = logging.getLogger(__name__)
  7
  8
  9class ConnectionPool(abc.ABC):
 10    @abc.abstractmethod
 11    def get_cursor(self) -> t.Any:
 12        """Returns cached cursor instance.
 13
 14        Automatically creates a new instance if one is not available.
 15
 16        Returns:
 17            A cursor instance.
 18        """
 19
 20    @abc.abstractmethod
 21    def get(self) -> t.Any:
 22        """Returns cached connection instance.
 23
 24        Automatically opens a new connection if one is not available.
 25
 26        Returns:
 27            A connection instance.
 28        """
 29
 30    @abc.abstractmethod
 31    def begin(self) -> None:
 32        """Starts a new transaction."""
 33
 34    @abc.abstractmethod
 35    def commit(self) -> None:
 36        """Commits the current transaction."""
 37
 38    @abc.abstractmethod
 39    def rollback(self) -> None:
 40        """Rolls back the current transaction."""
 41
 42    @property
 43    @abc.abstractmethod
 44    def is_transaction_active(self) -> bool:
 45        """Returns True if there is an active transaction and False otherwise."""
 46
 47    @abc.abstractmethod
 48    def close_cursor(self) -> None:
 49        """Closes the current cursor instance if exists."""
 50
 51    @abc.abstractmethod
 52    def close(self) -> None:
 53        """Closes the current connection instance if exists.
 54
 55        Note: if there is a cursor instance available it will be closed as well.
 56        """
 57
 58    @abc.abstractmethod
 59    def close_all(self, exclude_calling_thread: bool = False) -> None:
 60        """Closes all cached cursors and connections.
 61
 62        Args:
 63            exclude_calling_thread: If set to True excludes cursors and connections associated
 64                with the calling thread.
 65        """
 66
 67
 68class _TransactionManagementMixin(ConnectionPool):
 69    def _do_begin(self) -> None:
 70        cursor = self.get_cursor()
 71        if hasattr(cursor, "begin"):
 72            cursor.begin()
 73        else:
 74            conn = self.get()
 75            if hasattr(conn, "begin"):
 76                conn.begin()
 77
 78    def _do_commit(self) -> None:
 79        cursor = self.get_cursor()
 80        if hasattr(cursor, "commit"):
 81            cursor.commit()
 82        else:
 83            self.get().commit()
 84
 85    def _do_rollback(self) -> None:
 86        cursor = self.get_cursor()
 87        if hasattr(cursor, "rollback"):
 88            cursor.rollback()
 89        else:
 90            self.get().rollback()
 91
 92
 93class ThreadLocalConnectionPool(_TransactionManagementMixin):
 94    def __init__(self, connection_factory: t.Callable[[], t.Any]):
 95        self._connection_factory = connection_factory
 96        self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
 97        self._thread_cursors: t.Dict[t.Hashable, t.Any] = {}
 98        self._thread_transactions: t.Set[t.Hashable] = set()
 99        self._thread_connections_lock = Lock()
100        self._thread_cursors_lock = Lock()
101        self._thread_transactions_lock = Lock()
102
103    def get_cursor(self) -> t.Any:
104        thread_id = get_ident()
105        with self._thread_cursors_lock:
106            if thread_id not in self._thread_cursors:
107                self._thread_cursors[thread_id] = self.get().cursor()
108            return self._thread_cursors[thread_id]
109
110    def get(self) -> t.Any:
111        thread_id = get_ident()
112        with self._thread_connections_lock:
113            if thread_id not in self._thread_connections:
114                self._thread_connections[thread_id] = self._connection_factory()
115            return self._thread_connections[thread_id]
116
117    def begin(self) -> None:
118        self._do_begin()
119        with self._thread_transactions_lock:
120            self._thread_transactions.add(get_ident())
121
122    def commit(self) -> None:
123        self._do_commit()
124        self._discard_transaction(get_ident())
125
126    def rollback(self) -> None:
127        self._do_rollback()
128        self._discard_transaction(get_ident())
129
130    @property
131    def is_transaction_active(self) -> bool:
132        with self._thread_transactions_lock:
133            return get_ident() in self._thread_transactions
134
135    def close_cursor(self) -> None:
136        thread_id = get_ident()
137        with self._thread_cursors_lock:
138            if thread_id in self._thread_cursors:
139                _try_close(self._thread_cursors[thread_id], "cursor")
140                self._thread_cursors.pop(thread_id)
141
142    def close(self) -> None:
143        thread_id = get_ident()
144        with self._thread_cursors_lock, self._thread_connections_lock:
145            if thread_id in self._thread_connections:
146                _try_close(self._thread_connections[thread_id], "connection")
147                self._thread_connections.pop(thread_id)
148                self._thread_cursors.pop(thread_id, None)
149                self._discard_transaction(thread_id)
150
151    def close_all(self, exclude_calling_thread: bool = False) -> None:
152        calling_thread_id = get_ident()
153        with self._thread_cursors_lock, self._thread_connections_lock:
154            for thread_id, connection in self._thread_connections.copy().items():
155                if not exclude_calling_thread or thread_id != calling_thread_id:
156                    # NOTE: the access to the connection instance itself is not thread-safe here.
157                    _try_close(connection, "connection")
158                    self._thread_connections.pop(thread_id)
159                    self._thread_cursors.pop(thread_id, None)
160                    self._discard_transaction(thread_id)
161
162    def _discard_transaction(self, thread_id: t.Hashable) -> None:
163        with self._thread_transactions_lock:
164            self._thread_transactions.discard(thread_id)
165
166
167class SingletonConnectionPool(_TransactionManagementMixin):
168    def __init__(self, connection_factory: t.Callable[[], t.Any]):
169        self._connection_factory = connection_factory
170        self._connection: t.Optional[t.Any] = None
171        self._cursor: t.Optional[t.Any] = None
172        self._is_transaction_active: bool = False
173
174    def get_cursor(self) -> t.Any:
175        if not self._cursor:
176            self._cursor = self.get().cursor()
177        return self._cursor
178
179    def get(self) -> t.Any:
180        if not self._connection:
181            self._connection = self._connection_factory()
182        return self._connection
183
184    def begin(self) -> None:
185        self._do_begin()
186        self._is_transaction_active = True
187
188    def commit(self) -> None:
189        self._do_commit()
190        self._is_transaction_active = False
191
192    def rollback(self) -> None:
193        self._do_rollback()
194        self._is_transaction_active = False
195
196    @property
197    def is_transaction_active(self) -> bool:
198        return self._is_transaction_active
199
200    def close_cursor(self) -> None:
201        _try_close(self._cursor, "cursor")
202        self._cursor = None
203
204    def close(self) -> None:
205        _try_close(self._connection, "connection")
206        self._connection = None
207        self._cursor = None
208        self._is_transaction_active = False
209
210    def close_all(self, exclude_calling_thread: bool = False) -> None:
211        if not exclude_calling_thread:
212            self.close()
213
214
215def create_connection_pool(
216    connection_factory: t.Callable[[], t.Any], multithreaded: bool
217) -> ConnectionPool:
218    return (
219        ThreadLocalConnectionPool(connection_factory)
220        if multithreaded
221        else SingletonConnectionPool(connection_factory)
222    )
223
224
225def _try_close(closeable: t.Any, kind: str) -> None:
226    if closeable is None:
227        return
228    try:
229        closeable.close()
230    except Exception:
231        logger.exception("Failed to close %s", kind)
class ConnectionPool(abc.ABC):
10class ConnectionPool(abc.ABC):
11    @abc.abstractmethod
12    def get_cursor(self) -> t.Any:
13        """Returns cached cursor instance.
14
15        Automatically creates a new instance if one is not available.
16
17        Returns:
18            A cursor instance.
19        """
20
21    @abc.abstractmethod
22    def get(self) -> t.Any:
23        """Returns cached connection instance.
24
25        Automatically opens a new connection if one is not available.
26
27        Returns:
28            A connection instance.
29        """
30
31    @abc.abstractmethod
32    def begin(self) -> None:
33        """Starts a new transaction."""
34
35    @abc.abstractmethod
36    def commit(self) -> None:
37        """Commits the current transaction."""
38
39    @abc.abstractmethod
40    def rollback(self) -> None:
41        """Rolls back the current transaction."""
42
43    @property
44    @abc.abstractmethod
45    def is_transaction_active(self) -> bool:
46        """Returns True if there is an active transaction and False otherwise."""
47
48    @abc.abstractmethod
49    def close_cursor(self) -> None:
50        """Closes the current cursor instance if exists."""
51
52    @abc.abstractmethod
53    def close(self) -> None:
54        """Closes the current connection instance if exists.
55
56        Note: if there is a cursor instance available it will be closed as well.
57        """
58
59    @abc.abstractmethod
60    def close_all(self, exclude_calling_thread: bool = False) -> None:
61        """Closes all cached cursors and connections.
62
63        Args:
64            exclude_calling_thread: If set to True excludes cursors and connections associated
65                with the calling thread.
66        """

Helper class that provides a standard way to create an ABC using inheritance.

@abc.abstractmethod
def get_cursor(self) -> Any:
11    @abc.abstractmethod
12    def get_cursor(self) -> t.Any:
13        """Returns cached cursor instance.
14
15        Automatically creates a new instance if one is not available.
16
17        Returns:
18            A cursor instance.
19        """

Returns cached cursor instance.

Automatically creates a new instance if one is not available.

Returns:

A cursor instance.

@abc.abstractmethod
def get(self) -> Any:
21    @abc.abstractmethod
22    def get(self) -> t.Any:
23        """Returns cached connection instance.
24
25        Automatically opens a new connection if one is not available.
26
27        Returns:
28            A connection instance.
29        """

Returns cached connection instance.

Automatically opens a new connection if one is not available.

Returns:

A connection instance.

@abc.abstractmethod
def begin(self) -> None:
31    @abc.abstractmethod
32    def begin(self) -> None:
33        """Starts a new transaction."""

Starts a new transaction.

@abc.abstractmethod
def commit(self) -> None:
35    @abc.abstractmethod
36    def commit(self) -> None:
37        """Commits the current transaction."""

Commits the current transaction.

@abc.abstractmethod
def rollback(self) -> None:
39    @abc.abstractmethod
40    def rollback(self) -> None:
41        """Rolls back the current transaction."""

Rolls back the current transaction.

is_transaction_active: bool

Returns True if there is an active transaction and False otherwise.

@abc.abstractmethod
def close_cursor(self) -> None:
48    @abc.abstractmethod
49    def close_cursor(self) -> None:
50        """Closes the current cursor instance if exists."""

Closes the current cursor instance if exists.

@abc.abstractmethod
def close(self) -> None:
52    @abc.abstractmethod
53    def close(self) -> None:
54        """Closes the current connection instance if exists.
55
56        Note: if there is a cursor instance available it will be closed as well.
57        """

Closes the current connection instance if exists.

Note: if there is a cursor instance available it will be closed as well.

@abc.abstractmethod
def close_all(self, exclude_calling_thread: bool = False) -> None:
59    @abc.abstractmethod
60    def close_all(self, exclude_calling_thread: bool = False) -> None:
61        """Closes all cached cursors and connections.
62
63        Args:
64            exclude_calling_thread: If set to True excludes cursors and connections associated
65                with the calling thread.
66        """

Closes all cached cursors and connections.

Arguments:
  • exclude_calling_thread: If set to True excludes cursors and connections associated with the calling thread.
class ThreadLocalConnectionPool(_TransactionManagementMixin):
 94class ThreadLocalConnectionPool(_TransactionManagementMixin):
 95    def __init__(self, connection_factory: t.Callable[[], t.Any]):
 96        self._connection_factory = connection_factory
 97        self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
 98        self._thread_cursors: t.Dict[t.Hashable, t.Any] = {}
 99        self._thread_transactions: t.Set[t.Hashable] = set()
100        self._thread_connections_lock = Lock()
101        self._thread_cursors_lock = Lock()
102        self._thread_transactions_lock = Lock()
103
104    def get_cursor(self) -> t.Any:
105        thread_id = get_ident()
106        with self._thread_cursors_lock:
107            if thread_id not in self._thread_cursors:
108                self._thread_cursors[thread_id] = self.get().cursor()
109            return self._thread_cursors[thread_id]
110
111    def get(self) -> t.Any:
112        thread_id = get_ident()
113        with self._thread_connections_lock:
114            if thread_id not in self._thread_connections:
115                self._thread_connections[thread_id] = self._connection_factory()
116            return self._thread_connections[thread_id]
117
118    def begin(self) -> None:
119        self._do_begin()
120        with self._thread_transactions_lock:
121            self._thread_transactions.add(get_ident())
122
123    def commit(self) -> None:
124        self._do_commit()
125        self._discard_transaction(get_ident())
126
127    def rollback(self) -> None:
128        self._do_rollback()
129        self._discard_transaction(get_ident())
130
131    @property
132    def is_transaction_active(self) -> bool:
133        with self._thread_transactions_lock:
134            return get_ident() in self._thread_transactions
135
136    def close_cursor(self) -> None:
137        thread_id = get_ident()
138        with self._thread_cursors_lock:
139            if thread_id in self._thread_cursors:
140                _try_close(self._thread_cursors[thread_id], "cursor")
141                self._thread_cursors.pop(thread_id)
142
143    def close(self) -> None:
144        thread_id = get_ident()
145        with self._thread_cursors_lock, self._thread_connections_lock:
146            if thread_id in self._thread_connections:
147                _try_close(self._thread_connections[thread_id], "connection")
148                self._thread_connections.pop(thread_id)
149                self._thread_cursors.pop(thread_id, None)
150                self._discard_transaction(thread_id)
151
152    def close_all(self, exclude_calling_thread: bool = False) -> None:
153        calling_thread_id = get_ident()
154        with self._thread_cursors_lock, self._thread_connections_lock:
155            for thread_id, connection in self._thread_connections.copy().items():
156                if not exclude_calling_thread or thread_id != calling_thread_id:
157                    # NOTE: the access to the connection instance itself is not thread-safe here.
158                    _try_close(connection, "connection")
159                    self._thread_connections.pop(thread_id)
160                    self._thread_cursors.pop(thread_id, None)
161                    self._discard_transaction(thread_id)
162
163    def _discard_transaction(self, thread_id: t.Hashable) -> None:
164        with self._thread_transactions_lock:
165            self._thread_transactions.discard(thread_id)

Helper class that provides a standard way to create an ABC using inheritance.

ThreadLocalConnectionPool(connection_factory: Callable[[], Any])
 95    def __init__(self, connection_factory: t.Callable[[], t.Any]):
 96        self._connection_factory = connection_factory
 97        self._thread_connections: t.Dict[t.Hashable, t.Any] = {}
 98        self._thread_cursors: t.Dict[t.Hashable, t.Any] = {}
 99        self._thread_transactions: t.Set[t.Hashable] = set()
100        self._thread_connections_lock = Lock()
101        self._thread_cursors_lock = Lock()
102        self._thread_transactions_lock = Lock()
def get_cursor(self) -> Any:
104    def get_cursor(self) -> t.Any:
105        thread_id = get_ident()
106        with self._thread_cursors_lock:
107            if thread_id not in self._thread_cursors:
108                self._thread_cursors[thread_id] = self.get().cursor()
109            return self._thread_cursors[thread_id]

Returns cached cursor instance.

Automatically creates a new instance if one is not available.

Returns:

A cursor instance.

def get(self) -> Any:
111    def get(self) -> t.Any:
112        thread_id = get_ident()
113        with self._thread_connections_lock:
114            if thread_id not in self._thread_connections:
115                self._thread_connections[thread_id] = self._connection_factory()
116            return self._thread_connections[thread_id]

Returns cached connection instance.

Automatically opens a new connection if one is not available.

Returns:

A connection instance.

def begin(self) -> None:
118    def begin(self) -> None:
119        self._do_begin()
120        with self._thread_transactions_lock:
121            self._thread_transactions.add(get_ident())

Starts a new transaction.

def commit(self) -> None:
123    def commit(self) -> None:
124        self._do_commit()
125        self._discard_transaction(get_ident())

Commits the current transaction.

def rollback(self) -> None:
127    def rollback(self) -> None:
128        self._do_rollback()
129        self._discard_transaction(get_ident())

Rolls back the current transaction.

is_transaction_active: bool

Returns True if there is an active transaction and False otherwise.

def close_cursor(self) -> None:
136    def close_cursor(self) -> None:
137        thread_id = get_ident()
138        with self._thread_cursors_lock:
139            if thread_id in self._thread_cursors:
140                _try_close(self._thread_cursors[thread_id], "cursor")
141                self._thread_cursors.pop(thread_id)

Closes the current cursor instance if exists.

def close(self) -> None:
143    def close(self) -> None:
144        thread_id = get_ident()
145        with self._thread_cursors_lock, self._thread_connections_lock:
146            if thread_id in self._thread_connections:
147                _try_close(self._thread_connections[thread_id], "connection")
148                self._thread_connections.pop(thread_id)
149                self._thread_cursors.pop(thread_id, None)
150                self._discard_transaction(thread_id)

Closes the current connection instance if exists.

Note: if there is a cursor instance available it will be closed as well.

def close_all(self, exclude_calling_thread: bool = False) -> None:
152    def close_all(self, exclude_calling_thread: bool = False) -> None:
153        calling_thread_id = get_ident()
154        with self._thread_cursors_lock, self._thread_connections_lock:
155            for thread_id, connection in self._thread_connections.copy().items():
156                if not exclude_calling_thread or thread_id != calling_thread_id:
157                    # NOTE: the access to the connection instance itself is not thread-safe here.
158                    _try_close(connection, "connection")
159                    self._thread_connections.pop(thread_id)
160                    self._thread_cursors.pop(thread_id, None)
161                    self._discard_transaction(thread_id)

Closes all cached cursors and connections.

Arguments:
  • exclude_calling_thread: If set to True excludes cursors and connections associated with the calling thread.
class SingletonConnectionPool(_TransactionManagementMixin):
168class SingletonConnectionPool(_TransactionManagementMixin):
169    def __init__(self, connection_factory: t.Callable[[], t.Any]):
170        self._connection_factory = connection_factory
171        self._connection: t.Optional[t.Any] = None
172        self._cursor: t.Optional[t.Any] = None
173        self._is_transaction_active: bool = False
174
175    def get_cursor(self) -> t.Any:
176        if not self._cursor:
177            self._cursor = self.get().cursor()
178        return self._cursor
179
180    def get(self) -> t.Any:
181        if not self._connection:
182            self._connection = self._connection_factory()
183        return self._connection
184
185    def begin(self) -> None:
186        self._do_begin()
187        self._is_transaction_active = True
188
189    def commit(self) -> None:
190        self._do_commit()
191        self._is_transaction_active = False
192
193    def rollback(self) -> None:
194        self._do_rollback()
195        self._is_transaction_active = False
196
197    @property
198    def is_transaction_active(self) -> bool:
199        return self._is_transaction_active
200
201    def close_cursor(self) -> None:
202        _try_close(self._cursor, "cursor")
203        self._cursor = None
204
205    def close(self) -> None:
206        _try_close(self._connection, "connection")
207        self._connection = None
208        self._cursor = None
209        self._is_transaction_active = False
210
211    def close_all(self, exclude_calling_thread: bool = False) -> None:
212        if not exclude_calling_thread:
213            self.close()

Helper class that provides a standard way to create an ABC using inheritance.

SingletonConnectionPool(connection_factory: Callable[[], Any])
169    def __init__(self, connection_factory: t.Callable[[], t.Any]):
170        self._connection_factory = connection_factory
171        self._connection: t.Optional[t.Any] = None
172        self._cursor: t.Optional[t.Any] = None
173        self._is_transaction_active: bool = False
def get_cursor(self) -> Any:
175    def get_cursor(self) -> t.Any:
176        if not self._cursor:
177            self._cursor = self.get().cursor()
178        return self._cursor

Returns cached cursor instance.

Automatically creates a new instance if one is not available.

Returns:

A cursor instance.

def get(self) -> Any:
180    def get(self) -> t.Any:
181        if not self._connection:
182            self._connection = self._connection_factory()
183        return self._connection

Returns cached connection instance.

Automatically opens a new connection if one is not available.

Returns:

A connection instance.

def begin(self) -> None:
185    def begin(self) -> None:
186        self._do_begin()
187        self._is_transaction_active = True

Starts a new transaction.

def commit(self) -> None:
189    def commit(self) -> None:
190        self._do_commit()
191        self._is_transaction_active = False

Commits the current transaction.

def rollback(self) -> None:
193    def rollback(self) -> None:
194        self._do_rollback()
195        self._is_transaction_active = False

Rolls back the current transaction.

is_transaction_active: bool

Returns True if there is an active transaction and False otherwise.

def close_cursor(self) -> None:
201    def close_cursor(self) -> None:
202        _try_close(self._cursor, "cursor")
203        self._cursor = None

Closes the current cursor instance if exists.

def close(self) -> None:
205    def close(self) -> None:
206        _try_close(self._connection, "connection")
207        self._connection = None
208        self._cursor = None
209        self._is_transaction_active = False

Closes the current connection instance if exists.

Note: if there is a cursor instance available it will be closed as well.

def close_all(self, exclude_calling_thread: bool = False) -> None:
211    def close_all(self, exclude_calling_thread: bool = False) -> None:
212        if not exclude_calling_thread:
213            self.close()

Closes all cached cursors and connections.

Arguments:
  • exclude_calling_thread: If set to True excludes cursors and connections associated with the calling thread.
def create_connection_pool( connection_factory: Callable[[], Any], multithreaded: bool) -> sqlmesh.utils.connection_pool.ConnectionPool:
216def create_connection_pool(
217    connection_factory: t.Callable[[], t.Any], multithreaded: bool
218) -> ConnectionPool:
219    return (
220        ThreadLocalConnectionPool(connection_factory)
221        if multithreaded
222        else SingletonConnectionPool(connection_factory)
223    )