fix (data): Parallel connection pool.

This commit is contained in:
2023-08-23 17:31:38 +03:00
parent 5bca9bca33
commit df9b835cd5
27 changed files with 1175 additions and 1021 deletions

View File

@@ -1,6 +1,7 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connection = await connector.get_connection()
if connection is None:
raise ValueError("Cannot Init without connection.")
info = await EnumInfo.fetch(connection, self.name)
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,7 +1,10 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self.conn: psq.AsyncConnection = None
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection:
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Get the current active connection.
This should never be cached outside of a transaction.
Convenience property for the current context connection.
"""
# TODO: Reconnection logic?
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
if self.conn.info.transaction_status is TransactionStatus.INERROR:
await self.connect()
logger.error(
"Database connection transaction failed!! This should not happen. Reconnecting."
)
return self.conn
return ctx_connection.get()
async def connect(self) -> psq.AsyncConnection:
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
self.conn = await psq.AsyncConnection.connect(
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection:
return await self.connect()
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""

View File

@@ -35,12 +35,13 @@ class Database(Connector):
"""
Return the current schema version as a Version namedtuple.
"""
async with self.conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

View File

@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
conn = await self.connector.get_connection()
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
conn = self.conn
async with conn.cursor() as cursor:
data = await self._execute(cursor)
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data