Initial Commit
This commit is contained in:
323
models.py
Normal file
323
models.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from typing import TypeVar, Type, Optional, Generic, Union
|
||||
# from typing_extensions import Self
|
||||
from weakref import WeakValueDictionary
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
from psycopg.rows import DictRow
|
||||
|
||||
from .table import Table
|
||||
from .columns import Column
|
||||
from . import queries as q
|
||||
from .connector import Connector
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
RowT = TypeVar('RowT', bound='RowModel')
|
||||
|
||||
|
||||
class MISSING:
|
||||
__slots__ = ('oid',)
|
||||
|
||||
def __init__(self, oid):
|
||||
self.oid = oid
|
||||
|
||||
|
||||
class RowTable(Table, Generic[RowT]):
|
||||
__slots__ = (
|
||||
'model',
|
||||
)
|
||||
|
||||
def __init__(self, name, model: Type[RowT], **kwargs):
|
||||
super().__init__(name, **kwargs)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self.model._columns_
|
||||
|
||||
@property
|
||||
def id_col(self):
|
||||
return self.model._key_
|
||||
|
||||
@property
|
||||
def row_cache(self):
|
||||
return self.model._cache_
|
||||
|
||||
def _many_query_adapter(self, *data):
|
||||
self.model._make_rows(*data)
|
||||
return data
|
||||
|
||||
def _single_query_adapter(self, *data):
|
||||
if data:
|
||||
self.model._make_rows(*data)
|
||||
return data[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _delete_query_adapter(self, *data):
|
||||
self.model._delete_rows(*data)
|
||||
return data
|
||||
|
||||
# New methods to fetch and create rows
|
||||
async def create_row(self, *args, **kwargs) -> RowT:
|
||||
data = await super().insert(*args, **kwargs)
|
||||
return self.model._make_rows(data)[0]
|
||||
|
||||
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
||||
# TODO: Handle list of rowids here?
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self.model._make_rows,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
|
||||
WK = TypeVar('WK')
|
||||
WV = TypeVar('WV')
|
||||
|
||||
|
||||
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
|
||||
def __init__(self, ref_cache):
|
||||
self.ref_cache = ref_cache
|
||||
self.weak_cache = WeakValueDictionary()
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self.weak_cache[key]
|
||||
self.ref_cache[key] = value
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.weak_cache[key] = value
|
||||
self.ref_cache[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.weak_cache[key]
|
||||
try:
|
||||
del self.ref_cache[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.weak_cache
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.weak_cache)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.weak_cache)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def pop(self, key, default=None):
|
||||
if key in self:
|
||||
value = self[key]
|
||||
del self[key]
|
||||
else:
|
||||
value = default
|
||||
return value
|
||||
|
||||
|
||||
# TODO: Implement getitem and setitem, for dynamic column access
|
||||
class RowModel:
|
||||
__slots__ = ('data',)
|
||||
|
||||
_schema_: str = 'public'
|
||||
_tablename_: Optional[str] = None
|
||||
_columns_: dict[str, Column] = {}
|
||||
|
||||
# Cache to keep track of registered Rows
|
||||
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
||||
|
||||
_key_: tuple[str, ...] = ()
|
||||
_connector: Optional[Connector] = None
|
||||
_registry: Optional[Registry] = None
|
||||
|
||||
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
|
||||
table: RowTable = None
|
||||
|
||||
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
|
||||
"""
|
||||
Set table, _columns_, and _key_.
|
||||
"""
|
||||
if table is not None:
|
||||
cls._tablename_ = table
|
||||
|
||||
if cls._tablename_ is not None:
|
||||
columns = {}
|
||||
for key, value in cls.__dict__.items():
|
||||
if isinstance(value, Column):
|
||||
columns[key] = value
|
||||
|
||||
cls._columns_ = columns
|
||||
if not cls._key_:
|
||||
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
||||
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
||||
if cls._cache_ is None:
|
||||
cls._cache_ = WeakValueDictionary()
|
||||
|
||||
def __new__(cls, data):
|
||||
# Registry pattern.
|
||||
# Ensure each rowid always refers to a single Model instance
|
||||
if data is not None:
|
||||
rowid = cls._id_from_data(data)
|
||||
|
||||
cache = cls._cache_
|
||||
|
||||
if (row := cache.get(rowid, None)) is not None:
|
||||
obj = row
|
||||
else:
|
||||
obj = cache[rowid] = super().__new__(cls)
|
||||
else:
|
||||
obj = super().__new__(cls)
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def as_tuple(cls):
|
||||
return (cls.table.identifier, ())
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
@classmethod
|
||||
def bind(cls, connector: Connector):
|
||||
if cls.table is None:
|
||||
raise ValueError("Cannot bind abstract RowModel")
|
||||
cls._connector = connector
|
||||
cls.table.bind(connector)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def attach_to(cls, registry: Registry):
|
||||
cls._registry = registry
|
||||
return cls
|
||||
|
||||
@property
|
||||
def _dict_(self):
|
||||
return {key: self.data[key] for key in self._key_}
|
||||
|
||||
@property
|
||||
def _rowid_(self):
|
||||
return tuple(self.data[key] for key in self._key_)
|
||||
|
||||
def __repr__(self):
|
||||
return "{}.{}({})".format(
|
||||
self.table.schema,
|
||||
self.table.name,
|
||||
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _id_from_data(cls, data):
|
||||
return tuple(data[key] for key in cls._key_)
|
||||
|
||||
@classmethod
|
||||
def _dict_from_id(cls, rowid):
|
||||
return dict(zip(cls._key_, rowid))
|
||||
|
||||
@classmethod
|
||||
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
|
||||
"""
|
||||
Create or retrieve Row objects for each provided data row.
|
||||
If the rows already exist in cache, updates the cached row.
|
||||
"""
|
||||
# TODO: Handle partial row data here somehow?
|
||||
rows = [cls(data_row) for data_row in data_rows]
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def _delete_rows(cls, *data_rows):
|
||||
"""
|
||||
Remove the given rows from cache, if they exist.
|
||||
May be extended to handle object deletion.
|
||||
"""
|
||||
cache = cls._cache_
|
||||
|
||||
for data_row in data_rows:
|
||||
rowid = cls._id_from_data(data_row)
|
||||
cache.pop(rowid, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
|
||||
return await cls.table.create_row(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def fetch_where(cls: Type[RowT], *args, **kwargs):
|
||||
return cls.table.fetch_rows_where(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
|
||||
"""
|
||||
Fetch the row with the given id, retrieving from cache where possible.
|
||||
"""
|
||||
row = cls._cache_.get(rowid, None) if cached else None
|
||||
if row is None:
|
||||
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
|
||||
row = rows[0] if rows else None
|
||||
if row is None:
|
||||
cls._cache_[rowid] = cls(None)
|
||||
elif row.data is None:
|
||||
row = None
|
||||
|
||||
return row
|
||||
|
||||
@classmethod
|
||||
async def fetch_or_create(cls, *rowid, **kwargs):
|
||||
"""
|
||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||
"""
|
||||
if rowid:
|
||||
row = await cls.fetch(*rowid)
|
||||
else:
|
||||
rows = await cls.fetch_where(**kwargs).limit(1)
|
||||
row = rows[0] if rows else None
|
||||
|
||||
if row is None:
|
||||
creation_kwargs = kwargs
|
||||
if rowid:
|
||||
creation_kwargs.update(cls._dict_from_id(rowid))
|
||||
row = await cls.create(**creation_kwargs)
|
||||
return row
|
||||
|
||||
async def refresh(self: RowT) -> Optional[RowT]:
|
||||
"""
|
||||
Refresh this Row from data.
|
||||
|
||||
The return value may be `None` if the row was deleted.
|
||||
"""
|
||||
rows = await self.table.select_where(**self._dict_)
|
||||
if not rows:
|
||||
return None
|
||||
else:
|
||||
self.data = rows[0]
|
||||
return self
|
||||
|
||||
async def update(self: RowT, **values) -> Optional[RowT]:
|
||||
"""
|
||||
Update this Row with the given values.
|
||||
|
||||
Internally passes the provided `values` to the `update` Query.
|
||||
The return value may be `None` if the row was deleted.
|
||||
"""
|
||||
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
|
||||
if not data:
|
||||
return None
|
||||
else:
|
||||
return data[0]
|
||||
|
||||
async def delete(self: RowT) -> Optional[RowT]:
|
||||
"""
|
||||
Delete this Row.
|
||||
"""
|
||||
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
|
||||
return data[0] if data is not None else None
|
||||
Reference in New Issue
Block a user