Initial framework

This commit is contained in:
2025-07-24 03:33:42 +10:00
commit 9c0ae404c8
31 changed files with 3435 additions and 0 deletions

151
.gitignore vendored Normal file
View File

@@ -0,0 +1,151 @@
src/modules/test/*
pending-rewrite/
logs/*
notes/*
tmp/*
output/*
locales/domains
.idea/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
config/**

16
config/example-bot.conf Normal file
View File

@@ -0,0 +1,16 @@
[CROCBOT]
prefix = ?
owner_id =
bot_id =
ALSO_READ = config/secrets.conf
wshost = localhost
wsport = 4343
wsdomain = localhost:4343
[LOGGING]
general_log =
warning_log =
error_log =
critical_log =

View File

@@ -0,0 +1,7 @@
[CROCBOT]
client_id =
client_secret =
[DATA]
args =
appid =

0
data/.gitignore vendored Normal file
View File

116
data/schema.sql Normal file
View File

@@ -0,0 +1,116 @@
-- Metadata {{{
CREATE TABLE VersionHistory(
version INTEGER NOT NULL,
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
author TEXT
);
INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation');
CREATE OR REPLACE FUNCTION update_timestamp_column()
RETURNS TRIGGER AS $$
BEGIN
NEW._timestamp = now();
RETURN NEW;
END;
$$ language 'plpgsql';
-- }}}
-- App metadata {{{
CREATE TABLE app_config(
appname TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- }}}
-- Twitch Auth {{{
-- Authorisation tokens allowing us to take actions on behalf of certain users or channels.
-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope.
CREATE TABLE user_auth(
userid TEXT PRIMARY KEY,
token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER user_auth_timestamp BEFORE UPDATE ON user_auth
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
CREATE TABLE user_auth_scopes(
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
scope TEXT NOT NULL
);
-- Which joins will be joined at startup,
-- and any configurable choices needed when joining the channel
CREATE TABLE bot_channels(
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
autojoin BOOLEAN DEFAULT true,
listen_redeems BOOLEAN,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
-- }}}
-- Twitch user data {{{
---- Users are internally represented by 'profiles' with a unique profileid
---- This is associated to the user's twitch userid.
---- Any user-specific configuration data or preferences can be added here
CREATE TABLE user_profiles(
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
twitchid TEXT NOT NULL,
name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER user_profiles_timestamp BEFORE UPDATE ON user_profiles
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
CREATE UNIQUE INDEX user_profile_twitchid ON user_profiles (twitchid);
-- }}}
-- Twitch channel data {{{
---- Similar to user profiles, we associate twitch channels with 'communities'
---- This slight abstraction gives us more flexibility and control over the community and user data.
CREATE TABLE communities(
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
twitchid TEXT NOT NULL,
name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER communities_timestamp BEFORE UPDATE ON communities
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
-- }}}
-- Koan data {{{
---- !koans lists koans. !koan gives a random koan. !koans add name ... !koans del name ...
CREATE TABLE koans(
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
name TEXT NOT NULL,
message TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER koans_timestamp BEFORE UPDATE ON koans
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
-- }}}
-- vim: set fdm=marker:

4
requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
websockets
twitchio
psycopg[pool]
cachetools

12
scripts/start_bot.py Normal file
View File

@@ -0,0 +1,12 @@
# !/bin/python3
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
if __name__ == '__main__':
from bot import _main
_main()

46
src/bot.py Normal file
View File

@@ -0,0 +1,46 @@
import asyncio
import logging
from twitchio.web import AiohttpAdapter
from meta import CrocBot, conf, setup_main_logger, args
from data import Database
from constants import DATA_VERSION
logger = logging.getLogger(__name__)
async def main():
db = Database(conf.data['args'])
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
adapter = AiohttpAdapter(
host=conf.bot.get('wshost', None),
port=conf.bot.getint('wsport', None),
domain=conf.bot.get('wsdomain', None),
)
bot = CrocBot(
config=conf,
dbconn=db,
adapter=adapter,
)
# await bot.load_module('modules')
try:
await bot.start()
finally:
await bot.close()
def _main():
setup_main_logger()
asyncio.run(main())

2
src/constants.py Normal file
View File

@@ -0,0 +1,2 @@
CONFIG_FILE = 'config/bot.conf'
DATA_VERSION = 1

9
src/data/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .conditions import Condition, condition, NULL
from .database import Database
from .models import RowModel, RowTable, WeakCache
from .table import Table
from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum
from .queries import ORDER, NULLS, JOINTYPE

40
src/data/adapted.py Normal file
View File

@@ -0,0 +1,40 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
class RegisterEnum(Attachable):
def __init__(self, enum, name: Optional[str] = None, mapper=None):
super().__init__()
self.enum = enum
self.name = name or enum.__name__
self.mapping = mapper(enum) if mapper is not None else self._mapper()
def _mapper(self):
return {m: m.value[0] for m in self.enum}
def attach_to(self, registry: Registry):
self._registry = registry
registry.init_task(self.on_init)
return self
async def on_init(self, registry: Registry):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

45
src/data/base.py Normal file
View File

@@ -0,0 +1,45 @@
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from itertools import chain
from psycopg import sql
@runtime_checkable
class Expression(Protocol):
__slots__ = ()
@abstractmethod
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
raise NotImplementedError
class RawExpr(Expression):
__slots__ = ('expr', 'values')
expr: sql.Composable
values: tuple[Any, ...]
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
self.expr = expr
self.values = values
def as_tuple(self):
return (self.expr, self.values)
@classmethod
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
"""
Join a sequence of Expressions into a single RawExpr.
"""
tups = (
expression.as_tuple()
for expression in expressions
)
return cls.join_tuples(*tups, joiner=joiner)
@classmethod
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
exprs, values = zip(*tuples)
expr = joiner.join(exprs)
value = tuple(chain(*values))
return cls(expr, value)

155
src/data/columns.py Normal file
View File

@@ -0,0 +1,155 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
from psycopg import sql
from datetime import datetime
from .base import RawExpr, Expression
from .conditions import Condition, Joiner
from .table import Table
class ColumnExpr(RawExpr):
__slots__ = ()
def __lt__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column < Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LT, obj_expr)
cond_values = (*values, *obj_values)
else:
# column < Literal
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __le__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column <= Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LE, obj_expr)
cond_values = (*values, *obj_values)
else:
# column <= Literal
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __eq__(self, obj) -> Condition: # type: ignore[override]
return Condition._expression_equality(self, obj)
def __ne__(self, obj) -> Condition: # type: ignore[override]
return ~(self.__eq__(obj))
def __gt__(self, obj) -> Condition:
return ~(self.__le__(obj))
def __ge__(self, obj) -> Condition:
return ~(self.__lt__(obj))
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __sub__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __mul__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def CAST(self, target_type: sql.Composable):
return ColumnExpr(
sql.SQL("({}::{})").format(self.expr, target_type),
self.values
)
T = TypeVar('T')
if TYPE_CHECKING:
from .models import RowModel
class Column(ColumnExpr, Generic[T]):
def __init__(self, name: Optional[str] = None,
primary: bool = False, references: Optional['Column'] = None,
type: Optional[Type[T]] = None):
self.primary = primary
self.references = references
self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None
self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = ()
def __set_name__(self, owner, name):
# Only allow setting the owner once
self.name = self.name or name
self.owner = owner
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
@overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
...
@overload
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
...
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
# Get value from row data or session
if obj is None:
return self
else:
return obj.data[self.name]
class Integer(Column[int]):
pass
class String(Column[str]):
pass
class Bool(Column[bool]):
pass
class Timestamp(Column[datetime]):
pass

214
src/data/conditions.py Normal file
View File

@@ -0,0 +1,214 @@
# from meta import sharding
from typing import Any, Union
from enum import Enum
from itertools import chain
from psycopg import sql
from .base import Expression, RawExpr
"""
A Condition is a "logical" database expression, intended for use in Where statements.
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
"""
NULL = None
class Joiner(Enum):
EQUALS = ('=', '!=')
IS = ('IS', 'IS NOT')
LIKE = ('LIKE', 'NOT LIKE')
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
IN = ('IN', 'NOT IN')
LT = ('<', '>=')
LE = ('<=', '>')
NONE = ('', '')
class Condition(Expression):
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
def __init__(self,
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
values: tuple[Any, ...] = (), negated=False
):
self.expr1 = expr1
self.joiner = joiner
self.negated = negated
self.expr2 = expr2
self.values = values
def as_tuple(self):
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
if self.negated and self.joiner is Joiner.NONE:
expr = sql.SQL("NOT ({})").format(expr)
return (expr, self.values)
@classmethod
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
"""
Construct a Condition from a sequence of Conditions,
together with some explicit column conditions.
"""
# TODO: Consider adding a _table identifier here so we can identify implicit columns
# Or just require subquery type conditions to always come from modelled tables.
implicit_conditions = (
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
)
return cls._and(*conditions, *implicit_conditions)
@classmethod
def _and(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _or(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _not(cls, condition: 'Condition'):
condition.negated = not condition.negated
return condition
@classmethod
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
# TODO: Check if this supports sbqueries
col_expr, col_values = column.as_tuple()
# TODO: Also support sql.SQL? For joins?
if isinstance(value, Expression):
# column = Expression
value_expr, value_values = value.as_tuple()
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
cond_values = (*col_values, *value_values)
elif isinstance(value, (tuple, list)):
# column in (...)
# TODO: Support expressions in value tuple?
if not value:
raise ValueError("Cannot create Condition from empty iterable!")
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
cond_exprs = (col_expr, Joiner.IN, value_expr)
cond_values = (*col_values, *value)
elif value is None:
# column IS NULL
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
cond_values = col_values
else:
# column = Literal
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
cond_values = (*col_values, value)
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __invert__(self) -> 'Condition':
self.negated = not self.negated
return self
def __and__(self, condition: 'Condition') -> 'Condition':
return self._and(self, condition)
def __or__(self, condition: 'Condition') -> 'Condition':
return self._or(self, condition)
# Helper method to simply condition construction
def condition(*args, **kwargs) -> Condition:
return Condition.construct(*args, **kwargs)
# class NOT(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# if item:
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
# values.extend(item)
# else:
# raise ValueError("Cannot check an empty iterable!")
# else:
# conditions.append("{}!={}".format(key, _replace_char))
# values.append(item)
#
#
# class GEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply GEQ condition to a list!")
# else:
# conditions.append("{} >= {}".format(key, _replace_char))
# values.append(item)
#
#
# class LEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply LEQ condition to a list!")
# else:
# conditions.append("{} <= {}".format(key, _replace_char))
# values.append(item)
#
#
# class Constant(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# conditions.append("{} {}".format(key, self.value))
#
#
# class SHARDID(Condition):
# __slots__ = ('shardid', 'shard_count')
#
# def __init__(self, shardid, shard_count):
# self.shardid = shardid
# self.shard_count = shard_count
#
# def apply(self, key, values, conditions):
# if self.shard_count > 1:
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
# values.append(self.shardid)
#
#
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
#
#
# NULL = Constant('IS NULL')
# NOTNULL = Constant('IS NOT NULL')

135
src/data/connector.py Normal file
View File

@@ -0,0 +1,135 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Convenience property for the current context connection.
"""
return ctx_connection.get()
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=1,
max_size=4,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError

42
src/data/cursor.py Normal file
View File

@@ -0,0 +1,42 @@
import logging
from typing import Optional
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import conn_encoding
logger = logging.getLogger(__name__)
class AsyncLoggingCursor(AsyncCursor):
def mogrify_query(self, query: Query):
if isinstance(query, str):
msg = query
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
if logging.DEBUG >= logger.getEffectiveLevel():
msg = self.mogrify_query(query)
logger.debug(
"Executing query (%s) with values %s", msg, params,
extra={'action': "Query Execute"}
)
try:
return await super().execute(query, params=params, **kwargs)
except Exception:
msg = self.mogrify_query(query)
logger.exception(
"Exception during query execution. Query (%s) with parameters %s.",
msg, params,
extra={'action': "Query Execute"},
stack_info=True
)
else:
# TODO: Possibly log execution time
pass

47
src/data/database.py Normal file
View File

@@ -0,0 +1,47 @@
from typing import TypeVar
import logging
from collections import namedtuple
# from .cursor import AsyncLoggingCursor
from .registry import Registry
from .connector import Connector
logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author'))
T = TypeVar('T', bound=Registry)
class Database(Connector):
# cursor_factory = AsyncLoggingCursor
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.registries: dict[str, Registry] = {}
def load_registry(self, registry: T) -> T:
logger.debug(
f"Loading and binding registry '{registry.name}'.",
extra={'action': f"Reg {registry.name}"}
)
registry.bind(self)
self.registries[registry.name] = registry
return registry
async def version(self) -> Version:
"""
Return the current schema version as a Version namedtuple.
"""
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

323
src/data/models.py Normal file
View File

@@ -0,0 +1,323 @@
from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
if data:
self.model._make_rows(*data)
return data[0]
else:
return None
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.identifier,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
WK = TypeVar('WK')
WV = TypeVar('WV')
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable = None
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
@classmethod
def as_tuple(cls):
return (cls.table.identifier, ())
def __init__(self, data):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}.{}({})".format(
self.table.schema,
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None) if cached else None
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None

644
src/data/queries.py Normal file
View File

@@ -0,0 +1,644 @@
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from enum import Enum
from itertools import chain
from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql
from psycopg.rows import DictRow
import logging
from .conditions import Condition
from .base import Expression, RawExpr
from .connector import Connector
logger = logging.getLogger(__name__)
TQueryT = TypeVar('TQueryT', bound='TableQuery')
SQueryT = TypeVar('SQueryT', bound='Select')
QueryResult = TypeVar('QueryResult')
class Query(Generic[QueryResult]):
"""
ABC for an executable query statement.
"""
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
_adapter: Callable[..., QueryResult]
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
self.connector: Optional[Connector] = connector
self.conn: Optional[AsyncConnection] = conn
self.cursor: Optional[AsyncCursor] = cursor
if row_adapter is not None:
self._adapter = row_adapter
else:
self._adapter = self._no_adapter
self.result: Optional[QueryResult] = None
def bind(self, connector: Connector):
self.connector = connector
return self
def with_cursor(self, cursor: AsyncCursor):
self.cursor = cursor
return self
def with_connection(self, conn: AsyncConnection):
self.conn = conn
return self
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def with_adapter(self, callable: Callable[..., QueryResult]):
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
# For this to work cleanly, callable should have arg type of QR1, not any
self._adapter = callable
return self
def with_no_adapter(self):
"""
Sets the adapater to the identity.
"""
self._adapter = self._no_adapter
return self
def one(self):
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
return self
def build(self) -> Expression:
raise NotImplementedError
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
query, values = self.build().as_tuple()
# TODO: Move logging out to a custom cursor
# logger.debug(
# f"Executing query ({query.as_string(cursor)}) with values {values}",
# extra={'action': "Query"}
# )
await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall()
self.result = self._adapter(*data)
return self.result
async def execute(self, cursor=None) -> QueryResult:
"""
Execute the query, optionally with the provided cursor, and return the result rows.
If no cursor is provided, and no cursor has been set with `with_cursor`,
the execution will create a new cursor from the connection and close it automatically.
"""
# Create a cursor if possible
cursor = cursor if cursor is not None else self.cursor
if self.cursor is None:
if self.conn is None:
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data
def __await__(self):
return self.execute().__await__()
class TableQuery(Query[QueryResult]):
"""
ABC for an executable query statement expected to be run on a single table.
"""
__slots__ = (
'tableid',
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
)
def __init__(self, tableid, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tableid: sql.Identifier = tableid
def options(self, **kwargs):
"""
Set some query options.
Default implementation does nothing.
Should be overridden to provide specific options.
"""
return self
class WhereMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.condition: Optional[Condition] = None
def where(self, *args: Condition, **kwargs):
"""
Add a Condition to the query.
Position arguments should be Conditions,
and keyword arguments should be of the form `column=Value`,
where Value may be a Value-type or a literal value.
All provided Conditions will be and-ed together to create a new Condition.
TODO: Maybe just pass this verbatim to a condition.
"""
if args or kwargs:
condition = Condition.construct(*args, **kwargs)
if self.condition is not None:
condition = self.condition & condition
self.condition = condition
return self
@property
def _where_section(self) -> Optional[Expression]:
if self.condition is not None:
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
else:
return None
class JOINTYPE(Enum):
LEFT = sql.SQL('LEFT JOIN')
RIGHT = sql.SQL('RIGHT JOIN')
INNER = sql.SQL('INNER JOIN')
OUTER = sql.SQL('OUTER JOIN')
FULLOUTER = sql.SQL('FULL OUTER JOIN')
class JoinMixin(TableQuery[QueryResult]):
__slots__ = ()
# TODO: Remember to add join slots to TableQuery
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._joins: list[Expression] = []
def join(self,
target: Union[str, Expression],
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
join_type: JOINTYPE = JOINTYPE.INNER,
natural=False):
available = (on is not None) + (using is not None) + natural
if available == 0:
raise ValueError("No conditions given for Query Join")
if available > 1:
raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
if isinstance(target, str):
sections.append((sql.Identifier(target), ()))
else:
sections.append(target.as_tuple())
if on is not None:
sections.append((sql.SQL('ON'), ()))
sections.append(on.as_tuple())
elif using is not None:
sections.append((sql.SQL('USING'), ()))
if isinstance(using, Expression):
sections.append(using.as_tuple())
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
sections.append((cols, ()))
else:
raise ValueError("Unrecognised 'using' type.")
elif natural:
sections.insert(0, (sql.SQL('NATURAL'), ()))
expr = RawExpr.join_tuples(*sections)
self._joins.append(expr)
return self
def leftjoin(self, *args, **kwargs):
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
@property
def _join_section(self) -> Optional[Expression]:
if self._joins:
return RawExpr.join(*self._joins)
else:
return None
class ExtraMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extra: Optional[Expression] = None
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
"""
Add an extra string, and optionally values, to this query.
The extra string is inserted after any condition, and before the limit.
"""
extra_expr = RawExpr(extra, values)
if self._extra is not None:
extra_expr = RawExpr.join(self._extra, extra_expr)
self._extra = extra_expr
return self
@property
def _extra_section(self) -> Optional[Expression]:
if self._extra is None:
return None
else:
return self._extra
class LimitMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._limit: Optional[int] = None
def limit(self, limit: int):
"""
Add a limit to this query.
"""
self._limit = limit
return self
@property
def _limit_section(self) -> Optional[Expression]:
if self._limit is not None:
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
else:
return None
class FromMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._from: Optional[Expression] = None
def from_expr(self, _from: Expression):
self._from = _from
return self
@property
def _from_section(self) -> Optional[Expression]:
if self._from is not None:
expr, values = self._from.as_tuple()
return RawExpr(sql.SQL("FROM {}").format(expr), values)
else:
return None
class ORDER(Enum):
ASC = sql.SQL('ASC')
DESC = sql.SQL('DESC')
class NULLS(Enum):
FIRST = sql.SQL('NULLS FIRST')
LAST = sql.SQL('NULLS LAST')
class OrderMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._order: list[Expression] = []
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
"""
Add a single sort expression to the query.
This method stacks.
"""
if isinstance(expr, Expression):
string, values = expr.as_tuple()
else:
string = sql.Identifier(expr)
values = ()
parts = [string]
if direction is not None:
parts.append(direction.value)
if nulls is not None:
parts.append(nulls.value)
order_string = sql.SQL(' ').join(parts)
self._order.append(RawExpr(order_string, values))
return self
@property
def _order_section(self) -> Optional[Expression]:
if self._order:
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
return expr
else:
return None
class GroupMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._group: list[Expression] = []
def group_by(self, *exprs: Union[Expression, str]):
"""
Add a group expression(s) to the query.
This method stacks.
"""
for expr in exprs:
if isinstance(expr, Expression):
self._group.append(expr)
else:
self._group.append(RawExpr(sql.Identifier(expr)))
return self
@property
def _group_section(self) -> Optional[Expression]:
if self._group:
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
return expr
else:
return None
class Insert(ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table insert query.
"""
# TODO: Support ON CONFLICT for upserts
__slots__ = ('_columns', '_values', '_conflict')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[str, ...] = ()
self._values: tuple[tuple[Any, ...], ...] = ()
self._conflict: Optional[Expression] = None
def insert(self, columns, *values):
"""
Insert the given data.
Parameters
----------
columns: tuple[str]
Tuple of column names to insert.
values: tuple[tuple[Any, ...], ...]
Tuple of values to insert, corresponding to the columns.
"""
if not values:
raise ValueError("Cannot insert zero rows.")
if len(values[0]) != len(columns):
raise ValueError("Number of columns does not match length of values.")
self._columns = columns
self._values = values
return self
def on_conflict(self, ignore=False):
# TODO lots more we can do here
# Maybe return a Conflict object that can chain itself (not the query)
if ignore:
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
return self
@property
def _conflict_section(self) -> Optional[Expression]:
if self._conflict is not None:
e, v = self._conflict.as_tuple()
expr = RawExpr(
sql.SQL("ON CONFLICT {}").format(
e
),
v
)
return expr
return None
def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
single_value_str = sql.SQL('({})').format(
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
)
values_str = sql.SQL(',').join(single_value_str * len(self._values))
# TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=self.tableid,
columns=columns,
values_str=values_str
)
sections = [
RawExpr(base, tuple(chain(*self._values))),
self._conflict_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
"""
Select rows from a table matching provided conditions.
"""
__slots__ = ('_columns',)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[Expression, ...] = ()
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
"""
Set the columns and expressions to select.
If none are given, selects all columns.
"""
cols: List[Expression] = []
if columns:
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
if exprs:
for name, expr in exprs.items():
if isinstance(expr, str):
cols.append(
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, sql.Composable):
cols.append(
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, Expression):
value_expr, value_values = expr.as_tuple()
cols.append(RawExpr(
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
value_values
))
if cols:
self._columns = (*self._columns, *cols)
return self
def build(self):
if not self._columns:
columns, columns_values = sql.SQL('*'), ()
else:
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns,
table=self.tableid
)
sections = [
RawExpr(base, columns_values),
self._join_section,
self._where_section,
self._group_section,
self._extra_section,
self._order_section,
self._limit_section,
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table delete query.
"""
# TODO: Cascade option for delete, maybe other options
# TODO: Require a where unless specifically disabled, for safety
def build(self):
base = sql.SQL("DELETE FROM {table}").format(
table=self.tableid,
)
sections = [
RawExpr(base),
self._where_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
__slots__ = (
'_set',
)
# TODO: Again, require a where unless specifically disabled
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set: List[Expression] = []
def set(self, **column_values: Union[Any, Expression]):
exprs: List[Expression] = []
for name, value in column_values.items():
if isinstance(value, Expression):
value_tup = value.as_tuple()
else:
value_tup = (sql.Placeholder(), (value,))
exprs.append(
RawExpr.join_tuples(
(sql.Identifier(name), ()),
value_tup,
joiner=sql.SQL(' = ')
)
)
self._set.extend(exprs)
return self
def build(self):
if not self._set:
raise ValueError("No columns provided to update.")
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format(
table=self.tableid,
set=set_expr
)
sections = [
RawExpr(base, set_values),
self._from_section,
self._where_section,
self._extra_section,
self._limit_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
# async def upsert(cursor, table, constraint, **values):
# """
# Insert or on conflict update.
# """
# valuedict = values
# keys, values = zip(*values.items())
#
# key_str = _format_insertkeys(keys)
# value_str, values = _format_insertvalues(values)
# update_key_str, update_key_values = _format_updatestr(valuedict)
#
# if not isinstance(constraint, str):
# constraint = ", ".join(constraint)
#
# await cursor.execute(
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
# table, key_str, value_str, constraint, update_key_str
# ),
# tuple((*values, *update_key_values))
# )
# return await cursor.fetchone()
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
# cursor = cursor or conn.cursor()
#
# # TODO: executemany or copy syntax now
# return execute_values(
# cursor,
# """
# UPDATE {table}
# SET {set_clause}
# FROM (VALUES {cast_row}%s)
# AS {temp_table}
# WHERE {where_clause}
# RETURNING *
# """.format(
# table=table,
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
# cast_row=cast_row + ',' if cast_row else '',
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
# ),
# values,
# fetch=True
# )

102
src/data/registry.py Normal file
View File

@@ -0,0 +1,102 @@
from typing import Protocol, runtime_checkable, Optional
from psycopg import AsyncConnection
from .connector import Connector, Connectable
@runtime_checkable
class _Attachable(Connectable, Protocol):
def attach_to(self, registry: 'Registry'):
raise NotImplementedError
class Registry:
_attached: list[_Attachable] = []
_name: Optional[str] = None
def __init_subclass__(cls, name=None):
attached = []
for _, member in cls.__dict__.items():
if isinstance(member, _Attachable):
attached.append(member)
cls._attached = attached
cls._name = name or cls.__name__
def __init__(self, name=None):
self._conn: Optional[Connector] = None
self.name: str = name if name is not None else self._name
if self.name is None:
raise ValueError("A Registry must have a name!")
self.init_tasks = []
for member in self._attached:
member.attach_to(self)
def bind(self, connector: Connector):
self._conn = connector
for child in self._attached:
child.bind(connector)
def attach(self, attachable):
self._attached.append(attachable)
if self._conn is not None:
attachable.bind(self._conn)
return attachable
def init_task(self, coro):
"""
Initialisation tasks are run to setup the registry state.
These tasks will be run in the event loop, after connection to the database.
These tasks should be idempotent, as they may be run on reload and reconnect.
"""
self.init_tasks.append(coro)
return coro
async def init(self):
for task in self.init_tasks:
await task(self)
return self
class AttachableClass:
"""ABC for a default implementation of an Attachable class."""
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
@classmethod
def bind(cls, connector: Connector):
cls._connector = connector
connector.connect_hook(cls.on_connect)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@classmethod
async def on_connect(cls, connection: AsyncConnection):
pass
class Attachable:
"""ABC for a default implementation of an Attachable object."""
def __init__(self, *args, **kwargs):
self._connector: Optional[Connector] = None
self._registry: Optional[Registry] = None
def bind(self, connector: Connector):
self._connector = connector
connector.connect_hook(self.on_connect)
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
async def on_connect(self, connection: AsyncConnection):
pass

95
src/data/table.py Normal file
View File

@@ -0,0 +1,95 @@
from typing import Optional
from psycopg.rows import DictRow
from psycopg import sql
from . import queries as q
from .connector import Connector
from .registry import Registry
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
"""
def __init__(self, name, *args, schema='public', **kwargs):
self.name: str = name
self.schema: str = schema
self.connector: Connector = None
@property
def identifier(self):
if self.schema == 'public':
return sql.Identifier(self.name)
else:
return sql.Identifier(self.schema, self.name)
def bind(self, connector: Connector):
self.connector = connector
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
if data:
return data[0]
else:
return None
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs):
# with self.conn:
# return update_many(self.identifier, *args, **kwargs)
# def upsert(self, *args, **kwargs):
# return upsert(self.identifier, *args, **kwargs)

133
src/datamodels.py Normal file
View File

@@ -0,0 +1,133 @@
from data import Registry, RowModel, Table
from data.columns import String, Timestamp, Integer, Bool
class UserAuth(RowModel):
"""
Schema
======
CREATE TABLE user_auth(
userid TEXT PRIMARY KEY,
token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_auth'
_cache_ = {}
userid = String(primary=True)
token = String()
refresh_token = String()
created_at = Timestamp()
_timestamp = Timestamp()
class BotChannel(RowModel):
"""
Schema
======
CREATE TABLE bot_channels(
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
autojoin BOOLEAN DEFAULT true,
listen_redeems BOOLEAN,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'bot_channels'
_cache_ = {}
userid = String(primary=True)
autojoin = Bool()
listen_redeems = Bool()
joined_at = Timestamp()
_timestamp = Timestamp()
class UserProfile(RowModel):
"""
Schema
======
CREATE TABLE user_profiles(
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
twitchid TEXT NOT NULL,
name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_profiles'
_cache_ = {}
profileid = Integer(primary=True)
twitchid = String()
name = String()
created_at = Timestamp()
_timestamp = Timestamp()
class Communities(RowModel):
"""
Schema
======
CREATE TABLE communities(
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
twitchid TEXT NOT NULL,
name TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities'
_cache_ = {}
communityid = Integer(primary=True)
twitchid = Integer()
name = String()
created_at = Timestamp()
_timestamp = Timestamp()
class Koan(RowModel):
"""
Schema
======
CREATE TABLE koans(
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
name TEXT NOT NULL,
message TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'koans'
_cache_ = {}
koanid = Integer(primary=True)
communityid = Integer()
name = String()
message = String()
created_at = Timestamp()
_timestamp = Timestamp()
class BotData(Registry):
user_auth = UserAuth.table
"""
CREATE TABLE user_auth_scopes(
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
scope TEXT NOT NULL
);
"""
user_auth_scopes = Table('user_auth_scopes')
bot_channels = BotChannel.table
user_profiles = UserProfile.table
communities = Communities.table
koans = Koan.table

4
src/meta/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .args import args
from .crocbot import CrocBot
from .config import Conf, conf
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task

28
src/meta/args.py Normal file
View File

@@ -0,0 +1,28 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file."
)
parser.add_argument(
'--host',
dest='host',
default='127.0.0.1',
help="IP address to run the websocket server on."
)
parser.add_argument(
'--port',
dest='port',
default='5001',
help="Port to run the websocket server on."
)
args = parser.parse_args()

105
src/meta/config.py Normal file
View File

@@ -0,0 +1,105 @@
import configparser as cfgp
from .args import args
class MapDotProxy:
"""
Allows dot access to an underlying Mappable object.
"""
__slots__ = ("_map", "_converter")
def __init__(self, mappable, converter=None):
self._map = mappable
self._converter = converter
def __getattribute__(self, key):
_map = object.__getattribute__(self, '_map')
if key == '_map':
return _map
if key in _map:
_converter = object.__getattribute__(self, '_converter')
if _converter:
return _converter(_map[key])
else:
return _map[key]
else:
return object.__getattribute__(_map, key)
def __getitem__(self, key):
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
}
)
with open(configfile) as conff:
# Opening with read_file mainly to ensure the file exists
self.config.read_file(conff)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = MapDotProxy(self.config[self.section_name])
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
name = section.upper()
return self.config[name]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config, 'CROCBOT')

157
src/meta/crocbot.py Normal file
View File

@@ -0,0 +1,157 @@
import logging
from typing import Optional
from twitchio.authentication import UserTokenPayload
from twitchio.ext import commands
from twitchio import Scopes, eventsub
from data import Database
from datamodels import BotData, UserAuth, BotChannel
from .config import Conf
logger = logging.getLogger(__name__)
class CrocBot(commands.Bot):
def __init__(self, *args, config: Conf, dbconn: Database, **kwargs):
kwargs.setdefault('client_id', config.bot['client_id'])
kwargs.setdefault('client_secret', config.bot['client_secret'])
kwargs.setdefault('bot_id', config.bot['bot_id'])
kwargs.setdefault('prefix', config.bot['prefix'])
super().__init__(*args, **kwargs)
self.config = config
self.dbconn = dbconn
self.data: BotData = dbconn.load_registry(BotData())
self.joined: dict[str, BotChannel] = {}
async def event_ready(self):
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
logger.info("Logged in as %s", self.bot_id)
async def setup_hook(self):
await self.data.init()
# Get all current bot channels
channels = await BotChannel.fetch_where(autojoin=True)
# Join the channels
await self.join_channels(*channels)
# Build bot account's own url
scopes = Scopes((
Scopes.user_read_chat,
Scopes.user_write_chat,
Scopes.user_bot,
Scopes.channel_read_redemptions,
Scopes.channel_manage_redemptions,
Scopes.channel_bot,
))
url = self.get_auth_url(scopes)
logger.info("Bot account authorisation url: %s", url)
# Build everyone else's url
scopes = Scopes((
Scopes.channel_bot,
Scopes.channel_read_redemptions,
Scopes.channel_manage_redemptions,
))
url = self.get_auth_url(scopes)
logger.info("User account authorisation url: %s", url)
logger.info("Finished setup")
def get_auth_url(self, scopes: Optional[Scopes] = None):
if scopes is None:
scopes = Scopes((Scopes.channel_bot,))
url = self._adapter.get_authorization_url(scopes=scopes)
return url
async def join_channels(self, *channels: BotChannel):
"""
Register webhook subscriptions to the given channel(s).
"""
# TODO: If channels are already joined, unsubscribe
# TODO: Determine (or switch?) whether to use webhook or websocket
for channel in channels:
sub = None
try:
sub = eventsub.ChatMessageSubscription(
broadcaster_user_id=channel.userid,
user_id=self.bot_id,
)
resp = await self.subscribe_websocket(sub)
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
if channel.listen_redeems:
sub = eventsub.ChannelPointsRedeemAddSubscription(
broadcaster_user_id=channel.userid
)
resp = await self.subscribe_websocket(sub, as_bot=False, token_for=channel.userid)
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
self.joined[channel.userid] = channel
except Exception:
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
async def event_oauth_authorized(self, payload: UserTokenPayload):
logger.debug("Oauth flow authorization with payload %s", repr(payload))
# Save the token and scopes and update internal authorisations
resp = await self.add_token(payload.access_token, payload.refresh_token)
if resp.user_id is None:
logger.warning(
"Oauth flow recieved with no user_id. Payload was: %s",
repr(payload)
)
return
# If the scopes authorised included channel:bot, ensure a BotChannel exists
# And join it if needed
if Scopes.channel_bot.value in resp.scopes:
bot_channel = await BotChannel.fetch_or_create(
resp.user_id,
autojoin=True,
)
if bot_channel.autojoin:
await self.join_channels(bot_channel)
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
async def add_token(self, token: str, refresh: str):
# Update the tokens in internal cache
# This also validates the token
# And hopefully gets the userid and scopes
resp = await super().add_token(token, refresh)
if resp.user_id is None:
logger.warning(
"Added a token with no user_id. Response was: %s",
repr(resp)
)
return resp
userid = resp.user_id
new_scopes = resp.scopes
# Save the token and scopes to data
# Wrap this in a transaction so if it fails halfway we rollback correctly
async with self.dbconn.connection() as conn:
self.dbconn.conn = conn
async with conn.transaction():
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
if row.token != token or row.refresh_token != refresh:
await row.update(token=token, refresh_token=refresh)
await self.data.user_auth_scopes.delete_where(userid=userid)
await self.data.user_auth_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in new_scopes)
)
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
return resp
async def load_tokens(self, path: str | None = None):
for row in await UserAuth.fetch_where():
await self.add_token(row.token, row.refresh_token)

466
src/meta/logger.py Normal file
View File

@@ -0,0 +1,466 @@
import sys
import logging
import asyncio
from typing import List, Optional
from logging.handlers import QueueListener, QueueHandler
import queue
import multiprocessing
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
import aiohttp
from .config import conf
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
# TODO merge into TwithIO context
context: ContextVar[str] = ContextVar('context', default=None)
def set_logging_context(
context: Optional[str] = None,
action: Optional[str] = None,
stack: Optional[tuple[str, ...]] = None
):
"""
Statically set the logging context variables to the given values.
If `action` is given, pushes it onto the `log_action_stack`.
"""
if context is not None:
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
@contextmanager
def logging_context(context=None, action=None, stack=None):
"""
Context manager for executing a block of code in a given logging context.
This context manager should only be used around synchronous code.
This is because async code *may* get cancelled or externally garbage collected,
in which case the finally block will be executed in the wrong context.
See https://github.com/python/cpython/issues/93740
This can be refactored nicely if this gets merged:
https://github.com/python/cpython/pull/99634
(It will not necessarily break on async code,
if the async code can be guaranteed to clean up in its own context.)
"""
if context is not None:
oldcontext = log_context.get()
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
try:
yield
finally:
if context is not None:
log_context.set(oldcontext)
if stack is not None or action is not None:
log_action_stack.set(astack)
def with_log_ctx(isolate=True, **kwargs):
"""
Execute a coroutine inside a given logging context.
If `isolate` is true, ensures that context does not leak
outside the coroutine.
If `isolate` is false, just statically set the context,
which will leak unless the coroutine is
called in an externally copied context.
"""
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
if isolate:
with logging_context(**kwargs):
# Task creation will synchronously copy the context
# This is gc safe
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
return await task
else:
# This will leak context changes
set_logging_context(**kwargs)
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
# For backwards compatibility
log_wrap = with_log_ctx
def persist_task(task_collection: set):
"""
Coroutine decorator that ensures the coroutine is scheduled as a task
and added to the given task_collection for strong reference
when it is called.
This is just a hack to handle discord.py events potentially
being unexpectedly garbage collected.
Since this also implicitly schedules the coroutine as a task when it is called,
the coroutine will also be run inside an isolated context.
"""
def decorator(coro):
@wraps(coro)
async def wrapped(*w_args, **w_kwargs):
name = f"persisted-{coro.__name__}"
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
task_collection.add(task)
task.add_done_callback(lambda f: task_collection.discard(f))
await task
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
'%(cyan)%(app)-15s%(reset)|' +
'%(cyan)%(context)-24s%(reset)|' +
'%(cyan)%(actionstr)-22s%(reset)|' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
# datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ExactLevelFilter(logging.Filter):
def __init__(self, target_level, name=""):
super().__init__(name)
self.target_level = target_level
def filter(self, record):
return (record.levelno == self.target_level)
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter):
def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
except asyncio.CancelledError:
raise
except Exception:
self.handleError(record)
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
super().__init__()
self.webhook_url = webhook_url
self.prefix = prefix
self.batched = ""
self.batch = batch
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
self.bucket = Bucket(20, 40)
self.ignored = 0
self.session = None
self.webhook = None
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.format(record)
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
if self.session is None:
self.setup()
asyncio.create_task(self.post(record))
def setup(self):
self.session = aiohttp.ClientSession()
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
async def post(self, record):
if record.context == 'Webhook Logger':
# Don't livelog livelog errors
# Otherwise we recurse and Cloudflare hates us
return
log_context.set("Webhook Logger")
log_action_stack.set(("Logging",))
log_app.set(record.app)
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900:
as_file = True
else:
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 1500:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += message
if len(self.batched) + len(message) > 1500:
await self._send_batched_now()
else:
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, message, as_file=False):
try:
self.bucket.request()
except BucketOverFull:
# Silently ignore
self.ignored += 1
return
except BucketFull:
logger.warning(
"Can't keep up! "
f"Ignoring records on live-logger {self.webhook.id}."
)
self.ignored += 1
return
else:
if self.ignored > 0:
logger.warning(
"Can't keep up! "
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
)
self.ignored = 0
try:
if as_file or len(message) > 1900:
with StringIO(message) as fp:
fp.seek(0)
await self.webhook.send(
f"{self.prefix}\n`{message.splitlines()[0]}`",
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else:
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
except discord.HTTPException:
logger.exception(
"Live logger errored. Slowing down live logger."
)
self.bucket.fill()
handlers = []
if webhook := conf.logging['general_log']:
handler = WebHookHandler(webhook, batch=True)
handlers.append(handler)
if webhook := conf.logging['warning_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
handler.addFilter(ExactLevelFilter(logging.WARNING))
handler.setLevel(logging.WARNING)
handlers.append(handler)
if webhook := conf.logging['error_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
handler.setLevel(logging.ERROR)
handlers.append(handler)
if webhook := conf.logging['critical_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
handler.setLevel(logging.CRITICAL)
handlers.append(handler)
def make_queue_handler(queue):
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
return qhandler
def setup_main_logger(multiprocess=False):
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
qhandler = make_queue_handler(q)
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler)
listener = QueueListener(
q, *handlers, respect_handler_level=True
)
listener.start()
return q

15
src/modules/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
this_package = 'modules'
active = [
'.vstudio',
]
def prepare(bot):
for ext in active:
bot.load_module(this_package + ext)
async def setup(bot):
for ext in active:
await bot.load_module(this_package + ext)

68
src/sockets.py Normal file
View File

@@ -0,0 +1,68 @@
from abc import ABC
from collections import defaultdict
import json
from typing import Any
import logging
import websockets
logger = logging.getLogger(__name__)
class Channel(ABC):
"""
A channel is a stateful connection handler for a group of connected websockets.
"""
name = "Root Channel"
def __init__(self, **kwargs):
self.connections = set()
@property
def empty(self):
return not self.connections
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
self.connections.add(websocket)
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
self.connections.discard(websocket)
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
raise NotImplementedError
async def send_event(self, event, websocket=None):
message = json.dumps(event)
if not websocket:
for ws in self.connections:
await ws.send(message)
else:
await websocket.send(message)
channels = {}
def register_channel(name, channel: Channel):
channels[name] = channel
async def root_handler(websocket: websockets.WebSocketServerProtocol):
message = await websocket.recv()
event = json.loads(message)
if event.get('type', None) != 'init':
raise ValueError("Received Websocket connection with no init.")
if (channel_name := event.get('channel', None)) not in channels:
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
channel = channels[channel_name]
try:
await channel.on_connection(websocket, event)
async for message in websocket:
await channel.handle_message(websocket, message)
finally:
await channel.del_connection(websocket)

88
src/utils/lib.py Normal file
View File

@@ -0,0 +1,88 @@
import re
import datetime as dt
def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def utc_now() -> dt.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
def replace_multiple(format_string, mapping):
"""
Subsistutes the keys from the format_dict with their corresponding values.
Substitution is non-chained, and done in a single pass via regex.
"""
if not mapping:
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
return string
def parse_dur(time_str):
"""
Parses a user provided time duration string into a number of seconds.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds

166
src/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,166 @@
import asyncio
import time
import logging
from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
delay = (self._level + 1 - self.max_level) * self.leak_rate
else:
delay = 0
return delay
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level > self.max_level:
raise BucketOverFull
elif self._level == self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
def fill(self):
self._leak()
self._level = max(self._level, self.max_level + 1)
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
async def wrapped(self, coro):
await self.wait()
self.request()
await coro
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
bucket.request()
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")