fix (data): Parallel connection pool.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# from enum import Enum
|
||||
from typing import Optional
|
||||
from psycopg.types.enum import register_enum, EnumInfo
|
||||
from psycopg import AsyncConnection
|
||||
from .registry import Attachable, Registry
|
||||
|
||||
|
||||
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
|
||||
connector = registry._conn
|
||||
if connector is None:
|
||||
raise ValueError("Cannot initialise without connector!")
|
||||
connection = await connector.get_connection()
|
||||
if connection is None:
|
||||
raise ValueError("Cannot Init without connection.")
|
||||
info = await EnumInfo.fetch(connection, self.name)
|
||||
connector.connect_hook(self.connection_hook)
|
||||
# await connector.refresh_pool()
|
||||
# The below may be somewhat dangerous
|
||||
# But adaption should never write to the database
|
||||
await connector.map_over_pool(self.connection_hook)
|
||||
# if conn := connector.conn:
|
||||
# # Ensure the adaption is run in the current context as well
|
||||
# await self.connection_hook(conn)
|
||||
|
||||
async def connection_hook(self, conn: AsyncConnection):
|
||||
info = await EnumInfo.fetch(conn, self.name)
|
||||
if info is None:
|
||||
raise ValueError(f"Enum {self.name} not found in database.")
|
||||
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
|
||||
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
||||
import logging
|
||||
|
||||
from contextvars import ContextVar
|
||||
from contextlib import asynccontextmanager
|
||||
import psycopg as psq
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.pq import TransactionStatus
|
||||
|
||||
from .cursor import AsyncLoggingCursor
|
||||
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
row_factory = psq.rows.dict_row
|
||||
|
||||
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
||||
|
||||
|
||||
class Connector:
|
||||
cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, conn_args):
|
||||
self._conn_args = conn_args
|
||||
self.conn: psq.AsyncConnection = None
|
||||
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
||||
|
||||
self.pool = self.make_pool()
|
||||
|
||||
self.conn_hooks = []
|
||||
|
||||
async def get_connection(self) -> psq.AsyncConnection:
|
||||
@property
|
||||
def conn(self) -> Optional[psq.AsyncConnection]:
|
||||
"""
|
||||
Get the current active connection.
|
||||
This should never be cached outside of a transaction.
|
||||
Convenience property for the current context connection.
|
||||
"""
|
||||
# TODO: Reconnection logic?
|
||||
if not self.conn:
|
||||
raise ValueError("Attempting to get connection before initialisation!")
|
||||
if self.conn.info.transaction_status is TransactionStatus.INERROR:
|
||||
await self.connect()
|
||||
logger.error(
|
||||
"Database connection transaction failed!! This should not happen. Reconnecting."
|
||||
)
|
||||
return self.conn
|
||||
return ctx_connection.get()
|
||||
|
||||
async def connect(self) -> psq.AsyncConnection:
|
||||
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
|
||||
self.conn = await psq.AsyncConnection.connect(
|
||||
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
|
||||
@conn.setter
|
||||
def conn(self, conn: psq.AsyncConnection):
|
||||
"""
|
||||
Set the contextual connection in the current context.
|
||||
Always do this in an isolated context!
|
||||
"""
|
||||
ctx_connection.set(conn)
|
||||
|
||||
def make_pool(self) -> AsyncConnectionPool:
|
||||
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
||||
return AsyncConnectionPool(
|
||||
self._conn_args,
|
||||
open=False,
|
||||
min_size=4,
|
||||
max_size=8,
|
||||
configure=self._setup_connection,
|
||||
kwargs=self._conn_kwargs
|
||||
)
|
||||
for hook in self.conn_hooks:
|
||||
await hook(self.conn)
|
||||
return self.conn
|
||||
|
||||
async def reconnect(self) -> psq.AsyncConnection:
|
||||
return await self.connect()
|
||||
async def refresh_pool(self):
|
||||
"""
|
||||
Refresh the pool.
|
||||
|
||||
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
||||
Better ways should be sought (a way to
|
||||
"""
|
||||
logger.info("Pool refresh requested, closing and reopening.")
|
||||
old_pool = self.pool
|
||||
self.pool = self.make_pool()
|
||||
await self.pool.open()
|
||||
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
||||
await old_pool.close()
|
||||
logger.info("Pool refresh complete.")
|
||||
|
||||
async def map_over_pool(self, callable):
|
||||
"""
|
||||
Dangerous method to call a method on each connection in the pool.
|
||||
|
||||
Utilises private methods of the AsyncConnectionPool.
|
||||
"""
|
||||
async with self.pool._lock:
|
||||
conns = list(self.pool._pool)
|
||||
while conns:
|
||||
conn = conns.pop()
|
||||
try:
|
||||
await callable(conn)
|
||||
except Exception:
|
||||
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def open(self):
|
||||
try:
|
||||
logger.info("Opening database pool.")
|
||||
await self.pool.open()
|
||||
yield
|
||||
finally:
|
||||
# May be a different pool!
|
||||
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
||||
await self.pool.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> psq.AsyncConnection:
|
||||
"""
|
||||
Asynchronous context manager to get and manage a connection.
|
||||
|
||||
If the context connection is set, uses this and does not manage the lifetime.
|
||||
Otherwise, requests a new connection from the pool and returns it when done.
|
||||
"""
|
||||
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
||||
if (conn := self.conn):
|
||||
yield conn
|
||||
else:
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
async def _setup_connection(self, conn: psq.AsyncConnection):
|
||||
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
||||
for hook in self.conn_hooks:
|
||||
try:
|
||||
await hook(conn)
|
||||
except Exception:
|
||||
logger.exception("Exception encountered setting up new connection")
|
||||
return conn
|
||||
|
||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||
"""
|
||||
|
||||
@@ -35,12 +35,13 @@ class Database(Connector):
|
||||
"""
|
||||
Return the current schema version as a Version namedtuple.
|
||||
"""
|
||||
async with self.conn.cursor() as cursor:
|
||||
# Get last entry in version table, compare against desired version
|
||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Version(row['version'], row['time'], row['author'])
|
||||
else:
|
||||
# No versions in the database
|
||||
return Version(-1, None, None)
|
||||
async with self.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Get last entry in version table, compare against desired version
|
||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Version(row['version'], row['time'], row['author'])
|
||||
else:
|
||||
# No versions in the database
|
||||
return Version(-1, None, None)
|
||||
|
||||
@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
|
||||
if self.connector is None:
|
||||
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
||||
else:
|
||||
conn = await self.connector.get_connection()
|
||||
async with self.connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
conn = self.conn
|
||||
|
||||
async with conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
async with self.conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
data = await self._execute(cursor)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user