diff --git a/bot/data/adapted.py b/bot/data/adapted.py index 12bebd4e..a4344f19 100644 --- a/bot/data/adapted.py +++ b/bot/data/adapted.py @@ -29,4 +29,4 @@ class RegisterEnum(Attachable): info = await EnumInfo.fetch(connection, self.name) if info is None: raise ValueError(f"Enum {self.name} not found in database.") - register_enum(info, connection, self.enum, mapping=self.mapping) + register_enum(info, connection, self.enum, mapping=list(self.mapping.items())) diff --git a/bot/data/columns.py b/bot/data/columns.py index 2fe3b498..60626f17 100644 --- a/bot/data/columns.py +++ b/bot/data/columns.py @@ -1,5 +1,6 @@ -from typing import Any, Union, TypeVar, Generic, Type, overload, TYPE_CHECKING +from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING from psycopg import sql +from datetime import datetime from .base import RawExpr, Expression from .conditions import Condition, Joiner @@ -103,16 +104,26 @@ if TYPE_CHECKING: class Column(ColumnExpr, Generic[T]): - def __init__(self, name: str = None, primary: bool = False): # type: ignore + def __init__(self, name: Optional[str] = None, + primary: bool = False, references: Optional['Column'] = None, + type: Optional[Type[T]] = None): self.primary = primary - self.name: str = name + self.references = references + self.name: str = name # type: ignore + self.owner: Optional['RowModel'] = None + self.tablename: Optional[str] = None + self._type = type self.expr = sql.Identifier(name) if name else sql.SQL('') self.values = () def __set_name__(self, owner, name): - self.name = self.name or name - self.expr = sql.Identifier(owner._tablename_, self.name) + # Only allow setting the owner once + if self.owner is None: + self.name = self.name or name + self.owner = owner + self.tablename = owner._tablename_ + self.expr = sql.Identifier(self.tablename, self.name) @overload def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': @@ -126,8 +137,10 @@ class Column(ColumnExpr, Generic[T]): # Get value from row data or session if obj is None: return self - else: + elif obj is self.owner: return obj.data[self.name] + else: + return self class Integer(Column[int]): @@ -140,3 +153,7 @@ class String(Column[str]): class Bool(Column[bool]): pass + + +class Timestamp(Column[datetime]): + pass diff --git a/bot/data/database.py b/bot/data/database.py index 6d6b89e4..039a0b0e 100644 --- a/bot/data/database.py +++ b/bot/data/database.py @@ -1,3 +1,4 @@ +from typing import TypeVar import logging from collections import namedtuple @@ -10,6 +11,8 @@ logger = logging.getLogger(__name__) Version = namedtuple('Version', ('version', 'time', 'author')) +T = TypeVar('T', bound=Registry) + class Database(Connector): # cursor_factory = AsyncLoggingCursor @@ -19,9 +22,14 @@ class Database(Connector): self.registries: dict[str, Registry] = {} - def load_registry(self, registry: Registry): + def load_registry(self, registry: T) -> T: + logger.debug( + f"Loading and binding registry '{registry.name}'.", + extra={'action': f"Reg {registry.name}"} + ) registry.bind(self) self.registries[registry.name] = registry + return registry async def version(self) -> Version: """ diff --git a/bot/data/models.py b/bot/data/models.py index a5a15651..2ebb470b 100644 --- a/bot/data/models.py +++ b/bot/data/models.py @@ -124,7 +124,7 @@ class RowModel: # Cache to keep track of registered Rows _cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore - _key_: tuple[str, ...] + _key_: tuple[str, ...] = () _connector: Optional[Connector] = None _registry: Optional[Registry] = None @@ -145,7 +145,8 @@ class RowModel: columns[key] = value cls._columns_ = columns - cls._key_ = tuple(column.name for column in columns.values() if column.primary) + if not cls._key_: + cls._key_ = tuple(column.name for column in columns.values() if column.primary) cls.table = RowTable(cls._tablename_, cls) if cls._cache_ is None: cls._cache_ = WeakValueDictionary() diff --git a/bot/data/queries.py b/bot/data/queries.py index 118767a4..5d7bc0a0 100644 --- a/bot/data/queries.py +++ b/bot/data/queries.py @@ -163,6 +163,14 @@ class WhereMixin(TableQuery[QueryResult]): return None +class JOINTYPE(Enum): + LEFT = sql.SQL('LEFT JOIN') + RIGHT = sql.SQL('RIGHT JOIN') + INNER = sql.SQL('INNER JOIN') + OUTER = sql.SQL('OUTER JOIN') + FULLOUTER = sql.SQL('FULL OUTER JOIN') + + class JoinMixin(TableQuery[QueryResult]): __slots__ = () # TODO: Remember to add join slots to TableQuery @@ -174,6 +182,7 @@ class JoinMixin(TableQuery[QueryResult]): def join(self, target: Union[str, Expression], on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None, + join_type: JOINTYPE = JOINTYPE.INNER, natural=False): available = (on is not None) + (using is not None) + natural if available == 0: @@ -181,7 +190,7 @@ class JoinMixin(TableQuery[QueryResult]): if available > 1: raise ValueError("Exactly one join format must be given for Query Join") - sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(sql.SQL('JOIN'), ())] + sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())] if isinstance(target, str): sections.append((sql.Identifier(target), ())) else: @@ -206,6 +215,9 @@ class JoinMixin(TableQuery[QueryResult]): self._joins.append(expr) return self + def leftjoin(self, *args, **kwargs): + return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs) + @property def _join_section(self) -> Optional[Expression]: if self._joins: diff --git a/bot/data/registry.py b/bot/data/registry.py index 0d8fa84d..c130d0f3 100644 --- a/bot/data/registry.py +++ b/bot/data/registry.py @@ -17,7 +17,7 @@ class Registry: def __init_subclass__(cls, name=None): attached = [] - for name, member in cls.__dict__.items(): + for _, member in cls.__dict__.items(): if isinstance(member, _Attachable): attached.append(member) cls._attached = attached