Initial Commit

This commit is contained in:
2025-07-31 07:39:53 +10:00
commit 4da50d8678
12 changed files with 1851 additions and 0 deletions

135
connector.py Normal file
View File

@@ -0,0 +1,135 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Convenience property for the current context connection.
"""
return ctx_connection.get()
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=1,
max_size=4,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError