commit 4da50d86783d17d56179f304f8d9fdb36d77e1e8 Author: Interitio Date: Thu Jul 31 07:39:53 2025 +1000 Initial Commit diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..affb160 --- /dev/null +++ b/__init__.py @@ -0,0 +1,9 @@ +from .conditions import Condition, condition, NULL +from .database import Database +from .models import RowModel, RowTable, WeakCache +from .table import Table +from .base import Expression, RawExpr +from .columns import ColumnExpr, Column, Integer, String +from .registry import Registry, AttachableClass, Attachable +from .adapted import RegisterEnum +from .queries import ORDER, NULLS, JOINTYPE diff --git a/adapted.py b/adapted.py new file mode 100644 index 0000000..a6b4597 --- /dev/null +++ b/adapted.py @@ -0,0 +1,40 @@ +# from enum import Enum +from typing import Optional +from psycopg.types.enum import register_enum, EnumInfo +from psycopg import AsyncConnection +from .registry import Attachable, Registry + + +class RegisterEnum(Attachable): + def __init__(self, enum, name: Optional[str] = None, mapper=None): + super().__init__() + self.enum = enum + self.name = name or enum.__name__ + self.mapping = mapper(enum) if mapper is not None else self._mapper() + + def _mapper(self): + return {m: m.value[0] for m in self.enum} + + def attach_to(self, registry: Registry): + self._registry = registry + registry.init_task(self.on_init) + return self + + async def on_init(self, registry: Registry): + connector = registry._conn + if connector is None: + raise ValueError("Cannot initialise without connector!") + connector.connect_hook(self.connection_hook) + # await connector.refresh_pool() + # The below may be somewhat dangerous + # But adaption should never write to the database + await connector.map_over_pool(self.connection_hook) + # if conn := connector.conn: + # # Ensure the adaption is run in the current context as well + # await self.connection_hook(conn) + + async def connection_hook(self, conn: AsyncConnection): + info = await EnumInfo.fetch(conn, self.name) + if info is None: + raise ValueError(f"Enum {self.name} not found in database.") + register_enum(info, conn, self.enum, mapping=list(self.mapping.items())) diff --git a/base.py b/base.py new file mode 100644 index 0000000..272d588 --- /dev/null +++ b/base.py @@ -0,0 +1,45 @@ +from abc import abstractmethod +from typing import Any, Protocol, runtime_checkable +from itertools import chain +from psycopg import sql + + +@runtime_checkable +class Expression(Protocol): + __slots__ = () + + @abstractmethod + def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]: + raise NotImplementedError + + +class RawExpr(Expression): + __slots__ = ('expr', 'values') + + expr: sql.Composable + values: tuple[Any, ...] + + def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()): + self.expr = expr + self.values = values + + def as_tuple(self): + return (self.expr, self.values) + + @classmethod + def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')): + """ + Join a sequence of Expressions into a single RawExpr. + """ + tups = ( + expression.as_tuple() + for expression in expressions + ) + return cls.join_tuples(*tups, joiner=joiner) + + @classmethod + def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')): + exprs, values = zip(*tuples) + expr = joiner.join(exprs) + value = tuple(chain(*values)) + return cls(expr, value) diff --git a/columns.py b/columns.py new file mode 100644 index 0000000..252db83 --- /dev/null +++ b/columns.py @@ -0,0 +1,155 @@ +from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING +from psycopg import sql +from datetime import datetime + +from .base import RawExpr, Expression +from .conditions import Condition, Joiner +from .table import Table + + +class ColumnExpr(RawExpr): + __slots__ = () + + def __lt__(self, obj) -> Condition: + expr, values = self.as_tuple() + + if isinstance(obj, Expression): + # column < Expression + obj_expr, obj_values = obj.as_tuple() + cond_exprs = (expr, Joiner.LT, obj_expr) + cond_values = (*values, *obj_values) + else: + # column < Literal + cond_exprs = (expr, Joiner.LT, sql.Placeholder()) + cond_values = (*values, obj) + + return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __le__(self, obj) -> Condition: + expr, values = self.as_tuple() + + if isinstance(obj, Expression): + # column <= Expression + obj_expr, obj_values = obj.as_tuple() + cond_exprs = (expr, Joiner.LE, obj_expr) + cond_values = (*values, *obj_values) + else: + # column <= Literal + cond_exprs = (expr, Joiner.LE, sql.Placeholder()) + cond_values = (*values, obj) + + return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __eq__(self, obj) -> Condition: # type: ignore[override] + return Condition._expression_equality(self, obj) + + def __ne__(self, obj) -> Condition: # type: ignore[override] + return ~(self.__eq__(obj)) + + def __gt__(self, obj) -> Condition: + return ~(self.__le__(obj)) + + def __ge__(self, obj) -> Condition: + return ~(self.__lt__(obj)) + + def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} + {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} + {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def __sub__(self, obj) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} - {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} - {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def __mul__(self, obj) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} * {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} * {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def CAST(self, target_type: sql.Composable): + return ColumnExpr( + sql.SQL("({}::{})").format(self.expr, target_type), + self.values + ) + + +T = TypeVar('T') + +if TYPE_CHECKING: + from .models import RowModel + + +class Column(ColumnExpr, Generic[T]): + def __init__(self, name: Optional[str] = None, + primary: bool = False, references: Optional['Column'] = None, + type: Optional[Type[T]] = None): + self.primary = primary + self.references = references + self.name: str = name # type: ignore + self.owner: Optional['RowModel'] = None + self._type = type + + self.expr = sql.Identifier(name) if name else sql.SQL('') + self.values = () + + def __set_name__(self, owner, name): + # Only allow setting the owner once + self.name = self.name or name + self.owner = owner + self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name) + + @overload + def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': + ... + + @overload + def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T: + ... + + def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]": + # Get value from row data or session + if obj is None: + return self + else: + return obj.data[self.name] + + +class Integer(Column[int]): + pass + + +class String(Column[str]): + pass + + +class Bool(Column[bool]): + pass + + +class Timestamp(Column[datetime]): + pass diff --git a/conditions.py b/conditions.py new file mode 100644 index 0000000..f40dff6 --- /dev/null +++ b/conditions.py @@ -0,0 +1,214 @@ +# from meta import sharding +from typing import Any, Union +from enum import Enum +from itertools import chain +from psycopg import sql + +from .base import Expression, RawExpr + + +""" +A Condition is a "logical" database expression, intended for use in Where statements. +Conditions support bitwise logical operators ~, &, |, each producing another Condition. +""" + +NULL = None + + +class Joiner(Enum): + EQUALS = ('=', '!=') + IS = ('IS', 'IS NOT') + LIKE = ('LIKE', 'NOT LIKE') + BETWEEN = ('BETWEEN', 'NOT BETWEEN') + IN = ('IN', 'NOT IN') + LT = ('<', '>=') + LE = ('<=', '>') + NONE = ('', '') + + +class Condition(Expression): + __slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values') + + def __init__(self, + expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''), + values: tuple[Any, ...] = (), negated=False + ): + self.expr1 = expr1 + self.joiner = joiner + self.negated = negated + self.expr2 = expr2 + self.values = values + + def as_tuple(self): + expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2)) + if self.negated and self.joiner is Joiner.NONE: + expr = sql.SQL("NOT ({})").format(expr) + return (expr, self.values) + + @classmethod + def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]): + """ + Construct a Condition from a sequence of Conditions, + together with some explicit column conditions. + """ + # TODO: Consider adding a _table identifier here so we can identify implicit columns + # Or just require subquery type conditions to always come from modelled tables. + implicit_conditions = ( + cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items() + ) + return cls._and(*conditions, *implicit_conditions) + + @classmethod + def _and(cls, *conditions: 'Condition'): + if not len(conditions): + raise ValueError("Cannot combine 0 Conditions") + if len(conditions) == 1: + return conditions[0] + + exprs, values = zip(*(condition.as_tuple() for condition in conditions)) + cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs)) + cond_values = tuple(chain(*values)) + + return Condition(cond_expr, values=cond_values) + + @classmethod + def _or(cls, *conditions: 'Condition'): + if not len(conditions): + raise ValueError("Cannot combine 0 Conditions") + if len(conditions) == 1: + return conditions[0] + + exprs, values = zip(*(condition.as_tuple() for condition in conditions)) + cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs)) + cond_values = tuple(chain(*values)) + + return Condition(cond_expr, values=cond_values) + + @classmethod + def _not(cls, condition: 'Condition'): + condition.negated = not condition.negated + return condition + + @classmethod + def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition': + # TODO: Check if this supports sbqueries + col_expr, col_values = column.as_tuple() + + # TODO: Also support sql.SQL? For joins? + if isinstance(value, Expression): + # column = Expression + value_expr, value_values = value.as_tuple() + cond_exprs = (col_expr, Joiner.EQUALS, value_expr) + cond_values = (*col_values, *value_values) + elif isinstance(value, (tuple, list)): + # column in (...) + # TODO: Support expressions in value tuple? + if not value: + raise ValueError("Cannot create Condition from empty iterable!") + value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value))) + cond_exprs = (col_expr, Joiner.IN, value_expr) + cond_values = (*col_values, *value) + elif value is None: + # column IS NULL + cond_exprs = (col_expr, Joiner.IS, sql.NULL) + cond_values = col_values + else: + # column = Literal + cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder()) + cond_values = (*col_values, value) + + return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __invert__(self) -> 'Condition': + self.negated = not self.negated + return self + + def __and__(self, condition: 'Condition') -> 'Condition': + return self._and(self, condition) + + def __or__(self, condition: 'Condition') -> 'Condition': + return self._or(self, condition) + + +# Helper method to simply condition construction +def condition(*args, **kwargs) -> Condition: + return Condition.construct(*args, **kwargs) + + +# class NOT(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# if item: +# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item)))) +# values.extend(item) +# else: +# raise ValueError("Cannot check an empty iterable!") +# else: +# conditions.append("{}!={}".format(key, _replace_char)) +# values.append(item) +# +# +# class GEQ(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# raise ValueError("Cannot apply GEQ condition to a list!") +# else: +# conditions.append("{} >= {}".format(key, _replace_char)) +# values.append(item) +# +# +# class LEQ(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# raise ValueError("Cannot apply LEQ condition to a list!") +# else: +# conditions.append("{} <= {}".format(key, _replace_char)) +# values.append(item) +# +# +# class Constant(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# conditions.append("{} {}".format(key, self.value)) +# +# +# class SHARDID(Condition): +# __slots__ = ('shardid', 'shard_count') +# +# def __init__(self, shardid, shard_count): +# self.shardid = shardid +# self.shard_count = shard_count +# +# def apply(self, key, values, conditions): +# if self.shard_count > 1: +# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) +# values.append(self.shardid) +# +# +# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) +# +# +# NULL = Constant('IS NULL') +# NOTNULL = Constant('IS NOT NULL') diff --git a/connector.py b/connector.py new file mode 100644 index 0000000..6baf3b5 --- /dev/null +++ b/connector.py @@ -0,0 +1,135 @@ +from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional +import logging + +from contextvars import ContextVar +from contextlib import asynccontextmanager +import psycopg as psq +from psycopg_pool import AsyncConnectionPool +from psycopg.pq import TransactionStatus + +from .cursor import AsyncLoggingCursor + +logger = logging.getLogger(__name__) + +row_factory = psq.rows.dict_row + +ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None) + + +class Connector: + cursor_factory = AsyncLoggingCursor + + def __init__(self, conn_args): + self._conn_args = conn_args + self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory) + + self.pool = self.make_pool() + + self.conn_hooks = [] + + @property + def conn(self) -> Optional[psq.AsyncConnection]: + """ + Convenience property for the current context connection. + """ + return ctx_connection.get() + + @conn.setter + def conn(self, conn: psq.AsyncConnection): + """ + Set the contextual connection in the current context. + Always do this in an isolated context! + """ + ctx_connection.set(conn) + + def make_pool(self) -> AsyncConnectionPool: + logger.info("Initialising connection pool.", extra={'action': "Pool Init"}) + return AsyncConnectionPool( + self._conn_args, + open=False, + min_size=1, + max_size=4, + configure=self._setup_connection, + kwargs=self._conn_kwargs + ) + + async def refresh_pool(self): + """ + Refresh the pool. + + The point of this is to invalidate any existing connections so that the connection set up is run again. + Better ways should be sought (a way to + """ + logger.info("Pool refresh requested, closing and reopening.") + old_pool = self.pool + self.pool = self.make_pool() + await self.pool.open() + logger.info(f"Old pool statistics: {self.pool.get_stats()}") + await old_pool.close() + logger.info("Pool refresh complete.") + + async def map_over_pool(self, callable): + """ + Dangerous method to call a method on each connection in the pool. + + Utilises private methods of the AsyncConnectionPool. + """ + async with self.pool._lock: + conns = list(self.pool._pool) + while conns: + conn = conns.pop() + try: + await callable(conn) + except Exception: + logger.exception(f"Mapped connection task failed. {callable.__name__}") + + @asynccontextmanager + async def open(self): + try: + logger.info("Opening database pool.") + await self.pool.open() + yield + finally: + # May be a different pool! + logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}") + await self.pool.close() + + @asynccontextmanager + async def connection(self) -> psq.AsyncConnection: + """ + Asynchronous context manager to get and manage a connection. + + If the context connection is set, uses this and does not manage the lifetime. + Otherwise, requests a new connection from the pool and returns it when done. + """ + logger.debug("Database connection requested.", extra={'action': "Data Connect"}) + if (conn := self.conn): + yield conn + else: + async with self.pool.connection() as conn: + yield conn + + async def _setup_connection(self, conn: psq.AsyncConnection): + logger.debug("Initialising new connection.", extra={'action': "Conn Init"}) + for hook in self.conn_hooks: + try: + await hook(conn) + except Exception: + logger.exception("Exception encountered setting up new connection") + return conn + + def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): + """ + Minimal decorator to register a coroutine to run on connect or reconnect. + + Note that these are only run on connect and reconnect. + If a hook is registered after connection, it will not be run. + """ + self.conn_hooks.append(coro) + return coro + + +@runtime_checkable +class Connectable(Protocol): + def bind(self, connector: Connector): + raise NotImplementedError diff --git a/cursor.py b/cursor.py new file mode 100644 index 0000000..5d183e0 --- /dev/null +++ b/cursor.py @@ -0,0 +1,42 @@ +import logging +from typing import Optional + +from psycopg import AsyncCursor, sql +from psycopg.abc import Query, Params +from psycopg._encodings import conn_encoding + +logger = logging.getLogger(__name__) + + +class AsyncLoggingCursor(AsyncCursor): + def mogrify_query(self, query: Query): + if isinstance(query, str): + msg = query + elif isinstance(query, (sql.SQL, sql.Composed)): + msg = query.as_string(self) + elif isinstance(query, bytes): + msg = query.decode(conn_encoding(self._conn.pgconn), 'replace') + else: + msg = repr(query) + return msg + + async def execute(self, query: Query, params: Optional[Params] = None, **kwargs): + if logging.DEBUG >= logger.getEffectiveLevel(): + msg = self.mogrify_query(query) + logger.debug( + "Executing query (%s) with values %s", msg, params, + extra={'action': "Query Execute"} + ) + try: + return await super().execute(query, params=params, **kwargs) + except Exception: + msg = self.mogrify_query(query) + logger.exception( + "Exception during query execution. Query (%s) with parameters %s.", + msg, params, + extra={'action': "Query Execute"}, + stack_info=True + ) + else: + # TODO: Possibly log execution time + pass diff --git a/database.py b/database.py new file mode 100644 index 0000000..255e412 --- /dev/null +++ b/database.py @@ -0,0 +1,47 @@ +from typing import TypeVar +import logging +from collections import namedtuple + +# from .cursor import AsyncLoggingCursor +from .registry import Registry +from .connector import Connector + + +logger = logging.getLogger(__name__) + +Version = namedtuple('Version', ('version', 'time', 'author')) + +T = TypeVar('T', bound=Registry) + + +class Database(Connector): + # cursor_factory = AsyncLoggingCursor + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.registries: dict[str, Registry] = {} + + def load_registry(self, registry: T) -> T: + logger.debug( + f"Loading and binding registry '{registry.name}'.", + extra={'action': f"Reg {registry.name}"} + ) + registry.bind(self) + self.registries[registry.name] = registry + return registry + + async def version(self) -> Version: + """ + Return the current schema version as a Version namedtuple. + """ + async with self.connection() as conn: + async with conn.cursor() as cursor: + # Get last entry in version table, compare against desired version + await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") + row = await cursor.fetchone() + if row: + return Version(row['version'], row['time'], row['author']) + else: + # No versions in the database + return Version(-1, None, None) diff --git a/models.py b/models.py new file mode 100644 index 0000000..54b6282 --- /dev/null +++ b/models.py @@ -0,0 +1,323 @@ +from typing import TypeVar, Type, Optional, Generic, Union +# from typing_extensions import Self +from weakref import WeakValueDictionary +from collections.abc import MutableMapping + +from psycopg.rows import DictRow + +from .table import Table +from .columns import Column +from . import queries as q +from .connector import Connector +from .registry import Registry + + +RowT = TypeVar('RowT', bound='RowModel') + + +class MISSING: + __slots__ = ('oid',) + + def __init__(self, oid): + self.oid = oid + + +class RowTable(Table, Generic[RowT]): + __slots__ = ( + 'model', + ) + + def __init__(self, name, model: Type[RowT], **kwargs): + super().__init__(name, **kwargs) + self.model = model + + @property + def columns(self): + return self.model._columns_ + + @property + def id_col(self): + return self.model._key_ + + @property + def row_cache(self): + return self.model._cache_ + + def _many_query_adapter(self, *data): + self.model._make_rows(*data) + return data + + def _single_query_adapter(self, *data): + if data: + self.model._make_rows(*data) + return data[0] + else: + return None + + def _delete_query_adapter(self, *data): + self.model._delete_rows(*data) + return data + + # New methods to fetch and create rows + async def create_row(self, *args, **kwargs) -> RowT: + data = await super().insert(*args, **kwargs) + return self.model._make_rows(data)[0] + + def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]: + # TODO: Handle list of rowids here? + return q.Select( + self.identifier, + row_adapter=self.model._make_rows, + connector=self.connector + ).where(*args, **kwargs) + + +WK = TypeVar('WK') +WV = TypeVar('WV') + + +class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]): + def __init__(self, ref_cache): + self.ref_cache = ref_cache + self.weak_cache = WeakValueDictionary() + + def __getitem__(self, key): + value = self.weak_cache[key] + self.ref_cache[key] = value + return value + + def __setitem__(self, key, value): + self.weak_cache[key] = value + self.ref_cache[key] = value + + def __delitem__(self, key): + del self.weak_cache[key] + try: + del self.ref_cache[key] + except KeyError: + pass + + def __contains__(self, key): + return key in self.weak_cache + + def __iter__(self): + return iter(self.weak_cache) + + def __len__(self): + return len(self.weak_cache) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def pop(self, key, default=None): + if key in self: + value = self[key] + del self[key] + else: + value = default + return value + + +# TODO: Implement getitem and setitem, for dynamic column access +class RowModel: + __slots__ = ('data',) + + _schema_: str = 'public' + _tablename_: Optional[str] = None + _columns_: dict[str, Column] = {} + + # Cache to keep track of registered Rows + _cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore + + _key_: tuple[str, ...] = () + _connector: Optional[Connector] = None + _registry: Optional[Registry] = None + + # TODO: Proper typing for a classvariable which gets dynamically assigned in subclass + table: RowTable = None + + def __init_subclass__(cls: Type[RowT], table: Optional[str] = None): + """ + Set table, _columns_, and _key_. + """ + if table is not None: + cls._tablename_ = table + + if cls._tablename_ is not None: + columns = {} + for key, value in cls.__dict__.items(): + if isinstance(value, Column): + columns[key] = value + + cls._columns_ = columns + if not cls._key_: + cls._key_ = tuple(column.name for column in columns.values() if column.primary) + cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_) + if cls._cache_ is None: + cls._cache_ = WeakValueDictionary() + + def __new__(cls, data): + # Registry pattern. + # Ensure each rowid always refers to a single Model instance + if data is not None: + rowid = cls._id_from_data(data) + + cache = cls._cache_ + + if (row := cache.get(rowid, None)) is not None: + obj = row + else: + obj = cache[rowid] = super().__new__(cls) + else: + obj = super().__new__(cls) + + return obj + + @classmethod + def as_tuple(cls): + return (cls.table.identifier, ()) + + def __init__(self, data): + self.data = data + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + @classmethod + def bind(cls, connector: Connector): + if cls.table is None: + raise ValueError("Cannot bind abstract RowModel") + cls._connector = connector + cls.table.bind(connector) + return cls + + @classmethod + def attach_to(cls, registry: Registry): + cls._registry = registry + return cls + + @property + def _dict_(self): + return {key: self.data[key] for key in self._key_} + + @property + def _rowid_(self): + return tuple(self.data[key] for key in self._key_) + + def __repr__(self): + return "{}.{}({})".format( + self.table.schema, + self.table.name, + ', '.join(repr(column.__get__(self)) for column in self._columns_.values()) + ) + + @classmethod + def _id_from_data(cls, data): + return tuple(data[key] for key in cls._key_) + + @classmethod + def _dict_from_id(cls, rowid): + return dict(zip(cls._key_, rowid)) + + @classmethod + def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]: + """ + Create or retrieve Row objects for each provided data row. + If the rows already exist in cache, updates the cached row. + """ + # TODO: Handle partial row data here somehow? + rows = [cls(data_row) for data_row in data_rows] + return rows + + @classmethod + def _delete_rows(cls, *data_rows): + """ + Remove the given rows from cache, if they exist. + May be extended to handle object deletion. + """ + cache = cls._cache_ + + for data_row in data_rows: + rowid = cls._id_from_data(data_row) + cache.pop(rowid, None) + + @classmethod + async def create(cls: Type[RowT], *args, **kwargs) -> RowT: + return await cls.table.create_row(*args, **kwargs) + + @classmethod + def fetch_where(cls: Type[RowT], *args, **kwargs): + return cls.table.fetch_rows_where(*args, **kwargs) + + @classmethod + async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]: + """ + Fetch the row with the given id, retrieving from cache where possible. + """ + row = cls._cache_.get(rowid, None) if cached else None + if row is None: + rows = await cls.fetch_where(**cls._dict_from_id(rowid)) + row = rows[0] if rows else None + if row is None: + cls._cache_[rowid] = cls(None) + elif row.data is None: + row = None + + return row + + @classmethod + async def fetch_or_create(cls, *rowid, **kwargs): + """ + Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. + """ + if rowid: + row = await cls.fetch(*rowid) + else: + rows = await cls.fetch_where(**kwargs).limit(1) + row = rows[0] if rows else None + + if row is None: + creation_kwargs = kwargs + if rowid: + creation_kwargs.update(cls._dict_from_id(rowid)) + row = await cls.create(**creation_kwargs) + return row + + async def refresh(self: RowT) -> Optional[RowT]: + """ + Refresh this Row from data. + + The return value may be `None` if the row was deleted. + """ + rows = await self.table.select_where(**self._dict_) + if not rows: + return None + else: + self.data = rows[0] + return self + + async def update(self: RowT, **values) -> Optional[RowT]: + """ + Update this Row with the given values. + + Internally passes the provided `values` to the `update` Query. + The return value may be `None` if the row was deleted. + """ + data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows) + if not data: + return None + else: + return data[0] + + async def delete(self: RowT) -> Optional[RowT]: + """ + Delete this Row. + """ + data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows) + return data[0] if data is not None else None diff --git a/queries.py b/queries.py new file mode 100644 index 0000000..0232928 --- /dev/null +++ b/queries.py @@ -0,0 +1,644 @@ +from typing import Optional, TypeVar, Any, Callable, Generic, List, Union +from enum import Enum +from itertools import chain +from psycopg import AsyncConnection, AsyncCursor +from psycopg import sql +from psycopg.rows import DictRow + +import logging + +from .conditions import Condition +from .base import Expression, RawExpr +from .connector import Connector + + +logger = logging.getLogger(__name__) + + +TQueryT = TypeVar('TQueryT', bound='TableQuery') +SQueryT = TypeVar('SQueryT', bound='Select') + +QueryResult = TypeVar('QueryResult') + + +class Query(Generic[QueryResult]): + """ + ABC for an executable query statement. + """ + __slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result') + + _adapter: Callable[..., QueryResult] + + def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs): + self.connector: Optional[Connector] = connector + self.conn: Optional[AsyncConnection] = conn + self.cursor: Optional[AsyncCursor] = cursor + + if row_adapter is not None: + self._adapter = row_adapter + else: + self._adapter = self._no_adapter + + self.result: Optional[QueryResult] = None + + def bind(self, connector: Connector): + self.connector = connector + return self + + def with_cursor(self, cursor: AsyncCursor): + self.cursor = cursor + return self + + def with_connection(self, conn: AsyncConnection): + self.conn = conn + return self + + def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def with_adapter(self, callable: Callable[..., QueryResult]): + # NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1] + # For this to work cleanly, callable should have arg type of QR1, not any + self._adapter = callable + return self + + def with_no_adapter(self): + """ + Sets the adapater to the identity. + """ + self._adapter = self._no_adapter + return self + + def one(self): + # TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1] + return self + + def build(self) -> Expression: + raise NotImplementedError + + async def _execute(self, cursor: AsyncCursor) -> QueryResult: + query, values = self.build().as_tuple() + # TODO: Move logging out to a custom cursor + # logger.debug( + # f"Executing query ({query.as_string(cursor)}) with values {values}", + # extra={'action': "Query"} + # ) + await cursor.execute(sql.Composed((query,)), values) + data = await cursor.fetchall() + self.result = self._adapter(*data) + return self.result + + async def execute(self, cursor=None) -> QueryResult: + """ + Execute the query, optionally with the provided cursor, and return the result rows. + If no cursor is provided, and no cursor has been set with `with_cursor`, + the execution will create a new cursor from the connection and close it automatically. + """ + # Create a cursor if possible + cursor = cursor if cursor is not None else self.cursor + if self.cursor is None: + if self.conn is None: + if self.connector is None: + raise ValueError("Cannot execute query without cursor, connection, or connector.") + else: + async with self.connector.connection() as conn: + async with conn.cursor() as cursor: + data = await self._execute(cursor) + else: + async with self.conn.cursor() as cursor: + data = await self._execute(cursor) + else: + data = await self._execute(cursor) + return data + + def __await__(self): + return self.execute().__await__() + + +class TableQuery(Query[QueryResult]): + """ + ABC for an executable query statement expected to be run on a single table. + """ + __slots__ = ( + 'tableid', + 'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group' + ) + + def __init__(self, tableid, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tableid: sql.Identifier = tableid + + def options(self, **kwargs): + """ + Set some query options. + Default implementation does nothing. + Should be overridden to provide specific options. + """ + return self + + +class WhereMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.condition: Optional[Condition] = None + + def where(self, *args: Condition, **kwargs): + """ + Add a Condition to the query. + Position arguments should be Conditions, + and keyword arguments should be of the form `column=Value`, + where Value may be a Value-type or a literal value. + All provided Conditions will be and-ed together to create a new Condition. + TODO: Maybe just pass this verbatim to a condition. + """ + if args or kwargs: + condition = Condition.construct(*args, **kwargs) + if self.condition is not None: + condition = self.condition & condition + + self.condition = condition + + return self + + @property + def _where_section(self) -> Optional[Expression]: + if self.condition is not None: + return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple()) + else: + return None + + +class JOINTYPE(Enum): + LEFT = sql.SQL('LEFT JOIN') + RIGHT = sql.SQL('RIGHT JOIN') + INNER = sql.SQL('INNER JOIN') + OUTER = sql.SQL('OUTER JOIN') + FULLOUTER = sql.SQL('FULL OUTER JOIN') + + +class JoinMixin(TableQuery[QueryResult]): + __slots__ = () + # TODO: Remember to add join slots to TableQuery + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._joins: list[Expression] = [] + + def join(self, + target: Union[str, Expression], + on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None, + join_type: JOINTYPE = JOINTYPE.INNER, + natural=False): + available = (on is not None) + (using is not None) + natural + if available == 0: + raise ValueError("No conditions given for Query Join") + if available > 1: + raise ValueError("Exactly one join format must be given for Query Join") + + sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())] + if isinstance(target, str): + sections.append((sql.Identifier(target), ())) + else: + sections.append(target.as_tuple()) + + if on is not None: + sections.append((sql.SQL('ON'), ())) + sections.append(on.as_tuple()) + elif using is not None: + sections.append((sql.SQL('USING'), ())) + if isinstance(using, Expression): + sections.append(using.as_tuple()) + elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str): + cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using)) + sections.append((cols, ())) + else: + raise ValueError("Unrecognised 'using' type.") + elif natural: + sections.insert(0, (sql.SQL('NATURAL'), ())) + + expr = RawExpr.join_tuples(*sections) + self._joins.append(expr) + return self + + def leftjoin(self, *args, **kwargs): + return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs) + + @property + def _join_section(self) -> Optional[Expression]: + if self._joins: + return RawExpr.join(*self._joins) + else: + return None + + +class ExtraMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._extra: Optional[Expression] = None + + def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()): + """ + Add an extra string, and optionally values, to this query. + The extra string is inserted after any condition, and before the limit. + """ + extra_expr = RawExpr(extra, values) + if self._extra is not None: + extra_expr = RawExpr.join(self._extra, extra_expr) + self._extra = extra_expr + return self + + @property + def _extra_section(self) -> Optional[Expression]: + if self._extra is None: + return None + else: + return self._extra + + +class LimitMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._limit: Optional[int] = None + + def limit(self, limit: int): + """ + Add a limit to this query. + """ + self._limit = limit + return self + + @property + def _limit_section(self) -> Optional[Expression]: + if self._limit is not None: + return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,)) + else: + return None + + +class FromMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._from: Optional[Expression] = None + + def from_expr(self, _from: Expression): + self._from = _from + return self + + @property + def _from_section(self) -> Optional[Expression]: + if self._from is not None: + expr, values = self._from.as_tuple() + return RawExpr(sql.SQL("FROM {}").format(expr), values) + else: + return None + + +class ORDER(Enum): + ASC = sql.SQL('ASC') + DESC = sql.SQL('DESC') + + +class NULLS(Enum): + FIRST = sql.SQL('NULLS FIRST') + LAST = sql.SQL('NULLS LAST') + + +class OrderMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._order: list[Expression] = [] + + def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None): + """ + Add a single sort expression to the query. + This method stacks. + """ + if isinstance(expr, Expression): + string, values = expr.as_tuple() + else: + string = sql.Identifier(expr) + values = () + + parts = [string] + if direction is not None: + parts.append(direction.value) + if nulls is not None: + parts.append(nulls.value) + + order_string = sql.SQL(' ').join(parts) + self._order.append(RawExpr(order_string, values)) + return self + + @property + def _order_section(self) -> Optional[Expression]: + if self._order: + expr = RawExpr.join(*self._order, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("ORDER BY {}").format(expr.expr) + return expr + else: + return None + + +class GroupMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._group: list[Expression] = [] + + def group_by(self, *exprs: Union[Expression, str]): + """ + Add a group expression(s) to the query. + This method stacks. + """ + for expr in exprs: + if isinstance(expr, Expression): + self._group.append(expr) + else: + self._group.append(RawExpr(sql.Identifier(expr))) + return self + + @property + def _group_section(self) -> Optional[Expression]: + if self._group: + expr = RawExpr.join(*self._group, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("GROUP BY {}").format(expr.expr) + return expr + else: + return None + + +class Insert(ExtraMixin, TableQuery[QueryResult]): + """ + Query type representing a table insert query. + """ + # TODO: Support ON CONFLICT for upserts + __slots__ = ('_columns', '_values', '_conflict') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._columns: tuple[str, ...] = () + self._values: tuple[tuple[Any, ...], ...] = () + self._conflict: Optional[Expression] = None + + def insert(self, columns, *values): + """ + Insert the given data. + + Parameters + ---------- + columns: tuple[str] + Tuple of column names to insert. + + values: tuple[tuple[Any, ...], ...] + Tuple of values to insert, corresponding to the columns. + """ + if not values: + raise ValueError("Cannot insert zero rows.") + if len(values[0]) != len(columns): + raise ValueError("Number of columns does not match length of values.") + + self._columns = columns + self._values = values + return self + + def on_conflict(self, ignore=False): + # TODO lots more we can do here + # Maybe return a Conflict object that can chain itself (not the query) + if ignore: + self._conflict = RawExpr(sql.SQL('DO NOTHING')) + return self + + @property + def _conflict_section(self) -> Optional[Expression]: + if self._conflict is not None: + e, v = self._conflict.as_tuple() + expr = RawExpr( + sql.SQL("ON CONFLICT {}").format( + e + ), + v + ) + return expr + return None + + def build(self): + columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) + single_value_str = sql.SQL('({})').format( + sql.SQL(',').join(sql.Placeholder() * len(self._columns)) + ) + values_str = sql.SQL(',').join(single_value_str * len(self._values)) + + # TODO: Check efficiency of inserting multiple values like this + # Also implement a Copy query + base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( + table=self.tableid, + columns=columns, + values_str=values_str + ) + + sections = [ + RawExpr(base, tuple(chain(*self._values))), + self._conflict_section, + self._extra_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]): + """ + Select rows from a table matching provided conditions. + """ + __slots__ = ('_columns',) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._columns: tuple[Expression, ...] = () + + def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]): + """ + Set the columns and expressions to select. + If none are given, selects all columns. + """ + cols: List[Expression] = [] + if columns: + cols.extend(map(RawExpr, map(sql.Identifier, columns))) + if exprs: + for name, expr in exprs.items(): + if isinstance(expr, str): + cols.append( + RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name)) + ) + elif isinstance(expr, sql.Composable): + cols.append( + RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name)) + ) + elif isinstance(expr, Expression): + value_expr, value_values = expr.as_tuple() + cols.append(RawExpr( + value_expr + sql.SQL(' AS ') + sql.Identifier(name), + value_values + )) + if cols: + self._columns = (*self._columns, *cols) + return self + + def build(self): + if not self._columns: + columns, columns_values = sql.SQL('*'), () + else: + columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple() + + base = sql.SQL("SELECT {columns} FROM {table}").format( + columns=columns, + table=self.tableid + ) + + sections = [ + RawExpr(base, columns_values), + self._join_section, + self._where_section, + self._group_section, + self._extra_section, + self._order_section, + self._limit_section, + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]): + """ + Query type representing a table delete query. + """ + # TODO: Cascade option for delete, maybe other options + # TODO: Require a where unless specifically disabled, for safety + + def build(self): + base = sql.SQL("DELETE FROM {table}").format( + table=self.tableid, + ) + sections = [ + RawExpr(base), + self._where_section, + self._extra_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]): + __slots__ = ( + '_set', + ) + # TODO: Again, require a where unless specifically disabled + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._set: List[Expression] = [] + + def set(self, **column_values: Union[Any, Expression]): + exprs: List[Expression] = [] + for name, value in column_values.items(): + if isinstance(value, Expression): + value_tup = value.as_tuple() + else: + value_tup = (sql.Placeholder(), (value,)) + + exprs.append( + RawExpr.join_tuples( + (sql.Identifier(name), ()), + value_tup, + joiner=sql.SQL(' = ') + ) + ) + self._set.extend(exprs) + return self + + def build(self): + if not self._set: + raise ValueError("No columns provided to update.") + set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple() + + base = sql.SQL("UPDATE {table} SET {set}").format( + table=self.tableid, + set=set_expr + ) + sections = [ + RawExpr(base, set_values), + self._from_section, + self._where_section, + self._extra_section, + self._limit_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +# async def upsert(cursor, table, constraint, **values): +# """ +# Insert or on conflict update. +# """ +# valuedict = values +# keys, values = zip(*values.items()) +# +# key_str = _format_insertkeys(keys) +# value_str, values = _format_insertvalues(values) +# update_key_str, update_key_values = _format_updatestr(valuedict) +# +# if not isinstance(constraint, str): +# constraint = ", ".join(constraint) +# +# await cursor.execute( +# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( +# table, key_str, value_str, constraint, update_key_str +# ), +# tuple((*values, *update_key_values)) +# ) +# return await cursor.fetchone() + + +# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None): +# cursor = cursor or conn.cursor() +# +# # TODO: executemany or copy syntax now +# return execute_values( +# cursor, +# """ +# UPDATE {table} +# SET {set_clause} +# FROM (VALUES {cast_row}%s) +# AS {temp_table} +# WHERE {where_clause} +# RETURNING * +# """.format( +# table=table, +# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys), +# cast_row=cast_row + ',' if cast_row else '', +# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys), +# temp_table="_t ({})".format(', '.join(set_keys + where_keys)) +# ), +# values, +# fetch=True +# ) diff --git a/registry.py b/registry.py new file mode 100644 index 0000000..c130d0f --- /dev/null +++ b/registry.py @@ -0,0 +1,102 @@ +from typing import Protocol, runtime_checkable, Optional + +from psycopg import AsyncConnection + +from .connector import Connector, Connectable + + +@runtime_checkable +class _Attachable(Connectable, Protocol): + def attach_to(self, registry: 'Registry'): + raise NotImplementedError + + +class Registry: + _attached: list[_Attachable] = [] + _name: Optional[str] = None + + def __init_subclass__(cls, name=None): + attached = [] + for _, member in cls.__dict__.items(): + if isinstance(member, _Attachable): + attached.append(member) + cls._attached = attached + cls._name = name or cls.__name__ + + def __init__(self, name=None): + self._conn: Optional[Connector] = None + self.name: str = name if name is not None else self._name + if self.name is None: + raise ValueError("A Registry must have a name!") + + self.init_tasks = [] + + for member in self._attached: + member.attach_to(self) + + def bind(self, connector: Connector): + self._conn = connector + for child in self._attached: + child.bind(connector) + + def attach(self, attachable): + self._attached.append(attachable) + if self._conn is not None: + attachable.bind(self._conn) + return attachable + + def init_task(self, coro): + """ + Initialisation tasks are run to setup the registry state. + These tasks will be run in the event loop, after connection to the database. + These tasks should be idempotent, as they may be run on reload and reconnect. + """ + self.init_tasks.append(coro) + return coro + + async def init(self): + for task in self.init_tasks: + await task(self) + return self + + +class AttachableClass: + """ABC for a default implementation of an Attachable class.""" + + _connector: Optional[Connector] = None + _registry: Optional[Registry] = None + + @classmethod + def bind(cls, connector: Connector): + cls._connector = connector + connector.connect_hook(cls.on_connect) + return cls + + @classmethod + def attach_to(cls, registry: Registry): + cls._registry = registry + return cls + + @classmethod + async def on_connect(cls, connection: AsyncConnection): + pass + + +class Attachable: + """ABC for a default implementation of an Attachable object.""" + + def __init__(self, *args, **kwargs): + self._connector: Optional[Connector] = None + self._registry: Optional[Registry] = None + + def bind(self, connector: Connector): + self._connector = connector + connector.connect_hook(self.on_connect) + return self + + def attach_to(self, registry: Registry): + self._registry = registry + return self + + async def on_connect(self, connection: AsyncConnection): + pass diff --git a/table.py b/table.py new file mode 100644 index 0000000..e20647e --- /dev/null +++ b/table.py @@ -0,0 +1,95 @@ +from typing import Optional +from psycopg.rows import DictRow +from psycopg import sql + +from . import queries as q +from .connector import Connector +from .registry import Registry + + +class Table: + """ + Transparent interface to a single table structure in the database. + Contains standard methods to access the table. + """ + + def __init__(self, name, *args, schema='public', **kwargs): + self.name: str = name + self.schema: str = schema + self.connector: Connector = None + + @property + def identifier(self): + if self.schema == 'public': + return sql.Identifier(self.name) + else: + return sql.Identifier(self.schema, self.name) + + def bind(self, connector: Connector): + self.connector = connector + return self + + def attach_to(self, registry: Registry): + self._registry = registry + return self + + def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]: + if data: + return data[0] + else: + return None + + def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]: + return q.Select( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]: + return q.Select( + self.identifier, + row_adapter=self._single_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]: + return q.Update( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]: + return q.Delete( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def insert(self, **column_values) -> q.Insert[DictRow]: + return q.Insert( + self.identifier, + row_adapter=self._single_query_adapter, + connector=self.connector + ).insert(column_values.keys(), column_values.values()) + + def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]: + return q.Insert( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).insert(*args, **kwargs) + +# def update_many(self, *args, **kwargs): +# with self.conn: +# return update_many(self.identifier, *args, **kwargs) + +# def upsert(self, *args, **kwargs): +# return upsert(self.identifier, *args, **kwargs)