rewrite: Update data ORM.

Fix parsing bug in Update Query.
Fix parsing bug in Insert Query.

Add mapping interface to Model.
Implement OrderMixin, with ORDER and NULLS Enums.
Add `result` field to Query.
This commit is contained in:
2022-11-07 16:04:02 +02:00
parent 322f519640
commit 872e5fd71f
4 changed files with 49 additions and 13 deletions

View File

@@ -6,3 +6,4 @@ from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum from .adapted import RegisterEnum
from .queries import ORDER, NULLS

View File

@@ -115,7 +115,7 @@ class Column(ColumnExpr, Generic[T]):
self.expr = sql.Identifier(owner._tablename_, self.name) self.expr = sql.Identifier(owner._tablename_, self.name)
@overload @overload
def __get__(self: 'Column[T]', obj: None, objtype: None) -> 'Column[T]': def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
... ...
@overload @overload

View File

@@ -1,4 +1,4 @@
from typing import TypeVar, Type, Any, Optional, Generic, Union, Mapping from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self # from typing_extensions import Self
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from collections.abc import MutableMapping from collections.abc import MutableMapping
@@ -170,6 +170,12 @@ class RowModel:
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod @classmethod
def bind(cls, connector: Connector): def bind(cls, connector: Connector):
if cls.table is None: if cls.table is None:

View File

@@ -1,4 +1,6 @@
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from enum import Enum
from itertools import chain
from psycopg import AsyncConnection, AsyncCursor from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql from psycopg import sql
from psycopg.rows import DictRow from psycopg.rows import DictRow
@@ -23,7 +25,7 @@ class Query(Generic[QueryResult]):
""" """
ABC for an executable query statement. ABC for an executable query statement.
""" """
__slots__ = ('conn', 'cursor', '_adapter', 'connector') __slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
_adapter: Callable[..., QueryResult] _adapter: Callable[..., QueryResult]
@@ -37,6 +39,8 @@ class Query(Generic[QueryResult]):
else: else:
self._adapter = self._no_adapter self._adapter = self._no_adapter
self.result: Optional[QueryResult] = None
def bind(self, connector: Connector): def bind(self, connector: Connector):
self.connector = connector self.connector = connector
return self return self
@@ -74,7 +78,8 @@ class Query(Generic[QueryResult]):
# ) # )
await cursor.execute(sql.Composed((query,)), values) await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall() data = await cursor.fetchall()
return self._adapter(*data) self.result = self._adapter(*data)
return self.result
async def execute(self, cursor=None) -> QueryResult: async def execute(self, cursor=None) -> QueryResult:
""" """
@@ -258,27 +263,50 @@ class LimitMixin(TableQuery[QueryResult]):
return None return None
class ORDER(Enum):
ASC = sql.SQL('ASC')
DESC = sql.SQL('DESC')
class NULLS(Enum):
FIRST = sql.SQL('NULLS FIRST')
LAST = sql.SQL('NULLS LAST')
class OrderMixin(TableQuery[QueryResult]): class OrderMixin(TableQuery[QueryResult]):
__slots__ = () __slots__ = ()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._order: Optional[sql.Composable] = None self._order: list[Expression] = []
def order_by(self, expression, direction=None, nulls=None): def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
""" """
Add a single sort expression to the query. Add a single sort expression to the query.
This method stacks. This method stacks.
""" """
# TODO: Accept a ColumnExpression, string, or sql Composable if isinstance(expr, Expression):
# TODO: Enums for direction (ORDER.ASC or ORDER.DESC) and nulls (NULLS.FIRST, NULLS.LAST) string, values = expr.as_tuple()
... else:
string = sql.Identifier(expr)
values = ()
parts = [string]
if direction is not None:
parts.append(direction.value)
if nulls is not None:
parts.append(nulls.value)
order_string = sql.SQL(' ').join(parts)
self._order.append(RawExpr(order_string, values))
return self
@property @property
def _order_section(self) -> Optional[Expression]: def _order_section(self) -> Optional[Expression]:
if self._order is not None: if self._order is not None:
return RawExpr(sql.SQL("ORDER BY {}").format(self._order), ()) expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
expr.expr = sql.SQL("ORDER BY {}").formt(expr.expr)
else: else:
return None return None
@@ -314,6 +342,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
self._columns = columns self._columns = columns
self._values = values self._values = values
return self
def build(self): def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
@@ -324,14 +353,14 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
# TODO: Check efficiency of inserting multiple values like this # TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query # Also implement a Copy query
base = sql.Sql("INSERT INTO {table} {columns} VALUES {values_str}").format( base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=sql.Identifier(self.table), table=sql.Identifier(self.table),
columns=columns, columns=columns,
values_str=values_str values_str=values_str
) )
sections = [ sections = [
RawExpr(base), RawExpr(base, tuple(chain(*self._values))),
self._extra_section, self._extra_section,
RawExpr(sql.SQL('RETURNING *')) RawExpr(sql.SQL('RETURNING *'))
] ]
@@ -440,7 +469,7 @@ class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]):
if isinstance(value, Expression): if isinstance(value, Expression):
value_tup = value.as_tuple() value_tup = value.as_tuple()
else: else:
value_tup = (sql.Placeholder(), value) value_tup = (sql.Placeholder(), (value,))
exprs.append( exprs.append(
RawExpr.join_tuples( RawExpr.join_tuples(