From 4da50d86783d17d56179f304f8d9fdb36d77e1e8 Mon Sep 17 00:00:00 2001 From: Interitio Date: Thu, 31 Jul 2025 07:39:53 +1000 Subject: [PATCH] Initial Commit --- __init__.py | 9 + adapted.py | 40 ++++ base.py | 45 ++++ columns.py | 155 ++++++++++++ conditions.py | 214 +++++++++++++++++ connector.py | 135 +++++++++++ cursor.py | 42 ++++ database.py | 47 ++++ models.py | 323 +++++++++++++++++++++++++ queries.py | 644 ++++++++++++++++++++++++++++++++++++++++++++++++++ registry.py | 102 ++++++++ table.py | 95 ++++++++ 12 files changed, 1851 insertions(+) create mode 100644 __init__.py create mode 100644 adapted.py create mode 100644 base.py create mode 100644 columns.py create mode 100644 conditions.py create mode 100644 connector.py create mode 100644 cursor.py create mode 100644 database.py create mode 100644 models.py create mode 100644 queries.py create mode 100644 registry.py create mode 100644 table.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..affb160 --- /dev/null +++ b/__init__.py @@ -0,0 +1,9 @@ +from .conditions import Condition, condition, NULL +from .database import Database +from .models import RowModel, RowTable, WeakCache +from .table import Table +from .base import Expression, RawExpr +from .columns import ColumnExpr, Column, Integer, String +from .registry import Registry, AttachableClass, Attachable +from .adapted import RegisterEnum +from .queries import ORDER, NULLS, JOINTYPE diff --git a/adapted.py b/adapted.py new file mode 100644 index 0000000..a6b4597 --- /dev/null +++ b/adapted.py @@ -0,0 +1,40 @@ +# from enum import Enum +from typing import Optional +from psycopg.types.enum import register_enum, EnumInfo +from psycopg import AsyncConnection +from .registry import Attachable, Registry + + +class RegisterEnum(Attachable): + def __init__(self, enum, name: Optional[str] = None, mapper=None): + super().__init__() + self.enum = enum + self.name = name or enum.__name__ + self.mapping = mapper(enum) if mapper is not None else self._mapper() + + def _mapper(self): + return {m: m.value[0] for m in self.enum} + + def attach_to(self, registry: Registry): + self._registry = registry + registry.init_task(self.on_init) + return self + + async def on_init(self, registry: Registry): + connector = registry._conn + if connector is None: + raise ValueError("Cannot initialise without connector!") + connector.connect_hook(self.connection_hook) + # await connector.refresh_pool() + # The below may be somewhat dangerous + # But adaption should never write to the database + await connector.map_over_pool(self.connection_hook) + # if conn := connector.conn: + # # Ensure the adaption is run in the current context as well + # await self.connection_hook(conn) + + async def connection_hook(self, conn: AsyncConnection): + info = await EnumInfo.fetch(conn, self.name) + if info is None: + raise ValueError(f"Enum {self.name} not found in database.") + register_enum(info, conn, self.enum, mapping=list(self.mapping.items())) diff --git a/base.py b/base.py new file mode 100644 index 0000000..272d588 --- /dev/null +++ b/base.py @@ -0,0 +1,45 @@ +from abc import abstractmethod +from typing import Any, Protocol, runtime_checkable +from itertools import chain +from psycopg import sql + + +@runtime_checkable +class Expression(Protocol): + __slots__ = () + + @abstractmethod + def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]: + raise NotImplementedError + + +class RawExpr(Expression): + __slots__ = ('expr', 'values') + + expr: sql.Composable + values: tuple[Any, ...] + + def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()): + self.expr = expr + self.values = values + + def as_tuple(self): + return (self.expr, self.values) + + @classmethod + def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')): + """ + Join a sequence of Expressions into a single RawExpr. + """ + tups = ( + expression.as_tuple() + for expression in expressions + ) + return cls.join_tuples(*tups, joiner=joiner) + + @classmethod + def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')): + exprs, values = zip(*tuples) + expr = joiner.join(exprs) + value = tuple(chain(*values)) + return cls(expr, value) diff --git a/columns.py b/columns.py new file mode 100644 index 0000000..252db83 --- /dev/null +++ b/columns.py @@ -0,0 +1,155 @@ +from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING +from psycopg import sql +from datetime import datetime + +from .base import RawExpr, Expression +from .conditions import Condition, Joiner +from .table import Table + + +class ColumnExpr(RawExpr): + __slots__ = () + + def __lt__(self, obj) -> Condition: + expr, values = self.as_tuple() + + if isinstance(obj, Expression): + # column < Expression + obj_expr, obj_values = obj.as_tuple() + cond_exprs = (expr, Joiner.LT, obj_expr) + cond_values = (*values, *obj_values) + else: + # column < Literal + cond_exprs = (expr, Joiner.LT, sql.Placeholder()) + cond_values = (*values, obj) + + return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __le__(self, obj) -> Condition: + expr, values = self.as_tuple() + + if isinstance(obj, Expression): + # column <= Expression + obj_expr, obj_values = obj.as_tuple() + cond_exprs = (expr, Joiner.LE, obj_expr) + cond_values = (*values, *obj_values) + else: + # column <= Literal + cond_exprs = (expr, Joiner.LE, sql.Placeholder()) + cond_values = (*values, obj) + + return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __eq__(self, obj) -> Condition: # type: ignore[override] + return Condition._expression_equality(self, obj) + + def __ne__(self, obj) -> Condition: # type: ignore[override] + return ~(self.__eq__(obj)) + + def __gt__(self, obj) -> Condition: + return ~(self.__le__(obj)) + + def __ge__(self, obj) -> Condition: + return ~(self.__lt__(obj)) + + def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} + {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} + {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def __sub__(self, obj) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} - {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} - {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def __mul__(self, obj) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} * {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} * {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def CAST(self, target_type: sql.Composable): + return ColumnExpr( + sql.SQL("({}::{})").format(self.expr, target_type), + self.values + ) + + +T = TypeVar('T') + +if TYPE_CHECKING: + from .models import RowModel + + +class Column(ColumnExpr, Generic[T]): + def __init__(self, name: Optional[str] = None, + primary: bool = False, references: Optional['Column'] = None, + type: Optional[Type[T]] = None): + self.primary = primary + self.references = references + self.name: str = name # type: ignore + self.owner: Optional['RowModel'] = None + self._type = type + + self.expr = sql.Identifier(name) if name else sql.SQL('') + self.values = () + + def __set_name__(self, owner, name): + # Only allow setting the owner once + self.name = self.name or name + self.owner = owner + self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name) + + @overload + def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': + ... + + @overload + def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T: + ... + + def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]": + # Get value from row data or session + if obj is None: + return self + else: + return obj.data[self.name] + + +class Integer(Column[int]): + pass + + +class String(Column[str]): + pass + + +class Bool(Column[bool]): + pass + + +class Timestamp(Column[datetime]): + pass diff --git a/conditions.py b/conditions.py new file mode 100644 index 0000000..f40dff6 --- /dev/null +++ b/conditions.py @@ -0,0 +1,214 @@ +# from meta import sharding +from typing import Any, Union +from enum import Enum +from itertools import chain +from psycopg import sql + +from .base import Expression, RawExpr + + +""" +A Condition is a "logical" database expression, intended for use in Where statements. +Conditions support bitwise logical operators ~, &, |, each producing another Condition. +""" + +NULL = None + + +class Joiner(Enum): + EQUALS = ('=', '!=') + IS = ('IS', 'IS NOT') + LIKE = ('LIKE', 'NOT LIKE') + BETWEEN = ('BETWEEN', 'NOT BETWEEN') + IN = ('IN', 'NOT IN') + LT = ('<', '>=') + LE = ('<=', '>') + NONE = ('', '') + + +class Condition(Expression): + __slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values') + + def __init__(self, + expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''), + values: tuple[Any, ...] = (), negated=False + ): + self.expr1 = expr1 + self.joiner = joiner + self.negated = negated + self.expr2 = expr2 + self.values = values + + def as_tuple(self): + expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2)) + if self.negated and self.joiner is Joiner.NONE: + expr = sql.SQL("NOT ({})").format(expr) + return (expr, self.values) + + @classmethod + def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]): + """ + Construct a Condition from a sequence of Conditions, + together with some explicit column conditions. + """ + # TODO: Consider adding a _table identifier here so we can identify implicit columns + # Or just require subquery type conditions to always come from modelled tables. + implicit_conditions = ( + cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items() + ) + return cls._and(*conditions, *implicit_conditions) + + @classmethod + def _and(cls, *conditions: 'Condition'): + if not len(conditions): + raise ValueError("Cannot combine 0 Conditions") + if len(conditions) == 1: + return conditions[0] + + exprs, values = zip(*(condition.as_tuple() for condition in conditions)) + cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs)) + cond_values = tuple(chain(*values)) + + return Condition(cond_expr, values=cond_values) + + @classmethod + def _or(cls, *conditions: 'Condition'): + if not len(conditions): + raise ValueError("Cannot combine 0 Conditions") + if len(conditions) == 1: + return conditions[0] + + exprs, values = zip(*(condition.as_tuple() for condition in conditions)) + cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs)) + cond_values = tuple(chain(*values)) + + return Condition(cond_expr, values=cond_values) + + @classmethod + def _not(cls, condition: 'Condition'): + condition.negated = not condition.negated + return condition + + @classmethod + def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition': + # TODO: Check if this supports sbqueries + col_expr, col_values = column.as_tuple() + + # TODO: Also support sql.SQL? For joins? + if isinstance(value, Expression): + # column = Expression + value_expr, value_values = value.as_tuple() + cond_exprs = (col_expr, Joiner.EQUALS, value_expr) + cond_values = (*col_values, *value_values) + elif isinstance(value, (tuple, list)): + # column in (...) + # TODO: Support expressions in value tuple? + if not value: + raise ValueError("Cannot create Condition from empty iterable!") + value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value))) + cond_exprs = (col_expr, Joiner.IN, value_expr) + cond_values = (*col_values, *value) + elif value is None: + # column IS NULL + cond_exprs = (col_expr, Joiner.IS, sql.NULL) + cond_values = col_values + else: + # column = Literal + cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder()) + cond_values = (*col_values, value) + + return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __invert__(self) -> 'Condition': + self.negated = not self.negated + return self + + def __and__(self, condition: 'Condition') -> 'Condition': + return self._and(self, condition) + + def __or__(self, condition: 'Condition') -> 'Condition': + return self._or(self, condition) + + +# Helper method to simply condition construction +def condition(*args, **kwargs) -> Condition: + return Condition.construct(*args, **kwargs) + + +# class NOT(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# if item: +# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item)))) +# values.extend(item) +# else: +# raise ValueError("Cannot check an empty iterable!") +# else: +# conditions.append("{}!={}".format(key, _replace_char)) +# values.append(item) +# +# +# class GEQ(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# raise ValueError("Cannot apply GEQ condition to a list!") +# else: +# conditions.append("{} >= {}".format(key, _replace_char)) +# values.append(item) +# +# +# class LEQ(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# raise ValueError("Cannot apply LEQ condition to a list!") +# else: +# conditions.append("{} <= {}".format(key, _replace_char)) +# values.append(item) +# +# +# class Constant(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# conditions.append("{} {}".format(key, self.value)) +# +# +# class SHARDID(Condition): +# __slots__ = ('shardid', 'shard_count') +# +# def __init__(self, shardid, shard_count): +# self.shardid = shardid +# self.shard_count = shard_count +# +# def apply(self, key, values, conditions): +# if self.shard_count > 1: +# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) +# values.append(self.shardid) +# +# +# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) +# +# +# NULL = Constant('IS NULL') +# NOTNULL = Constant('IS NOT NULL') diff --git a/connector.py b/connector.py new file mode 100644 index 0000000..6baf3b5 --- /dev/null +++ b/connector.py @@ -0,0 +1,135 @@ +from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional +import logging + +from contextvars import ContextVar +from contextlib import asynccontextmanager +import psycopg as psq +from psycopg_pool import AsyncConnectionPool +from psycopg.pq import TransactionStatus + +from .cursor import AsyncLoggingCursor + +logger = logging.getLogger(__name__) + +row_factory = psq.rows.dict_row + +ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None) + + +class Connector: + cursor_factory = AsyncLoggingCursor + + def __init__(self, conn_args): + self._conn_args = conn_args + self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory) + + self.pool = self.make_pool() + + self.conn_hooks = [] + + @property + def conn(self) -> Optional[psq.AsyncConnection]: + """ + Convenience property for the current context connection. + """ + return ctx_connection.get() + + @conn.setter + def conn(self, conn: psq.AsyncConnection): + """ + Set the contextual connection in the current context. + Always do this in an isolated context! + """ + ctx_connection.set(conn) + + def make_pool(self) -> AsyncConnectionPool: + logger.info("Initialising connection pool.", extra={'action': "Pool Init"}) + return AsyncConnectionPool( + self._conn_args, + open=False, + min_size=1, + max_size=4, + configure=self._setup_connection, + kwargs=self._conn_kwargs + ) + + async def refresh_pool(self): + """ + Refresh the pool. + + The point of this is to invalidate any existing connections so that the connection set up is run again. + Better ways should be sought (a way to + """ + logger.info("Pool refresh requested, closing and reopening.") + old_pool = self.pool + self.pool = self.make_pool() + await self.pool.open() + logger.info(f"Old pool statistics: {self.pool.get_stats()}") + await old_pool.close() + logger.info("Pool refresh complete.") + + async def map_over_pool(self, callable): + """ + Dangerous method to call a method on each connection in the pool. + + Utilises private methods of the AsyncConnectionPool. + """ + async with self.pool._lock: + conns = list(self.pool._pool) + while conns: + conn = conns.pop() + try: + await callable(conn) + except Exception: + logger.exception(f"Mapped connection task failed. {callable.__name__}") + + @asynccontextmanager + async def open(self): + try: + logger.info("Opening database pool.") + await self.pool.open() + yield + finally: + # May be a different pool! + logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}") + await self.pool.close() + + @asynccontextmanager + async def connection(self) -> psq.AsyncConnection: + """ + Asynchronous context manager to get and manage a connection. + + If the context connection is set, uses this and does not manage the lifetime. + Otherwise, requests a new connection from the pool and returns it when done. + """ + logger.debug("Database connection requested.", extra={'action': "Data Connect"}) + if (conn := self.conn): + yield conn + else: + async with self.pool.connection() as conn: + yield conn + + async def _setup_connection(self, conn: psq.AsyncConnection): + logger.debug("Initialising new connection.", extra={'action': "Conn Init"}) + for hook in self.conn_hooks: + try: + await hook(conn) + except Exception: + logger.exception("Exception encountered setting up new connection") + return conn + + def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): + """ + Minimal decorator to register a coroutine to run on connect or reconnect. + + Note that these are only run on connect and reconnect. + If a hook is registered after connection, it will not be run. + """ + self.conn_hooks.append(coro) + return coro + + +@runtime_checkable +class Connectable(Protocol): + def bind(self, connector: Connector): + raise NotImplementedError diff --git a/cursor.py b/cursor.py new file mode 100644 index 0000000..5d183e0 --- /dev/null +++ b/cursor.py @@ -0,0 +1,42 @@ +import logging +from typing import Optional + +from psycopg import AsyncCursor, sql +from psycopg.abc import Query, Params +from psycopg._encodings import conn_encoding + +logger = logging.getLogger(__name__) + + +class AsyncLoggingCursor(AsyncCursor): + def mogrify_query(self, query: Query): + if isinstance(query, str): + msg = query + elif isinstance(query, (sql.SQL, sql.Composed)): + msg = query.as_string(self) + elif isinstance(query, bytes): + msg = query.decode(conn_encoding(self._conn.pgconn), 'replace') + else: + msg = repr(query) + return msg + + async def execute(self, query: Query, params: Optional[Params] = None, **kwargs): + if logging.DEBUG >= logger.getEffectiveLevel(): + msg = self.mogrify_query(query) + logger.debug( + "Executing query (%s) with values %s", msg, params, + extra={'action': "Query Execute"} + ) + try: + return await super().execute(query, params=params, **kwargs) + except Exception: + msg = self.mogrify_query(query) + logger.exception( + "Exception during query execution. Query (%s) with parameters %s.", + msg, params, + extra={'action': "Query Execute"}, + stack_info=True + ) + else: + # TODO: Possibly log execution time + pass diff --git a/database.py b/database.py new file mode 100644 index 0000000..255e412 --- /dev/null +++ b/database.py @@ -0,0 +1,47 @@ +from typing import TypeVar +import logging +from collections import namedtuple + +# from .cursor import AsyncLoggingCursor +from .registry import Registry +from .connector import Connector + + +logger = logging.getLogger(__name__) + +Version = namedtuple('Version', ('version', 'time', 'author')) + +T = TypeVar('T', bound=Registry) + + +class Database(Connector): + # cursor_factory = AsyncLoggingCursor + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.registries: dict[str, Registry] = {} + + def load_registry(self, registry: T) -> T: + logger.debug( + f"Loading and binding registry '{registry.name}'.", + extra={'action': f"Reg {registry.name}"} + ) + registry.bind(self) + self.registries[registry.name] = registry + return registry + + async def version(self) -> Version: + """ + Return the current schema version as a Version namedtuple. + """ + async with self.connection() as conn: + async with conn.cursor() as cursor: + # Get last entry in version table, compare against desired version + await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") + row = await cursor.fetchone() + if row: + return Version(row['version'], row['time'], row['author']) + else: + # No versions in the database + return Version(-1, None, None) diff --git a/models.py b/models.py new file mode 100644 index 0000000..54b6282 --- /dev/null +++ b/models.py @@ -0,0 +1,323 @@ +from typing import TypeVar, Type, Optional, Generic, Union +# from typing_extensions import Self +from weakref import WeakValueDictionary +from collections.abc import MutableMapping + +from psycopg.rows import DictRow + +from .table import Table +from .columns import Column +from . import queries as q +from .connector import Connector +from .registry import Registry + + +RowT = TypeVar('RowT', bound='RowModel') + + +class MISSING: + __slots__ = ('oid',) + + def __init__(self, oid): + self.oid = oid + + +class RowTable(Table, Generic[RowT]): + __slots__ = ( + 'model', + ) + + def __init__(self, name, model: Type[RowT], **kwargs): + super().__init__(name, **kwargs) + self.model = model + + @property + def columns(self): + return self.model._columns_ + + @property + def id_col(self): + return self.model._key_ + + @property + def row_cache(self): + return self.model._cache_ + + def _many_query_adapter(self, *data): + self.model._make_rows(*data) + return data + + def _single_query_adapter(self, *data): + if data: + self.model._make_rows(*data) + return data[0] + else: + return None + + def _delete_query_adapter(self, *data): + self.model._delete_rows(*data) + return data + + # New methods to fetch and create rows + async def create_row(self, *args, **kwargs) -> RowT: + data = await super().insert(*args, **kwargs) + return self.model._make_rows(data)[0] + + def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]: + # TODO: Handle list of rowids here? + return q.Select( + self.identifier, + row_adapter=self.model._make_rows, + connector=self.connector + ).where(*args, **kwargs) + + +WK = TypeVar('WK') +WV = TypeVar('WV') + + +class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]): + def __init__(self, ref_cache): + self.ref_cache = ref_cache + self.weak_cache = WeakValueDictionary() + + def __getitem__(self, key): + value = self.weak_cache[key] + self.ref_cache[key] = value + return value + + def __setitem__(self, key, value): + self.weak_cache[key] = value + self.ref_cache[key] = value + + def __delitem__(self, key): + del self.weak_cache[key] + try: + del self.ref_cache[key] + except KeyError: + pass + + def __contains__(self, key): + return key in self.weak_cache + + def __iter__(self): + return iter(self.weak_cache) + + def __len__(self): + return len(self.weak_cache) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def pop(self, key, default=None): + if key in self: + value = self[key] + del self[key] + else: + value = default + return value + + +# TODO: Implement getitem and setitem, for dynamic column access +class RowModel: + __slots__ = ('data',) + + _schema_: str = 'public' + _tablename_: Optional[str] = None + _columns_: dict[str, Column] = {} + + # Cache to keep track of registered Rows + _cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore + + _key_: tuple[str, ...] = () + _connector: Optional[Connector] = None + _registry: Optional[Registry] = None + + # TODO: Proper typing for a classvariable which gets dynamically assigned in subclass + table: RowTable = None + + def __init_subclass__(cls: Type[RowT], table: Optional[str] = None): + """ + Set table, _columns_, and _key_. + """ + if table is not None: + cls._tablename_ = table + + if cls._tablename_ is not None: + columns = {} + for key, value in cls.__dict__.items(): + if isinstance(value, Column): + columns[key] = value + + cls._columns_ = columns + if not cls._key_: + cls._key_ = tuple(column.name for column in columns.values() if column.primary) + cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_) + if cls._cache_ is None: + cls._cache_ = WeakValueDictionary() + + def __new__(cls, data): + # Registry pattern. + # Ensure each rowid always refers to a single Model instance + if data is not None: + rowid = cls._id_from_data(data) + + cache = cls._cache_ + + if (row := cache.get(rowid, None)) is not None: + obj = row + else: + obj = cache[rowid] = super().__new__(cls) + else: + obj = super().__new__(cls) + + return obj + + @classmethod + def as_tuple(cls): + return (cls.table.identifier, ()) + + def __init__(self, data): + self.data = data + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + @classmethod + def bind(cls, connector: Connector): + if cls.table is None: + raise ValueError("Cannot bind abstract RowModel") + cls._connector = connector + cls.table.bind(connector) + return cls + + @classmethod + def attach_to(cls, registry: Registry): + cls._registry = registry + return cls + + @property + def _dict_(self): + return {key: self.data[key] for key in self._key_} + + @property + def _rowid_(self): + return tuple(self.data[key] for key in self._key_) + + def __repr__(self): + return "{}.{}({})".format( + self.table.schema, + self.table.name, + ', '.join(repr(column.__get__(self)) for column in self._columns_.values()) + ) + + @classmethod + def _id_from_data(cls, data): + return tuple(data[key] for key in cls._key_) + + @classmethod + def _dict_from_id(cls, rowid): + return dict(zip(cls._key_, rowid)) + + @classmethod + def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]: + """ + Create or retrieve Row objects for each provided data row. + If the rows already exist in cache, updates the cached row. + """ + # TODO: Handle partial row data here somehow? + rows = [cls(data_row) for data_row in data_rows] + return rows + + @classmethod + def _delete_rows(cls, *data_rows): + """ + Remove the given rows from cache, if they exist. + May be extended to handle object deletion. + """ + cache = cls._cache_ + + for data_row in data_rows: + rowid = cls._id_from_data(data_row) + cache.pop(rowid, None) + + @classmethod + async def create(cls: Type[RowT], *args, **kwargs) -> RowT: + return await cls.table.create_row(*args, **kwargs) + + @classmethod + def fetch_where(cls: Type[RowT], *args, **kwargs): + return cls.table.fetch_rows_where(*args, **kwargs) + + @classmethod + async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]: + """ + Fetch the row with the given id, retrieving from cache where possible. + """ + row = cls._cache_.get(rowid, None) if cached else None + if row is None: + rows = await cls.fetch_where(**cls._dict_from_id(rowid)) + row = rows[0] if rows else None + if row is None: + cls._cache_[rowid] = cls(None) + elif row.data is None: + row = None + + return row + + @classmethod + async def fetch_or_create(cls, *rowid, **kwargs): + """ + Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. + """ + if rowid: + row = await cls.fetch(*rowid) + else: + rows = await cls.fetch_where(**kwargs).limit(1) + row = rows[0] if rows else None + + if row is None: + creation_kwargs = kwargs + if rowid: + creation_kwargs.update(cls._dict_from_id(rowid)) + row = await cls.create(**creation_kwargs) + return row + + async def refresh(self: RowT) -> Optional[RowT]: + """ + Refresh this Row from data. + + The return value may be `None` if the row was deleted. + """ + rows = await self.table.select_where(**self._dict_) + if not rows: + return None + else: + self.data = rows[0] + return self + + async def update(self: RowT, **values) -> Optional[RowT]: + """ + Update this Row with the given values. + + Internally passes the provided `values` to the `update` Query. + The return value may be `None` if the row was deleted. + """ + data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows) + if not data: + return None + else: + return data[0] + + async def delete(self: RowT) -> Optional[RowT]: + """ + Delete this Row. + """ + data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows) + return data[0] if data is not None else None diff --git a/queries.py b/queries.py new file mode 100644 index 0000000..0232928 --- /dev/null +++ b/queries.py @@ -0,0 +1,644 @@ +from typing import Optional, TypeVar, Any, Callable, Generic, List, Union +from enum import Enum +from itertools import chain +from psycopg import AsyncConnection, AsyncCursor +from psycopg import sql +from psycopg.rows import DictRow + +import logging + +from .conditions import Condition +from .base import Expression, RawExpr +from .connector import Connector + + +logger = logging.getLogger(__name__) + + +TQueryT = TypeVar('TQueryT', bound='TableQuery') +SQueryT = TypeVar('SQueryT', bound='Select') + +QueryResult = TypeVar('QueryResult') + + +class Query(Generic[QueryResult]): + """ + ABC for an executable query statement. + """ + __slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result') + + _adapter: Callable[..., QueryResult] + + def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs): + self.connector: Optional[Connector] = connector + self.conn: Optional[AsyncConnection] = conn + self.cursor: Optional[AsyncCursor] = cursor + + if row_adapter is not None: + self._adapter = row_adapter + else: + self._adapter = self._no_adapter + + self.result: Optional[QueryResult] = None + + def bind(self, connector: Connector): + self.connector = connector + return self + + def with_cursor(self, cursor: AsyncCursor): + self.cursor = cursor + return self + + def with_connection(self, conn: AsyncConnection): + self.conn = conn + return self + + def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def with_adapter(self, callable: Callable[..., QueryResult]): + # NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1] + # For this to work cleanly, callable should have arg type of QR1, not any + self._adapter = callable + return self + + def with_no_adapter(self): + """ + Sets the adapater to the identity. + """ + self._adapter = self._no_adapter + return self + + def one(self): + # TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1] + return self + + def build(self) -> Expression: + raise NotImplementedError + + async def _execute(self, cursor: AsyncCursor) -> QueryResult: + query, values = self.build().as_tuple() + # TODO: Move logging out to a custom cursor + # logger.debug( + # f"Executing query ({query.as_string(cursor)}) with values {values}", + # extra={'action': "Query"} + # ) + await cursor.execute(sql.Composed((query,)), values) + data = await cursor.fetchall() + self.result = self._adapter(*data) + return self.result + + async def execute(self, cursor=None) -> QueryResult: + """ + Execute the query, optionally with the provided cursor, and return the result rows. + If no cursor is provided, and no cursor has been set with `with_cursor`, + the execution will create a new cursor from the connection and close it automatically. + """ + # Create a cursor if possible + cursor = cursor if cursor is not None else self.cursor + if self.cursor is None: + if self.conn is None: + if self.connector is None: + raise ValueError("Cannot execute query without cursor, connection, or connector.") + else: + async with self.connector.connection() as conn: + async with conn.cursor() as cursor: + data = await self._execute(cursor) + else: + async with self.conn.cursor() as cursor: + data = await self._execute(cursor) + else: + data = await self._execute(cursor) + return data + + def __await__(self): + return self.execute().__await__() + + +class TableQuery(Query[QueryResult]): + """ + ABC for an executable query statement expected to be run on a single table. + """ + __slots__ = ( + 'tableid', + 'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group' + ) + + def __init__(self, tableid, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tableid: sql.Identifier = tableid + + def options(self, **kwargs): + """ + Set some query options. + Default implementation does nothing. + Should be overridden to provide specific options. + """ + return self + + +class WhereMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.condition: Optional[Condition] = None + + def where(self, *args: Condition, **kwargs): + """ + Add a Condition to the query. + Position arguments should be Conditions, + and keyword arguments should be of the form `column=Value`, + where Value may be a Value-type or a literal value. + All provided Conditions will be and-ed together to create a new Condition. + TODO: Maybe just pass this verbatim to a condition. + """ + if args or kwargs: + condition = Condition.construct(*args, **kwargs) + if self.condition is not None: + condition = self.condition & condition + + self.condition = condition + + return self + + @property + def _where_section(self) -> Optional[Expression]: + if self.condition is not None: + return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple()) + else: + return None + + +class JOINTYPE(Enum): + LEFT = sql.SQL('LEFT JOIN') + RIGHT = sql.SQL('RIGHT JOIN') + INNER = sql.SQL('INNER JOIN') + OUTER = sql.SQL('OUTER JOIN') + FULLOUTER = sql.SQL('FULL OUTER JOIN') + + +class JoinMixin(TableQuery[QueryResult]): + __slots__ = () + # TODO: Remember to add join slots to TableQuery + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._joins: list[Expression] = [] + + def join(self, + target: Union[str, Expression], + on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None, + join_type: JOINTYPE = JOINTYPE.INNER, + natural=False): + available = (on is not None) + (using is not None) + natural + if available == 0: + raise ValueError("No conditions given for Query Join") + if available > 1: + raise ValueError("Exactly one join format must be given for Query Join") + + sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())] + if isinstance(target, str): + sections.append((sql.Identifier(target), ())) + else: + sections.append(target.as_tuple()) + + if on is not None: + sections.append((sql.SQL('ON'), ())) + sections.append(on.as_tuple()) + elif using is not None: + sections.append((sql.SQL('USING'), ())) + if isinstance(using, Expression): + sections.append(using.as_tuple()) + elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str): + cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using)) + sections.append((cols, ())) + else: + raise ValueError("Unrecognised 'using' type.") + elif natural: + sections.insert(0, (sql.SQL('NATURAL'), ())) + + expr = RawExpr.join_tuples(*sections) + self._joins.append(expr) + return self + + def leftjoin(self, *args, **kwargs): + return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs) + + @property + def _join_section(self) -> Optional[Expression]: + if self._joins: + return RawExpr.join(*self._joins) + else: + return None + + +class ExtraMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._extra: Optional[Expression] = None + + def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()): + """ + Add an extra string, and optionally values, to this query. + The extra string is inserted after any condition, and before the limit. + """ + extra_expr = RawExpr(extra, values) + if self._extra is not None: + extra_expr = RawExpr.join(self._extra, extra_expr) + self._extra = extra_expr + return self + + @property + def _extra_section(self) -> Optional[Expression]: + if self._extra is None: + return None + else: + return self._extra + + +class LimitMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._limit: Optional[int] = None + + def limit(self, limit: int): + """ + Add a limit to this query. + """ + self._limit = limit + return self + + @property + def _limit_section(self) -> Optional[Expression]: + if self._limit is not None: + return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,)) + else: + return None + + +class FromMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._from: Optional[Expression] = None + + def from_expr(self, _from: Expression): + self._from = _from + return self + + @property + def _from_section(self) -> Optional[Expression]: + if self._from is not None: + expr, values = self._from.as_tuple() + return RawExpr(sql.SQL("FROM {}").format(expr), values) + else: + return None + + +class ORDER(Enum): + ASC = sql.SQL('ASC') + DESC = sql.SQL('DESC') + + +class NULLS(Enum): + FIRST = sql.SQL('NULLS FIRST') + LAST = sql.SQL('NULLS LAST') + + +class OrderMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._order: list[Expression] = [] + + def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None): + """ + Add a single sort expression to the query. + This method stacks. + """ + if isinstance(expr, Expression): + string, values = expr.as_tuple() + else: + string = sql.Identifier(expr) + values = () + + parts = [string] + if direction is not None: + parts.append(direction.value) + if nulls is not None: + parts.append(nulls.value) + + order_string = sql.SQL(' ').join(parts) + self._order.append(RawExpr(order_string, values)) + return self + + @property + def _order_section(self) -> Optional[Expression]: + if self._order: + expr = RawExpr.join(*self._order, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("ORDER BY {}").format(expr.expr) + return expr + else: + return None + + +class GroupMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._group: list[Expression] = [] + + def group_by(self, *exprs: Union[Expression, str]): + """ + Add a group expression(s) to the query. + This method stacks. + """ + for expr in exprs: + if isinstance(expr, Expression): + self._group.append(expr) + else: + self._group.append(RawExpr(sql.Identifier(expr))) + return self + + @property + def _group_section(self) -> Optional[Expression]: + if self._group: + expr = RawExpr.join(*self._group, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("GROUP BY {}").format(expr.expr) + return expr + else: + return None + + +class Insert(ExtraMixin, TableQuery[QueryResult]): + """ + Query type representing a table insert query. + """ + # TODO: Support ON CONFLICT for upserts + __slots__ = ('_columns', '_values', '_conflict') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._columns: tuple[str, ...] = () + self._values: tuple[tuple[Any, ...], ...] = () + self._conflict: Optional[Expression] = None + + def insert(self, columns, *values): + """ + Insert the given data. + + Parameters + ---------- + columns: tuple[str] + Tuple of column names to insert. + + values: tuple[tuple[Any, ...], ...] + Tuple of values to insert, corresponding to the columns. + """ + if not values: + raise ValueError("Cannot insert zero rows.") + if len(values[0]) != len(columns): + raise ValueError("Number of columns does not match length of values.") + + self._columns = columns + self._values = values + return self + + def on_conflict(self, ignore=False): + # TODO lots more we can do here + # Maybe return a Conflict object that can chain itself (not the query) + if ignore: + self._conflict = RawExpr(sql.SQL('DO NOTHING')) + return self + + @property + def _conflict_section(self) -> Optional[Expression]: + if self._conflict is not None: + e, v = self._conflict.as_tuple() + expr = RawExpr( + sql.SQL("ON CONFLICT {}").format( + e + ), + v + ) + return expr + return None + + def build(self): + columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) + single_value_str = sql.SQL('({})').format( + sql.SQL(',').join(sql.Placeholder() * len(self._columns)) + ) + values_str = sql.SQL(',').join(single_value_str * len(self._values)) + + # TODO: Check efficiency of inserting multiple values like this + # Also implement a Copy query + base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( + table=self.tableid, + columns=columns, + values_str=values_str + ) + + sections = [ + RawExpr(base, tuple(chain(*self._values))), + self._conflict_section, + self._extra_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]): + """ + Select rows from a table matching provided conditions. + """ + __slots__ = ('_columns',) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._columns: tuple[Expression, ...] = () + + def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]): + """ + Set the columns and expressions to select. + If none are given, selects all columns. + """ + cols: List[Expression] = [] + if columns: + cols.extend(map(RawExpr, map(sql.Identifier, columns))) + if exprs: + for name, expr in exprs.items(): + if isinstance(expr, str): + cols.append( + RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name)) + ) + elif isinstance(expr, sql.Composable): + cols.append( + RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name)) + ) + elif isinstance(expr, Expression): + value_expr, value_values = expr.as_tuple() + cols.append(RawExpr( + value_expr + sql.SQL(' AS ') + sql.Identifier(name), + value_values + )) + if cols: + self._columns = (*self._columns, *cols) + return self + + def build(self): + if not self._columns: + columns, columns_values = sql.SQL('*'), () + else: + columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple() + + base = sql.SQL("SELECT {columns} FROM {table}").format( + columns=columns, + table=self.tableid + ) + + sections = [ + RawExpr(base, columns_values), + self._join_section, + self._where_section, + self._group_section, + self._extra_section, + self._order_section, + self._limit_section, + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]): + """ + Query type representing a table delete query. + """ + # TODO: Cascade option for delete, maybe other options + # TODO: Require a where unless specifically disabled, for safety + + def build(self): + base = sql.SQL("DELETE FROM {table}").format( + table=self.tableid, + ) + sections = [ + RawExpr(base), + self._where_section, + self._extra_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]): + __slots__ = ( + '_set', + ) + # TODO: Again, require a where unless specifically disabled + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._set: List[Expression] = [] + + def set(self, **column_values: Union[Any, Expression]): + exprs: List[Expression] = [] + for name, value in column_values.items(): + if isinstance(value, Expression): + value_tup = value.as_tuple() + else: + value_tup = (sql.Placeholder(), (value,)) + + exprs.append( + RawExpr.join_tuples( + (sql.Identifier(name), ()), + value_tup, + joiner=sql.SQL(' = ') + ) + ) + self._set.extend(exprs) + return self + + def build(self): + if not self._set: + raise ValueError("No columns provided to update.") + set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple() + + base = sql.SQL("UPDATE {table} SET {set}").format( + table=self.tableid, + set=set_expr + ) + sections = [ + RawExpr(base, set_values), + self._from_section, + self._where_section, + self._extra_section, + self._limit_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +# async def upsert(cursor, table, constraint, **values): +# """ +# Insert or on conflict update. +# """ +# valuedict = values +# keys, values = zip(*values.items()) +# +# key_str = _format_insertkeys(keys) +# value_str, values = _format_insertvalues(values) +# update_key_str, update_key_values = _format_updatestr(valuedict) +# +# if not isinstance(constraint, str): +# constraint = ", ".join(constraint) +# +# await cursor.execute( +# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( +# table, key_str, value_str, constraint, update_key_str +# ), +# tuple((*values, *update_key_values)) +# ) +# return await cursor.fetchone() + + +# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None): +# cursor = cursor or conn.cursor() +# +# # TODO: executemany or copy syntax now +# return execute_values( +# cursor, +# """ +# UPDATE {table} +# SET {set_clause} +# FROM (VALUES {cast_row}%s) +# AS {temp_table} +# WHERE {where_clause} +# RETURNING * +# """.format( +# table=table, +# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys), +# cast_row=cast_row + ',' if cast_row else '', +# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys), +# temp_table="_t ({})".format(', '.join(set_keys + where_keys)) +# ), +# values, +# fetch=True +# ) diff --git a/registry.py b/registry.py new file mode 100644 index 0000000..c130d0f --- /dev/null +++ b/registry.py @@ -0,0 +1,102 @@ +from typing import Protocol, runtime_checkable, Optional + +from psycopg import AsyncConnection + +from .connector import Connector, Connectable + + +@runtime_checkable +class _Attachable(Connectable, Protocol): + def attach_to(self, registry: 'Registry'): + raise NotImplementedError + + +class Registry: + _attached: list[_Attachable] = [] + _name: Optional[str] = None + + def __init_subclass__(cls, name=None): + attached = [] + for _, member in cls.__dict__.items(): + if isinstance(member, _Attachable): + attached.append(member) + cls._attached = attached + cls._name = name or cls.__name__ + + def __init__(self, name=None): + self._conn: Optional[Connector] = None + self.name: str = name if name is not None else self._name + if self.name is None: + raise ValueError("A Registry must have a name!") + + self.init_tasks = [] + + for member in self._attached: + member.attach_to(self) + + def bind(self, connector: Connector): + self._conn = connector + for child in self._attached: + child.bind(connector) + + def attach(self, attachable): + self._attached.append(attachable) + if self._conn is not None: + attachable.bind(self._conn) + return attachable + + def init_task(self, coro): + """ + Initialisation tasks are run to setup the registry state. + These tasks will be run in the event loop, after connection to the database. + These tasks should be idempotent, as they may be run on reload and reconnect. + """ + self.init_tasks.append(coro) + return coro + + async def init(self): + for task in self.init_tasks: + await task(self) + return self + + +class AttachableClass: + """ABC for a default implementation of an Attachable class.""" + + _connector: Optional[Connector] = None + _registry: Optional[Registry] = None + + @classmethod + def bind(cls, connector: Connector): + cls._connector = connector + connector.connect_hook(cls.on_connect) + return cls + + @classmethod + def attach_to(cls, registry: Registry): + cls._registry = registry + return cls + + @classmethod + async def on_connect(cls, connection: AsyncConnection): + pass + + +class Attachable: + """ABC for a default implementation of an Attachable object.""" + + def __init__(self, *args, **kwargs): + self._connector: Optional[Connector] = None + self._registry: Optional[Registry] = None + + def bind(self, connector: Connector): + self._connector = connector + connector.connect_hook(self.on_connect) + return self + + def attach_to(self, registry: Registry): + self._registry = registry + return self + + async def on_connect(self, connection: AsyncConnection): + pass diff --git a/table.py b/table.py new file mode 100644 index 0000000..e20647e --- /dev/null +++ b/table.py @@ -0,0 +1,95 @@ +from typing import Optional +from psycopg.rows import DictRow +from psycopg import sql + +from . import queries as q +from .connector import Connector +from .registry import Registry + + +class Table: + """ + Transparent interface to a single table structure in the database. + Contains standard methods to access the table. + """ + + def __init__(self, name, *args, schema='public', **kwargs): + self.name: str = name + self.schema: str = schema + self.connector: Connector = None + + @property + def identifier(self): + if self.schema == 'public': + return sql.Identifier(self.name) + else: + return sql.Identifier(self.schema, self.name) + + def bind(self, connector: Connector): + self.connector = connector + return self + + def attach_to(self, registry: Registry): + self._registry = registry + return self + + def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]: + if data: + return data[0] + else: + return None + + def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]: + return q.Select( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]: + return q.Select( + self.identifier, + row_adapter=self._single_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]: + return q.Update( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]: + return q.Delete( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def insert(self, **column_values) -> q.Insert[DictRow]: + return q.Insert( + self.identifier, + row_adapter=self._single_query_adapter, + connector=self.connector + ).insert(column_values.keys(), column_values.values()) + + def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]: + return q.Insert( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).insert(*args, **kwargs) + +# def update_many(self, *args, **kwargs): +# with self.conn: +# return update_many(self.identifier, *args, **kwargs) + +# def upsert(self, *args, **kwargs): +# return upsert(self.identifier, *args, **kwargs)