139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
from typing import AsyncIterator, Protocol, runtime_checkable, Callable, Awaitable, Optional
|
|
import logging
|
|
|
|
from contextvars import ContextVar
|
|
from contextlib import asynccontextmanager
|
|
import psycopg as psq
|
|
from psycopg_pool import AsyncConnectionPool
|
|
from psycopg.pq import TransactionStatus
|
|
|
|
from .cursor import AsyncLoggingCursor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
row_factory = psq.rows.dict_row
|
|
|
|
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
|
|
|
|
|
class Connector:
|
|
cursor_factory = AsyncLoggingCursor
|
|
|
|
def __init__(self, conn_args, pool_min_size=1, pool_max_size=4):
|
|
self._conn_args = conn_args
|
|
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
|
|
|
self.pool_min_size = pool_min_size
|
|
self.pool_max_size = pool_max_size
|
|
|
|
self.pool = self.make_pool()
|
|
|
|
self.conn_hooks = []
|
|
|
|
@property
|
|
def conn(self) -> Optional[psq.AsyncConnection]:
|
|
"""
|
|
Convenience property for the current context connection.
|
|
"""
|
|
return ctx_connection.get()
|
|
|
|
@conn.setter
|
|
def conn(self, conn: psq.AsyncConnection):
|
|
"""
|
|
Set the contextual connection in the current context.
|
|
Always do this in an isolated context!
|
|
"""
|
|
ctx_connection.set(conn)
|
|
|
|
def make_pool(self) -> AsyncConnectionPool:
|
|
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
|
return AsyncConnectionPool(
|
|
self._conn_args,
|
|
open=False,
|
|
min_size=self.pool_min_size,
|
|
max_size=self.pool_max_size,
|
|
configure=self._setup_connection,
|
|
kwargs=self._conn_kwargs
|
|
)
|
|
|
|
async def refresh_pool(self):
|
|
"""
|
|
Refresh the pool.
|
|
|
|
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
|
Better ways should be sought (a way to
|
|
"""
|
|
logger.info("Pool refresh requested, closing and reopening.")
|
|
old_pool = self.pool
|
|
self.pool = self.make_pool()
|
|
await self.pool.open()
|
|
logger.info(f"Old pool statistics: {old_pool.get_stats()}")
|
|
await old_pool.close()
|
|
logger.info("Pool refresh complete.")
|
|
|
|
async def map_over_pool(self, callable):
|
|
"""
|
|
Dangerous method to call a method on each connection in the pool.
|
|
|
|
Utilises private methods of the AsyncConnectionPool.
|
|
"""
|
|
async with self.pool._lock:
|
|
conns = list(self.pool._pool)
|
|
while conns:
|
|
conn = conns.pop()
|
|
try:
|
|
await callable(conn)
|
|
except Exception:
|
|
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
|
|
|
@asynccontextmanager
|
|
async def open(self):
|
|
try:
|
|
logger.info("Opening database pool.")
|
|
await self.pool.open()
|
|
yield
|
|
finally:
|
|
# May be a different pool!
|
|
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
|
await self.pool.close()
|
|
|
|
@asynccontextmanager
|
|
async def connection(self) -> AsyncIterator[psq.AsyncConnection]:
|
|
"""
|
|
Asynchronous context manager to get and manage a connection.
|
|
|
|
If the context connection is set, uses this and does not manage the lifetime.
|
|
Otherwise, requests a new connection from the pool and returns it when done.
|
|
"""
|
|
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
|
if (conn := self.conn):
|
|
yield conn
|
|
else:
|
|
async with self.pool.connection() as conn:
|
|
yield conn
|
|
|
|
async def _setup_connection(self, conn: psq.AsyncConnection):
|
|
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
|
for hook in self.conn_hooks:
|
|
try:
|
|
await hook(conn)
|
|
except Exception:
|
|
logger.exception("Exception encountered setting up new connection")
|
|
return conn
|
|
|
|
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
|
"""
|
|
Minimal decorator to register a coroutine to run on connect or reconnect.
|
|
|
|
Note that these are only run on connect and reconnect.
|
|
If a hook is registered after connection, it will not be run.
|
|
"""
|
|
self.conn_hooks.append(coro)
|
|
return coro
|
|
|
|
|
|
@runtime_checkable
|
|
class Connectable(Protocol):
|
|
def bind(self, connector: Connector):
|
|
raise NotImplementedError
|