Switch to new plugin framework.

This commit is contained in:
2025-09-03 20:35:35 +10:00
parent e3bdebe221
commit 749f2a021c
28 changed files with 228 additions and 2179 deletions

6
.gitmodules vendored Normal file
View File

@@ -0,0 +1,6 @@
[submodule "src/modules/profiles"]
path = src/modules/profiles
url = https://git.thewisewolf.dev/HoloTech/profiles-plugin.git
[submodule "src/data"]
path = src/data
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git

View File

@@ -1,44 +1,45 @@
import asyncio import asyncio
import logging import logging
import websockets
from twitchio.web import AiohttpAdapter from twitchio.web import AiohttpAdapter
from meta import CrocBot, conf, setup_main_logger, args from meta import Bot, conf, setup_main_logger, args, sockets
from data import Database from data import Database
from constants import DATA_VERSION
from modules import setup from modules import twitch_setup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProxyAiohttpAdapter(AiohttpAdapter):
def _find_redirect(self, request):
return self.redirect_url
async def main(): async def main():
db = Database(conf.data['args']) db = Database(conf.data['args'])
async with db.open(): async with db.open():
version = await db.version() adapter = ProxyAiohttpAdapter(
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
adapter = AiohttpAdapter(
host=conf.bot.get('wshost', None), host=conf.bot.get('wshost', None),
port=conf.bot.getint('wsport', None), port=conf.bot.getint('wsport', None),
domain=conf.bot.get('wsdomain', None), domain=conf.bot.get('wsdomain', None),
eventsub_secret=conf.bot.get('eventsub_secret', None)
) )
bot = CrocBot( bot = Bot(
config=conf, config=conf,
dbconn=db, dbconn=db,
adapter=adapter, adapter=adapter,
setup=setup, setup=twitch_setup,
) )
try: async with websockets.serve(sockets.root_handler, '', conf.wserver.getint('port')):
await bot.start() try:
finally: await bot.start()
await bot.close() finally:
await bot.close()
def _main(): def _main():

82
src/botdata.py Normal file
View File

@@ -0,0 +1,82 @@
from data import Registry, RowModel, Table
from data.columns import String, Timestamp, Integer, Bool
class UserAuth(RowModel):
"""
Schema
======
CREATE TABLE user_auth(
userid TEXT PRIMARY KEY,
token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_auth'
_cache_ = {}
userid = String(primary=True)
token = String()
refresh_token = String()
created_at = Timestamp()
_timestamp = Timestamp()
class BotChannel(RowModel):
"""
Schema
======
CREATE TABLE bot_channels(
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
autojoin BOOLEAN DEFAULT true,
listen_redeems BOOLEAN,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'bot_channels'
_cache_ = {}
userid = String(primary=True)
autojoin = Bool()
listen_redeems = Bool()
joined_at = Timestamp()
_timestamp = Timestamp()
class VersionHistory(RowModel):
"""
CREATE TABLE version_history(
component TEXT NOT NULL,
from_version INTEGER NOT NULL,
to_version INTEGER NOT NULL,
author TEXT NOT NULL,
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
);
"""
_tablename_ = 'version_history'
_cache_ = {}
component = String()
from_version = Integer()
to_version = Integer()
author = String()
_timestamp = Timestamp()
class BotData(Registry):
version_history = VersionHistory.table
user_auth = UserAuth.table
bot_channels = BotChannel.table
"""
CREATE TABLE user_auth_scopes(
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
scope TEXT NOT NULL
);
"""
user_auth_scopes = Table('user_auth_scopes')

View File

@@ -1,2 +1,22 @@
from twitchio import Scopes
CONFIG_FILE = 'config/bot.conf' CONFIG_FILE = 'config/bot.conf'
DATA_VERSION = 1
SCHEMA_VERSIONS = {
'ROOT': 1,
'TWITCH_AUTH': 1
}
# Requested scopes for the bots own twitch user
BOTUSER_SCOPES = Scopes((
Scopes.user_read_chat,
Scopes.user_write_chat,
Scopes.user_bot,
Scopes.channel_bot,
))
# Default requested scopes for joining a channel
CHANNEL_SCOPES = Scopes((
Scopes.channel_bot,
))

1
src/data Submodule

Submodule src/data added at c495b4d097

View File

@@ -1,9 +0,0 @@
from .conditions import Condition, condition, NULL
from .database import Database
from .models import RowModel, RowTable, WeakCache
from .table import Table
from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum
from .queries import ORDER, NULLS, JOINTYPE

View File

@@ -1,40 +0,0 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
class RegisterEnum(Attachable):
def __init__(self, enum, name: Optional[str] = None, mapper=None):
super().__init__()
self.enum = enum
self.name = name or enum.__name__
self.mapping = mapper(enum) if mapper is not None else self._mapper()
def _mapper(self):
return {m: m.value[0] for m in self.enum}
def attach_to(self, registry: Registry):
self._registry = registry
registry.init_task(self.on_init)
return self
async def on_init(self, registry: Registry):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,45 +0,0 @@
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from itertools import chain
from psycopg import sql
@runtime_checkable
class Expression(Protocol):
__slots__ = ()
@abstractmethod
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
raise NotImplementedError
class RawExpr(Expression):
__slots__ = ('expr', 'values')
expr: sql.Composable
values: tuple[Any, ...]
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
self.expr = expr
self.values = values
def as_tuple(self):
return (self.expr, self.values)
@classmethod
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
"""
Join a sequence of Expressions into a single RawExpr.
"""
tups = (
expression.as_tuple()
for expression in expressions
)
return cls.join_tuples(*tups, joiner=joiner)
@classmethod
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
exprs, values = zip(*tuples)
expr = joiner.join(exprs)
value = tuple(chain(*values))
return cls(expr, value)

View File

@@ -1,155 +0,0 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
from psycopg import sql
from datetime import datetime
from .base import RawExpr, Expression
from .conditions import Condition, Joiner
from .table import Table
class ColumnExpr(RawExpr):
__slots__ = ()
def __lt__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column < Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LT, obj_expr)
cond_values = (*values, *obj_values)
else:
# column < Literal
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __le__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column <= Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LE, obj_expr)
cond_values = (*values, *obj_values)
else:
# column <= Literal
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __eq__(self, obj) -> Condition: # type: ignore[override]
return Condition._expression_equality(self, obj)
def __ne__(self, obj) -> Condition: # type: ignore[override]
return ~(self.__eq__(obj))
def __gt__(self, obj) -> Condition:
return ~(self.__le__(obj))
def __ge__(self, obj) -> Condition:
return ~(self.__lt__(obj))
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __sub__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __mul__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def CAST(self, target_type: sql.Composable):
return ColumnExpr(
sql.SQL("({}::{})").format(self.expr, target_type),
self.values
)
T = TypeVar('T')
if TYPE_CHECKING:
from .models import RowModel
class Column(ColumnExpr, Generic[T]):
def __init__(self, name: Optional[str] = None,
primary: bool = False, references: Optional['Column'] = None,
type: Optional[Type[T]] = None):
self.primary = primary
self.references = references
self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None
self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = ()
def __set_name__(self, owner, name):
# Only allow setting the owner once
self.name = self.name or name
self.owner = owner
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
@overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
...
@overload
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
...
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
# Get value from row data or session
if obj is None:
return self
else:
return obj.data[self.name]
class Integer(Column[int]):
pass
class String(Column[str]):
pass
class Bool(Column[bool]):
pass
class Timestamp(Column[datetime]):
pass

View File

@@ -1,214 +0,0 @@
# from meta import sharding
from typing import Any, Union
from enum import Enum
from itertools import chain
from psycopg import sql
from .base import Expression, RawExpr
"""
A Condition is a "logical" database expression, intended for use in Where statements.
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
"""
NULL = None
class Joiner(Enum):
EQUALS = ('=', '!=')
IS = ('IS', 'IS NOT')
LIKE = ('LIKE', 'NOT LIKE')
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
IN = ('IN', 'NOT IN')
LT = ('<', '>=')
LE = ('<=', '>')
NONE = ('', '')
class Condition(Expression):
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
def __init__(self,
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
values: tuple[Any, ...] = (), negated=False
):
self.expr1 = expr1
self.joiner = joiner
self.negated = negated
self.expr2 = expr2
self.values = values
def as_tuple(self):
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
if self.negated and self.joiner is Joiner.NONE:
expr = sql.SQL("NOT ({})").format(expr)
return (expr, self.values)
@classmethod
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
"""
Construct a Condition from a sequence of Conditions,
together with some explicit column conditions.
"""
# TODO: Consider adding a _table identifier here so we can identify implicit columns
# Or just require subquery type conditions to always come from modelled tables.
implicit_conditions = (
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
)
return cls._and(*conditions, *implicit_conditions)
@classmethod
def _and(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _or(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _not(cls, condition: 'Condition'):
condition.negated = not condition.negated
return condition
@classmethod
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
# TODO: Check if this supports sbqueries
col_expr, col_values = column.as_tuple()
# TODO: Also support sql.SQL? For joins?
if isinstance(value, Expression):
# column = Expression
value_expr, value_values = value.as_tuple()
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
cond_values = (*col_values, *value_values)
elif isinstance(value, (tuple, list)):
# column in (...)
# TODO: Support expressions in value tuple?
if not value:
raise ValueError("Cannot create Condition from empty iterable!")
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
cond_exprs = (col_expr, Joiner.IN, value_expr)
cond_values = (*col_values, *value)
elif value is None:
# column IS NULL
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
cond_values = col_values
else:
# column = Literal
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
cond_values = (*col_values, value)
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __invert__(self) -> 'Condition':
self.negated = not self.negated
return self
def __and__(self, condition: 'Condition') -> 'Condition':
return self._and(self, condition)
def __or__(self, condition: 'Condition') -> 'Condition':
return self._or(self, condition)
# Helper method to simply condition construction
def condition(*args, **kwargs) -> Condition:
return Condition.construct(*args, **kwargs)
# class NOT(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# if item:
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
# values.extend(item)
# else:
# raise ValueError("Cannot check an empty iterable!")
# else:
# conditions.append("{}!={}".format(key, _replace_char))
# values.append(item)
#
#
# class GEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply GEQ condition to a list!")
# else:
# conditions.append("{} >= {}".format(key, _replace_char))
# values.append(item)
#
#
# class LEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply LEQ condition to a list!")
# else:
# conditions.append("{} <= {}".format(key, _replace_char))
# values.append(item)
#
#
# class Constant(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# conditions.append("{} {}".format(key, self.value))
#
#
# class SHARDID(Condition):
# __slots__ = ('shardid', 'shard_count')
#
# def __init__(self, shardid, shard_count):
# self.shardid = shardid
# self.shard_count = shard_count
#
# def apply(self, key, values, conditions):
# if self.shard_count > 1:
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
# values.append(self.shardid)
#
#
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
#
#
# NULL = Constant('IS NULL')
# NOTNULL = Constant('IS NOT NULL')

View File

@@ -1,135 +0,0 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Convenience property for the current context connection.
"""
return ctx_connection.get()
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=1,
max_size=4,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError

View File

@@ -1,42 +0,0 @@
import logging
from typing import Optional
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import conn_encoding
logger = logging.getLogger(__name__)
class AsyncLoggingCursor(AsyncCursor):
def mogrify_query(self, query: Query):
if isinstance(query, str):
msg = query
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
if logging.DEBUG >= logger.getEffectiveLevel():
msg = self.mogrify_query(query)
logger.debug(
"Executing query (%s) with values %s", msg, params,
extra={'action': "Query Execute"}
)
try:
return await super().execute(query, params=params, **kwargs)
except Exception:
msg = self.mogrify_query(query)
logger.exception(
"Exception during query execution. Query (%s) with parameters %s.",
msg, params,
extra={'action': "Query Execute"},
stack_info=True
)
else:
# TODO: Possibly log execution time
pass

View File

@@ -1,47 +0,0 @@
from typing import TypeVar
import logging
from collections import namedtuple
# from .cursor import AsyncLoggingCursor
from .registry import Registry
from .connector import Connector
logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author'))
T = TypeVar('T', bound=Registry)
class Database(Connector):
# cursor_factory = AsyncLoggingCursor
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.registries: dict[str, Registry] = {}
def load_registry(self, registry: T) -> T:
logger.debug(
f"Loading and binding registry '{registry.name}'.",
extra={'action': f"Reg {registry.name}"}
)
registry.bind(self)
self.registries[registry.name] = registry
return registry
async def version(self) -> Version:
"""
Return the current schema version as a Version namedtuple.
"""
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

View File

@@ -1,323 +0,0 @@
from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
if data:
self.model._make_rows(*data)
return data[0]
else:
return None
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.identifier,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
WK = TypeVar('WK')
WV = TypeVar('WV')
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable = None
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
@classmethod
def as_tuple(cls):
return (cls.table.identifier, ())
def __init__(self, data):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}.{}({})".format(
self.table.schema,
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None) if cached else None
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None

View File

@@ -1,644 +0,0 @@
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from enum import Enum
from itertools import chain
from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql
from psycopg.rows import DictRow
import logging
from .conditions import Condition
from .base import Expression, RawExpr
from .connector import Connector
logger = logging.getLogger(__name__)
TQueryT = TypeVar('TQueryT', bound='TableQuery')
SQueryT = TypeVar('SQueryT', bound='Select')
QueryResult = TypeVar('QueryResult')
class Query(Generic[QueryResult]):
"""
ABC for an executable query statement.
"""
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
_adapter: Callable[..., QueryResult]
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
self.connector: Optional[Connector] = connector
self.conn: Optional[AsyncConnection] = conn
self.cursor: Optional[AsyncCursor] = cursor
if row_adapter is not None:
self._adapter = row_adapter
else:
self._adapter = self._no_adapter
self.result: Optional[QueryResult] = None
def bind(self, connector: Connector):
self.connector = connector
return self
def with_cursor(self, cursor: AsyncCursor):
self.cursor = cursor
return self
def with_connection(self, conn: AsyncConnection):
self.conn = conn
return self
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def with_adapter(self, callable: Callable[..., QueryResult]):
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
# For this to work cleanly, callable should have arg type of QR1, not any
self._adapter = callable
return self
def with_no_adapter(self):
"""
Sets the adapater to the identity.
"""
self._adapter = self._no_adapter
return self
def one(self):
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
return self
def build(self) -> Expression:
raise NotImplementedError
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
query, values = self.build().as_tuple()
# TODO: Move logging out to a custom cursor
# logger.debug(
# f"Executing query ({query.as_string(cursor)}) with values {values}",
# extra={'action': "Query"}
# )
await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall()
self.result = self._adapter(*data)
return self.result
async def execute(self, cursor=None) -> QueryResult:
"""
Execute the query, optionally with the provided cursor, and return the result rows.
If no cursor is provided, and no cursor has been set with `with_cursor`,
the execution will create a new cursor from the connection and close it automatically.
"""
# Create a cursor if possible
cursor = cursor if cursor is not None else self.cursor
if self.cursor is None:
if self.conn is None:
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data
def __await__(self):
return self.execute().__await__()
class TableQuery(Query[QueryResult]):
"""
ABC for an executable query statement expected to be run on a single table.
"""
__slots__ = (
'tableid',
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
)
def __init__(self, tableid, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tableid: sql.Identifier = tableid
def options(self, **kwargs):
"""
Set some query options.
Default implementation does nothing.
Should be overridden to provide specific options.
"""
return self
class WhereMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.condition: Optional[Condition] = None
def where(self, *args: Condition, **kwargs):
"""
Add a Condition to the query.
Position arguments should be Conditions,
and keyword arguments should be of the form `column=Value`,
where Value may be a Value-type or a literal value.
All provided Conditions will be and-ed together to create a new Condition.
TODO: Maybe just pass this verbatim to a condition.
"""
if args or kwargs:
condition = Condition.construct(*args, **kwargs)
if self.condition is not None:
condition = self.condition & condition
self.condition = condition
return self
@property
def _where_section(self) -> Optional[Expression]:
if self.condition is not None:
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
else:
return None
class JOINTYPE(Enum):
LEFT = sql.SQL('LEFT JOIN')
RIGHT = sql.SQL('RIGHT JOIN')
INNER = sql.SQL('INNER JOIN')
OUTER = sql.SQL('OUTER JOIN')
FULLOUTER = sql.SQL('FULL OUTER JOIN')
class JoinMixin(TableQuery[QueryResult]):
__slots__ = ()
# TODO: Remember to add join slots to TableQuery
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._joins: list[Expression] = []
def join(self,
target: Union[str, Expression],
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
join_type: JOINTYPE = JOINTYPE.INNER,
natural=False):
available = (on is not None) + (using is not None) + natural
if available == 0:
raise ValueError("No conditions given for Query Join")
if available > 1:
raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
if isinstance(target, str):
sections.append((sql.Identifier(target), ()))
else:
sections.append(target.as_tuple())
if on is not None:
sections.append((sql.SQL('ON'), ()))
sections.append(on.as_tuple())
elif using is not None:
sections.append((sql.SQL('USING'), ()))
if isinstance(using, Expression):
sections.append(using.as_tuple())
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
sections.append((cols, ()))
else:
raise ValueError("Unrecognised 'using' type.")
elif natural:
sections.insert(0, (sql.SQL('NATURAL'), ()))
expr = RawExpr.join_tuples(*sections)
self._joins.append(expr)
return self
def leftjoin(self, *args, **kwargs):
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
@property
def _join_section(self) -> Optional[Expression]:
if self._joins:
return RawExpr.join(*self._joins)
else:
return None
class ExtraMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extra: Optional[Expression] = None
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
"""
Add an extra string, and optionally values, to this query.
The extra string is inserted after any condition, and before the limit.
"""
extra_expr = RawExpr(extra, values)
if self._extra is not None:
extra_expr = RawExpr.join(self._extra, extra_expr)
self._extra = extra_expr
return self
@property
def _extra_section(self) -> Optional[Expression]:
if self._extra is None:
return None
else:
return self._extra
class LimitMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._limit: Optional[int] = None
def limit(self, limit: int):
"""
Add a limit to this query.
"""
self._limit = limit
return self
@property
def _limit_section(self) -> Optional[Expression]:
if self._limit is not None:
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
else:
return None
class FromMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._from: Optional[Expression] = None
def from_expr(self, _from: Expression):
self._from = _from
return self
@property
def _from_section(self) -> Optional[Expression]:
if self._from is not None:
expr, values = self._from.as_tuple()
return RawExpr(sql.SQL("FROM {}").format(expr), values)
else:
return None
class ORDER(Enum):
ASC = sql.SQL('ASC')
DESC = sql.SQL('DESC')
class NULLS(Enum):
FIRST = sql.SQL('NULLS FIRST')
LAST = sql.SQL('NULLS LAST')
class OrderMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._order: list[Expression] = []
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
"""
Add a single sort expression to the query.
This method stacks.
"""
if isinstance(expr, Expression):
string, values = expr.as_tuple()
else:
string = sql.Identifier(expr)
values = ()
parts = [string]
if direction is not None:
parts.append(direction.value)
if nulls is not None:
parts.append(nulls.value)
order_string = sql.SQL(' ').join(parts)
self._order.append(RawExpr(order_string, values))
return self
@property
def _order_section(self) -> Optional[Expression]:
if self._order:
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
return expr
else:
return None
class GroupMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._group: list[Expression] = []
def group_by(self, *exprs: Union[Expression, str]):
"""
Add a group expression(s) to the query.
This method stacks.
"""
for expr in exprs:
if isinstance(expr, Expression):
self._group.append(expr)
else:
self._group.append(RawExpr(sql.Identifier(expr)))
return self
@property
def _group_section(self) -> Optional[Expression]:
if self._group:
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
return expr
else:
return None
class Insert(ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table insert query.
"""
# TODO: Support ON CONFLICT for upserts
__slots__ = ('_columns', '_values', '_conflict')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[str, ...] = ()
self._values: tuple[tuple[Any, ...], ...] = ()
self._conflict: Optional[Expression] = None
def insert(self, columns, *values):
"""
Insert the given data.
Parameters
----------
columns: tuple[str]
Tuple of column names to insert.
values: tuple[tuple[Any, ...], ...]
Tuple of values to insert, corresponding to the columns.
"""
if not values:
raise ValueError("Cannot insert zero rows.")
if len(values[0]) != len(columns):
raise ValueError("Number of columns does not match length of values.")
self._columns = columns
self._values = values
return self
def on_conflict(self, ignore=False):
# TODO lots more we can do here
# Maybe return a Conflict object that can chain itself (not the query)
if ignore:
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
return self
@property
def _conflict_section(self) -> Optional[Expression]:
if self._conflict is not None:
e, v = self._conflict.as_tuple()
expr = RawExpr(
sql.SQL("ON CONFLICT {}").format(
e
),
v
)
return expr
return None
def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
single_value_str = sql.SQL('({})').format(
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
)
values_str = sql.SQL(',').join(single_value_str * len(self._values))
# TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=self.tableid,
columns=columns,
values_str=values_str
)
sections = [
RawExpr(base, tuple(chain(*self._values))),
self._conflict_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
"""
Select rows from a table matching provided conditions.
"""
__slots__ = ('_columns',)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[Expression, ...] = ()
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
"""
Set the columns and expressions to select.
If none are given, selects all columns.
"""
cols: List[Expression] = []
if columns:
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
if exprs:
for name, expr in exprs.items():
if isinstance(expr, str):
cols.append(
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, sql.Composable):
cols.append(
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, Expression):
value_expr, value_values = expr.as_tuple()
cols.append(RawExpr(
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
value_values
))
if cols:
self._columns = (*self._columns, *cols)
return self
def build(self):
if not self._columns:
columns, columns_values = sql.SQL('*'), ()
else:
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns,
table=self.tableid
)
sections = [
RawExpr(base, columns_values),
self._join_section,
self._where_section,
self._group_section,
self._extra_section,
self._order_section,
self._limit_section,
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table delete query.
"""
# TODO: Cascade option for delete, maybe other options
# TODO: Require a where unless specifically disabled, for safety
def build(self):
base = sql.SQL("DELETE FROM {table}").format(
table=self.tableid,
)
sections = [
RawExpr(base),
self._where_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
__slots__ = (
'_set',
)
# TODO: Again, require a where unless specifically disabled
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set: List[Expression] = []
def set(self, **column_values: Union[Any, Expression]):
exprs: List[Expression] = []
for name, value in column_values.items():
if isinstance(value, Expression):
value_tup = value.as_tuple()
else:
value_tup = (sql.Placeholder(), (value,))
exprs.append(
RawExpr.join_tuples(
(sql.Identifier(name), ()),
value_tup,
joiner=sql.SQL(' = ')
)
)
self._set.extend(exprs)
return self
def build(self):
if not self._set:
raise ValueError("No columns provided to update.")
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format(
table=self.tableid,
set=set_expr
)
sections = [
RawExpr(base, set_values),
self._from_section,
self._where_section,
self._extra_section,
self._limit_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
# async def upsert(cursor, table, constraint, **values):
# """
# Insert or on conflict update.
# """
# valuedict = values
# keys, values = zip(*values.items())
#
# key_str = _format_insertkeys(keys)
# value_str, values = _format_insertvalues(values)
# update_key_str, update_key_values = _format_updatestr(valuedict)
#
# if not isinstance(constraint, str):
# constraint = ", ".join(constraint)
#
# await cursor.execute(
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
# table, key_str, value_str, constraint, update_key_str
# ),
# tuple((*values, *update_key_values))
# )
# return await cursor.fetchone()
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
# cursor = cursor or conn.cursor()
#
# # TODO: executemany or copy syntax now
# return execute_values(
# cursor,
# """
# UPDATE {table}
# SET {set_clause}
# FROM (VALUES {cast_row}%s)
# AS {temp_table}
# WHERE {where_clause}
# RETURNING *
# """.format(
# table=table,
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
# cast_row=cast_row + ',' if cast_row else '',
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
# ),
# values,
# fetch=True
# )

View File

@@ -1,102 +0,0 @@
from typing import Protocol, runtime_checkable, Optional
from psycopg import AsyncConnection
from .connector import Connector, Connectable
@runtime_checkable
class _Attachable(Connectable, Protocol):
def attach_to(self, registry: 'Registry'):
raise NotImplementedError
class Registry:
_attached: list[_Attachable] = []
_name: Optional[str] = None
def __init_subclass__(cls, name=None):
attached = []
for _, member in cls.__dict__.items():
if isinstance(member, _Attachable):
attached.append(member)
cls._attached = attached
cls._name = name or cls.__name__
def __init__(self, name=None):
self._conn: Optional[Connector] = None
self.name: str = name if name is not None else self._name
if self.name is None:
raise ValueError("A Registry must have a name!")
self.init_tasks = []
for member in self._attached:
member.attach_to(self)
def bind(self, connector: Connector):
self._conn = connector
for child in self._attached:
child.bind(connector)
def attach(self, attachable):
self._attached.append(attachable)
if self._conn is not None:
attachable.bind(self._conn)
return attachable
def init_task(self, coro):
"""
Initialisation tasks are run to setup the registry state.
These tasks will be run in the event loop, after connection to the database.
These tasks should be idempotent, as they may be run on reload and reconnect.
"""
self.init_tasks.append(coro)
return coro
async def init(self):
for task in self.init_tasks:
await task(self)
return self
class AttachableClass:
"""ABC for a default implementation of an Attachable class."""
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
@classmethod
def bind(cls, connector: Connector):
cls._connector = connector
connector.connect_hook(cls.on_connect)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@classmethod
async def on_connect(cls, connection: AsyncConnection):
pass
class Attachable:
"""ABC for a default implementation of an Attachable object."""
def __init__(self, *args, **kwargs):
self._connector: Optional[Connector] = None
self._registry: Optional[Registry] = None
def bind(self, connector: Connector):
self._connector = connector
connector.connect_hook(self.on_connect)
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
async def on_connect(self, connection: AsyncConnection):
pass

View File

@@ -1,95 +0,0 @@
from typing import Optional
from psycopg.rows import DictRow
from psycopg import sql
from . import queries as q
from .connector import Connector
from .registry import Registry
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
"""
def __init__(self, name, *args, schema='public', **kwargs):
self.name: str = name
self.schema: str = schema
self.connector: Connector = None
@property
def identifier(self):
if self.schema == 'public':
return sql.Identifier(self.name)
else:
return sql.Identifier(self.schema, self.name)
def bind(self, connector: Connector):
self.connector = connector
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
if data:
return data[0]
else:
return None
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs):
# with self.conn:
# return update_many(self.identifier, *args, **kwargs)
# def upsert(self, *args, **kwargs):
# return upsert(self.identifier, *args, **kwargs)

View File

@@ -1,133 +0,0 @@
from data import Registry, RowModel, Table
from data.columns import String, Timestamp, Integer, Bool
class UserAuth(RowModel):
"""
Schema
======
CREATE TABLE user_auth(
userid TEXT PRIMARY KEY,
token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_auth'
_cache_ = {}
userid = String(primary=True)
token = String()
refresh_token = String()
created_at = Timestamp()
_timestamp = Timestamp()
class BotChannel(RowModel):
"""
Schema
======
CREATE TABLE bot_channels(
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
autojoin BOOLEAN DEFAULT true,
listen_redeems BOOLEAN,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'bot_channels'
_cache_ = {}
userid = String(primary=True)
autojoin = Bool()
listen_redeems = Bool()
joined_at = Timestamp()
_timestamp = Timestamp()
class UserProfile(RowModel):
"""
Schema
======
CREATE TABLE user_profiles(
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
twitchid TEXT NOT NULL,
name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_profiles'
_cache_ = {}
profileid = Integer(primary=True)
twitchid = String()
name = String()
created_at = Timestamp()
_timestamp = Timestamp()
class Communities(RowModel):
"""
Schema
======
CREATE TABLE communities(
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
twitchid TEXT NOT NULL,
name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities'
_cache_ = {}
communityid = Integer(primary=True)
twitchid = Integer()
name = String()
created_at = Timestamp()
_timestamp = Timestamp()
class Koan(RowModel):
"""
Schema
======
CREATE TABLE koans(
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
name TEXT NOT NULL,
message TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'koans'
_cache_ = {}
koanid = Integer(primary=True)
communityid = Integer()
name = String()
message = String()
created_at = Timestamp()
_timestamp = Timestamp()
class BotData(Registry):
user_auth = UserAuth.table
"""
CREATE TABLE user_auth_scopes(
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
scope TEXT NOT NULL
);
"""
user_auth_scopes = Table('user_auth_scopes')
bot_channels = BotChannel.table
user_profiles = UserProfile.table
communities = Communities.table
koans = Koan.table

View File

@@ -1,4 +1,5 @@
from .args import args from .args import args
from .crocbot import CrocBot from .bot import Bot
from .context import Context
from .config import Conf, conf from .config import Conf, conf
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task

View File

@@ -1,20 +1,25 @@
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Any, Literal, Optional, overload
from twitchio.authentication import UserTokenPayload from twitchio.authentication import UserTokenPayload
from twitchio.ext import commands from twitchio.ext import commands
from twitchio import Scopes, eventsub from twitchio import Scopes, eventsub
from data import Database from data import Database, ORDER
from datamodels import BotData, UserAuth, BotChannel from botdata import BotData, UserAuth, BotChannel, VersionHistory
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
from .config import Conf from .config import Conf
from .context import Context
if TYPE_CHECKING:
from modules.profiles.profiles.twitch.component import ProfilesComponent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CrocBot(commands.Bot): class Bot(commands.Bot):
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs): def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
kwargs.setdefault('client_id', config.bot['client_id']) kwargs.setdefault('client_id', config.bot['client_id'])
kwargs.setdefault('client_secret', config.bot['client_secret']) kwargs.setdefault('client_secret', config.bot['client_secret'])
@@ -23,6 +28,12 @@ class CrocBot(commands.Bot):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Whether we should do eventsub via webhooks or websockets
if config.bot.get('eventsub_secret', None):
self.using_webhooks = True
else:
self.using_webhooks = False
self.config = config self.config = config
self.dbconn = dbconn self.dbconn = dbconn
self.data: BotData = dbconn.load_registry(BotData()) self.data: BotData = dbconn.load_registry(BotData())
@@ -30,12 +41,56 @@ class CrocBot(commands.Bot):
self.joined: dict[str, BotChannel] = {} self.joined: dict[str, BotChannel] = {}
# Make the type checker happy about fetching components by name
# TODO: Move to stubs
@property
def profiles(self):
return self.get_component('ProfilesComponent')
@overload
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
...
@overload
def get_component(self, name: str) -> Optional[commands.Component]:
...
def get_component(self, name: str) -> Optional[commands.Component]:
return super().get_component(name)
def get_context(self, payload, *, cls: Any = None) -> Context:
cls = cls or Context
return cls(payload, bot=self)
async def event_ready(self): async def event_ready(self):
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") # logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
logger.info("Logged in as %s", self.bot_id) logger.info("Logged in as %s", self.bot_id)
async def version_check(self, component: str, req_version: int):
# Query the database to confirm that the given component is listed with the given version.
# Typically done upon loading a component
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
version = rows[0].to_version if rows else 0
if version != req_version:
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
else:
logger.debug(
"Component %s passed version check with version %s",
component,
version
)
return True
async def setup_hook(self): async def setup_hook(self):
await self.data.init() await self.data.init()
for component, req in SCHEMA_VERSIONS.items():
await self.version_check(component, req)
if self._setup_hook is not None:
await self._setup_hook(self)
# Get all current bot channels # Get all current bot channels
channels = await BotChannel.fetch_where(autojoin=True) channels = await BotChannel.fetch_where(autojoin=True)
@@ -44,29 +99,15 @@ class CrocBot(commands.Bot):
await self.join_channels(*channels) await self.join_channels(*channels)
# Build bot account's own url # Build bot account's own url
scopes = Scopes(( scopes = BOTUSER_SCOPES
Scopes.user_read_chat,
Scopes.user_write_chat,
Scopes.user_bot,
Scopes.channel_read_redemptions,
Scopes.channel_manage_redemptions,
Scopes.channel_bot,
))
url = self.get_auth_url(scopes) url = self.get_auth_url(scopes)
logger.info("Bot account authorisation url: %s", url) logger.info("Bot account authorisation url: %s", url)
# Build everyone else's url # Build everyone else's url
scopes = Scopes(( scopes = CHANNEL_SCOPES
Scopes.channel_bot,
Scopes.channel_read_redemptions,
Scopes.channel_manage_redemptions,
))
url = self.get_auth_url(scopes) url = self.get_auth_url(scopes)
logger.info("User account authorisation url: %s", url) logger.info("User account authorisation url: %s", url)
if self._setup_hook is not None:
await self._setup_hook(self)
logger.info("Finished setup") logger.info("Finished setup")
def get_auth_url(self, scopes: Optional[Scopes] = None): def get_auth_url(self, scopes: Optional[Scopes] = None):
@@ -81,7 +122,6 @@ class CrocBot(commands.Bot):
Register webhook subscriptions to the given channel(s). Register webhook subscriptions to the given channel(s).
""" """
# TODO: If channels are already joined, unsubscribe # TODO: If channels are already joined, unsubscribe
# TODO: Determine (or switch?) whether to use webhook or websocket
for channel in channels: for channel in channels:
sub = None sub = None
try: try:
@@ -89,19 +129,16 @@ class CrocBot(commands.Bot):
broadcaster_user_id=channel.userid, broadcaster_user_id=channel.userid,
user_id=self.bot_id, user_id=self.bot_id,
) )
resp = await self.subscribe_websocket(sub) if self.using_webhooks:
resp = await self.subscribe_webhook(sub)
else:
resp = await self.subscribe_websocket(sub)
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp) logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
if channel.listen_redeems:
sub = eventsub.ChannelPointsRedeemAddSubscription(
broadcaster_user_id=channel.userid
)
resp = await self.subscribe_websocket(sub, as_bot=False, token_for=channel.userid)
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
self.joined[channel.userid] = channel self.joined[channel.userid] = channel
self.safe_dispatch('channel_joined', payload=channel)
except Exception: except Exception:
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub) logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
async def event_oauth_authorized(self, payload: UserTokenPayload): async def event_oauth_authorized(self, payload: UserTokenPayload):
logger.debug("Oauth flow authorization with payload %s", repr(payload)) logger.debug("Oauth flow authorization with payload %s", repr(payload))
# Save the token and scopes and update internal authorisations # Save the token and scopes and update internal authorisations
@@ -141,21 +178,22 @@ class CrocBot(commands.Bot):
# Save the token and scopes to data # Save the token and scopes to data
# Wrap this in a transaction so if it fails halfway we rollback correctly # Wrap this in a transaction so if it fails halfway we rollback correctly
async with self.dbconn.connection() as conn: # TODO
self.dbconn.conn = conn row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
async with conn.transaction(): if row.token != token or row.refresh_token != refresh:
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh) await row.update(token=token, refresh_token=refresh)
if row.token != token or row.refresh_token != refresh: await self.data.user_auth_scopes.delete_where(userid=userid)
await row.update(token=token, refresh_token=refresh) await self.data.user_auth_scopes.insert_many(
await self.data.user_auth_scopes.delete_where(userid=userid) ('userid', 'scope'),
await self.data.user_auth_scopes.insert_many( *((userid, scope) for scope in new_scopes)
('userid', 'scope'), )
*((userid, scope) for scope in new_scopes)
)
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes)) logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
return resp return resp
async def load_tokens(self, path: str | None = None): async def load_tokens(self, path: str | None = None):
for row in await UserAuth.fetch_where(): for row in await UserAuth.fetch_where():
await self.add_token(row.token, row.refresh_token) try:
await self.add_token(row.token, row.refresh_token)
except Exception:
logger.exception(f"Failed to add token for {row}")

View File

@@ -102,4 +102,4 @@ class Conf:
self.config.write(conffile) self.config.write(conffile)
conf = Conf(args.config, 'CROCBOT') conf = Conf(args.config, 'BOT')

4
src/meta/context.py Normal file
View File

@@ -0,0 +1,4 @@
from twitchio.ext import commands as cmds
class Context(cmds.Context):
...

View File

@@ -14,6 +14,7 @@ import aiohttp
from .config import conf from .config import conf
from utils.ratelimits import Bucket, BucketOverFull, BucketFull from utils.ratelimits import Bucket, BucketOverFull, BucketFull
from utils.lib import utc_now
log_logger = logging.getLogger(__name__) log_logger = logging.getLogger(__name__)
@@ -24,8 +25,7 @@ log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=()) log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT") log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
# TODO merge into TwithIO context context: ContextVar[Optional[str]] = ContextVar('context', default=None)
context: ContextVar[str] = ContextVar('context', default=None)
def set_logging_context( def set_logging_context(
context: Optional[str] = None, context: Optional[str] = None,
@@ -299,6 +299,7 @@ class WebHookHandler(logging.StreamHandler):
asyncio.create_task(self.post(record)) asyncio.create_task(self.post(record))
def setup(self): def setup(self):
from discord import Webhook
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
self.webhook = Webhook.from_url(self.webhook_url, session=self.session) self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
@@ -365,6 +366,8 @@ class WebHookHandler(logging.StreamHandler):
await self._send(batched) await self._send(batched)
async def _send(self, message, as_file=False): async def _send(self, message, as_file=False):
import discord
from discord import File
try: try:
self.bucket.request() self.bucket.request()
except BucketOverFull: except BucketOverFull:

View File

@@ -1,3 +1,9 @@
async def setup(bot): from typing import TYPE_CHECKING
from . import koans
await koans.setup(bot) if TYPE_CHECKING:
from meta import Bot
async def twitch_setup(bot: 'Bot'):
from . import profiles
await profiles.twitch_setup(bot)

View File

@@ -1,7 +0,0 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .component import KoanComponent
await bot.add_component(KoanComponent(bot))

View File

@@ -1,123 +0,0 @@
from typing import Optional
import random
import twitchio
from twitchio.ext import commands as cmds
from datamodels import Koan, Communities
from . import logger
class KoanComponent(cmds.Component):
def __init__(self, bot):
self.bot = bot
@cmds.Component.listener()
async def event_message(self, payload: twitchio.ChatMessage) -> None:
print(f"[{payload.broadcaster.name}] - {payload.chatter.name}: {payload.text}")
@cmds.group(invoke_fallback=True)
async def koans(self, ctx: cmds.Context) -> None:
"""
List the (names of the) koans in this channel.
!koans
"""
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
cid = community.communityid
koans = await Koan.fetch_where(communityid=cid)
if koans:
names = ', '.join(koan.name for koan in koans)
await ctx.reply(
f"Koans: {names}"
)
else:
await ctx.reply("No koans have been made in this channel!")
@koans.command(name='add', aliases=['new', 'create'])
async def koans_add(self, ctx: cmds.Context, name: str, *, text: str):
"""
Add or overwrite a koan to this channel.
!koans add wind This is a wind koan
"""
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
cid = community.communityid
name = name.lower()
assert isinstance(ctx.author, twitchio.Chatter)
if (ctx.author.moderator or ctx.author.broadcaster):
# Delete the koan with this name if it exists
existing = await Koan.table.delete_where(
communityid=cid,
name=name,
)
# Insert the new koan
await Koan.create(
communityid=cid,
name=name,
message=text
)
# Ack
if existing:
await ctx.reply(f"Updated the koan '{name}'")
else:
await ctx.reply(f"Created the new koan '{name}'")
@koans.command(name='del', aliases=['delete', 'rm', 'remove'])
async def koans_del(self, ctx: cmds.Context, name: str):
"""
Remove a koan from this channel by name.
!koans del wind
"""
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
cid = community.communityid
name = name.lower()
assert isinstance(ctx.author, twitchio.Chatter)
if (ctx.author.moderator or ctx.author.broadcaster):
# Delete the koan with this name if it exists
existing = await Koan.table.delete_where(
communityid=cid,
name=name,
)
if existing:
await ctx.reply(f"Deleted the koan '{name}'")
else:
await ctx.reply(f"The koan '{name}' does not exist to delete!")
@cmds.command(name='koan')
async def koan(self, ctx: cmds.Context, name: Optional[str] = None):
"""
Show a koan from this channel. Optionally by name.
!koan
!koan wind
"""
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
cid = community.communityid
if name is not None:
name = name.lower()
koans = await Koan.fetch_where(
communityid=cid,
name=name
)
if koans:
koan = koans[0]
await ctx.reply(koan.message)
else:
await ctx.reply(f"The requested koan '{name}' does not exist! Use '{ctx.prefix}koans' to see all the koans.")
else:
koans = await Koan.fetch_where(communityid=cid)
if koans:
koan = random.choice(koans)
await ctx.reply(koan.message)
else:
await ctx.reply("This channel doesn't have any koans!")

1
src/modules/profiles Submodule

Submodule src/modules/profiles added at 0363dc2bcd