rewrite (data): Increase Column flexibility.
New Column field types. Allow Column to be an attribute of a non-rowmodel. Add `references` field to Column. Add logging for registry attach. Add support for alternative join types.
This commit is contained in:
@@ -29,4 +29,4 @@ class RegisterEnum(Attachable):
|
|||||||
info = await EnumInfo.fetch(connection, self.name)
|
info = await EnumInfo.fetch(connection, self.name)
|
||||||
if info is None:
|
if info is None:
|
||||||
raise ValueError(f"Enum {self.name} not found in database.")
|
raise ValueError(f"Enum {self.name} not found in database.")
|
||||||
register_enum(info, connection, self.enum, mapping=self.mapping)
|
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Union, TypeVar, Generic, Type, overload, TYPE_CHECKING
|
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
|
||||||
from psycopg import sql
|
from psycopg import sql
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .base import RawExpr, Expression
|
from .base import RawExpr, Expression
|
||||||
from .conditions import Condition, Joiner
|
from .conditions import Condition, Joiner
|
||||||
@@ -103,16 +104,26 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class Column(ColumnExpr, Generic[T]):
|
class Column(ColumnExpr, Generic[T]):
|
||||||
def __init__(self, name: str = None, primary: bool = False): # type: ignore
|
def __init__(self, name: Optional[str] = None,
|
||||||
|
primary: bool = False, references: Optional['Column'] = None,
|
||||||
|
type: Optional[Type[T]] = None):
|
||||||
self.primary = primary
|
self.primary = primary
|
||||||
self.name: str = name
|
self.references = references
|
||||||
|
self.name: str = name # type: ignore
|
||||||
|
self.owner: Optional['RowModel'] = None
|
||||||
|
self.tablename: Optional[str] = None
|
||||||
|
self._type = type
|
||||||
|
|
||||||
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
||||||
self.values = ()
|
self.values = ()
|
||||||
|
|
||||||
def __set_name__(self, owner, name):
|
def __set_name__(self, owner, name):
|
||||||
self.name = self.name or name
|
# Only allow setting the owner once
|
||||||
self.expr = sql.Identifier(owner._tablename_, self.name)
|
if self.owner is None:
|
||||||
|
self.name = self.name or name
|
||||||
|
self.owner = owner
|
||||||
|
self.tablename = owner._tablename_
|
||||||
|
self.expr = sql.Identifier(self.tablename, self.name)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
||||||
@@ -126,8 +137,10 @@ class Column(ColumnExpr, Generic[T]):
|
|||||||
# Get value from row data or session
|
# Get value from row data or session
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return self
|
return self
|
||||||
else:
|
elif obj is self.owner:
|
||||||
return obj.data[self.name]
|
return obj.data[self.name]
|
||||||
|
else:
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Integer(Column[int]):
|
class Integer(Column[int]):
|
||||||
@@ -140,3 +153,7 @@ class String(Column[str]):
|
|||||||
|
|
||||||
class Bool(Column[bool]):
|
class Bool(Column[bool]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Timestamp(Column[datetime]):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import TypeVar
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
@@ -10,6 +11,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
Version = namedtuple('Version', ('version', 'time', 'author'))
|
Version = namedtuple('Version', ('version', 'time', 'author'))
|
||||||
|
|
||||||
|
T = TypeVar('T', bound=Registry)
|
||||||
|
|
||||||
|
|
||||||
class Database(Connector):
|
class Database(Connector):
|
||||||
# cursor_factory = AsyncLoggingCursor
|
# cursor_factory = AsyncLoggingCursor
|
||||||
@@ -19,9 +22,14 @@ class Database(Connector):
|
|||||||
|
|
||||||
self.registries: dict[str, Registry] = {}
|
self.registries: dict[str, Registry] = {}
|
||||||
|
|
||||||
def load_registry(self, registry: Registry):
|
def load_registry(self, registry: T) -> T:
|
||||||
|
logger.debug(
|
||||||
|
f"Loading and binding registry '{registry.name}'.",
|
||||||
|
extra={'action': f"Reg {registry.name}"}
|
||||||
|
)
|
||||||
registry.bind(self)
|
registry.bind(self)
|
||||||
self.registries[registry.name] = registry
|
self.registries[registry.name] = registry
|
||||||
|
return registry
|
||||||
|
|
||||||
async def version(self) -> Version:
|
async def version(self) -> Version:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ class RowModel:
|
|||||||
# Cache to keep track of registered Rows
|
# Cache to keep track of registered Rows
|
||||||
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
||||||
|
|
||||||
_key_: tuple[str, ...]
|
_key_: tuple[str, ...] = ()
|
||||||
_connector: Optional[Connector] = None
|
_connector: Optional[Connector] = None
|
||||||
_registry: Optional[Registry] = None
|
_registry: Optional[Registry] = None
|
||||||
|
|
||||||
@@ -145,7 +145,8 @@ class RowModel:
|
|||||||
columns[key] = value
|
columns[key] = value
|
||||||
|
|
||||||
cls._columns_ = columns
|
cls._columns_ = columns
|
||||||
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
if not cls._key_:
|
||||||
|
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
||||||
cls.table = RowTable(cls._tablename_, cls)
|
cls.table = RowTable(cls._tablename_, cls)
|
||||||
if cls._cache_ is None:
|
if cls._cache_ is None:
|
||||||
cls._cache_ = WeakValueDictionary()
|
cls._cache_ = WeakValueDictionary()
|
||||||
|
|||||||
@@ -163,6 +163,14 @@ class WhereMixin(TableQuery[QueryResult]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class JOINTYPE(Enum):
|
||||||
|
LEFT = sql.SQL('LEFT JOIN')
|
||||||
|
RIGHT = sql.SQL('RIGHT JOIN')
|
||||||
|
INNER = sql.SQL('INNER JOIN')
|
||||||
|
OUTER = sql.SQL('OUTER JOIN')
|
||||||
|
FULLOUTER = sql.SQL('FULL OUTER JOIN')
|
||||||
|
|
||||||
|
|
||||||
class JoinMixin(TableQuery[QueryResult]):
|
class JoinMixin(TableQuery[QueryResult]):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
# TODO: Remember to add join slots to TableQuery
|
# TODO: Remember to add join slots to TableQuery
|
||||||
@@ -174,6 +182,7 @@ class JoinMixin(TableQuery[QueryResult]):
|
|||||||
def join(self,
|
def join(self,
|
||||||
target: Union[str, Expression],
|
target: Union[str, Expression],
|
||||||
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
||||||
|
join_type: JOINTYPE = JOINTYPE.INNER,
|
||||||
natural=False):
|
natural=False):
|
||||||
available = (on is not None) + (using is not None) + natural
|
available = (on is not None) + (using is not None) + natural
|
||||||
if available == 0:
|
if available == 0:
|
||||||
@@ -181,7 +190,7 @@ class JoinMixin(TableQuery[QueryResult]):
|
|||||||
if available > 1:
|
if available > 1:
|
||||||
raise ValueError("Exactly one join format must be given for Query Join")
|
raise ValueError("Exactly one join format must be given for Query Join")
|
||||||
|
|
||||||
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(sql.SQL('JOIN'), ())]
|
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
|
||||||
if isinstance(target, str):
|
if isinstance(target, str):
|
||||||
sections.append((sql.Identifier(target), ()))
|
sections.append((sql.Identifier(target), ()))
|
||||||
else:
|
else:
|
||||||
@@ -206,6 +215,9 @@ class JoinMixin(TableQuery[QueryResult]):
|
|||||||
self._joins.append(expr)
|
self._joins.append(expr)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def leftjoin(self, *args, **kwargs):
|
||||||
|
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _join_section(self) -> Optional[Expression]:
|
def _join_section(self) -> Optional[Expression]:
|
||||||
if self._joins:
|
if self._joins:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class Registry:
|
|||||||
|
|
||||||
def __init_subclass__(cls, name=None):
|
def __init_subclass__(cls, name=None):
|
||||||
attached = []
|
attached = []
|
||||||
for name, member in cls.__dict__.items():
|
for _, member in cls.__dict__.items():
|
||||||
if isinstance(member, _Attachable):
|
if isinstance(member, _Attachable):
|
||||||
attached.append(member)
|
attached.append(member)
|
||||||
cls._attached = attached
|
cls._attached = attached
|
||||||
|
|||||||
Reference in New Issue
Block a user