fix (data): Parallel connection pool.

This commit is contained in:
2023-08-23 17:31:38 +03:00
parent 5bca9bca33
commit df9b835cd5
27 changed files with 1175 additions and 1021 deletions

View File

@@ -1,7 +1,10 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self.conn: psq.AsyncConnection = None
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection:
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Get the current active connection.
This should never be cached outside of a transaction.
Convenience property for the current context connection.
"""
# TODO: Reconnection logic?
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
if self.conn.info.transaction_status is TransactionStatus.INERROR:
await self.connect()
logger.error(
"Database connection transaction failed!! This should not happen. Reconnecting."
)
return self.conn
return ctx_connection.get()
async def connect(self) -> psq.AsyncConnection:
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
self.conn = await psq.AsyncConnection.connect(
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection:
return await self.connect()
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""