diff --git a/bot/data/__init__.py b/bot/data/__init__.py index 3fcfd4a0..2fd48650 100644 --- a/bot/data/__init__.py +++ b/bot/data/__init__.py @@ -6,3 +6,4 @@ from .base import Expression, RawExpr from .columns import ColumnExpr, Column, Integer, String from .registry import Registry, AttachableClass, Attachable from .adapted import RegisterEnum +from .queries import ORDER, NULLS diff --git a/bot/data/columns.py b/bot/data/columns.py index ba5a3f5c..8a14b27e 100644 --- a/bot/data/columns.py +++ b/bot/data/columns.py @@ -115,7 +115,7 @@ class Column(ColumnExpr, Generic[T]): self.expr = sql.Identifier(owner._tablename_, self.name) @overload - def __get__(self: 'Column[T]', obj: None, objtype: None) -> 'Column[T]': + def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': ... @overload diff --git a/bot/data/models.py b/bot/data/models.py index 0b692d69..38db7a2f 100644 --- a/bot/data/models.py +++ b/bot/data/models.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Type, Any, Optional, Generic, Union, Mapping +from typing import TypeVar, Type, Optional, Generic, Union # from typing_extensions import Self from weakref import WeakValueDictionary from collections.abc import MutableMapping @@ -170,6 +170,12 @@ class RowModel: def __init__(self, data): self.data = data + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + @classmethod def bind(cls, connector: Connector): if cls.table is None: diff --git a/bot/data/queries.py b/bot/data/queries.py index edad381e..7e50a0cb 100644 --- a/bot/data/queries.py +++ b/bot/data/queries.py @@ -1,4 +1,6 @@ from typing import Optional, TypeVar, Any, Callable, Generic, List, Union +from enum import Enum +from itertools import chain from psycopg import AsyncConnection, AsyncCursor from psycopg import sql from psycopg.rows import DictRow @@ -23,7 +25,7 @@ class Query(Generic[QueryResult]): """ ABC for an executable query statement. """ - __slots__ = ('conn', 'cursor', '_adapter', 'connector') + __slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result') _adapter: Callable[..., QueryResult] @@ -37,6 +39,8 @@ class Query(Generic[QueryResult]): else: self._adapter = self._no_adapter + self.result: Optional[QueryResult] = None + def bind(self, connector: Connector): self.connector = connector return self @@ -74,7 +78,8 @@ class Query(Generic[QueryResult]): # ) await cursor.execute(sql.Composed((query,)), values) data = await cursor.fetchall() - return self._adapter(*data) + self.result = self._adapter(*data) + return self.result async def execute(self, cursor=None) -> QueryResult: """ @@ -258,27 +263,50 @@ class LimitMixin(TableQuery[QueryResult]): return None +class ORDER(Enum): + ASC = sql.SQL('ASC') + DESC = sql.SQL('DESC') + + +class NULLS(Enum): + FIRST = sql.SQL('NULLS FIRST') + LAST = sql.SQL('NULLS LAST') + + class OrderMixin(TableQuery[QueryResult]): __slots__ = () def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._order: Optional[sql.Composable] = None + self._order: list[Expression] = [] - def order_by(self, expression, direction=None, nulls=None): + def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None): """ Add a single sort expression to the query. This method stacks. """ - # TODO: Accept a ColumnExpression, string, or sql Composable - # TODO: Enums for direction (ORDER.ASC or ORDER.DESC) and nulls (NULLS.FIRST, NULLS.LAST) - ... + if isinstance(expr, Expression): + string, values = expr.as_tuple() + else: + string = sql.Identifier(expr) + values = () + + parts = [string] + if direction is not None: + parts.append(direction.value) + if nulls is not None: + parts.append(nulls.value) + + order_string = sql.SQL(' ').join(parts) + self._order.append(RawExpr(order_string, values)) + return self @property def _order_section(self) -> Optional[Expression]: if self._order is not None: - return RawExpr(sql.SQL("ORDER BY {}").format(self._order), ()) + expr = RawExpr.join(*self._order, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("ORDER BY {}").formt(expr.expr) else: return None @@ -314,6 +342,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): self._columns = columns self._values = values + return self def build(self): columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) @@ -324,14 +353,14 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): # TODO: Check efficiency of inserting multiple values like this # Also implement a Copy query - base = sql.Sql("INSERT INTO {table} {columns} VALUES {values_str}").format( + base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( table=sql.Identifier(self.table), columns=columns, values_str=values_str ) sections = [ - RawExpr(base), + RawExpr(base, tuple(chain(*self._values))), self._extra_section, RawExpr(sql.SQL('RETURNING *')) ] @@ -440,7 +469,7 @@ class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]): if isinstance(value, Expression): value_tup = value.as_tuple() else: - value_tup = (sql.Placeholder(), value) + value_tup = (sql.Placeholder(), (value,)) exprs.append( RawExpr.join_tuples(