From 42af454864ef21904c50287a335350641faf73d5 Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 1 Aug 2025 01:30:23 +1000 Subject: [PATCH] Move data to psqlmapper module. --- .gitmodules | 3 + src/data | 1 + src/data/__init__.py | 9 - src/data/adapted.py | 40 --- src/data/base.py | 45 --- src/data/columns.py | 155 ---------- src/data/conditions.py | 214 -------------- src/data/connector.py | 135 --------- src/data/cursor.py | 42 --- src/data/database.py | 47 --- src/data/models.py | 323 --------------------- src/data/queries.py | 644 ----------------------------------------- src/data/registry.py | 102 ------- src/data/table.py | 95 ------ 14 files changed, 4 insertions(+), 1851 deletions(-) create mode 160000 src/data delete mode 100644 src/data/__init__.py delete mode 100644 src/data/adapted.py delete mode 100644 src/data/base.py delete mode 100644 src/data/columns.py delete mode 100644 src/data/conditions.py delete mode 100644 src/data/connector.py delete mode 100644 src/data/cursor.py delete mode 100644 src/data/database.py delete mode 100644 src/data/models.py delete mode 100644 src/data/queries.py delete mode 100644 src/data/registry.py delete mode 100644 src/data/table.py diff --git a/.gitmodules b/.gitmodules index 64ee955..562cc5e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "src/modules/streamalerts"] path = src/modules/streamalerts url = git@github.com:Intery/StudyLion-streamalerts.git +[submodule "src/data"] + path = src/data + url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git diff --git a/src/data b/src/data new file mode 160000 index 0000000..cfdfe0e --- /dev/null +++ b/src/data @@ -0,0 +1 @@ +Subproject commit cfdfe0eb50034d54a08c8449e8a62a5b8854e259 diff --git a/src/data/__init__.py b/src/data/__init__.py deleted file mode 100644 index affb160..0000000 --- a/src/data/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .conditions import Condition, condition, NULL -from .database import Database -from .models import RowModel, RowTable, WeakCache -from .table import Table -from .base import Expression, RawExpr -from .columns import ColumnExpr, Column, Integer, String -from .registry import Registry, AttachableClass, Attachable -from .adapted import RegisterEnum -from .queries import ORDER, NULLS, JOINTYPE diff --git a/src/data/adapted.py b/src/data/adapted.py deleted file mode 100644 index a6b4597..0000000 --- a/src/data/adapted.py +++ /dev/null @@ -1,40 +0,0 @@ -# from enum import Enum -from typing import Optional -from psycopg.types.enum import register_enum, EnumInfo -from psycopg import AsyncConnection -from .registry import Attachable, Registry - - -class RegisterEnum(Attachable): - def __init__(self, enum, name: Optional[str] = None, mapper=None): - super().__init__() - self.enum = enum - self.name = name or enum.__name__ - self.mapping = mapper(enum) if mapper is not None else self._mapper() - - def _mapper(self): - return {m: m.value[0] for m in self.enum} - - def attach_to(self, registry: Registry): - self._registry = registry - registry.init_task(self.on_init) - return self - - async def on_init(self, registry: Registry): - connector = registry._conn - if connector is None: - raise ValueError("Cannot initialise without connector!") - connector.connect_hook(self.connection_hook) - # await connector.refresh_pool() - # The below may be somewhat dangerous - # But adaption should never write to the database - await connector.map_over_pool(self.connection_hook) - # if conn := connector.conn: - # # Ensure the adaption is run in the current context as well - # await self.connection_hook(conn) - - async def connection_hook(self, conn: AsyncConnection): - info = await EnumInfo.fetch(conn, self.name) - if info is None: - raise ValueError(f"Enum {self.name} not found in database.") - register_enum(info, conn, self.enum, mapping=list(self.mapping.items())) diff --git a/src/data/base.py b/src/data/base.py deleted file mode 100644 index 272d588..0000000 --- a/src/data/base.py +++ /dev/null @@ -1,45 +0,0 @@ -from abc import abstractmethod -from typing import Any, Protocol, runtime_checkable -from itertools import chain -from psycopg import sql - - -@runtime_checkable -class Expression(Protocol): - __slots__ = () - - @abstractmethod - def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]: - raise NotImplementedError - - -class RawExpr(Expression): - __slots__ = ('expr', 'values') - - expr: sql.Composable - values: tuple[Any, ...] - - def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()): - self.expr = expr - self.values = values - - def as_tuple(self): - return (self.expr, self.values) - - @classmethod - def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')): - """ - Join a sequence of Expressions into a single RawExpr. - """ - tups = ( - expression.as_tuple() - for expression in expressions - ) - return cls.join_tuples(*tups, joiner=joiner) - - @classmethod - def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')): - exprs, values = zip(*tuples) - expr = joiner.join(exprs) - value = tuple(chain(*values)) - return cls(expr, value) diff --git a/src/data/columns.py b/src/data/columns.py deleted file mode 100644 index 252db83..0000000 --- a/src/data/columns.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING -from psycopg import sql -from datetime import datetime - -from .base import RawExpr, Expression -from .conditions import Condition, Joiner -from .table import Table - - -class ColumnExpr(RawExpr): - __slots__ = () - - def __lt__(self, obj) -> Condition: - expr, values = self.as_tuple() - - if isinstance(obj, Expression): - # column < Expression - obj_expr, obj_values = obj.as_tuple() - cond_exprs = (expr, Joiner.LT, obj_expr) - cond_values = (*values, *obj_values) - else: - # column < Literal - cond_exprs = (expr, Joiner.LT, sql.Placeholder()) - cond_values = (*values, obj) - - return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) - - def __le__(self, obj) -> Condition: - expr, values = self.as_tuple() - - if isinstance(obj, Expression): - # column <= Expression - obj_expr, obj_values = obj.as_tuple() - cond_exprs = (expr, Joiner.LE, obj_expr) - cond_values = (*values, *obj_values) - else: - # column <= Literal - cond_exprs = (expr, Joiner.LE, sql.Placeholder()) - cond_values = (*values, obj) - - return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) - - def __eq__(self, obj) -> Condition: # type: ignore[override] - return Condition._expression_equality(self, obj) - - def __ne__(self, obj) -> Condition: # type: ignore[override] - return ~(self.__eq__(obj)) - - def __gt__(self, obj) -> Condition: - return ~(self.__le__(obj)) - - def __ge__(self, obj) -> Condition: - return ~(self.__lt__(obj)) - - def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr': - if isinstance(obj, Expression): - obj_expr, obj_values = obj.as_tuple() - return ColumnExpr( - sql.SQL("({} + {})").format(self.expr, obj_expr), - (*self.values, *obj_values) - ) - else: - return ColumnExpr( - sql.SQL("({} + {})").format(self.expr, sql.Placeholder()), - (*self.values, obj) - ) - - def __sub__(self, obj) -> 'ColumnExpr': - if isinstance(obj, Expression): - obj_expr, obj_values = obj.as_tuple() - return ColumnExpr( - sql.SQL("({} - {})").format(self.expr, obj_expr), - (*self.values, *obj_values) - ) - else: - return ColumnExpr( - sql.SQL("({} - {})").format(self.expr, sql.Placeholder()), - (*self.values, obj) - ) - - def __mul__(self, obj) -> 'ColumnExpr': - if isinstance(obj, Expression): - obj_expr, obj_values = obj.as_tuple() - return ColumnExpr( - sql.SQL("({} * {})").format(self.expr, obj_expr), - (*self.values, *obj_values) - ) - else: - return ColumnExpr( - sql.SQL("({} * {})").format(self.expr, sql.Placeholder()), - (*self.values, obj) - ) - - def CAST(self, target_type: sql.Composable): - return ColumnExpr( - sql.SQL("({}::{})").format(self.expr, target_type), - self.values - ) - - -T = TypeVar('T') - -if TYPE_CHECKING: - from .models import RowModel - - -class Column(ColumnExpr, Generic[T]): - def __init__(self, name: Optional[str] = None, - primary: bool = False, references: Optional['Column'] = None, - type: Optional[Type[T]] = None): - self.primary = primary - self.references = references - self.name: str = name # type: ignore - self.owner: Optional['RowModel'] = None - self._type = type - - self.expr = sql.Identifier(name) if name else sql.SQL('') - self.values = () - - def __set_name__(self, owner, name): - # Only allow setting the owner once - self.name = self.name or name - self.owner = owner - self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name) - - @overload - def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': - ... - - @overload - def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T: - ... - - def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]": - # Get value from row data or session - if obj is None: - return self - else: - return obj.data[self.name] - - -class Integer(Column[int]): - pass - - -class String(Column[str]): - pass - - -class Bool(Column[bool]): - pass - - -class Timestamp(Column[datetime]): - pass diff --git a/src/data/conditions.py b/src/data/conditions.py deleted file mode 100644 index f40dff6..0000000 --- a/src/data/conditions.py +++ /dev/null @@ -1,214 +0,0 @@ -# from meta import sharding -from typing import Any, Union -from enum import Enum -from itertools import chain -from psycopg import sql - -from .base import Expression, RawExpr - - -""" -A Condition is a "logical" database expression, intended for use in Where statements. -Conditions support bitwise logical operators ~, &, |, each producing another Condition. -""" - -NULL = None - - -class Joiner(Enum): - EQUALS = ('=', '!=') - IS = ('IS', 'IS NOT') - LIKE = ('LIKE', 'NOT LIKE') - BETWEEN = ('BETWEEN', 'NOT BETWEEN') - IN = ('IN', 'NOT IN') - LT = ('<', '>=') - LE = ('<=', '>') - NONE = ('', '') - - -class Condition(Expression): - __slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values') - - def __init__(self, - expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''), - values: tuple[Any, ...] = (), negated=False - ): - self.expr1 = expr1 - self.joiner = joiner - self.negated = negated - self.expr2 = expr2 - self.values = values - - def as_tuple(self): - expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2)) - if self.negated and self.joiner is Joiner.NONE: - expr = sql.SQL("NOT ({})").format(expr) - return (expr, self.values) - - @classmethod - def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]): - """ - Construct a Condition from a sequence of Conditions, - together with some explicit column conditions. - """ - # TODO: Consider adding a _table identifier here so we can identify implicit columns - # Or just require subquery type conditions to always come from modelled tables. - implicit_conditions = ( - cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items() - ) - return cls._and(*conditions, *implicit_conditions) - - @classmethod - def _and(cls, *conditions: 'Condition'): - if not len(conditions): - raise ValueError("Cannot combine 0 Conditions") - if len(conditions) == 1: - return conditions[0] - - exprs, values = zip(*(condition.as_tuple() for condition in conditions)) - cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs)) - cond_values = tuple(chain(*values)) - - return Condition(cond_expr, values=cond_values) - - @classmethod - def _or(cls, *conditions: 'Condition'): - if not len(conditions): - raise ValueError("Cannot combine 0 Conditions") - if len(conditions) == 1: - return conditions[0] - - exprs, values = zip(*(condition.as_tuple() for condition in conditions)) - cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs)) - cond_values = tuple(chain(*values)) - - return Condition(cond_expr, values=cond_values) - - @classmethod - def _not(cls, condition: 'Condition'): - condition.negated = not condition.negated - return condition - - @classmethod - def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition': - # TODO: Check if this supports sbqueries - col_expr, col_values = column.as_tuple() - - # TODO: Also support sql.SQL? For joins? - if isinstance(value, Expression): - # column = Expression - value_expr, value_values = value.as_tuple() - cond_exprs = (col_expr, Joiner.EQUALS, value_expr) - cond_values = (*col_values, *value_values) - elif isinstance(value, (tuple, list)): - # column in (...) - # TODO: Support expressions in value tuple? - if not value: - raise ValueError("Cannot create Condition from empty iterable!") - value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value))) - cond_exprs = (col_expr, Joiner.IN, value_expr) - cond_values = (*col_values, *value) - elif value is None: - # column IS NULL - cond_exprs = (col_expr, Joiner.IS, sql.NULL) - cond_values = col_values - else: - # column = Literal - cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder()) - cond_values = (*col_values, value) - - return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) - - def __invert__(self) -> 'Condition': - self.negated = not self.negated - return self - - def __and__(self, condition: 'Condition') -> 'Condition': - return self._and(self, condition) - - def __or__(self, condition: 'Condition') -> 'Condition': - return self._or(self, condition) - - -# Helper method to simply condition construction -def condition(*args, **kwargs) -> Condition: - return Condition.construct(*args, **kwargs) - - -# class NOT(Condition): -# __slots__ = ('value',) -# -# def __init__(self, value): -# self.value = value -# -# def apply(self, key, values, conditions): -# item = self.value -# if isinstance(item, (list, tuple)): -# if item: -# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item)))) -# values.extend(item) -# else: -# raise ValueError("Cannot check an empty iterable!") -# else: -# conditions.append("{}!={}".format(key, _replace_char)) -# values.append(item) -# -# -# class GEQ(Condition): -# __slots__ = ('value',) -# -# def __init__(self, value): -# self.value = value -# -# def apply(self, key, values, conditions): -# item = self.value -# if isinstance(item, (list, tuple)): -# raise ValueError("Cannot apply GEQ condition to a list!") -# else: -# conditions.append("{} >= {}".format(key, _replace_char)) -# values.append(item) -# -# -# class LEQ(Condition): -# __slots__ = ('value',) -# -# def __init__(self, value): -# self.value = value -# -# def apply(self, key, values, conditions): -# item = self.value -# if isinstance(item, (list, tuple)): -# raise ValueError("Cannot apply LEQ condition to a list!") -# else: -# conditions.append("{} <= {}".format(key, _replace_char)) -# values.append(item) -# -# -# class Constant(Condition): -# __slots__ = ('value',) -# -# def __init__(self, value): -# self.value = value -# -# def apply(self, key, values, conditions): -# conditions.append("{} {}".format(key, self.value)) -# -# -# class SHARDID(Condition): -# __slots__ = ('shardid', 'shard_count') -# -# def __init__(self, shardid, shard_count): -# self.shardid = shardid -# self.shard_count = shard_count -# -# def apply(self, key, values, conditions): -# if self.shard_count > 1: -# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) -# values.append(self.shardid) -# -# -# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) -# -# -# NULL = Constant('IS NULL') -# NOTNULL = Constant('IS NOT NULL') diff --git a/src/data/connector.py b/src/data/connector.py deleted file mode 100644 index 7b25aed..0000000 --- a/src/data/connector.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional -import logging - -from contextvars import ContextVar -from contextlib import asynccontextmanager -import psycopg as psq -from psycopg_pool import AsyncConnectionPool -from psycopg.pq import TransactionStatus - -from .cursor import AsyncLoggingCursor - -logger = logging.getLogger(__name__) - -row_factory = psq.rows.dict_row - -ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None) - - -class Connector: - cursor_factory = AsyncLoggingCursor - - def __init__(self, conn_args): - self._conn_args = conn_args - self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory) - - self.pool = self.make_pool() - - self.conn_hooks = [] - - @property - def conn(self) -> Optional[psq.AsyncConnection]: - """ - Convenience property for the current context connection. - """ - return ctx_connection.get() - - @conn.setter - def conn(self, conn: psq.AsyncConnection): - """ - Set the contextual connection in the current context. - Always do this in an isolated context! - """ - ctx_connection.set(conn) - - def make_pool(self) -> AsyncConnectionPool: - logger.info("Initialising connection pool.", extra={'action': "Pool Init"}) - return AsyncConnectionPool( - self._conn_args, - open=False, - min_size=4, - max_size=8, - configure=self._setup_connection, - kwargs=self._conn_kwargs - ) - - async def refresh_pool(self): - """ - Refresh the pool. - - The point of this is to invalidate any existing connections so that the connection set up is run again. - Better ways should be sought (a way to - """ - logger.info("Pool refresh requested, closing and reopening.") - old_pool = self.pool - self.pool = self.make_pool() - await self.pool.open() - logger.info(f"Old pool statistics: {self.pool.get_stats()}") - await old_pool.close() - logger.info("Pool refresh complete.") - - async def map_over_pool(self, callable): - """ - Dangerous method to call a method on each connection in the pool. - - Utilises private methods of the AsyncConnectionPool. - """ - async with self.pool._lock: - conns = list(self.pool._pool) - while conns: - conn = conns.pop() - try: - await callable(conn) - except Exception: - logger.exception(f"Mapped connection task failed. {callable.__name__}") - - @asynccontextmanager - async def open(self): - try: - logger.info("Opening database pool.") - await self.pool.open() - yield - finally: - # May be a different pool! - logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}") - await self.pool.close() - - @asynccontextmanager - async def connection(self) -> psq.AsyncConnection: - """ - Asynchronous context manager to get and manage a connection. - - If the context connection is set, uses this and does not manage the lifetime. - Otherwise, requests a new connection from the pool and returns it when done. - """ - logger.debug("Database connection requested.", extra={'action': "Data Connect"}) - if (conn := self.conn): - yield conn - else: - async with self.pool.connection() as conn: - yield conn - - async def _setup_connection(self, conn: psq.AsyncConnection): - logger.debug("Initialising new connection.", extra={'action': "Conn Init"}) - for hook in self.conn_hooks: - try: - await hook(conn) - except Exception: - logger.exception("Exception encountered setting up new connection") - return conn - - def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): - """ - Minimal decorator to register a coroutine to run on connect or reconnect. - - Note that these are only run on connect and reconnect. - If a hook is registered after connection, it will not be run. - """ - self.conn_hooks.append(coro) - return coro - - -@runtime_checkable -class Connectable(Protocol): - def bind(self, connector: Connector): - raise NotImplementedError diff --git a/src/data/cursor.py b/src/data/cursor.py deleted file mode 100644 index 5d183e0..0000000 --- a/src/data/cursor.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging -from typing import Optional - -from psycopg import AsyncCursor, sql -from psycopg.abc import Query, Params -from psycopg._encodings import conn_encoding - -logger = logging.getLogger(__name__) - - -class AsyncLoggingCursor(AsyncCursor): - def mogrify_query(self, query: Query): - if isinstance(query, str): - msg = query - elif isinstance(query, (sql.SQL, sql.Composed)): - msg = query.as_string(self) - elif isinstance(query, bytes): - msg = query.decode(conn_encoding(self._conn.pgconn), 'replace') - else: - msg = repr(query) - return msg - - async def execute(self, query: Query, params: Optional[Params] = None, **kwargs): - if logging.DEBUG >= logger.getEffectiveLevel(): - msg = self.mogrify_query(query) - logger.debug( - "Executing query (%s) with values %s", msg, params, - extra={'action': "Query Execute"} - ) - try: - return await super().execute(query, params=params, **kwargs) - except Exception: - msg = self.mogrify_query(query) - logger.exception( - "Exception during query execution. Query (%s) with parameters %s.", - msg, params, - extra={'action': "Query Execute"}, - stack_info=True - ) - else: - # TODO: Possibly log execution time - pass diff --git a/src/data/database.py b/src/data/database.py deleted file mode 100644 index 255e412..0000000 --- a/src/data/database.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import TypeVar -import logging -from collections import namedtuple - -# from .cursor import AsyncLoggingCursor -from .registry import Registry -from .connector import Connector - - -logger = logging.getLogger(__name__) - -Version = namedtuple('Version', ('version', 'time', 'author')) - -T = TypeVar('T', bound=Registry) - - -class Database(Connector): - # cursor_factory = AsyncLoggingCursor - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.registries: dict[str, Registry] = {} - - def load_registry(self, registry: T) -> T: - logger.debug( - f"Loading and binding registry '{registry.name}'.", - extra={'action': f"Reg {registry.name}"} - ) - registry.bind(self) - self.registries[registry.name] = registry - return registry - - async def version(self) -> Version: - """ - Return the current schema version as a Version namedtuple. - """ - async with self.connection() as conn: - async with conn.cursor() as cursor: - # Get last entry in version table, compare against desired version - await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") - row = await cursor.fetchone() - if row: - return Version(row['version'], row['time'], row['author']) - else: - # No versions in the database - return Version(-1, None, None) diff --git a/src/data/models.py b/src/data/models.py deleted file mode 100644 index 54b6282..0000000 --- a/src/data/models.py +++ /dev/null @@ -1,323 +0,0 @@ -from typing import TypeVar, Type, Optional, Generic, Union -# from typing_extensions import Self -from weakref import WeakValueDictionary -from collections.abc import MutableMapping - -from psycopg.rows import DictRow - -from .table import Table -from .columns import Column -from . import queries as q -from .connector import Connector -from .registry import Registry - - -RowT = TypeVar('RowT', bound='RowModel') - - -class MISSING: - __slots__ = ('oid',) - - def __init__(self, oid): - self.oid = oid - - -class RowTable(Table, Generic[RowT]): - __slots__ = ( - 'model', - ) - - def __init__(self, name, model: Type[RowT], **kwargs): - super().__init__(name, **kwargs) - self.model = model - - @property - def columns(self): - return self.model._columns_ - - @property - def id_col(self): - return self.model._key_ - - @property - def row_cache(self): - return self.model._cache_ - - def _many_query_adapter(self, *data): - self.model._make_rows(*data) - return data - - def _single_query_adapter(self, *data): - if data: - self.model._make_rows(*data) - return data[0] - else: - return None - - def _delete_query_adapter(self, *data): - self.model._delete_rows(*data) - return data - - # New methods to fetch and create rows - async def create_row(self, *args, **kwargs) -> RowT: - data = await super().insert(*args, **kwargs) - return self.model._make_rows(data)[0] - - def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]: - # TODO: Handle list of rowids here? - return q.Select( - self.identifier, - row_adapter=self.model._make_rows, - connector=self.connector - ).where(*args, **kwargs) - - -WK = TypeVar('WK') -WV = TypeVar('WV') - - -class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]): - def __init__(self, ref_cache): - self.ref_cache = ref_cache - self.weak_cache = WeakValueDictionary() - - def __getitem__(self, key): - value = self.weak_cache[key] - self.ref_cache[key] = value - return value - - def __setitem__(self, key, value): - self.weak_cache[key] = value - self.ref_cache[key] = value - - def __delitem__(self, key): - del self.weak_cache[key] - try: - del self.ref_cache[key] - except KeyError: - pass - - def __contains__(self, key): - return key in self.weak_cache - - def __iter__(self): - return iter(self.weak_cache) - - def __len__(self): - return len(self.weak_cache) - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def pop(self, key, default=None): - if key in self: - value = self[key] - del self[key] - else: - value = default - return value - - -# TODO: Implement getitem and setitem, for dynamic column access -class RowModel: - __slots__ = ('data',) - - _schema_: str = 'public' - _tablename_: Optional[str] = None - _columns_: dict[str, Column] = {} - - # Cache to keep track of registered Rows - _cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore - - _key_: tuple[str, ...] = () - _connector: Optional[Connector] = None - _registry: Optional[Registry] = None - - # TODO: Proper typing for a classvariable which gets dynamically assigned in subclass - table: RowTable = None - - def __init_subclass__(cls: Type[RowT], table: Optional[str] = None): - """ - Set table, _columns_, and _key_. - """ - if table is not None: - cls._tablename_ = table - - if cls._tablename_ is not None: - columns = {} - for key, value in cls.__dict__.items(): - if isinstance(value, Column): - columns[key] = value - - cls._columns_ = columns - if not cls._key_: - cls._key_ = tuple(column.name for column in columns.values() if column.primary) - cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_) - if cls._cache_ is None: - cls._cache_ = WeakValueDictionary() - - def __new__(cls, data): - # Registry pattern. - # Ensure each rowid always refers to a single Model instance - if data is not None: - rowid = cls._id_from_data(data) - - cache = cls._cache_ - - if (row := cache.get(rowid, None)) is not None: - obj = row - else: - obj = cache[rowid] = super().__new__(cls) - else: - obj = super().__new__(cls) - - return obj - - @classmethod - def as_tuple(cls): - return (cls.table.identifier, ()) - - def __init__(self, data): - self.data = data - - def __getitem__(self, key): - return self.data[key] - - def __setitem__(self, key, value): - self.data[key] = value - - @classmethod - def bind(cls, connector: Connector): - if cls.table is None: - raise ValueError("Cannot bind abstract RowModel") - cls._connector = connector - cls.table.bind(connector) - return cls - - @classmethod - def attach_to(cls, registry: Registry): - cls._registry = registry - return cls - - @property - def _dict_(self): - return {key: self.data[key] for key in self._key_} - - @property - def _rowid_(self): - return tuple(self.data[key] for key in self._key_) - - def __repr__(self): - return "{}.{}({})".format( - self.table.schema, - self.table.name, - ', '.join(repr(column.__get__(self)) for column in self._columns_.values()) - ) - - @classmethod - def _id_from_data(cls, data): - return tuple(data[key] for key in cls._key_) - - @classmethod - def _dict_from_id(cls, rowid): - return dict(zip(cls._key_, rowid)) - - @classmethod - def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]: - """ - Create or retrieve Row objects for each provided data row. - If the rows already exist in cache, updates the cached row. - """ - # TODO: Handle partial row data here somehow? - rows = [cls(data_row) for data_row in data_rows] - return rows - - @classmethod - def _delete_rows(cls, *data_rows): - """ - Remove the given rows from cache, if they exist. - May be extended to handle object deletion. - """ - cache = cls._cache_ - - for data_row in data_rows: - rowid = cls._id_from_data(data_row) - cache.pop(rowid, None) - - @classmethod - async def create(cls: Type[RowT], *args, **kwargs) -> RowT: - return await cls.table.create_row(*args, **kwargs) - - @classmethod - def fetch_where(cls: Type[RowT], *args, **kwargs): - return cls.table.fetch_rows_where(*args, **kwargs) - - @classmethod - async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]: - """ - Fetch the row with the given id, retrieving from cache where possible. - """ - row = cls._cache_.get(rowid, None) if cached else None - if row is None: - rows = await cls.fetch_where(**cls._dict_from_id(rowid)) - row = rows[0] if rows else None - if row is None: - cls._cache_[rowid] = cls(None) - elif row.data is None: - row = None - - return row - - @classmethod - async def fetch_or_create(cls, *rowid, **kwargs): - """ - Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. - """ - if rowid: - row = await cls.fetch(*rowid) - else: - rows = await cls.fetch_where(**kwargs).limit(1) - row = rows[0] if rows else None - - if row is None: - creation_kwargs = kwargs - if rowid: - creation_kwargs.update(cls._dict_from_id(rowid)) - row = await cls.create(**creation_kwargs) - return row - - async def refresh(self: RowT) -> Optional[RowT]: - """ - Refresh this Row from data. - - The return value may be `None` if the row was deleted. - """ - rows = await self.table.select_where(**self._dict_) - if not rows: - return None - else: - self.data = rows[0] - return self - - async def update(self: RowT, **values) -> Optional[RowT]: - """ - Update this Row with the given values. - - Internally passes the provided `values` to the `update` Query. - The return value may be `None` if the row was deleted. - """ - data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows) - if not data: - return None - else: - return data[0] - - async def delete(self: RowT) -> Optional[RowT]: - """ - Delete this Row. - """ - data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows) - return data[0] if data is not None else None diff --git a/src/data/queries.py b/src/data/queries.py deleted file mode 100644 index 0232928..0000000 --- a/src/data/queries.py +++ /dev/null @@ -1,644 +0,0 @@ -from typing import Optional, TypeVar, Any, Callable, Generic, List, Union -from enum import Enum -from itertools import chain -from psycopg import AsyncConnection, AsyncCursor -from psycopg import sql -from psycopg.rows import DictRow - -import logging - -from .conditions import Condition -from .base import Expression, RawExpr -from .connector import Connector - - -logger = logging.getLogger(__name__) - - -TQueryT = TypeVar('TQueryT', bound='TableQuery') -SQueryT = TypeVar('SQueryT', bound='Select') - -QueryResult = TypeVar('QueryResult') - - -class Query(Generic[QueryResult]): - """ - ABC for an executable query statement. - """ - __slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result') - - _adapter: Callable[..., QueryResult] - - def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs): - self.connector: Optional[Connector] = connector - self.conn: Optional[AsyncConnection] = conn - self.cursor: Optional[AsyncCursor] = cursor - - if row_adapter is not None: - self._adapter = row_adapter - else: - self._adapter = self._no_adapter - - self.result: Optional[QueryResult] = None - - def bind(self, connector: Connector): - self.connector = connector - return self - - def with_cursor(self, cursor: AsyncCursor): - self.cursor = cursor - return self - - def with_connection(self, conn: AsyncConnection): - self.conn = conn - return self - - def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: - return data - - def with_adapter(self, callable: Callable[..., QueryResult]): - # NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1] - # For this to work cleanly, callable should have arg type of QR1, not any - self._adapter = callable - return self - - def with_no_adapter(self): - """ - Sets the adapater to the identity. - """ - self._adapter = self._no_adapter - return self - - def one(self): - # TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1] - return self - - def build(self) -> Expression: - raise NotImplementedError - - async def _execute(self, cursor: AsyncCursor) -> QueryResult: - query, values = self.build().as_tuple() - # TODO: Move logging out to a custom cursor - # logger.debug( - # f"Executing query ({query.as_string(cursor)}) with values {values}", - # extra={'action': "Query"} - # ) - await cursor.execute(sql.Composed((query,)), values) - data = await cursor.fetchall() - self.result = self._adapter(*data) - return self.result - - async def execute(self, cursor=None) -> QueryResult: - """ - Execute the query, optionally with the provided cursor, and return the result rows. - If no cursor is provided, and no cursor has been set with `with_cursor`, - the execution will create a new cursor from the connection and close it automatically. - """ - # Create a cursor if possible - cursor = cursor if cursor is not None else self.cursor - if self.cursor is None: - if self.conn is None: - if self.connector is None: - raise ValueError("Cannot execute query without cursor, connection, or connector.") - else: - async with self.connector.connection() as conn: - async with conn.cursor() as cursor: - data = await self._execute(cursor) - else: - async with self.conn.cursor() as cursor: - data = await self._execute(cursor) - else: - data = await self._execute(cursor) - return data - - def __await__(self): - return self.execute().__await__() - - -class TableQuery(Query[QueryResult]): - """ - ABC for an executable query statement expected to be run on a single table. - """ - __slots__ = ( - 'tableid', - 'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group' - ) - - def __init__(self, tableid, *args, **kwargs): - super().__init__(*args, **kwargs) - self.tableid: sql.Identifier = tableid - - def options(self, **kwargs): - """ - Set some query options. - Default implementation does nothing. - Should be overridden to provide specific options. - """ - return self - - -class WhereMixin(TableQuery[QueryResult]): - __slots__ = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.condition: Optional[Condition] = None - - def where(self, *args: Condition, **kwargs): - """ - Add a Condition to the query. - Position arguments should be Conditions, - and keyword arguments should be of the form `column=Value`, - where Value may be a Value-type or a literal value. - All provided Conditions will be and-ed together to create a new Condition. - TODO: Maybe just pass this verbatim to a condition. - """ - if args or kwargs: - condition = Condition.construct(*args, **kwargs) - if self.condition is not None: - condition = self.condition & condition - - self.condition = condition - - return self - - @property - def _where_section(self) -> Optional[Expression]: - if self.condition is not None: - return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple()) - else: - return None - - -class JOINTYPE(Enum): - LEFT = sql.SQL('LEFT JOIN') - RIGHT = sql.SQL('RIGHT JOIN') - INNER = sql.SQL('INNER JOIN') - OUTER = sql.SQL('OUTER JOIN') - FULLOUTER = sql.SQL('FULL OUTER JOIN') - - -class JoinMixin(TableQuery[QueryResult]): - __slots__ = () - # TODO: Remember to add join slots to TableQuery - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._joins: list[Expression] = [] - - def join(self, - target: Union[str, Expression], - on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None, - join_type: JOINTYPE = JOINTYPE.INNER, - natural=False): - available = (on is not None) + (using is not None) + natural - if available == 0: - raise ValueError("No conditions given for Query Join") - if available > 1: - raise ValueError("Exactly one join format must be given for Query Join") - - sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())] - if isinstance(target, str): - sections.append((sql.Identifier(target), ())) - else: - sections.append(target.as_tuple()) - - if on is not None: - sections.append((sql.SQL('ON'), ())) - sections.append(on.as_tuple()) - elif using is not None: - sections.append((sql.SQL('USING'), ())) - if isinstance(using, Expression): - sections.append(using.as_tuple()) - elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str): - cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using)) - sections.append((cols, ())) - else: - raise ValueError("Unrecognised 'using' type.") - elif natural: - sections.insert(0, (sql.SQL('NATURAL'), ())) - - expr = RawExpr.join_tuples(*sections) - self._joins.append(expr) - return self - - def leftjoin(self, *args, **kwargs): - return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs) - - @property - def _join_section(self) -> Optional[Expression]: - if self._joins: - return RawExpr.join(*self._joins) - else: - return None - - -class ExtraMixin(TableQuery[QueryResult]): - __slots__ = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._extra: Optional[Expression] = None - - def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()): - """ - Add an extra string, and optionally values, to this query. - The extra string is inserted after any condition, and before the limit. - """ - extra_expr = RawExpr(extra, values) - if self._extra is not None: - extra_expr = RawExpr.join(self._extra, extra_expr) - self._extra = extra_expr - return self - - @property - def _extra_section(self) -> Optional[Expression]: - if self._extra is None: - return None - else: - return self._extra - - -class LimitMixin(TableQuery[QueryResult]): - __slots__ = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._limit: Optional[int] = None - - def limit(self, limit: int): - """ - Add a limit to this query. - """ - self._limit = limit - return self - - @property - def _limit_section(self) -> Optional[Expression]: - if self._limit is not None: - return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,)) - else: - return None - - -class FromMixin(TableQuery[QueryResult]): - __slots__ = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._from: Optional[Expression] = None - - def from_expr(self, _from: Expression): - self._from = _from - return self - - @property - def _from_section(self) -> Optional[Expression]: - if self._from is not None: - expr, values = self._from.as_tuple() - return RawExpr(sql.SQL("FROM {}").format(expr), values) - else: - return None - - -class ORDER(Enum): - ASC = sql.SQL('ASC') - DESC = sql.SQL('DESC') - - -class NULLS(Enum): - FIRST = sql.SQL('NULLS FIRST') - LAST = sql.SQL('NULLS LAST') - - -class OrderMixin(TableQuery[QueryResult]): - __slots__ = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._order: list[Expression] = [] - - def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None): - """ - Add a single sort expression to the query. - This method stacks. - """ - if isinstance(expr, Expression): - string, values = expr.as_tuple() - else: - string = sql.Identifier(expr) - values = () - - parts = [string] - if direction is not None: - parts.append(direction.value) - if nulls is not None: - parts.append(nulls.value) - - order_string = sql.SQL(' ').join(parts) - self._order.append(RawExpr(order_string, values)) - return self - - @property - def _order_section(self) -> Optional[Expression]: - if self._order: - expr = RawExpr.join(*self._order, joiner=sql.SQL(', ')) - expr.expr = sql.SQL("ORDER BY {}").format(expr.expr) - return expr - else: - return None - - -class GroupMixin(TableQuery[QueryResult]): - __slots__ = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._group: list[Expression] = [] - - def group_by(self, *exprs: Union[Expression, str]): - """ - Add a group expression(s) to the query. - This method stacks. - """ - for expr in exprs: - if isinstance(expr, Expression): - self._group.append(expr) - else: - self._group.append(RawExpr(sql.Identifier(expr))) - return self - - @property - def _group_section(self) -> Optional[Expression]: - if self._group: - expr = RawExpr.join(*self._group, joiner=sql.SQL(', ')) - expr.expr = sql.SQL("GROUP BY {}").format(expr.expr) - return expr - else: - return None - - -class Insert(ExtraMixin, TableQuery[QueryResult]): - """ - Query type representing a table insert query. - """ - # TODO: Support ON CONFLICT for upserts - __slots__ = ('_columns', '_values', '_conflict') - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._columns: tuple[str, ...] = () - self._values: tuple[tuple[Any, ...], ...] = () - self._conflict: Optional[Expression] = None - - def insert(self, columns, *values): - """ - Insert the given data. - - Parameters - ---------- - columns: tuple[str] - Tuple of column names to insert. - - values: tuple[tuple[Any, ...], ...] - Tuple of values to insert, corresponding to the columns. - """ - if not values: - raise ValueError("Cannot insert zero rows.") - if len(values[0]) != len(columns): - raise ValueError("Number of columns does not match length of values.") - - self._columns = columns - self._values = values - return self - - def on_conflict(self, ignore=False): - # TODO lots more we can do here - # Maybe return a Conflict object that can chain itself (not the query) - if ignore: - self._conflict = RawExpr(sql.SQL('DO NOTHING')) - return self - - @property - def _conflict_section(self) -> Optional[Expression]: - if self._conflict is not None: - e, v = self._conflict.as_tuple() - expr = RawExpr( - sql.SQL("ON CONFLICT {}").format( - e - ), - v - ) - return expr - return None - - def build(self): - columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) - single_value_str = sql.SQL('({})').format( - sql.SQL(',').join(sql.Placeholder() * len(self._columns)) - ) - values_str = sql.SQL(',').join(single_value_str * len(self._values)) - - # TODO: Check efficiency of inserting multiple values like this - # Also implement a Copy query - base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( - table=self.tableid, - columns=columns, - values_str=values_str - ) - - sections = [ - RawExpr(base, tuple(chain(*self._values))), - self._conflict_section, - self._extra_section, - RawExpr(sql.SQL('RETURNING *')) - ] - - sections = (section for section in sections if section is not None) - return RawExpr.join(*sections) - - -class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]): - """ - Select rows from a table matching provided conditions. - """ - __slots__ = ('_columns',) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._columns: tuple[Expression, ...] = () - - def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]): - """ - Set the columns and expressions to select. - If none are given, selects all columns. - """ - cols: List[Expression] = [] - if columns: - cols.extend(map(RawExpr, map(sql.Identifier, columns))) - if exprs: - for name, expr in exprs.items(): - if isinstance(expr, str): - cols.append( - RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name)) - ) - elif isinstance(expr, sql.Composable): - cols.append( - RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name)) - ) - elif isinstance(expr, Expression): - value_expr, value_values = expr.as_tuple() - cols.append(RawExpr( - value_expr + sql.SQL(' AS ') + sql.Identifier(name), - value_values - )) - if cols: - self._columns = (*self._columns, *cols) - return self - - def build(self): - if not self._columns: - columns, columns_values = sql.SQL('*'), () - else: - columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple() - - base = sql.SQL("SELECT {columns} FROM {table}").format( - columns=columns, - table=self.tableid - ) - - sections = [ - RawExpr(base, columns_values), - self._join_section, - self._where_section, - self._group_section, - self._extra_section, - self._order_section, - self._limit_section, - ] - - sections = (section for section in sections if section is not None) - return RawExpr.join(*sections) - - -class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]): - """ - Query type representing a table delete query. - """ - # TODO: Cascade option for delete, maybe other options - # TODO: Require a where unless specifically disabled, for safety - - def build(self): - base = sql.SQL("DELETE FROM {table}").format( - table=self.tableid, - ) - sections = [ - RawExpr(base), - self._where_section, - self._extra_section, - RawExpr(sql.SQL('RETURNING *')) - ] - - sections = (section for section in sections if section is not None) - return RawExpr.join(*sections) - - -class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]): - __slots__ = ( - '_set', - ) - # TODO: Again, require a where unless specifically disabled - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._set: List[Expression] = [] - - def set(self, **column_values: Union[Any, Expression]): - exprs: List[Expression] = [] - for name, value in column_values.items(): - if isinstance(value, Expression): - value_tup = value.as_tuple() - else: - value_tup = (sql.Placeholder(), (value,)) - - exprs.append( - RawExpr.join_tuples( - (sql.Identifier(name), ()), - value_tup, - joiner=sql.SQL(' = ') - ) - ) - self._set.extend(exprs) - return self - - def build(self): - if not self._set: - raise ValueError("No columns provided to update.") - set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple() - - base = sql.SQL("UPDATE {table} SET {set}").format( - table=self.tableid, - set=set_expr - ) - sections = [ - RawExpr(base, set_values), - self._from_section, - self._where_section, - self._extra_section, - self._limit_section, - RawExpr(sql.SQL('RETURNING *')) - ] - - sections = (section for section in sections if section is not None) - return RawExpr.join(*sections) - - -# async def upsert(cursor, table, constraint, **values): -# """ -# Insert or on conflict update. -# """ -# valuedict = values -# keys, values = zip(*values.items()) -# -# key_str = _format_insertkeys(keys) -# value_str, values = _format_insertvalues(values) -# update_key_str, update_key_values = _format_updatestr(valuedict) -# -# if not isinstance(constraint, str): -# constraint = ", ".join(constraint) -# -# await cursor.execute( -# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( -# table, key_str, value_str, constraint, update_key_str -# ), -# tuple((*values, *update_key_values)) -# ) -# return await cursor.fetchone() - - -# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None): -# cursor = cursor or conn.cursor() -# -# # TODO: executemany or copy syntax now -# return execute_values( -# cursor, -# """ -# UPDATE {table} -# SET {set_clause} -# FROM (VALUES {cast_row}%s) -# AS {temp_table} -# WHERE {where_clause} -# RETURNING * -# """.format( -# table=table, -# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys), -# cast_row=cast_row + ',' if cast_row else '', -# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys), -# temp_table="_t ({})".format(', '.join(set_keys + where_keys)) -# ), -# values, -# fetch=True -# ) diff --git a/src/data/registry.py b/src/data/registry.py deleted file mode 100644 index c130d0f..0000000 --- a/src/data/registry.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Protocol, runtime_checkable, Optional - -from psycopg import AsyncConnection - -from .connector import Connector, Connectable - - -@runtime_checkable -class _Attachable(Connectable, Protocol): - def attach_to(self, registry: 'Registry'): - raise NotImplementedError - - -class Registry: - _attached: list[_Attachable] = [] - _name: Optional[str] = None - - def __init_subclass__(cls, name=None): - attached = [] - for _, member in cls.__dict__.items(): - if isinstance(member, _Attachable): - attached.append(member) - cls._attached = attached - cls._name = name or cls.__name__ - - def __init__(self, name=None): - self._conn: Optional[Connector] = None - self.name: str = name if name is not None else self._name - if self.name is None: - raise ValueError("A Registry must have a name!") - - self.init_tasks = [] - - for member in self._attached: - member.attach_to(self) - - def bind(self, connector: Connector): - self._conn = connector - for child in self._attached: - child.bind(connector) - - def attach(self, attachable): - self._attached.append(attachable) - if self._conn is not None: - attachable.bind(self._conn) - return attachable - - def init_task(self, coro): - """ - Initialisation tasks are run to setup the registry state. - These tasks will be run in the event loop, after connection to the database. - These tasks should be idempotent, as they may be run on reload and reconnect. - """ - self.init_tasks.append(coro) - return coro - - async def init(self): - for task in self.init_tasks: - await task(self) - return self - - -class AttachableClass: - """ABC for a default implementation of an Attachable class.""" - - _connector: Optional[Connector] = None - _registry: Optional[Registry] = None - - @classmethod - def bind(cls, connector: Connector): - cls._connector = connector - connector.connect_hook(cls.on_connect) - return cls - - @classmethod - def attach_to(cls, registry: Registry): - cls._registry = registry - return cls - - @classmethod - async def on_connect(cls, connection: AsyncConnection): - pass - - -class Attachable: - """ABC for a default implementation of an Attachable object.""" - - def __init__(self, *args, **kwargs): - self._connector: Optional[Connector] = None - self._registry: Optional[Registry] = None - - def bind(self, connector: Connector): - self._connector = connector - connector.connect_hook(self.on_connect) - return self - - def attach_to(self, registry: Registry): - self._registry = registry - return self - - async def on_connect(self, connection: AsyncConnection): - pass diff --git a/src/data/table.py b/src/data/table.py deleted file mode 100644 index e20647e..0000000 --- a/src/data/table.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Optional -from psycopg.rows import DictRow -from psycopg import sql - -from . import queries as q -from .connector import Connector -from .registry import Registry - - -class Table: - """ - Transparent interface to a single table structure in the database. - Contains standard methods to access the table. - """ - - def __init__(self, name, *args, schema='public', **kwargs): - self.name: str = name - self.schema: str = schema - self.connector: Connector = None - - @property - def identifier(self): - if self.schema == 'public': - return sql.Identifier(self.name) - else: - return sql.Identifier(self.schema, self.name) - - def bind(self, connector: Connector): - self.connector = connector - return self - - def attach_to(self, registry: Registry): - self._registry = registry - return self - - def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: - return data - - def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]: - if data: - return data[0] - else: - return None - - def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: - return data - - def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]: - return q.Select( - self.identifier, - row_adapter=self._many_query_adapter, - connector=self.connector - ).where(*args, **kwargs) - - def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]: - return q.Select( - self.identifier, - row_adapter=self._single_query_adapter, - connector=self.connector - ).where(*args, **kwargs) - - def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]: - return q.Update( - self.identifier, - row_adapter=self._many_query_adapter, - connector=self.connector - ).where(*args, **kwargs) - - def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]: - return q.Delete( - self.identifier, - row_adapter=self._many_query_adapter, - connector=self.connector - ).where(*args, **kwargs) - - def insert(self, **column_values) -> q.Insert[DictRow]: - return q.Insert( - self.identifier, - row_adapter=self._single_query_adapter, - connector=self.connector - ).insert(column_values.keys(), column_values.values()) - - def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]: - return q.Insert( - self.identifier, - row_adapter=self._many_query_adapter, - connector=self.connector - ).insert(*args, **kwargs) - -# def update_many(self, *args, **kwargs): -# with self.conn: -# return update_many(self.identifier, *args, **kwargs) - -# def upsert(self, *args, **kwargs): -# return upsert(self.identifier, *args, **kwargs)