Initial Creation from bot template.

This commit is contained in:
2025-06-05 19:35:46 +10:00
commit 2e8d2555d5
50 changed files with 6751 additions and 0 deletions

151
.gitignore vendored Normal file
View File

@@ -0,0 +1,151 @@
src/modules/test/*
pending-rewrite/
logs/*
notes/*
tmp/*
output/*
locales/domains
.idea/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
config/**

0
data/.gitignore vendored Normal file
View File

34
data/schema.sql Normal file
View File

@@ -0,0 +1,34 @@
-- Metadata {{{
CREATE TABLE VersionHistory(
version INTEGER NOT NULL,
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
author TEXT
);
INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation');
CREATE OR REPLACE FUNCTION update_timestamp_column()
RETURNS TRIGGER AS $$
BEGIN
NEW._timestamp = (now() at time zone 'utc');
RETURN NEW;
END;
$$ language 'plpgsql';
-- }}}
-- App metadata {{{
CREATE TABLE app_config(
appname TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE bot_config(
appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE,
sponsor_prompt TEXT,
sponsor_message TEXT,
default_skin TEXT
);
-- }}}
-- vim: set fdm=marker:

6
requirements.txt Normal file
View File

@@ -0,0 +1,6 @@
cachetools
configparser
discord.py [voice]
iso8601
psycopg[pool]
pytz

12
scripts/start_bot.py Executable file
View File

@@ -0,0 +1,12 @@
# !/bin/python3
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
if __name__ == '__main__':
from bot import _main
_main()

35
scripts/start_debug.py Executable file
View File

@@ -0,0 +1,35 @@
# !/bin/python3
import sys
import os
import tracemalloc
import asyncio
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
tracemalloc.start()
def loop_exception_handler(loop, context):
print(context)
task: asyncio.Task = context.get('task', None)
if task is not None:
addendum = f"<Task name='{task.get_name()}' stack='{task.get_stack()}'>"
message = context.get('message', '')
context['message'] = ' '.join((message, addendum))
loop.default_exception_handler(context)
def main():
loop = asyncio.get_event_loop()
loop.set_exception_handler(loop_exception_handler)
loop.set_debug(enabled=True)
from bot import _main
_main()
if __name__ == '__main__':
main()

114
src/bot.py Normal file
View File

@@ -0,0 +1,114 @@
import asyncio
import logging
import aiohttp
import discord
from discord.ext import commands
from meta import LionBot, conf, sharding, appname
from meta.app import shardname
from meta.logger import log_context, log_action_stack, setup_main_logger
from meta.context import ctx_bot
from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus
from data import Database
from constants import DATA_VERSION
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
logging.getLogger(name).setLevel(conf.logging_levels[name])
logging_queue = setup_main_logger()
logger = logging.getLogger(__name__)
db = Database(conf.data['args'])
async def _data_monitor() -> ComponentStatus:
"""
Component monitor callback for the database.
"""
data = {
'stats': str(db.pool.get_stats())
}
if not db.pool._opened:
level = StatusLevel.WAITING
info = "(WAITING) Database Pool is not opened."
elif db.pool._closed:
level = StatusLevel.ERRORED
info = "(ERROR) Database Pool is closed."
else:
level = StatusLevel.OKAY
info = "(OK) Database Pool statistics: {stats}"
return ComponentStatus(level, info, info, data)
async def main():
log_action_stack.set(("Initialising",))
logger.info("Initialising StudyLion")
intents = discord.Intents.all()
intents.members = True
intents.message_content = True
intents.presences = False
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
async with aiohttp.ClientSession() as session:
async with LionBot(
command_prefix='!leo!',
intents=intents,
appname=appname,
shardname=shardname,
db=db,
config=conf,
initial_extensions=[
'core',
'modules',
],
web_client=session,
testing_guilds=conf.bot.getintlist('admin_guilds'),
shard_id=sharding.shard_number,
shard_count=sharding.shard_count,
help_command=None,
proxy=conf.bot.get('proxy', None),
chunk_guilds_at_startup=False,
) as lionbot:
ctx_bot.set(lionbot)
lionbot.system_monitor.add_component(
ComponentMonitor('Database', _data_monitor)
)
try:
log_context.set(f"APP: {appname}")
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError:
log_context.set(f"APP: {appname}")
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
def _main():
from signal import SIGINT, SIGTERM
loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(main())
for signal in [SIGINT, SIGTERM]:
loop.add_signal_handler(signal, main_task.cancel)
try:
loop.run_until_complete(main_task)
finally:
loop.close()
logging.shutdown()
if __name__ == '__main__':
_main()

6
src/constants.py Normal file
View File

@@ -0,0 +1,6 @@
CONFIG_FILE = "config/bot.conf"
DATA_VERSION = 1
MAX_COINS = 2147483647 - 1
HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"

6
src/core/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
async def setup(bot):
from .cog import CoreCog
await bot.add_cog(CoreCog(bot))

76
src/core/cog.py Normal file
View File

@@ -0,0 +1,76 @@
import logging
from typing import Optional
from collections import defaultdict
from weakref import WeakValueDictionary
import discord
import discord.app_commands as appcmd
from meta import LionBot, LionCog, LionContext
from meta.app import shardname, appname
from meta.logger import log_wrap
from utils.lib import utc_now
from .data import CoreData
logger = logging.getLogger(__name__)
class keydefaultdict(defaultdict):
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
else:
ret = self[key] = self.default_factory(key)
return ret
class CoreCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = CoreData()
bot.db.load_registry(self.data)
self.app_config: Optional[CoreData.AppConfig] = None
self.bot_config: Optional[CoreData.BotConfig] = None
self.app_cmd_cache: list[discord.app_commands.AppCommand] = []
self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {}
self.mention_cache: dict[str, str] = keydefaultdict(self.mention_cmd)
async def cog_load(self):
# Fetch (and possibly create) core data rows.
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
# Load the app command cache
await self.reload_appcmd_cache()
async def reload_appcmd_cache(self):
for guildid in self.bot.testing_guilds:
self.app_cmd_cache += await self.bot.tree.fetch_commands(guild=discord.Object(guildid))
self.app_cmd_cache += await self.bot.tree.fetch_commands()
self.cmd_name_cache = {cmd.name: cmd for cmd in self.app_cmd_cache}
self.mention_cache = self._mention_cache_from(self.app_cmd_cache)
def _mention_cache_from(self, cmds: list[appcmd.AppCommand | appcmd.AppCommandGroup]):
cache = keydefaultdict(self.mention_cmd)
for cmd in cmds:
cache[cmd.qualified_name if isinstance(cmd, appcmd.AppCommandGroup) else cmd.name] = cmd.mention
subcommands = [option for option in cmd.options if isinstance(option, appcmd.AppCommandGroup)]
if subcommands:
subcache = self._mention_cache_from(subcommands)
cache |= subcache
return cache
def mention_cmd(self, name: str):
"""
Create an application command mention for the given names.
If not found in cache, creates a 'fake' mention with an invalid id.
"""
if name in self.mention_cache:
mention = self.mention_cache[name]
else:
mention = f"</{name}:1110834049204891730>"
return mention

45
src/core/data.py Normal file
View File

@@ -0,0 +1,45 @@
from enum import Enum
from itertools import chain
from psycopg import sql
from cachetools import TTLCache
import discord
from meta import conf
from meta.logger import log_wrap
from data import Table, Registry, Column, RowModel, RegisterEnum
from data.models import WeakCache
from data.columns import Integer, String, Bool, Timestamp
class CoreData(Registry, name="core"):
class AppConfig(RowModel):
"""
Schema
------
CREATE TABLE app_config(
appname TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
_tablename_ = 'app_config'
appname = String(primary=True)
created_at = Timestamp()
class BotConfig(RowModel):
"""
Schema
------
CREATE TABLE bot_config(
appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE,
sponsor_prompt TEXT,
sponsor_message TEXT,
default_skin TEXT
);
"""
_tablename_ = 'bot_config'
appname = String(primary=True)
default_skin = String()
sponsor_prompt = String()
sponsor_message = String()

9
src/data/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .conditions import Condition, condition, NULL
from .database import Database
from .models import RowModel, RowTable, WeakCache
from .table import Table
from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum
from .queries import ORDER, NULLS, JOINTYPE

40
src/data/adapted.py Normal file
View File

@@ -0,0 +1,40 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
class RegisterEnum(Attachable):
def __init__(self, enum, name: Optional[str] = None, mapper=None):
super().__init__()
self.enum = enum
self.name = name or enum.__name__
self.mapping = mapper(enum) if mapper is not None else self._mapper()
def _mapper(self):
return {m: m.value[0] for m in self.enum}
def attach_to(self, registry: Registry):
self._registry = registry
registry.init_task(self.on_init)
return self
async def on_init(self, registry: Registry):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

45
src/data/base.py Normal file
View File

@@ -0,0 +1,45 @@
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from itertools import chain
from psycopg import sql
@runtime_checkable
class Expression(Protocol):
__slots__ = ()
@abstractmethod
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
raise NotImplementedError
class RawExpr(Expression):
__slots__ = ('expr', 'values')
expr: sql.Composable
values: tuple[Any, ...]
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
self.expr = expr
self.values = values
def as_tuple(self):
return (self.expr, self.values)
@classmethod
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
"""
Join a sequence of Expressions into a single RawExpr.
"""
tups = (
expression.as_tuple()
for expression in expressions
)
return cls.join_tuples(*tups, joiner=joiner)
@classmethod
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
exprs, values = zip(*tuples)
expr = joiner.join(exprs)
value = tuple(chain(*values))
return cls(expr, value)

155
src/data/columns.py Normal file
View File

@@ -0,0 +1,155 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
from psycopg import sql
from datetime import datetime
from .base import RawExpr, Expression
from .conditions import Condition, Joiner
from .table import Table
class ColumnExpr(RawExpr):
__slots__ = ()
def __lt__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column < Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LT, obj_expr)
cond_values = (*values, *obj_values)
else:
# column < Literal
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __le__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column <= Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LE, obj_expr)
cond_values = (*values, *obj_values)
else:
# column <= Literal
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __eq__(self, obj) -> Condition: # type: ignore[override]
return Condition._expression_equality(self, obj)
def __ne__(self, obj) -> Condition: # type: ignore[override]
return ~(self.__eq__(obj))
def __gt__(self, obj) -> Condition:
return ~(self.__le__(obj))
def __ge__(self, obj) -> Condition:
return ~(self.__lt__(obj))
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __sub__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __mul__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def CAST(self, target_type: sql.Composable):
return ColumnExpr(
sql.SQL("({}::{})").format(self.expr, target_type),
self.values
)
T = TypeVar('T')
if TYPE_CHECKING:
from .models import RowModel
class Column(ColumnExpr, Generic[T]):
def __init__(self, name: Optional[str] = None,
primary: bool = False, references: Optional['Column'] = None,
type: Optional[Type[T]] = None):
self.primary = primary
self.references = references
self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None
self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = ()
def __set_name__(self, owner, name):
# Only allow setting the owner once
self.name = self.name or name
self.owner = owner
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
@overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
...
@overload
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
...
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
# Get value from row data or session
if obj is None:
return self
else:
return obj.data[self.name]
class Integer(Column[int]):
pass
class String(Column[str]):
pass
class Bool(Column[bool]):
pass
class Timestamp(Column[datetime]):
pass

214
src/data/conditions.py Normal file
View File

@@ -0,0 +1,214 @@
# from meta import sharding
from typing import Any, Union
from enum import Enum
from itertools import chain
from psycopg import sql
from .base import Expression, RawExpr
"""
A Condition is a "logical" database expression, intended for use in Where statements.
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
"""
NULL = None
class Joiner(Enum):
EQUALS = ('=', '!=')
IS = ('IS', 'IS NOT')
LIKE = ('LIKE', 'NOT LIKE')
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
IN = ('IN', 'NOT IN')
LT = ('<', '>=')
LE = ('<=', '>')
NONE = ('', '')
class Condition(Expression):
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
def __init__(self,
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
values: tuple[Any, ...] = (), negated=False
):
self.expr1 = expr1
self.joiner = joiner
self.negated = negated
self.expr2 = expr2
self.values = values
def as_tuple(self):
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
if self.negated and self.joiner is Joiner.NONE:
expr = sql.SQL("NOT ({})").format(expr)
return (expr, self.values)
@classmethod
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
"""
Construct a Condition from a sequence of Conditions,
together with some explicit column conditions.
"""
# TODO: Consider adding a _table identifier here so we can identify implicit columns
# Or just require subquery type conditions to always come from modelled tables.
implicit_conditions = (
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
)
return cls._and(*conditions, *implicit_conditions)
@classmethod
def _and(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _or(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _not(cls, condition: 'Condition'):
condition.negated = not condition.negated
return condition
@classmethod
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
# TODO: Check if this supports sbqueries
col_expr, col_values = column.as_tuple()
# TODO: Also support sql.SQL? For joins?
if isinstance(value, Expression):
# column = Expression
value_expr, value_values = value.as_tuple()
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
cond_values = (*col_values, *value_values)
elif isinstance(value, (tuple, list)):
# column in (...)
# TODO: Support expressions in value tuple?
if not value:
raise ValueError("Cannot create Condition from empty iterable!")
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
cond_exprs = (col_expr, Joiner.IN, value_expr)
cond_values = (*col_values, *value)
elif value is None:
# column IS NULL
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
cond_values = col_values
else:
# column = Literal
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
cond_values = (*col_values, value)
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __invert__(self) -> 'Condition':
self.negated = not self.negated
return self
def __and__(self, condition: 'Condition') -> 'Condition':
return self._and(self, condition)
def __or__(self, condition: 'Condition') -> 'Condition':
return self._or(self, condition)
# Helper method to simply condition construction
def condition(*args, **kwargs) -> Condition:
return Condition.construct(*args, **kwargs)
# class NOT(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# if item:
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
# values.extend(item)
# else:
# raise ValueError("Cannot check an empty iterable!")
# else:
# conditions.append("{}!={}".format(key, _replace_char))
# values.append(item)
#
#
# class GEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply GEQ condition to a list!")
# else:
# conditions.append("{} >= {}".format(key, _replace_char))
# values.append(item)
#
#
# class LEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply LEQ condition to a list!")
# else:
# conditions.append("{} <= {}".format(key, _replace_char))
# values.append(item)
#
#
# class Constant(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# conditions.append("{} {}".format(key, self.value))
#
#
# class SHARDID(Condition):
# __slots__ = ('shardid', 'shard_count')
#
# def __init__(self, shardid, shard_count):
# self.shardid = shardid
# self.shard_count = shard_count
#
# def apply(self, key, values, conditions):
# if self.shard_count > 1:
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
# values.append(self.shardid)
#
#
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
#
#
# NULL = Constant('IS NULL')
# NOTNULL = Constant('IS NOT NULL')

135
src/data/connector.py Normal file
View File

@@ -0,0 +1,135 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Convenience property for the current context connection.
"""
return ctx_connection.get()
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError

42
src/data/cursor.py Normal file
View File

@@ -0,0 +1,42 @@
import logging
from typing import Optional
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import pgconn_encoding
logger = logging.getLogger(__name__)
class AsyncLoggingCursor(AsyncCursor):
def mogrify_query(self, query: Query):
if isinstance(query, str):
msg = query
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
if logging.DEBUG >= logger.getEffectiveLevel():
msg = self.mogrify_query(query)
logger.debug(
"Executing query (%s) with values %s", msg, params,
extra={'action': "Query Execute"}
)
try:
return await super().execute(query, params=params, **kwargs)
except Exception:
msg = self.mogrify_query(query)
logger.exception(
"Exception during query execution. Query (%s) with parameters %s.",
msg, params,
extra={'action': "Query Execute"},
stack_info=True
)
else:
# TODO: Possibly log execution time
pass

47
src/data/database.py Normal file
View File

@@ -0,0 +1,47 @@
from typing import TypeVar
import logging
from collections import namedtuple
# from .cursor import AsyncLoggingCursor
from .registry import Registry
from .connector import Connector
logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author'))
T = TypeVar('T', bound=Registry)
class Database(Connector):
# cursor_factory = AsyncLoggingCursor
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.registries: dict[str, Registry] = {}
def load_registry(self, registry: T) -> T:
logger.debug(
f"Loading and binding registry '{registry.name}'.",
extra={'action': f"Reg {registry.name}"}
)
registry.bind(self)
self.registries[registry.name] = registry
return registry
async def version(self) -> Version:
"""
Return the current schema version as a Version namedtuple.
"""
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

323
src/data/models.py Normal file
View File

@@ -0,0 +1,323 @@
from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
if data:
self.model._make_rows(*data)
return data[0]
else:
return None
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.identifier,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
WK = TypeVar('WK')
WV = TypeVar('WV')
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable = None
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
@classmethod
def as_tuple(cls):
return (cls.table.identifier, ())
def __init__(self, data):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}.{}({})".format(
self.table.schema,
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None) if cached else None
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None

644
src/data/queries.py Normal file
View File

@@ -0,0 +1,644 @@
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from enum import Enum
from itertools import chain
from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql
from psycopg.rows import DictRow
import logging
from .conditions import Condition
from .base import Expression, RawExpr
from .connector import Connector
logger = logging.getLogger(__name__)
TQueryT = TypeVar('TQueryT', bound='TableQuery')
SQueryT = TypeVar('SQueryT', bound='Select')
QueryResult = TypeVar('QueryResult')
class Query(Generic[QueryResult]):
"""
ABC for an executable query statement.
"""
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
_adapter: Callable[..., QueryResult]
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
self.connector: Optional[Connector] = connector
self.conn: Optional[AsyncConnection] = conn
self.cursor: Optional[AsyncCursor] = cursor
if row_adapter is not None:
self._adapter = row_adapter
else:
self._adapter = self._no_adapter
self.result: Optional[QueryResult] = None
def bind(self, connector: Connector):
self.connector = connector
return self
def with_cursor(self, cursor: AsyncCursor):
self.cursor = cursor
return self
def with_connection(self, conn: AsyncConnection):
self.conn = conn
return self
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def with_adapter(self, callable: Callable[..., QueryResult]):
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
# For this to work cleanly, callable should have arg type of QR1, not any
self._adapter = callable
return self
def with_no_adapter(self):
"""
Sets the adapater to the identity.
"""
self._adapter = self._no_adapter
return self
def one(self):
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
return self
def build(self) -> Expression:
raise NotImplementedError
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
query, values = self.build().as_tuple()
# TODO: Move logging out to a custom cursor
# logger.debug(
# f"Executing query ({query.as_string(cursor)}) with values {values}",
# extra={'action': "Query"}
# )
await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall()
self.result = self._adapter(*data)
return self.result
async def execute(self, cursor=None) -> QueryResult:
"""
Execute the query, optionally with the provided cursor, and return the result rows.
If no cursor is provided, and no cursor has been set with `with_cursor`,
the execution will create a new cursor from the connection and close it automatically.
"""
# Create a cursor if possible
cursor = cursor if cursor is not None else self.cursor
if self.cursor is None:
if self.conn is None:
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data
def __await__(self):
return self.execute().__await__()
class TableQuery(Query[QueryResult]):
"""
ABC for an executable query statement expected to be run on a single table.
"""
__slots__ = (
'tableid',
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
)
def __init__(self, tableid, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tableid: sql.Identifier = tableid
def options(self, **kwargs):
"""
Set some query options.
Default implementation does nothing.
Should be overridden to provide specific options.
"""
return self
class WhereMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.condition: Optional[Condition] = None
def where(self, *args: Condition, **kwargs):
"""
Add a Condition to the query.
Position arguments should be Conditions,
and keyword arguments should be of the form `column=Value`,
where Value may be a Value-type or a literal value.
All provided Conditions will be and-ed together to create a new Condition.
TODO: Maybe just pass this verbatim to a condition.
"""
if args or kwargs:
condition = Condition.construct(*args, **kwargs)
if self.condition is not None:
condition = self.condition & condition
self.condition = condition
return self
@property
def _where_section(self) -> Optional[Expression]:
if self.condition is not None:
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
else:
return None
class JOINTYPE(Enum):
LEFT = sql.SQL('LEFT JOIN')
RIGHT = sql.SQL('RIGHT JOIN')
INNER = sql.SQL('INNER JOIN')
OUTER = sql.SQL('OUTER JOIN')
FULLOUTER = sql.SQL('FULL OUTER JOIN')
class JoinMixin(TableQuery[QueryResult]):
__slots__ = ()
# TODO: Remember to add join slots to TableQuery
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._joins: list[Expression] = []
def join(self,
target: Union[str, Expression],
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
join_type: JOINTYPE = JOINTYPE.INNER,
natural=False):
available = (on is not None) + (using is not None) + natural
if available == 0:
raise ValueError("No conditions given for Query Join")
if available > 1:
raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
if isinstance(target, str):
sections.append((sql.Identifier(target), ()))
else:
sections.append(target.as_tuple())
if on is not None:
sections.append((sql.SQL('ON'), ()))
sections.append(on.as_tuple())
elif using is not None:
sections.append((sql.SQL('USING'), ()))
if isinstance(using, Expression):
sections.append(using.as_tuple())
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
sections.append((cols, ()))
else:
raise ValueError("Unrecognised 'using' type.")
elif natural:
sections.insert(0, (sql.SQL('NATURAL'), ()))
expr = RawExpr.join_tuples(*sections)
self._joins.append(expr)
return self
def leftjoin(self, *args, **kwargs):
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
@property
def _join_section(self) -> Optional[Expression]:
if self._joins:
return RawExpr.join(*self._joins)
else:
return None
class ExtraMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extra: Optional[Expression] = None
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
"""
Add an extra string, and optionally values, to this query.
The extra string is inserted after any condition, and before the limit.
"""
extra_expr = RawExpr(extra, values)
if self._extra is not None:
extra_expr = RawExpr.join(self._extra, extra_expr)
self._extra = extra_expr
return self
@property
def _extra_section(self) -> Optional[Expression]:
if self._extra is None:
return None
else:
return self._extra
class LimitMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._limit: Optional[int] = None
def limit(self, limit: int):
"""
Add a limit to this query.
"""
self._limit = limit
return self
@property
def _limit_section(self) -> Optional[Expression]:
if self._limit is not None:
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
else:
return None
class FromMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._from: Optional[Expression] = None
def from_expr(self, _from: Expression):
self._from = _from
return self
@property
def _from_section(self) -> Optional[Expression]:
if self._from is not None:
expr, values = self._from.as_tuple()
return RawExpr(sql.SQL("FROM {}").format(expr), values)
else:
return None
class ORDER(Enum):
ASC = sql.SQL('ASC')
DESC = sql.SQL('DESC')
class NULLS(Enum):
FIRST = sql.SQL('NULLS FIRST')
LAST = sql.SQL('NULLS LAST')
class OrderMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._order: list[Expression] = []
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
"""
Add a single sort expression to the query.
This method stacks.
"""
if isinstance(expr, Expression):
string, values = expr.as_tuple()
else:
string = sql.Identifier(expr)
values = ()
parts = [string]
if direction is not None:
parts.append(direction.value)
if nulls is not None:
parts.append(nulls.value)
order_string = sql.SQL(' ').join(parts)
self._order.append(RawExpr(order_string, values))
return self
@property
def _order_section(self) -> Optional[Expression]:
if self._order:
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
return expr
else:
return None
class GroupMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._group: list[Expression] = []
def group_by(self, *exprs: Union[Expression, str]):
"""
Add a group expression(s) to the query.
This method stacks.
"""
for expr in exprs:
if isinstance(expr, Expression):
self._group.append(expr)
else:
self._group.append(RawExpr(sql.Identifier(expr)))
return self
@property
def _group_section(self) -> Optional[Expression]:
if self._group:
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
return expr
else:
return None
class Insert(ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table insert query.
"""
# TODO: Support ON CONFLICT for upserts
__slots__ = ('_columns', '_values', '_conflict')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[str, ...] = ()
self._values: tuple[tuple[Any, ...], ...] = ()
self._conflict: Optional[Expression] = None
def insert(self, columns, *values):
"""
Insert the given data.
Parameters
----------
columns: tuple[str]
Tuple of column names to insert.
values: tuple[tuple[Any, ...], ...]
Tuple of values to insert, corresponding to the columns.
"""
if not values:
raise ValueError("Cannot insert zero rows.")
if len(values[0]) != len(columns):
raise ValueError("Number of columns does not match length of values.")
self._columns = columns
self._values = values
return self
def on_conflict(self, ignore=False):
# TODO lots more we can do here
# Maybe return a Conflict object that can chain itself (not the query)
if ignore:
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
return self
@property
def _conflict_section(self) -> Optional[Expression]:
if self._conflict is not None:
e, v = self._conflict.as_tuple()
expr = RawExpr(
sql.SQL("ON CONFLICT {}").format(
e
),
v
)
return expr
return None
def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
single_value_str = sql.SQL('({})').format(
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
)
values_str = sql.SQL(',').join(single_value_str * len(self._values))
# TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=self.tableid,
columns=columns,
values_str=values_str
)
sections = [
RawExpr(base, tuple(chain(*self._values))),
self._conflict_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
"""
Select rows from a table matching provided conditions.
"""
__slots__ = ('_columns',)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[Expression, ...] = ()
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
"""
Set the columns and expressions to select.
If none are given, selects all columns.
"""
cols: List[Expression] = []
if columns:
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
if exprs:
for name, expr in exprs.items():
if isinstance(expr, str):
cols.append(
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, sql.Composable):
cols.append(
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, Expression):
value_expr, value_values = expr.as_tuple()
cols.append(RawExpr(
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
value_values
))
if cols:
self._columns = (*self._columns, *cols)
return self
def build(self):
if not self._columns:
columns, columns_values = sql.SQL('*'), ()
else:
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns,
table=self.tableid
)
sections = [
RawExpr(base, columns_values),
self._join_section,
self._where_section,
self._group_section,
self._extra_section,
self._order_section,
self._limit_section,
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table delete query.
"""
# TODO: Cascade option for delete, maybe other options
# TODO: Require a where unless specifically disabled, for safety
def build(self):
base = sql.SQL("DELETE FROM {table}").format(
table=self.tableid,
)
sections = [
RawExpr(base),
self._where_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
__slots__ = (
'_set',
)
# TODO: Again, require a where unless specifically disabled
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set: List[Expression] = []
def set(self, **column_values: Union[Any, Expression]):
exprs: List[Expression] = []
for name, value in column_values.items():
if isinstance(value, Expression):
value_tup = value.as_tuple()
else:
value_tup = (sql.Placeholder(), (value,))
exprs.append(
RawExpr.join_tuples(
(sql.Identifier(name), ()),
value_tup,
joiner=sql.SQL(' = ')
)
)
self._set.extend(exprs)
return self
def build(self):
if not self._set:
raise ValueError("No columns provided to update.")
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format(
table=self.tableid,
set=set_expr
)
sections = [
RawExpr(base, set_values),
self._from_section,
self._where_section,
self._extra_section,
self._limit_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
# async def upsert(cursor, table, constraint, **values):
# """
# Insert or on conflict update.
# """
# valuedict = values
# keys, values = zip(*values.items())
#
# key_str = _format_insertkeys(keys)
# value_str, values = _format_insertvalues(values)
# update_key_str, update_key_values = _format_updatestr(valuedict)
#
# if not isinstance(constraint, str):
# constraint = ", ".join(constraint)
#
# await cursor.execute(
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
# table, key_str, value_str, constraint, update_key_str
# ),
# tuple((*values, *update_key_values))
# )
# return await cursor.fetchone()
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
# cursor = cursor or conn.cursor()
#
# # TODO: executemany or copy syntax now
# return execute_values(
# cursor,
# """
# UPDATE {table}
# SET {set_clause}
# FROM (VALUES {cast_row}%s)
# AS {temp_table}
# WHERE {where_clause}
# RETURNING *
# """.format(
# table=table,
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
# cast_row=cast_row + ',' if cast_row else '',
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
# ),
# values,
# fetch=True
# )

102
src/data/registry.py Normal file
View File

@@ -0,0 +1,102 @@
from typing import Protocol, runtime_checkable, Optional
from psycopg import AsyncConnection
from .connector import Connector, Connectable
@runtime_checkable
class _Attachable(Connectable, Protocol):
def attach_to(self, registry: 'Registry'):
raise NotImplementedError
class Registry:
_attached: list[_Attachable] = []
_name: Optional[str] = None
def __init_subclass__(cls, name=None):
attached = []
for _, member in cls.__dict__.items():
if isinstance(member, _Attachable):
attached.append(member)
cls._attached = attached
cls._name = name or cls.__name__
def __init__(self, name=None):
self._conn: Optional[Connector] = None
self.name: str = name if name is not None else self._name
if self.name is None:
raise ValueError("A Registry must have a name!")
self.init_tasks = []
for member in self._attached:
member.attach_to(self)
def bind(self, connector: Connector):
self._conn = connector
for child in self._attached:
child.bind(connector)
def attach(self, attachable):
self._attached.append(attachable)
if self._conn is not None:
attachable.bind(self._conn)
return attachable
def init_task(self, coro):
"""
Initialisation tasks are run to setup the registry state.
These tasks will be run in the event loop, after connection to the database.
These tasks should be idempotent, as they may be run on reload and reconnect.
"""
self.init_tasks.append(coro)
return coro
async def init(self):
for task in self.init_tasks:
await task(self)
return self
class AttachableClass:
"""ABC for a default implementation of an Attachable class."""
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
@classmethod
def bind(cls, connector: Connector):
cls._connector = connector
connector.connect_hook(cls.on_connect)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@classmethod
async def on_connect(cls, connection: AsyncConnection):
pass
class Attachable:
"""ABC for a default implementation of an Attachable object."""
def __init__(self, *args, **kwargs):
self._connector: Optional[Connector] = None
self._registry: Optional[Registry] = None
def bind(self, connector: Connector):
self._connector = connector
connector.connect_hook(self.on_connect)
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
async def on_connect(self, connection: AsyncConnection):
pass

95
src/data/table.py Normal file
View File

@@ -0,0 +1,95 @@
from typing import Optional
from psycopg.rows import DictRow
from psycopg import sql
from . import queries as q
from .connector import Connector
from .registry import Registry
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
"""
def __init__(self, name, *args, schema='public', **kwargs):
self.name: str = name
self.schema: str = schema
self.connector: Connector = None
@property
def identifier(self):
if self.schema == 'public':
return sql.Identifier(self.name)
else:
return sql.Identifier(self.schema, self.name)
def bind(self, connector: Connector):
self.connector = connector
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
if data:
return data[0]
else:
return None
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs):
# with self.conn:
# return update_many(self.identifier, *args, **kwargs)
# def upsert(self, *args, **kwargs):
# return upsert(self.identifier, *args, **kwargs)

344
src/meta/LionBot.py Normal file
View File

@@ -0,0 +1,344 @@
from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload
import logging
import asyncio
from weakref import WeakValueDictionary
import discord
from discord.utils import MISSING
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
from aiohttp import ClientSession
from data import Database
from utils.lib import tabulate
from .config import Conf
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
from .context import context
from .LionContext import LionContext
from .LionTree import LionTree
from .errors import HandledException, SafeCancellation
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
if TYPE_CHECKING:
from core.cog import CoreCog
logger = logging.getLogger(__name__)
class LionBot(Bot):
def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession,
testing_guilds: List[int] = [], **kwargs
):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs)
self.web_client = web_client
self.testing_guilds = testing_guilds
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.system_monitor = SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
self.system_monitor.add_component(self.monitor)
self._locks = WeakValueDictionary()
self._running_events = set()
@property
def core(self):
return self.get_cog('CoreCog')
async def _monitor_status(self):
if self.is_closed():
level = StatusLevel.ERRORED
info = "(ERROR) Websocket is closed"
data = {}
elif self.is_ws_ratelimited():
level = StatusLevel.WAITING
info = "(WAITING) Websocket is ratelimited"
data = {}
elif not self.is_ready():
level = StatusLevel.STARTING
info = "(STARTING) Not yet ready"
data = {}
else:
level = StatusLevel.OKAY
info = (
"(OK) "
"Logged in with {guild_count} guilds, "
", websocket latency {latency}, and {events} running events."
)
data = {
'guild_count': len(self.guilds),
'latency': self.latency,
'events': len(self._running_events),
}
return ComponentStatus(level, info, info, data)
async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
for extension in self.initial_extensions:
await self.load_extension(extension)
for guildid in self.testing_guilds:
guild = discord.Object(guildid)
if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)):
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)
# To make the type checker happy about fetching cogs by name
# TODO: Move this to stubs at some point
@overload
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
...
@overload
def get_cog(self, name: str) -> Optional[Cog]:
...
def get_cog(self, name: str) -> Optional[Cog]:
return super().get_cog(name)
async def add_cog(self, cog: Cog, **kwargs):
sup = super()
@log_wrap(action=f"Attach {cog.__cog_name__}")
async def wrapper():
logger.info(f"Attaching Cog {cog.__cog_name__}")
await sup.add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
await wrapper()
async def load_extension(self, name, *, package=None, **kwargs):
sup = super()
@log_wrap(action=f"Load {name.strip('.')}")
async def wrapper():
logger.info(f"Loading extension {name} in package {package}.")
await sup.load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.")
await wrapper()
async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"):
start_task = asyncio.create_task(self.login(token))
await start_task
with logging_context(stack=("Running",)):
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
await run_task
def dispatch(self, event_name: str, *args, **kwargs):
with logging_context(action=f"Dispatch {event_name}"):
super().dispatch(event_name, *args, **kwargs)
def _schedule_event(self, coro, event_name, *args, **kwargs):
"""
Extends client._schedule_event to keep a persistent
background task store.
"""
task = super()._schedule_event(coro, event_name, *args, **kwargs)
self._running_events.add(task)
task.add_done_callback(lambda fut: self._running_events.discard(fut))
def idlock(self, snowflakeid):
lock = self._locks.get(snowflakeid, None)
if lock is None:
lock = self._locks[snowflakeid] = asyncio.Lock()
return lock
async def on_ready(self):
logger.info(
f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.shardname}\n"
"------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
"Ready to take commands!\n",
extra={'action': 'Ready'}
)
async def get_context(self, origin, /, *, cls=MISSING):
if cls is MISSING:
cls = LionContext
ctx = await super().get_context(origin, cls=cls)
context.set(ctx)
return ctx
async def on_command(self, ctx: LionContext):
logger.info(
f"Executing command '{ctx.command.qualified_name}' "
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
f"with interaction: {ctx.interaction.data if ctx.interaction else None}",
extra={'with_ctx': True}
)
async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback
logger.debug(f"Handling command error for {ctx}: {exception}")
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
else:
cmd_str = str(ctx.command)
try:
raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
try:
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
original = exception.original.original
raise original
else:
original = exception.original
raise original
except HandledException:
pass
except TransformerError as e:
msg = str(e)
if msg:
try:
await ctx.error_reply(msg)
except Exception:
pass
logger.debug(
f"Caught a transformer error: {repr(e)}",
extra={'action': 'BotError', 'with_ctx': True}
)
except SafeCancellation:
if original.msg:
try:
await ctx.error_reply(original.msg)
except Exception:
pass
logger.debug(
f"Caught a safe cancellation: {original.details}",
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.Forbidden:
# Unknown uncaught Forbidden
try:
# Attempt a general error reply
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
except Exception:
# We can't send anything at all. Exit quietly, but log.
logger.warning(
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.HTTPException:
logger.error(
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass
except Exception as e:
logger.exception(
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
error_embed = discord.Embed(
title="Something went wrong!",
colour=discord.Colour.dark_red()
)
error_embed.description = (
"An unexpected error occurred while processing your command!\n"
"Our development team has been notified, and the issue will be addressed soon.\n"
"If the error persists, or you have any questions, please contact our [support team]({link}) "
"and give them the extra details below."
).format(link=self.config.bot.support_guild)
details = {}
details['error'] = f"`{repr(e)}`"
if ctx.interaction:
details['interactionid'] = f"`{ctx.interaction.id}`"
if ctx.command:
details['cmd'] = f"`{ctx.command.qualified_name}`"
if ctx.author:
details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`"
if ctx.guild:
details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`"
details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`"
if ctx.author:
ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else ''
details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`"
if ctx.channel.type is discord.enums.ChannelType.private:
details['channel'] = "`Direct Message`"
elif ctx.channel:
details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`"
details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`"
if ctx.author:
details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`"
details['shard'] = f"`{self.shardname}`"
details['log_stack'] = f"`{log_action_stack.get()}`"
table = '\n'.join(tabulate(*details.items()))
error_embed.add_field(name='Details', value=table)
try:
await ctx.error_reply(embed=error_embed)
except discord.HTTPException:
pass
finally:
exception.original = HandledException(exception.original)
except CheckFailure as e:
logger.debug(
f"Command failed check: {e}: {e.args}",
extra={'action': 'BotError', 'with_ctx': True}
)
try:
await ctx.error_reply(str(e))
except discord.HTTPException:
pass
except Exception:
# Completely unknown exception outside of command invocation!
# Something is very wrong here, don't attempt user interaction.
logger.exception(
f"Caught an unknown top-level exception while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
def add_command(self, command):
if not hasattr(command, '_placeholder_group_'):
super().add_command(command)
def request_chunking_for(self, guild):
if not guild.chunked:
return asyncio.create_task(
self._connection.chunk_guild(guild, wait=False, cache=True),
name=f"Background chunkreq for {guild.id}"
)
async def on_interaction(self, interaction: discord.Interaction):
"""
Adds the interaction author to guild cache if appropriate.
This gets run a little bit late, so it is possible the interaction gets handled
without the author being in case.
"""
guild = interaction.guild
user = interaction.user
if guild is not None and user is not None and isinstance(user, discord.Member):
if not guild.get_member(user.id):
guild._add_member(user)
if guild is not None and not guild.chunked:
# Getting an interaction in the guild is a good enough reason to request chunking
logger.info(
f"Unchunked guild <gid: {guild.id}> requesting chunking after interaction."
)
self.request_chunking_for(guild)

58
src/meta/LionCog.py Normal file
View File

@@ -0,0 +1,58 @@
from typing import Any
from discord.ext.commands import Cog
from discord.ext import commands as cmds
class LionCog(Cog):
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set()
for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name)
def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list
self = super().__new__(cls)
self.__cog_commands__ = [
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
]
return self
async def _inject(self, bot, *args, **kwargs):
if self.depends_on:
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
return await super()._inject(bot, *args, *kwargs)
@classmethod
def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True
return group
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
"""
Crossload a placeholder group's commands into the target group
"""
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
raise ValueError("Placeholder and target groups my be HypridGroups.")
if placeholder_group.name not in self._placeholder_groups_:
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
if target_group.app_command is None:
raise ValueError("Target group has no app_command to crossload into.")
for command in placeholder_group.commands:
placeholder_group.remove_command(command.name)
target_group.remove_command(command.name)
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
command.app_command = acmd
target_group.add_command(command)

195
src/meta/LionContext.py Normal file
View File

@@ -0,0 +1,195 @@
import types
import logging
from collections import namedtuple
from typing import Optional, TYPE_CHECKING
import discord
from discord.enums import ChannelType
from discord.ext.commands import Context
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__)
"""
Stuff that might be useful to implement (see cmdClient):
sent_messages cache
tasks cache
error reply
usage
interaction cache
View cache?
setting access
"""
FlatContext = namedtuple(
'FlatContext',
('message',
'interaction',
'guild',
'author',
'channel',
'alias',
'prefix',
'failed')
)
class LionContext(Context['LionBot']):
"""
Represents the context a command is invoked under.
Extends Context to add Lion-specific methods and attributes.
Also adds several contextual wrapped utilities for simpler user during command invocation.
"""
def __repr__(self):
parts = {}
if self.interaction is not None:
parts['iid'] = self.interaction.id
parts['itype'] = f"\"{self.interaction.type.name}\""
if self.message is not None:
parts['mid'] = self.message.id
if self.author is not None:
parts['uid'] = self.author.id
parts['uname'] = f"\"{self.author.name}\""
if self.channel is not None:
parts['cid'] = self.channel.id
if self.channel.type is ChannelType.private:
parts['cname'] = f"\"{self.channel.recipient}\""
else:
parts['cname'] = f"\"{self.channel.name}\""
if self.guild is not None:
parts['gid'] = self.guild.id
parts['gname'] = f"\"{self.guild.name}\""
if self.command is not None:
parts['cmd'] = f"\"{self.command.qualified_name}\""
if self.invoked_with is not None:
parts['alias'] = f"\"{self.invoked_with}\""
if self.command_failed:
parts['failed'] = self.command_failed
return "<LionContext: {}>".format(
' '.join(f"{name}={value}" for name, value in parts.items())
)
def flatten(self):
"""Flat pure-data context information, for caching and logging."""
return FlatContext(
self.message.id,
self.interaction.id if self.interaction is not None else None,
self.guild.id if self.guild is not None else None,
self.author.id if self.author is not None else None,
self.channel.id if self.channel is not None else None,
self.invoked_with,
self.prefix,
self.command_failed
)
@classmethod
def util(cls, util_func):
"""
Decorator to make a utility function available as a Context instance method.
"""
setattr(cls, util_func.__name__, util_func)
logger.debug(f"Attached context utility function: {util_func.__name__}")
return util_func
@classmethod
def wrappable_util(cls, util_func):
"""
Decorator to add a Wrappable utility function as a Context instance method.
"""
wrapped = Wrappable(util_func)
setattr(cls, util_func.__name__, wrapped)
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
return wrapped
async def error_reply(self, content: Optional[str] = None, **kwargs):
if content and 'embed' not in kwargs:
embed = discord.Embed(
colour=discord.Colour.red(),
description=content
)
kwargs['embed'] = embed
content = None
# Expect this may be run in highly unusual circumstances.
# This should never error, or at least handle all errors.
if self.interaction:
kwargs.setdefault('ephemeral', True)
try:
await self.reply(content=content, **kwargs)
except discord.HTTPException:
pass
except Exception:
logger.exception(
"Unknown exception in 'error_reply'.",
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
)
class Wrappable:
__slots__ = ('_func', 'wrappers')
def __init__(self, func):
self._func = func
self.wrappers = None
@property
def __name__(self):
return self._func.__name__
def add_wrapper(self, func, name=None):
self.wrappers = self.wrappers or {}
name = name or func.__name__
self.wrappers[name] = func
logger.debug(
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
extra={'action': "Wrap Util"}
)
def remove_wrapper(self, name):
if not self.wrappers or name not in self.wrappers:
raise ValueError(
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
)
self.wrappers.pop(name)
logger.debug(
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
extra={'action': "Unwrap Util"}
)
def __call__(self, *args, **kwargs):
if self.wrappers:
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
else:
return self._func(*args, **kwargs)
def _wrapped(self, iter_wraps):
next_wrap = next(iter_wraps, None)
if next_wrap:
def _func(*args, **kwargs):
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
else:
_func = self._func
return _func
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
return types.MethodType(self, instance)
LionContext.reply = Wrappable(LionContext.reply)
# @LionContext.reply.add_wrapper
# async def think(func, ctx, *args, **kwargs):
# await ctx.channel.send("thinking")
# await func(ctx, *args, **kwargs)

150
src/meta/LionTree.py Normal file
View File

@@ -0,0 +1,150 @@
import logging
import discord
from discord import Interaction
from discord.app_commands import CommandTree
from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace
from utils.lib import tabulate
from .logger import logging_context, set_logging_context, log_wrap, log_action_stack
from .errors import SafeCancellation
from .config import conf
logger = logging.getLogger(__name__)
class LionTree(CommandTree):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._call_tasks = set()
async def on_error(self, interaction: discord.Interaction, error) -> None:
try:
if isinstance(error, CommandInvokeError):
raise error.original
else:
raise error
except SafeCancellation:
# Assume this has already been handled
pass
except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
if interaction.type is not InteractionType.autocomplete:
embed = self.bugsplat(interaction, error)
await self.error_reply(interaction, embed)
async def error_reply(self, interaction, embed):
if not interaction.is_expired():
try:
if interaction.response.is_done():
await interaction.followup.send(embed=embed, ephemeral=True)
else:
await interaction.response.send_message(embed=embed, ephemeral=True)
except discord.HTTPException:
pass
def bugsplat(self, interaction, e):
error_embed = discord.Embed(title="Something went wrong!", colour=discord.Colour.red())
error_embed.description = (
"An unexpected error occurred during this interaction!\n"
"Our development team has been notified, and the issue will be addressed soon.\n"
"If the error persists, or you have any questions, please contact our [support team]({link}) "
"and give them the extra details below."
).format(link=interaction.client.config.bot.support_guild)
details = {}
details['error'] = f"`{repr(e)}`"
details['interactionid'] = f"`{interaction.id}`"
details['interactiontype'] = f"`{interaction.type}`"
if interaction.command:
details['cmd'] = f"`{interaction.command.qualified_name}`"
if interaction.user:
details['user'] = f"`{interaction.user.id}` -- `{interaction.user}`"
if interaction.guild:
details['guild'] = f"`{interaction.guild.id}` -- `{interaction.guild.name}`"
details['my_guild_perms'] = f"`{interaction.guild.me.guild_permissions.value}`"
if interaction.user:
ownerstr = ' (owner)' if interaction.user.id == interaction.guild.owner_id else ''
details['user_guild_perms'] = f"`{interaction.user.guild_permissions.value}{ownerstr}`"
if interaction.channel.type is discord.enums.ChannelType.private:
details['channel'] = "`Direct Message`"
elif interaction.channel:
details['channel'] = f"`{interaction.channel.id}` -- `{interaction.channel.name}`"
details['my_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.guild.me).value}`"
if interaction.user:
details['user_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.user).value}`"
details['shard'] = f"`{interaction.client.shardname}`"
details['log_stack'] = f"`{log_action_stack.get()}`"
table = '\n'.join(tabulate(*details.items()))
error_embed.add_field(name='Details', value=table)
return error_embed
def _from_interaction(self, interaction: Interaction) -> None:
@log_wrap(context=f"iid: {interaction.id}", isolate=False)
async def wrapper():
try:
await self._call(interaction)
except AppCommandError as e:
await self._dispatch_error(interaction, e)
task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
self._call_tasks.add(task)
task.add_done_callback(lambda fut: self._call_tasks.discard(fut))
async def _call(self, interaction):
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above
interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
set_logging_context(action=f"Acmp {command.qualified_name}")
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
try:
await command._invoke_autocomplete(interaction, focused, namespace)
except Exception as e:
await self.on_error(interaction, e)
return
set_logging_context(action=f"Run {command.qualified_name}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")

15
src/meta/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .LionBot import LionBot
from .LionCog import LionCog
from .LionContext import LionContext
from .LionTree import LionTree
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji
from .args import args
from .app import appname, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot
from . import sharding
from . import logger
from . import app

32
src/meta/app.py Normal file
View File

@@ -0,0 +1,32 @@
"""
appname: str
The base identifer for this application.
This identifies which services the app offers.
shardname: str
The specific name of the running application.
Only one process should be connecteded with a given appname.
For the bot apps, usually specifies the shard id and shard number.
"""
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
from . import sharding, conf
from .logger import log_app
from .args import args
appname = conf.data['appid']
appid = appname # backwards compatibility
def appname_from_shard(shardid):
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
return appname
def shard_from_appname(appname: str):
return int(appname.rsplit('_', maxsplit=1)[-1])
shardname = appname_from_shard(sharding.shard_number)
log_app.set(shardname)

35
src/meta/args.py Normal file
View File

@@ -0,0 +1,35 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file."
)
parser.add_argument(
'--shard',
dest='shard',
default=None,
type=int,
help="Shard number to run, if applicable."
)
parser.add_argument(
'--host',
dest='host',
default='127.0.0.1',
help="IP address to run the app listener on."
)
parser.add_argument(
'--port',
dest='port',
default='5001',
help="Port to run the app listener on."
)
args = parser.parse_args()

146
src/meta/config.py Normal file
View File

@@ -0,0 +1,146 @@
from discord import PartialEmoji
import configparser as cfgp
from .args import args
shard_number = args.shard
class configEmoji(PartialEmoji):
__slots__ = ('fallback',)
def __init__(self, *args, fallback=None, **kwargs):
super().__init__(*args, **kwargs)
self.fallback = fallback
@classmethod
def from_str(cls, emojistr: str):
"""
Parses emoji strings of one of the following forms
`<a:name:id> or fallback`
`<:name:id> or fallback`
`<a:name:id>`
`<:name:id>`
"""
splits = emojistr.rsplit(' or ', maxsplit=1)
fallback = splits[1] if len(splits) > 1 else None
emojistr = splits[0].strip('<> ')
animated, name, id = emojistr.split(':')
return cls(
name=name,
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated),
id=int(id) if id else None
)
class MapDotProxy:
"""
Allows dot access to an underlying Mappable object.
"""
__slots__ = ("_map", "_converter")
def __init__(self, mappable, converter=None):
self._map = mappable
self._converter = converter
def __getattribute__(self, key):
_map = object.__getattribute__(self, '_map')
if key == '_map':
return _map
if key in _map:
_converter = object.__getattribute__(self, '_converter')
if _converter:
return _converter(_map[key])
else:
return _map[key]
else:
return object.__getattribute__(_map, key)
def __getitem__(self, key):
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
"emoji": configEmoji.from_str,
}
)
with open(configfile) as conff:
# Opening with read_file mainly to ensure the file exists
self.config.read_file(conff)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = MapDotProxy(self.config[self.section_name])
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
self.emojis = MapDotProxy(
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
converter=configEmoji.from_str
)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
name = section.upper()
shard_name = f"{name}-{shard_number}"
if shard_name in self.config:
return self.config[shard_name]
else:
return self.config[name]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config, 'BOT')

20
src/meta/context.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Namespace for various global context variables.
Allows asyncio callbacks to accurately retrieve information about the current state.
"""
from typing import TYPE_CHECKING, Optional
from contextvars import ContextVar
if TYPE_CHECKING:
from .LionBot import LionBot
from .LionContext import LionContext
# Contains the current command context, if applicable
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
# Contains the current LionBot instance
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)

64
src/meta/errors.py Normal file
View File

@@ -0,0 +1,64 @@
from typing import Optional
from string import Template
class SafeCancellation(Exception):
"""
Raised to safely cancel execution of the current operation.
If not caught, is expected to be propagated to the Tree and safely ignored there.
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
The error handler should then set the `msg` to None, to avoid double handling.
Debugging information should go in `details`, to be logged by a top-level error handler.
"""
default_message = ""
@property
def msg(self):
return self._msg if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self._msg: Optional[str] = _msg
self.details: str = details if details is not None else self.msg
super().__init__(**kwargs)
class UserInputError(SafeCancellation):
"""
A SafeCancellation induced from unparseable user input.
"""
default_message = "Could not understand your input."
@property
def msg(self):
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
self.info = info
super().__init__(_msg, **kwargs)
class UserCancelled(SafeCancellation):
"""
A SafeCancellation induced from manual user cancellation.
Usually silent.
"""
default_msg = None
class ResponseTimedOut(SafeCancellation):
"""
A SafeCancellation induced from a user interaction time-out.
"""
default_msg = "Session timed out waiting for input."
class HandledException(SafeCancellation):
"""
Sentinel class to indicate to error handlers that this exception has been handled.
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
"""
def __init__(self, exc=None, **kwargs):
self.exc = exc
super().__init__(**kwargs)

468
src/meta/logger.py Normal file
View File

@@ -0,0 +1,468 @@
import sys
import logging
import asyncio
from typing import List, Optional
from logging.handlers import QueueListener, QueueHandler
import queue
import multiprocessing
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
import discord
from discord import Webhook, File
import aiohttp
from .config import conf
from . import sharding
from .context import context
from utils.lib import utc_now
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
def set_logging_context(
context: Optional[str] = None,
action: Optional[str] = None,
stack: Optional[tuple[str, ...]] = None
):
"""
Statically set the logging context variables to the given values.
If `action` is given, pushes it onto the `log_action_stack`.
"""
if context is not None:
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
@contextmanager
def logging_context(context=None, action=None, stack=None):
"""
Context manager for executing a block of code in a given logging context.
This context manager should only be used around synchronous code.
This is because async code *may* get cancelled or externally garbage collected,
in which case the finally block will be executed in the wrong context.
See https://github.com/python/cpython/issues/93740
This can be refactored nicely if this gets merged:
https://github.com/python/cpython/pull/99634
(It will not necessarily break on async code,
if the async code can be guaranteed to clean up in its own context.)
"""
if context is not None:
oldcontext = log_context.get()
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
try:
yield
finally:
if context is not None:
log_context.set(oldcontext)
if stack is not None or action is not None:
log_action_stack.set(astack)
def with_log_ctx(isolate=True, **kwargs):
"""
Execute a coroutine inside a given logging context.
If `isolate` is true, ensures that context does not leak
outside the coroutine.
If `isolate` is false, just statically set the context,
which will leak unless the coroutine is
called in an externally copied context.
"""
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
if isolate:
with logging_context(**kwargs):
# Task creation will synchronously copy the context
# This is gc safe
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
return await task
else:
# This will leak context changes
set_logging_context(**kwargs)
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
# For backwards compatibility
log_wrap = with_log_ctx
def persist_task(task_collection: set):
"""
Coroutine decorator that ensures the coroutine is scheduled as a task
and added to the given task_collection for strong reference
when it is called.
This is just a hack to handle discord.py events potentially
being unexpectedly garbage collected.
Since this also implicitly schedules the coroutine as a task when it is called,
the coroutine will also be run inside an isolated context.
"""
def decorator(coro):
@wraps(coro)
async def wrapped(*w_args, **w_kwargs):
name = f"persisted-{coro.__name__}"
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
task_collection.add(task)
task.add_done_callback(lambda f: task_collection.discard(f))
await task
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
'%(cyan)%(app)-15s%(reset)|' +
'%(cyan)%(context)-24s%(reset)|' +
'%(cyan)%(actionstr)-22s%(reset)|' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
# datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ExactLevelFilter(logging.Filter):
def __init__(self, target_level, name=""):
super().__init__(name)
self.target_level = target_level
def filter(self, record):
return (record.levelno == self.target_level)
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter):
def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
except asyncio.CancelledError:
raise
except Exception:
self.handleError(record)
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
super().__init__()
self.webhook_url = webhook_url
self.prefix = prefix
self.batched = ""
self.batch = batch
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
self.bucket = Bucket(20, 40)
self.ignored = 0
self.session = None
self.webhook = None
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.format(record)
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
if self.session is None:
self.setup()
asyncio.create_task(self.post(record))
def setup(self):
self.session = aiohttp.ClientSession()
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
async def post(self, record):
if record.context == 'Webhook Logger':
# Don't livelog livelog errors
# Otherwise we recurse and Cloudflare hates us
return
log_context.set("Webhook Logger")
log_action_stack.set(("Logging",))
log_app.set(record.app)
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900:
as_file = True
else:
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 1500:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += message
if len(self.batched) + len(message) > 1500:
await self._send_batched_now()
else:
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, message, as_file=False):
try:
self.bucket.request()
except BucketOverFull:
# Silently ignore
self.ignored += 1
return
except BucketFull:
logger.warning(
"Can't keep up! "
f"Ignoring records on live-logger {self.webhook.id}."
)
self.ignored += 1
return
else:
if self.ignored > 0:
logger.warning(
"Can't keep up! "
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
)
self.ignored = 0
try:
if as_file or len(message) > 1900:
with StringIO(message) as fp:
fp.seek(0)
await self.webhook.send(
f"{self.prefix}\n`{message.splitlines()[0]}`",
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else:
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
except discord.HTTPException:
logger.exception(
"Live logger errored. Slowing down live logger."
)
self.bucket.fill()
handlers = []
if webhook := conf.logging['general_log']:
handler = WebHookHandler(webhook, batch=True)
handlers.append(handler)
if webhook := conf.logging['warning_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
handler.addFilter(ExactLevelFilter(logging.WARNING))
handler.setLevel(logging.WARNING)
handlers.append(handler)
if webhook := conf.logging['error_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
handler.setLevel(logging.ERROR)
handlers.append(handler)
if webhook := conf.logging['critical_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
handler.setLevel(logging.CRITICAL)
handlers.append(handler)
def make_queue_handler(queue):
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
return qhandler
def setup_main_logger(multiprocess=False):
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
qhandler = make_queue_handler(q)
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler)
listener = QueueListener(
q, *handlers, respect_handler_level=True
)
listener.start()
return q

139
src/meta/monitor.py Normal file
View File

@@ -0,0 +1,139 @@
import logging
import asyncio
from enum import IntEnum
from collections import deque, ChainMap
import datetime as dt
logger = logging.getLogger(__name__)
class StatusLevel(IntEnum):
ERRORED = -2
UNSURE = -1
WAITING = 0
STARTING = 1
OKAY = 2
@property
def symbol(self):
return symbols[self]
symbols = {
StatusLevel.ERRORED: '🟥',
StatusLevel.UNSURE: '🟧',
StatusLevel.WAITING: '',
StatusLevel.STARTING: '🟫',
StatusLevel.OKAY: '🟩',
}
class ComponentStatus:
def __init__(self, level: StatusLevel, short_formatstr: str, long_formatstr: str, data: dict = {}):
self.level = level
self.short_formatstr = short_formatstr
self.long_formatstr = long_formatstr
self.data = data
self.created_at = dt.datetime.now(tz=dt.timezone.utc)
def format_args(self):
extra = {
'created_at': self.created_at,
'level': self.level,
'symbol': self.level.symbol,
}
return ChainMap(extra, self.data)
@property
def short(self):
return self.short_formatstr.format(**self.format_args())
@property
def long(self):
return self.long_formatstr.format(**self.format_args())
class ComponentMonitor:
_name = None
def __init__(self, name=None, callback=None):
self._callback = callback
self.name = name or self._name
if not self.name:
raise ValueError("ComponentMonitor must have a name")
async def _make_status(self, *args, **kwargs):
if self._callback is not None:
return await self._callback(*args, **kwargs)
else:
raise NotImplementedError
async def status(self) -> ComponentStatus:
try:
status = await self._make_status()
except Exception as e:
logger.exception(
f"Status callback for component '{self.name}' failed. This should not happen."
)
status = ComponentStatus(
level=StatusLevel.UNSURE,
short_formatstr="Status callback for '{name}' failed with error '{error}'",
long_formatstr="Status callback for '{name}' failed with error '{error}'",
data={
'name': self.name,
'error': repr(e)
}
)
return status
class SystemMonitor:
def __init__(self):
self.components = {}
self.recent = deque(maxlen=10)
def add_component(self, component: ComponentMonitor):
self.components[component.name] = component
return component
async def request(self):
"""
Request status from each component.
"""
tasks = {
name: asyncio.create_task(comp.status())
for name, comp in self.components.items()
}
await asyncio.gather(*tasks.values())
status = {
name: await fut for name, fut in tasks.items()
}
self.recent.append(status)
return status
async def _format_summary(self, status_dict: dict[str, ComponentStatus]):
"""
Format a one line summary from a status dict.
"""
freq = {level: 0 for level in StatusLevel}
for status in status_dict.values():
freq[status.level] += 1
summary = '\t'.join(f"{level.symbol} {count}" for level, count in freq.items() if count)
return summary
async def _format_overview(self, status_dict: dict[str, ComponentStatus]):
"""
Format an overview (one line per component) from a status dict.
"""
lines = []
for name, status in status_dict.items():
lines.append(f"{status.level.symbol} {name}: {status.short}")
summary = await self._format_summary(status_dict)
return '\n'.join((summary, *lines))
async def get_summary(self):
return await self._format_summary(await self.request())
async def get_overview(self):
return await self._format_overview(await self.request())

35
src/meta/sharding.py Normal file
View File

@@ -0,0 +1,35 @@
from .args import args
from .config import conf
from psycopg import sql
from data.conditions import Condition, Joiner
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
"""
Condition constructor for filtering by shard id.
Example Usage
-------------
Query.where(_shard_condition('guildid', 10, 1))
"""
return Condition(
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
guildid=sql.Identifier(guild_column),
shard_count=sql.Literal(shard_count)
),
Joiner.EQUALS,
sql.Placeholder(),
(shard_id,)
)
# Pre-built Condition for filtering by current shard.
THIS_SHARD = SHARDID(shard_number)

10
src/modules/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
this_package = 'modules'
active = [
'.sysadmin',
]
async def setup(bot):
for ext in active:
await bot.load_extension(ext, package=this_package)

View File

@@ -0,0 +1,5 @@
async def setup(bot):
from .exec_cog import Exec
await bot.add_cog(Exec(bot))

View File

@@ -0,0 +1,336 @@
import io
import ast
import sys
import types
import asyncio
import traceback
import builtins
import inspect
import logging
from io import StringIO
from typing import Callable, Any, Optional
from enum import Enum
import discord
from discord.ext import commands
from discord.ext.commands.errors import CheckFailure
from discord.ui import TextInput, View
from discord.ui.button import button
import discord.app_commands as appcmd
from meta.logger import logging_context, log_wrap
from meta import conf
from meta.context import context, ctx_bot
from meta.LionContext import LionContext
from meta.LionCog import LionCog
from meta.LionBot import LionBot
from utils.ui import FastModal, input
from wards import sys_admin
def _(arg): return arg
def _p(ctx, arg): return arg
logger = logging.getLogger(__name__)
class ExecModal(FastModal, title="Execute"):
code: TextInput = TextInput(
label="Code to execute",
style=discord.TextStyle.long,
required=True
)
class ExecStyle(Enum):
EXEC = 'exec'
EVAL = 'eval'
class ExecUI(View):
def __init__(self, ctx, code=None, style=ExecStyle.EXEC, ephemeral=True) -> None:
super().__init__()
self.ctx: LionContext = ctx
self.interaction: Optional[discord.Interaction] = ctx.interaction
self.code: Optional[str] = code
self.style: ExecStyle = style
self.ephemeral: bool = ephemeral
self._modal: Optional[ExecModal] = None
self._msg: Optional[discord.Message] = None
async def interaction_check(self, interaction: discord.Interaction):
"""Only allow the original author to use this View"""
if interaction.user.id != self.ctx.author.id:
await interaction.response.send_message(
("You cannot use this interface!"),
ephemeral=True
)
return False
else:
return True
async def run(self):
if self.code is None:
if (interaction := self.interaction) is not None:
self.interaction = None
await interaction.response.send_modal(self.get_modal())
await self.wait()
else:
# Complain
# TODO: error_reply
await self.ctx.reply("Pls give code.")
else:
await self.interaction.response.defer(thinking=True, ephemeral=self.ephemeral)
await self.compile()
await self.wait()
@button(label="Recompile")
async def recompile_button(self, interaction, butt):
# Interaction response with modal
await interaction.response.send_modal(self.get_modal())
@button(label="Show Source")
async def source_button(self, interaction, butt):
if len(self.code) > 1900:
# Send as file
with StringIO(self.code) as fp:
fp.seek(0)
file = discord.File(fp, filename="source.py")
await interaction.response.send_message(file=file, ephemeral=True)
else:
# Send as message
await interaction.response.send_message(
content=f"```py\n{self.code}```",
ephemeral=True
)
def create_modal(self) -> ExecModal:
modal = ExecModal()
@modal.submit_callback()
async def exec_submit(interaction: discord.Interaction):
if self.interaction is None:
self.interaction = interaction
await interaction.response.defer(thinking=True)
else:
await interaction.response.defer()
# Set code
self.code = modal.code.value
# Call compile
await self.compile()
return modal
def get_modal(self):
self._modal = self.create_modal()
self._modal.code.default = self.code
return self._modal
async def compile(self):
# Call _async
result = await _async(self.code, style=self.style.value)
# Display output
await self.show_output(result)
async def show_output(self, output):
# Format output
# If output message exists and not ephemeral, edit
# Otherwise, send message, add buttons
if len(output) > 1900:
# Send as file
with StringIO(output) as fp:
fp.seek(0)
args = {
'content': None,
'attachments': [discord.File(fp, filename="output.md")]
}
else:
args = {
'content': f"```md\n{output}```",
'attachments': []
}
if self._msg is None:
if self.interaction is not None:
msg = await self.interaction.edit_original_response(**args, view=self)
else:
# Send new message
if args['content'] is None:
args['file'] = args.pop('attachments')[0]
msg = await self.ctx.reply(**args, ephemeral=self.ephemeral, view=self)
if not self.ephemeral:
self._msg = msg
else:
if self.interaction is not None:
await self.interaction.edit_original_response(**args, view=self)
else:
# Edit message
await self._msg.edit(**args)
def mk_print(fp: io.StringIO) -> Callable[..., None]:
def _print(*args, file: Any = fp, **kwargs):
return print(*args, file=file, **kwargs)
return _print
def mk_status_printer(bot, printer):
async def _status(details=False):
if details:
status = await bot.system_monitor.get_overview()
else:
status = await bot.system_monitor.get_summary()
printer(status)
return status
return _status
@log_wrap(action="Code Exec")
async def _async(to_eval: str, style='exec'):
newline = '\n' * ('\n' in to_eval)
logger.info(
f"Exec code with {style}: {newline}{to_eval}"
)
output = io.StringIO()
_print = mk_print(output)
scope: dict[str, Any] = dict(sys.modules)
scope['__builtins__'] = builtins
scope.update(builtins.__dict__)
scope['ctx'] = ctx = context.get()
scope['bot'] = ctx_bot.get()
scope['print'] = _print # type: ignore
scope['print_status'] = mk_status_printer(scope['bot'], _print)
try:
if ctx and ctx.message:
source_str = f"<msg: {ctx.message.id}>"
elif ctx and ctx.interaction:
source_str = f"<iid: {ctx.interaction.id}>"
else:
source_str = "Unknown async"
code = compile(
to_eval,
source_str,
style,
ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
)
func = types.FunctionType(code, scope)
ret = func()
if inspect.iscoroutine(ret):
ret = await ret
if ret is not None:
_print(repr(ret))
except Exception:
_, exc, tb = sys.exc_info()
_print("".join(traceback.format_tb(tb)))
_print(f"{type(exc).__name__}: {exc}")
result = output.getvalue().strip()
newline = '\n' * ('\n' in result)
logger.info(
f"Exec complete, output: {newline}{result}"
)
return result
class Exec(LionCog):
guild_ids = conf.bot.getintlist('admin_guilds')
def __init__(self, bot: LionBot):
self.bot = bot
async def cog_check(self, ctx: LionContext) -> bool: # type: ignore
passed = await sys_admin(ctx.bot, ctx.author.id)
if passed:
return True
else:
raise CheckFailure(
"You must be a bot owner to do this!"
)
@commands.hybrid_command(
name=_('async'),
description=_("Execute arbitrary code with Exec")
)
@appcmd.describe(
string="Code to execute.",
)
async def async_cmd(self, ctx: LionContext,
string: Optional[str] = None,
):
await ExecUI(ctx, string, ExecStyle.EXEC, ephemeral=False).run()
@commands.hybrid_command(
name=_('reload'),
description=_("Reload a given LionBot extension. Launches an ExecUI.")
)
@appcmd.describe(
extension=_("Name of the extension to reload. See autocomplete for options."),
force=_("Whether to force an extension reload even if it doesn't exist.")
)
@appcmd.guilds(*guild_ids)
async def reload_cmd(self, ctx: LionContext, extension: str, force: Optional[bool] = False):
"""
This is essentially just a friendly wrapper to reload an extension.
It is equivalent to running "await bot.reload_extension(extension)" in eval,
with a slightly nicer interface through the autocomplete and error handling.
"""
exists = (extension in self.bot.extensions)
if not (force or exists):
embed = discord.Embed(description=f"Unknown extension {extension}", colour=discord.Colour.red())
await ctx.reply(embed=embed)
else:
# Uses an ExecUI to simplify error handling and re-execution
if exists:
string = f"await bot.reload_extension('{extension}')"
else:
string = f"await bot.load_extension('{extension}')"
await ExecUI(ctx, string, ExecStyle.EVAL).run()
@reload_cmd.autocomplete('extension')
async def reload_extension_acmpl(self, interaction: discord.Interaction, partial: str):
keys = set(self.bot.extensions.keys())
results = [
appcmd.Choice(name=key, value=key)
for key in keys
if partial.lower() in key.lower()
]
if not results:
results = [
appcmd.Choice(name=f"No extensions found matching {partial}", value="None")
]
return results[:25]
@commands.hybrid_command(
name=_('shutdown'),
description=_("Shutdown (or restart) the client.")
)
@appcmd.guilds(*guild_ids)
async def shutdown_cmd(self, ctx: LionContext):
"""
Shutdown the client.
Maybe do something friendly here?
"""
logger.info("Shutting down on admin request.")
await ctx.reply(
embed=discord.Embed(
description=f"Understood {ctx.author.mention}, cleaning up and shutting down!",
colour=discord.Colour.orange()
)
)
await self.bot.close()

0
src/utils/__init__.py Normal file
View File

97
src/utils/ansi.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Minimal library for making Discord Ansi colour codes.
"""
from enum import StrEnum
PREFIX = u'\u001b'
class TextColour(StrEnum):
Gray = '30'
Red = '31'
Green = '32'
Yellow = '33'
Blue = '34'
Pink = '35'
Cyan = '36'
White = '37'
def __str__(self) -> str:
return AnsiColour(fg=self).as_str()
def __call__(self):
return AnsiColour(fg=self)
class BgColour(StrEnum):
FireflyDarkBlue = '40'
Orange = '41'
MarbleBlue = '42'
GrayTurq = '43'
Gray = '44'
Indigo = '45'
LightGray = '46'
White = '47'
def __str__(self) -> str:
return AnsiColour(bg=self).as_str()
def __call__(self):
return AnsiColour(bg=self)
class Format(StrEnum):
NORMAL = '0'
BOLD = '1'
UNDERLINE = '4'
NOOP = '9'
def __str__(self) -> str:
return AnsiColour(self).as_str()
def __call__(self):
return AnsiColour(self)
class AnsiColour:
def __init__(self, *flags, fg=None, bg=None):
self.text_colour = fg
self.background_colour = bg
self.reset = (Format.NORMAL in flags)
self._flags = set(flags)
self._flags.discard(Format.NORMAL)
@property
def flags(self):
return (*((Format.NORMAL,) if self.reset else ()), *self._flags)
def as_str(self):
parts = []
if self.reset:
parts.append(Format.NORMAL)
elif not self.flags:
parts.append(Format.NOOP)
parts.extend(self._flags)
for c in (self.text_colour, self.background_colour):
if c is not None:
parts.append(c)
partstr = ';'.join(part.value for part in parts)
return f"{PREFIX}[{partstr}m" # ]
def __str__(self):
return self.as_str()
def __add__(self, obj: 'AnsiColour'):
text_colour = obj.text_colour or self.text_colour
background_colour = obj.background_colour or self.background_colour
flags = (*self.flags, *obj.flags)
return AnsiColour(*flags, fg=text_colour, bg=background_colour)
RESET = AnsiColour(Format.NORMAL)
BOLD = AnsiColour(Format.BOLD)
UNDERLINE = AnsiColour(Format.UNDERLINE)

165
src/utils/data.py Normal file
View File

@@ -0,0 +1,165 @@
"""
Some useful pre-built Conditions for data queries.
"""
from typing import Optional, Any
from itertools import chain
from psycopg import sql
from data.conditions import Condition, Joiner
from data.columns import ColumnExpr
from data.base import Expression
from constants import MAX_COINS
def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[Any, ...]) -> Condition:
"""
Condition constructor for filtering by multiple column equalities.
Example Usage
-------------
Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4)))
"""
if not data:
raise ValueError("Cannot create empty multivalue condition.")
left = sql.SQL("({})").format(
sql.SQL(', ').join(
sql.Identifier(key)
for key in columns
)
)
right_item = sql.SQL('({})').format(
sql.SQL(', ').join(
sql.Placeholder()
for _ in columns
)
)
right = sql.SQL("({})").format(
sql.SQL(', ').join(
right_item
for _ in data
)
)
return Condition(
left,
Joiner.IN,
right,
chain(*data)
)
def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition:
"""
Condition constructor for filtering member tables by guild and user id simultaneously.
Example Usage
-------------
Query.where(MEMBERS((1234,12), (5678,34)))
"""
if not memberids:
raise ValueError("Cannot create a condition with no members")
return Condition(
sql.SQL("({guildid}, {userid})").format(
guildid=sql.Identifier(guild_column),
userid=sql.Identifier(user_column)
),
Joiner.IN,
sql.SQL("({})").format(
sql.SQL(', ').join(
sql.SQL("({}, {})").format(
sql.Placeholder(),
sql.Placeholder()
) for _ in memberids
)
),
chain(*memberids)
)
def as_duration(expr: Expression) -> ColumnExpr:
"""
Convert an integer expression into a duration expression.
"""
expr_expr, expr_values = expr.as_tuple()
return ColumnExpr(
sql.SQL("({} * interval '1 second')").format(expr_expr),
expr_values
)
class TemporaryTable(Expression):
"""
Create a temporary table expression to be used in From or With clauses.
Example
-------
```
tmp_table = TemporaryTable('_col1', '_col2', name='data')
tmp_table.values((1, 2), (3, 4))
real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table)
```
"""
def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str, ...]] = None):
self.name = name
self.columns = columns
self.types = types
if types and len(types) != len(columns):
raise ValueError("Number of types does not much number of columns!")
self._table_columns = {
col: ColumnExpr(sql.Identifier(name, col))
for col in columns
}
self.values = []
def __getitem__(self, key) -> sql.Identifier:
return self._table_columns[key]
def as_tuple(self):
"""
(VALUES {})
AS
name (col1, col2)
"""
if not self.values:
raise ValueError("Cannot flatten CTE with no values.")
single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns))
if self.types:
first_value = sql.SQL("({})").format(
sql.SQL(", ").join(
sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast))
for cast in self.types
)
)
else:
first_value = single_value
value_placeholder = sql.SQL("(VALUES {})").format(
sql.SQL(", ").join(
(first_value, *(single_value for _ in self.values[1:]))
)
)
expr = sql.SQL("{values} AS {name} ({columns})").format(
values=value_placeholder,
name=sql.Identifier(self.name),
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns)
)
values = chain(*self.values)
return (expr, values)
def set_values(self, *data):
self.values = data
def SAFECOINS(expr: Expression) -> Expression:
expr_expr, expr_values = expr.as_tuple()
return ColumnExpr(
sql.SQL("LEAST({}, {})").format(
expr_expr,
sql.Literal(MAX_COINS)
),
expr_values
)

847
src/utils/lib.py Normal file
View File

@@ -0,0 +1,847 @@
from io import StringIO
from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any
import collections
import datetime
import datetime as dt
import iso8601 # type: ignore
import pytz
import re
import json
from contextvars import Context
import discord
from discord.partial_emoji import _EmojiTag
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord.ui import View
from meta.errors import UserInputError
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
MISSING = object()
class MessageArgs:
"""
Utility class for storing message creation and editing arguments.
"""
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
def __init__(self, **kwargs):
self.kwargs = kwargs
@property
def send_args(self) -> dict:
if self.kwargs.get('view', MISSING) is None:
kwargs = self.kwargs.copy()
kwargs.pop('view')
else:
kwargs = self.kwargs
return kwargs
@property
def edit_args(self) -> dict:
args = {}
kept = (
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
)
for k in kept:
if k in self.kwargs:
args[k] = self.kwargs[k]
if 'file' in self.kwargs:
args['attachments'] = [self.kwargs['file']]
if 'files' in self.kwargs:
args['attachments'] = self.kwargs['files']
if 'suppress_embeds' in self.kwargs:
args['suppress'] = self.kwargs['suppress_embeds']
return args
def tabulate(
*fields: tuple[str, str],
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
sub_format: str = "`{invis:<{pad}}{colon}`\t{value}",
colon: str = ':',
invis: str = "",
**args
) -> list[str]:
"""
Turns a list of (property, value) pairs into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Use `\\r\\n` in a value to break the line with padding.
Parameters
----------
fields: List[tuple[str, str]]
List of (key, value) pairs.
row_format: str
The format string used to format each row.
sub_format: str
The format string used to format each subline in a row.
colon: str
The colon character used.
invis: str
The invisible character used (to avoid Discord stripping the string).
Returns: List[str]
The list of resulting table rows.
Each row corresponds to one (key, value) pair from fields.
"""
max_len = max(len(field[0]) for field in fields)
rows = []
for field in fields:
key = field[0]
value = field[1]
lines = value.split('\r\n')
row_line = row_format.format(
invis=invis,
key=key,
pad=max_len,
colon=colon,
value=lines[0],
field=field,
**args
)
if len(lines) > 1:
row_lines = [row_line]
for line in lines[1:]:
sub_line = sub_format.format(
invis=invis,
pad=max_len + len(colon),
colon=colon,
value=line,
**args
)
row_lines.append(sub_line)
row_line = '\n'.join(row_lines)
rows.append(row_line)
return rows
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def _parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def strfdur(duration: int, short=True, show_days=False) -> str:
"""
Convert a duration given in seconds to a number of hours, minutes, and seconds.
"""
days = duration // (3600 * 24) if show_days else 0
hours = duration // 3600
if days:
hours %= 24
minutes = duration // 60 % 60
seconds = duration % 60
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
if short:
return ' '.join(parts)
else:
return ', '.join(parts)
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
# TODO: Upgrade to SafeCancellation
raise ValueError("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
_numbers = (item.strip() for item in substituted.split(','))
numbers = [item for item in _numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
if not ignore_errors and len(integers) != len(numbers):
# TODO: Upgrade to SafeCancellation
raise ValueError(
"Couldn't parse the provided selection!\n"
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
)
return integers
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.created_at.strftime(timestr)
user = str(msg.author)
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring: str) -> datetime.timedelta:
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message:
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
class EmbedField(NamedTuple):
name: str
value: str
inline: Optional[bool] = True
def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string: list[str], nfs=False) -> str:
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
# TODO: Probably not useful with localisation
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def shard_of(shard_count: int, guildid: int) -> int:
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)
def utc_now() -> datetime.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
return string
def recover_context(context: Context):
for var in context:
var.set(context[var])
def parse_ids(idstr: str) -> List[int]:
"""
Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids.
Object agnostic, so all mention tokens are stripped.
Raises UserInputError if an id is invalid,
setting `orig` and `item` info fields.
"""
from meta.errors import UserInputError
# Extract ids from string
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
splits = [split for split in splititer if split]
# Check they are integers
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
# Cast to integer and return
return list(map(int, splits))
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
)
return embed
class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
class Timezoned:
"""
ABC mixin for objects with a set timezone.
Provides several useful localised properties.
"""
__slots__ = ()
@property
def timezone(self) -> pytz.timezone:
"""
Must be implemented by the deriving class!
"""
raise NotImplementedError
@property
def now(self):
"""
Return the current time localised to the object's timezone.
"""
return datetime.datetime.now(tz=self.timezone)
@property
def today(self):
"""
Return the start of the day localised to the object's timezone.
"""
now = self.now
return now.replace(hour=0, minute=0, second=0, microsecond=0)
@property
def week_start(self):
"""
Return the start of the week in the object's timezone
"""
today = self.today
return today - datetime.timedelta(days=today.weekday())
@property
def month_start(self):
"""
Return the start of the current month in the object's timezone
"""
today = self.today
return today.replace(day=1)
def replace_multiple(format_string, mapping):
"""
Subsistutes the keys from the format_dict with their corresponding values.
Substitution is non-chained, and done in a single pass via regex.
"""
if not mapping:
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
return string
def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str):
"""
Produces a distinguishing key for an Emoji or PartialEmoji.
Equality checks using this key should act as expected.
"""
if isinstance(emoji, _EmojiTag):
if emoji.id:
key = str(emoji.id)
else:
key = str(emoji.name)
else:
key = str(emoji)
return key
def recurse_map(func, obj, loc=[]):
if isinstance(obj, dict):
for k, v in obj.items():
loc.append(k)
obj[k] = recurse_map(func, v, loc)
loc.pop()
elif isinstance(obj, list):
for i, item in enumerate(obj):
loc.append(i)
obj[i] = recurse_map(func, item)
loc.pop()
else:
obj = func(loc, obj)
return obj
async def check_dm(user: discord.User | discord.Member) -> bool:
"""
Check whether we can direct message the given user.
Assumes the client is initialised.
This uses an always-failing HTTP request,
so we need to be very very very careful that this is not used frequently.
Optimally only at the explicit behest of the user
(i.e. during a user instigated interaction).
"""
try:
await user.send('')
except discord.Forbidden:
return False
except discord.HTTPException:
return True
async def command_lengths(tree) -> dict[str, int]:
cmds = tree.get_commands()
payloads = [
await cmd.get_translated_payload(tree.translator)
for cmd in cmds
]
lens = {}
for command in payloads:
name = command['name']
crumbs = {}
cmd_len = lens[name] = _recurse_length(command, crumbs, (name,))
if name == 'configure' or cmd_len > 4000:
print(f"'{name}' over 4000. Breadcrumb Trail follows:")
lines = []
for loc, val in crumbs.items():
locstr = '.'.join(loc)
lines.append(f"{locstr}: {val}")
print('\n'.join(lines))
print(json.dumps(command, indent=2))
return lens
def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
total = 0
total_header = (*header, '')
breadcrumbs[total_header] = 0
if isinstance(payload, dict):
# Read strings that count towards command length
# String length is length of longest localisation, including default.
for key in ('name', 'description', 'value'):
if key in payload:
value = payload[key]
if isinstance(value, str):
values = (value, *payload.get(key + '_localizations', {}).values())
maxlen = max(map(len, values))
total += maxlen
breadcrumbs[(*header, key)] = maxlen
for key, value in payload.items():
loc = (*header, key)
total += _recurse_length(value, breadcrumbs, loc)
elif isinstance(payload, list):
for i, item in enumerate(payload):
if isinstance(item, dict) and 'name' in item:
loc = (*header, f"{i}<{item['name']}>")
else:
loc = (*header, str(i))
total += _recurse_length(item, breadcrumbs, loc)
if total:
breadcrumbs[total_header] = total
else:
breadcrumbs.pop(total_header)
return total
def write_records(records: list[dict[str, Any]], stream: StringIO):
if records:
keys = records[0].keys()
stream.write(','.join(keys))
stream.write('\n')
for record in records:
stream.write(','.join(map(str, record.values())))
stream.write('\n')

191
src/utils/monitor.py Normal file
View File

@@ -0,0 +1,191 @@
import asyncio
import bisect
import logging
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
from .lib import utc_now
from .ratelimits import Bucket
logger = logging.getLogger(__name__)
Taskid = TypeVar('Taskid')
class TaskMonitor(Generic[Taskid]):
"""
Base class for a task monitor.
Stores tasks as a time-sorted list of taskids.
Subclasses may override `run_task` to implement an executor.
Adding or removing a single task has O(n) performance.
To bulk update tasks, instead use `schedule_tasks`.
Each taskid must be unique and hashable.
"""
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
# Ratelimit bucket to enforce maximum execution rate
self._bucket = bucket
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
self._wakeup: asyncio.Event = asyncio.Event()
self._monitor_task: Optional[asyncio.Task] = None
# Task data
self._tasklist: list[Taskid] = []
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
# Running map ensures we keep a reference to the running task
# And allows simpler external cancellation if required
self._running: dict[Taskid, asyncio.Future] = {}
def __repr__(self):
return (
"<"
f"{self.__class__.__name__}"
f" tasklist={len(self._tasklist)}"
f" taskmap={len(self._taskmap)}"
f" wakeup={self._wakeup.is_set()}"
f" bucket={self._bucket}"
f" running={len(self._running)}"
f" task={self._monitor_task}"
f">"
)
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Similar to `schedule_tasks`, but wipe and reset the tasklist.
"""
self._taskmap = {tid: time for tid, time in tasks}
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
self._wakeup.set()
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Schedule the given tasks.
Rather than repeatedly inserting tasks,
where the O(log n) insort is dominated by the O(n) list insertion,
we build an entirely new list, and always wake up the loop.
"""
self._taskmap |= {tid: time for tid, time in tasks}
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
self._wakeup.set()
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
"""
Insert the provided task into the tasklist.
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
"""
if self._tasklist:
nextid = self._tasklist[-1]
wake = self._taskmap[nextid] >= timestamp
wake = wake or taskid == nextid
else:
wake = True
if taskid in self._taskmap:
self._tasklist.remove(taskid)
self._taskmap[taskid] = timestamp
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
if wake:
self._wakeup.set()
def cancel_tasks(self, *taskids: Taskid) -> None:
"""
Remove all tasks with the given taskids from the tasklist.
If the next task has this taskid, wake up the monitor loop.
"""
taskids = set(taskids)
wake = (self._tasklist and self._tasklist[-1] in taskids)
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
for tid in taskids:
self._taskmap.pop(tid, None)
if wake:
self._wakeup.set()
def start(self):
if self._monitor_task and not self._monitor_task.done():
self._monitor_task.cancel()
# Start the monitor
self._monitor_task = asyncio.create_task(self.monitor())
return self._monitor_task
async def monitor(self):
"""
Start the monitor.
Executes the tasks in `self.tasks` at the specified time.
This will shield task execution from cancellation
to avoid partial states.
"""
try:
while True:
self._wakeup.clear()
if not self._tasklist:
# No tasks left, just sleep until wakeup
await self._wakeup.wait()
else:
# Get the next task, sleep until wakeup or it is ready to run
nextid = self._tasklist[-1]
nexttime = self._taskmap[nextid]
sleep_for = nexttime - utc_now().timestamp()
try:
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
except asyncio.TimeoutError:
# Ready to run the task
self._tasklist.pop()
self._taskmap.pop(nextid, None)
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
else:
# Wakeup task fired, loop again
continue
except asyncio.CancelledError:
# Log closure and wait for remaining tasks
# A second cancellation will also cancel the tasks
logger.debug(
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
f"Waiting for {len(self._running)} running tasks to complete."
)
await asyncio.gather(*self._running.values(), return_exceptions=True)
async def _run(self, taskid: Taskid) -> None:
# Execute the task, respecting the ratelimit bucket
if self._bucket is not None:
# IMPLEMENTATION NOTE:
# Bucket.wait() should guarantee not more than n tasks/second are run
# and that a request directly afterwards will _not_ raise BucketFull
# Make sure that only one waiter is actually waiting on its sleep task
# The other waiters should be waiting on a lock around the sleep task
# Waiters are executed in wait-order, so if we only let a single waiter in
# we shouldn't get collisions.
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
# Or we will lose thread-safety for BucketFull
await self._bucket.wait()
fut = asyncio.create_task(self.run_task(taskid))
try:
await asyncio.shield(fut)
except asyncio.CancelledError:
raise
except Exception:
# Protect the monitor loop from any other exceptions
logger.exception(
f"Ignoring exception in task monitor {self.__class__.__name__} while "
f"executing <taskid: {taskid}>"
)
finally:
self._running.pop(taskid)
async def run_task(self, taskid: Taskid):
"""
Execute the task with the given taskid.
Default implementation executes `self.executor` if it exists,
otherwise raises NotImplementedError.
"""
if self.executor is not None:
await self.executor(taskid)
else:
raise NotImplementedError

173
src/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,173 @@
import asyncio
import time
import logging
from meta.errors import SafeCancellation
from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
delay = (self._level + 1 - self.max_level) * self.leak_rate
else:
delay = 0
return delay
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level > self.max_level:
raise BucketOverFull
elif self._level == self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
def fill(self):
self._leak()
self._level = max(self._level, self.max_level + 1)
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
async def wrapped(self, coro):
await self.wait()
self.request()
await coro
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
try:
bucket.request()
except BucketOverFull:
raise SafeCancellation(details="Bucket overflow")
except BucketFull:
raise SafeCancellation(self.error, details="Bucket full")
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")

8
src/utils/ui/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
import asyncio
import logging
logger = logging.getLogger(__name__)
from .hooked import *
from .leo import *
from .micros import *

59
src/utils/ui/hooked.py Normal file
View File

@@ -0,0 +1,59 @@
import time
import discord
from discord.ui.item import Item
from discord.ui.button import Button
from .leo import LeoUI
__all__ = (
'HookedItem',
'AButton',
'AsComponents'
)
class HookedItem:
"""
Mixin for Item classes allowing an instance to be used as a callback decorator.
"""
def __init__(self, *args, pass_kwargs={}, **kwargs):
super().__init__(*args, **kwargs)
self.pass_kwargs = pass_kwargs
def __call__(self, coro):
async def wrapped(interaction, **kwargs):
return await coro(interaction, self, **(self.pass_kwargs | kwargs))
self.callback = wrapped
return self
class AButton(HookedItem, Button):
...
class AsComponents(LeoUI):
"""
Simple container class to accept a number of Items and turn them into an attachable View.
"""
def __init__(self, *items, pass_kwargs={}, **kwargs):
super().__init__(**kwargs)
self.pass_kwargs = pass_kwargs
for item in items:
self.add_item(item)
async def _scheduled_task(self, item: Item, interaction: discord.Interaction):
try:
item._refresh_state(interaction, interaction.data) # type: ignore
allow = await self.interaction_check(interaction)
if not allow:
return
if self.timeout:
self.__timeout_expiry = time.monotonic() + self.timeout
await item.callback(interaction, **self.pass_kwargs)
except Exception as e:
return await self.on_error(interaction, e, item)

485
src/utils/ui/leo.py Normal file
View File

@@ -0,0 +1,485 @@
from typing import List, Optional, Any, Dict
import asyncio
import logging
import time
from contextvars import copy_context, Context
import discord
from discord.ui import Modal, View, Item
from meta.logger import log_action_stack, logging_context
from meta.errors import SafeCancellation
from . import logger
from ..lib import MessageArgs, error_embed
__all__ = (
'LeoUI',
'MessageUI',
'LeoModal',
'error_handler_for'
)
class LeoUI(View):
"""
View subclass for small-scale user interfaces.
While a 'View' provides an interface for managing a collection of components,
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
The `LeoUI` also exposes more advanced cleanup and timeout methods,
and preserves the context.
"""
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._name = ui_name or self.__class__.__name__
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
# List of slaved views to stop when this view stops
self._slaves: List[View] = []
# TODO: Replace this with a substitutable ViewLayout class
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
@property
def _stopped(self) -> asyncio.Future:
"""
Return an future indicating whether the View has finished interacting.
Currently exposes a hidden attribute of the underlying View.
May be reimplemented in future.
"""
return self._View__stopped
def to_components(self) -> List[Dict[str, Any]]:
"""
Extending component generator to apply the set _layout, if it exists.
"""
if self._layout is not None:
# Alternative rendering using layout
components = []
for i, row in enumerate(self._layout):
# Skip empty rows
if not row:
continue
# Since we aren't relying on ViewWeights, manually check width here
if sum(item.width for item in row) > 5:
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
# Create the component dict for this row
components.append({
'type': 1,
'components': [item.to_component_dict() for item in row]
})
else:
components = super().to_components()
return components
def set_layout(self, *rows: tuple[Item, ...]) -> None:
"""
Set the layout of the rendered View as a matrix of items,
or more precisely, a list of action rows.
This acts independently of the existing sorting with `_ViewWeights`,
and overrides the sorting if applied.
"""
self._layout = rows
async def cleanup(self):
"""
Coroutine to run when timeing out, stopping, or cancelling.
Generally cleans up any open resources, and removes any leftover components.
"""
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
return None
def stop(self):
"""
Extends View.stop() to also stop all the slave views.
Note that stopping is idempotent, so it is okay if close() also calls stop().
"""
for slave in self._slaves:
slave.stop()
super().stop()
async def close(self, msg=None):
self.stop()
await self.cleanup()
async def pre_timeout(self):
"""
Task to execute before actually timing out.
This may cancel the timeout by refreshing or rescheduling it.
(E.g. to ask the user whether they want to keep going.)
Default implementation does nothing.
"""
return None
async def on_timeout(self):
"""
Task to execute after timeout is complete.
Default implementation calls cleanup.
"""
await self.cleanup()
async def __dispatch_timeout(self):
"""
This essentially extends View._dispatch_timeout,
to include a pre_timeout task
which may optionally refresh and hence cancel the timeout.
"""
if self._View__stopped.done():
# We are already stopped, nothing to do
return
with logging_context(action='Timeout'):
try:
await self.pre_timeout()
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
pass
except Exception:
logger.exception(
"Unhandled error caught while dispatching timeout for {self!r}.",
extra={'with_ctx': True, 'action': 'Error'}
)
# Check if we still need to timeout
if self.timeout is None:
# The timeout was removed entirely, silently walk away
return
if self._View__stopped.done():
# We stopped while waiting for the pre timeout.
# Or maybe another thread timed us out
# Either way, we are done here
return
now = time.monotonic()
if self._View__timeout_expiry is not None and now < self._View__timeout_expiry:
# The timeout was extended, make sure the timeout task is running then fade away
if self._View__timeout_task is None or self._View__timeout_task.done():
self._View__timeout_task = asyncio.create_task(self._View__timeout_task_impl())
else:
# Actually timeout, and call the post-timeout task for cleanup.
self._really_timeout()
await self.on_timeout()
def _dispatch_timeout(self):
"""
Overriding timeout method completely, to support interactive flow during timeout,
and optional refreshing of the timeout.
"""
return self._context.run(asyncio.create_task, self.__dispatch_timeout())
def _really_timeout(self):
"""
Actuallly times out the View.
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
which is now handled by `__dispatch_timeout`.
"""
if self._View__stopped.done():
return
if self._View__cancel_callback:
self._View__cancel_callback(self)
self._View__cancel_callback = None
self._View__stopped.set_result(True)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
"""
Default LeoUI error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except SafeCancellation as e:
if e.msg and not interaction.is_expired():
try:
if interaction.response.is_done():
await interaction.followup.send(
embed=error_embed(e.msg),
ephemeral=True
)
else:
await interaction.response.send_message(
embed=error_embed(e.msg),
ephemeral=True
)
except discord.HTTPException:
pass
logger.debug(
f"Caught a safe cancellation from LeoUI: {e.details}",
extra={'action': 'Cancel'}
)
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r} from interaction: "
f"{interaction.data}",
extra={'with_ctx': True, 'action': 'UIError'}
)
# Explicitly handle the bugsplat ourselves
splat = interaction.client.tree.bugsplat(interaction, error)
await interaction.client.tree.error_reply(interaction, splat)
class MessageUI(LeoUI):
"""
Simple single-message LeoUI, intended as a framework for UIs
attached to a single interaction response.
UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`.
"""
def __init__(self, *args, callerid: Optional[int] = None, **kwargs):
super().__init__(*args, **kwargs)
# ----- UI state -----
# User ID of the original caller (e.g. command author).
# Mainly used for interaction usage checks and logging
self._callerid = callerid
# Original interaction, if this UI is sent as an interaction response
self._original: discord.Interaction = None
# Message holding the UI, when the UI is sent attached to a followup
self._message: discord.Message = None
# Refresh lock, to avoid cache collisions on refresh
self._refresh_lock = asyncio.Lock()
@property
def channel(self):
if self._original is not None:
return self._original.channel
else:
return self._message.channel
# ----- UI API -----
async def run(self, interaction: discord.Interaction, **kwargs):
"""
Run the UI as a response or followup to the given interaction.
Should be extended if more complex run mechanics are needed
(e.g. registering listeners or setting up caches).
"""
await self.draw(interaction, **kwargs)
async def refresh(self, *args, thinking: Optional[discord.Interaction] = None, **kwargs):
"""
Reload and redraw this UI.
Primarily a hook-method for use by parents and other controllers.
Performs a full data and reload and refresh (maintaining UI state, e.g. page n).
"""
async with self._refresh_lock:
# Reload data
await self.reload()
# Redraw UI message
await self.redraw(thinking=thinking)
async def quit(self):
"""
Quit the UI.
This usually involves removing the original message,
and stopping or closing the underlying View.
"""
for child in self._slaves:
# TODO: Better to use duck typing or interface typing
if isinstance(child, MessageUI) and not child.is_finished():
asyncio.create_task(child.quit())
try:
if self._original is not None and not self._original.is_expired():
await self._original.delete_original_response()
self._original = None
if self._message is not None:
await self._message.delete()
self._message = None
except discord.HTTPException:
pass
# Note close() also runs cleanup and stop
await self.close()
# ----- UI Flow -----
async def interaction_check(self, interaction: discord.Interaction):
"""
Check the given interaction is authorised to use this UI.
Default implementation simply checks that the interaction is
from the original caller.
Extend for more complex logic.
"""
return interaction.user.id == self._callerid
async def make_message(self) -> MessageArgs:
"""
Create the UI message body, depening on the current state.
Called upon each redraw.
Should handle caching if message construction is for some reason intensive.
Must be implemented by concrete UI subclasses.
"""
raise NotImplementedError
async def refresh_layout(self):
"""
Asynchronously refresh the message components,
and explicitly set the message component layout.
Called just before redrawing, before `make_message`.
Must be implemented by concrete UI subclasses.
"""
raise NotImplementedError
async def reload(self):
"""
Reload and recompute the underlying data for this UI.
Must be implemented by concrete UI subclasses.
"""
raise NotImplementedError
async def draw(self, interaction, force_followup=False, **kwargs):
"""
Send the UI as a response or followup to the given interaction.
If the interaction has been responded to, or `force_followup` is set,
creates a followup message instead of a response to the interaction.
"""
# Initial data loading
await self.reload()
# Set the UI layout
await self.refresh_layout()
# Fetch message arguments
args = await self.make_message()
as_followup = force_followup or interaction.response.is_done()
if as_followup:
self._message = await interaction.followup.send(**args.send_args, **kwargs, view=self)
else:
self._original = interaction
await interaction.response.send_message(**args.send_args, **kwargs, view=self)
async def send(self, channel: discord.abc.Messageable, **kwargs):
"""
Alternative to draw() which uses a discord.abc.Messageable.
"""
await self.reload()
await self.refresh_layout()
args = await self.make_message()
self._message = await channel.send(**args.send_args, view=self)
async def _redraw(self, args):
if self._original and not self._original.is_expired():
await self._original.edit_original_response(**args.edit_args, view=self)
elif self._message:
await self._message.edit(**args.edit_args, view=self)
else:
# Interaction expired or already closed. Quietly cleanup.
await self.close()
async def redraw(self, thinking: Optional[discord.Interaction] = None):
"""
Update the output message for this UI.
If a thinking interaction is provided, deletes the response while redrawing.
"""
await self.refresh_layout()
args = await self.make_message()
if thinking is not None and not thinking.is_expired() and thinking.response.is_done():
asyncio.create_task(thinking.delete_original_response())
try:
await self._redraw(args)
except discord.HTTPException as e:
# Unknown communication error, nothing we can reliably do. Exit quietly.
logger.warning(
f"Unexpected UI redraw failure occurred in {self}: {repr(e)}",
)
await self.close()
async def cleanup(self):
"""
Remove message components from interaction response, if possible.
Extend to remove listeners or clean up caches.
`cleanup` is always called when the UI is exiting,
through timeout or user-driven closure.
"""
try:
if self._original is not None and not self._original.is_expired():
await self._original.edit_original_response(view=None)
self._original = None
if self._message is not None:
await self._message.edit(view=None)
self._message = None
except discord.HTTPException:
pass
class LeoModal(Modal):
"""
Context-aware Modal class.
"""
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
super().__init__(**kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
def _dispatch_submit(self, *args, **kwargs):
"""
Extending event dispatch to run in the instantiation context.
"""
return self._context.run(super()._dispatch_submit, *args, **kwargs)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
"""
Default LeoModal error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in {self!r}. Interaction: {interaction.data}",
extra={'with_ctx': True, 'action': 'ModalError'}
)
# Explicitly handle the bugsplat ourselves
splat = interaction.client.tree.bugsplat(interaction, error)
await interaction.client.tree.error_reply(interaction, splat)
def error_handler_for(exc):
def wrapper(coro):
coro._ui_error_handler_for_ = exc
return coro
return wrapper

329
src/utils/ui/micros.py Normal file
View File

@@ -0,0 +1,329 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
import functools
import asyncio
import discord
from discord.ui import TextInput
from discord.ui.button import button
from meta.logger import logging_context
from meta.errors import ResponseTimedOut
from .leo import LeoModal, LeoUI
__all__ = (
'FastModal',
'ModalRetryUI',
'Confirm',
'input',
)
class FastModal(LeoModal):
__class_error_handlers__ = []
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
error_handlers = {}
for base in reversed(cls.__mro__):
for name, member in base.__dict__.items():
if hasattr(member, '_ui_error_handler_for_'):
error_handlers[name] = member
cls.__class_error_handlers__ = list(error_handlers.values())
def __init__error_handlers__(self):
handlers = {}
for handler in self.__class_error_handlers__:
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
return handlers
def __init__(self, *items: TextInput, **kwargs):
super().__init__(**kwargs)
for item in items:
self.add_item(item)
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
self._error_handlers = self.__init__error_handlers__()
def error_handler(self, exception):
def wrapper(coro):
self._error_handlers[exception] = coro
return coro
return wrapper
async def wait_for(self, check=None, timeout=None):
# Wait for _result or timeout
# If we timeout, or the view times out, raise TimeoutError
# Otherwise, return the Interaction
# This allows multiple listeners and callbacks to wait on
while True:
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
if check is not None:
if not check(result):
continue
return result
async def on_timeout(self):
self._result.set_exception(asyncio.TimeoutError)
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
def wrapper(coro):
async def wrapped_callback(interaction):
with logging_context(action=coro.__name__):
if check is not None:
if not check(interaction):
return
try:
await coro(interaction, *pass_args, **pass_kwargs)
except Exception:
raise
finally:
if once:
self._waiters.remove(wrapped_callback)
self._waiters.append(wrapped_callback)
return wrapper
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
try:
# First let our error handlers have a go
# If there is no handler for this error, or the handlers themselves error,
# drop to the superclass error handler implementation.
try:
raise error
except tuple(self._error_handlers.keys()) as e:
# If an error handler is registered for this exception, run it.
for cls, handler in self._error_handlers.items():
if isinstance(e, cls):
await handler(interaction, e)
except Exception as error:
await super().on_error(interaction, error)
async def on_submit(self, interaction):
print("On submit")
old_result = self._result
self._result = asyncio.get_event_loop().create_future()
old_result.set_result(interaction)
tasks = []
for waiter in self._waiters:
task = asyncio.create_task(
waiter(interaction),
name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}"
)
tasks.append(task)
if tasks:
await asyncio.gather(*tasks)
async def input(
interaction: discord.Interaction,
title: str,
question: Optional[str] = None,
field: Optional[TextInput] = None,
timeout=180,
**kwargs,
) -> tuple[discord.Interaction, str]:
"""
Spawn a modal to accept input.
Returns an (interaction, value) pair, with interaction not yet responded to.
May raise asyncio.TimeoutError if the view times out.
"""
if field is None:
field = TextInput(
label=kwargs.get('label', question),
**kwargs
)
modal = FastModal(
field,
title=title,
timeout=timeout
)
await interaction.response.send_modal(modal)
interaction = await modal.wait_for()
return (interaction, field.value)
class ModalRetryUI(LeoUI):
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.modal = modal
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
self.message = message
self._interaction = None
if label is not None:
self.retry_button.label = label
@property
def embed(self):
return discord.Embed(
title="Uh-Oh!",
description=self.message,
colour=discord.Colour.red()
)
async def respond_to(self, interaction):
self._interaction = interaction
if interaction.response.is_done():
await interaction.followup.send(embed=self.embed, ephemeral=True, view=self)
else:
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
@button(label="Retry")
async def retry_button(self, interaction, butt):
# Setting these here so they don't update in the meantime
for item, value in self.item_values.items():
item.default = value
if self._interaction is not None:
await self._interaction.delete_original_response()
self._interaction = None
await interaction.response.send_modal(self.modal)
await self.close()
class Confirm(LeoUI):
"""
Micro UI class implementing a confirmation question.
Parameters
----------
confirm_msg: str
The confirmation question to ask from the user.
This is set as the description of the `embed` property.
The `embed` may be further modified if required.
permitted_id: Optional[int]
The user id allowed to access this interaction.
Other users will recieve an access denied error message.
defer: bool
Whether to defer the interaction response while handling the button.
It may be useful to set this to `False` to obtain manual control
over the interaction response flow (e.g. to send a modal or ephemeral message).
The button press interaction may be accessed through `Confirm.interaction`.
Default: True
Example
-------
```
confirm = Confirm("Are you sure?", ctx.author.id)
confirm.embed.colour = discord.Colour.red()
confirm.confirm_button.label = "Yes I am sure"
confirm.cancel_button.label = "No I am not sure"
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResultTimedOut:
return
```
"""
def __init__(
self,
confirm_msg: str,
permitted_id: Optional[int] = None,
defer: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.confirm_msg = confirm_msg
self.permitted_id = permitted_id
self.defer = defer
self._embed: Optional[discord.Embed] = None
self._result: asyncio.Future[bool] = asyncio.Future()
# Indicates whether we should delete the message or the interaction response
self._is_followup: bool = False
self._original: Optional[discord.Interaction] = None
self._message: Optional[discord.Message] = None
async def interaction_check(self, interaction: discord.Interaction):
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
async def on_timeout(self):
# Propagate timeout to result Future
self._result.set_exception(ResponseTimedOut)
await self.cleanup()
async def cleanup(self):
"""
Cleanup the confirmation prompt by deleting it, if possible.
Ignores any Discord errors that occur during the process.
"""
try:
if self._is_followup and self._message:
await self._message.delete()
elif not self._is_followup and self._original and not self._original.is_expired():
await self._original.delete_original_response()
except discord.HTTPException:
# A user probably already deleted the message
# Anything could have happened, just ignore.
pass
@button(label="Confirm")
async def confirm_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(True)
await self.close()
@button(label="Cancel")
async def cancel_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(False)
await self.close()
@property
def embed(self):
"""
Confirmation embed shown to the user.
This is cached, and may be modifed directly through the usual EmbedProxy API,
or explicitly overwritten.
"""
if self._embed is None:
self._embed = discord.Embed(
colour=discord.Colour.orange(),
description=self.confirm_msg
)
return self._embed
@embed.setter
def embed(self, value):
self._embed = value
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
"""
Send this confirmation prompt in response to the provided interaction.
Extra keyword arguments are passed to `Interaction.response.send_message`
or `Interaction.send_followup`, depending on whether
the provided interaction has already been responded to.
The `epehemeral` argument is handled specially,
since the question message can only be deleted through `Interaction.delete_original_response`.
Waits on and returns the internal `result` Future.
Returns: bool
True if the user pressed the confirm button.
False if the user pressed the cancel button.
Raises:
ResponseTimedOut:
If the user does not respond before the UI times out.
"""
self._original = interaction
if interaction.response.is_done():
# Interaction already responded to, send a follow up
if ephemeral:
raise ValueError("Cannot send an ephemeral response to a used interaction.")
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
self._is_followup = True
else:
await interaction.response.send_message(
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
)
self._is_followup = False
return await self._result
# TODO: Selector MicroUI for displaying options (<= 25)

9
src/wards.py Normal file
View File

@@ -0,0 +1,9 @@
from meta import LionBot
# Raw checks, return True/False depending on whether they pass
async def sys_admin(bot: LionBot, userid: int):
"""
Checks whether the context author is listed in the configuration file as a bot admin.
"""
admins = bot.config.bot.getintlist('admins')
return userid in admins