Files
psqlmapper/models.py
2025-07-31 07:39:53 +10:00

324 lines
8.9 KiB
Python

from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
if data:
self.model._make_rows(*data)
return data[0]
else:
return None
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.identifier,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
WK = TypeVar('WK')
WV = TypeVar('WV')
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable = None
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
@classmethod
def as_tuple(cls):
return (cls.table.identifier, ())
def __init__(self, data):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}.{}({})".format(
self.table.schema,
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None) if cached else None
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None