fix (data): Parallel connection pool.
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
||||
import logging
|
||||
|
||||
from contextvars import ContextVar
|
||||
from contextlib import asynccontextmanager
|
||||
import psycopg as psq
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.pq import TransactionStatus
|
||||
|
||||
from .cursor import AsyncLoggingCursor
|
||||
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
row_factory = psq.rows.dict_row
|
||||
|
||||
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
||||
|
||||
|
||||
class Connector:
|
||||
cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, conn_args):
|
||||
self._conn_args = conn_args
|
||||
self.conn: psq.AsyncConnection = None
|
||||
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
||||
|
||||
self.pool = self.make_pool()
|
||||
|
||||
self.conn_hooks = []
|
||||
|
||||
async def get_connection(self) -> psq.AsyncConnection:
|
||||
@property
|
||||
def conn(self) -> Optional[psq.AsyncConnection]:
|
||||
"""
|
||||
Get the current active connection.
|
||||
This should never be cached outside of a transaction.
|
||||
Convenience property for the current context connection.
|
||||
"""
|
||||
# TODO: Reconnection logic?
|
||||
if not self.conn:
|
||||
raise ValueError("Attempting to get connection before initialisation!")
|
||||
if self.conn.info.transaction_status is TransactionStatus.INERROR:
|
||||
await self.connect()
|
||||
logger.error(
|
||||
"Database connection transaction failed!! This should not happen. Reconnecting."
|
||||
)
|
||||
return self.conn
|
||||
return ctx_connection.get()
|
||||
|
||||
async def connect(self) -> psq.AsyncConnection:
|
||||
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
|
||||
self.conn = await psq.AsyncConnection.connect(
|
||||
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
|
||||
@conn.setter
|
||||
def conn(self, conn: psq.AsyncConnection):
|
||||
"""
|
||||
Set the contextual connection in the current context.
|
||||
Always do this in an isolated context!
|
||||
"""
|
||||
ctx_connection.set(conn)
|
||||
|
||||
def make_pool(self) -> AsyncConnectionPool:
|
||||
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
||||
return AsyncConnectionPool(
|
||||
self._conn_args,
|
||||
open=False,
|
||||
min_size=4,
|
||||
max_size=8,
|
||||
configure=self._setup_connection,
|
||||
kwargs=self._conn_kwargs
|
||||
)
|
||||
for hook in self.conn_hooks:
|
||||
await hook(self.conn)
|
||||
return self.conn
|
||||
|
||||
async def reconnect(self) -> psq.AsyncConnection:
|
||||
return await self.connect()
|
||||
async def refresh_pool(self):
|
||||
"""
|
||||
Refresh the pool.
|
||||
|
||||
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
||||
Better ways should be sought (a way to
|
||||
"""
|
||||
logger.info("Pool refresh requested, closing and reopening.")
|
||||
old_pool = self.pool
|
||||
self.pool = self.make_pool()
|
||||
await self.pool.open()
|
||||
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
||||
await old_pool.close()
|
||||
logger.info("Pool refresh complete.")
|
||||
|
||||
async def map_over_pool(self, callable):
|
||||
"""
|
||||
Dangerous method to call a method on each connection in the pool.
|
||||
|
||||
Utilises private methods of the AsyncConnectionPool.
|
||||
"""
|
||||
async with self.pool._lock:
|
||||
conns = list(self.pool._pool)
|
||||
while conns:
|
||||
conn = conns.pop()
|
||||
try:
|
||||
await callable(conn)
|
||||
except Exception:
|
||||
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def open(self):
|
||||
try:
|
||||
logger.info("Opening database pool.")
|
||||
await self.pool.open()
|
||||
yield
|
||||
finally:
|
||||
# May be a different pool!
|
||||
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
||||
await self.pool.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> psq.AsyncConnection:
|
||||
"""
|
||||
Asynchronous context manager to get and manage a connection.
|
||||
|
||||
If the context connection is set, uses this and does not manage the lifetime.
|
||||
Otherwise, requests a new connection from the pool and returns it when done.
|
||||
"""
|
||||
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
||||
if (conn := self.conn):
|
||||
yield conn
|
||||
else:
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
async def _setup_connection(self, conn: psq.AsyncConnection):
|
||||
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
||||
for hook in self.conn_hooks:
|
||||
try:
|
||||
await hook(conn)
|
||||
except Exception:
|
||||
logger.exception("Exception encountered setting up new connection")
|
||||
return conn
|
||||
|
||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user