324 lines
8.9 KiB
Python
324 lines
8.9 KiB
Python
from typing import TypeVar, Type, Optional, Generic, Union
|
|
# from typing_extensions import Self
|
|
from weakref import WeakValueDictionary
|
|
from collections.abc import MutableMapping
|
|
|
|
from psycopg.rows import DictRow
|
|
|
|
from .table import Table
|
|
from .columns import Column
|
|
from . import queries as q
|
|
from .connector import Connector
|
|
from .registry import Registry
|
|
|
|
|
|
RowT = TypeVar('RowT', bound='RowModel')
|
|
|
|
|
|
class MISSING:
|
|
__slots__ = ('oid',)
|
|
|
|
def __init__(self, oid):
|
|
self.oid = oid
|
|
|
|
|
|
class RowTable(Table, Generic[RowT]):
|
|
__slots__ = (
|
|
'model',
|
|
)
|
|
|
|
def __init__(self, name, model: Type[RowT], **kwargs):
|
|
super().__init__(name, **kwargs)
|
|
self.model = model
|
|
|
|
@property
|
|
def columns(self):
|
|
return self.model._columns_
|
|
|
|
@property
|
|
def id_col(self):
|
|
return self.model._key_
|
|
|
|
@property
|
|
def row_cache(self):
|
|
return self.model._cache_
|
|
|
|
def _many_query_adapter(self, *data):
|
|
self.model._make_rows(*data)
|
|
return data
|
|
|
|
def _single_query_adapter(self, *data):
|
|
if data:
|
|
self.model._make_rows(*data)
|
|
return data[0]
|
|
else:
|
|
return None
|
|
|
|
def _delete_query_adapter(self, *data):
|
|
self.model._delete_rows(*data)
|
|
return data
|
|
|
|
# New methods to fetch and create rows
|
|
async def create_row(self, *args, **kwargs) -> RowT:
|
|
data = await super().insert(*args, **kwargs)
|
|
return self.model._make_rows(data)[0]
|
|
|
|
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
|
# TODO: Handle list of rowids here?
|
|
return q.Select(
|
|
self.identifier,
|
|
row_adapter=self.model._make_rows,
|
|
connector=self.connector
|
|
).where(*args, **kwargs)
|
|
|
|
|
|
WK = TypeVar('WK')
|
|
WV = TypeVar('WV')
|
|
|
|
|
|
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
|
|
def __init__(self, ref_cache):
|
|
self.ref_cache = ref_cache
|
|
self.weak_cache = WeakValueDictionary()
|
|
|
|
def __getitem__(self, key):
|
|
value = self.weak_cache[key]
|
|
self.ref_cache[key] = value
|
|
return value
|
|
|
|
def __setitem__(self, key, value):
|
|
self.weak_cache[key] = value
|
|
self.ref_cache[key] = value
|
|
|
|
def __delitem__(self, key):
|
|
del self.weak_cache[key]
|
|
try:
|
|
del self.ref_cache[key]
|
|
except KeyError:
|
|
pass
|
|
|
|
def __contains__(self, key):
|
|
return key in self.weak_cache
|
|
|
|
def __iter__(self):
|
|
return iter(self.weak_cache)
|
|
|
|
def __len__(self):
|
|
return len(self.weak_cache)
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
def pop(self, key, default=None):
|
|
if key in self:
|
|
value = self[key]
|
|
del self[key]
|
|
else:
|
|
value = default
|
|
return value
|
|
|
|
|
|
# TODO: Implement getitem and setitem, for dynamic column access
|
|
class RowModel:
|
|
__slots__ = ('data',)
|
|
|
|
_schema_: str = 'public'
|
|
_tablename_: Optional[str] = None
|
|
_columns_: dict[str, Column] = {}
|
|
|
|
# Cache to keep track of registered Rows
|
|
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
|
|
|
_key_: tuple[str, ...] = ()
|
|
_connector: Optional[Connector] = None
|
|
_registry: Optional[Registry] = None
|
|
|
|
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
|
|
table: RowTable = None
|
|
|
|
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
|
|
"""
|
|
Set table, _columns_, and _key_.
|
|
"""
|
|
if table is not None:
|
|
cls._tablename_ = table
|
|
|
|
if cls._tablename_ is not None:
|
|
columns = {}
|
|
for key, value in cls.__dict__.items():
|
|
if isinstance(value, Column):
|
|
columns[key] = value
|
|
|
|
cls._columns_ = columns
|
|
if not cls._key_:
|
|
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
|
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
|
if cls._cache_ is None:
|
|
cls._cache_ = WeakValueDictionary()
|
|
|
|
def __new__(cls, data):
|
|
# Registry pattern.
|
|
# Ensure each rowid always refers to a single Model instance
|
|
if data is not None:
|
|
rowid = cls._id_from_data(data)
|
|
|
|
cache = cls._cache_
|
|
|
|
if (row := cache.get(rowid, None)) is not None:
|
|
obj = row
|
|
else:
|
|
obj = cache[rowid] = super().__new__(cls)
|
|
else:
|
|
obj = super().__new__(cls)
|
|
|
|
return obj
|
|
|
|
@classmethod
|
|
def as_tuple(cls):
|
|
return (cls.table.identifier, ())
|
|
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
def __getitem__(self, key):
|
|
return self.data[key]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.data[key] = value
|
|
|
|
@classmethod
|
|
def bind(cls, connector: Connector):
|
|
if cls.table is None:
|
|
raise ValueError("Cannot bind abstract RowModel")
|
|
cls._connector = connector
|
|
cls.table.bind(connector)
|
|
return cls
|
|
|
|
@classmethod
|
|
def attach_to(cls, registry: Registry):
|
|
cls._registry = registry
|
|
return cls
|
|
|
|
@property
|
|
def _dict_(self):
|
|
return {key: self.data[key] for key in self._key_}
|
|
|
|
@property
|
|
def _rowid_(self):
|
|
return tuple(self.data[key] for key in self._key_)
|
|
|
|
def __repr__(self):
|
|
return "{}.{}({})".format(
|
|
self.table.schema,
|
|
self.table.name,
|
|
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
|
)
|
|
|
|
@classmethod
|
|
def _id_from_data(cls, data):
|
|
return tuple(data[key] for key in cls._key_)
|
|
|
|
@classmethod
|
|
def _dict_from_id(cls, rowid):
|
|
return dict(zip(cls._key_, rowid))
|
|
|
|
@classmethod
|
|
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
|
|
"""
|
|
Create or retrieve Row objects for each provided data row.
|
|
If the rows already exist in cache, updates the cached row.
|
|
"""
|
|
# TODO: Handle partial row data here somehow?
|
|
rows = [cls(data_row) for data_row in data_rows]
|
|
return rows
|
|
|
|
@classmethod
|
|
def _delete_rows(cls, *data_rows):
|
|
"""
|
|
Remove the given rows from cache, if they exist.
|
|
May be extended to handle object deletion.
|
|
"""
|
|
cache = cls._cache_
|
|
|
|
for data_row in data_rows:
|
|
rowid = cls._id_from_data(data_row)
|
|
cache.pop(rowid, None)
|
|
|
|
@classmethod
|
|
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
|
|
return await cls.table.create_row(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def fetch_where(cls: Type[RowT], *args, **kwargs):
|
|
return cls.table.fetch_rows_where(*args, **kwargs)
|
|
|
|
@classmethod
|
|
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
|
|
"""
|
|
Fetch the row with the given id, retrieving from cache where possible.
|
|
"""
|
|
row = cls._cache_.get(rowid, None) if cached else None
|
|
if row is None:
|
|
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
|
|
row = rows[0] if rows else None
|
|
if row is None:
|
|
cls._cache_[rowid] = cls(None)
|
|
elif row.data is None:
|
|
row = None
|
|
|
|
return row
|
|
|
|
@classmethod
|
|
async def fetch_or_create(cls, *rowid, **kwargs):
|
|
"""
|
|
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
|
"""
|
|
if rowid:
|
|
row = await cls.fetch(*rowid)
|
|
else:
|
|
rows = await cls.fetch_where(**kwargs).limit(1)
|
|
row = rows[0] if rows else None
|
|
|
|
if row is None:
|
|
creation_kwargs = kwargs
|
|
if rowid:
|
|
creation_kwargs.update(cls._dict_from_id(rowid))
|
|
row = await cls.create(**creation_kwargs)
|
|
return row
|
|
|
|
async def refresh(self: RowT) -> Optional[RowT]:
|
|
"""
|
|
Refresh this Row from data.
|
|
|
|
The return value may be `None` if the row was deleted.
|
|
"""
|
|
rows = await self.table.select_where(**self._dict_)
|
|
if not rows:
|
|
return None
|
|
else:
|
|
self.data = rows[0]
|
|
return self
|
|
|
|
async def update(self: RowT, **values) -> Optional[RowT]:
|
|
"""
|
|
Update this Row with the given values.
|
|
|
|
Internally passes the provided `values` to the `update` Query.
|
|
The return value may be `None` if the row was deleted.
|
|
"""
|
|
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
|
|
if not data:
|
|
return None
|
|
else:
|
|
return data[0]
|
|
|
|
async def delete(self: RowT) -> Optional[RowT]:
|
|
"""
|
|
Delete this Row.
|
|
"""
|
|
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
|
|
return data[0] if data is not None else None
|