From 1ae7c60ba83a8ef6a30c0474670e8bf1855d4f30 Mon Sep 17 00:00:00 2001 From: Interitio Date: Thu, 2 Nov 2023 10:09:06 +0200 Subject: [PATCH] Initial Template. --- .gitignore | 151 ++++++ config/example-bot.conf | 28 + config/example-secrets.conf | 6 + data/.gitignore | 0 data/schema.sql | 34 ++ requirements.txt | 7 + scripts/start_bot.py | 12 + scripts/start_debug.py | 35 ++ src/bot.py | 114 +++++ src/constants.py | 6 + src/core/__init__.py | 6 + src/core/cog.py | 76 +++ src/core/data.py | 45 ++ src/data/__init__.py | 9 + src/data/adapted.py | 40 ++ src/data/base.py | 45 ++ src/data/columns.py | 155 ++++++ src/data/conditions.py | 214 ++++++++ src/data/connector.py | 135 +++++ src/data/cursor.py | 42 ++ src/data/database.py | 47 ++ src/data/models.py | 323 ++++++++++++ src/data/queries.py | 644 +++++++++++++++++++++++ src/data/registry.py | 102 ++++ src/data/table.py | 95 ++++ src/meta/LionBot.py | 344 +++++++++++++ src/meta/LionCog.py | 58 +++ src/meta/LionContext.py | 195 +++++++ src/meta/LionTree.py | 150 ++++++ src/meta/__init__.py | 15 + src/meta/app.py | 32 ++ src/meta/args.py | 35 ++ src/meta/config.py | 146 ++++++ src/meta/context.py | 20 + src/meta/errors.py | 64 +++ src/meta/logger.py | 468 +++++++++++++++++ src/meta/monitor.py | 139 +++++ src/meta/sharding.py | 35 ++ src/modules/__init__.py | 10 + src/modules/sysadmin/__init__.py | 5 + src/modules/sysadmin/exec_cog.py | 336 ++++++++++++ src/utils/__init__.py | 0 src/utils/ansi.py | 97 ++++ src/utils/data.py | 165 ++++++ src/utils/lib.py | 847 +++++++++++++++++++++++++++++++ src/utils/monitor.py | 191 +++++++ src/utils/ratelimits.py | 173 +++++++ src/utils/ui/__init__.py | 8 + src/utils/ui/hooked.py | 59 +++ src/utils/ui/leo.py | 485 ++++++++++++++++++ src/utils/ui/micros.py | 329 ++++++++++++ src/wards.py | 9 + 52 files changed, 6786 insertions(+) create mode 100644 .gitignore create mode 100644 config/example-bot.conf create mode 100644 config/example-secrets.conf create mode 100644 data/.gitignore create mode 100644 data/schema.sql create mode 100644 requirements.txt create mode 100755 scripts/start_bot.py create mode 100755 scripts/start_debug.py create mode 100644 src/bot.py create mode 100644 src/constants.py create mode 100644 src/core/__init__.py create mode 100644 src/core/cog.py create mode 100644 src/core/data.py create mode 100644 src/data/__init__.py create mode 100644 src/data/adapted.py create mode 100644 src/data/base.py create mode 100644 src/data/columns.py create mode 100644 src/data/conditions.py create mode 100644 src/data/connector.py create mode 100644 src/data/cursor.py create mode 100644 src/data/database.py create mode 100644 src/data/models.py create mode 100644 src/data/queries.py create mode 100644 src/data/registry.py create mode 100644 src/data/table.py create mode 100644 src/meta/LionBot.py create mode 100644 src/meta/LionCog.py create mode 100644 src/meta/LionContext.py create mode 100644 src/meta/LionTree.py create mode 100644 src/meta/__init__.py create mode 100644 src/meta/app.py create mode 100644 src/meta/args.py create mode 100644 src/meta/config.py create mode 100644 src/meta/context.py create mode 100644 src/meta/errors.py create mode 100644 src/meta/logger.py create mode 100644 src/meta/monitor.py create mode 100644 src/meta/sharding.py create mode 100644 src/modules/__init__.py create mode 100644 src/modules/sysadmin/__init__.py create mode 100644 src/modules/sysadmin/exec_cog.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/ansi.py create mode 100644 src/utils/data.py create mode 100644 src/utils/lib.py create mode 100644 src/utils/monitor.py create mode 100644 src/utils/ratelimits.py create mode 100644 src/utils/ui/__init__.py create mode 100644 src/utils/ui/hooked.py create mode 100644 src/utils/ui/leo.py create mode 100644 src/utils/ui/micros.py create mode 100644 src/wards.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d0cda90 --- /dev/null +++ b/.gitignore @@ -0,0 +1,151 @@ +src/modules/test/* + +pending-rewrite/ +logs/* +notes/* +tmp/* +output/* +locales/domains + +.idea/* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +config/** diff --git a/config/example-bot.conf b/config/example-bot.conf new file mode 100644 index 0000000..9b67e4a --- /dev/null +++ b/config/example-bot.conf @@ -0,0 +1,28 @@ +[BOT] +prefix = !! + +admins = + +admin_guilds = + +shard_count = 1 + +ALSO_READ = config/emojis.conf, config/secrets.conf + + +[LOGGING] +log_file = bot.log + +general_log = +error_log = %(general_log) +critical_log = %(general_log) +warning_log = %(general_log) +warning_prefix = +error_prefix = +critical_prefix = + +[LOGGING_LEVELS] +root = DEBUG +discord = INFO +discord.http = INFO +discord.gateway = INFO diff --git a/config/example-secrets.conf b/config/example-secrets.conf new file mode 100644 index 0000000..b48b2ca --- /dev/null +++ b/config/example-secrets.conf @@ -0,0 +1,6 @@ +[STUDYLION] +token = + +[DATA] +args = dbname= +appid = diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/data/schema.sql b/data/schema.sql new file mode 100644 index 0000000..0641eed --- /dev/null +++ b/data/schema.sql @@ -0,0 +1,34 @@ +-- Metadata {{{ +CREATE TABLE VersionHistory( + version INTEGER NOT NULL, + time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + author TEXT +); +INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation'); + + +CREATE OR REPLACE FUNCTION update_timestamp_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW._timestamp = (now() at time zone 'utc'); + RETURN NEW; +END; +$$ language 'plpgsql'; +-- }}} + +-- App metadata {{{ + +CREATE TABLE app_config( + appname TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE bot_config( + appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE, + sponsor_prompt TEXT, + sponsor_message TEXT, + default_skin TEXT +); +-- }}} + +-- vim: set fdm=marker: diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..48ed17b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +aiohttp==3.7.4.post0 +cachetools==4.2.2 +configparser==5.0.2 +discord.py [voice] +iso8601==0.1.16 +psycopg[pool] +pytz==2021.1 diff --git a/scripts/start_bot.py b/scripts/start_bot.py new file mode 100755 index 0000000..49d6ad1 --- /dev/null +++ b/scripts/start_bot.py @@ -0,0 +1,12 @@ +# !/bin/python3 + +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + + +if __name__ == '__main__': + from bot import _main + _main() diff --git a/scripts/start_debug.py b/scripts/start_debug.py new file mode 100755 index 0000000..d4837d0 --- /dev/null +++ b/scripts/start_debug.py @@ -0,0 +1,35 @@ +# !/bin/python3 + +import sys +import os +import tracemalloc +import asyncio + + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + +tracemalloc.start() + + +def loop_exception_handler(loop, context): + print(context) + task: asyncio.Task = context.get('task', None) + if task is not None: + addendum = f"" + message = context.get('message', '') + context['message'] = ' '.join((message, addendum)) + loop.default_exception_handler(context) + + +def main(): + loop = asyncio.get_event_loop() + loop.set_exception_handler(loop_exception_handler) + loop.set_debug(enabled=True) + + from bot import _main + _main() + + +if __name__ == '__main__': + main() diff --git a/src/bot.py b/src/bot.py new file mode 100644 index 0000000..0b61a61 --- /dev/null +++ b/src/bot.py @@ -0,0 +1,114 @@ +import asyncio +import logging + +import aiohttp +import discord +from discord.ext import commands + +from meta import LionBot, conf, sharding, appname +from meta.app import shardname +from meta.logger import log_context, log_action_stack, setup_main_logger +from meta.context import ctx_bot +from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus + +from data import Database + +from constants import DATA_VERSION + + +for name in conf.config.options('LOGGING_LEVELS', no_defaults=True): + logging.getLogger(name).setLevel(conf.logging_levels[name]) + + +logging_queue = setup_main_logger() + + +logger = logging.getLogger(__name__) + +db = Database(conf.data['args']) + + +async def _data_monitor() -> ComponentStatus: + """ + Component monitor callback for the database. + """ + data = { + 'stats': str(db.pool.get_stats()) + } + if not db.pool._opened: + level = StatusLevel.WAITING + info = "(WAITING) Database Pool is not opened." + elif db.pool._closed: + level = StatusLevel.ERRORED + info = "(ERROR) Database Pool is closed." + else: + level = StatusLevel.OKAY + info = "(OK) Database Pool statistics: {stats}" + return ComponentStatus(level, info, info, data) + + +async def main(): + log_action_stack.set(("Initialising",)) + logger.info("Initialising StudyLion") + + intents = discord.Intents.all() + intents.members = True + intents.message_content = True + intents.presences = False + + async with db.open(): + version = await db.version() + if version.version != DATA_VERSION: + error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." + logger.critical(error) + raise RuntimeError(error) + + async with aiohttp.ClientSession() as session: + async with LionBot( + command_prefix='!leo!', + intents=intents, + appname=appname, + shardname=shardname, + db=db, + config=conf, + initial_extensions=[ + 'core', + 'modules', + ], + web_client=session, + testing_guilds=conf.bot.getintlist('admin_guilds'), + shard_id=sharding.shard_number, + shard_count=sharding.shard_count, + help_command=None, + proxy=conf.bot.get('proxy', None), + chunk_guilds_at_startup=False, + ) as lionbot: + ctx_bot.set(lionbot) + lionbot.system_monitor.add_component( + ComponentMonitor('Database', _data_monitor) + ) + try: + log_context.set(f"APP: {appname}") + logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) + await lionbot.start(conf.bot['TOKEN']) + except asyncio.CancelledError: + log_context.set(f"APP: {appname}") + logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) + + +def _main(): + from signal import SIGINT, SIGTERM + + loop = asyncio.get_event_loop() + main_task = asyncio.ensure_future(main()) + for signal in [SIGINT, SIGTERM]: + loop.add_signal_handler(signal, main_task.cancel) + try: + loop.run_until_complete(main_task) + finally: + loop.close() + logging.shutdown() + + +if __name__ == '__main__': + _main() diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..3a34de6 --- /dev/null +++ b/src/constants.py @@ -0,0 +1,6 @@ +CONFIG_FILE = "config/bot.conf" +DATA_VERSION = 1 + +MAX_COINS = 2147483647 - 1 + +HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png" diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..2cbb58f --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,6 @@ + + +async def setup(bot): + from .cog import CoreCog + + await bot.add_cog(CoreCog(bot)) diff --git a/src/core/cog.py b/src/core/cog.py new file mode 100644 index 0000000..b5eab52 --- /dev/null +++ b/src/core/cog.py @@ -0,0 +1,76 @@ +import logging +from typing import Optional +from collections import defaultdict +from weakref import WeakValueDictionary + +import discord +import discord.app_commands as appcmd + +from meta import LionBot, LionCog, LionContext +from meta.app import shardname, appname +from meta.logger import log_wrap +from utils.lib import utc_now + +from .data import CoreData + +logger = logging.getLogger(__name__) + + +class keydefaultdict(defaultdict): + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + else: + ret = self[key] = self.default_factory(key) + return ret + + +class CoreCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = CoreData() + bot.db.load_registry(self.data) + + self.app_config: Optional[CoreData.AppConfig] = None + self.bot_config: Optional[CoreData.BotConfig] = None + + self.app_cmd_cache: list[discord.app_commands.AppCommand] = [] + self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {} + self.mention_cache: dict[str, str] = keydefaultdict(self.mention_cmd) + + async def cog_load(self): + # Fetch (and possibly create) core data rows. + self.app_config = await self.data.AppConfig.fetch_or_create(appname) + self.bot_config = await self.data.BotConfig.fetch_or_create(appname) + + # Load the app command cache + await self.reload_appcmd_cache() + + async def reload_appcmd_cache(self): + for guildid in self.bot.testing_guilds: + self.app_cmd_cache += await self.bot.tree.fetch_commands(guild=discord.Object(guildid)) + self.app_cmd_cache += await self.bot.tree.fetch_commands() + self.cmd_name_cache = {cmd.name: cmd for cmd in self.app_cmd_cache} + self.mention_cache = self._mention_cache_from(self.app_cmd_cache) + + def _mention_cache_from(self, cmds: list[appcmd.AppCommand | appcmd.AppCommandGroup]): + cache = keydefaultdict(self.mention_cmd) + for cmd in cmds: + cache[cmd.qualified_name if isinstance(cmd, appcmd.AppCommandGroup) else cmd.name] = cmd.mention + subcommands = [option for option in cmd.options if isinstance(option, appcmd.AppCommandGroup)] + if subcommands: + subcache = self._mention_cache_from(subcommands) + cache |= subcache + return cache + + def mention_cmd(self, name: str): + """ + Create an application command mention for the given names. + + If not found in cache, creates a 'fake' mention with an invalid id. + """ + if name in self.mention_cache: + mention = self.mention_cache[name] + else: + mention = f"" + return mention diff --git a/src/core/data.py b/src/core/data.py new file mode 100644 index 0000000..7bb4276 --- /dev/null +++ b/src/core/data.py @@ -0,0 +1,45 @@ +from enum import Enum +from itertools import chain +from psycopg import sql +from cachetools import TTLCache +import discord + +from meta import conf +from meta.logger import log_wrap +from data import Table, Registry, Column, RowModel, RegisterEnum +from data.models import WeakCache +from data.columns import Integer, String, Bool, Timestamp + + +class CoreData(Registry, name="core"): + class AppConfig(RowModel): + """ + Schema + ------ + CREATE TABLE app_config( + appname TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """ + _tablename_ = 'app_config' + + appname = String(primary=True) + created_at = Timestamp() + + class BotConfig(RowModel): + """ + Schema + ------ + CREATE TABLE bot_config( + appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE, + sponsor_prompt TEXT, + sponsor_message TEXT, + default_skin TEXT + ); + """ + _tablename_ = 'bot_config' + + appname = String(primary=True) + default_skin = String() + sponsor_prompt = String() + sponsor_message = String() diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..affb160 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,9 @@ +from .conditions import Condition, condition, NULL +from .database import Database +from .models import RowModel, RowTable, WeakCache +from .table import Table +from .base import Expression, RawExpr +from .columns import ColumnExpr, Column, Integer, String +from .registry import Registry, AttachableClass, Attachable +from .adapted import RegisterEnum +from .queries import ORDER, NULLS, JOINTYPE diff --git a/src/data/adapted.py b/src/data/adapted.py new file mode 100644 index 0000000..a6b4597 --- /dev/null +++ b/src/data/adapted.py @@ -0,0 +1,40 @@ +# from enum import Enum +from typing import Optional +from psycopg.types.enum import register_enum, EnumInfo +from psycopg import AsyncConnection +from .registry import Attachable, Registry + + +class RegisterEnum(Attachable): + def __init__(self, enum, name: Optional[str] = None, mapper=None): + super().__init__() + self.enum = enum + self.name = name or enum.__name__ + self.mapping = mapper(enum) if mapper is not None else self._mapper() + + def _mapper(self): + return {m: m.value[0] for m in self.enum} + + def attach_to(self, registry: Registry): + self._registry = registry + registry.init_task(self.on_init) + return self + + async def on_init(self, registry: Registry): + connector = registry._conn + if connector is None: + raise ValueError("Cannot initialise without connector!") + connector.connect_hook(self.connection_hook) + # await connector.refresh_pool() + # The below may be somewhat dangerous + # But adaption should never write to the database + await connector.map_over_pool(self.connection_hook) + # if conn := connector.conn: + # # Ensure the adaption is run in the current context as well + # await self.connection_hook(conn) + + async def connection_hook(self, conn: AsyncConnection): + info = await EnumInfo.fetch(conn, self.name) + if info is None: + raise ValueError(f"Enum {self.name} not found in database.") + register_enum(info, conn, self.enum, mapping=list(self.mapping.items())) diff --git a/src/data/base.py b/src/data/base.py new file mode 100644 index 0000000..272d588 --- /dev/null +++ b/src/data/base.py @@ -0,0 +1,45 @@ +from abc import abstractmethod +from typing import Any, Protocol, runtime_checkable +from itertools import chain +from psycopg import sql + + +@runtime_checkable +class Expression(Protocol): + __slots__ = () + + @abstractmethod + def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]: + raise NotImplementedError + + +class RawExpr(Expression): + __slots__ = ('expr', 'values') + + expr: sql.Composable + values: tuple[Any, ...] + + def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()): + self.expr = expr + self.values = values + + def as_tuple(self): + return (self.expr, self.values) + + @classmethod + def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')): + """ + Join a sequence of Expressions into a single RawExpr. + """ + tups = ( + expression.as_tuple() + for expression in expressions + ) + return cls.join_tuples(*tups, joiner=joiner) + + @classmethod + def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')): + exprs, values = zip(*tuples) + expr = joiner.join(exprs) + value = tuple(chain(*values)) + return cls(expr, value) diff --git a/src/data/columns.py b/src/data/columns.py new file mode 100644 index 0000000..252db83 --- /dev/null +++ b/src/data/columns.py @@ -0,0 +1,155 @@ +from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING +from psycopg import sql +from datetime import datetime + +from .base import RawExpr, Expression +from .conditions import Condition, Joiner +from .table import Table + + +class ColumnExpr(RawExpr): + __slots__ = () + + def __lt__(self, obj) -> Condition: + expr, values = self.as_tuple() + + if isinstance(obj, Expression): + # column < Expression + obj_expr, obj_values = obj.as_tuple() + cond_exprs = (expr, Joiner.LT, obj_expr) + cond_values = (*values, *obj_values) + else: + # column < Literal + cond_exprs = (expr, Joiner.LT, sql.Placeholder()) + cond_values = (*values, obj) + + return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __le__(self, obj) -> Condition: + expr, values = self.as_tuple() + + if isinstance(obj, Expression): + # column <= Expression + obj_expr, obj_values = obj.as_tuple() + cond_exprs = (expr, Joiner.LE, obj_expr) + cond_values = (*values, *obj_values) + else: + # column <= Literal + cond_exprs = (expr, Joiner.LE, sql.Placeholder()) + cond_values = (*values, obj) + + return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __eq__(self, obj) -> Condition: # type: ignore[override] + return Condition._expression_equality(self, obj) + + def __ne__(self, obj) -> Condition: # type: ignore[override] + return ~(self.__eq__(obj)) + + def __gt__(self, obj) -> Condition: + return ~(self.__le__(obj)) + + def __ge__(self, obj) -> Condition: + return ~(self.__lt__(obj)) + + def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} + {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} + {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def __sub__(self, obj) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} - {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} - {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def __mul__(self, obj) -> 'ColumnExpr': + if isinstance(obj, Expression): + obj_expr, obj_values = obj.as_tuple() + return ColumnExpr( + sql.SQL("({} * {})").format(self.expr, obj_expr), + (*self.values, *obj_values) + ) + else: + return ColumnExpr( + sql.SQL("({} * {})").format(self.expr, sql.Placeholder()), + (*self.values, obj) + ) + + def CAST(self, target_type: sql.Composable): + return ColumnExpr( + sql.SQL("({}::{})").format(self.expr, target_type), + self.values + ) + + +T = TypeVar('T') + +if TYPE_CHECKING: + from .models import RowModel + + +class Column(ColumnExpr, Generic[T]): + def __init__(self, name: Optional[str] = None, + primary: bool = False, references: Optional['Column'] = None, + type: Optional[Type[T]] = None): + self.primary = primary + self.references = references + self.name: str = name # type: ignore + self.owner: Optional['RowModel'] = None + self._type = type + + self.expr = sql.Identifier(name) if name else sql.SQL('') + self.values = () + + def __set_name__(self, owner, name): + # Only allow setting the owner once + self.name = self.name or name + self.owner = owner + self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name) + + @overload + def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': + ... + + @overload + def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T: + ... + + def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]": + # Get value from row data or session + if obj is None: + return self + else: + return obj.data[self.name] + + +class Integer(Column[int]): + pass + + +class String(Column[str]): + pass + + +class Bool(Column[bool]): + pass + + +class Timestamp(Column[datetime]): + pass diff --git a/src/data/conditions.py b/src/data/conditions.py new file mode 100644 index 0000000..f40dff6 --- /dev/null +++ b/src/data/conditions.py @@ -0,0 +1,214 @@ +# from meta import sharding +from typing import Any, Union +from enum import Enum +from itertools import chain +from psycopg import sql + +from .base import Expression, RawExpr + + +""" +A Condition is a "logical" database expression, intended for use in Where statements. +Conditions support bitwise logical operators ~, &, |, each producing another Condition. +""" + +NULL = None + + +class Joiner(Enum): + EQUALS = ('=', '!=') + IS = ('IS', 'IS NOT') + LIKE = ('LIKE', 'NOT LIKE') + BETWEEN = ('BETWEEN', 'NOT BETWEEN') + IN = ('IN', 'NOT IN') + LT = ('<', '>=') + LE = ('<=', '>') + NONE = ('', '') + + +class Condition(Expression): + __slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values') + + def __init__(self, + expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''), + values: tuple[Any, ...] = (), negated=False + ): + self.expr1 = expr1 + self.joiner = joiner + self.negated = negated + self.expr2 = expr2 + self.values = values + + def as_tuple(self): + expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2)) + if self.negated and self.joiner is Joiner.NONE: + expr = sql.SQL("NOT ({})").format(expr) + return (expr, self.values) + + @classmethod + def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]): + """ + Construct a Condition from a sequence of Conditions, + together with some explicit column conditions. + """ + # TODO: Consider adding a _table identifier here so we can identify implicit columns + # Or just require subquery type conditions to always come from modelled tables. + implicit_conditions = ( + cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items() + ) + return cls._and(*conditions, *implicit_conditions) + + @classmethod + def _and(cls, *conditions: 'Condition'): + if not len(conditions): + raise ValueError("Cannot combine 0 Conditions") + if len(conditions) == 1: + return conditions[0] + + exprs, values = zip(*(condition.as_tuple() for condition in conditions)) + cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs)) + cond_values = tuple(chain(*values)) + + return Condition(cond_expr, values=cond_values) + + @classmethod + def _or(cls, *conditions: 'Condition'): + if not len(conditions): + raise ValueError("Cannot combine 0 Conditions") + if len(conditions) == 1: + return conditions[0] + + exprs, values = zip(*(condition.as_tuple() for condition in conditions)) + cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs)) + cond_values = tuple(chain(*values)) + + return Condition(cond_expr, values=cond_values) + + @classmethod + def _not(cls, condition: 'Condition'): + condition.negated = not condition.negated + return condition + + @classmethod + def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition': + # TODO: Check if this supports sbqueries + col_expr, col_values = column.as_tuple() + + # TODO: Also support sql.SQL? For joins? + if isinstance(value, Expression): + # column = Expression + value_expr, value_values = value.as_tuple() + cond_exprs = (col_expr, Joiner.EQUALS, value_expr) + cond_values = (*col_values, *value_values) + elif isinstance(value, (tuple, list)): + # column in (...) + # TODO: Support expressions in value tuple? + if not value: + raise ValueError("Cannot create Condition from empty iterable!") + value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value))) + cond_exprs = (col_expr, Joiner.IN, value_expr) + cond_values = (*col_values, *value) + elif value is None: + # column IS NULL + cond_exprs = (col_expr, Joiner.IS, sql.NULL) + cond_values = col_values + else: + # column = Literal + cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder()) + cond_values = (*col_values, value) + + return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values) + + def __invert__(self) -> 'Condition': + self.negated = not self.negated + return self + + def __and__(self, condition: 'Condition') -> 'Condition': + return self._and(self, condition) + + def __or__(self, condition: 'Condition') -> 'Condition': + return self._or(self, condition) + + +# Helper method to simply condition construction +def condition(*args, **kwargs) -> Condition: + return Condition.construct(*args, **kwargs) + + +# class NOT(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# if item: +# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item)))) +# values.extend(item) +# else: +# raise ValueError("Cannot check an empty iterable!") +# else: +# conditions.append("{}!={}".format(key, _replace_char)) +# values.append(item) +# +# +# class GEQ(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# raise ValueError("Cannot apply GEQ condition to a list!") +# else: +# conditions.append("{} >= {}".format(key, _replace_char)) +# values.append(item) +# +# +# class LEQ(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# item = self.value +# if isinstance(item, (list, tuple)): +# raise ValueError("Cannot apply LEQ condition to a list!") +# else: +# conditions.append("{} <= {}".format(key, _replace_char)) +# values.append(item) +# +# +# class Constant(Condition): +# __slots__ = ('value',) +# +# def __init__(self, value): +# self.value = value +# +# def apply(self, key, values, conditions): +# conditions.append("{} {}".format(key, self.value)) +# +# +# class SHARDID(Condition): +# __slots__ = ('shardid', 'shard_count') +# +# def __init__(self, shardid, shard_count): +# self.shardid = shardid +# self.shard_count = shard_count +# +# def apply(self, key, values, conditions): +# if self.shard_count > 1: +# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) +# values.append(self.shardid) +# +# +# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) +# +# +# NULL = Constant('IS NULL') +# NOTNULL = Constant('IS NOT NULL') diff --git a/src/data/connector.py b/src/data/connector.py new file mode 100644 index 0000000..7b25aed --- /dev/null +++ b/src/data/connector.py @@ -0,0 +1,135 @@ +from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional +import logging + +from contextvars import ContextVar +from contextlib import asynccontextmanager +import psycopg as psq +from psycopg_pool import AsyncConnectionPool +from psycopg.pq import TransactionStatus + +from .cursor import AsyncLoggingCursor + +logger = logging.getLogger(__name__) + +row_factory = psq.rows.dict_row + +ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None) + + +class Connector: + cursor_factory = AsyncLoggingCursor + + def __init__(self, conn_args): + self._conn_args = conn_args + self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory) + + self.pool = self.make_pool() + + self.conn_hooks = [] + + @property + def conn(self) -> Optional[psq.AsyncConnection]: + """ + Convenience property for the current context connection. + """ + return ctx_connection.get() + + @conn.setter + def conn(self, conn: psq.AsyncConnection): + """ + Set the contextual connection in the current context. + Always do this in an isolated context! + """ + ctx_connection.set(conn) + + def make_pool(self) -> AsyncConnectionPool: + logger.info("Initialising connection pool.", extra={'action': "Pool Init"}) + return AsyncConnectionPool( + self._conn_args, + open=False, + min_size=4, + max_size=8, + configure=self._setup_connection, + kwargs=self._conn_kwargs + ) + + async def refresh_pool(self): + """ + Refresh the pool. + + The point of this is to invalidate any existing connections so that the connection set up is run again. + Better ways should be sought (a way to + """ + logger.info("Pool refresh requested, closing and reopening.") + old_pool = self.pool + self.pool = self.make_pool() + await self.pool.open() + logger.info(f"Old pool statistics: {self.pool.get_stats()}") + await old_pool.close() + logger.info("Pool refresh complete.") + + async def map_over_pool(self, callable): + """ + Dangerous method to call a method on each connection in the pool. + + Utilises private methods of the AsyncConnectionPool. + """ + async with self.pool._lock: + conns = list(self.pool._pool) + while conns: + conn = conns.pop() + try: + await callable(conn) + except Exception: + logger.exception(f"Mapped connection task failed. {callable.__name__}") + + @asynccontextmanager + async def open(self): + try: + logger.info("Opening database pool.") + await self.pool.open() + yield + finally: + # May be a different pool! + logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}") + await self.pool.close() + + @asynccontextmanager + async def connection(self) -> psq.AsyncConnection: + """ + Asynchronous context manager to get and manage a connection. + + If the context connection is set, uses this and does not manage the lifetime. + Otherwise, requests a new connection from the pool and returns it when done. + """ + logger.debug("Database connection requested.", extra={'action': "Data Connect"}) + if (conn := self.conn): + yield conn + else: + async with self.pool.connection() as conn: + yield conn + + async def _setup_connection(self, conn: psq.AsyncConnection): + logger.debug("Initialising new connection.", extra={'action': "Conn Init"}) + for hook in self.conn_hooks: + try: + await hook(conn) + except Exception: + logger.exception("Exception encountered setting up new connection") + return conn + + def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): + """ + Minimal decorator to register a coroutine to run on connect or reconnect. + + Note that these are only run on connect and reconnect. + If a hook is registered after connection, it will not be run. + """ + self.conn_hooks.append(coro) + return coro + + +@runtime_checkable +class Connectable(Protocol): + def bind(self, connector: Connector): + raise NotImplementedError diff --git a/src/data/cursor.py b/src/data/cursor.py new file mode 100644 index 0000000..5e01a8d --- /dev/null +++ b/src/data/cursor.py @@ -0,0 +1,42 @@ +import logging +from typing import Optional + +from psycopg import AsyncCursor, sql +from psycopg.abc import Query, Params +from psycopg._encodings import pgconn_encoding + +logger = logging.getLogger(__name__) + + +class AsyncLoggingCursor(AsyncCursor): + def mogrify_query(self, query: Query): + if isinstance(query, str): + msg = query + elif isinstance(query, (sql.SQL, sql.Composed)): + msg = query.as_string(self) + elif isinstance(query, bytes): + msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace') + else: + msg = repr(query) + return msg + + async def execute(self, query: Query, params: Optional[Params] = None, **kwargs): + if logging.DEBUG >= logger.getEffectiveLevel(): + msg = self.mogrify_query(query) + logger.debug( + "Executing query (%s) with values %s", msg, params, + extra={'action': "Query Execute"} + ) + try: + return await super().execute(query, params=params, **kwargs) + except Exception: + msg = self.mogrify_query(query) + logger.exception( + "Exception during query execution. Query (%s) with parameters %s.", + msg, params, + extra={'action': "Query Execute"}, + stack_info=True + ) + else: + # TODO: Possibly log execution time + pass diff --git a/src/data/database.py b/src/data/database.py new file mode 100644 index 0000000..255e412 --- /dev/null +++ b/src/data/database.py @@ -0,0 +1,47 @@ +from typing import TypeVar +import logging +from collections import namedtuple + +# from .cursor import AsyncLoggingCursor +from .registry import Registry +from .connector import Connector + + +logger = logging.getLogger(__name__) + +Version = namedtuple('Version', ('version', 'time', 'author')) + +T = TypeVar('T', bound=Registry) + + +class Database(Connector): + # cursor_factory = AsyncLoggingCursor + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.registries: dict[str, Registry] = {} + + def load_registry(self, registry: T) -> T: + logger.debug( + f"Loading and binding registry '{registry.name}'.", + extra={'action': f"Reg {registry.name}"} + ) + registry.bind(self) + self.registries[registry.name] = registry + return registry + + async def version(self) -> Version: + """ + Return the current schema version as a Version namedtuple. + """ + async with self.connection() as conn: + async with conn.cursor() as cursor: + # Get last entry in version table, compare against desired version + await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") + row = await cursor.fetchone() + if row: + return Version(row['version'], row['time'], row['author']) + else: + # No versions in the database + return Version(-1, None, None) diff --git a/src/data/models.py b/src/data/models.py new file mode 100644 index 0000000..54b6282 --- /dev/null +++ b/src/data/models.py @@ -0,0 +1,323 @@ +from typing import TypeVar, Type, Optional, Generic, Union +# from typing_extensions import Self +from weakref import WeakValueDictionary +from collections.abc import MutableMapping + +from psycopg.rows import DictRow + +from .table import Table +from .columns import Column +from . import queries as q +from .connector import Connector +from .registry import Registry + + +RowT = TypeVar('RowT', bound='RowModel') + + +class MISSING: + __slots__ = ('oid',) + + def __init__(self, oid): + self.oid = oid + + +class RowTable(Table, Generic[RowT]): + __slots__ = ( + 'model', + ) + + def __init__(self, name, model: Type[RowT], **kwargs): + super().__init__(name, **kwargs) + self.model = model + + @property + def columns(self): + return self.model._columns_ + + @property + def id_col(self): + return self.model._key_ + + @property + def row_cache(self): + return self.model._cache_ + + def _many_query_adapter(self, *data): + self.model._make_rows(*data) + return data + + def _single_query_adapter(self, *data): + if data: + self.model._make_rows(*data) + return data[0] + else: + return None + + def _delete_query_adapter(self, *data): + self.model._delete_rows(*data) + return data + + # New methods to fetch and create rows + async def create_row(self, *args, **kwargs) -> RowT: + data = await super().insert(*args, **kwargs) + return self.model._make_rows(data)[0] + + def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]: + # TODO: Handle list of rowids here? + return q.Select( + self.identifier, + row_adapter=self.model._make_rows, + connector=self.connector + ).where(*args, **kwargs) + + +WK = TypeVar('WK') +WV = TypeVar('WV') + + +class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]): + def __init__(self, ref_cache): + self.ref_cache = ref_cache + self.weak_cache = WeakValueDictionary() + + def __getitem__(self, key): + value = self.weak_cache[key] + self.ref_cache[key] = value + return value + + def __setitem__(self, key, value): + self.weak_cache[key] = value + self.ref_cache[key] = value + + def __delitem__(self, key): + del self.weak_cache[key] + try: + del self.ref_cache[key] + except KeyError: + pass + + def __contains__(self, key): + return key in self.weak_cache + + def __iter__(self): + return iter(self.weak_cache) + + def __len__(self): + return len(self.weak_cache) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def pop(self, key, default=None): + if key in self: + value = self[key] + del self[key] + else: + value = default + return value + + +# TODO: Implement getitem and setitem, for dynamic column access +class RowModel: + __slots__ = ('data',) + + _schema_: str = 'public' + _tablename_: Optional[str] = None + _columns_: dict[str, Column] = {} + + # Cache to keep track of registered Rows + _cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore + + _key_: tuple[str, ...] = () + _connector: Optional[Connector] = None + _registry: Optional[Registry] = None + + # TODO: Proper typing for a classvariable which gets dynamically assigned in subclass + table: RowTable = None + + def __init_subclass__(cls: Type[RowT], table: Optional[str] = None): + """ + Set table, _columns_, and _key_. + """ + if table is not None: + cls._tablename_ = table + + if cls._tablename_ is not None: + columns = {} + for key, value in cls.__dict__.items(): + if isinstance(value, Column): + columns[key] = value + + cls._columns_ = columns + if not cls._key_: + cls._key_ = tuple(column.name for column in columns.values() if column.primary) + cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_) + if cls._cache_ is None: + cls._cache_ = WeakValueDictionary() + + def __new__(cls, data): + # Registry pattern. + # Ensure each rowid always refers to a single Model instance + if data is not None: + rowid = cls._id_from_data(data) + + cache = cls._cache_ + + if (row := cache.get(rowid, None)) is not None: + obj = row + else: + obj = cache[rowid] = super().__new__(cls) + else: + obj = super().__new__(cls) + + return obj + + @classmethod + def as_tuple(cls): + return (cls.table.identifier, ()) + + def __init__(self, data): + self.data = data + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + @classmethod + def bind(cls, connector: Connector): + if cls.table is None: + raise ValueError("Cannot bind abstract RowModel") + cls._connector = connector + cls.table.bind(connector) + return cls + + @classmethod + def attach_to(cls, registry: Registry): + cls._registry = registry + return cls + + @property + def _dict_(self): + return {key: self.data[key] for key in self._key_} + + @property + def _rowid_(self): + return tuple(self.data[key] for key in self._key_) + + def __repr__(self): + return "{}.{}({})".format( + self.table.schema, + self.table.name, + ', '.join(repr(column.__get__(self)) for column in self._columns_.values()) + ) + + @classmethod + def _id_from_data(cls, data): + return tuple(data[key] for key in cls._key_) + + @classmethod + def _dict_from_id(cls, rowid): + return dict(zip(cls._key_, rowid)) + + @classmethod + def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]: + """ + Create or retrieve Row objects for each provided data row. + If the rows already exist in cache, updates the cached row. + """ + # TODO: Handle partial row data here somehow? + rows = [cls(data_row) for data_row in data_rows] + return rows + + @classmethod + def _delete_rows(cls, *data_rows): + """ + Remove the given rows from cache, if they exist. + May be extended to handle object deletion. + """ + cache = cls._cache_ + + for data_row in data_rows: + rowid = cls._id_from_data(data_row) + cache.pop(rowid, None) + + @classmethod + async def create(cls: Type[RowT], *args, **kwargs) -> RowT: + return await cls.table.create_row(*args, **kwargs) + + @classmethod + def fetch_where(cls: Type[RowT], *args, **kwargs): + return cls.table.fetch_rows_where(*args, **kwargs) + + @classmethod + async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]: + """ + Fetch the row with the given id, retrieving from cache where possible. + """ + row = cls._cache_.get(rowid, None) if cached else None + if row is None: + rows = await cls.fetch_where(**cls._dict_from_id(rowid)) + row = rows[0] if rows else None + if row is None: + cls._cache_[rowid] = cls(None) + elif row.data is None: + row = None + + return row + + @classmethod + async def fetch_or_create(cls, *rowid, **kwargs): + """ + Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. + """ + if rowid: + row = await cls.fetch(*rowid) + else: + rows = await cls.fetch_where(**kwargs).limit(1) + row = rows[0] if rows else None + + if row is None: + creation_kwargs = kwargs + if rowid: + creation_kwargs.update(cls._dict_from_id(rowid)) + row = await cls.create(**creation_kwargs) + return row + + async def refresh(self: RowT) -> Optional[RowT]: + """ + Refresh this Row from data. + + The return value may be `None` if the row was deleted. + """ + rows = await self.table.select_where(**self._dict_) + if not rows: + return None + else: + self.data = rows[0] + return self + + async def update(self: RowT, **values) -> Optional[RowT]: + """ + Update this Row with the given values. + + Internally passes the provided `values` to the `update` Query. + The return value may be `None` if the row was deleted. + """ + data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows) + if not data: + return None + else: + return data[0] + + async def delete(self: RowT) -> Optional[RowT]: + """ + Delete this Row. + """ + data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows) + return data[0] if data is not None else None diff --git a/src/data/queries.py b/src/data/queries.py new file mode 100644 index 0000000..0232928 --- /dev/null +++ b/src/data/queries.py @@ -0,0 +1,644 @@ +from typing import Optional, TypeVar, Any, Callable, Generic, List, Union +from enum import Enum +from itertools import chain +from psycopg import AsyncConnection, AsyncCursor +from psycopg import sql +from psycopg.rows import DictRow + +import logging + +from .conditions import Condition +from .base import Expression, RawExpr +from .connector import Connector + + +logger = logging.getLogger(__name__) + + +TQueryT = TypeVar('TQueryT', bound='TableQuery') +SQueryT = TypeVar('SQueryT', bound='Select') + +QueryResult = TypeVar('QueryResult') + + +class Query(Generic[QueryResult]): + """ + ABC for an executable query statement. + """ + __slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result') + + _adapter: Callable[..., QueryResult] + + def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs): + self.connector: Optional[Connector] = connector + self.conn: Optional[AsyncConnection] = conn + self.cursor: Optional[AsyncCursor] = cursor + + if row_adapter is not None: + self._adapter = row_adapter + else: + self._adapter = self._no_adapter + + self.result: Optional[QueryResult] = None + + def bind(self, connector: Connector): + self.connector = connector + return self + + def with_cursor(self, cursor: AsyncCursor): + self.cursor = cursor + return self + + def with_connection(self, conn: AsyncConnection): + self.conn = conn + return self + + def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def with_adapter(self, callable: Callable[..., QueryResult]): + # NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1] + # For this to work cleanly, callable should have arg type of QR1, not any + self._adapter = callable + return self + + def with_no_adapter(self): + """ + Sets the adapater to the identity. + """ + self._adapter = self._no_adapter + return self + + def one(self): + # TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1] + return self + + def build(self) -> Expression: + raise NotImplementedError + + async def _execute(self, cursor: AsyncCursor) -> QueryResult: + query, values = self.build().as_tuple() + # TODO: Move logging out to a custom cursor + # logger.debug( + # f"Executing query ({query.as_string(cursor)}) with values {values}", + # extra={'action': "Query"} + # ) + await cursor.execute(sql.Composed((query,)), values) + data = await cursor.fetchall() + self.result = self._adapter(*data) + return self.result + + async def execute(self, cursor=None) -> QueryResult: + """ + Execute the query, optionally with the provided cursor, and return the result rows. + If no cursor is provided, and no cursor has been set with `with_cursor`, + the execution will create a new cursor from the connection and close it automatically. + """ + # Create a cursor if possible + cursor = cursor if cursor is not None else self.cursor + if self.cursor is None: + if self.conn is None: + if self.connector is None: + raise ValueError("Cannot execute query without cursor, connection, or connector.") + else: + async with self.connector.connection() as conn: + async with conn.cursor() as cursor: + data = await self._execute(cursor) + else: + async with self.conn.cursor() as cursor: + data = await self._execute(cursor) + else: + data = await self._execute(cursor) + return data + + def __await__(self): + return self.execute().__await__() + + +class TableQuery(Query[QueryResult]): + """ + ABC for an executable query statement expected to be run on a single table. + """ + __slots__ = ( + 'tableid', + 'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group' + ) + + def __init__(self, tableid, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tableid: sql.Identifier = tableid + + def options(self, **kwargs): + """ + Set some query options. + Default implementation does nothing. + Should be overridden to provide specific options. + """ + return self + + +class WhereMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.condition: Optional[Condition] = None + + def where(self, *args: Condition, **kwargs): + """ + Add a Condition to the query. + Position arguments should be Conditions, + and keyword arguments should be of the form `column=Value`, + where Value may be a Value-type or a literal value. + All provided Conditions will be and-ed together to create a new Condition. + TODO: Maybe just pass this verbatim to a condition. + """ + if args or kwargs: + condition = Condition.construct(*args, **kwargs) + if self.condition is not None: + condition = self.condition & condition + + self.condition = condition + + return self + + @property + def _where_section(self) -> Optional[Expression]: + if self.condition is not None: + return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple()) + else: + return None + + +class JOINTYPE(Enum): + LEFT = sql.SQL('LEFT JOIN') + RIGHT = sql.SQL('RIGHT JOIN') + INNER = sql.SQL('INNER JOIN') + OUTER = sql.SQL('OUTER JOIN') + FULLOUTER = sql.SQL('FULL OUTER JOIN') + + +class JoinMixin(TableQuery[QueryResult]): + __slots__ = () + # TODO: Remember to add join slots to TableQuery + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._joins: list[Expression] = [] + + def join(self, + target: Union[str, Expression], + on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None, + join_type: JOINTYPE = JOINTYPE.INNER, + natural=False): + available = (on is not None) + (using is not None) + natural + if available == 0: + raise ValueError("No conditions given for Query Join") + if available > 1: + raise ValueError("Exactly one join format must be given for Query Join") + + sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())] + if isinstance(target, str): + sections.append((sql.Identifier(target), ())) + else: + sections.append(target.as_tuple()) + + if on is not None: + sections.append((sql.SQL('ON'), ())) + sections.append(on.as_tuple()) + elif using is not None: + sections.append((sql.SQL('USING'), ())) + if isinstance(using, Expression): + sections.append(using.as_tuple()) + elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str): + cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using)) + sections.append((cols, ())) + else: + raise ValueError("Unrecognised 'using' type.") + elif natural: + sections.insert(0, (sql.SQL('NATURAL'), ())) + + expr = RawExpr.join_tuples(*sections) + self._joins.append(expr) + return self + + def leftjoin(self, *args, **kwargs): + return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs) + + @property + def _join_section(self) -> Optional[Expression]: + if self._joins: + return RawExpr.join(*self._joins) + else: + return None + + +class ExtraMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._extra: Optional[Expression] = None + + def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()): + """ + Add an extra string, and optionally values, to this query. + The extra string is inserted after any condition, and before the limit. + """ + extra_expr = RawExpr(extra, values) + if self._extra is not None: + extra_expr = RawExpr.join(self._extra, extra_expr) + self._extra = extra_expr + return self + + @property + def _extra_section(self) -> Optional[Expression]: + if self._extra is None: + return None + else: + return self._extra + + +class LimitMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._limit: Optional[int] = None + + def limit(self, limit: int): + """ + Add a limit to this query. + """ + self._limit = limit + return self + + @property + def _limit_section(self) -> Optional[Expression]: + if self._limit is not None: + return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,)) + else: + return None + + +class FromMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._from: Optional[Expression] = None + + def from_expr(self, _from: Expression): + self._from = _from + return self + + @property + def _from_section(self) -> Optional[Expression]: + if self._from is not None: + expr, values = self._from.as_tuple() + return RawExpr(sql.SQL("FROM {}").format(expr), values) + else: + return None + + +class ORDER(Enum): + ASC = sql.SQL('ASC') + DESC = sql.SQL('DESC') + + +class NULLS(Enum): + FIRST = sql.SQL('NULLS FIRST') + LAST = sql.SQL('NULLS LAST') + + +class OrderMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._order: list[Expression] = [] + + def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None): + """ + Add a single sort expression to the query. + This method stacks. + """ + if isinstance(expr, Expression): + string, values = expr.as_tuple() + else: + string = sql.Identifier(expr) + values = () + + parts = [string] + if direction is not None: + parts.append(direction.value) + if nulls is not None: + parts.append(nulls.value) + + order_string = sql.SQL(' ').join(parts) + self._order.append(RawExpr(order_string, values)) + return self + + @property + def _order_section(self) -> Optional[Expression]: + if self._order: + expr = RawExpr.join(*self._order, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("ORDER BY {}").format(expr.expr) + return expr + else: + return None + + +class GroupMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._group: list[Expression] = [] + + def group_by(self, *exprs: Union[Expression, str]): + """ + Add a group expression(s) to the query. + This method stacks. + """ + for expr in exprs: + if isinstance(expr, Expression): + self._group.append(expr) + else: + self._group.append(RawExpr(sql.Identifier(expr))) + return self + + @property + def _group_section(self) -> Optional[Expression]: + if self._group: + expr = RawExpr.join(*self._group, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("GROUP BY {}").format(expr.expr) + return expr + else: + return None + + +class Insert(ExtraMixin, TableQuery[QueryResult]): + """ + Query type representing a table insert query. + """ + # TODO: Support ON CONFLICT for upserts + __slots__ = ('_columns', '_values', '_conflict') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._columns: tuple[str, ...] = () + self._values: tuple[tuple[Any, ...], ...] = () + self._conflict: Optional[Expression] = None + + def insert(self, columns, *values): + """ + Insert the given data. + + Parameters + ---------- + columns: tuple[str] + Tuple of column names to insert. + + values: tuple[tuple[Any, ...], ...] + Tuple of values to insert, corresponding to the columns. + """ + if not values: + raise ValueError("Cannot insert zero rows.") + if len(values[0]) != len(columns): + raise ValueError("Number of columns does not match length of values.") + + self._columns = columns + self._values = values + return self + + def on_conflict(self, ignore=False): + # TODO lots more we can do here + # Maybe return a Conflict object that can chain itself (not the query) + if ignore: + self._conflict = RawExpr(sql.SQL('DO NOTHING')) + return self + + @property + def _conflict_section(self) -> Optional[Expression]: + if self._conflict is not None: + e, v = self._conflict.as_tuple() + expr = RawExpr( + sql.SQL("ON CONFLICT {}").format( + e + ), + v + ) + return expr + return None + + def build(self): + columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) + single_value_str = sql.SQL('({})').format( + sql.SQL(',').join(sql.Placeholder() * len(self._columns)) + ) + values_str = sql.SQL(',').join(single_value_str * len(self._values)) + + # TODO: Check efficiency of inserting multiple values like this + # Also implement a Copy query + base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( + table=self.tableid, + columns=columns, + values_str=values_str + ) + + sections = [ + RawExpr(base, tuple(chain(*self._values))), + self._conflict_section, + self._extra_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]): + """ + Select rows from a table matching provided conditions. + """ + __slots__ = ('_columns',) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._columns: tuple[Expression, ...] = () + + def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]): + """ + Set the columns and expressions to select. + If none are given, selects all columns. + """ + cols: List[Expression] = [] + if columns: + cols.extend(map(RawExpr, map(sql.Identifier, columns))) + if exprs: + for name, expr in exprs.items(): + if isinstance(expr, str): + cols.append( + RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name)) + ) + elif isinstance(expr, sql.Composable): + cols.append( + RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name)) + ) + elif isinstance(expr, Expression): + value_expr, value_values = expr.as_tuple() + cols.append(RawExpr( + value_expr + sql.SQL(' AS ') + sql.Identifier(name), + value_values + )) + if cols: + self._columns = (*self._columns, *cols) + return self + + def build(self): + if not self._columns: + columns, columns_values = sql.SQL('*'), () + else: + columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple() + + base = sql.SQL("SELECT {columns} FROM {table}").format( + columns=columns, + table=self.tableid + ) + + sections = [ + RawExpr(base, columns_values), + self._join_section, + self._where_section, + self._group_section, + self._extra_section, + self._order_section, + self._limit_section, + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]): + """ + Query type representing a table delete query. + """ + # TODO: Cascade option for delete, maybe other options + # TODO: Require a where unless specifically disabled, for safety + + def build(self): + base = sql.SQL("DELETE FROM {table}").format( + table=self.tableid, + ) + sections = [ + RawExpr(base), + self._where_section, + self._extra_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]): + __slots__ = ( + '_set', + ) + # TODO: Again, require a where unless specifically disabled + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._set: List[Expression] = [] + + def set(self, **column_values: Union[Any, Expression]): + exprs: List[Expression] = [] + for name, value in column_values.items(): + if isinstance(value, Expression): + value_tup = value.as_tuple() + else: + value_tup = (sql.Placeholder(), (value,)) + + exprs.append( + RawExpr.join_tuples( + (sql.Identifier(name), ()), + value_tup, + joiner=sql.SQL(' = ') + ) + ) + self._set.extend(exprs) + return self + + def build(self): + if not self._set: + raise ValueError("No columns provided to update.") + set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple() + + base = sql.SQL("UPDATE {table} SET {set}").format( + table=self.tableid, + set=set_expr + ) + sections = [ + RawExpr(base, set_values), + self._from_section, + self._where_section, + self._extra_section, + self._limit_section, + RawExpr(sql.SQL('RETURNING *')) + ] + + sections = (section for section in sections if section is not None) + return RawExpr.join(*sections) + + +# async def upsert(cursor, table, constraint, **values): +# """ +# Insert or on conflict update. +# """ +# valuedict = values +# keys, values = zip(*values.items()) +# +# key_str = _format_insertkeys(keys) +# value_str, values = _format_insertvalues(values) +# update_key_str, update_key_values = _format_updatestr(valuedict) +# +# if not isinstance(constraint, str): +# constraint = ", ".join(constraint) +# +# await cursor.execute( +# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( +# table, key_str, value_str, constraint, update_key_str +# ), +# tuple((*values, *update_key_values)) +# ) +# return await cursor.fetchone() + + +# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None): +# cursor = cursor or conn.cursor() +# +# # TODO: executemany or copy syntax now +# return execute_values( +# cursor, +# """ +# UPDATE {table} +# SET {set_clause} +# FROM (VALUES {cast_row}%s) +# AS {temp_table} +# WHERE {where_clause} +# RETURNING * +# """.format( +# table=table, +# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys), +# cast_row=cast_row + ',' if cast_row else '', +# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys), +# temp_table="_t ({})".format(', '.join(set_keys + where_keys)) +# ), +# values, +# fetch=True +# ) diff --git a/src/data/registry.py b/src/data/registry.py new file mode 100644 index 0000000..c130d0f --- /dev/null +++ b/src/data/registry.py @@ -0,0 +1,102 @@ +from typing import Protocol, runtime_checkable, Optional + +from psycopg import AsyncConnection + +from .connector import Connector, Connectable + + +@runtime_checkable +class _Attachable(Connectable, Protocol): + def attach_to(self, registry: 'Registry'): + raise NotImplementedError + + +class Registry: + _attached: list[_Attachable] = [] + _name: Optional[str] = None + + def __init_subclass__(cls, name=None): + attached = [] + for _, member in cls.__dict__.items(): + if isinstance(member, _Attachable): + attached.append(member) + cls._attached = attached + cls._name = name or cls.__name__ + + def __init__(self, name=None): + self._conn: Optional[Connector] = None + self.name: str = name if name is not None else self._name + if self.name is None: + raise ValueError("A Registry must have a name!") + + self.init_tasks = [] + + for member in self._attached: + member.attach_to(self) + + def bind(self, connector: Connector): + self._conn = connector + for child in self._attached: + child.bind(connector) + + def attach(self, attachable): + self._attached.append(attachable) + if self._conn is not None: + attachable.bind(self._conn) + return attachable + + def init_task(self, coro): + """ + Initialisation tasks are run to setup the registry state. + These tasks will be run in the event loop, after connection to the database. + These tasks should be idempotent, as they may be run on reload and reconnect. + """ + self.init_tasks.append(coro) + return coro + + async def init(self): + for task in self.init_tasks: + await task(self) + return self + + +class AttachableClass: + """ABC for a default implementation of an Attachable class.""" + + _connector: Optional[Connector] = None + _registry: Optional[Registry] = None + + @classmethod + def bind(cls, connector: Connector): + cls._connector = connector + connector.connect_hook(cls.on_connect) + return cls + + @classmethod + def attach_to(cls, registry: Registry): + cls._registry = registry + return cls + + @classmethod + async def on_connect(cls, connection: AsyncConnection): + pass + + +class Attachable: + """ABC for a default implementation of an Attachable object.""" + + def __init__(self, *args, **kwargs): + self._connector: Optional[Connector] = None + self._registry: Optional[Registry] = None + + def bind(self, connector: Connector): + self._connector = connector + connector.connect_hook(self.on_connect) + return self + + def attach_to(self, registry: Registry): + self._registry = registry + return self + + async def on_connect(self, connection: AsyncConnection): + pass diff --git a/src/data/table.py b/src/data/table.py new file mode 100644 index 0000000..e20647e --- /dev/null +++ b/src/data/table.py @@ -0,0 +1,95 @@ +from typing import Optional +from psycopg.rows import DictRow +from psycopg import sql + +from . import queries as q +from .connector import Connector +from .registry import Registry + + +class Table: + """ + Transparent interface to a single table structure in the database. + Contains standard methods to access the table. + """ + + def __init__(self, name, *args, schema='public', **kwargs): + self.name: str = name + self.schema: str = schema + self.connector: Connector = None + + @property + def identifier(self): + if self.schema == 'public': + return sql.Identifier(self.name) + else: + return sql.Identifier(self.schema, self.name) + + def bind(self, connector: Connector): + self.connector = connector + return self + + def attach_to(self, registry: Registry): + self._registry = registry + return self + + def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]: + if data: + return data[0] + else: + return None + + def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]: + return data + + def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]: + return q.Select( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]: + return q.Select( + self.identifier, + row_adapter=self._single_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]: + return q.Update( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]: + return q.Delete( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).where(*args, **kwargs) + + def insert(self, **column_values) -> q.Insert[DictRow]: + return q.Insert( + self.identifier, + row_adapter=self._single_query_adapter, + connector=self.connector + ).insert(column_values.keys(), column_values.values()) + + def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]: + return q.Insert( + self.identifier, + row_adapter=self._many_query_adapter, + connector=self.connector + ).insert(*args, **kwargs) + +# def update_many(self, *args, **kwargs): +# with self.conn: +# return update_many(self.identifier, *args, **kwargs) + +# def upsert(self, *args, **kwargs): +# return upsert(self.identifier, *args, **kwargs) diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py new file mode 100644 index 0000000..32f39ee --- /dev/null +++ b/src/meta/LionBot.py @@ -0,0 +1,344 @@ +from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload +import logging +import asyncio +from weakref import WeakValueDictionary + +import discord +from discord.utils import MISSING +from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError +from discord.ext.commands.errors import CommandInvokeError, CheckFailure +from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError +from aiohttp import ClientSession + +from data import Database +from utils.lib import tabulate + +from .config import Conf +from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context +from .context import context +from .LionContext import LionContext +from .LionTree import LionTree +from .errors import HandledException, SafeCancellation +from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus + +if TYPE_CHECKING: + from core.cog import CoreCog + +logger = logging.getLogger(__name__) + + +class LionBot(Bot): + def __init__( + self, *args, appname: str, shardname: str, db: Database, config: Conf, + initial_extensions: List[str], web_client: ClientSession, + testing_guilds: List[int] = [], **kwargs + ): + kwargs.setdefault('tree_cls', LionTree) + super().__init__(*args, **kwargs) + self.web_client = web_client + self.testing_guilds = testing_guilds + self.initial_extensions = initial_extensions + self.db = db + self.appname = appname + self.shardname = shardname +# self.appdata = appdata + self.config = config + + self.system_monitor = SystemMonitor() + self.monitor = ComponentMonitor('LionBot', self._monitor_status) + self.system_monitor.add_component(self.monitor) + + self._locks = WeakValueDictionary() + self._running_events = set() + + @property + def core(self): + return self.get_cog('CoreCog') + + async def _monitor_status(self): + if self.is_closed(): + level = StatusLevel.ERRORED + info = "(ERROR) Websocket is closed" + data = {} + elif self.is_ws_ratelimited(): + level = StatusLevel.WAITING + info = "(WAITING) Websocket is ratelimited" + data = {} + elif not self.is_ready(): + level = StatusLevel.STARTING + info = "(STARTING) Not yet ready" + data = {} + else: + level = StatusLevel.OKAY + info = ( + "(OK) " + "Logged in with {guild_count} guilds, " + ", websocket latency {latency}, and {events} running events." + ) + data = { + 'guild_count': len(self.guilds), + 'latency': self.latency, + 'events': len(self._running_events), + } + return ComponentStatus(level, info, info, data) + + async def setup_hook(self) -> None: + log_context.set(f"APP: {self.application_id}") + + for extension in self.initial_extensions: + await self.load_extension(extension) + + for guildid in self.testing_guilds: + guild = discord.Object(guildid) + if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)): + self.tree.copy_global_to(guild=guild) + await self.tree.sync(guild=guild) + + # To make the type checker happy about fetching cogs by name + # TODO: Move this to stubs at some point + + @overload + def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog': + ... + + @overload + def get_cog(self, name: str) -> Optional[Cog]: + ... + + def get_cog(self, name: str) -> Optional[Cog]: + return super().get_cog(name) + + async def add_cog(self, cog: Cog, **kwargs): + sup = super() + @log_wrap(action=f"Attach {cog.__cog_name__}") + async def wrapper(): + logger.info(f"Attaching Cog {cog.__cog_name__}") + await sup.add_cog(cog, **kwargs) + logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.") + await wrapper() + + async def load_extension(self, name, *, package=None, **kwargs): + sup = super() + @log_wrap(action=f"Load {name.strip('.')}") + async def wrapper(): + logger.info(f"Loading extension {name} in package {package}.") + await sup.load_extension(name, package=package, **kwargs) + logger.debug(f"Loaded extension {name} in package {package}.") + await wrapper() + + async def start(self, token: str, *, reconnect: bool = True): + with logging_context(action="Login"): + start_task = asyncio.create_task(self.login(token)) + await start_task + + with logging_context(stack=("Running",)): + run_task = asyncio.create_task(self.connect(reconnect=reconnect)) + await run_task + + def dispatch(self, event_name: str, *args, **kwargs): + with logging_context(action=f"Dispatch {event_name}"): + super().dispatch(event_name, *args, **kwargs) + + def _schedule_event(self, coro, event_name, *args, **kwargs): + """ + Extends client._schedule_event to keep a persistent + background task store. + """ + task = super()._schedule_event(coro, event_name, *args, **kwargs) + self._running_events.add(task) + task.add_done_callback(lambda fut: self._running_events.discard(fut)) + + def idlock(self, snowflakeid): + lock = self._locks.get(snowflakeid, None) + if lock is None: + lock = self._locks[snowflakeid] = asyncio.Lock() + return lock + + async def on_ready(self): + logger.info( + f"Logged in as {self.application.name}\n" + f"Application id {self.application.id}\n" + f"Shard Talk identifier {self.shardname}\n" + "------------------------------\n" + f"Enabled Modules: {', '.join(self.extensions.keys())}\n" + f"Loaded Cogs: {', '.join(self.cogs.keys())}\n" + f"Registered Data: {', '.join(self.db.registries.keys())}\n" + f"Listening for {sum(1 for _ in self.walk_commands())} commands\n" + "------------------------------\n" + f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n" + "Ready to take commands!\n", + extra={'action': 'Ready'} + ) + + async def get_context(self, origin, /, *, cls=MISSING): + if cls is MISSING: + cls = LionContext + ctx = await super().get_context(origin, cls=cls) + context.set(ctx) + return ctx + + async def on_command(self, ctx: LionContext): + logger.info( + f"Executing command '{ctx.command.qualified_name}' " + f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') " + f"with interaction: {ctx.interaction.data if ctx.interaction else None}", + extra={'with_ctx': True} + ) + + async def on_command_error(self, ctx, exception): + # TODO: Some of these could have more user-feedback + logger.debug(f"Handling command error for {ctx}: {exception}") + if isinstance(ctx.command, HybridCommand) and ctx.command.app_command: + cmd_str = ctx.command.app_command.to_dict() + else: + cmd_str = str(ctx.command) + try: + raise exception + except (HybridCommandError, CommandInvokeError, appCommandInvokeError): + try: + if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)): + original = exception.original.original + raise original + else: + original = exception.original + raise original + except HandledException: + pass + except TransformerError as e: + msg = str(e) + if msg: + try: + await ctx.error_reply(msg) + except Exception: + pass + logger.debug( + f"Caught a transformer error: {repr(e)}", + extra={'action': 'BotError', 'with_ctx': True} + ) + except SafeCancellation: + if original.msg: + try: + await ctx.error_reply(original.msg) + except Exception: + pass + logger.debug( + f"Caught a safe cancellation: {original.details}", + extra={'action': 'BotError', 'with_ctx': True} + ) + except discord.Forbidden: + # Unknown uncaught Forbidden + try: + # Attempt a general error reply + await ctx.reply("I don't have enough channel or server permissions to complete that command here!") + except Exception: + # We can't send anything at all. Exit quietly, but log. + logger.warning( + f"Caught an unhandled 'Forbidden' while executing: {cmd_str}", + exc_info=True, + extra={'action': 'BotError', 'with_ctx': True} + ) + except discord.HTTPException: + logger.error( + f"Caught an unhandled 'HTTPException' while executing: {cmd_str}", + exc_info=True, + extra={'action': 'BotError', 'with_ctx': True} + ) + except asyncio.CancelledError: + pass + except asyncio.TimeoutError: + pass + except Exception as e: + logger.exception( + f"Caught an unknown CommandInvokeError while executing: {cmd_str}", + extra={'action': 'BotError', 'with_ctx': True} + ) + + error_embed = discord.Embed( + title="Something went wrong!", + colour=discord.Colour.dark_red() + ) + error_embed.description = ( + "An unexpected error occurred while processing your command!\n" + "Our development team has been notified, and the issue will be addressed soon.\n" + "If the error persists, or you have any questions, please contact our [support team]({link}) " + "and give them the extra details below." + ).format(link=self.config.bot.support_guild) + details = {} + details['error'] = f"`{repr(e)}`" + if ctx.interaction: + details['interactionid'] = f"`{ctx.interaction.id}`" + if ctx.command: + details['cmd'] = f"`{ctx.command.qualified_name}`" + if ctx.author: + details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`" + if ctx.guild: + details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`" + details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`" + if ctx.author: + ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else '' + details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`" + if ctx.channel.type is discord.enums.ChannelType.private: + details['channel'] = "`Direct Message`" + elif ctx.channel: + details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`" + details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`" + if ctx.author: + details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`" + details['shard'] = f"`{self.shardname}`" + details['log_stack'] = f"`{log_action_stack.get()}`" + + table = '\n'.join(tabulate(*details.items())) + error_embed.add_field(name='Details', value=table) + + try: + await ctx.error_reply(embed=error_embed) + except discord.HTTPException: + pass + finally: + exception.original = HandledException(exception.original) + except CheckFailure as e: + logger.debug( + f"Command failed check: {e}: {e.args}", + extra={'action': 'BotError', 'with_ctx': True} + ) + try: + await ctx.error_reply(str(e)) + except discord.HTTPException: + pass + except Exception: + # Completely unknown exception outside of command invocation! + # Something is very wrong here, don't attempt user interaction. + logger.exception( + f"Caught an unknown top-level exception while executing: {cmd_str}", + extra={'action': 'BotError', 'with_ctx': True} + ) + + def add_command(self, command): + if not hasattr(command, '_placeholder_group_'): + super().add_command(command) + + def request_chunking_for(self, guild): + if not guild.chunked: + return asyncio.create_task( + self._connection.chunk_guild(guild, wait=False, cache=True), + name=f"Background chunkreq for {guild.id}" + ) + + async def on_interaction(self, interaction: discord.Interaction): + """ + Adds the interaction author to guild cache if appropriate. + + This gets run a little bit late, so it is possible the interaction gets handled + without the author being in case. + """ + guild = interaction.guild + user = interaction.user + if guild is not None and user is not None and isinstance(user, discord.Member): + if not guild.get_member(user.id): + guild._add_member(user) + if guild is not None and not guild.chunked: + # Getting an interaction in the guild is a good enough reason to request chunking + logger.info( + f"Unchunked guild requesting chunking after interaction." + ) + self.request_chunking_for(guild) diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py new file mode 100644 index 0000000..39ca43a --- /dev/null +++ b/src/meta/LionCog.py @@ -0,0 +1,58 @@ +from typing import Any + +from discord.ext.commands import Cog +from discord.ext import commands as cmds + + +class LionCog(Cog): + # A set of other cogs that this cog depends on + depends_on: set['LionCog'] = set() + _placeholder_groups_: set[str] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + cls._placeholder_groups_ = set() + + for base in reversed(cls.__mro__): + for elem, value in base.__dict__.items(): + if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'): + cls._placeholder_groups_.add(value.name) + + def __new__(cls, *args: Any, **kwargs: Any): + # Patch to ensure no placeholder groups are in the command list + self = super().__new__(cls) + self.__cog_commands__ = [ + command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_ + ] + return self + + async def _inject(self, bot, *args, **kwargs): + if self.depends_on: + not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)} + raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}") + + return await super()._inject(bot, *args, *kwargs) + + @classmethod + def placeholder_group(cls, group: cmds.HybridGroup): + group._placeholder_group_ = True + return group + + def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup): + """ + Crossload a placeholder group's commands into the target group + """ + if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup): + raise ValueError("Placeholder and target groups my be HypridGroups.") + if placeholder_group.name not in self._placeholder_groups_: + raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.") + if target_group.app_command is None: + raise ValueError("Target group has no app_command to crossload into.") + + for command in placeholder_group.commands: + placeholder_group.remove_command(command.name) + target_group.remove_command(command.name) + acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self) + command.app_command = acmd + target_group.add_command(command) diff --git a/src/meta/LionContext.py b/src/meta/LionContext.py new file mode 100644 index 0000000..e1b21d8 --- /dev/null +++ b/src/meta/LionContext.py @@ -0,0 +1,195 @@ +import types +import logging +from collections import namedtuple +from typing import Optional, TYPE_CHECKING + +import discord +from discord.enums import ChannelType +from discord.ext.commands import Context + +if TYPE_CHECKING: + from .LionBot import LionBot + + +logger = logging.getLogger(__name__) + + +""" +Stuff that might be useful to implement (see cmdClient): + sent_messages cache + tasks cache + error reply + usage + interaction cache + View cache? + setting access +""" + + +FlatContext = namedtuple( + 'FlatContext', + ('message', + 'interaction', + 'guild', + 'author', + 'channel', + 'alias', + 'prefix', + 'failed') +) + + +class LionContext(Context['LionBot']): + """ + Represents the context a command is invoked under. + + Extends Context to add Lion-specific methods and attributes. + Also adds several contextual wrapped utilities for simpler user during command invocation. + """ + + def __repr__(self): + parts = {} + if self.interaction is not None: + parts['iid'] = self.interaction.id + parts['itype'] = f"\"{self.interaction.type.name}\"" + if self.message is not None: + parts['mid'] = self.message.id + if self.author is not None: + parts['uid'] = self.author.id + parts['uname'] = f"\"{self.author.name}\"" + if self.channel is not None: + parts['cid'] = self.channel.id + if self.channel.type is ChannelType.private: + parts['cname'] = f"\"{self.channel.recipient}\"" + else: + parts['cname'] = f"\"{self.channel.name}\"" + if self.guild is not None: + parts['gid'] = self.guild.id + parts['gname'] = f"\"{self.guild.name}\"" + if self.command is not None: + parts['cmd'] = f"\"{self.command.qualified_name}\"" + if self.invoked_with is not None: + parts['alias'] = f"\"{self.invoked_with}\"" + if self.command_failed: + parts['failed'] = self.command_failed + + return "".format( + ' '.join(f"{name}={value}" for name, value in parts.items()) + ) + + def flatten(self): + """Flat pure-data context information, for caching and logging.""" + return FlatContext( + self.message.id, + self.interaction.id if self.interaction is not None else None, + self.guild.id if self.guild is not None else None, + self.author.id if self.author is not None else None, + self.channel.id if self.channel is not None else None, + self.invoked_with, + self.prefix, + self.command_failed + ) + + @classmethod + def util(cls, util_func): + """ + Decorator to make a utility function available as a Context instance method. + """ + setattr(cls, util_func.__name__, util_func) + logger.debug(f"Attached context utility function: {util_func.__name__}") + return util_func + + @classmethod + def wrappable_util(cls, util_func): + """ + Decorator to add a Wrappable utility function as a Context instance method. + """ + wrapped = Wrappable(util_func) + setattr(cls, util_func.__name__, wrapped) + logger.debug(f"Attached wrappable context utility function: {util_func.__name__}") + return wrapped + + async def error_reply(self, content: Optional[str] = None, **kwargs): + if content and 'embed' not in kwargs: + embed = discord.Embed( + colour=discord.Colour.red(), + description=content + ) + kwargs['embed'] = embed + content = None + + # Expect this may be run in highly unusual circumstances. + # This should never error, or at least handle all errors. + if self.interaction: + kwargs.setdefault('ephemeral', True) + try: + await self.reply(content=content, **kwargs) + except discord.HTTPException: + pass + except Exception: + logger.exception( + "Unknown exception in 'error_reply'.", + extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True} + ) + + +class Wrappable: + __slots__ = ('_func', 'wrappers') + + def __init__(self, func): + self._func = func + self.wrappers = None + + @property + def __name__(self): + return self._func.__name__ + + def add_wrapper(self, func, name=None): + self.wrappers = self.wrappers or {} + name = name or func.__name__ + self.wrappers[name] = func + logger.debug( + f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.", + extra={'action': "Wrap Util"} + ) + + def remove_wrapper(self, name): + if not self.wrappers or name not in self.wrappers: + raise ValueError( + f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'" + ) + self.wrappers.pop(name) + logger.debug( + f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.", + extra={'action': "Unwrap Util"} + ) + + def __call__(self, *args, **kwargs): + if self.wrappers: + return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs) + else: + return self._func(*args, **kwargs) + + def _wrapped(self, iter_wraps): + next_wrap = next(iter_wraps, None) + if next_wrap: + def _func(*args, **kwargs): + return next_wrap(self._wrapped(iter_wraps), *args, **kwargs) + else: + _func = self._func + return _func + + def __get__(self, instance, cls=None): + if instance is None: + return self + else: + return types.MethodType(self, instance) + + +LionContext.reply = Wrappable(LionContext.reply) + + +# @LionContext.reply.add_wrapper +# async def think(func, ctx, *args, **kwargs): +# await ctx.channel.send("thinking") +# await func(ctx, *args, **kwargs) diff --git a/src/meta/LionTree.py b/src/meta/LionTree.py new file mode 100644 index 0000000..75a3ccf --- /dev/null +++ b/src/meta/LionTree.py @@ -0,0 +1,150 @@ +import logging + +import discord +from discord import Interaction +from discord.app_commands import CommandTree +from discord.app_commands.errors import AppCommandError, CommandInvokeError +from discord.enums import InteractionType +from discord.app_commands.namespace import Namespace + +from utils.lib import tabulate + +from .logger import logging_context, set_logging_context, log_wrap, log_action_stack +from .errors import SafeCancellation +from .config import conf + +logger = logging.getLogger(__name__) + + +class LionTree(CommandTree): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._call_tasks = set() + + async def on_error(self, interaction: discord.Interaction, error) -> None: + try: + if isinstance(error, CommandInvokeError): + raise error.original + else: + raise error + except SafeCancellation: + # Assume this has already been handled + pass + except Exception: + logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'}) + if interaction.type is not InteractionType.autocomplete: + embed = self.bugsplat(interaction, error) + await self.error_reply(interaction, embed) + + async def error_reply(self, interaction, embed): + if not interaction.is_expired(): + try: + if interaction.response.is_done(): + await interaction.followup.send(embed=embed, ephemeral=True) + else: + await interaction.response.send_message(embed=embed, ephemeral=True) + except discord.HTTPException: + pass + + def bugsplat(self, interaction, e): + error_embed = discord.Embed(title="Something went wrong!", colour=discord.Colour.red()) + error_embed.description = ( + "An unexpected error occurred during this interaction!\n" + "Our development team has been notified, and the issue will be addressed soon.\n" + "If the error persists, or you have any questions, please contact our [support team]({link}) " + "and give them the extra details below." + ).format(link=interaction.client.config.bot.support_guild) + details = {} + details['error'] = f"`{repr(e)}`" + details['interactionid'] = f"`{interaction.id}`" + details['interactiontype'] = f"`{interaction.type}`" + if interaction.command: + details['cmd'] = f"`{interaction.command.qualified_name}`" + if interaction.user: + details['user'] = f"`{interaction.user.id}` -- `{interaction.user}`" + if interaction.guild: + details['guild'] = f"`{interaction.guild.id}` -- `{interaction.guild.name}`" + details['my_guild_perms'] = f"`{interaction.guild.me.guild_permissions.value}`" + if interaction.user: + ownerstr = ' (owner)' if interaction.user.id == interaction.guild.owner_id else '' + details['user_guild_perms'] = f"`{interaction.user.guild_permissions.value}{ownerstr}`" + if interaction.channel.type is discord.enums.ChannelType.private: + details['channel'] = "`Direct Message`" + elif interaction.channel: + details['channel'] = f"`{interaction.channel.id}` -- `{interaction.channel.name}`" + details['my_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.guild.me).value}`" + if interaction.user: + details['user_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.user).value}`" + details['shard'] = f"`{interaction.client.shardname}`" + details['log_stack'] = f"`{log_action_stack.get()}`" + + table = '\n'.join(tabulate(*details.items())) + error_embed.add_field(name='Details', value=table) + return error_embed + + def _from_interaction(self, interaction: Interaction) -> None: + @log_wrap(context=f"iid: {interaction.id}", isolate=False) + async def wrapper(): + try: + await self._call(interaction) + except AppCommandError as e: + await self._dispatch_error(interaction, e) + + task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker') + self._call_tasks.add(task) + task.add_done_callback(lambda fut: self._call_tasks.discard(fut)) + + async def _call(self, interaction): + if not await self.interaction_check(interaction): + interaction.command_failed = True + return + + data = interaction.data # type: ignore + type = data.get('type', 1) + if type != 1: + # Context menu command... + await self._call_context_menu(interaction, data, type) + return + + command, options = self._get_app_command_options(data) + + # Pre-fill the cached slot to prevent re-computation + interaction._cs_command = command + + # At this point options refers to the arguments of the command + # and command refers to the class type we care about + namespace = Namespace(interaction, data.get('resolved', {}), options) + + # Same pre-fill as above + interaction._cs_namespace = namespace + + # Auto complete handles the namespace differently... so at this point this is where we decide where that is. + if interaction.type is InteractionType.autocomplete: + set_logging_context(action=f"Acmp {command.qualified_name}") + focused = next((opt['name'] for opt in options if opt.get('focused')), None) + if focused is None: + raise AppCommandError( + 'This should not happen, but there is no focused element. This is a Discord bug.' + ) + try: + await command._invoke_autocomplete(interaction, focused, namespace) + except Exception as e: + await self.on_error(interaction, e) + return + + set_logging_context(action=f"Run {command.qualified_name}") + logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}") + try: + await command._invoke_with_namespace(interaction, namespace) + except AppCommandError as e: + interaction.command_failed = True + await command._invoke_error_handlers(interaction, e) + await self.on_error(interaction, e) + else: + if not interaction.command_failed: + self.client.dispatch('app_command_completion', interaction, command) + finally: + if interaction.command_failed: + logger.debug("Command completed with errors.") + else: + logger.debug("Command completed without errors.") diff --git a/src/meta/__init__.py b/src/meta/__init__.py new file mode 100644 index 0000000..5f68fe3 --- /dev/null +++ b/src/meta/__init__.py @@ -0,0 +1,15 @@ +from .LionBot import LionBot +from .LionCog import LionCog +from .LionContext import LionContext +from .LionTree import LionTree + +from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app +from .config import conf, configEmoji +from .args import args +from .app import appname, appname_from_shard, shard_from_appname +from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled +from .context import context, ctx_bot + +from . import sharding +from . import logger +from . import app diff --git a/src/meta/app.py b/src/meta/app.py new file mode 100644 index 0000000..9f0c9a2 --- /dev/null +++ b/src/meta/app.py @@ -0,0 +1,32 @@ +""" +appname: str + The base identifer for this application. + This identifies which services the app offers. +shardname: str + The specific name of the running application. + Only one process should be connecteded with a given appname. + For the bot apps, usually specifies the shard id and shard number. +""" +# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data? + +from . import sharding, conf +from .logger import log_app +from .args import args + + +appname = conf.data['appid'] +appid = appname # backwards compatibility + + +def appname_from_shard(shardid): + appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}" + return appname + + +def shard_from_appname(appname: str): + return int(appname.rsplit('_', maxsplit=1)[-1]) + + +shardname = appname_from_shard(sharding.shard_number) + +log_app.set(shardname) diff --git a/src/meta/args.py b/src/meta/args.py new file mode 100644 index 0000000..8d82a69 --- /dev/null +++ b/src/meta/args.py @@ -0,0 +1,35 @@ +import argparse + +from constants import CONFIG_FILE + +# ------------------------------ +# Parsed commandline arguments +# ------------------------------ +parser = argparse.ArgumentParser() +parser.add_argument( + '--conf', + dest='config', + default=CONFIG_FILE, + help="Path to configuration file." +) +parser.add_argument( + '--shard', + dest='shard', + default=None, + type=int, + help="Shard number to run, if applicable." +) +parser.add_argument( + '--host', + dest='host', + default='127.0.0.1', + help="IP address to run the app listener on." +) +parser.add_argument( + '--port', + dest='port', + default='5001', + help="Port to run the app listener on." +) + +args = parser.parse_args() diff --git a/src/meta/config.py b/src/meta/config.py new file mode 100644 index 0000000..9e624df --- /dev/null +++ b/src/meta/config.py @@ -0,0 +1,146 @@ +from discord import PartialEmoji +import configparser as cfgp + +from .args import args + +shard_number = args.shard + +class configEmoji(PartialEmoji): + __slots__ = ('fallback',) + + def __init__(self, *args, fallback=None, **kwargs): + super().__init__(*args, **kwargs) + self.fallback = fallback + + @classmethod + def from_str(cls, emojistr: str): + """ + Parses emoji strings of one of the following forms + ` or fallback` + `<:name:id> or fallback` + `` + `<:name:id>` + """ + splits = emojistr.rsplit(' or ', maxsplit=1) + + fallback = splits[1] if len(splits) > 1 else None + emojistr = splits[0].strip('<> ') + animated, name, id = emojistr.split(':') + return cls( + name=name, + fallback=PartialEmoji(name=fallback) if fallback is not None else None, + animated=bool(animated), + id=int(id) if id else None + ) + + +class MapDotProxy: + """ + Allows dot access to an underlying Mappable object. + """ + __slots__ = ("_map", "_converter") + + def __init__(self, mappable, converter=None): + self._map = mappable + self._converter = converter + + def __getattribute__(self, key): + _map = object.__getattribute__(self, '_map') + if key == '_map': + return _map + if key in _map: + _converter = object.__getattribute__(self, '_converter') + if _converter: + return _converter(_map[key]) + else: + return _map[key] + else: + return object.__getattribute__(_map, key) + + def __getitem__(self, key): + return self._map.__getitem__(key) + + +class ConfigParser(cfgp.ConfigParser): + """ + Extension of base ConfigParser allowing optional + section option retrieval without defaults. + """ + def options(self, section, no_defaults=False, **kwargs): + if no_defaults: + try: + return list(self._sections[section].keys()) + except KeyError: + raise cfgp.NoSectionError(section) + else: + return super().options(section, **kwargs) + + +class Conf: + def __init__(self, configfile, section_name="DEFAULT"): + self.configfile = configfile + + self.config = ConfigParser( + converters={ + "intlist": self._getintlist, + "list": self._getlist, + "emoji": configEmoji.from_str, + } + ) + + with open(configfile) as conff: + # Opening with read_file mainly to ensure the file exists + self.config.read_file(conff) + + self.section_name = section_name if section_name in self.config else 'DEFAULT' + + self.default = self.config["DEFAULT"] + self.section = MapDotProxy(self.config[self.section_name]) + self.bot = self.section + + # Config file recursion, read in configuration files specified in every "ALSO_READ" key. + more_to_read = self.section.getlist("ALSO_READ", []) + read = set() + while more_to_read: + to_read = more_to_read.pop(0) + read.add(to_read) + self.config.read(to_read) + new_paths = [path for path in self.section.getlist("ALSO_READ", []) + if path not in read and path not in more_to_read] + more_to_read.extend(new_paths) + + self.emojis = MapDotProxy( + self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section, + converter=configEmoji.from_str + ) + + global conf + conf = self + + def __getitem__(self, key): + return self.section[key].strip() + + def __getattr__(self, section): + name = section.upper() + shard_name = f"{name}-{shard_number}" + if shard_name in self.config: + return self.config[shard_name] + else: + return self.config[name] + + def get(self, name, fallback=None): + result = self.section.get(name, fallback) + return result.strip() if result else result + + def _getintlist(self, value): + return [int(item.strip()) for item in value.split(',')] + + def _getlist(self, value): + return [item.strip() for item in value.split(',')] + + def write(self): + with open(self.configfile, 'w') as conffile: + self.config.write(conffile) + + +conf = Conf(args.config, 'BOT') diff --git a/src/meta/context.py b/src/meta/context.py new file mode 100644 index 0000000..75f1df2 --- /dev/null +++ b/src/meta/context.py @@ -0,0 +1,20 @@ +""" +Namespace for various global context variables. +Allows asyncio callbacks to accurately retrieve information about the current state. +""" + + +from typing import TYPE_CHECKING, Optional + +from contextvars import ContextVar + +if TYPE_CHECKING: + from .LionBot import LionBot + from .LionContext import LionContext + + +# Contains the current command context, if applicable +context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None) + +# Contains the current LionBot instance +ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None) diff --git a/src/meta/errors.py b/src/meta/errors.py new file mode 100644 index 0000000..a5d6cbf --- /dev/null +++ b/src/meta/errors.py @@ -0,0 +1,64 @@ +from typing import Optional +from string import Template + + +class SafeCancellation(Exception): + """ + Raised to safely cancel execution of the current operation. + + If not caught, is expected to be propagated to the Tree and safely ignored there. + If a `msg` is provided, a context-aware error handler should catch and send the message to the user. + The error handler should then set the `msg` to None, to avoid double handling. + Debugging information should go in `details`, to be logged by a top-level error handler. + """ + default_message = "" + + @property + def msg(self): + return self._msg if self._msg is not None else self.default_message + + def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs): + self._msg: Optional[str] = _msg + self.details: str = details if details is not None else self.msg + super().__init__(**kwargs) + + +class UserInputError(SafeCancellation): + """ + A SafeCancellation induced from unparseable user input. + """ + default_message = "Could not understand your input." + + @property + def msg(self): + return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message + + def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs): + self.info = info + super().__init__(_msg, **kwargs) + + +class UserCancelled(SafeCancellation): + """ + A SafeCancellation induced from manual user cancellation. + + Usually silent. + """ + default_msg = None + + +class ResponseTimedOut(SafeCancellation): + """ + A SafeCancellation induced from a user interaction time-out. + """ + default_msg = "Session timed out waiting for input." + + +class HandledException(SafeCancellation): + """ + Sentinel class to indicate to error handlers that this exception has been handled. + Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler. + """ + def __init__(self, exc=None, **kwargs): + self.exc = exc + super().__init__(**kwargs) diff --git a/src/meta/logger.py b/src/meta/logger.py new file mode 100644 index 0000000..ffa97f7 --- /dev/null +++ b/src/meta/logger.py @@ -0,0 +1,468 @@ +import sys +import logging +import asyncio +from typing import List, Optional +from logging.handlers import QueueListener, QueueHandler +import queue +import multiprocessing +from contextlib import contextmanager +from io import StringIO +from functools import wraps +from contextvars import ContextVar + +import discord +from discord import Webhook, File +import aiohttp + +from .config import conf +from . import sharding +from .context import context +from utils.lib import utc_now +from utils.ratelimits import Bucket, BucketOverFull, BucketFull + + +log_logger = logging.getLogger(__name__) +log_logger.propagate = False + + +log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') +log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=()) +log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number)) + +def set_logging_context( + context: Optional[str] = None, + action: Optional[str] = None, + stack: Optional[tuple[str, ...]] = None +): + """ + Statically set the logging context variables to the given values. + + If `action` is given, pushes it onto the `log_action_stack`. + """ + if context is not None: + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + + +@contextmanager +def logging_context(context=None, action=None, stack=None): + """ + Context manager for executing a block of code in a given logging context. + + This context manager should only be used around synchronous code. + This is because async code *may* get cancelled or externally garbage collected, + in which case the finally block will be executed in the wrong context. + See https://github.com/python/cpython/issues/93740 + This can be refactored nicely if this gets merged: + https://github.com/python/cpython/pull/99634 + + (It will not necessarily break on async code, + if the async code can be guaranteed to clean up in its own context.) + """ + if context is not None: + oldcontext = log_context.get() + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + try: + yield + finally: + if context is not None: + log_context.set(oldcontext) + if stack is not None or action is not None: + log_action_stack.set(astack) + + +def with_log_ctx(isolate=True, **kwargs): + """ + Execute a coroutine inside a given logging context. + + If `isolate` is true, ensures that context does not leak + outside the coroutine. + + If `isolate` is false, just statically set the context, + which will leak unless the coroutine is + called in an externally copied context. + """ + def decorator(func): + @wraps(func) + async def wrapped(*w_args, **w_kwargs): + if isolate: + with logging_context(**kwargs): + # Task creation will synchronously copy the context + # This is gc safe + name = kwargs.get('action', f"log-wrapped-{func.__name__}") + task = asyncio.create_task(func(*w_args, **w_kwargs), name=name) + return await task + else: + # This will leak context changes + set_logging_context(**kwargs) + return await func(*w_args, **w_kwargs) + return wrapped + return decorator + + +# For backwards compatibility +log_wrap = with_log_ctx + + +def persist_task(task_collection: set): + """ + Coroutine decorator that ensures the coroutine is scheduled as a task + and added to the given task_collection for strong reference + when it is called. + + This is just a hack to handle discord.py events potentially + being unexpectedly garbage collected. + + Since this also implicitly schedules the coroutine as a task when it is called, + the coroutine will also be run inside an isolated context. + """ + def decorator(coro): + @wraps(coro) + async def wrapped(*w_args, **w_kwargs): + name = f"persisted-{coro.__name__}" + task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name) + task_collection.add(task) + task.add_done_callback(lambda f: task_collection.discard(f)) + await task + + +RESET_SEQ = "\033[0m" +COLOR_SEQ = "\033[3%dm" +BOLD_SEQ = "\033[1m" +"]]]" +BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) + + +def colour_escape(fmt: str) -> str: + cmap = { + '%(black)': COLOR_SEQ % BLACK, + '%(red)': COLOR_SEQ % RED, + '%(green)': COLOR_SEQ % GREEN, + '%(yellow)': COLOR_SEQ % YELLOW, + '%(blue)': COLOR_SEQ % BLUE, + '%(magenta)': COLOR_SEQ % MAGENTA, + '%(cyan)': COLOR_SEQ % CYAN, + '%(white)': COLOR_SEQ % WHITE, + '%(reset)': RESET_SEQ, + '%(bold)': BOLD_SEQ, + } + for key, value in cmap.items(): + fmt = fmt.replace(key, value) + return fmt + + +log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' + + '%(cyan)%(app)-15s%(reset)|' + + '%(cyan)%(context)-24s%(reset)|' + + '%(cyan)%(actionstr)-22s%(reset)|' + + ' %(bold)%(cyan)%(name)s:%(reset)' + + ' %(white)%(message)s%(ctxstr)s%(reset)') +log_format = colour_escape(log_format) + + +# Setup the logger +logger = logging.getLogger() +log_fmt = logging.Formatter( + fmt=log_format, + # datefmt='%Y-%m-%d %H:%M:%S' +) +logger.setLevel(logging.NOTSET) + + +class LessThanFilter(logging.Filter): + def __init__(self, exclusive_maximum, name=""): + super(LessThanFilter, self).__init__(name) + self.max_level = exclusive_maximum + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.levelno < self.max_level else 0 + +class ExactLevelFilter(logging.Filter): + def __init__(self, target_level, name=""): + super().__init__(name) + self.target_level = target_level + + def filter(self, record): + return (record.levelno == self.target_level) + + +class ThreadFilter(logging.Filter): + def __init__(self, thread_name): + super().__init__("") + self.thread = thread_name + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.threadName == self.thread else 0 + + +class ContextInjection(logging.Filter): + def filter(self, record): + # These guards are to allow override through _extra + # And to ensure the injection is idempotent + if not hasattr(record, 'context'): + record.context = log_context.get() + + if not hasattr(record, 'actionstr'): + action_stack = log_action_stack.get() + if hasattr(record, 'action'): + action_stack = (*action_stack, record.action) + if action_stack: + record.actionstr = ' ➔ '.join(action_stack) + else: + record.actionstr = "Unknown Action" + + if not hasattr(record, 'app'): + record.app = log_app.get() + + if not hasattr(record, 'ctx'): + if ctx := context.get(): + record.ctx = repr(ctx) + else: + record.ctx = None + + if getattr(record, 'with_ctx', False) and record.ctx: + record.ctxstr = '\n' + record.ctx + else: + record.ctxstr = "" + return True + + +logging_handler_out = logging.StreamHandler(sys.stdout) +logging_handler_out.setLevel(logging.DEBUG) +logging_handler_out.setFormatter(log_fmt) +logging_handler_out.addFilter(ContextInjection()) +logger.addHandler(logging_handler_out) +log_logger.addHandler(logging_handler_out) + +logging_handler_err = logging.StreamHandler(sys.stderr) +logging_handler_err.setLevel(logging.WARNING) +logging_handler_err.setFormatter(log_fmt) +logging_handler_err.addFilter(ContextInjection()) +logger.addHandler(logging_handler_err) +log_logger.addHandler(logging_handler_err) + + +class LocalQueueHandler(QueueHandler): + def _emit(self, record: logging.LogRecord) -> None: + # Removed the call to self.prepare(), handle task cancellation + try: + self.enqueue(record) + except asyncio.CancelledError: + raise + except Exception: + self.handleError(record) + + +class WebHookHandler(logging.StreamHandler): + def __init__(self, webhook_url, prefix="", batch=True, loop=None): + super().__init__() + self.webhook_url = webhook_url + self.prefix = prefix + self.batched = "" + self.batch = batch + self.loop = loop + self.batch_delay = 10 + self.batch_task = None + self.last_batched = None + self.waiting = [] + + self.bucket = Bucket(20, 40) + self.ignored = 0 + + self.session = None + self.webhook = None + + def get_loop(self): + if self.loop is None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + return self.loop + + def emit(self, record): + self.format(record) + self.get_loop().call_soon_threadsafe(self._post, record) + + def _post(self, record): + if self.session is None: + self.setup() + asyncio.create_task(self.post(record)) + + def setup(self): + self.session = aiohttp.ClientSession() + self.webhook = Webhook.from_url(self.webhook_url, session=self.session) + + async def post(self, record): + if record.context == 'Webhook Logger': + # Don't livelog livelog errors + # Otherwise we recurse and Cloudflare hates us + return + log_context.set("Webhook Logger") + log_action_stack.set(("Logging",)) + log_app.set(record.app) + + try: + timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") + header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>" + context = f"\n# Context: {record.ctx}" if record.ctx else "" + message = f"{header}\n{record.msg}{context}" + + if len(message) > 1900: + as_file = True + else: + as_file = False + message = "```md\n{}\n```".format(message) + + # Post the log message(s) + if self.batch: + if len(message) > 1500: + await self._send_batched_now() + await self._send(message, as_file=as_file) + else: + self.batched += message + if len(self.batched) + len(message) > 1500: + await self._send_batched_now() + else: + asyncio.create_task(self._schedule_batched()) + else: + await self._send(message, as_file=as_file) + except Exception as ex: + print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr) + + async def _schedule_batched(self): + if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()): + # noop, don't reschedule if it is already scheduled + return + try: + self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay)) + await self.batch_task + await self._send_batched() + except asyncio.CancelledError: + return + except Exception as ex: + print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr) + + async def _send_batched_now(self): + if self.batch_task is not None and not self.batch_task.done(): + self.batch_task.cancel() + self.last_batched = None + await self._send_batched() + + async def _send_batched(self): + if self.batched: + batched = self.batched + self.batched = "" + await self._send(batched) + + async def _send(self, message, as_file=False): + try: + self.bucket.request() + except BucketOverFull: + # Silently ignore + self.ignored += 1 + return + except BucketFull: + logger.warning( + "Can't keep up! " + f"Ignoring records on live-logger {self.webhook.id}." + ) + self.ignored += 1 + return + else: + if self.ignored > 0: + logger.warning( + "Can't keep up! " + f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing." + ) + self.ignored = 0 + + try: + if as_file or len(message) > 1900: + with StringIO(message) as fp: + fp.seek(0) + await self.webhook.send( + f"{self.prefix}\n`{message.splitlines()[0]}`", + file=File(fp, filename="logs.md"), + username=log_app.get() + ) + else: + await self.webhook.send(self.prefix + '\n' + message, username=log_app.get()) + except discord.HTTPException: + logger.exception( + "Live logger errored. Slowing down live logger." + ) + self.bucket.fill() + + +handlers = [] +if webhook := conf.logging['general_log']: + handler = WebHookHandler(webhook, batch=True) + handlers.append(handler) + +if webhook := conf.logging['warning_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True) + handler.addFilter(ExactLevelFilter(logging.WARNING)) + handler.setLevel(logging.WARNING) + handlers.append(handler) + +if webhook := conf.logging['error_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True) + handler.setLevel(logging.ERROR) + handlers.append(handler) + +if webhook := conf.logging['critical_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False) + handler.setLevel(logging.CRITICAL) + handlers.append(handler) + + +def make_queue_handler(queue): + qhandler = QueueHandler(queue) + qhandler.setLevel(logging.INFO) + qhandler.addFilter(ContextInjection()) + return qhandler + + +def setup_main_logger(multiprocess=False): + q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue() + if handlers: + # First create a separate loop to run the handlers on + import threading + + def run_loop(loop): + asyncio.set_event_loop(loop) + try: + loop.run_forever() + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + loop = asyncio.new_event_loop() + loop_thread = threading.Thread(target=lambda: run_loop(loop)) + loop_thread.daemon = True + loop_thread.start() + + for handler in handlers: + handler.loop = loop + + qhandler = make_queue_handler(q) + # qhandler.addFilter(ThreadFilter('MainThread')) + logger.addHandler(qhandler) + + listener = QueueListener( + q, *handlers, respect_handler_level=True + ) + listener.start() + return q diff --git a/src/meta/monitor.py b/src/meta/monitor.py new file mode 100644 index 0000000..474c51f --- /dev/null +++ b/src/meta/monitor.py @@ -0,0 +1,139 @@ +import logging +import asyncio +from enum import IntEnum +from collections import deque, ChainMap +import datetime as dt + +logger = logging.getLogger(__name__) + + +class StatusLevel(IntEnum): + ERRORED = -2 + UNSURE = -1 + WAITING = 0 + STARTING = 1 + OKAY = 2 + + @property + def symbol(self): + return symbols[self] + + +symbols = { + StatusLevel.ERRORED: '🟥', + StatusLevel.UNSURE: '🟧', + StatusLevel.WAITING: '⬜', + StatusLevel.STARTING: '🟫', + StatusLevel.OKAY: '🟩', +} + + +class ComponentStatus: + def __init__(self, level: StatusLevel, short_formatstr: str, long_formatstr: str, data: dict = {}): + self.level = level + self.short_formatstr = short_formatstr + self.long_formatstr = long_formatstr + self.data = data + self.created_at = dt.datetime.now(tz=dt.timezone.utc) + + def format_args(self): + extra = { + 'created_at': self.created_at, + 'level': self.level, + 'symbol': self.level.symbol, + } + return ChainMap(extra, self.data) + + @property + def short(self): + return self.short_formatstr.format(**self.format_args()) + + @property + def long(self): + return self.long_formatstr.format(**self.format_args()) + + +class ComponentMonitor: + _name = None + + def __init__(self, name=None, callback=None): + self._callback = callback + self.name = name or self._name + if not self.name: + raise ValueError("ComponentMonitor must have a name") + + async def _make_status(self, *args, **kwargs): + if self._callback is not None: + return await self._callback(*args, **kwargs) + else: + raise NotImplementedError + + async def status(self) -> ComponentStatus: + try: + status = await self._make_status() + except Exception as e: + logger.exception( + f"Status callback for component '{self.name}' failed. This should not happen." + ) + status = ComponentStatus( + level=StatusLevel.UNSURE, + short_formatstr="Status callback for '{name}' failed with error '{error}'", + long_formatstr="Status callback for '{name}' failed with error '{error}'", + data={ + 'name': self.name, + 'error': repr(e) + } + ) + return status + + +class SystemMonitor: + def __init__(self): + self.components = {} + self.recent = deque(maxlen=10) + + def add_component(self, component: ComponentMonitor): + self.components[component.name] = component + return component + + async def request(self): + """ + Request status from each component. + """ + tasks = { + name: asyncio.create_task(comp.status()) + for name, comp in self.components.items() + } + await asyncio.gather(*tasks.values()) + status = { + name: await fut for name, fut in tasks.items() + } + self.recent.append(status) + return status + + async def _format_summary(self, status_dict: dict[str, ComponentStatus]): + """ + Format a one line summary from a status dict. + """ + freq = {level: 0 for level in StatusLevel} + for status in status_dict.values(): + freq[status.level] += 1 + + summary = '\t'.join(f"{level.symbol} {count}" for level, count in freq.items() if count) + return summary + + async def _format_overview(self, status_dict: dict[str, ComponentStatus]): + """ + Format an overview (one line per component) from a status dict. + """ + lines = [] + for name, status in status_dict.items(): + lines.append(f"{status.level.symbol} {name}: {status.short}") + summary = await self._format_summary(status_dict) + return '\n'.join((summary, *lines)) + + async def get_summary(self): + return await self._format_summary(await self.request()) + + async def get_overview(self): + return await self._format_overview(await self.request()) diff --git a/src/meta/sharding.py b/src/meta/sharding.py new file mode 100644 index 0000000..14da402 --- /dev/null +++ b/src/meta/sharding.py @@ -0,0 +1,35 @@ +from .args import args +from .config import conf + +from psycopg import sql +from data.conditions import Condition, Joiner + + +shard_number = args.shard or 0 + +shard_count = conf.bot.getint('shard_count', 1) + +sharded = (shard_count > 0) + + +def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition: + """ + Condition constructor for filtering by shard id. + + Example Usage + ------------- + Query.where(_shard_condition('guildid', 10, 1)) + """ + return Condition( + sql.SQL("({guildid} >> 22) %% {shard_count}").format( + guildid=sql.Identifier(guild_column), + shard_count=sql.Literal(shard_count) + ), + Joiner.EQUALS, + sql.Placeholder(), + (shard_id,) + ) + + +# Pre-built Condition for filtering by current shard. +THIS_SHARD = SHARDID(shard_number) diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..3f6887e --- /dev/null +++ b/src/modules/__init__.py @@ -0,0 +1,10 @@ +this_package = 'modules' + +active = [ + '.sysadmin', +] + + +async def setup(bot): + for ext in active: + await bot.load_extension(ext, package=this_package) diff --git a/src/modules/sysadmin/__init__.py b/src/modules/sysadmin/__init__.py new file mode 100644 index 0000000..4a96ce6 --- /dev/null +++ b/src/modules/sysadmin/__init__.py @@ -0,0 +1,5 @@ + +async def setup(bot): + from .exec_cog import Exec + + await bot.add_cog(Exec(bot)) diff --git a/src/modules/sysadmin/exec_cog.py b/src/modules/sysadmin/exec_cog.py new file mode 100644 index 0000000..776292a --- /dev/null +++ b/src/modules/sysadmin/exec_cog.py @@ -0,0 +1,336 @@ +import io +import ast +import sys +import types +import asyncio +import traceback +import builtins +import inspect +import logging +from io import StringIO + +from typing import Callable, Any, Optional + +from enum import Enum + +import discord +from discord.ext import commands +from discord.ext.commands.errors import CheckFailure +from discord.ui import TextInput, View +from discord.ui.button import button +import discord.app_commands as appcmd + +from meta.logger import logging_context, log_wrap +from meta import conf +from meta.context import context, ctx_bot +from meta.LionContext import LionContext +from meta.LionCog import LionCog +from meta.LionBot import LionBot + +from utils.ui import FastModal, input + +from wards import sys_admin + + +def _(arg): return arg + +def _p(ctx, arg): return arg + + +logger = logging.getLogger(__name__) + +class ExecModal(FastModal, title="Execute"): + code: TextInput = TextInput( + label="Code to execute", + style=discord.TextStyle.long, + required=True + ) + + +class ExecStyle(Enum): + EXEC = 'exec' + EVAL = 'eval' + + +class ExecUI(View): + def __init__(self, ctx, code=None, style=ExecStyle.EXEC, ephemeral=True) -> None: + super().__init__() + + self.ctx: LionContext = ctx + self.interaction: Optional[discord.Interaction] = ctx.interaction + self.code: Optional[str] = code + self.style: ExecStyle = style + self.ephemeral: bool = ephemeral + + self._modal: Optional[ExecModal] = None + self._msg: Optional[discord.Message] = None + + async def interaction_check(self, interaction: discord.Interaction): + """Only allow the original author to use this View""" + if interaction.user.id != self.ctx.author.id: + await interaction.response.send_message( + ("You cannot use this interface!"), + ephemeral=True + ) + return False + else: + return True + + async def run(self): + if self.code is None: + if (interaction := self.interaction) is not None: + self.interaction = None + await interaction.response.send_modal(self.get_modal()) + await self.wait() + else: + # Complain + # TODO: error_reply + await self.ctx.reply("Pls give code.") + else: + await self.interaction.response.defer(thinking=True, ephemeral=self.ephemeral) + await self.compile() + await self.wait() + + @button(label="Recompile") + async def recompile_button(self, interaction, butt): + # Interaction response with modal + await interaction.response.send_modal(self.get_modal()) + + @button(label="Show Source") + async def source_button(self, interaction, butt): + if len(self.code) > 1900: + # Send as file + with StringIO(self.code) as fp: + fp.seek(0) + file = discord.File(fp, filename="source.py") + await interaction.response.send_message(file=file, ephemeral=True) + else: + # Send as message + await interaction.response.send_message( + content=f"```py\n{self.code}```", + ephemeral=True + ) + + def create_modal(self) -> ExecModal: + modal = ExecModal() + + @modal.submit_callback() + async def exec_submit(interaction: discord.Interaction): + if self.interaction is None: + self.interaction = interaction + await interaction.response.defer(thinking=True) + else: + await interaction.response.defer() + + # Set code + self.code = modal.code.value + + # Call compile + await self.compile() + + return modal + + def get_modal(self): + self._modal = self.create_modal() + self._modal.code.default = self.code + return self._modal + + async def compile(self): + # Call _async + result = await _async(self.code, style=self.style.value) + + # Display output + await self.show_output(result) + + async def show_output(self, output): + # Format output + # If output message exists and not ephemeral, edit + # Otherwise, send message, add buttons + if len(output) > 1900: + # Send as file + with StringIO(output) as fp: + fp.seek(0) + args = { + 'content': None, + 'attachments': [discord.File(fp, filename="output.md")] + } + else: + args = { + 'content': f"```md\n{output}```", + 'attachments': [] + } + + if self._msg is None: + if self.interaction is not None: + msg = await self.interaction.edit_original_response(**args, view=self) + else: + # Send new message + if args['content'] is None: + args['file'] = args.pop('attachments')[0] + msg = await self.ctx.reply(**args, ephemeral=self.ephemeral, view=self) + + if not self.ephemeral: + self._msg = msg + else: + if self.interaction is not None: + await self.interaction.edit_original_response(**args, view=self) + else: + # Edit message + await self._msg.edit(**args) + + +def mk_print(fp: io.StringIO) -> Callable[..., None]: + def _print(*args, file: Any = fp, **kwargs): + return print(*args, file=file, **kwargs) + return _print + + +def mk_status_printer(bot, printer): + async def _status(details=False): + if details: + status = await bot.system_monitor.get_overview() + else: + status = await bot.system_monitor.get_summary() + printer(status) + return status + return _status + + +@log_wrap(action="Code Exec") +async def _async(to_eval: str, style='exec'): + newline = '\n' * ('\n' in to_eval) + logger.info( + f"Exec code with {style}: {newline}{to_eval}" + ) + + output = io.StringIO() + _print = mk_print(output) + + scope: dict[str, Any] = dict(sys.modules) + scope['__builtins__'] = builtins + scope.update(builtins.__dict__) + scope['ctx'] = ctx = context.get() + scope['bot'] = ctx_bot.get() + scope['print'] = _print # type: ignore + scope['print_status'] = mk_status_printer(scope['bot'], _print) + + try: + if ctx and ctx.message: + source_str = f"" + elif ctx and ctx.interaction: + source_str = f"" + else: + source_str = "Unknown async" + + code = compile( + to_eval, + source_str, + style, + ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + ) + func = types.FunctionType(code, scope) + + ret = func() + if inspect.iscoroutine(ret): + ret = await ret + if ret is not None: + _print(repr(ret)) + except Exception: + _, exc, tb = sys.exc_info() + _print("".join(traceback.format_tb(tb))) + _print(f"{type(exc).__name__}: {exc}") + + result = output.getvalue().strip() + newline = '\n' * ('\n' in result) + logger.info( + f"Exec complete, output: {newline}{result}" + ) + return result + + +class Exec(LionCog): + guild_ids = conf.bot.getintlist('admin_guilds') + + def __init__(self, bot: LionBot): + self.bot = bot + + async def cog_check(self, ctx: LionContext) -> bool: # type: ignore + passed = await sys_admin(ctx.bot, ctx.author.id) + if passed: + return True + else: + raise CheckFailure( + "You must be a bot owner to do this!" + ) + + @commands.hybrid_command( + name=_('async'), + description=_("Execute arbitrary code with Exec") + ) + @appcmd.describe( + string="Code to execute.", + ) + async def async_cmd(self, ctx: LionContext, + string: Optional[str] = None, + ): + await ExecUI(ctx, string, ExecStyle.EXEC, ephemeral=False).run() + + @commands.hybrid_command( + name=_('reload'), + description=_("Reload a given LionBot extension. Launches an ExecUI.") + ) + @appcmd.describe( + extension=_("Name of the extension to reload. See autocomplete for options."), + force=_("Whether to force an extension reload even if it doesn't exist.") + ) + @appcmd.guilds(*guild_ids) + async def reload_cmd(self, ctx: LionContext, extension: str, force: Optional[bool] = False): + """ + This is essentially just a friendly wrapper to reload an extension. + It is equivalent to running "await bot.reload_extension(extension)" in eval, + with a slightly nicer interface through the autocomplete and error handling. + """ + exists = (extension in self.bot.extensions) + if not (force or exists): + embed = discord.Embed(description=f"Unknown extension {extension}", colour=discord.Colour.red()) + await ctx.reply(embed=embed) + else: + # Uses an ExecUI to simplify error handling and re-execution + if exists: + string = f"await bot.reload_extension('{extension}')" + else: + string = f"await bot.load_extension('{extension}')" + await ExecUI(ctx, string, ExecStyle.EVAL).run() + + @reload_cmd.autocomplete('extension') + async def reload_extension_acmpl(self, interaction: discord.Interaction, partial: str): + keys = set(self.bot.extensions.keys()) + results = [ + appcmd.Choice(name=key, value=key) + for key in keys + if partial.lower() in key.lower() + ] + if not results: + results = [ + appcmd.Choice(name=f"No extensions found matching {partial}", value="None") + ] + return results[:25] + + @commands.hybrid_command( + name=_('shutdown'), + description=_("Shutdown (or restart) the client.") + ) + @appcmd.guilds(*guild_ids) + async def shutdown_cmd(self, ctx: LionContext): + """ + Shutdown the client. + Maybe do something friendly here? + """ + logger.info("Shutting down on admin request.") + await ctx.reply( + embed=discord.Embed( + description=f"Understood {ctx.author.mention}, cleaning up and shutting down!", + colour=discord.Colour.orange() + ) + ) + await self.bot.close() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/ansi.py b/src/utils/ansi.py new file mode 100644 index 0000000..11f2852 --- /dev/null +++ b/src/utils/ansi.py @@ -0,0 +1,97 @@ +""" +Minimal library for making Discord Ansi colour codes. +""" +from enum import StrEnum + + +PREFIX = u'\u001b' + + +class TextColour(StrEnum): + Gray = '30' + Red = '31' + Green = '32' + Yellow = '33' + Blue = '34' + Pink = '35' + Cyan = '36' + White = '37' + + def __str__(self) -> str: + return AnsiColour(fg=self).as_str() + + def __call__(self): + return AnsiColour(fg=self) + + +class BgColour(StrEnum): + FireflyDarkBlue = '40' + Orange = '41' + MarbleBlue = '42' + GrayTurq = '43' + Gray = '44' + Indigo = '45' + LightGray = '46' + White = '47' + + def __str__(self) -> str: + return AnsiColour(bg=self).as_str() + + def __call__(self): + return AnsiColour(bg=self) + + +class Format(StrEnum): + NORMAL = '0' + BOLD = '1' + UNDERLINE = '4' + NOOP = '9' + + def __str__(self) -> str: + return AnsiColour(self).as_str() + + def __call__(self): + return AnsiColour(self) + + +class AnsiColour: + def __init__(self, *flags, fg=None, bg=None): + self.text_colour = fg + self.background_colour = bg + self.reset = (Format.NORMAL in flags) + self._flags = set(flags) + self._flags.discard(Format.NORMAL) + + @property + def flags(self): + return (*((Format.NORMAL,) if self.reset else ()), *self._flags) + + def as_str(self): + parts = [] + if self.reset: + parts.append(Format.NORMAL) + elif not self.flags: + parts.append(Format.NOOP) + + parts.extend(self._flags) + + for c in (self.text_colour, self.background_colour): + if c is not None: + parts.append(c) + + partstr = ';'.join(part.value for part in parts) + return f"{PREFIX}[{partstr}m" # ] + + def __str__(self): + return self.as_str() + + def __add__(self, obj: 'AnsiColour'): + text_colour = obj.text_colour or self.text_colour + background_colour = obj.background_colour or self.background_colour + flags = (*self.flags, *obj.flags) + return AnsiColour(*flags, fg=text_colour, bg=background_colour) + + +RESET = AnsiColour(Format.NORMAL) +BOLD = AnsiColour(Format.BOLD) +UNDERLINE = AnsiColour(Format.UNDERLINE) diff --git a/src/utils/data.py b/src/utils/data.py new file mode 100644 index 0000000..a590430 --- /dev/null +++ b/src/utils/data.py @@ -0,0 +1,165 @@ +""" +Some useful pre-built Conditions for data queries. +""" +from typing import Optional, Any +from itertools import chain + +from psycopg import sql +from data.conditions import Condition, Joiner +from data.columns import ColumnExpr +from data.base import Expression +from constants import MAX_COINS + + +def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[Any, ...]) -> Condition: + """ + Condition constructor for filtering by multiple column equalities. + + Example Usage + ------------- + Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4))) + """ + if not data: + raise ValueError("Cannot create empty multivalue condition.") + left = sql.SQL("({})").format( + sql.SQL(', ').join( + sql.Identifier(key) + for key in columns + ) + ) + right_item = sql.SQL('({})').format( + sql.SQL(', ').join( + sql.Placeholder() + for _ in columns + ) + ) + right = sql.SQL("({})").format( + sql.SQL(', ').join( + right_item + for _ in data + ) + ) + return Condition( + left, + Joiner.IN, + right, + chain(*data) + ) + + +def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition: + """ + Condition constructor for filtering member tables by guild and user id simultaneously. + + Example Usage + ------------- + Query.where(MEMBERS((1234,12), (5678,34))) + """ + if not memberids: + raise ValueError("Cannot create a condition with no members") + return Condition( + sql.SQL("({guildid}, {userid})").format( + guildid=sql.Identifier(guild_column), + userid=sql.Identifier(user_column) + ), + Joiner.IN, + sql.SQL("({})").format( + sql.SQL(', ').join( + sql.SQL("({}, {})").format( + sql.Placeholder(), + sql.Placeholder() + ) for _ in memberids + ) + ), + chain(*memberids) + ) + + +def as_duration(expr: Expression) -> ColumnExpr: + """ + Convert an integer expression into a duration expression. + """ + expr_expr, expr_values = expr.as_tuple() + return ColumnExpr( + sql.SQL("({} * interval '1 second')").format(expr_expr), + expr_values + ) + + +class TemporaryTable(Expression): + """ + Create a temporary table expression to be used in From or With clauses. + + Example + ------- + ``` + tmp_table = TemporaryTable('_col1', '_col2', name='data') + tmp_table.values((1, 2), (3, 4)) + + real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table) + ``` + """ + + def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str, ...]] = None): + self.name = name + self.columns = columns + self.types = types + if types and len(types) != len(columns): + raise ValueError("Number of types does not much number of columns!") + + self._table_columns = { + col: ColumnExpr(sql.Identifier(name, col)) + for col in columns + } + + self.values = [] + + def __getitem__(self, key) -> sql.Identifier: + return self._table_columns[key] + + def as_tuple(self): + """ + (VALUES {}) + AS + name (col1, col2) + """ + if not self.values: + raise ValueError("Cannot flatten CTE with no values.") + + single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns)) + if self.types: + first_value = sql.SQL("({})").format( + sql.SQL(", ").join( + sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast)) + for cast in self.types + ) + ) + else: + first_value = single_value + + value_placeholder = sql.SQL("(VALUES {})").format( + sql.SQL(", ").join( + (first_value, *(single_value for _ in self.values[1:])) + ) + ) + expr = sql.SQL("{values} AS {name} ({columns})").format( + values=value_placeholder, + name=sql.Identifier(self.name), + columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns) + ) + values = chain(*self.values) + return (expr, values) + + def set_values(self, *data): + self.values = data + + +def SAFECOINS(expr: Expression) -> Expression: + expr_expr, expr_values = expr.as_tuple() + return ColumnExpr( + sql.SQL("LEAST({}, {})").format( + expr_expr, + sql.Literal(MAX_COINS) + ), + expr_values + ) diff --git a/src/utils/lib.py b/src/utils/lib.py new file mode 100644 index 0000000..30657ed --- /dev/null +++ b/src/utils/lib.py @@ -0,0 +1,847 @@ +from io import StringIO +from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any +import collections +import datetime +import datetime as dt +import iso8601 # type: ignore +import pytz +import re +import json +from contextvars import Context + +import discord +from discord.partial_emoji import _EmojiTag +from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage +from discord.ui import View + +from meta.errors import UserInputError + + +multiselect_regex = re.compile( + r"^([0-9, -]+)$", + re.DOTALL | re.IGNORECASE | re.VERBOSE +) +tick = '✅' +cross = '❌' + +MISSING = object() + + +class MessageArgs: + """ + Utility class for storing message creation and editing arguments. + """ + # TODO: Overrides for mutually exclusive arguments, see Messageable.send + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + def __init__(self, **kwargs): + self.kwargs = kwargs + + @property + def send_args(self) -> dict: + if self.kwargs.get('view', MISSING) is None: + kwargs = self.kwargs.copy() + kwargs.pop('view') + else: + kwargs = self.kwargs + + return kwargs + + @property + def edit_args(self) -> dict: + args = {} + kept = ( + 'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view' + ) + for k in kept: + if k in self.kwargs: + args[k] = self.kwargs[k] + + if 'file' in self.kwargs: + args['attachments'] = [self.kwargs['file']] + + if 'files' in self.kwargs: + args['attachments'] = self.kwargs['files'] + + if 'suppress_embeds' in self.kwargs: + args['suppress'] = self.kwargs['suppress_embeds'] + + return args + + +def tabulate( + *fields: tuple[str, str], + row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}", + sub_format: str = "`{invis:<{pad}}{colon}`\t{value}", + colon: str = ':', + invis: str = "​", + **args +) -> list[str]: + """ + Turns a list of (property, value) pairs into + a pretty string with one `prop: value` pair each line, + padded so that the colons in each line are lined up. + Use `\\r\\n` in a value to break the line with padding. + + Parameters + ---------- + fields: List[tuple[str, str]] + List of (key, value) pairs. + row_format: str + The format string used to format each row. + sub_format: str + The format string used to format each subline in a row. + colon: str + The colon character used. + invis: str + The invisible character used (to avoid Discord stripping the string). + + Returns: List[str] + The list of resulting table rows. + Each row corresponds to one (key, value) pair from fields. + """ + max_len = max(len(field[0]) for field in fields) + + rows = [] + for field in fields: + key = field[0] + value = field[1] + lines = value.split('\r\n') + + row_line = row_format.format( + invis=invis, + key=key, + pad=max_len, + colon=colon, + value=lines[0], + field=field, + **args + ) + if len(lines) > 1: + row_lines = [row_line] + for line in lines[1:]: + sub_line = sub_format.format( + invis=invis, + pad=max_len + len(colon), + colon=colon, + value=line, + **args + ) + row_lines.append(sub_line) + row_line = '\n'.join(row_lines) + rows.append(row_line) + return rows + + +def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]: + """ + Create pretty codeblock pages from a list of strings. + + Parameters + ---------- + item_list: List[str] + List of strings to paginate. + block_length: int + Maximum number of strings per page. + style: str + Codeblock style to use. + Title formatting assumes the `markdown` style, and numbered lists work well with this. + However, `markdown` sometimes messes up formatting in the list. + title: str + Optional title to add to the top of each page. + + Returns: List[str] + List of pages, each formatted into a codeblock, + and containing at most `block_length` of the provided strings. + """ + lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)] + page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)] + pages = [] + for i, block in enumerate(page_blocks): + pagenum = "Page {}/{}".format(i + 1, len(page_blocks)) + if title: + header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title + else: + header = pagenum + header_line = "=" * len(header) + full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else "" + pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block))) + return pages + + +def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]: + """ + Break the text into blocks of maximum length blocksize + If possible, break across nearby newlines. Otherwise just break at blocksize chars + + Parameters + ---------- + text: str + Text to break into blocks. + blocksize: int + Maximum character length for each block. + code: bool + Whether to wrap each block in codeblocks (these are counted in the blocksize). + syntax: str + The markdown formatting language to use for the codeblocks, if applicable. + maxheight: int + The maximum number of lines in each block + + Returns: List[str] + List of blocks, + each containing at most `block_size` characters, + of height at most `maxheight`. + """ + # Adjust blocksize to account for the codeblocks if required + blocksize = blocksize - 8 - len(syntax) if code else blocksize + + # Build the blocks + blocks = [] + while True: + # If the remaining text is already small enough, append it + if len(text) <= blocksize: + blocks.append(text) + break + text = text.strip('\n') + + # Find the last newline in the prototype block + split_on = text[0:blocksize].rfind('\n') + split_on = blocksize if split_on < blocksize // 5 else split_on + + # Add the block and truncate the text + blocks.append(text[0:split_on]) + text = text[split_on:] + + # Add the codeblock ticks and the code syntax header, if required + if code: + blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks] + + return blocks + + +def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str: + """ + Convert a datetime.timedelta object into an easily readable duration string. + + Parameters + ---------- + delta: datetime.timedelta + The timedelta object to convert into a readable string. + sec: bool + Whether to include the seconds from the timedelta object in the string. + minutes: bool + Whether to include the minutes from the timedelta object in the string. + short: bool + Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s"). + + Returns: str + A string containing a time from the datetime.timedelta object, in a readable format. + Time units will be abbreviated if short was set to True. + """ + output = [[delta.days, 'd' if short else ' day'], + [delta.seconds // 3600, 'h' if short else ' hour']] + if minutes: + output.append([delta.seconds // 60 % 60, 'm' if short else ' minute']) + if sec: + output.append([delta.seconds % 60, 's' if short else ' second']) + for i in range(len(output)): + if output[i][0] != 1 and not short: + output[i][1] += 's' # type: ignore + reply_msg = [] + if output[0][0] != 0: + reply_msg.append("{}{} ".format(output[0][0], output[0][1])) + if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2: + reply_msg.append("{}{} ".format(output[1][0], output[1][1])) + for i in range(2, len(output) - 1): + reply_msg.append("{}{} ".format(output[i][0], output[i][1])) + if not short and reply_msg: + reply_msg.append("and ") + reply_msg.append("{}{}".format(output[-1][0], output[-1][1])) + return "".join(reply_msg) + + +def _parse_dur(time_str: str) -> int: + """ + Parses a user provided time duration string into a timedelta object. + + Parameters + ---------- + time_str: str + The time string to parse. String can include days, hours, minutes, and seconds. + + Returns: int + The number of seconds the duration represents. + """ + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + time_str = time_str.strip(" ,") + found = re.findall(r'(\d+)\s?(\w+?)', time_str) + seconds = 0 + for bit in found: + if bit[1] in funcs: + seconds += funcs[bit[1]](int(bit[0])) + return seconds + + +def strfdur(duration: int, short=True, show_days=False) -> str: + """ + Convert a duration given in seconds to a number of hours, minutes, and seconds. + """ + days = duration // (3600 * 24) if show_days else 0 + hours = duration // 3600 + if days: + hours %= 24 + minutes = duration // 60 % 60 + seconds = duration % 60 + + parts = [] + if days: + unit = 'd' if short else (' days' if days != 1 else ' day') + parts.append('{}{}'.format(days, unit)) + if hours: + unit = 'h' if short else (' hours' if hours != 1 else ' hour') + parts.append('{}{}'.format(hours, unit)) + if minutes: + unit = 'm' if short else (' minutes' if minutes != 1 else ' minute') + parts.append('{}{}'.format(minutes, unit)) + if seconds or duration == 0: + unit = 's' if short else (' seconds' if seconds != 1 else ' second') + parts.append('{}{}'.format(seconds, unit)) + + if short: + return ' '.join(parts) + else: + return ', '.join(parts) + + +def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str: + """ + Substitutes a user provided list of numbers and ranges, + and replaces the ranges by the corresponding list of numbers. + + Parameters + ---------- + ranges_str: str + The string to ranges in. + max_match: int + The maximum number of ranges to replace. + Any ranges exceeding this will be ignored. + max_range: int + The maximum length of range to replace. + Attempting to replace a range longer than this will raise a `ValueError`. + """ + def _repl(match): + n1 = int(match.group(1)) + n2 = int(match.group(2)) + if n2 - n1 > max_range: + # TODO: Upgrade to SafeCancellation + raise ValueError("Provided range is too large!") + return separator.join(str(i) for i in range(n1, n2 + 1)) + + return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match) + + +def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]: + """ + Parses a user provided range string into a list of numbers. + Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`. + """ + substituted = substitute_ranges(ranges_str, separator=separator, **kwargs) + _numbers = (item.strip() for item in substituted.split(',')) + numbers = [item for item in _numbers if item] + integers = [int(item) for item in numbers if item.isdigit()] + + if not ignore_errors and len(integers) != len(numbers): + # TODO: Upgrade to SafeCancellation + raise ValueError( + "Couldn't parse the provided selection!\n" + "Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`." + ) + + return integers + + +def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str: + """ + Format a message into a string with various information, such as: + the timestamp of the message, author, message content, and attachments. + + Parameters + ---------- + msg: Message + The message to format. + mask_link: bool + Whether to mask the URLs of any attachments. + line_break: bool + Whether a line break should be used in the string. + tz: Timezone + The timezone to use in the formatted message. + clean: bool + Whether to use the clean content of the original message. + + Returns: str + A formatted string containing various information: + User timezone, message author, message content, attachments + """ + timestr = "%I:%M %p, %d/%m/%Y" + if tz: + time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr) + else: + time = msg.created_at.strftime(timestr) + user = str(msg.author) + attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url] + if mask_link: + attach_list = ["[Link]({})".format(url) for url in attach_list] + attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else "" + return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format( + time=time, + user=user, + line_break="\n" if line_break else "", + message=msg.clean_content if clean else msg.content, + attachments=attachments + ) + + +def convdatestring(datestring: str) -> datetime.timedelta: + """ + Convert a date string into a datetime.timedelta object. + + Parameters + ---------- + datestring: str + The string to convert to a datetime.timedelta object. + + Returns: datetime.timedelta + A datetime.timedelta object formed from the string provided. + """ + datestring = datestring.strip(' ,') + datearray = [] + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + currentnumber = '' + for char in datestring: + if char.isdigit(): + currentnumber += char + else: + if currentnumber == '': + continue + datearray.append((int(currentnumber), char)) + currentnumber = '' + seconds = 0 + if currentnumber: + seconds += int(currentnumber) + for i in datearray: + if i[1] in funcs: + seconds += funcs[i[1]](i[0]) + return datetime.timedelta(seconds=seconds) + + +class _rawChannel(discord.abc.Messageable): + """ + Raw messageable class representing an arbitrary channel, + not necessarially seen by the gateway. + """ + def __init__(self, state, id): + self._state = state + self.id = id + + async def _get_channel(self): + return discord.Object(self.id) + + +async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message: + """ + Mails a message to a channelid which may be invisible to the gateway. + + Parameters: + client: discord.Client + The client to use for mailing. + Must at least have static authentication and have a valid `_connection`. + channelid: int + The channel id to mail to. + msg_args: Any + Message keyword arguments which are passed transparently to `_rawChannel.send(...)`. + """ + # Create the raw channel + channel = _rawChannel(client._connection, channelid) + return await channel.send(**msg_args) + + +class EmbedField(NamedTuple): + name: str + value: str + inline: Optional[bool] = True + + +def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]): + """ + Append embed fields to an embed. + Parameters + ---------- + embed: discord.Embed + The embed to add the field to. + emb_fields: tuple + The values to add to a field. + name: str + The name of the field. + value: str + The value of the field. + inline: bool + Whether the embed field should be inline or not. + """ + for field in emb_fields: + embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2])) + + +def join_list(string: list[str], nfs=False) -> str: + """ + Join a list together, separated with commas, plus add "and" to the beginning of the last value. + Parameters + ---------- + string: list + The list to join together. + nfs: bool + (no fullstops) + Whether to exclude fullstops/periods from the output messages. + If not provided, fullstops will be appended to the output. + """ + # TODO: Probably not useful with localisation + if len(string) > 1: + return "{}{} and {}{}".format((", ").join(string[:-1]), + "," if len(string) > 2 else "", string[-1], "" if nfs else ".") + else: + return "{}{}".format("".join(string), "" if nfs else ".") + + +def shard_of(shard_count: int, guildid: int) -> int: + """ + Calculate the shard number of a given guild. + """ + return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0 + + +def jumpto(guildid: int, channeldid: int, messageid: int) -> str: + """ + Build a jump link for a message given its location. + """ + return 'https://discord.com/channels/{}/{}/{}'.format( + guildid, + channeldid, + messageid + ) + + +def utc_now() -> datetime.datetime: + """ + Return the current timezone-aware utc timestamp. + """ + return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + + +def multiple_replace(string: str, rep_dict: dict[str, str]) -> str: + if rep_dict: + pattern = re.compile( + "|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]), + flags=re.DOTALL + ) + return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string) + else: + return string + + +def recover_context(context: Context): + for var in context: + var.set(context[var]) + + +def parse_ids(idstr: str) -> List[int]: + """ + Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids. + + Object agnostic, so all mention tokens are stripped. + Raises UserInputError if an id is invalid, + setting `orig` and `item` info fields. + """ + from meta.errors import UserInputError + + # Extract ids from string + splititer = (split.strip('<@!#&>, ') for split in idstr.split(',')) + splits = [split for split in splititer if split] + + # Check they are integers + if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None: + raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id}) + + # Cast to integer and return + return list(map(int, splits)) + + +def error_embed(error, **kwargs) -> discord.Embed: + embed = discord.Embed( + colour=discord.Colour.brand_red(), + description=error, + timestamp=utc_now() + ) + return embed + + +class DotDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +class Timezoned: + """ + ABC mixin for objects with a set timezone. + + Provides several useful localised properties. + """ + __slots__ = () + + @property + def timezone(self) -> pytz.timezone: + """ + Must be implemented by the deriving class! + """ + raise NotImplementedError + + @property + def now(self): + """ + Return the current time localised to the object's timezone. + """ + return datetime.datetime.now(tz=self.timezone) + + @property + def today(self): + """ + Return the start of the day localised to the object's timezone. + """ + now = self.now + return now.replace(hour=0, minute=0, second=0, microsecond=0) + + @property + def week_start(self): + """ + Return the start of the week in the object's timezone + """ + today = self.today + return today - datetime.timedelta(days=today.weekday()) + + @property + def month_start(self): + """ + Return the start of the current month in the object's timezone + """ + today = self.today + return today.replace(day=1) + + +def replace_multiple(format_string, mapping): + """ + Subsistutes the keys from the format_dict with their corresponding values. + + Substitution is non-chained, and done in a single pass via regex. + """ + if not mapping: + raise ValueError("Empty mapping passed.") + + keys = list(mapping.keys()) + pattern = '|'.join(f"({key})" for key in keys) + string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string) + return string + + +def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str): + """ + Produces a distinguishing key for an Emoji or PartialEmoji. + + Equality checks using this key should act as expected. + """ + if isinstance(emoji, _EmojiTag): + if emoji.id: + key = str(emoji.id) + else: + key = str(emoji.name) + else: + key = str(emoji) + + return key + +def recurse_map(func, obj, loc=[]): + if isinstance(obj, dict): + for k, v in obj.items(): + loc.append(k) + obj[k] = recurse_map(func, v, loc) + loc.pop() + elif isinstance(obj, list): + for i, item in enumerate(obj): + loc.append(i) + obj[i] = recurse_map(func, item) + loc.pop() + else: + obj = func(loc, obj) + return obj + +async def check_dm(user: discord.User | discord.Member) -> bool: + """ + Check whether we can direct message the given user. + + Assumes the client is initialised. + This uses an always-failing HTTP request, + so we need to be very very very careful that this is not used frequently. + Optimally only at the explicit behest of the user + (i.e. during a user instigated interaction). + """ + try: + await user.send('') + except discord.Forbidden: + return False + except discord.HTTPException: + return True + + +async def command_lengths(tree) -> dict[str, int]: + cmds = tree.get_commands() + payloads = [ + await cmd.get_translated_payload(tree.translator) + for cmd in cmds + ] + lens = {} + for command in payloads: + name = command['name'] + crumbs = {} + cmd_len = lens[name] = _recurse_length(command, crumbs, (name,)) + if name == 'configure' or cmd_len > 4000: + print(f"'{name}' over 4000. Breadcrumb Trail follows:") + lines = [] + for loc, val in crumbs.items(): + locstr = '.'.join(loc) + lines.append(f"{locstr}: {val}") + print('\n'.join(lines)) + print(json.dumps(command, indent=2)) + return lens + +def _recurse_length(payload, breadcrumbs={}, header=()) -> int: + total = 0 + total_header = (*header, '') + breadcrumbs[total_header] = 0 + + if isinstance(payload, dict): + # Read strings that count towards command length + # String length is length of longest localisation, including default. + for key in ('name', 'description', 'value'): + if key in payload: + value = payload[key] + if isinstance(value, str): + values = (value, *payload.get(key + '_localizations', {}).values()) + maxlen = max(map(len, values)) + total += maxlen + breadcrumbs[(*header, key)] = maxlen + + for key, value in payload.items(): + loc = (*header, key) + total += _recurse_length(value, breadcrumbs, loc) + elif isinstance(payload, list): + for i, item in enumerate(payload): + if isinstance(item, dict) and 'name' in item: + loc = (*header, f"{i}<{item['name']}>") + else: + loc = (*header, str(i)) + total += _recurse_length(item, breadcrumbs, loc) + + if total: + breadcrumbs[total_header] = total + else: + breadcrumbs.pop(total_header) + + return total + +def write_records(records: list[dict[str, Any]], stream: StringIO): + if records: + keys = records[0].keys() + stream.write(','.join(keys)) + stream.write('\n') + for record in records: + stream.write(','.join(map(str, record.values()))) + stream.write('\n') diff --git a/src/utils/monitor.py b/src/utils/monitor.py new file mode 100644 index 0000000..96aedeb --- /dev/null +++ b/src/utils/monitor.py @@ -0,0 +1,191 @@ +import asyncio +import bisect +import logging +from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any + +from .lib import utc_now +from .ratelimits import Bucket + + +logger = logging.getLogger(__name__) + +Taskid = TypeVar('Taskid') + + +class TaskMonitor(Generic[Taskid]): + """ + Base class for a task monitor. + + Stores tasks as a time-sorted list of taskids. + Subclasses may override `run_task` to implement an executor. + + Adding or removing a single task has O(n) performance. + To bulk update tasks, instead use `schedule_tasks`. + + Each taskid must be unique and hashable. + """ + + def __init__(self, executor=None, bucket: Optional[Bucket] = None): + # Ratelimit bucket to enforce maximum execution rate + self._bucket = bucket + + self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor + + self._wakeup: asyncio.Event = asyncio.Event() + self._monitor_task: Optional[asyncio.Task] = None + + # Task data + self._tasklist: list[Taskid] = [] + self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp + + # Running map ensures we keep a reference to the running task + # And allows simpler external cancellation if required + self._running: dict[Taskid, asyncio.Future] = {} + + def __repr__(self): + return ( + "<" + f"{self.__class__.__name__}" + f" tasklist={len(self._tasklist)}" + f" taskmap={len(self._taskmap)}" + f" wakeup={self._wakeup.is_set()}" + f" bucket={self._bucket}" + f" running={len(self._running)}" + f" task={self._monitor_task}" + f">" + ) + + def set_tasks(self, *tasks: tuple[Taskid, int]) -> None: + """ + Similar to `schedule_tasks`, but wipe and reset the tasklist. + """ + self._taskmap = {tid: time for tid, time in tasks} + self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])) + self._wakeup.set() + + def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None: + """ + Schedule the given tasks. + + Rather than repeatedly inserting tasks, + where the O(log n) insort is dominated by the O(n) list insertion, + we build an entirely new list, and always wake up the loop. + """ + self._taskmap |= {tid: time for tid, time in tasks} + self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])) + self._wakeup.set() + + def schedule_task(self, taskid: Taskid, timestamp: int) -> None: + """ + Insert the provided task into the tasklist. + If the new task has a lower timestamp than the next task, wakes up the monitor loop. + """ + if self._tasklist: + nextid = self._tasklist[-1] + wake = self._taskmap[nextid] >= timestamp + wake = wake or taskid == nextid + else: + wake = True + if taskid in self._taskmap: + self._tasklist.remove(taskid) + self._taskmap[taskid] = timestamp + bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t]) + if wake: + self._wakeup.set() + + def cancel_tasks(self, *taskids: Taskid) -> None: + """ + Remove all tasks with the given taskids from the tasklist. + If the next task has this taskid, wake up the monitor loop. + """ + taskids = set(taskids) + wake = (self._tasklist and self._tasklist[-1] in taskids) + self._tasklist = [tid for tid in self._tasklist if tid not in taskids] + for tid in taskids: + self._taskmap.pop(tid, None) + if wake: + self._wakeup.set() + + def start(self): + if self._monitor_task and not self._monitor_task.done(): + self._monitor_task.cancel() + # Start the monitor + self._monitor_task = asyncio.create_task(self.monitor()) + return self._monitor_task + + async def monitor(self): + """ + Start the monitor. + Executes the tasks in `self.tasks` at the specified time. + + This will shield task execution from cancellation + to avoid partial states. + """ + try: + while True: + self._wakeup.clear() + if not self._tasklist: + # No tasks left, just sleep until wakeup + await self._wakeup.wait() + else: + # Get the next task, sleep until wakeup or it is ready to run + nextid = self._tasklist[-1] + nexttime = self._taskmap[nextid] + sleep_for = nexttime - utc_now().timestamp() + try: + await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for) + except asyncio.TimeoutError: + # Ready to run the task + self._tasklist.pop() + self._taskmap.pop(nextid, None) + self._running[nextid] = asyncio.ensure_future(self._run(nextid)) + else: + # Wakeup task fired, loop again + continue + except asyncio.CancelledError: + # Log closure and wait for remaining tasks + # A second cancellation will also cancel the tasks + logger.debug( + f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. " + f"Waiting for {len(self._running)} running tasks to complete." + ) + await asyncio.gather(*self._running.values(), return_exceptions=True) + + async def _run(self, taskid: Taskid) -> None: + # Execute the task, respecting the ratelimit bucket + if self._bucket is not None: + # IMPLEMENTATION NOTE: + # Bucket.wait() should guarantee not more than n tasks/second are run + # and that a request directly afterwards will _not_ raise BucketFull + # Make sure that only one waiter is actually waiting on its sleep task + # The other waiters should be waiting on a lock around the sleep task + # Waiters are executed in wait-order, so if we only let a single waiter in + # we shouldn't get collisions. + # Furthermore, make sure we do _not_ pass back to the event loop after waiting + # Or we will lose thread-safety for BucketFull + await self._bucket.wait() + fut = asyncio.create_task(self.run_task(taskid)) + try: + await asyncio.shield(fut) + except asyncio.CancelledError: + raise + except Exception: + # Protect the monitor loop from any other exceptions + logger.exception( + f"Ignoring exception in task monitor {self.__class__.__name__} while " + f"executing " + ) + finally: + self._running.pop(taskid) + + async def run_task(self, taskid: Taskid): + """ + Execute the task with the given taskid. + + Default implementation executes `self.executor` if it exists, + otherwise raises NotImplementedError. + """ + if self.executor is not None: + await self.executor(taskid) + else: + raise NotImplementedError diff --git a/src/utils/ratelimits.py b/src/utils/ratelimits.py new file mode 100644 index 0000000..7322336 --- /dev/null +++ b/src/utils/ratelimits.py @@ -0,0 +1,173 @@ +import asyncio +import time +import logging + +from meta.errors import SafeCancellation + +from cachetools import TTLCache + +logger = logging.getLogger() + + + +class BucketFull(Exception): + """ + Throw when a requested Bucket is already full + """ + pass + + +class BucketOverFull(BucketFull): + """ + Throw when a requested Bucket is overfull + """ + pass + + +class Bucket: + __slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock') + + def __init__(self, max_level, empty_time): + self.max_level = max_level + self.empty_time = empty_time + self.leak_rate = max_level / empty_time + + self._level = 0 + self._last_checked = time.monotonic() + + self._last_full = False + self._wait_lock = asyncio.Lock() + + @property + def full(self) -> bool: + """ + Return whether the bucket is 'full', + that is, whether an immediate request against the bucket will raise `BucketFull`. + """ + self._leak() + return self._level + 1 > self.max_level + + @property + def overfull(self): + self._leak() + return self._level > self.max_level + + @property + def delay(self): + self._leak() + if self._level + 1 > self.max_level: + delay = (self._level + 1 - self.max_level) * self.leak_rate + else: + delay = 0 + return delay + + def _leak(self): + if self._level: + elapsed = time.monotonic() - self._last_checked + self._level = max(0, self._level - (elapsed * self.leak_rate)) + + self._last_checked = time.monotonic() + + def request(self): + self._leak() + if self._level > self.max_level: + raise BucketOverFull + elif self._level == self.max_level: + self._level += 1 + if self._last_full: + raise BucketOverFull + else: + self._last_full = True + raise BucketFull + else: + self._last_full = False + self._level += 1 + + def fill(self): + self._leak() + self._level = max(self._level, self.max_level + 1) + + async def wait(self): + """ + Wait until the bucket has room. + + Guarantees that a `request` directly afterwards will not raise `BucketFull`. + """ + # Wrapped in a lock so that waiters are correctly handled in wait-order + # Otherwise multiple waiters will have the same delay, + # and race for the wakeup after sleep. + # Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order + async with self._wait_lock: + # We do this in a loop in case asyncio.sleep throws us out early, + # or a synchronous request overflows the bucket while we are waiting. + while self.full: + await asyncio.sleep(self.delay) + + async def wrapped(self, coro): + await self.wait() + self.request() + await coro + + +class RateLimit: + def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)): + self.max_level = max_level + self.empty_time = empty_time + + self.error = error or "Too many requests, please slow down!" + self.buckets = cache + + def request_for(self, key): + if not (bucket := self.buckets.get(key, None)): + bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time) + + try: + bucket.request() + except BucketOverFull: + raise SafeCancellation(details="Bucket overflow") + except BucketFull: + raise SafeCancellation(self.error, details="Bucket full") + + def ward(self, member=True, key=None): + """ + Command ratelimit decorator. + """ + key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id)) + + def decorator(func): + async def wrapper(ctx, *args, **kwargs): + self.request_for(key(ctx)) + return await func(ctx, *args, **kwargs) + return wrapper + return decorator + + +async def limit_concurrency(aws, limit): + """ + Run provided awaitables concurrently, + ensuring that no more than `limit` are running at once. + """ + aws = iter(aws) + aws_ended = False + pending = set() + count = 0 + logger.debug("Starting limited concurrency executor") + + while pending or not aws_ended: + while len(pending) < limit and not aws_ended: + aw = next(aws, None) + if aw is None: + aws_ended = True + else: + pending.add(asyncio.create_task(aw)) + count += 1 + + if not pending: + break + + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + while done: + yield done.pop() + logger.debug(f"Completed {count} tasks") diff --git a/src/utils/ui/__init__.py b/src/utils/ui/__init__.py new file mode 100644 index 0000000..633b3a9 --- /dev/null +++ b/src/utils/ui/__init__.py @@ -0,0 +1,8 @@ +import asyncio +import logging + +logger = logging.getLogger(__name__) + +from .hooked import * +from .leo import * +from .micros import * diff --git a/src/utils/ui/hooked.py b/src/utils/ui/hooked.py new file mode 100644 index 0000000..075476d --- /dev/null +++ b/src/utils/ui/hooked.py @@ -0,0 +1,59 @@ +import time + +import discord +from discord.ui.item import Item +from discord.ui.button import Button + +from .leo import LeoUI + +__all__ = ( + 'HookedItem', + 'AButton', + 'AsComponents' +) + + +class HookedItem: + """ + Mixin for Item classes allowing an instance to be used as a callback decorator. + """ + def __init__(self, *args, pass_kwargs={}, **kwargs): + super().__init__(*args, **kwargs) + self.pass_kwargs = pass_kwargs + + def __call__(self, coro): + async def wrapped(interaction, **kwargs): + return await coro(interaction, self, **(self.pass_kwargs | kwargs)) + self.callback = wrapped + return self + + +class AButton(HookedItem, Button): + ... + + +class AsComponents(LeoUI): + """ + Simple container class to accept a number of Items and turn them into an attachable View. + """ + def __init__(self, *items, pass_kwargs={}, **kwargs): + super().__init__(**kwargs) + self.pass_kwargs = pass_kwargs + + for item in items: + self.add_item(item) + + async def _scheduled_task(self, item: Item, interaction: discord.Interaction): + try: + item._refresh_state(interaction, interaction.data) # type: ignore + + allow = await self.interaction_check(interaction) + if not allow: + return + + if self.timeout: + self.__timeout_expiry = time.monotonic() + self.timeout + + await item.callback(interaction, **self.pass_kwargs) + except Exception as e: + return await self.on_error(interaction, e, item) diff --git a/src/utils/ui/leo.py b/src/utils/ui/leo.py new file mode 100644 index 0000000..eebaea4 --- /dev/null +++ b/src/utils/ui/leo.py @@ -0,0 +1,485 @@ +from typing import List, Optional, Any, Dict +import asyncio +import logging +import time +from contextvars import copy_context, Context + +import discord +from discord.ui import Modal, View, Item + +from meta.logger import log_action_stack, logging_context +from meta.errors import SafeCancellation + +from . import logger +from ..lib import MessageArgs, error_embed + +__all__ = ( + 'LeoUI', + 'MessageUI', + 'LeoModal', + 'error_handler_for' +) + + +class LeoUI(View): + """ + View subclass for small-scale user interfaces. + + While a 'View' provides an interface for managing a collection of components, + a `LeoUI` may also manage a message, and potentially slave Views or UIs. + The `LeoUI` also exposes more advanced cleanup and timeout methods, + and preserves the context. + """ + + def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + + self._name = ui_name or self.__class__.__name__ + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name]) + + # List of slaved views to stop when this view stops + self._slaves: List[View] = [] + + # TODO: Replace this with a substitutable ViewLayout class + self._layout: Optional[tuple[tuple[Item, ...], ...]] = None + + @property + def _stopped(self) -> asyncio.Future: + """ + Return an future indicating whether the View has finished interacting. + + Currently exposes a hidden attribute of the underlying View. + May be reimplemented in future. + """ + return self._View__stopped + + def to_components(self) -> List[Dict[str, Any]]: + """ + Extending component generator to apply the set _layout, if it exists. + """ + if self._layout is not None: + # Alternative rendering using layout + components = [] + for i, row in enumerate(self._layout): + # Skip empty rows + if not row: + continue + + # Since we aren't relying on ViewWeights, manually check width here + if sum(item.width for item in row) > 5: + raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!") + + # Create the component dict for this row + components.append({ + 'type': 1, + 'components': [item.to_component_dict() for item in row] + }) + else: + components = super().to_components() + + return components + + def set_layout(self, *rows: tuple[Item, ...]) -> None: + """ + Set the layout of the rendered View as a matrix of items, + or more precisely, a list of action rows. + + This acts independently of the existing sorting with `_ViewWeights`, + and overrides the sorting if applied. + """ + self._layout = rows + + async def cleanup(self): + """ + Coroutine to run when timeing out, stopping, or cancelling. + Generally cleans up any open resources, and removes any leftover components. + """ + logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'}) + return None + + def stop(self): + """ + Extends View.stop() to also stop all the slave views. + Note that stopping is idempotent, so it is okay if close() also calls stop(). + """ + for slave in self._slaves: + slave.stop() + super().stop() + + async def close(self, msg=None): + self.stop() + await self.cleanup() + + async def pre_timeout(self): + """ + Task to execute before actually timing out. + This may cancel the timeout by refreshing or rescheduling it. + (E.g. to ask the user whether they want to keep going.) + + Default implementation does nothing. + """ + return None + + async def on_timeout(self): + """ + Task to execute after timeout is complete. + Default implementation calls cleanup. + """ + await self.cleanup() + + async def __dispatch_timeout(self): + """ + This essentially extends View._dispatch_timeout, + to include a pre_timeout task + which may optionally refresh and hence cancel the timeout. + """ + if self._View__stopped.done(): + # We are already stopped, nothing to do + return + + with logging_context(action='Timeout'): + try: + await self.pre_timeout() + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: + pass + except Exception: + logger.exception( + "Unhandled error caught while dispatching timeout for {self!r}.", + extra={'with_ctx': True, 'action': 'Error'} + ) + + # Check if we still need to timeout + if self.timeout is None: + # The timeout was removed entirely, silently walk away + return + + if self._View__stopped.done(): + # We stopped while waiting for the pre timeout. + # Or maybe another thread timed us out + # Either way, we are done here + return + + now = time.monotonic() + if self._View__timeout_expiry is not None and now < self._View__timeout_expiry: + # The timeout was extended, make sure the timeout task is running then fade away + if self._View__timeout_task is None or self._View__timeout_task.done(): + self._View__timeout_task = asyncio.create_task(self._View__timeout_task_impl()) + else: + # Actually timeout, and call the post-timeout task for cleanup. + self._really_timeout() + await self.on_timeout() + + def _dispatch_timeout(self): + """ + Overriding timeout method completely, to support interactive flow during timeout, + and optional refreshing of the timeout. + """ + return self._context.run(asyncio.create_task, self.__dispatch_timeout()) + + def _really_timeout(self): + """ + Actuallly times out the View. + This copies View._dispatch_timeout, apart from the `on_timeout` dispatch, + which is now handled by `__dispatch_timeout`. + """ + if self._View__stopped.done(): + return + + if self._View__cancel_callback: + self._View__cancel_callback(self) + self._View__cancel_callback = None + + self._View__stopped.set_result(True) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item): + """ + Default LeoUI error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except SafeCancellation as e: + if e.msg and not interaction.is_expired(): + try: + if interaction.response.is_done(): + await interaction.followup.send( + embed=error_embed(e.msg), + ephemeral=True + ) + else: + await interaction.response.send_message( + embed=error_embed(e.msg), + ephemeral=True + ) + except discord.HTTPException: + pass + logger.debug( + f"Caught a safe cancellation from LeoUI: {e.details}", + extra={'action': 'Cancel'} + ) + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r} from interaction: " + f"{interaction.data}", + extra={'with_ctx': True, 'action': 'UIError'} + ) + # Explicitly handle the bugsplat ourselves + splat = interaction.client.tree.bugsplat(interaction, error) + await interaction.client.tree.error_reply(interaction, splat) + + +class MessageUI(LeoUI): + """ + Simple single-message LeoUI, intended as a framework for UIs + attached to a single interaction response. + + UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`. + """ + + def __init__(self, *args, callerid: Optional[int] = None, **kwargs): + super().__init__(*args, **kwargs) + + # ----- UI state ----- + # User ID of the original caller (e.g. command author). + # Mainly used for interaction usage checks and logging + self._callerid = callerid + + # Original interaction, if this UI is sent as an interaction response + self._original: discord.Interaction = None + + # Message holding the UI, when the UI is sent attached to a followup + self._message: discord.Message = None + + # Refresh lock, to avoid cache collisions on refresh + self._refresh_lock = asyncio.Lock() + + @property + def channel(self): + if self._original is not None: + return self._original.channel + else: + return self._message.channel + + # ----- UI API ----- + async def run(self, interaction: discord.Interaction, **kwargs): + """ + Run the UI as a response or followup to the given interaction. + + Should be extended if more complex run mechanics are needed + (e.g. registering listeners or setting up caches). + """ + await self.draw(interaction, **kwargs) + + async def refresh(self, *args, thinking: Optional[discord.Interaction] = None, **kwargs): + """ + Reload and redraw this UI. + + Primarily a hook-method for use by parents and other controllers. + Performs a full data and reload and refresh (maintaining UI state, e.g. page n). + """ + async with self._refresh_lock: + # Reload data + await self.reload() + # Redraw UI message + await self.redraw(thinking=thinking) + + async def quit(self): + """ + Quit the UI. + + This usually involves removing the original message, + and stopping or closing the underlying View. + """ + for child in self._slaves: + # TODO: Better to use duck typing or interface typing + if isinstance(child, MessageUI) and not child.is_finished(): + asyncio.create_task(child.quit()) + try: + if self._original is not None and not self._original.is_expired(): + await self._original.delete_original_response() + self._original = None + if self._message is not None: + await self._message.delete() + self._message = None + except discord.HTTPException: + pass + + # Note close() also runs cleanup and stop + await self.close() + + # ----- UI Flow ----- + async def interaction_check(self, interaction: discord.Interaction): + """ + Check the given interaction is authorised to use this UI. + + Default implementation simply checks that the interaction is + from the original caller. + Extend for more complex logic. + """ + return interaction.user.id == self._callerid + + async def make_message(self) -> MessageArgs: + """ + Create the UI message body, depening on the current state. + + Called upon each redraw. + Should handle caching if message construction is for some reason intensive. + + Must be implemented by concrete UI subclasses. + """ + raise NotImplementedError + + async def refresh_layout(self): + """ + Asynchronously refresh the message components, + and explicitly set the message component layout. + + Called just before redrawing, before `make_message`. + + Must be implemented by concrete UI subclasses. + """ + raise NotImplementedError + + async def reload(self): + """ + Reload and recompute the underlying data for this UI. + + Must be implemented by concrete UI subclasses. + """ + raise NotImplementedError + + async def draw(self, interaction, force_followup=False, **kwargs): + """ + Send the UI as a response or followup to the given interaction. + + If the interaction has been responded to, or `force_followup` is set, + creates a followup message instead of a response to the interaction. + """ + # Initial data loading + await self.reload() + # Set the UI layout + await self.refresh_layout() + # Fetch message arguments + args = await self.make_message() + + as_followup = force_followup or interaction.response.is_done() + if as_followup: + self._message = await interaction.followup.send(**args.send_args, **kwargs, view=self) + else: + self._original = interaction + await interaction.response.send_message(**args.send_args, **kwargs, view=self) + + async def send(self, channel: discord.abc.Messageable, **kwargs): + """ + Alternative to draw() which uses a discord.abc.Messageable. + """ + await self.reload() + await self.refresh_layout() + args = await self.make_message() + self._message = await channel.send(**args.send_args, view=self) + + async def _redraw(self, args): + if self._original and not self._original.is_expired(): + await self._original.edit_original_response(**args.edit_args, view=self) + elif self._message: + await self._message.edit(**args.edit_args, view=self) + else: + # Interaction expired or already closed. Quietly cleanup. + await self.close() + + async def redraw(self, thinking: Optional[discord.Interaction] = None): + """ + Update the output message for this UI. + + If a thinking interaction is provided, deletes the response while redrawing. + """ + await self.refresh_layout() + args = await self.make_message() + + if thinking is not None and not thinking.is_expired() and thinking.response.is_done(): + asyncio.create_task(thinking.delete_original_response()) + + try: + await self._redraw(args) + except discord.HTTPException as e: + # Unknown communication error, nothing we can reliably do. Exit quietly. + logger.warning( + f"Unexpected UI redraw failure occurred in {self}: {repr(e)}", + ) + await self.close() + + async def cleanup(self): + """ + Remove message components from interaction response, if possible. + + Extend to remove listeners or clean up caches. + `cleanup` is always called when the UI is exiting, + through timeout or user-driven closure. + """ + try: + if self._original is not None and not self._original.is_expired(): + await self._original.edit_original_response(view=None) + self._original = None + if self._message is not None: + await self._message.edit(view=None) + self._message = None + except discord.HTTPException: + pass + + +class LeoModal(Modal): + """ + Context-aware Modal class. + """ + def __init__(self, *args, context: Optional[Context] = None, **kwargs): + super().__init__(**kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__]) + + def _dispatch_submit(self, *args, **kwargs): + """ + Extending event dispatch to run in the instantiation context. + """ + return self._context.run(super()._dispatch_submit, *args, **kwargs) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + """ + Default LeoModal error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in {self!r}. Interaction: {interaction.data}", + extra={'with_ctx': True, 'action': 'ModalError'} + ) + # Explicitly handle the bugsplat ourselves + splat = interaction.client.tree.bugsplat(interaction, error) + await interaction.client.tree.error_reply(interaction, splat) + + +def error_handler_for(exc): + def wrapper(coro): + coro._ui_error_handler_for_ = exc + return coro + return wrapper diff --git a/src/utils/ui/micros.py b/src/utils/ui/micros.py new file mode 100644 index 0000000..eebf418 --- /dev/null +++ b/src/utils/ui/micros.py @@ -0,0 +1,329 @@ +from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict +import functools +import asyncio + +import discord +from discord.ui import TextInput +from discord.ui.button import button + +from meta.logger import logging_context +from meta.errors import ResponseTimedOut + +from .leo import LeoModal, LeoUI + +__all__ = ( + 'FastModal', + 'ModalRetryUI', + 'Confirm', + 'input', +) + + +class FastModal(LeoModal): + __class_error_handlers__ = [] + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + error_handlers = {} + for base in reversed(cls.__mro__): + for name, member in base.__dict__.items(): + if hasattr(member, '_ui_error_handler_for_'): + error_handlers[name] = member + + cls.__class_error_handlers__ = list(error_handlers.values()) + + def __init__error_handlers__(self): + handlers = {} + for handler in self.__class_error_handlers__: + handlers[handler._ui_error_handler_for_] = functools.partial(handler, self) + return handlers + + def __init__(self, *items: TextInput, **kwargs): + super().__init__(**kwargs) + for item in items: + self.add_item(item) + self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() + self._waiters: List[Callable[[discord.Interaction], Coroutine]] = [] + self._error_handlers = self.__init__error_handlers__() + + def error_handler(self, exception): + def wrapper(coro): + self._error_handlers[exception] = coro + return coro + return wrapper + + async def wait_for(self, check=None, timeout=None): + # Wait for _result or timeout + # If we timeout, or the view times out, raise TimeoutError + # Otherwise, return the Interaction + # This allows multiple listeners and callbacks to wait on + while True: + result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) + if check is not None: + if not check(result): + continue + return result + + async def on_timeout(self): + self._result.set_exception(asyncio.TimeoutError) + + def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): + def wrapper(coro): + async def wrapped_callback(interaction): + with logging_context(action=coro.__name__): + if check is not None: + if not check(interaction): + return + try: + await coro(interaction, *pass_args, **pass_kwargs) + except Exception: + raise + finally: + if once: + self._waiters.remove(wrapped_callback) + self._waiters.append(wrapped_callback) + return wrapper + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + try: + # First let our error handlers have a go + # If there is no handler for this error, or the handlers themselves error, + # drop to the superclass error handler implementation. + try: + raise error + except tuple(self._error_handlers.keys()) as e: + # If an error handler is registered for this exception, run it. + for cls, handler in self._error_handlers.items(): + if isinstance(e, cls): + await handler(interaction, e) + except Exception as error: + await super().on_error(interaction, error) + + async def on_submit(self, interaction): + print("On submit") + old_result = self._result + self._result = asyncio.get_event_loop().create_future() + old_result.set_result(interaction) + + tasks = [] + for waiter in self._waiters: + task = asyncio.create_task( + waiter(interaction), + name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}" + ) + tasks.append(task) + if tasks: + await asyncio.gather(*tasks) + + +async def input( + interaction: discord.Interaction, + title: str, + question: Optional[str] = None, + field: Optional[TextInput] = None, + timeout=180, + **kwargs, +) -> tuple[discord.Interaction, str]: + """ + Spawn a modal to accept input. + Returns an (interaction, value) pair, with interaction not yet responded to. + May raise asyncio.TimeoutError if the view times out. + """ + if field is None: + field = TextInput( + label=kwargs.get('label', question), + **kwargs + ) + modal = FastModal( + field, + title=title, + timeout=timeout + ) + await interaction.response.send_modal(modal) + interaction = await modal.wait_for() + return (interaction, field.value) + + +class ModalRetryUI(LeoUI): + def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.modal = modal + self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)} + + self.message = message + + self._interaction = None + + if label is not None: + self.retry_button.label = label + + @property + def embed(self): + return discord.Embed( + title="Uh-Oh!", + description=self.message, + colour=discord.Colour.red() + ) + + async def respond_to(self, interaction): + self._interaction = interaction + if interaction.response.is_done(): + await interaction.followup.send(embed=self.embed, ephemeral=True, view=self) + else: + await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self) + + @button(label="Retry") + async def retry_button(self, interaction, butt): + # Setting these here so they don't update in the meantime + for item, value in self.item_values.items(): + item.default = value + if self._interaction is not None: + await self._interaction.delete_original_response() + self._interaction = None + await interaction.response.send_modal(self.modal) + await self.close() + + +class Confirm(LeoUI): + """ + Micro UI class implementing a confirmation question. + + Parameters + ---------- + confirm_msg: str + The confirmation question to ask from the user. + This is set as the description of the `embed` property. + The `embed` may be further modified if required. + permitted_id: Optional[int] + The user id allowed to access this interaction. + Other users will recieve an access denied error message. + defer: bool + Whether to defer the interaction response while handling the button. + It may be useful to set this to `False` to obtain manual control + over the interaction response flow (e.g. to send a modal or ephemeral message). + The button press interaction may be accessed through `Confirm.interaction`. + Default: True + + Example + ------- + ``` + confirm = Confirm("Are you sure?", ctx.author.id) + confirm.embed.colour = discord.Colour.red() + confirm.confirm_button.label = "Yes I am sure" + confirm.cancel_button.label = "No I am not sure" + + try: + result = await confirm.ask(ctx.interaction, ephemeral=True) + except ResultTimedOut: + return + ``` + """ + def __init__( + self, + confirm_msg: str, + permitted_id: Optional[int] = None, + defer: bool = True, + **kwargs + ): + super().__init__(**kwargs) + self.confirm_msg = confirm_msg + self.permitted_id = permitted_id + self.defer = defer + + self._embed: Optional[discord.Embed] = None + self._result: asyncio.Future[bool] = asyncio.Future() + + # Indicates whether we should delete the message or the interaction response + self._is_followup: bool = False + self._original: Optional[discord.Interaction] = None + self._message: Optional[discord.Message] = None + + async def interaction_check(self, interaction: discord.Interaction): + return (self.permitted_id is None) or interaction.user.id == self.permitted_id + + async def on_timeout(self): + # Propagate timeout to result Future + self._result.set_exception(ResponseTimedOut) + await self.cleanup() + + async def cleanup(self): + """ + Cleanup the confirmation prompt by deleting it, if possible. + Ignores any Discord errors that occur during the process. + """ + try: + if self._is_followup and self._message: + await self._message.delete() + elif not self._is_followup and self._original and not self._original.is_expired(): + await self._original.delete_original_response() + except discord.HTTPException: + # A user probably already deleted the message + # Anything could have happened, just ignore. + pass + + @button(label="Confirm") + async def confirm_button(self, interaction: discord.Interaction, press): + if self.defer: + await interaction.response.defer() + self._result.set_result(True) + await self.close() + + @button(label="Cancel") + async def cancel_button(self, interaction: discord.Interaction, press): + if self.defer: + await interaction.response.defer() + self._result.set_result(False) + await self.close() + + @property + def embed(self): + """ + Confirmation embed shown to the user. + This is cached, and may be modifed directly through the usual EmbedProxy API, + or explicitly overwritten. + """ + if self._embed is None: + self._embed = discord.Embed( + colour=discord.Colour.orange(), + description=self.confirm_msg + ) + return self._embed + + @embed.setter + def embed(self, value): + self._embed = value + + async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs): + """ + Send this confirmation prompt in response to the provided interaction. + Extra keyword arguments are passed to `Interaction.response.send_message` + or `Interaction.send_followup`, depending on whether + the provided interaction has already been responded to. + + The `epehemeral` argument is handled specially, + since the question message can only be deleted through `Interaction.delete_original_response`. + + Waits on and returns the internal `result` Future. + + Returns: bool + True if the user pressed the confirm button. + False if the user pressed the cancel button. + Raises: + ResponseTimedOut: + If the user does not respond before the UI times out. + """ + self._original = interaction + if interaction.response.is_done(): + # Interaction already responded to, send a follow up + if ephemeral: + raise ValueError("Cannot send an ephemeral response to a used interaction.") + self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self) + self._is_followup = True + else: + await interaction.response.send_message( + embed=self.embed, ephemeral=ephemeral, **kwargs, view=self + ) + self._is_followup = False + return await self._result + +# TODO: Selector MicroUI for displaying options (<= 25) diff --git a/src/wards.py b/src/wards.py new file mode 100644 index 0000000..4f1647f --- /dev/null +++ b/src/wards.py @@ -0,0 +1,9 @@ +from meta import LionBot + +# Raw checks, return True/False depending on whether they pass +async def sys_admin(bot: LionBot, userid: int): + """ + Checks whether the context author is listed in the configuration file as a bot admin. + """ + admins = bot.config.bot.getintlist('admins') + return userid in admins