59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
from typing import Protocol, runtime_checkable, Callable, Awaitable
|
|
import logging
|
|
|
|
import psycopg as psq
|
|
|
|
from .cursor import AsyncLoggingCursor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
row_factory = psq.rows.dict_row
|
|
|
|
|
|
class Connector:
|
|
cursor_factory = AsyncLoggingCursor
|
|
|
|
def __init__(self, conn_args):
|
|
self._conn_args = conn_args
|
|
self.conn: psq.AsyncConnection = None
|
|
|
|
self.conn_hooks = []
|
|
|
|
async def get_connection(self) -> psq.AsyncConnection:
|
|
"""
|
|
Get the current active connection.
|
|
This should never be cached outside of a transaction.
|
|
"""
|
|
# TODO: Reconnection logic?
|
|
if not self.conn:
|
|
raise ValueError("Attempting to get connection before initialisation!")
|
|
return self.conn
|
|
|
|
async def connect(self) -> psq.AsyncConnection:
|
|
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
|
|
self.conn = await psq.AsyncConnection.connect(
|
|
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
|
|
)
|
|
for hook in self.conn_hooks:
|
|
await hook(self.conn)
|
|
return self.conn
|
|
|
|
async def reconnect(self) -> psq.AsyncConnection:
|
|
return await self.connect()
|
|
|
|
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
|
"""
|
|
Minimal decorator to register a coroutine to run on connect or reconnect.
|
|
|
|
Note that these are only run on connect and reconnect.
|
|
If a hook is registered after connection, it will not be run.
|
|
"""
|
|
self.conn_hooks.append(coro)
|
|
return coro
|
|
|
|
|
|
@runtime_checkable
|
|
class Connectable(Protocol):
|
|
def bind(self, connector: Connector):
|
|
raise NotImplementedError
|