rewrite: New data ORM.

This commit is contained in:
2022-11-02 07:23:51 +02:00
parent a5147323b5
commit 069c032e02
15 changed files with 1542 additions and 717 deletions

View File

@@ -1,5 +1,8 @@
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa
from .connection import conn # noqa
from .formatters import UpdateValue, UpdateValueAdd # noqa
from .interfaces import Table, RowTable, Row, tables # noqa
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa
from .conditions import Condition, condition
from .database import Database
from .models import RowModel, RowTable, WeakCache
from .table import Table
from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum

29
bot/data/adapted.py Normal file
View File

@@ -0,0 +1,29 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from .registry import Attachable, Registry
class RegisterEnum(Attachable):
def __init__(self, enum, name: Optional[str] = None, mapper=None):
super().__init__()
self.enum = enum
self.name = name or enum.__name__
self.mapping = mapper(enum) if mapper is not None else self._mapper()
def _mapper(self):
return {m: m.value[0] for m in self.enum}
def attach_to(self, registry: Registry):
self._registry = registry
registry.init_task(self.on_init)
return self
async def on_init(self, registry: Registry):
connection = await registry._conn.get_connection()
if connection is None:
raise ValueError("Cannot Init without connection.")
info = await EnumInfo.fetch(connection, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=self.mapping)

45
bot/data/base.py Normal file
View File

@@ -0,0 +1,45 @@
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from itertools import chain
from psycopg import sql
@runtime_checkable
class Expression(Protocol):
__slots__ = ()
@abstractmethod
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
raise NotImplementedError
class RawExpr(Expression):
__slots__ = ('expr', 'values')
expr: sql.Composable
values: tuple[Any, ...]
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
self.expr = expr
self.values = values
def as_tuple(self):
return (self.expr, self.values)
@classmethod
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
"""
Join a sequence of Expressions into a single RawExpr.
"""
tups = (
expression.as_tuple()
for expression in expressions
)
return cls.join_tuples(*tups, joiner=joiner)
@classmethod
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
exprs, values = zip(*tuples)
expr = joiner.join(exprs)
value = tuple(chain(*values))
return cls(expr, value)

142
bot/data/columns.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, TYPE_CHECKING
from psycopg import sql
from .base import RawExpr, Expression
from .conditions import Condition, Joiner
class ColumnExpr(RawExpr):
__slots__ = ()
def __lt__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column < Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LT, obj_expr)
cond_values = (*values, *obj_values)
else:
# column < Literal
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __le__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column <= Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LE, obj_expr)
cond_values = (*values, *obj_values)
else:
# column <= Literal
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __eq__(self, obj) -> Condition: # type: ignore[override]
return Condition._expression_equality(self, obj)
def __ne__(self, obj) -> Condition: # type: ignore[override]
return ~(self.__eq__(obj))
def __gt__(self, obj) -> Condition:
return ~(self.__le__(obj))
def __ge__(self, obj) -> Condition:
return ~(self.__lt__(obj))
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __sub__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __mul__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def CAST(self, target_type: sql.Composable):
return ColumnExpr(
sql.SQL("({}::{})").format(self.expr, target_type),
self.values
)
T = TypeVar('T')
if TYPE_CHECKING:
from .models import RowModel
class Column(ColumnExpr, Generic[T]):
def __init__(self, name: str = None, primary: bool = False):
self.primary = primary
self.name: str = name # type: ignore
self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = ()
def __set_name__(self, owner, name):
self.name = self.name or name
self.expr = sql.Identifier(owner._tablename_, self.name)
@overload
def __get__(self: 'Column[T]', obj: None, objtype: None) -> 'Column[T]':
...
@overload
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
...
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
# Get value from row data or session
if obj is None:
return self
else:
return obj.data[self.name]
class Integer(Column[int]):
pass
class String(Column[str]):
pass
class Bool(Column[bool]):
pass

View File

@@ -1,92 +1,212 @@
from .connection import _replace_char
# from meta import sharding
from typing import Any, Union
from enum import Enum
from itertools import chain
from psycopg import sql
from meta import sharding
from .base import Expression, RawExpr
class Condition:
"""
ABC representing a selection condition.
A Condition is a "logical" database expression, intended for use in Where statements.
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
"""
__slots__ = ()
def apply(self, key, values, conditions):
raise NotImplementedError
class NOT(Condition):
__slots__ = ('value',)
class Joiner(Enum):
EQUALS = ('=', '!=')
IS = ('IS', 'IS NOT')
LIKE = ('LIKE', 'NOT LIKE')
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
IN = ('IN', 'NOT IN')
LT = ('<', '>=')
LE = ('<=', '>')
NONE = ('', '')
def __init__(self, value):
self.value = value
def apply(self, key, values, conditions):
item = self.value
if isinstance(item, (list, tuple)):
if item:
conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
values.extend(item)
class Condition(Expression):
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
def __init__(self,
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
values: tuple[Any, ...] = (), negated=False
):
self.expr1 = expr1
self.joiner = joiner
self.negated = negated
self.expr2 = expr2
self.values = values
def as_tuple(self):
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
if self.negated and self.joiner is Joiner.NONE:
expr = sql.SQL("NOT ({})").format(expr)
return (expr, self.values)
@classmethod
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
"""
Construct a Condition from a sequence of Conditions,
together with some explicit column conditions.
"""
# TODO: Consider adding a _table identifier here so we can identify implicit columns
# Or just require subquery type conditions to always come from modelled tables.
implicit_conditions = (
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
)
return cls._and(*conditions, *implicit_conditions)
@classmethod
def _and(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _or(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _not(cls, condition: 'Condition'):
condition.negated = not condition.negated
return condition
@classmethod
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
# TODO: Check if this supports sbqueries
col_expr, col_values = column.as_tuple()
# TODO: Also support sql.SQL? For joins?
if isinstance(value, Expression):
# column = Expression
value_expr, value_values = value.as_tuple()
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
cond_values = (*col_values, *value_values)
elif isinstance(value, (tuple, list)):
# column in (...)
# TODO: Support expressions in value tuple?
if not value:
raise ValueError("Cannot create Condition from empty iterable!")
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
cond_exprs = (col_expr, Joiner.IN, value_expr)
cond_values = (*col_values, *value)
elif value is None:
# column IS NULL
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
cond_values = col_values
else:
raise ValueError("Cannot check an empty iterable!")
else:
conditions.append("{}!={}".format(key, _replace_char))
values.append(item)
# column = Literal
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
cond_values = (*col_values, value)
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __invert__(self) -> 'Condition':
self.negated = not self.negated
return self
def __and__(self, condition: 'Condition') -> 'Condition':
return self._and(self, condition)
def __or__(self, condition: 'Condition') -> 'Condition':
return self._or(self, condition)
class GEQ(Condition):
__slots__ = ('value',)
def __init__(self, value):
self.value = value
def apply(self, key, values, conditions):
item = self.value
if isinstance(item, (list, tuple)):
raise ValueError("Cannot apply GEQ condition to a list!")
else:
conditions.append("{} >= {}".format(key, _replace_char))
values.append(item)
# Helper method to simply condition construction
def condition(*args, **kwargs) -> Condition:
return Condition.construct(*args, **kwargs)
class LEQ(Condition):
__slots__ = ('value',)
def __init__(self, value):
self.value = value
def apply(self, key, values, conditions):
item = self.value
if isinstance(item, (list, tuple)):
raise ValueError("Cannot apply LEQ condition to a list!")
else:
conditions.append("{} <= {}".format(key, _replace_char))
values.append(item)
class Constant(Condition):
__slots__ = ('value',)
def __init__(self, value):
self.value = value
def apply(self, key, values, conditions):
conditions.append("{} {}".format(key, self.value))
class SHARDID(Condition):
__slots__ = ('shardid', 'shard_count')
def __init__(self, shardid, shard_count):
self.shardid = shardid
self.shard_count = shard_count
def apply(self, key, values, conditions):
if self.shard_count > 1:
conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
values.append(self.shardid)
THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
NULL = Constant('IS NULL')
NOTNULL = Constant('IS NOT NULL')
# class NOT(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# if item:
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
# values.extend(item)
# else:
# raise ValueError("Cannot check an empty iterable!")
# else:
# conditions.append("{}!={}".format(key, _replace_char))
# values.append(item)
#
#
# class GEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply GEQ condition to a list!")
# else:
# conditions.append("{} >= {}".format(key, _replace_char))
# values.append(item)
#
#
# class LEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply LEQ condition to a list!")
# else:
# conditions.append("{} <= {}".format(key, _replace_char))
# values.append(item)
#
#
# class Constant(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# conditions.append("{} {}".format(key, self.value))
#
#
# class SHARDID(Condition):
# __slots__ = ('shardid', 'shard_count')
#
# def __init__(self, shardid, shard_count):
# self.shardid = shardid
# self.shard_count = shard_count
#
# def apply(self, key, values, conditions):
# if self.shard_count > 1:
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
# values.append(self.shardid)
#
#
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
#
#
# NULL = Constant('IS NULL')
# NOTNULL = Constant('IS NOT NULL')

View File

@@ -1,47 +0,0 @@
import logging
import psycopg2 as psy
from meta import log, conf
from constants import DATA_VERSION
from .cursor import DictLoggingCursor
# Set up database connection
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
# Replace char used by the connection for query formatting
_replace_char: str = '%s'
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
# Check the version matches the required version
with conn:
log("Checking db version.", "DB_INIT")
cursor = conn.cursor()
# Get last entry in version table, compare against desired version
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
current_version, _, _ = cursor.fetchone()
if current_version != DATA_VERSION:
# Complain
raise Exception(
("Database version is {}, required version is {}. "
"Please migrate database.").format(current_version, DATA_VERSION)
)
cursor.close()
log("Established connection.", "DB_INIT")
def reset_connection():
log("Re-establishing connection.", "DB_INIT", level=logging.DEBUG)
global conn
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
log("Re-established connection.", "DB_INIT")

58
bot/data/connector.py Normal file
View File

@@ -0,0 +1,58 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable
import logging
import psycopg as psq
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self.conn: psq.AsyncConnection = None
self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection:
"""
Get the current active connection.
This should never be cached outside of a transaction.
"""
# TODO: Reconnection logic?
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
return self.conn
async def connect(self) -> psq.AsyncConnection:
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
self.conn = await psq.AsyncConnection.connect(
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
)
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection:
return await self.connect()
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError

View File

@@ -1,30 +1,42 @@
import logging
from psycopg2.extras import DictCursor, _ext
from typing import Optional
from meta import log
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import pgconn_encoding
logger = logging.getLogger(__name__)
class DictLoggingCursor(DictCursor):
def log(self):
msg = self.query
if isinstance(msg, bytes):
msg = msg.decode(_ext.encodings[self.connection.encoding], 'replace')
class AsyncLoggingCursor(AsyncCursor):
def mogrify_query(self, query: Query):
if isinstance(query, str):
msg = query
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg
log(
msg,
context="DATABASE_QUERY",
level=logging.DEBUG,
post=False
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
if logging.DEBUG >= logger.getEffectiveLevel():
msg = self.mogrify_query(query)
logger.debug(
"Executing query (%s) with values %s", msg, params,
extra={'action': "Query Execute"}
)
def execute(self, query, vars=None):
try:
return super().execute(query, vars)
finally:
self.log()
def callproc(self, procname, vars=None):
try:
return super().callproc(procname, vars)
finally:
self.log()
return await super().execute(query, params=params, **kwargs)
except Exception:
msg = self.mogrify_query(query)
logger.exception(
"Exception during query execution. Query (%s) with parameters %s.",
msg, params,
extra={'action': "Query Execute"},
stack_info=True
)
else:
# TODO: Possibly log execution time
pass

38
bot/data/database.py Normal file
View File

@@ -0,0 +1,38 @@
import logging
from collections import namedtuple
# from .cursor import AsyncLoggingCursor
from .registry import Registry
from .connector import Connector
logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author'))
class Database(Connector):
# cursor_factory = AsyncLoggingCursor
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.registries: dict[str, Registry] = {}
def load_registry(self, registry: Registry):
registry.bind(self)
self.registries[registry.name] = registry
async def version(self) -> Version:
"""
Return the current schema version as a Version namedtuple.
"""
async with self.conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

View File

@@ -1,115 +0,0 @@
from .connection import _replace_char
from .conditions import Condition
class _updateField:
__slots__ = ()
_EMPTY = object() # Return value for `value` indicating no value should be added
def key_field(self, key):
raise NotImplementedError
def value_field(self, key):
raise NotImplementedError
class UpdateValue(_updateField):
__slots__ = ('key_str', 'value')
def __init__(self, key_str, value=_updateField._EMPTY):
self.key_str = key_str
self.value = value
def key_field(self, key):
return self.key_str.format(key=key, value=_replace_char, replace=_replace_char)
def value_field(self, key):
return self.value
class UpdateValueAdd(_updateField):
__slots__ = ('value',)
def __init__(self, value):
self.value = value
def key_field(self, key):
return "{key} = {key} + {replace}".format(key=key, replace=_replace_char)
def value_field(self, key):
return self.value
def _format_conditions(conditions):
"""
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
Supports `IN` type conditionals.
"""
if not conditions:
return ("", tuple())
values = []
conditional_strings = []
for key, item in conditions.items():
if isinstance(item, (list, tuple)):
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
values.extend(item)
elif isinstance(item, Condition):
item.apply(key, values, conditional_strings)
else:
conditional_strings.append("{}={}".format(key, _replace_char))
values.append(item)
return (' AND '.join(conditional_strings), values)
def _format_selectkeys(keys):
"""
Formats a list of keys into a string suitable for `SELECT`.
"""
if not keys:
return "*"
elif type(keys) is str:
return keys
else:
return ", ".join(keys)
def _format_insertkeys(keys):
"""
Formats a list of keys into a string suitable for `INSERT`
"""
if not keys:
return ""
else:
return "({})".format(", ".join(keys))
def _format_insertvalues(values):
"""
Formats a list of values into a string suitable for `INSERT`
"""
value_str = "({})".format(", ".join(_replace_char for value in values))
return (value_str, values)
def _format_updatestr(valuedict):
"""
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
"""
if not valuedict:
return ("", tuple())
key_fields = []
values = []
for key, value in valuedict.items():
if isinstance(value, _updateField):
key_fields.append(value.key_field(key))
v = value.value_field(key)
if v is not _updateField._EMPTY:
values.append(value.value_field(key))
else:
key_fields.append("{} = {}".format(key, _replace_char))
values.append(value)
return (', '.join(key_fields), values)

View File

@@ -1,315 +0,0 @@
from __future__ import annotations
import logging
import traceback
import contextlib
from cachetools import LRUCache
from typing import Mapping
import psycopg2
import asyncio
from meta import log, client
from utils.lib import DotDict
from .connection import conn
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where, update_many
# Global cache of interfaces
tables: Mapping[str, Table] = DotDict()
def _connection_guard(func):
"""
Query decorator that performs a client shutdown when the database isn't responding.
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (psycopg2.OperationalError, psycopg2.InterfaceError):
log("Critical error performing database query. Shutting down. "
"Exception traceback follows.\n{}".format(
traceback.format_exc()
),
context="DATABASE_QUERY",
level=logging.ERROR)
asyncio.create_task(client.close())
raise Exception("Critical error, database connection closed. Restarting client.")
return wrapper
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
Intended to be subclassed to provide more derivative access for specific tables.
"""
conn = conn
def __init__(self, name, attach_as=None):
self.name = name
self.queries = DotDict()
tables[attach_as or name] = self
@_connection_guard
def select_where(self, *args, **kwargs):
with self.conn:
return select_where(self.name, *args, **kwargs)
def select_one_where(self, *args, **kwargs):
rows = self.select_where(*args, **kwargs)
return rows[0] if rows else None
@_connection_guard
def update_where(self, *args, **kwargs):
with self.conn:
return update_where(self.name, *args, **kwargs)
@_connection_guard
def delete_where(self, *args, **kwargs):
with self.conn:
return delete_where(self.name, *args, **kwargs)
@_connection_guard
def insert(self, *args, **kwargs):
with self.conn:
return insert(self.name, *args, **kwargs)
@_connection_guard
def insert_many(self, *args, **kwargs):
with self.conn:
return insert_many(self.name, *args, **kwargs)
@_connection_guard
def update_many(self, *args, **kwargs):
with self.conn:
return update_many(self.name, *args, **kwargs)
@_connection_guard
def upsert(self, *args, **kwargs):
with self.conn:
return upsert(self.name, *args, **kwargs)
def save_query(self, func):
"""
Decorator to add a saved query to the table.
"""
self.queries[func.__name__] = func
return func
class Row:
__slots__ = ('table', 'data', '_pending')
conn = conn
def __init__(self, table, data, *args, **kwargs):
super().__setattr__('table', table)
self.data = data
self._pending = None
@property
def rowid(self):
return self.table.id_from_row(self.data)
def __repr__(self):
return "Row[{}]({})".format(
self.table.name,
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
)
def __getattr__(self, key):
if key in self.table.columns:
if self._pending and key in self._pending:
return self._pending[key]
else:
return self.data[key]
else:
raise AttributeError(key)
def __setattr__(self, key, value):
if key in self.table.columns:
if self._pending is None:
self.update(**{key: value})
else:
self._pending[key] = value
else:
super().__setattr__(key, value)
@contextlib.contextmanager
def batch_update(self):
if self._pending:
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
self._pending = {}
try:
yield self._pending
finally:
if self._pending:
self.update(**self._pending)
self._pending = None
def _refresh(self):
row = self.table.select_one_where(**self.table.dict_from_id(self.rowid))
if not row:
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
self.data = row
def update(self, **values):
rows = self.table.update_where(values, **self.table.dict_from_id(self.rowid))
self.data = rows[0]
@classmethod
def _select_where(cls, _extra=None, **conditions):
return select_where(cls._table, **conditions)
@classmethod
def _insert(cls, **values):
return insert(cls._table, **values)
@classmethod
def _update_where(cls, values, **conditions):
return update_where(cls._table, values, **conditions)
class RowTable(Table):
__slots__ = (
'name',
'columns',
'id_col',
'multi_key',
'row_cache'
)
conn = conn
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000, **kwargs):
super().__init__(name, **kwargs)
self.name = name
self.columns = columns
self.id_col = id_col
self.multi_key = isinstance(id_col, tuple)
self.row_cache = (cache if cache is not None else LRUCache(cache_size)) if use_cache else None
def id_from_row(self, row):
if self.multi_key:
return tuple(row[key] for key in self.id_col)
else:
return row[self.id_col]
def dict_from_id(self, rowid):
if self.multi_key:
return dict(zip(self.id_col, rowid))
else:
return {self.id_col: rowid}
# Extend original Table update methods to modify the cached rows
def insert(self, *args, **kwargs):
data = super().insert(*args, **kwargs)
if self.row_cache is not None:
self.row_cache[self.id_from_row(data)] = Row(self, data)
return data
def insert_many(self, *args, **kwargs):
data = super().insert_many(*args, **kwargs)
if self.row_cache is not None:
for data_row in data:
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
if cached_row is not None:
cached_row.data = data_row
return data
def update_where(self, *args, **kwargs):
data = super().update_where(*args, **kwargs)
if self.row_cache is not None:
for data_row in data:
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
if cached_row is not None:
cached_row.data = data_row
return data
def update_many(self, *args, **kwargs):
data = super().update_many(*args, **kwargs)
if self.row_cache is not None:
for data_row in data:
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
if cached_row is not None:
cached_row.data = data_row
return data
def delete_where(self, *args, **kwargs):
data = super().delete_where(*args, **kwargs)
if self.row_cache is not None:
for data_row in data:
self.row_cache.pop(self.id_from_row(data_row), None)
return data
def upsert(self, *args, **kwargs):
data = super().upsert(*args, **kwargs)
if self.row_cache is not None:
rowid = self.id_from_row(data)
cached_row = self.row_cache.get(rowid, None)
if cached_row is not None:
cached_row.data = data
else:
self.row_cache[rowid] = Row(self, data)
return data
# New methods to fetch and create rows
def _make_rows(self, *data_rows):
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
if self.row_cache is not None:
rows = []
for data_row in data_rows:
rowid = self.id_from_row(data_row)
cached_row = self.row_cache.get(rowid, None)
if cached_row is not None:
cached_row.data = data_row
row = cached_row
else:
row = Row(self, data_row)
self.row_cache[rowid] = row
rows.append(row)
else:
rows = [Row(self, data_row) for data_row in data_rows]
return rows
def create_row(self, *args, **kwargs):
data = self.insert(*args, **kwargs)
return self._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs):
# TODO: Handle list of rowids here?
data = self.select_where(*args, **kwargs)
return self._make_rows(*data)
def fetch(self, rowid):
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
if row is None:
rows = self.fetch_rows_where(**self.dict_from_id(rowid))
row = rows[0] if rows else None
return row
def fetch_or_create(self, rowid=None, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid is not None:
row = self.fetch(rowid)
else:
data = self.select_where(**kwargs)
row = self._make_rows(data[0])[0] if data else None
if row is None:
creation_kwargs = kwargs
if rowid is not None:
creation_kwargs.update(self.dict_from_id(rowid))
row = self.create_row(**creation_kwargs)
return row

296
bot/data/models.py Normal file
View File

@@ -0,0 +1,296 @@
from typing import TypeVar, Type, Any, Optional, Generic, Union, Mapping
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
self.model._make_rows(*data)
return data[0]
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.name,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
class WeakCache(MutableMapping):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...]
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable
def __init_subclass__(cls: Type[RowT], table: str = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
def __init__(self, data):
self.data = data
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}({})".format(
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None)
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]

View File

@@ -1,150 +1,521 @@
from itertools import chain
from psycopg2.extras import execute_values
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql
from psycopg.rows import DictRow
from .connection import conn
from .formatters import (_format_updatestr, _format_conditions, _format_insertkeys,
_format_selectkeys, _format_insertvalues)
import logging
from .conditions import Condition
from .base import Expression, RawExpr
from .connector import Connector
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
logger = logging.getLogger(__name__)
TQueryT = TypeVar('TQueryT', bound='TableQuery')
SQueryT = TypeVar('SQueryT', bound='Select')
QueryResult = TypeVar('QueryResult')
class Query(Generic[QueryResult]):
"""
Select rows from the given table matching the conditions
ABC for an executable query statement.
"""
criteria, criteria_values = _format_conditions(conditions)
col_str = _format_selectkeys(select_columns)
__slots__ = ('conn', 'cursor', '_adapter', 'connector')
if criteria:
where_str = "WHERE {}".format(criteria)
_adapter: Callable[..., QueryResult]
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
self.connector: Optional[Connector] = connector
self.conn: Optional[AsyncConnection] = conn
self.cursor: Optional[AsyncCursor] = cursor
if row_adapter is not None:
self._adapter = row_adapter
else:
where_str = ""
self._adapter = self._no_adapter
cursor = cursor or conn.cursor()
cursor.execute(
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
criteria_values
)
return cursor.fetchall()
def bind(self, connector: Connector):
self.connector = connector
return self
def with_cursor(self, cursor: AsyncCursor):
self.cursor = cursor
return self
def update_where(table, valuedict, cursor=None, **conditions):
def with_connection(self, conn: AsyncConnection):
self.conn = conn
return self
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def with_adapter(self, callable: Callable[..., QueryResult]):
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
# For this to work cleanly, callable should have arg type of QR1, not any
self._adapter = callable
return self
def one(self):
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
return self
def build(self) -> Expression:
raise NotImplementedError
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
query, values = self.build().as_tuple()
# TODO: Move logging out to a custom cursor
# logger.debug(
# f"Executing query ({query.as_string(cursor)}) with values {values}",
# extra={'action': "Query"}
# )
await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall()
return self._adapter(*data)
async def execute(self, cursor=None) -> QueryResult:
"""
Update rows in the given table matching the conditions
Execute the query, optionally with the provided cursor, and return the result rows.
If no cursor is provided, and no cursor has been set with `with_cursor`,
the execution will create a new cursor from the connection and close it automatically.
"""
key_str, key_values = _format_updatestr(valuedict)
criteria, criteria_values = _format_conditions(conditions)
if criteria:
where_str = "WHERE {}".format(criteria)
# Create a cursor if possible
cursor = cursor if cursor is not None else self.cursor
if self.cursor is None:
if self.conn is None:
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
where_str = ""
cursor = cursor or conn.cursor()
cursor.execute(
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
tuple((*key_values, *criteria_values))
)
return cursor.fetchall()
def delete_where(table, cursor=None, **conditions):
"""
Delete rows in the given table matching the conditions
"""
criteria, criteria_values = _format_conditions(conditions)
if criteria:
where_str = "WHERE {}".format(criteria)
conn = await self.connector.get_connection()
else:
where_str = ""
conn = self.conn
cursor = cursor or conn.cursor()
cursor.execute(
'DELETE FROM {} {} RETURNING *'.format(table, where_str),
criteria_values
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data
def __await__(self):
return self.execute().__await__()
class TableQuery(Query[QueryResult]):
"""
ABC for an executable query statement expected to be run on a single table.
"""
__slots__ = (
'table',
'condition', '_extra', '_limit', '_order', '_joins'
)
return cursor.fetchall()
def __init__(self, table, *args, **kwargs):
super().__init__(*args, **kwargs)
self.table: str = table
def insert(table, cursor=None, allow_replace=False, **values):
def options(self, **kwargs):
"""
Insert the given values into the table
Set some query options.
Default implementation does nothing.
Should be overridden to provide specific options.
"""
keys, values = zip(*values.items())
return self
key_str = _format_insertkeys(keys)
value_str, values = _format_insertvalues(values)
action = 'REPLACE' if allow_replace else 'INSERT'
class WhereMixin(TableQuery[QueryResult]):
__slots__ = ()
cursor = cursor or conn.cursor()
cursor.execute(
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
values
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.condition: Optional[Condition] = None
def where(self, *args: Condition, **kwargs):
"""
Add a Condition to the query.
Position arguments should be Conditions,
and keyword arguments should be of the form `column=Value`,
where Value may be a Value-type or a literal value.
All provided Conditions will be and-ed together to create a new Condition.
TODO: Maybe just pass this verbatim to a condition.
"""
if args or kwargs:
condition = Condition.construct(*args, **kwargs)
if self.condition is not None:
condition = self.condition & condition
self.condition = condition
return self
@property
def _where_section(self) -> Optional[Expression]:
if self.condition is not None:
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
else:
return None
class JoinMixin(TableQuery[QueryResult]):
__slots__ = ()
# TODO: Remember to add join slots to TableQuery
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._joins: list[Expression, ...] = []
def join(self,
target: Union[str, Expression],
on: Optional[Condition] = None, using: Union[Expression, tuple[str, ...]] = None,
natural=False):
available = (on is not None) + (using is not None) + natural
if available == 0:
raise ValueError("No conditions given for Query Join")
if available > 1:
raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(sql.SQL('JOIN'), ())]
if isinstance(target, str):
sections.append((sql.Identifier(target), ()))
else:
sections.append(target.as_tuple())
if on is not None:
sections.append((sql.SQL('ON'), ()))
sections.append(on.as_tuple())
elif using is not None:
sections.append((sql.SQL('USING'), ()))
if isinstance(using, Expression):
sections.append(using.as_tuple())
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
sections.append((cols, ()))
else:
raise ValueError("Unrecognised 'using' type.")
elif natural:
sections.insert(0, (sql.SQL('NATURAL'), ()))
expr = RawExpr.join_tuples(*sections)
self._joins.append(expr)
return self
@property
def _join_section(self) -> Optional[Expression]:
if self._joins:
return RawExpr.join(*self._joins)
else:
return None
class ExtraMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extra: Optional[Expression] = None
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
"""
Add an extra string, and optionally values, to this query.
The extra string is inserted after any condition, and before the limit.
"""
extra_expr = RawExpr(extra, values)
if self._extra is not None:
extra_expr = RawExpr.join(self._extra, extra_expr)
self._extra = extra_expr
return self
@property
def _extra_section(self) -> Optional[Expression]:
if self._extra is None:
return None
else:
return self._extra
class LimitMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._limit: Optional[int] = None
def limit(self, limit: int):
"""
Add a limit to this query.
"""
self._limit = limit
return self
@property
def _limit_section(self) -> Optional[Expression]:
if self._limit is not None:
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
else:
return None
class OrderMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._order: Optional[sql.Composable] = None
def order_by(self, expression, direction=None, nulls=None):
"""
Add a single sort expression to the query.
This method stacks.
"""
# TODO: Accept a ColumnExpression, string, or sql Composable
# TODO: Enums for direction (ORDER.ASC or ORDER.DESC) and nulls (NULLS.FIRST, NULLS.LAST)
...
@property
def _order_section(self) -> Optional[Expression]:
if self._order is not None:
return RawExpr(sql.SQL("ORDER BY {}").format(self._order), ())
else:
return None
class Insert(ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table insert query.
"""
# TODO: Support ON CONFLICT for upserts
__slots__ = ('_columns', '_values')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[str, ...] = ()
self._values: tuple[tuple[Any, ...], ...] = ()
def insert(self, columns, *values):
"""
Insert the given data.
Parameters
----------
columns: tuple[str]
Tuple of column names to insert.
values: tuple[tuple[Any, ...], ...]
Tuple of values to insert, corresponding to the columns.
"""
if not values:
raise ValueError("Cannot insert zero rows.")
if len(values[0]) != len(columns):
raise ValueError("Number of columns does not match length of values.")
self._columns = columns
self._values = values
def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
single_value_str = sql.SQL('({})').format(
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
)
return cursor.fetchone()
values_str = sql.SQL(',').join(single_value_str * len(self._values))
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
"""
Insert all the given values into the table
"""
key_str = _format_insertkeys(insert_keys)
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
value_str = ", ".join(value_strs)
values = tuple(chain(*value_tuples))
cursor = cursor or conn.cursor()
cursor.execute(
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
values
# TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query
base = sql.Sql("INSERT INTO {table} {columns} VALUES {values_str}").format(
table=sql.Identifier(self.table),
columns=columns,
values_str=values_str
)
return cursor.fetchall()
sections = [
RawExpr(base),
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
def upsert(table, constraint, cursor=None, **values):
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQuery[QueryResult]):
"""
Insert or on conflict update.
Select rows from a table matching provided conditions.
"""
valuedict = values
keys, values = zip(*values.items())
__slots__ = ('_columns',)
key_str = _format_insertkeys(keys)
value_str, values = _format_insertvalues(values)
update_key_str, update_key_values = _format_updatestr(valuedict)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[Expression, ...] = ()
if not isinstance(constraint, str):
constraint = ", ".join(constraint)
cursor = cursor or conn.cursor()
cursor.execute(
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
table, key_str, value_str, constraint, update_key_str
),
tuple((*values, *update_key_values))
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
"""
Set the columns and expressions to select.
If none are given, selects all columns.
"""
cols: List[Expression] = []
if columns:
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
if exprs:
for name, expr in exprs.items():
if isinstance(expr, str):
cols.append(
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
)
return cursor.fetchone()
def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
cursor = cursor or conn.cursor()
return execute_values(
cursor,
"""
UPDATE {table}
SET {set_clause}
FROM (VALUES {cast_row}%s)
AS {temp_table}
WHERE {where_clause}
RETURNING *
""".format(
table=table,
set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
cast_row=cast_row + ',' if cast_row else '',
where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
temp_table="_t ({})".format(', '.join(set_keys + where_keys))
),
values,
fetch=True
elif isinstance(expr, sql.Composable):
cols.append(
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, Expression):
value_expr, value_values = expr.as_tuple()
cols.append(RawExpr(
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
value_values
))
if cols:
self._columns = (*self._columns, *cols)
return self
def build(self):
if not self._columns:
columns, columns_values = sql.SQL('*'), ()
else:
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns,
table=sql.Identifier(self.table)
)
sections = [
RawExpr(base, columns_values),
self._join_section,
self._where_section,
self._extra_section,
self._order_section,
self._limit_section,
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table delete query.
"""
# TODO: Cascade option for delete, maybe other options
# TODO: Require a where unless specifically disabled, for safety
def build(self):
base = sql.SQL("DELETE FROM {table}").format(
table=sql.Identifier(self.table),
)
sections = [
RawExpr(base),
self._where_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]):
__slots__ = (
'_set',
)
# TODO: Again, require a where unless specifically disabled
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set: List[Expression] = []
def set(self, **column_values: Union[Any, Expression]):
exprs: List[Expression] = []
for name, value in column_values.items():
if isinstance(value, Expression):
value_tup = value.as_tuple()
else:
value_tup = (sql.Placeholder(), value)
exprs.append(
RawExpr.join_tuples(
(sql.Identifier(name), ()),
value_tup,
joiner=sql.SQL(' = ')
)
)
self._set.extend(exprs)
return self
def build(self):
if not self._set:
raise ValueError("No columns provided to update.")
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format(
table=sql.Identifier(self.table),
set=set_expr
)
sections = [
RawExpr(base, set_values),
self._where_section,
self._extra_section,
self._limit_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
# async def upsert(cursor, table, constraint, **values):
# """
# Insert or on conflict update.
# """
# valuedict = values
# keys, values = zip(*values.items())
#
# key_str = _format_insertkeys(keys)
# value_str, values = _format_insertvalues(values)
# update_key_str, update_key_values = _format_updatestr(valuedict)
#
# if not isinstance(constraint, str):
# constraint = ", ".join(constraint)
#
# await cursor.execute(
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
# table, key_str, value_str, constraint, update_key_str
# ),
# tuple((*values, *update_key_values))
# )
# return await cursor.fetchone()
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
# cursor = cursor or conn.cursor()
#
# # TODO: executemany or copy syntax now
# return execute_values(
# cursor,
# """
# UPDATE {table}
# SET {set_clause}
# FROM (VALUES {cast_row}%s)
# AS {temp_table}
# WHERE {where_clause}
# RETURNING *
# """.format(
# table=table,
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
# cast_row=cast_row + ',' if cast_row else '',
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
# ),
# values,
# fetch=True
# )

102
bot/data/registry.py Normal file
View File

@@ -0,0 +1,102 @@
from typing import Protocol, runtime_checkable, Optional
from psycopg import AsyncConnection
from .connector import Connector, Connectable
@runtime_checkable
class _Attachable(Connectable, Protocol):
def attach_to(self, registry: 'Registry'):
raise NotImplementedError
class Registry:
_attached: list[_Attachable] = []
_name: Optional[str] = None
def __init_subclass__(cls, name=None):
attached = []
for name, member in cls.__dict__.items():
if isinstance(member, _Attachable):
attached.append(member)
cls._attached = attached
cls._name = name or cls.__name__
def __init__(self, name=None):
self._conn: Connector = None
self.name: str = name or self._name
if self.name is None:
raise ValueError("A Registry must have a name!")
self.init_tasks = []
for member in self._attached:
member.attach_to(self)
def bind(self, connector: Connector):
self._conn = connector
for child in self._attached:
child.bind(connector)
def attach(self, attachable):
self._attached.append(attachable)
if self._conn is not None:
attachable.bind(self._conn)
return attachable
def init_task(self, coro):
"""
Initialisation tasks are run to setup the registry state.
These tasks will be run in the event loop, after connection to the database.
These tasks should be idempotent, as they may be run on reload and reconnect.
"""
self.init_tasks.append(coro)
return coro
async def init(self):
for task in self.init_tasks:
await task(self)
return self
class AttachableClass:
"""ABC for a default implementation of an Attachable class."""
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
@classmethod
def bind(cls, connector: Connector):
cls._connector = connector
connector.connect_hook(cls.on_connect)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@classmethod
async def on_connect(cls, connection: AsyncConnection):
pass
class Attachable:
"""ABC for a default implementation of an Attachable object."""
def __init__(self, *args, **kwargs):
self._connector: Optional[Connector] = None
self._registry: Optional[Registry] = None
def bind(self, connector: Connector):
self._connector = connector
connector.connect_hook(self.on_connect)
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
async def on_connect(self, connection: AsyncConnection):
pass

86
bot/data/table.py Normal file
View File

@@ -0,0 +1,86 @@
from typing import Optional
from psycopg.rows import DictRow
from . import queries as q
from .connector import Connector
from .registry import Registry
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
"""
def __init__(self, name, *args, **kwargs):
self.name: str = name
self.connector: Connector = None
def bind(self, connector: Connector):
self.connector = connector
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
if data:
return data[0]
else:
return None
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select(
self.name,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select(
self.name,
row_adapter=self._single_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update(
self.name,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete(
self.name,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert(
self.name,
row_adapter=self._single_query_adapter,
connector=self.connector
).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert(
self.name,
row_adapter=self._many_query_adapter,
connector=self.connector
).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs):
# with self.conn:
# return update_many(self.name, *args, **kwargs)
# def upsert(self, *args, **kwargs):
# return upsert(self.name, *args, **kwargs)