rewrite (data): Increase Column flexibility.

New Column field types.
Allow Column to be an attribute of a non-rowmodel.
Add `references` field to Column.
Add logging for registry attach.
Add support for alternative join types.
This commit is contained in:
2022-11-18 08:41:11 +02:00
parent e528e8d0b6
commit 56f66ec7d4
6 changed files with 50 additions and 12 deletions

View File

@@ -29,4 +29,4 @@ class RegisterEnum(Attachable):
info = await EnumInfo.fetch(connection, self.name) info = await EnumInfo.fetch(connection, self.name)
if info is None: if info is None:
raise ValueError(f"Enum {self.name} not found in database.") raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=self.mapping) register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,5 +1,6 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, TYPE_CHECKING from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
from psycopg import sql from psycopg import sql
from datetime import datetime
from .base import RawExpr, Expression from .base import RawExpr, Expression
from .conditions import Condition, Joiner from .conditions import Condition, Joiner
@@ -103,16 +104,26 @@ if TYPE_CHECKING:
class Column(ColumnExpr, Generic[T]): class Column(ColumnExpr, Generic[T]):
def __init__(self, name: str = None, primary: bool = False): # type: ignore def __init__(self, name: Optional[str] = None,
primary: bool = False, references: Optional['Column'] = None,
type: Optional[Type[T]] = None):
self.primary = primary self.primary = primary
self.name: str = name self.references = references
self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None
self.tablename: Optional[str] = None
self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('') self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = () self.values = ()
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
self.name = self.name or name # Only allow setting the owner once
self.expr = sql.Identifier(owner._tablename_, self.name) if self.owner is None:
self.name = self.name or name
self.owner = owner
self.tablename = owner._tablename_
self.expr = sql.Identifier(self.tablename, self.name)
@overload @overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
@@ -126,8 +137,10 @@ class Column(ColumnExpr, Generic[T]):
# Get value from row data or session # Get value from row data or session
if obj is None: if obj is None:
return self return self
else: elif obj is self.owner:
return obj.data[self.name] return obj.data[self.name]
else:
return self
class Integer(Column[int]): class Integer(Column[int]):
@@ -140,3 +153,7 @@ class String(Column[str]):
class Bool(Column[bool]): class Bool(Column[bool]):
pass pass
class Timestamp(Column[datetime]):
pass

View File

@@ -1,3 +1,4 @@
from typing import TypeVar
import logging import logging
from collections import namedtuple from collections import namedtuple
@@ -10,6 +11,8 @@ logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author')) Version = namedtuple('Version', ('version', 'time', 'author'))
T = TypeVar('T', bound=Registry)
class Database(Connector): class Database(Connector):
# cursor_factory = AsyncLoggingCursor # cursor_factory = AsyncLoggingCursor
@@ -19,9 +22,14 @@ class Database(Connector):
self.registries: dict[str, Registry] = {} self.registries: dict[str, Registry] = {}
def load_registry(self, registry: Registry): def load_registry(self, registry: T) -> T:
logger.debug(
f"Loading and binding registry '{registry.name}'.",
extra={'action': f"Reg {registry.name}"}
)
registry.bind(self) registry.bind(self)
self.registries[registry.name] = registry self.registries[registry.name] = registry
return registry
async def version(self) -> Version: async def version(self) -> Version:
""" """

View File

@@ -124,7 +124,7 @@ class RowModel:
# Cache to keep track of registered Rows # Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore _cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] _key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None _connector: Optional[Connector] = None
_registry: Optional[Registry] = None _registry: Optional[Registry] = None
@@ -145,7 +145,8 @@ class RowModel:
columns[key] = value columns[key] = value
cls._columns_ = columns cls._columns_ = columns
cls._key_ = tuple(column.name for column in columns.values() if column.primary) if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls) cls.table = RowTable(cls._tablename_, cls)
if cls._cache_ is None: if cls._cache_ is None:
cls._cache_ = WeakValueDictionary() cls._cache_ = WeakValueDictionary()

View File

@@ -163,6 +163,14 @@ class WhereMixin(TableQuery[QueryResult]):
return None return None
class JOINTYPE(Enum):
LEFT = sql.SQL('LEFT JOIN')
RIGHT = sql.SQL('RIGHT JOIN')
INNER = sql.SQL('INNER JOIN')
OUTER = sql.SQL('OUTER JOIN')
FULLOUTER = sql.SQL('FULL OUTER JOIN')
class JoinMixin(TableQuery[QueryResult]): class JoinMixin(TableQuery[QueryResult]):
__slots__ = () __slots__ = ()
# TODO: Remember to add join slots to TableQuery # TODO: Remember to add join slots to TableQuery
@@ -174,6 +182,7 @@ class JoinMixin(TableQuery[QueryResult]):
def join(self, def join(self,
target: Union[str, Expression], target: Union[str, Expression],
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None, on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
join_type: JOINTYPE = JOINTYPE.INNER,
natural=False): natural=False):
available = (on is not None) + (using is not None) + natural available = (on is not None) + (using is not None) + natural
if available == 0: if available == 0:
@@ -181,7 +190,7 @@ class JoinMixin(TableQuery[QueryResult]):
if available > 1: if available > 1:
raise ValueError("Exactly one join format must be given for Query Join") raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(sql.SQL('JOIN'), ())] sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
if isinstance(target, str): if isinstance(target, str):
sections.append((sql.Identifier(target), ())) sections.append((sql.Identifier(target), ()))
else: else:
@@ -206,6 +215,9 @@ class JoinMixin(TableQuery[QueryResult]):
self._joins.append(expr) self._joins.append(expr)
return self return self
def leftjoin(self, *args, **kwargs):
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
@property @property
def _join_section(self) -> Optional[Expression]: def _join_section(self) -> Optional[Expression]:
if self._joins: if self._joins:

View File

@@ -17,7 +17,7 @@ class Registry:
def __init_subclass__(cls, name=None): def __init_subclass__(cls, name=None):
attached = [] attached = []
for name, member in cls.__dict__.items(): for _, member in cls.__dict__.items():
if isinstance(member, _Attachable): if isinstance(member, _Attachable):
attached.append(member) attached.append(member)
cls._attached = attached cls._attached = attached