From c33cc325cef8f901203a1ffa9638fbb8ea1529e5 Mon Sep 17 00:00:00 2001 From: Tuxverse <> Date: Mon, 25 Aug 2025 13:19:00 +1000 Subject: [PATCH] Initial commit --- .gitignore | 151 ++++ .gitmodules | 9 + data/.gitignore | 0 data/schema.sql | 38 + requirements.txt | 7 + scripts/start_bot.py | 12 + scripts/start_debug.py | 35 + src/babel/__init__.py | 3 + src/babel/enums.py | 81 ++ src/babel/translator.py | 108 +++ src/babel/utils.py | 20 + src/bot.py | 107 +++ src/botdata.py | 26 + src/constants.py | 7 + src/core/__init__.py | 8 + src/core/cog.py | 76 ++ src/core/data.py | 45 ++ src/core/setting_types.py | 227 ++++++ src/meta/LionBot.py | 373 +++++++++ src/meta/LionCog.py | 58 ++ src/meta/LionContext.py | 195 +++++ src/meta/LionTree.py | 148 ++++ src/meta/__init__.py | 15 + src/meta/app.py | 32 + src/meta/args.py | 35 + src/meta/config.py | 146 ++++ src/meta/context.py | 20 + src/meta/errors.py | 64 ++ src/meta/logger.py | 468 +++++++++++ src/meta/monitor.py | 139 ++++ src/meta/sharding.py | 35 + src/modules/__init__.py | 9 + src/settings/__init__.py | 7 + src/settings/base.py | 166 ++++ src/settings/data.py | 233 ++++++ src/settings/groups.py | 204 +++++ src/settings/mock.py | 13 + src/settings/setting_types.py | 1393 +++++++++++++++++++++++++++++++++ src/settings/ui.py | 512 ++++++++++++ src/utils/__init__.py | 0 src/utils/ansi.py | 97 +++ src/utils/data.py | 165 ++++ src/utils/lib.py | 879 +++++++++++++++++++++ src/utils/monitor.py | 191 +++++ src/utils/ratelimits.py | 173 ++++ src/utils/ui/__init__.py | 12 + src/utils/ui/hooked.py | 59 ++ src/utils/ui/leo.py | 485 ++++++++++++ src/utils/ui/micros.py | 329 ++++++++ src/utils/ui/msgeditor.py | 1070 +++++++++++++++++++++++++ src/wards.py | 9 + 51 files changed, 8694 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 data/.gitignore create mode 100644 data/schema.sql create mode 100644 requirements.txt create mode 100755 scripts/start_bot.py create mode 100755 scripts/start_debug.py create mode 100644 src/babel/__init__.py create mode 100644 src/babel/enums.py create mode 100644 src/babel/translator.py create mode 100644 src/babel/utils.py create mode 100644 src/bot.py create mode 100644 src/botdata.py create mode 100644 src/constants.py create mode 100644 src/core/__init__.py create mode 100644 src/core/cog.py create mode 100644 src/core/data.py create mode 100644 src/core/setting_types.py create mode 100644 src/meta/LionBot.py create mode 100644 src/meta/LionCog.py create mode 100644 src/meta/LionContext.py create mode 100644 src/meta/LionTree.py create mode 100644 src/meta/__init__.py create mode 100644 src/meta/app.py create mode 100644 src/meta/args.py create mode 100644 src/meta/config.py create mode 100644 src/meta/context.py create mode 100644 src/meta/errors.py create mode 100644 src/meta/logger.py create mode 100644 src/meta/monitor.py create mode 100644 src/meta/sharding.py create mode 100644 src/modules/__init__.py create mode 100644 src/settings/__init__.py create mode 100644 src/settings/base.py create mode 100644 src/settings/data.py create mode 100644 src/settings/groups.py create mode 100644 src/settings/mock.py create mode 100644 src/settings/setting_types.py create mode 100644 src/settings/ui.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/ansi.py create mode 100644 src/utils/data.py create mode 100644 src/utils/lib.py create mode 100644 src/utils/monitor.py create mode 100644 src/utils/ratelimits.py create mode 100644 src/utils/ui/__init__.py create mode 100644 src/utils/ui/hooked.py create mode 100644 src/utils/ui/leo.py create mode 100644 src/utils/ui/micros.py create mode 100644 src/utils/ui/msgeditor.py create mode 100644 src/wards.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d0cda90 --- /dev/null +++ b/.gitignore @@ -0,0 +1,151 @@ +src/modules/test/* + +pending-rewrite/ +logs/* +notes/* +tmp/* +output/* +locales/domains + +.idea/* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +config/** diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..562cc5e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "src/modules/voicefix"] + path = src/modules/voicefix + url = git@github.com:Intery/StudyLion-voicefix.git +[submodule "src/modules/streamalerts"] + path = src/modules/streamalerts + url = git@github.com:Intery/StudyLion-streamalerts.git +[submodule "src/data"] + path = src/data + url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/data/schema.sql b/data/schema.sql new file mode 100644 index 0000000..52c49c5 --- /dev/null +++ b/data/schema.sql @@ -0,0 +1,38 @@ +-- Metadata {{{ +CREATE TABLE version_history( + component TEXT NOT NULL, + from_version INTEGER NOT NULL, + to_version INTEGER NOT NULL, + author TEXT NOT NULL, + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), +); +INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation'); + + +CREATE OR REPLACE FUNCTION update_timestamp_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW._timestamp = (now() at time zone 'utc'); + RETURN NEW; +END; +$$ language 'plpgsql'; +-- }}} + +-- App metadata {{{ + +CREATE TABLE app_config( + appname TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE bot_config( + appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE, + sponsor_prompt TEXT, + sponsor_message TEXT, + default_skin TEXT +); +-- }}} + +-- TODO: Profile data + +-- vim: set fdm=marker: diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..15f10a4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +aiohttp +cachetools +configparser +discord.py [voice] +iso8601 +psycopg[pool] +pytz diff --git a/scripts/start_bot.py b/scripts/start_bot.py new file mode 100755 index 0000000..49d6ad1 --- /dev/null +++ b/scripts/start_bot.py @@ -0,0 +1,12 @@ +# !/bin/python3 + +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + + +if __name__ == '__main__': + from bot import _main + _main() diff --git a/scripts/start_debug.py b/scripts/start_debug.py new file mode 100755 index 0000000..d4837d0 --- /dev/null +++ b/scripts/start_debug.py @@ -0,0 +1,35 @@ +# !/bin/python3 + +import sys +import os +import tracemalloc +import asyncio + + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + +tracemalloc.start() + + +def loop_exception_handler(loop, context): + print(context) + task: asyncio.Task = context.get('task', None) + if task is not None: + addendum = f"" + message = context.get('message', '') + context['message'] = ' '.join((message, addendum)) + loop.default_exception_handler(context) + + +def main(): + loop = asyncio.get_event_loop() + loop.set_exception_handler(loop_exception_handler) + loop.set_debug(enabled=True) + + from bot import _main + _main() + + +if __name__ == '__main__': + main() diff --git a/src/babel/__init__.py b/src/babel/__init__.py new file mode 100644 index 0000000..48da68e --- /dev/null +++ b/src/babel/__init__.py @@ -0,0 +1,3 @@ +from .translator import SOURCE_LOCALE, LeoBabel, LocalBabel, LazyStr, ctx_locale, ctx_translator + +babel = LocalBabel('babel') diff --git a/src/babel/enums.py b/src/babel/enums.py new file mode 100644 index 0000000..9e73da3 --- /dev/null +++ b/src/babel/enums.py @@ -0,0 +1,81 @@ +from enum import Enum +from . import babel + +_p = babel._p + + +class LocaleMap(Enum): + american_english = 'en-US' + british_english = 'en-GB' + bulgarian = 'bg' + chinese = 'zh-CN' + taiwan_chinese = 'zh-TW' + croatian = 'hr' + czech = 'cs' + danish = 'da' + dutch = 'nl' + finnish = 'fi' + french = 'fr' + german = 'de' + greek = 'el' + hindi = 'hi' + hungarian = 'hu' + italian = 'it' + japanese = 'ja' + korean = 'ko' + lithuanian = 'lt' + norwegian = 'no' + polish = 'pl' + brazil_portuguese = 'pt-BR' + romanian = 'ro' + russian = 'ru' + spain_spanish = 'es-ES' + swedish = 'sv-SE' + thai = 'th' + turkish = 'tr' + ukrainian = 'uk' + vietnamese = 'vi' + hebrew = 'he-IL' + + +# Original Discord names +locale_names = { + 'id': (_p('localenames|locale:id', "Indonesian"), "Bahasa Indonesia"), + 'da': (_p('localenames|locale:da', "Danish"), "Dansk"), + 'de': (_p('localenames|locale:de', "German"), "Deutsch"), + 'en-GB': (_p('localenames|locale:en-GB', "English, UK"), "English, UK"), + 'en-US': (_p('localenames|locale:en-US', "English, US"), "English, US"), + 'es-ES': (_p('localenames|locale:es-ES', "Spanish"), "Español"), + 'fr': (_p('localenames|locale:fr', "French"), "Français"), + 'hr': (_p('localenames|locale:hr', "Croatian"), "Hrvatski"), + 'it': (_p('localenames|locale:it', "Italian"), "Italiano"), + 'lt': (_p('localenames|locale:lt', "Lithuanian"), "Lietuviškai"), + 'hu': (_p('localenames|locale:hu', "Hungarian"), "Magyar"), + 'nl': (_p('localenames|locale:nl', "Dutch"), "Nederlands"), + 'no': (_p('localenames|locale:no', "Norwegian"), "Norsk"), + 'pl': (_p('localenames|locale:pl', "Polish"), "Polski"), + 'pt-BR': (_p('localenames|locale:pt-BR', "Portuguese, Brazilian"), "Português do Brasil"), + 'ro': (_p('localenames|locale:ro', "Romanian, Romania"), "Română"), + 'fi': (_p('localenames|locale:fi', "Finnish"), "Suomi"), + 'sv-SE': (_p('localenames|locale:sv-SE', "Swedish"), "Svenska"), + 'vi': (_p('localenames|locale:vi', "Vietnamese"), "Tiếng Việt"), + 'tr': (_p('localenames|locale:tr', "Turkish"), "Türkçe"), + 'cs': (_p('localenames|locale:cs', "Czech"), "Čeština"), + 'el': (_p('localenames|locale:el', "Greek"), "Ελληνικά"), + 'bg': (_p('localenames|locale:bg', "Bulgarian"), "български"), + 'ru': (_p('localenames|locale:ru', "Russian"), "Pусский"), + 'uk': (_p('localenames|locale:uk', "Ukrainian"), "Українська"), + 'hi': (_p('localenames|locale:hi', "Hindi"), "हिन्दी"), + 'th': (_p('localenames|locale:th', "Thai"), "ไทย"), + 'zh-CN': (_p('localenames|locale:zh-CN', "Chinese, China"), "中文"), + 'ja': (_p('localenames|locale:ja', "Japanese"), "日本語"), + 'zh-TW': (_p('localenames|locale:zh-TW', "Chinese, Taiwan"), "繁體中文"), + 'ko': (_p('localenames|locale:ko', "Korean"), "한국어"), +} + +# More names for languages not supported by Discord +locale_names |= { + 'he': (_p('localenames|locale:he', "Hebrew"), "Hebrew"), + 'he-IL': (_p('localenames|locale:he-IL', "Hebrew"), "Hebrew"), + 'ceaser': (_p('localenames|locale:test', "Test Language"), "dfbtfs"), +} diff --git a/src/babel/translator.py b/src/babel/translator.py new file mode 100644 index 0000000..f3ec1d0 --- /dev/null +++ b/src/babel/translator.py @@ -0,0 +1,108 @@ +from typing import Optional +import logging +from contextvars import ContextVar +from collections import defaultdict +from enum import Enum + +import gettext + +from discord.app_commands import Translator, locale_str +from discord.enums import Locale + + +logger = logging.getLogger(__name__) + + +SOURCE_LOCALE = 'en_GB' +ctx_locale: ContextVar[str] = ContextVar('locale', default=SOURCE_LOCALE) +ctx_translator: ContextVar['LeoBabel'] = ContextVar('translator', default=None) # type: ignore + +null = gettext.NullTranslations() + + +class LeoBabel(Translator): + def __init__(self): + self.supported_locales = {loc.name for loc in Locale} + self.supported_domains = {} + self.translators = defaultdict(dict) # locale -> domain -> GNUTranslator + + async def load(self): + pass + + async def unload(self): + self.translators.clear() + + def get_translator(self, locale: Optional[str], domain): + return null + + def t(self, lazystr, locale=None): + return lazystr._translate_with(null) + + async def translate(self, string: locale_str, locale: Locale, context): + if not isinstance(string, LazyStr): + return string + else: + return string.message + +ctx_translator.set(LeoBabel()) + +class Method(Enum): + GETTEXT = 'gettext' + NGETTEXT = 'ngettext' + PGETTEXT = 'pgettext' + NPGETTEXT = 'npgettext' + + +class LocalBabel: + def __init__(self, domain): + self.domain = domain + + @property + def methods(self): + return (self._, self._n, self._p, self._np) + + def _(self, message): + return LazyStr(Method.GETTEXT, message, domain=self.domain) + + def _n(self, singular, plural, n): + return LazyStr(Method.NGETTEXT, singular, plural, n, domain=self.domain) + + def _p(self, context, message): + return LazyStr(Method.PGETTEXT, context, message, domain=self.domain) + + def _np(self, context, singular, plural, n): + return LazyStr(Method.NPGETTEXT, context, singular, plural, n, domain=self.domain) + + +class LazyStr(locale_str): + __slots__ = ('method', 'args', 'domain', 'locale') + + def __init__(self, method, *args, locale=None, domain=None): + self.method = method + self.args = args + self.domain = domain + self.locale = locale + + @property + def message(self): + return self._translate_with(null) + + @property + def extras(self): + return {'locale': self.locale, 'domain': self.domain} + + def __str__(self): + return self.message + + def _translate_with(self, translator: gettext.GNUTranslations): + method = getattr(translator, self.method.value) + return method(*self.args) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.method}, {self.args!r}, locale={self.locale}, domain={self.domain})' + + def __eq__(self, obj: object) -> bool: + return isinstance(obj, locale_str) and self.message == obj.message + + def __hash__(self) -> int: + return hash(self.args) diff --git a/src/babel/utils.py b/src/babel/utils.py new file mode 100644 index 0000000..e3faa1c --- /dev/null +++ b/src/babel/utils.py @@ -0,0 +1,20 @@ +from .translator import ctx_translator +from . import babel + +_, _p, _np = babel._, babel._p, babel._np + + +MONTHS = _p( + 'utils|months', + "January,February,March,April,May,June,July,August,September,October,November,December" +) + +SHORT_MONTHS = _p( + 'utils|short_months', + "Jan,Feb,Mar,Apr,May,Jun,Jul,Aug,Sep,Oct,Nov,Dec" +) + + +def local_month(month, short=False): + string = MONTHS if not short else SHORT_MONTHS + return ctx_translator.get().t(string).split(',')[month-1] diff --git a/src/bot.py b/src/bot.py new file mode 100644 index 0000000..2fa4221 --- /dev/null +++ b/src/bot.py @@ -0,0 +1,107 @@ +import asyncio +import logging + +import aiohttp +import discord +from discord.ext import commands + +from meta import LionBot, conf, sharding, appname +from meta.app import shardname +from meta.logger import log_context, log_action_stack, setup_main_logger +from meta.context import ctx_bot +from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus + +from data import Database + + +for name in conf.config.options('LOGGING_LEVELS', no_defaults=True): + logging.getLogger(name).setLevel(conf.logging_levels[name]) + + +logging_queue = setup_main_logger() + + +logger = logging.getLogger(__name__) + +db = Database(conf.data['args']) + + +async def _data_monitor() -> ComponentStatus: + """ + Component monitor callback for the database. + """ + data = { + 'stats': str(db.pool.get_stats()) + } + if not db.pool._opened: + level = StatusLevel.WAITING + info = "(WAITING) Database Pool is not opened." + elif db.pool._closed: + level = StatusLevel.ERRORED + info = "(ERROR) Database Pool is closed." + else: + level = StatusLevel.OKAY + info = "(OK) Database Pool statistics: {stats}" + return ComponentStatus(level, info, info, data) + + +async def main(): + log_action_stack.set(("Initialising",)) + logger.info("Initialising StudyLion") + + intents = discord.Intents.all() + intents.members = True + intents.message_content = True + intents.presences = False + + async with db.open(): + + async with aiohttp.ClientSession() as session: + async with LionBot( + command_prefix=conf.bot.get('prefix', '!!'), + intents=intents, + appname=appname, + shardname=shardname, + db=db, + config=conf, + initial_extensions=[ + 'core', + 'modules', + ], + web_client=session, + testing_guilds=conf.bot.getintlist('admin_guilds'), + shard_id=sharding.shard_number, + shard_count=sharding.shard_count, + help_command=None, + proxy=conf.bot.get('proxy', None), + chunk_guilds_at_startup=True, + ) as lionbot: + ctx_bot.set(lionbot) + lionbot.system_monitor.add_component( + ComponentMonitor('Database', _data_monitor) + ) + try: + log_context.set(f"APP: {appname}") + logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) + await lionbot.start(conf.bot['TOKEN']) + except asyncio.CancelledError: + log_context.set(f"APP: {appname}") + logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) + + +def _main(): + from signal import SIGINT, SIGTERM + + loop = asyncio.get_event_loop() + main_task = asyncio.ensure_future(main()) + for signal in [SIGINT, SIGTERM]: + loop.add_signal_handler(signal, main_task.cancel) + try: + loop.run_until_complete(main_task) + finally: + loop.close() + logging.shutdown() + + +if __name__ == '__main__': + _main() diff --git a/src/botdata.py b/src/botdata.py new file mode 100644 index 0000000..025e44d --- /dev/null +++ b/src/botdata.py @@ -0,0 +1,26 @@ +from data import Registry, RowModel, Table +from data.columns import String, Timestamp, Integer, Bool + + +class VersionHistory(RowModel): + """ + CREATE TABLE version_history( + component TEXT NOT NULL, + from_version INTEGER NOT NULL, + to_version INTEGER NOT NULL, + author TEXT NOT NULL, + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ); + """ + _tablename_ = 'version_history' + _cache_ = {} + + component = String() + from_version = Integer() + to_version = Integer() + author = String() + _timestamp = Timestamp() + + +class BotData(Registry): + version_history = VersionHistory.table diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..ad63a2c --- /dev/null +++ b/src/constants.py @@ -0,0 +1,7 @@ +CONFIG_FILE = "config/bot.conf" + +HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png" + +SCHEMA_VERSIONS = { + 'ROOT': 1, +} diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..6be5d80 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,8 @@ +from babel import LocalBabel + +babel = LocalBabel('core') + +async def setup(bot): + from .cog import CoreCog + + await bot.add_cog(CoreCog(bot)) diff --git a/src/core/cog.py b/src/core/cog.py new file mode 100644 index 0000000..b5eab52 --- /dev/null +++ b/src/core/cog.py @@ -0,0 +1,76 @@ +import logging +from typing import Optional +from collections import defaultdict +from weakref import WeakValueDictionary + +import discord +import discord.app_commands as appcmd + +from meta import LionBot, LionCog, LionContext +from meta.app import shardname, appname +from meta.logger import log_wrap +from utils.lib import utc_now + +from .data import CoreData + +logger = logging.getLogger(__name__) + + +class keydefaultdict(defaultdict): + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + else: + ret = self[key] = self.default_factory(key) + return ret + + +class CoreCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = CoreData() + bot.db.load_registry(self.data) + + self.app_config: Optional[CoreData.AppConfig] = None + self.bot_config: Optional[CoreData.BotConfig] = None + + self.app_cmd_cache: list[discord.app_commands.AppCommand] = [] + self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {} + self.mention_cache: dict[str, str] = keydefaultdict(self.mention_cmd) + + async def cog_load(self): + # Fetch (and possibly create) core data rows. + self.app_config = await self.data.AppConfig.fetch_or_create(appname) + self.bot_config = await self.data.BotConfig.fetch_or_create(appname) + + # Load the app command cache + await self.reload_appcmd_cache() + + async def reload_appcmd_cache(self): + for guildid in self.bot.testing_guilds: + self.app_cmd_cache += await self.bot.tree.fetch_commands(guild=discord.Object(guildid)) + self.app_cmd_cache += await self.bot.tree.fetch_commands() + self.cmd_name_cache = {cmd.name: cmd for cmd in self.app_cmd_cache} + self.mention_cache = self._mention_cache_from(self.app_cmd_cache) + + def _mention_cache_from(self, cmds: list[appcmd.AppCommand | appcmd.AppCommandGroup]): + cache = keydefaultdict(self.mention_cmd) + for cmd in cmds: + cache[cmd.qualified_name if isinstance(cmd, appcmd.AppCommandGroup) else cmd.name] = cmd.mention + subcommands = [option for option in cmd.options if isinstance(option, appcmd.AppCommandGroup)] + if subcommands: + subcache = self._mention_cache_from(subcommands) + cache |= subcache + return cache + + def mention_cmd(self, name: str): + """ + Create an application command mention for the given names. + + If not found in cache, creates a 'fake' mention with an invalid id. + """ + if name in self.mention_cache: + mention = self.mention_cache[name] + else: + mention = f"" + return mention diff --git a/src/core/data.py b/src/core/data.py new file mode 100644 index 0000000..7bb4276 --- /dev/null +++ b/src/core/data.py @@ -0,0 +1,45 @@ +from enum import Enum +from itertools import chain +from psycopg import sql +from cachetools import TTLCache +import discord + +from meta import conf +from meta.logger import log_wrap +from data import Table, Registry, Column, RowModel, RegisterEnum +from data.models import WeakCache +from data.columns import Integer, String, Bool, Timestamp + + +class CoreData(Registry, name="core"): + class AppConfig(RowModel): + """ + Schema + ------ + CREATE TABLE app_config( + appname TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """ + _tablename_ = 'app_config' + + appname = String(primary=True) + created_at = Timestamp() + + class BotConfig(RowModel): + """ + Schema + ------ + CREATE TABLE bot_config( + appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE, + sponsor_prompt TEXT, + sponsor_message TEXT, + default_skin TEXT + ); + """ + _tablename_ = 'bot_config' + + appname = String(primary=True) + default_skin = String() + sponsor_prompt = String() + sponsor_message = String() diff --git a/src/core/setting_types.py b/src/core/setting_types.py new file mode 100644 index 0000000..777b782 --- /dev/null +++ b/src/core/setting_types.py @@ -0,0 +1,227 @@ +""" +Additional abstract setting types useful for StudyLion settings. +""" +from typing import Optional +import json +import traceback + +import discord +from discord.enums import TextStyle + +from settings.base import ParentID +from settings.setting_types import IntegerSetting, StringSetting +from meta import conf +from meta.errors import UserInputError +from babel.translator import ctx_translator +from utils.lib import MessageArgs + +from . import babel + +_p = babel._p + + +class MessageSetting(StringSetting): + """ + Typed Setting ABC representing a message sent to Discord. + + Data is a json-formatted string dict with at least one of the fields 'content', 'embed', 'embeds' + Value is the corresponding dictionary + """ + # TODO: Extend to support format keys + + _accepts = _p( + 'settype:message|accepts', + "JSON formatted raw message data" + ) + + @staticmethod + async def download_attachment(attached: discord.Attachment): + """ + Download a discord.Attachment with some basic filetype and file size validation. + """ + t = ctx_translator.get().t + + error = None + decoded = None + if attached.content_type and not ('json' in attached.content_type): + error = t(_p( + 'settype:message|download|error:not_json', + "The attached message data is not a JSON file!" + )) + elif attached.size > 10000: + error = t(_p( + 'settype:message|download|error:size', + "The attached message data is too large!" + )) + else: + content = await attached.read() + try: + decoded = content.decode('UTF-8') + except UnicodeDecodeError: + error = t(_p( + 'settype:message|download|error:decoding', + "Could not decode the message data. Please ensure it is saved with the `UTF-8` encoding." + )) + + if error is not None: + raise UserInputError(error) + else: + return decoded + + @classmethod + def value_to_args(cls, parent_id: ParentID, value: dict, **kwargs) -> MessageArgs: + if not value: + return None + + args = {} + args['content'] = value.get('content', "") + if 'embed' in value: + embed = discord.Embed.from_dict(value['embed']) + args['embed'] = embed + if 'embeds' in value: + embeds = [] + for embed_data in value['embeds']: + embeds.append(discord.Embed.from_dict(embed_data)) + args['embeds'] = embeds + return MessageArgs(**args) + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value: Optional[dict], **kwargs): + if value and any(value.get(key, None) for key in ('content', 'embed', 'embeds')): + data = json.dumps(value) + else: + data = None + return data + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data: Optional[str], **kwargs): + if data: + value = json.loads(data) + else: + value = None + return value + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Provided user string can be downright random. + + If it isn't json-formatted, treat it as the content of the message. + If it is, do basic checking on the length and embeds. + """ + string = string.strip() + if not string or string.lower() == 'none': + return None + + t = ctx_translator.get().t + + error_tip = t(_p( + 'settype:message|error_suffix', + "You can view, test, and fix your embed using the online [embed builder]({link})." + )).format( + link="https://glitchii.github.io/embedbuilder/?editor=json" + ) + + if string.startswith('{') and string.endswith('}'): + # Assume the string is a json-formatted message dict + try: + value = json.loads(string) + except json.JSONDecodeError as err: + error = t(_p( + 'settype:message|error:invalid_json', + "The provided message data was not a valid JSON document!\n" + "`{error}`" + )).format(error=str(err)) + raise UserInputError(error + '\n' + error_tip) + + if not isinstance(value, dict) or not any(value.get(key, None) for key in ('content', 'embed', 'embeds')): + error = t(_p( + 'settype:message|error:json_missing_keys', + "Message data must be a JSON object with at least one of the following fields: " + "`content`, `embed`, `embeds`" + )) + raise UserInputError(error + '\n' + error_tip) + + embed_data = value.get('embed', None) + if not isinstance(embed_data, dict): + error = t(_p( + 'settype:message|error:json_embed_type', + "`embed` field must be a valid JSON object." + )) + raise UserInputError(error + '\n' + error_tip) + + embeds_data = value.get('embeds', []) + if not isinstance(embeds_data, list): + error = t(_p( + 'settype:message|error:json_embeds_type', + "`embeds` field must be a list." + )) + raise UserInputError(error + '\n' + error_tip) + + if embed_data and embeds_data: + error = t(_p( + 'settype:message|error:json_embed_embeds', + "Message data cannot include both `embed` and `embeds`." + )) + raise UserInputError(error + '\n' + error_tip) + + content_data = value.get('content', "") + if not isinstance(content_data, str): + error = t(_p( + 'settype:message|error:json_content_type', + "`content` field must be a string." + )) + raise UserInputError(error + '\n' + error_tip) + + # Validate embeds, which is the most likely place for something to go wrong + embeds = [embed_data] if embed_data else embeds_data + try: + for embed in embeds: + discord.Embed.from_dict(embed) + except Exception as e: + # from_dict may raise a range of possible exceptions. + raw_error = ''.join( + traceback.TracebackException.from_exception(e).format_exception_only() + ) + error = t(_p( + 'ui:settype:message|error:embed_conversion', + "Could not parse the message embed data.\n" + "**Error:** `{exception}`" + )).format(exception=raw_error) + raise UserInputError(error + '\n' + error_tip) + + # At this point, the message will at least successfully convert into MessageArgs + # There are numerous ways it could still be invalid, e.g. invalid urls, or too-long fields + # or the total message content being too long, or too many fields, etc + # This will need to be caught in anything which displays a message parsed from user data. + else: + # Either the string is not json formatted, or the formatting is broken + # Assume the string is a content message + value = { + 'content': string + } + return json.dumps(value) + + @classmethod + def _format_data(cls, parent_id: ParentID, data: Optional[str], **kwargs): + if not data: + return None + + value = cls._data_to_value(parent_id, data, **kwargs) + content = value.get('content', "") + if 'embed' in value or 'embeds' in value or len(content) > 100: + t = ctx_translator.get().t + formatted = t(_p( + 'settype:message|format:too_long', + "Too long to display! See Preview." + )) + else: + formatted = content + + return formatted + + @property + def input_field(self): + field = super().input_field + field.style = TextStyle.long + return field diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py new file mode 100644 index 0000000..d31065e --- /dev/null +++ b/src/meta/LionBot.py @@ -0,0 +1,373 @@ +from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload +import logging +import asyncio +from weakref import WeakValueDictionary + +from constants import SCHEMA_VERSIONS +import discord +from discord.utils import MISSING +from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError +from discord.ext.commands.errors import CommandInvokeError, CheckFailure +from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError +from aiohttp import ClientSession + +from data import Database +from utils.lib import tabulate +from babel.translator import LeoBabel +from botdata import BotData, VersionHistory + +from .config import Conf +from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context +from .context import context +from .LionContext import LionContext +from .LionTree import LionTree +from .errors import HandledException, SafeCancellation +from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus + +if TYPE_CHECKING: + from core.cog import CoreCog + +logger = logging.getLogger(__name__) + + +class LionBot(Bot): + def __init__( + self, *args, appname: str, shardname: str, db: Database, config: Conf, + initial_extensions: List[str], web_client: ClientSession, + testing_guilds: List[int] = [], **kwargs + ): + kwargs.setdefault('tree_cls', LionTree) + super().__init__(*args, **kwargs) + self.web_client = web_client + self.testing_guilds = testing_guilds + self.initial_extensions = initial_extensions + self.db = db + self.appname = appname + self.shardname = shardname +# self.appdata = appdata + self.data: BotData = db.load_registry(BotData()) + self.config = config + self.translator = LeoBabel() + + self.system_monitor = SystemMonitor() + self.monitor = ComponentMonitor('LionBot', self._monitor_status) + self.system_monitor.add_component(self.monitor) + + self._locks = WeakValueDictionary() + self._running_events = set() + + @property + def dbconn(self): + return self.db + + @property + def core(self): + return self.get_cog('CoreCog') + + async def _monitor_status(self): + if self.is_closed(): + level = StatusLevel.ERRORED + info = "(ERROR) Websocket is closed" + data = {} + elif self.is_ws_ratelimited(): + level = StatusLevel.WAITING + info = "(WAITING) Websocket is ratelimited" + data = {} + elif not self.is_ready(): + level = StatusLevel.STARTING + info = "(STARTING) Not yet ready" + data = {} + else: + level = StatusLevel.OKAY + info = ( + "(OK) " + "Logged in with {guild_count} guilds, " + ", websocket latency {latency}, and {events} running events." + ) + data = { + 'guild_count': len(self.guilds), + 'latency': self.latency, + 'events': len(self._running_events), + } + return ComponentStatus(level, info, info, data) + + async def setup_hook(self) -> None: + log_context.set(f"APP: {self.application_id}") + + for extension in self.initial_extensions: + await self.load_extension(extension) + + for guildid in self.testing_guilds: + guild = discord.Object(guildid) + if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)): + self.tree.copy_global_to(guild=guild) + await self.tree.sync(guild=guild) + + # To make the type checker happy about fetching cogs by name + # TODO: Move this to stubs at some point + + @overload + def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog': + ... + + @overload + def get_cog(self, name: str) -> Optional[Cog]: + ... + + def get_cog(self, name: str) -> Optional[Cog]: + return super().get_cog(name) + + async def add_cog(self, cog: Cog, **kwargs): + sup = super() + @log_wrap(action=f"Attach {cog.__cog_name__}") + async def wrapper(): + logger.info(f"Attaching Cog {cog.__cog_name__}") + await sup.add_cog(cog, **kwargs) + logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.") + await wrapper() + + async def load_extension(self, name, *, package=None, **kwargs): + sup = super() + @log_wrap(action=f"Load {name.strip('.')}") + async def wrapper(): + logger.info(f"Loading extension {name} in package {package}.") + await sup.load_extension(name, package=package, **kwargs) + logger.debug(f"Loaded extension {name} in package {package}.") + await wrapper() + + async def start(self, token: str, *, reconnect: bool = True): + await self.data.init() + for component, req in SCHEMA_VERSIONS.items(): + await self.version_check(component, req) + + with logging_context(action="Login"): + start_task = asyncio.create_task(self.login(token)) + await start_task + + with logging_context(stack=("Running",)): + run_task = asyncio.create_task(self.connect(reconnect=reconnect)) + await run_task + + async def version_check(self, component: str, req_version: int): + # Query the database to confirm that the given component is listed with the given version. + # Typically done upon loading a component + rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1) + + version = rows[0].to_version if rows else 0 + + if version != req_version: + raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'") + else: + logger.debug( + "Component %s passed version check with version %s", + component, + version + ) + return True + + + def dispatch(self, event_name: str, *args, **kwargs): + with logging_context(action=f"Dispatch {event_name}"): + super().dispatch(event_name, *args, **kwargs) + + def _schedule_event(self, coro, event_name, *args, **kwargs): + """ + Extends client._schedule_event to keep a persistent + background task store. + """ + task = super()._schedule_event(coro, event_name, *args, **kwargs) + self._running_events.add(task) + task.add_done_callback(lambda fut: self._running_events.discard(fut)) + + def idlock(self, snowflakeid): + lock = self._locks.get(snowflakeid, None) + if lock is None: + lock = self._locks[snowflakeid] = asyncio.Lock() + return lock + + async def on_ready(self): + logger.info( + f"Logged in as {self.application.name}\n" + f"Application id {self.application.id}\n" + f"Shard Talk identifier {self.shardname}\n" + "------------------------------\n" + f"Enabled Modules: {', '.join(self.extensions.keys())}\n" + f"Loaded Cogs: {', '.join(self.cogs.keys())}\n" + f"Registered Data: {', '.join(self.db.registries.keys())}\n" + f"Listening for {sum(1 for _ in self.walk_commands())} commands\n" + "------------------------------\n" + f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n" + "Ready to take commands!\n", + extra={'action': 'Ready'} + ) + + async def get_context(self, origin, /, *, cls=MISSING): + if cls is MISSING: + cls = LionContext + ctx = await super().get_context(origin, cls=cls) + context.set(ctx) + return ctx + + async def on_command(self, ctx: LionContext): + logger.info( + f"Executing command '{ctx.command.qualified_name}' " + f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') " + f"with interaction: {ctx.interaction.data if ctx.interaction else None}", + extra={'with_ctx': True} + ) + + async def on_command_error(self, ctx, exception): + # TODO: Some of these could have more user-feedback + logger.debug(f"Handling command error for {ctx}: {exception}") + if isinstance(ctx.command, HybridCommand) and ctx.command.app_command: + cmd_str = ctx.command.app_command.to_dict(self.tree) + else: + cmd_str = str(ctx.command) + try: + raise exception + except (HybridCommandError, CommandInvokeError, appCommandInvokeError): + try: + if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)): + original = exception.original.original + raise original + else: + original = exception.original + raise original + except HandledException: + pass + except TransformerError as e: + msg = str(e) + if msg: + try: + await ctx.error_reply(msg) + except Exception: + pass + logger.debug( + f"Caught a transformer error: {repr(e)}", + extra={'action': 'BotError', 'with_ctx': True} + ) + except SafeCancellation: + if original.msg: + try: + await ctx.error_reply(original.msg) + except Exception: + pass + logger.debug( + f"Caught a safe cancellation: {original.details}", + extra={'action': 'BotError', 'with_ctx': True} + ) + except discord.Forbidden: + # Unknown uncaught Forbidden + try: + # Attempt a general error reply + await ctx.reply("I don't have enough channel or server permissions to complete that command here!") + except Exception: + # We can't send anything at all. Exit quietly, but log. + logger.warning( + f"Caught an unhandled 'Forbidden' while executing: {cmd_str}", + exc_info=True, + extra={'action': 'BotError', 'with_ctx': True} + ) + except discord.HTTPException: + logger.error( + f"Caught an unhandled 'HTTPException' while executing: {cmd_str}", + exc_info=True, + extra={'action': 'BotError', 'with_ctx': True} + ) + except asyncio.CancelledError: + pass + except asyncio.TimeoutError: + pass + except Exception as e: + logger.exception( + f"Caught an unknown CommandInvokeError while executing: {cmd_str}", + extra={'action': 'BotError', 'with_ctx': True} + ) + + error_embed = discord.Embed( + title="Something went wrong!", + colour=discord.Colour.dark_red() + ) + error_embed.description = ( + "An unexpected error occurred while processing your command!\n" + "Our development team has been notified, and the issue will be addressed soon.\n" + ) + details = {} + details['error'] = f"`{repr(e)}`" + if ctx.interaction: + details['interactionid'] = f"`{ctx.interaction.id}`" + if ctx.command: + details['cmd'] = f"`{ctx.command.qualified_name}`" + if ctx.author: + details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`" + if ctx.guild: + details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`" + details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`" + if ctx.author: + ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else '' + details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`" + if ctx.channel.type is discord.enums.ChannelType.private: + details['channel'] = "`Direct Message`" + elif ctx.channel: + details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`" + details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`" + if ctx.author: + details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`" + details['shard'] = f"`{self.shardname}`" + details['log_stack'] = f"`{log_action_stack.get()}`" + + table = '\n'.join(tabulate(*details.items())) + error_embed.add_field(name='Details', value=table) + + try: + await ctx.error_reply(embed=error_embed) + except discord.HTTPException: + pass + finally: + exception.original = HandledException(exception.original) + except CheckFailure as e: + logger.debug( + f"Command failed check: {e}: {e.args}", + extra={'action': 'BotError', 'with_ctx': True} + ) + try: + await ctx.error_reply(str(e)) + except discord.HTTPException: + pass + except Exception: + # Completely unknown exception outside of command invocation! + # Something is very wrong here, don't attempt user interaction. + logger.exception( + f"Caught an unknown top-level exception while executing: {cmd_str}", + extra={'action': 'BotError', 'with_ctx': True} + ) + + def add_command(self, command): + if not hasattr(command, '_placeholder_group_'): + super().add_command(command) + + def request_chunking_for(self, guild): + if not guild.chunked: + return asyncio.create_task( + self._connection.chunk_guild(guild, wait=False, cache=True), + name=f"Background chunkreq for {guild.id}" + ) + + async def on_interaction(self, interaction: discord.Interaction): + """ + Adds the interaction author to guild cache if appropriate. + + This gets run a little bit late, so it is possible the interaction gets handled + without the author being in case. + """ + guild = interaction.guild + user = interaction.user + if guild is not None and user is not None and isinstance(user, discord.Member): + if not guild.get_member(user.id): + guild._add_member(user) + if guild is not None and not guild.chunked: + # Getting an interaction in the guild is a good enough reason to request chunking + logger.info( + f"Unchunked guild requesting chunking after interaction." + ) + self.request_chunking_for(guild) diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py new file mode 100644 index 0000000..39ca43a --- /dev/null +++ b/src/meta/LionCog.py @@ -0,0 +1,58 @@ +from typing import Any + +from discord.ext.commands import Cog +from discord.ext import commands as cmds + + +class LionCog(Cog): + # A set of other cogs that this cog depends on + depends_on: set['LionCog'] = set() + _placeholder_groups_: set[str] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + cls._placeholder_groups_ = set() + + for base in reversed(cls.__mro__): + for elem, value in base.__dict__.items(): + if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'): + cls._placeholder_groups_.add(value.name) + + def __new__(cls, *args: Any, **kwargs: Any): + # Patch to ensure no placeholder groups are in the command list + self = super().__new__(cls) + self.__cog_commands__ = [ + command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_ + ] + return self + + async def _inject(self, bot, *args, **kwargs): + if self.depends_on: + not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)} + raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}") + + return await super()._inject(bot, *args, *kwargs) + + @classmethod + def placeholder_group(cls, group: cmds.HybridGroup): + group._placeholder_group_ = True + return group + + def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup): + """ + Crossload a placeholder group's commands into the target group + """ + if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup): + raise ValueError("Placeholder and target groups my be HypridGroups.") + if placeholder_group.name not in self._placeholder_groups_: + raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.") + if target_group.app_command is None: + raise ValueError("Target group has no app_command to crossload into.") + + for command in placeholder_group.commands: + placeholder_group.remove_command(command.name) + target_group.remove_command(command.name) + acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self) + command.app_command = acmd + target_group.add_command(command) diff --git a/src/meta/LionContext.py b/src/meta/LionContext.py new file mode 100644 index 0000000..e1b21d8 --- /dev/null +++ b/src/meta/LionContext.py @@ -0,0 +1,195 @@ +import types +import logging +from collections import namedtuple +from typing import Optional, TYPE_CHECKING + +import discord +from discord.enums import ChannelType +from discord.ext.commands import Context + +if TYPE_CHECKING: + from .LionBot import LionBot + + +logger = logging.getLogger(__name__) + + +""" +Stuff that might be useful to implement (see cmdClient): + sent_messages cache + tasks cache + error reply + usage + interaction cache + View cache? + setting access +""" + + +FlatContext = namedtuple( + 'FlatContext', + ('message', + 'interaction', + 'guild', + 'author', + 'channel', + 'alias', + 'prefix', + 'failed') +) + + +class LionContext(Context['LionBot']): + """ + Represents the context a command is invoked under. + + Extends Context to add Lion-specific methods and attributes. + Also adds several contextual wrapped utilities for simpler user during command invocation. + """ + + def __repr__(self): + parts = {} + if self.interaction is not None: + parts['iid'] = self.interaction.id + parts['itype'] = f"\"{self.interaction.type.name}\"" + if self.message is not None: + parts['mid'] = self.message.id + if self.author is not None: + parts['uid'] = self.author.id + parts['uname'] = f"\"{self.author.name}\"" + if self.channel is not None: + parts['cid'] = self.channel.id + if self.channel.type is ChannelType.private: + parts['cname'] = f"\"{self.channel.recipient}\"" + else: + parts['cname'] = f"\"{self.channel.name}\"" + if self.guild is not None: + parts['gid'] = self.guild.id + parts['gname'] = f"\"{self.guild.name}\"" + if self.command is not None: + parts['cmd'] = f"\"{self.command.qualified_name}\"" + if self.invoked_with is not None: + parts['alias'] = f"\"{self.invoked_with}\"" + if self.command_failed: + parts['failed'] = self.command_failed + + return "".format( + ' '.join(f"{name}={value}" for name, value in parts.items()) + ) + + def flatten(self): + """Flat pure-data context information, for caching and logging.""" + return FlatContext( + self.message.id, + self.interaction.id if self.interaction is not None else None, + self.guild.id if self.guild is not None else None, + self.author.id if self.author is not None else None, + self.channel.id if self.channel is not None else None, + self.invoked_with, + self.prefix, + self.command_failed + ) + + @classmethod + def util(cls, util_func): + """ + Decorator to make a utility function available as a Context instance method. + """ + setattr(cls, util_func.__name__, util_func) + logger.debug(f"Attached context utility function: {util_func.__name__}") + return util_func + + @classmethod + def wrappable_util(cls, util_func): + """ + Decorator to add a Wrappable utility function as a Context instance method. + """ + wrapped = Wrappable(util_func) + setattr(cls, util_func.__name__, wrapped) + logger.debug(f"Attached wrappable context utility function: {util_func.__name__}") + return wrapped + + async def error_reply(self, content: Optional[str] = None, **kwargs): + if content and 'embed' not in kwargs: + embed = discord.Embed( + colour=discord.Colour.red(), + description=content + ) + kwargs['embed'] = embed + content = None + + # Expect this may be run in highly unusual circumstances. + # This should never error, or at least handle all errors. + if self.interaction: + kwargs.setdefault('ephemeral', True) + try: + await self.reply(content=content, **kwargs) + except discord.HTTPException: + pass + except Exception: + logger.exception( + "Unknown exception in 'error_reply'.", + extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True} + ) + + +class Wrappable: + __slots__ = ('_func', 'wrappers') + + def __init__(self, func): + self._func = func + self.wrappers = None + + @property + def __name__(self): + return self._func.__name__ + + def add_wrapper(self, func, name=None): + self.wrappers = self.wrappers or {} + name = name or func.__name__ + self.wrappers[name] = func + logger.debug( + f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.", + extra={'action': "Wrap Util"} + ) + + def remove_wrapper(self, name): + if not self.wrappers or name not in self.wrappers: + raise ValueError( + f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'" + ) + self.wrappers.pop(name) + logger.debug( + f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.", + extra={'action': "Unwrap Util"} + ) + + def __call__(self, *args, **kwargs): + if self.wrappers: + return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs) + else: + return self._func(*args, **kwargs) + + def _wrapped(self, iter_wraps): + next_wrap = next(iter_wraps, None) + if next_wrap: + def _func(*args, **kwargs): + return next_wrap(self._wrapped(iter_wraps), *args, **kwargs) + else: + _func = self._func + return _func + + def __get__(self, instance, cls=None): + if instance is None: + return self + else: + return types.MethodType(self, instance) + + +LionContext.reply = Wrappable(LionContext.reply) + + +# @LionContext.reply.add_wrapper +# async def think(func, ctx, *args, **kwargs): +# await ctx.channel.send("thinking") +# await func(ctx, *args, **kwargs) diff --git a/src/meta/LionTree.py b/src/meta/LionTree.py new file mode 100644 index 0000000..a0697f0 --- /dev/null +++ b/src/meta/LionTree.py @@ -0,0 +1,148 @@ +import logging + +import discord +from discord import Interaction +from discord.app_commands import CommandTree +from discord.app_commands.errors import AppCommandError, CommandInvokeError +from discord.enums import InteractionType +from discord.app_commands.namespace import Namespace + +from utils.lib import tabulate + +from .logger import logging_context, set_logging_context, log_wrap, log_action_stack +from .errors import SafeCancellation +from .config import conf + +logger = logging.getLogger(__name__) + + +class LionTree(CommandTree): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._call_tasks = set() + + async def on_error(self, interaction: discord.Interaction, error) -> None: + try: + if isinstance(error, CommandInvokeError): + raise error.original + else: + raise error + except SafeCancellation: + # Assume this has already been handled + pass + except Exception: + logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'}) + if interaction.type is not InteractionType.autocomplete: + embed = self.bugsplat(interaction, error) + await self.error_reply(interaction, embed) + + async def error_reply(self, interaction, embed): + if not interaction.is_expired(): + try: + if interaction.response.is_done(): + await interaction.followup.send(embed=embed, ephemeral=True) + else: + await interaction.response.send_message(embed=embed, ephemeral=True) + except discord.HTTPException: + pass + + def bugsplat(self, interaction, e): + error_embed = discord.Embed(title="Something went wrong!", colour=discord.Colour.red()) + error_embed.description = ( + "An unexpected error occurred during this interaction!\n" + "Our development team has been notified, and the issue will be addressed soon.\n" + ) + details = {} + details['error'] = f"`{repr(e)}`" + details['interactionid'] = f"`{interaction.id}`" + details['interactiontype'] = f"`{interaction.type}`" + if interaction.command: + details['cmd'] = f"`{interaction.command.qualified_name}`" + if interaction.user: + details['user'] = f"`{interaction.user.id}` -- `{interaction.user}`" + if interaction.guild: + details['guild'] = f"`{interaction.guild.id}` -- `{interaction.guild.name}`" + details['my_guild_perms'] = f"`{interaction.guild.me.guild_permissions.value}`" + if interaction.user: + ownerstr = ' (owner)' if interaction.user.id == interaction.guild.owner_id else '' + details['user_guild_perms'] = f"`{interaction.user.guild_permissions.value}{ownerstr}`" + if interaction.channel.type is discord.enums.ChannelType.private: + details['channel'] = "`Direct Message`" + elif interaction.channel: + details['channel'] = f"`{interaction.channel.id}` -- `{interaction.channel.name}`" + details['my_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.guild.me).value}`" + if interaction.user: + details['user_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.user).value}`" + details['shard'] = f"`{interaction.client.shardname}`" + details['log_stack'] = f"`{log_action_stack.get()}`" + + table = '\n'.join(tabulate(*details.items())) + error_embed.add_field(name='Details', value=table) + return error_embed + + def _from_interaction(self, interaction: Interaction) -> None: + @log_wrap(context=f"iid: {interaction.id}", isolate=False) + async def wrapper(): + try: + await self._call(interaction) + except AppCommandError as e: + await self._dispatch_error(interaction, e) + + task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker') + self._call_tasks.add(task) + task.add_done_callback(lambda fut: self._call_tasks.discard(fut)) + + async def _call(self, interaction): + if not await self.interaction_check(interaction): + interaction.command_failed = True + return + + data = interaction.data # type: ignore + type = data.get('type', 1) + if type != 1: + # Context menu command... + await self._call_context_menu(interaction, data, type) + return + + command, options = self._get_app_command_options(data) + + # Pre-fill the cached slot to prevent re-computation + interaction._cs_command = command + + # At this point options refers to the arguments of the command + # and command refers to the class type we care about + namespace = Namespace(interaction, data.get('resolved', {}), options) + + # Same pre-fill as above + interaction._cs_namespace = namespace + + # Auto complete handles the namespace differently... so at this point this is where we decide where that is. + if interaction.type is InteractionType.autocomplete: + set_logging_context(action=f"Acmp {command.qualified_name}") + focused = next((opt['name'] for opt in options if opt.get('focused')), None) + if focused is None: + raise AppCommandError( + 'This should not happen, but there is no focused element. This is a Discord bug.' + ) + try: + await command._invoke_autocomplete(interaction, focused, namespace) + except Exception as e: + await self.on_error(interaction, e) + return + + set_logging_context(action=f"Run {command.qualified_name}") + logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}") + try: + await command._invoke_with_namespace(interaction, namespace) + except AppCommandError as e: + interaction.command_failed = True + await command._invoke_error_handlers(interaction, e) + await self.on_error(interaction, e) + else: + if not interaction.command_failed: + self.client.dispatch('app_command_completion', interaction, command) + finally: + if interaction.command_failed: + logger.debug("Command completed with errors.") + else: + logger.debug("Command completed without errors.") diff --git a/src/meta/__init__.py b/src/meta/__init__.py new file mode 100644 index 0000000..5f68fe3 --- /dev/null +++ b/src/meta/__init__.py @@ -0,0 +1,15 @@ +from .LionBot import LionBot +from .LionCog import LionCog +from .LionContext import LionContext +from .LionTree import LionTree + +from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app +from .config import conf, configEmoji +from .args import args +from .app import appname, appname_from_shard, shard_from_appname +from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled +from .context import context, ctx_bot + +from . import sharding +from . import logger +from . import app diff --git a/src/meta/app.py b/src/meta/app.py new file mode 100644 index 0000000..9f0c9a2 --- /dev/null +++ b/src/meta/app.py @@ -0,0 +1,32 @@ +""" +appname: str + The base identifer for this application. + This identifies which services the app offers. +shardname: str + The specific name of the running application. + Only one process should be connecteded with a given appname. + For the bot apps, usually specifies the shard id and shard number. +""" +# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data? + +from . import sharding, conf +from .logger import log_app +from .args import args + + +appname = conf.data['appid'] +appid = appname # backwards compatibility + + +def appname_from_shard(shardid): + appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}" + return appname + + +def shard_from_appname(appname: str): + return int(appname.rsplit('_', maxsplit=1)[-1]) + + +shardname = appname_from_shard(sharding.shard_number) + +log_app.set(shardname) diff --git a/src/meta/args.py b/src/meta/args.py new file mode 100644 index 0000000..8d82a69 --- /dev/null +++ b/src/meta/args.py @@ -0,0 +1,35 @@ +import argparse + +from constants import CONFIG_FILE + +# ------------------------------ +# Parsed commandline arguments +# ------------------------------ +parser = argparse.ArgumentParser() +parser.add_argument( + '--conf', + dest='config', + default=CONFIG_FILE, + help="Path to configuration file." +) +parser.add_argument( + '--shard', + dest='shard', + default=None, + type=int, + help="Shard number to run, if applicable." +) +parser.add_argument( + '--host', + dest='host', + default='127.0.0.1', + help="IP address to run the app listener on." +) +parser.add_argument( + '--port', + dest='port', + default='5001', + help="Port to run the app listener on." +) + +args = parser.parse_args() diff --git a/src/meta/config.py b/src/meta/config.py new file mode 100644 index 0000000..9e624df --- /dev/null +++ b/src/meta/config.py @@ -0,0 +1,146 @@ +from discord import PartialEmoji +import configparser as cfgp + +from .args import args + +shard_number = args.shard + +class configEmoji(PartialEmoji): + __slots__ = ('fallback',) + + def __init__(self, *args, fallback=None, **kwargs): + super().__init__(*args, **kwargs) + self.fallback = fallback + + @classmethod + def from_str(cls, emojistr: str): + """ + Parses emoji strings of one of the following forms + ` or fallback` + `<:name:id> or fallback` + `` + `<:name:id>` + """ + splits = emojistr.rsplit(' or ', maxsplit=1) + + fallback = splits[1] if len(splits) > 1 else None + emojistr = splits[0].strip('<> ') + animated, name, id = emojistr.split(':') + return cls( + name=name, + fallback=PartialEmoji(name=fallback) if fallback is not None else None, + animated=bool(animated), + id=int(id) if id else None + ) + + +class MapDotProxy: + """ + Allows dot access to an underlying Mappable object. + """ + __slots__ = ("_map", "_converter") + + def __init__(self, mappable, converter=None): + self._map = mappable + self._converter = converter + + def __getattribute__(self, key): + _map = object.__getattribute__(self, '_map') + if key == '_map': + return _map + if key in _map: + _converter = object.__getattribute__(self, '_converter') + if _converter: + return _converter(_map[key]) + else: + return _map[key] + else: + return object.__getattribute__(_map, key) + + def __getitem__(self, key): + return self._map.__getitem__(key) + + +class ConfigParser(cfgp.ConfigParser): + """ + Extension of base ConfigParser allowing optional + section option retrieval without defaults. + """ + def options(self, section, no_defaults=False, **kwargs): + if no_defaults: + try: + return list(self._sections[section].keys()) + except KeyError: + raise cfgp.NoSectionError(section) + else: + return super().options(section, **kwargs) + + +class Conf: + def __init__(self, configfile, section_name="DEFAULT"): + self.configfile = configfile + + self.config = ConfigParser( + converters={ + "intlist": self._getintlist, + "list": self._getlist, + "emoji": configEmoji.from_str, + } + ) + + with open(configfile) as conff: + # Opening with read_file mainly to ensure the file exists + self.config.read_file(conff) + + self.section_name = section_name if section_name in self.config else 'DEFAULT' + + self.default = self.config["DEFAULT"] + self.section = MapDotProxy(self.config[self.section_name]) + self.bot = self.section + + # Config file recursion, read in configuration files specified in every "ALSO_READ" key. + more_to_read = self.section.getlist("ALSO_READ", []) + read = set() + while more_to_read: + to_read = more_to_read.pop(0) + read.add(to_read) + self.config.read(to_read) + new_paths = [path for path in self.section.getlist("ALSO_READ", []) + if path not in read and path not in more_to_read] + more_to_read.extend(new_paths) + + self.emojis = MapDotProxy( + self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section, + converter=configEmoji.from_str + ) + + global conf + conf = self + + def __getitem__(self, key): + return self.section[key].strip() + + def __getattr__(self, section): + name = section.upper() + shard_name = f"{name}-{shard_number}" + if shard_name in self.config: + return self.config[shard_name] + else: + return self.config[name] + + def get(self, name, fallback=None): + result = self.section.get(name, fallback) + return result.strip() if result else result + + def _getintlist(self, value): + return [int(item.strip()) for item in value.split(',')] + + def _getlist(self, value): + return [item.strip() for item in value.split(',')] + + def write(self): + with open(self.configfile, 'w') as conffile: + self.config.write(conffile) + + +conf = Conf(args.config, 'BOT') diff --git a/src/meta/context.py b/src/meta/context.py new file mode 100644 index 0000000..75f1df2 --- /dev/null +++ b/src/meta/context.py @@ -0,0 +1,20 @@ +""" +Namespace for various global context variables. +Allows asyncio callbacks to accurately retrieve information about the current state. +""" + + +from typing import TYPE_CHECKING, Optional + +from contextvars import ContextVar + +if TYPE_CHECKING: + from .LionBot import LionBot + from .LionContext import LionContext + + +# Contains the current command context, if applicable +context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None) + +# Contains the current LionBot instance +ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None) diff --git a/src/meta/errors.py b/src/meta/errors.py new file mode 100644 index 0000000..a5d6cbf --- /dev/null +++ b/src/meta/errors.py @@ -0,0 +1,64 @@ +from typing import Optional +from string import Template + + +class SafeCancellation(Exception): + """ + Raised to safely cancel execution of the current operation. + + If not caught, is expected to be propagated to the Tree and safely ignored there. + If a `msg` is provided, a context-aware error handler should catch and send the message to the user. + The error handler should then set the `msg` to None, to avoid double handling. + Debugging information should go in `details`, to be logged by a top-level error handler. + """ + default_message = "" + + @property + def msg(self): + return self._msg if self._msg is not None else self.default_message + + def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs): + self._msg: Optional[str] = _msg + self.details: str = details if details is not None else self.msg + super().__init__(**kwargs) + + +class UserInputError(SafeCancellation): + """ + A SafeCancellation induced from unparseable user input. + """ + default_message = "Could not understand your input." + + @property + def msg(self): + return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message + + def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs): + self.info = info + super().__init__(_msg, **kwargs) + + +class UserCancelled(SafeCancellation): + """ + A SafeCancellation induced from manual user cancellation. + + Usually silent. + """ + default_msg = None + + +class ResponseTimedOut(SafeCancellation): + """ + A SafeCancellation induced from a user interaction time-out. + """ + default_msg = "Session timed out waiting for input." + + +class HandledException(SafeCancellation): + """ + Sentinel class to indicate to error handlers that this exception has been handled. + Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler. + """ + def __init__(self, exc=None, **kwargs): + self.exc = exc + super().__init__(**kwargs) diff --git a/src/meta/logger.py b/src/meta/logger.py new file mode 100644 index 0000000..ffa97f7 --- /dev/null +++ b/src/meta/logger.py @@ -0,0 +1,468 @@ +import sys +import logging +import asyncio +from typing import List, Optional +from logging.handlers import QueueListener, QueueHandler +import queue +import multiprocessing +from contextlib import contextmanager +from io import StringIO +from functools import wraps +from contextvars import ContextVar + +import discord +from discord import Webhook, File +import aiohttp + +from .config import conf +from . import sharding +from .context import context +from utils.lib import utc_now +from utils.ratelimits import Bucket, BucketOverFull, BucketFull + + +log_logger = logging.getLogger(__name__) +log_logger.propagate = False + + +log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') +log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=()) +log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number)) + +def set_logging_context( + context: Optional[str] = None, + action: Optional[str] = None, + stack: Optional[tuple[str, ...]] = None +): + """ + Statically set the logging context variables to the given values. + + If `action` is given, pushes it onto the `log_action_stack`. + """ + if context is not None: + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + + +@contextmanager +def logging_context(context=None, action=None, stack=None): + """ + Context manager for executing a block of code in a given logging context. + + This context manager should only be used around synchronous code. + This is because async code *may* get cancelled or externally garbage collected, + in which case the finally block will be executed in the wrong context. + See https://github.com/python/cpython/issues/93740 + This can be refactored nicely if this gets merged: + https://github.com/python/cpython/pull/99634 + + (It will not necessarily break on async code, + if the async code can be guaranteed to clean up in its own context.) + """ + if context is not None: + oldcontext = log_context.get() + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + try: + yield + finally: + if context is not None: + log_context.set(oldcontext) + if stack is not None or action is not None: + log_action_stack.set(astack) + + +def with_log_ctx(isolate=True, **kwargs): + """ + Execute a coroutine inside a given logging context. + + If `isolate` is true, ensures that context does not leak + outside the coroutine. + + If `isolate` is false, just statically set the context, + which will leak unless the coroutine is + called in an externally copied context. + """ + def decorator(func): + @wraps(func) + async def wrapped(*w_args, **w_kwargs): + if isolate: + with logging_context(**kwargs): + # Task creation will synchronously copy the context + # This is gc safe + name = kwargs.get('action', f"log-wrapped-{func.__name__}") + task = asyncio.create_task(func(*w_args, **w_kwargs), name=name) + return await task + else: + # This will leak context changes + set_logging_context(**kwargs) + return await func(*w_args, **w_kwargs) + return wrapped + return decorator + + +# For backwards compatibility +log_wrap = with_log_ctx + + +def persist_task(task_collection: set): + """ + Coroutine decorator that ensures the coroutine is scheduled as a task + and added to the given task_collection for strong reference + when it is called. + + This is just a hack to handle discord.py events potentially + being unexpectedly garbage collected. + + Since this also implicitly schedules the coroutine as a task when it is called, + the coroutine will also be run inside an isolated context. + """ + def decorator(coro): + @wraps(coro) + async def wrapped(*w_args, **w_kwargs): + name = f"persisted-{coro.__name__}" + task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name) + task_collection.add(task) + task.add_done_callback(lambda f: task_collection.discard(f)) + await task + + +RESET_SEQ = "\033[0m" +COLOR_SEQ = "\033[3%dm" +BOLD_SEQ = "\033[1m" +"]]]" +BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) + + +def colour_escape(fmt: str) -> str: + cmap = { + '%(black)': COLOR_SEQ % BLACK, + '%(red)': COLOR_SEQ % RED, + '%(green)': COLOR_SEQ % GREEN, + '%(yellow)': COLOR_SEQ % YELLOW, + '%(blue)': COLOR_SEQ % BLUE, + '%(magenta)': COLOR_SEQ % MAGENTA, + '%(cyan)': COLOR_SEQ % CYAN, + '%(white)': COLOR_SEQ % WHITE, + '%(reset)': RESET_SEQ, + '%(bold)': BOLD_SEQ, + } + for key, value in cmap.items(): + fmt = fmt.replace(key, value) + return fmt + + +log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' + + '%(cyan)%(app)-15s%(reset)|' + + '%(cyan)%(context)-24s%(reset)|' + + '%(cyan)%(actionstr)-22s%(reset)|' + + ' %(bold)%(cyan)%(name)s:%(reset)' + + ' %(white)%(message)s%(ctxstr)s%(reset)') +log_format = colour_escape(log_format) + + +# Setup the logger +logger = logging.getLogger() +log_fmt = logging.Formatter( + fmt=log_format, + # datefmt='%Y-%m-%d %H:%M:%S' +) +logger.setLevel(logging.NOTSET) + + +class LessThanFilter(logging.Filter): + def __init__(self, exclusive_maximum, name=""): + super(LessThanFilter, self).__init__(name) + self.max_level = exclusive_maximum + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.levelno < self.max_level else 0 + +class ExactLevelFilter(logging.Filter): + def __init__(self, target_level, name=""): + super().__init__(name) + self.target_level = target_level + + def filter(self, record): + return (record.levelno == self.target_level) + + +class ThreadFilter(logging.Filter): + def __init__(self, thread_name): + super().__init__("") + self.thread = thread_name + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.threadName == self.thread else 0 + + +class ContextInjection(logging.Filter): + def filter(self, record): + # These guards are to allow override through _extra + # And to ensure the injection is idempotent + if not hasattr(record, 'context'): + record.context = log_context.get() + + if not hasattr(record, 'actionstr'): + action_stack = log_action_stack.get() + if hasattr(record, 'action'): + action_stack = (*action_stack, record.action) + if action_stack: + record.actionstr = ' ➔ '.join(action_stack) + else: + record.actionstr = "Unknown Action" + + if not hasattr(record, 'app'): + record.app = log_app.get() + + if not hasattr(record, 'ctx'): + if ctx := context.get(): + record.ctx = repr(ctx) + else: + record.ctx = None + + if getattr(record, 'with_ctx', False) and record.ctx: + record.ctxstr = '\n' + record.ctx + else: + record.ctxstr = "" + return True + + +logging_handler_out = logging.StreamHandler(sys.stdout) +logging_handler_out.setLevel(logging.DEBUG) +logging_handler_out.setFormatter(log_fmt) +logging_handler_out.addFilter(ContextInjection()) +logger.addHandler(logging_handler_out) +log_logger.addHandler(logging_handler_out) + +logging_handler_err = logging.StreamHandler(sys.stderr) +logging_handler_err.setLevel(logging.WARNING) +logging_handler_err.setFormatter(log_fmt) +logging_handler_err.addFilter(ContextInjection()) +logger.addHandler(logging_handler_err) +log_logger.addHandler(logging_handler_err) + + +class LocalQueueHandler(QueueHandler): + def _emit(self, record: logging.LogRecord) -> None: + # Removed the call to self.prepare(), handle task cancellation + try: + self.enqueue(record) + except asyncio.CancelledError: + raise + except Exception: + self.handleError(record) + + +class WebHookHandler(logging.StreamHandler): + def __init__(self, webhook_url, prefix="", batch=True, loop=None): + super().__init__() + self.webhook_url = webhook_url + self.prefix = prefix + self.batched = "" + self.batch = batch + self.loop = loop + self.batch_delay = 10 + self.batch_task = None + self.last_batched = None + self.waiting = [] + + self.bucket = Bucket(20, 40) + self.ignored = 0 + + self.session = None + self.webhook = None + + def get_loop(self): + if self.loop is None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + return self.loop + + def emit(self, record): + self.format(record) + self.get_loop().call_soon_threadsafe(self._post, record) + + def _post(self, record): + if self.session is None: + self.setup() + asyncio.create_task(self.post(record)) + + def setup(self): + self.session = aiohttp.ClientSession() + self.webhook = Webhook.from_url(self.webhook_url, session=self.session) + + async def post(self, record): + if record.context == 'Webhook Logger': + # Don't livelog livelog errors + # Otherwise we recurse and Cloudflare hates us + return + log_context.set("Webhook Logger") + log_action_stack.set(("Logging",)) + log_app.set(record.app) + + try: + timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") + header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>" + context = f"\n# Context: {record.ctx}" if record.ctx else "" + message = f"{header}\n{record.msg}{context}" + + if len(message) > 1900: + as_file = True + else: + as_file = False + message = "```md\n{}\n```".format(message) + + # Post the log message(s) + if self.batch: + if len(message) > 1500: + await self._send_batched_now() + await self._send(message, as_file=as_file) + else: + self.batched += message + if len(self.batched) + len(message) > 1500: + await self._send_batched_now() + else: + asyncio.create_task(self._schedule_batched()) + else: + await self._send(message, as_file=as_file) + except Exception as ex: + print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr) + + async def _schedule_batched(self): + if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()): + # noop, don't reschedule if it is already scheduled + return + try: + self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay)) + await self.batch_task + await self._send_batched() + except asyncio.CancelledError: + return + except Exception as ex: + print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr) + + async def _send_batched_now(self): + if self.batch_task is not None and not self.batch_task.done(): + self.batch_task.cancel() + self.last_batched = None + await self._send_batched() + + async def _send_batched(self): + if self.batched: + batched = self.batched + self.batched = "" + await self._send(batched) + + async def _send(self, message, as_file=False): + try: + self.bucket.request() + except BucketOverFull: + # Silently ignore + self.ignored += 1 + return + except BucketFull: + logger.warning( + "Can't keep up! " + f"Ignoring records on live-logger {self.webhook.id}." + ) + self.ignored += 1 + return + else: + if self.ignored > 0: + logger.warning( + "Can't keep up! " + f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing." + ) + self.ignored = 0 + + try: + if as_file or len(message) > 1900: + with StringIO(message) as fp: + fp.seek(0) + await self.webhook.send( + f"{self.prefix}\n`{message.splitlines()[0]}`", + file=File(fp, filename="logs.md"), + username=log_app.get() + ) + else: + await self.webhook.send(self.prefix + '\n' + message, username=log_app.get()) + except discord.HTTPException: + logger.exception( + "Live logger errored. Slowing down live logger." + ) + self.bucket.fill() + + +handlers = [] +if webhook := conf.logging['general_log']: + handler = WebHookHandler(webhook, batch=True) + handlers.append(handler) + +if webhook := conf.logging['warning_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True) + handler.addFilter(ExactLevelFilter(logging.WARNING)) + handler.setLevel(logging.WARNING) + handlers.append(handler) + +if webhook := conf.logging['error_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True) + handler.setLevel(logging.ERROR) + handlers.append(handler) + +if webhook := conf.logging['critical_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False) + handler.setLevel(logging.CRITICAL) + handlers.append(handler) + + +def make_queue_handler(queue): + qhandler = QueueHandler(queue) + qhandler.setLevel(logging.INFO) + qhandler.addFilter(ContextInjection()) + return qhandler + + +def setup_main_logger(multiprocess=False): + q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue() + if handlers: + # First create a separate loop to run the handlers on + import threading + + def run_loop(loop): + asyncio.set_event_loop(loop) + try: + loop.run_forever() + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + loop = asyncio.new_event_loop() + loop_thread = threading.Thread(target=lambda: run_loop(loop)) + loop_thread.daemon = True + loop_thread.start() + + for handler in handlers: + handler.loop = loop + + qhandler = make_queue_handler(q) + # qhandler.addFilter(ThreadFilter('MainThread')) + logger.addHandler(qhandler) + + listener = QueueListener( + q, *handlers, respect_handler_level=True + ) + listener.start() + return q diff --git a/src/meta/monitor.py b/src/meta/monitor.py new file mode 100644 index 0000000..474c51f --- /dev/null +++ b/src/meta/monitor.py @@ -0,0 +1,139 @@ +import logging +import asyncio +from enum import IntEnum +from collections import deque, ChainMap +import datetime as dt + +logger = logging.getLogger(__name__) + + +class StatusLevel(IntEnum): + ERRORED = -2 + UNSURE = -1 + WAITING = 0 + STARTING = 1 + OKAY = 2 + + @property + def symbol(self): + return symbols[self] + + +symbols = { + StatusLevel.ERRORED: '🟥', + StatusLevel.UNSURE: '🟧', + StatusLevel.WAITING: '⬜', + StatusLevel.STARTING: '🟫', + StatusLevel.OKAY: '🟩', +} + + +class ComponentStatus: + def __init__(self, level: StatusLevel, short_formatstr: str, long_formatstr: str, data: dict = {}): + self.level = level + self.short_formatstr = short_formatstr + self.long_formatstr = long_formatstr + self.data = data + self.created_at = dt.datetime.now(tz=dt.timezone.utc) + + def format_args(self): + extra = { + 'created_at': self.created_at, + 'level': self.level, + 'symbol': self.level.symbol, + } + return ChainMap(extra, self.data) + + @property + def short(self): + return self.short_formatstr.format(**self.format_args()) + + @property + def long(self): + return self.long_formatstr.format(**self.format_args()) + + +class ComponentMonitor: + _name = None + + def __init__(self, name=None, callback=None): + self._callback = callback + self.name = name or self._name + if not self.name: + raise ValueError("ComponentMonitor must have a name") + + async def _make_status(self, *args, **kwargs): + if self._callback is not None: + return await self._callback(*args, **kwargs) + else: + raise NotImplementedError + + async def status(self) -> ComponentStatus: + try: + status = await self._make_status() + except Exception as e: + logger.exception( + f"Status callback for component '{self.name}' failed. This should not happen." + ) + status = ComponentStatus( + level=StatusLevel.UNSURE, + short_formatstr="Status callback for '{name}' failed with error '{error}'", + long_formatstr="Status callback for '{name}' failed with error '{error}'", + data={ + 'name': self.name, + 'error': repr(e) + } + ) + return status + + +class SystemMonitor: + def __init__(self): + self.components = {} + self.recent = deque(maxlen=10) + + def add_component(self, component: ComponentMonitor): + self.components[component.name] = component + return component + + async def request(self): + """ + Request status from each component. + """ + tasks = { + name: asyncio.create_task(comp.status()) + for name, comp in self.components.items() + } + await asyncio.gather(*tasks.values()) + status = { + name: await fut for name, fut in tasks.items() + } + self.recent.append(status) + return status + + async def _format_summary(self, status_dict: dict[str, ComponentStatus]): + """ + Format a one line summary from a status dict. + """ + freq = {level: 0 for level in StatusLevel} + for status in status_dict.values(): + freq[status.level] += 1 + + summary = '\t'.join(f"{level.symbol} {count}" for level, count in freq.items() if count) + return summary + + async def _format_overview(self, status_dict: dict[str, ComponentStatus]): + """ + Format an overview (one line per component) from a status dict. + """ + lines = [] + for name, status in status_dict.items(): + lines.append(f"{status.level.symbol} {name}: {status.short}") + summary = await self._format_summary(status_dict) + return '\n'.join((summary, *lines)) + + async def get_summary(self): + return await self._format_summary(await self.request()) + + async def get_overview(self): + return await self._format_overview(await self.request()) diff --git a/src/meta/sharding.py b/src/meta/sharding.py new file mode 100644 index 0000000..14da402 --- /dev/null +++ b/src/meta/sharding.py @@ -0,0 +1,35 @@ +from .args import args +from .config import conf + +from psycopg import sql +from data.conditions import Condition, Joiner + + +shard_number = args.shard or 0 + +shard_count = conf.bot.getint('shard_count', 1) + +sharded = (shard_count > 0) + + +def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition: + """ + Condition constructor for filtering by shard id. + + Example Usage + ------------- + Query.where(_shard_condition('guildid', 10, 1)) + """ + return Condition( + sql.SQL("({guildid} >> 22) %% {shard_count}").format( + guildid=sql.Identifier(guild_column), + shard_count=sql.Literal(shard_count) + ), + Joiner.EQUALS, + sql.Placeholder(), + (shard_id,) + ) + + +# Pre-built Condition for filtering by current shard. +THIS_SHARD = SHARDID(shard_number) diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..07e226e --- /dev/null +++ b/src/modules/__init__.py @@ -0,0 +1,9 @@ +this_package = 'modules' + +active = [ +] + + +async def setup(bot): + for ext in active: + await bot.load_extension(ext, package=this_package) diff --git a/src/settings/__init__.py b/src/settings/__init__.py new file mode 100644 index 0000000..9c38704 --- /dev/null +++ b/src/settings/__init__.py @@ -0,0 +1,7 @@ +from babel.translator import LocalBabel +babel = LocalBabel('settings_base') + +from .data import ModelData, ListData +from .base import BaseSetting +from .ui import SettingWidget, InteractiveSetting +from .groups import SettingDotDict, SettingGroup, ModelSettings, ModelSetting diff --git a/src/settings/base.py b/src/settings/base.py new file mode 100644 index 0000000..0cbfd0d --- /dev/null +++ b/src/settings/base.py @@ -0,0 +1,166 @@ +from typing import Generic, TypeVar, Type, Optional, overload + + +""" +Setting metclass? +Parse setting docstring to generate default info? +Or just put it in the decorator we are already using +""" + + +# Typing using Generic[parent_id_type, data_type, value_type] +# value generic, could be Union[?, UNSET] +ParentID = TypeVar('ParentID') +SettingData = TypeVar('SettingData') +SettingValue = TypeVar('SettingValue') + +T = TypeVar('T', bound='BaseSetting') + + +class BaseSetting(Generic[ParentID, SettingData, SettingValue]): + """ + Abstract base class describing a stored configuration setting. + A setting consists of logic to load the setting from storage, + present it in a readable form, understand user entered values, + and write it again in storage. + Additionally, the setting has attributes attached describing + the setting in a user-friendly manner for display purposes. + """ + setting_id: str # Unique source identifier for the setting + + _default: Optional[SettingData] = None # Default data value for the setting + + def __init__(self, parent_id: ParentID, data: Optional[SettingData], **kwargs): + self.parent_id = parent_id + self._data = data + self.kwargs = kwargs + + # Instance generation + @classmethod + async def get(cls: Type[T], parent_id: ParentID, **kwargs) -> T: + """ + Return a setting instance initialised from the stored value, associated with the given parent id. + """ + data = await cls._reader(parent_id, **kwargs) + return cls(parent_id, data, **kwargs) + + # Main interface + @property + def data(self) -> Optional[SettingData]: + """ + Retrieves the current internal setting data if it is set, otherwise the default data + """ + return self._data if self._data is not None else self.default + + @data.setter + def data(self, new_data: Optional[SettingData]): + """ + Sets the internal raw data. + Does not write the changes. + """ + self._data = new_data + + @property + def default(self) -> Optional[SettingData]: + """ + Retrieves the default value for this setting. + Settings should override this if the default depends on the object id. + """ + return self._default + + @property + def value(self) -> SettingValue: # Actually optional *if* _default is None + """ + Context-aware object or objects associated with the setting. + """ + return self._data_to_value(self.parent_id, self.data) # type: ignore + + @value.setter + def value(self, new_value: Optional[SettingValue]): + """ + Setter which reads the discord-aware object and converts it to data. + Does not write the new value. + """ + self._data = self._data_from_value(self.parent_id, new_value) + + async def write(self, **kwargs) -> None: + """ + Write current data to the database. + For settings which override this, + ensure you handle deletion of values when internal data is None. + """ + await self._writer(self.parent_id, self._data, **kwargs) + + # Raw converters + @overload + @classmethod + def _data_from_value(cls: Type[T], parent_id: ParentID, value: SettingValue, **kwargs) -> SettingData: + ... + + @overload + @classmethod + def _data_from_value(cls: Type[T], parent_id: ParentID, value: None, **kwargs) -> None: + ... + + @classmethod + def _data_from_value( + cls: Type[T], parent_id: ParentID, value: Optional[SettingValue], **kwargs + ) -> Optional[SettingData]: + """ + Convert a high-level setting value to internal data. + Must be overridden by the setting. + Be aware of UNSET values, these should always pass through as None + to provide an unsetting interface. + """ + raise NotImplementedError + + @overload + @classmethod + def _data_to_value(cls: Type[T], parent_id: ParentID, data: SettingData, **kwargs) -> SettingValue: + ... + + @overload + @classmethod + def _data_to_value(cls: Type[T], parent_id: ParentID, data: None, **kwargs) -> None: + ... + + @classmethod + def _data_to_value( + cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs + ) -> Optional[SettingValue]: + """ + Convert internal data to high-level setting value. + Must be overriden by the setting. + """ + raise NotImplementedError + + # Database access + @classmethod + async def _reader(cls: Type[T], parent_id: ParentID, **kwargs) -> Optional[SettingData]: + """ + Retrieve the setting data associated with the given parent_id. + May be None if the setting is not set. + Must be overridden by the setting. + """ + raise NotImplementedError + + @classmethod + async def _writer(cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs) -> None: + """ + Write provided setting data to storage. + Must be overridden by the setting unless the `write` method is overridden. + If the data is None, the setting is UNSET and should be deleted. + """ + raise NotImplementedError + + @classmethod + async def setup(cls, bot): + """ + Initialisation task to be executed during client initialisation. + May be used for e.g. populating a cache or required client setup. + + Main application must execute the initialisation task before the setting is used. + Further, the task must always be executable, if the setting is loaded. + Conditional initialisation should go in the relevant module's init tasks. + """ + return None diff --git a/src/settings/data.py b/src/settings/data.py new file mode 100644 index 0000000..9f627ad --- /dev/null +++ b/src/settings/data.py @@ -0,0 +1,233 @@ +from typing import Type +import json + +from data import RowModel, Table, ORDER +from meta.logger import log_wrap, set_logging_context + + +class ModelData: + """ + Mixin for settings stored in a single row and column of a Model. + Assumes that the parent_id is the identity key of the Model. + + This does not create a reference to the Row. + """ + # Table storing the desired data + _model: Type[RowModel] + + # Column with the desired data + _column: str + + # Whether to create a row if not found + _create_row = False + + # High level data cache to use, leave as None to disable cache. + _cache = None # Map[id -> value] + + @classmethod + def _read_from_row(cls, parent_id, row, **kwargs): + data = row[cls._column] + + if cls._cache is not None: + cls._cache[parent_id] = data + + return data + + @classmethod + async def _reader(cls, parent_id, use_cache=True, **kwargs): + """ + Read in the requested column associated to the parent id. + """ + if cls._cache is not None and parent_id in cls._cache and use_cache: + return cls._cache[parent_id] + + model = cls._model + if cls._create_row: + row = await model.fetch_or_create(parent_id) + else: + row = await model.fetch(parent_id) + data = row[cls._column] if row else None + + if cls._cache is not None: + cls._cache[parent_id] = data + + return data + + @classmethod + async def _writer(cls, parent_id, data, **kwargs): + """ + Write the provided entry to the table. + This does *not* create the row if it does not exist. + It only updates. + """ + # TODO: Better way of getting the key? + # TODO: Transaction + if not isinstance(parent_id, tuple): + parent_id = (parent_id, ) + model = cls._model + rows = await model.table.update_where( + **model._dict_from_id(parent_id) + ).set( + **{cls._column: data} + ) + # If we didn't update any rows, create a new row + if not rows: + await model.fetch_or_create(**model._dict_from_id(parent_id), **{cls._column: data}) + + if cls._cache is not None: + cls._cache[parent_id] = data + + +class ListData: + """ + Mixin for list types implemented on a Table. + Implements a reader and writer. + This assumes the list is the only data stored in the table, + and removes list entries by deleting rows. + """ + setting_id: str + + # Table storing the setting data + _table_interface: Table + + # Name of the column storing the id + _id_column: str + + # Name of the column storing the data to read + _data_column: str + + # Name of column storing the order index to use, if any. Assumed to be Serial on writing. + _order_column: str + _order_type: ORDER = ORDER.ASC + + # High level data cache to use, set to None to disable cache. + _cache = None # Map[id -> value] + + @classmethod + @log_wrap(isolate=True) + async def _reader(cls, parent_id, use_cache=True, **kwargs): + """ + Read in all entries associated to the given id. + """ + set_logging_context(action="Read cls.setting_id") + if cls._cache is not None and parent_id in cls._cache and use_cache: + return cls._cache[parent_id] + + table = cls._table_interface # type: Table + query = table.select_where(**{cls._id_column: parent_id}).select(cls._data_column) + if cls._order_column: + query.order_by(cls._order_column, direction=cls._order_type) + + rows = await query + data = [row[cls._data_column] for row in rows] + + if cls._cache is not None: + cls._cache[parent_id] = data + + return data + + @classmethod + @log_wrap(isolate=True) + async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs): + """ + Write the provided list to storage. + """ + set_logging_context(action="Write cls.setting_id") + table = cls._table_interface + async with table.connector.connection() as conn: + table.connector.conn = conn + async with conn.transaction(): + # Handle None input as an empty list + if data is None: + data = [] + + current = await cls._reader(id, use_cache=False, **kwargs) + if not cls._order_column and (add_only or remove_only): + to_insert = [item for item in data if item not in current] if not remove_only else [] + to_remove = data if remove_only else ( + [item for item in current if item not in data] if not add_only else [] + ) + + # Handle required deletions + if to_remove: + params = { + cls._id_column: id, + cls._data_column: to_remove + } + await table.delete_where(**params) + + # Handle required insertions + if to_insert: + columns = (cls._id_column, cls._data_column) + values = [(id, value) for value in to_insert] + await table.insert_many(columns, *values) + + if cls._cache is not None: + new_current = [item for item in current + to_insert if item not in to_remove] + cls._cache[id] = new_current + else: + # Remove all and add all to preserve order + delete_params = {cls._id_column: id} + await table.delete_where(**delete_params) + + if data: + columns = (cls._id_column, cls._data_column) + values = [(id, value) for value in data] + await table.insert_many(columns, *values) + + if cls._cache is not None: + cls._cache[id] = data + + +class KeyValueData: + """ + Mixin for settings implemented in a Key-Value table. + The underlying table should have a Unique constraint on the `(_id_column, _key_column)` pair. + """ + _table_interface: Table + + _id_column: str + + _key_column: str + + _value_column: str + + _key: str + + @classmethod + async def _reader(cls, id, **kwargs): + params = { + cls._id_column: id, + cls._key_column: cls._key + } + + row = await cls._table_interface.select_one_where(**params).select(cls._value_column) + data = row[cls._value_column] if row else None + + if data is not None: + data = json.loads(data) + + return data + + @classmethod + async def _writer(cls, id, data, **kwargs): + params = { + cls._id_column: id, + cls._key_column: cls._key + } + if data is not None: + values = { + cls._value_column: json.dumps(data) + } + rows = await cls._table_interface.update_where(**params).set(**values) + if not rows: + await cls._table_interface.insert_many( + (cls._id_column, cls._key_column, cls._value_column), + (id, cls._key, json.dumps(data)) + ) + else: + await cls._table_interface.delete_where(**params) + + +# class UserInputError(SafeCancellation): +# pass diff --git a/src/settings/groups.py b/src/settings/groups.py new file mode 100644 index 0000000..3719a7b --- /dev/null +++ b/src/settings/groups.py @@ -0,0 +1,204 @@ +from typing import Generic, Type, TypeVar, Optional, overload + +from data import RowModel + +from .data import ModelData +from .ui import InteractiveSetting +from .base import BaseSetting + +from utils.lib import tabulate + + +T = TypeVar('T', bound=InteractiveSetting) + + +class SettingDotDict(Generic[T], dict[str, Type[T]]): + """ + Dictionary structure allowing simple dot access to items. + """ + __getattr__ = dict.__getitem__ # type: ignore + __setattr__ = dict.__setitem__ # type: ignore + __delattr__ = dict.__delitem__ # type: ignore + + +class SettingGroup: + """ + A SettingGroup is a collection of settings under one name. + """ + __initial_settings__: list[Type[InteractiveSetting]] = [] + + _title: Optional[str] = None + _description: Optional[str] = None + + def __init_subclass__(cls, title: Optional[str] = None): + cls._title = title or cls._title + cls._description = cls._description or cls.__doc__ + + settings: list[Type[InteractiveSetting]] = [] + for item in cls.__dict__.values(): + if isinstance(item, type) and issubclass(item, InteractiveSetting): + settings.append(item) + cls.__initial_settings__ = settings + + def __init_settings__(self): + settings = SettingDotDict() + for setting in self.__initial_settings__: + settings[setting.__name__] = setting + return settings + + def __init__(self, title=None, description=None) -> None: + self.title: str = title or self._title or self.__class__.__name__ + self.description: str = description or self._description or "" + self.settings: SettingDotDict[InteractiveSetting] = self.__init_settings__() + + def attach(self, cls: Type[T], name: Optional[str] = None): + name = name or cls.setting_id + self.settings[name] = cls + return cls + + def detach(self, cls): + return self.settings.pop(cls.__name__, None) + + def update(self, smap): + self.settings.update(smap.settings) + + def reduce(self, *keys): + for key in keys: + self.settings.pop(key, None) + return + + async def make_setting_table(self, parent_id, **kwargs): + """ + Convenience method for generating a rendered setting table. + """ + rows = [] + for setting in self.settings.values(): + if not setting._virtual: + set = await setting.get(parent_id, **kwargs) + name = set.display_name + value = str(set.formatted) + rows.append((name, value, set.hover_desc)) + table_rows = tabulate( + *rows, + row_format="[`{invis}{key:<{pad}}{colon}`](https://lionbot.org \"{field[2]}\")\t{value}" + ) + return '\n'.join(table_rows) + + +class ModelSetting(ModelData, BaseSetting): + ... + + +class ModelConfig: + """ + A ModelConfig provides a central point of configuration for any object described by a single Model. + + An instance of a ModelConfig represents configuration for a single object + (given by a single row of the corresponding Model). + + The ModelConfig also supports registration of non-model configuration, + to support associated settings (e.g. list-settings) for the object. + + This is an ABC, and must be subclassed for each object-type. + """ + settings: SettingDotDict + _model_settings: set + model: Type[RowModel] + + def __init__(self, parent_id, row, **kwargs): + self.parent_id = parent_id + self.row = row + self.kwargs = kwargs + + @classmethod + def register_setting(cls, setting_cls): + """ + Decorator to register a non-model setting as part of the object configuration. + + The setting class may be re-accessed through the `settings` class attr. + + Subclasses may provide alternative access pathways to key non-model settings. + """ + cls.settings[setting_cls.setting_id] = setting_cls + return setting_cls + + @classmethod + def register_model_setting(cls, model_setting_cls): + """ + Decorator to register a model setting as part of the object configuration. + + The setting class may be accessed through the `settings` class attr. + + A fresh setting instance may also be retrieved (using cached data) + through the `get` instance method. + + Subclasses are recommended to provide model settings as properties + for simplified access and type checking. + """ + cls._model_settings.add(model_setting_cls.setting_id) + return cls.register_setting(model_setting_cls) + + def get(self, setting_id): + """ + Retrieve a freshly initialised copy of the given model-setting. + + The given `setting_id` must have been previously registered through `register_model_setting`. + This uses cached data, and so is not guaranteed to be up-to-date. + """ + if setting_id not in self._model_settings: + # TODO: Log + raise ValueError + setting_cls = self.settings[setting_id] + data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs) + return setting_cls(self.parent_id, data, **self.kwargs) + + +class ModelSettings: + """ + A ModelSettings instance aggregates multiple `ModelSetting` instances + bound to the same parent id on a single Model. + + This enables a single point of access + for settings of a given Model, + with support for caching or deriving as needed. + + This is an abstract base class, + and should be subclassed to define the contained settings. + """ + _settings: SettingDotDict = SettingDotDict() + model: Type[RowModel] + + def __init__(self, parent_id, row, **kwargs): + self.parent_id = parent_id + self.row = row + self.kwargs = kwargs + + @classmethod + async def fetch(cls, *parent_id, **kwargs): + """ + Load an instance of this ModelSetting with the given parent_id + and setting keyword arguments. + """ + row = await cls.model.fetch_or_create(*parent_id) + return cls(parent_id, row, **kwargs) + + @classmethod + def attach(self, setting_cls): + """ + Decorator to attach the given setting class to this modelsetting. + """ + # This violates the interface principle, use structured typing instead? + if not (issubclass(setting_cls, BaseSetting) and issubclass(setting_cls, ModelData)): + raise ValueError( + f"The provided setting class must be `ModelSetting`, not {setting_cls.__class__.__name__}." + ) + self._settings[setting_cls.setting_id] = setting_cls + return setting_cls + + def get(self, setting_id): + setting_cls = self._settings.get(setting_id) + data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs) + return setting_cls(self.parent_id, data, **self.kwargs) + + def __getitem__(self, setting_id): + return self.get(setting_id) diff --git a/src/settings/mock.py b/src/settings/mock.py new file mode 100644 index 0000000..e80a4c9 --- /dev/null +++ b/src/settings/mock.py @@ -0,0 +1,13 @@ +import discord +from discord import app_commands + + +class LocalString: + def __init__(self, string): + self.string = string + + def as_string(self): + return self.string + + +_ = LocalString diff --git a/src/settings/setting_types.py b/src/settings/setting_types.py new file mode 100644 index 0000000..4cf4177 --- /dev/null +++ b/src/settings/setting_types.py @@ -0,0 +1,1393 @@ +from typing import Optional, Union, TYPE_CHECKING, TypeVar, Generic, Any, TypeAlias, Type +from enum import Enum + +import pytz +import discord +import discord.app_commands as appcmds + +import itertools +import datetime as dt +from discord import ui +from discord.ui.button import button, Button, ButtonStyle +from dateutil.parser import parse, ParserError + +from meta.context import ctx_bot +from meta.errors import UserInputError +from utils.lib import strfdur, parse_duration +from babel.translator import ctx_translator, LazyStr + +from .base import ParentID +from .ui import InteractiveSetting, SettingWidget +from . import babel + +_, _p = babel._, babel._p + + +if TYPE_CHECKING: + from discord.guild import GuildChannel + + +# TODO: Localise this file + + +class StringSetting(InteractiveSetting[ParentID, str, str]): + """ + Setting type mixin describing an arbitrary string type. + + Options + ------- + _maxlen: int + Maximum length of string to accept in `_parse_string`. + Default: 4000 + + _quote: bool + Whether to display the string with backticks. + Default: True + """ + + _accepts = _p('settype:string|accepts', "Any Text") + + _maxlen: int = 4000 + _quote: bool = True + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Return the provided value string as the data string. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Return the provided data string as the value string. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the user input `string` into StringSetting data. + Provides some minor input validation. + Treats an empty string as a `None` value. + """ + t = ctx_translator.get().t + if len(string) > cls._maxlen: + raise UserInputError( + t(_p( + 'settype:string|error', + "Provided string is too long! Maximum length: {maxlen} characters." + )).format(maxlen=cls._maxlen) + ) + elif len(string) == 0: + return None + else: + return string + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Optionally (see `_quote`) wrap the data string in backticks. + """ + if data: + return "`{}`".format(data) if cls._quote else str(data) + else: + return None + + +class EmojiSetting(InteractiveSetting[ParentID, str, str]): + """ + Setting type representing a stored emoji. + + The emoji is stored in a single string field, and at no time is guaranteed to be a valid emoji. + """ + _accepts = _p('settype:emoji|accepts', "Paste a builtin emoji, custom emoji, or emoji id.") + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Return the provided value string as the data string. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Return the provided data string as the value string. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the given user entered emoji string. + + Accepts unicode (builtin) emojis, custom emojis, and custom emoji ids. + """ + t = ctx_translator.get().t + + provided = string + string = string.strip(' :<>') + if string.startswith('a:'): + string = string[2:] + + if not string or string.lower() == 'none': + emojistr = None + elif string.isdigit(): + # Assume emoji id + emojistr = f"" + elif ':' in string: + # Assume custom emoji + emojistr = provided.strip() + elif string.isascii(): + # Probably not an emoji + raise UserInputError( + t(_p( + 'settype:emoji|error:parse', + "Could not parse `{provided}` as a Discord emoji. " + "Supported formats are builtin emojis (e.g. `{builtin}`), " + "custom emojis (e.g. {custom}), " + "or custom emoji ids (e.g. `{custom_id}`)." + )).format( + provided=provided, + builtin="🤔", + custom="*`<`*`CuteLeo:942499177135480942`*`>`*", + custom_id="942499177135480942", + ) + ) + else: + # We don't have a good way of testing for emoji unicode + # So just assume anything with unicode is an emoji. + emojistr = string + + return emojistr + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Optionally (see `_quote`) wrap the data string in backticks. + """ + if data: + return data + else: + return None + + @property + def as_partial(self) -> Optional[discord.PartialEmoji]: + return self._parse_emoji(self.data) + + @staticmethod + def _parse_emoji(emojistr: str): + """ + Converts a provided string into a PartialEmoji. + Deos not validate the emoji string. + """ + if not emojistr: + return None + elif ":" in emojistr: + emojistr = emojistr.strip('<>') + splits = emojistr.split(":") + if len(splits) == 3: + animated, name, id = splits + animated = bool(animated) + return discord.PartialEmoji(name=name, animated=animated, id=int(id)) + else: + return discord.PartialEmoji(name=emojistr) + + +CT = TypeVar('CT', 'GuildChannel', 'discord.Object', 'discord.Thread') +MCT = TypeVar('MCT', discord.TextChannel, discord.Thread, discord.VoiceChannel, discord.Object) + + +class ChannelSetting(Generic[ParentID, CT], InteractiveSetting[ParentID, int, CT]): + """ + Setting type mixin describing a Guild Channel. + + Options + ------- + _selector_placeholder: str + Placeholder to use in the Widget selector. + Default: "Select a channel" + + channel_types: list[discord.ChannelType] + List of guild channel types to accept. + Default: [] + """ + _accepts = _p('settype:channel|accepts', "A channel name or id") + + _selector_placeholder = "Select a Channel" + channel_types: list[discord.ChannelType] = [] + _allow_object = False + + @classmethod + def _data_from_value(cls, parent_id, value, **kwargs): + """ + Returns the id of the provided channel. + """ + if value is not None: + return value.id + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + """ + Searches for the provided channel id in the current channel cache. + If the channel cannot be found, returns a `discord.Object` instead. + """ + if data is not None: + bot = ctx_bot.get() + channel = bot.get_channel(data) + if channel is None and cls._allow_object: + channel = discord.Object(id=data) + return channel + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + if not string or string.lower() == 'none': + return None + + t = ctx_translator.get().t + bot = ctx_bot.get() + channel = None + guild = bot.get_guild(parent_id) + + if string.isdigit(): + maybe_id = int(string) + channel = guild.get_channel(maybe_id) + else: + channel = next((channel for channel in guild.channels if channel.name.lower() == string.lower()), None) + + if channel is None: + raise UserInputError(t(_p( + 'settype:channel|parse|error:not_found', + "Channel `{string}` could not be found in this guild!".format(string=string) + ))) + return channel.id + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Returns a manually formatted channel mention. + """ + if data: + return "<#{}>".format(data) + + @property + def input_formatted(self) -> str: + data = self._data + return str(data) if data else '' + + class Widget(SettingWidget['ChannelSetting']): + def update_children(self): + self.update_child( + self.channel_selector, { + 'channel_types': self.setting.channel_types, + 'placeholder': self.setting._selector_placeholder + } + ) + + def make_exports(self): + return [self.channel_selector] + + @ui.select( + cls=ui.ChannelSelect, + channel_types=[discord.ChannelType.text], + placeholder="Select a Channel", + max_values=1, + min_values=0 + ) + async def channel_selector(self, interaction: discord.Interaction, select: discord.ui.ChannelSelect) -> None: + await interaction.response.defer(thinking=True, ephemeral=True) + if select.values: + channel = select.values[0] + await self.setting.interactive_set(channel.id, interaction) + else: + await self.setting.interactive_set(None, interaction) + + +class VoiceChannelSetting(ChannelSetting): + """ + Setting type mixin representing a discord VoiceChannel. + Implemented as a narrowed `ChannelSetting`. + See `ChannelSetting` for options. + """ + channel_types = [discord.ChannelType.voice] + + +class MessageablelSetting(ChannelSetting): + """ + Setting type mixin representing a discord Messageable guild channel. + Implemented as a narrowed `ChannelSetting`. + See `ChannelSetting` for options. + """ + channel_types = [discord.ChannelType.text, discord.ChannelType.voice, discord.ChannelType.public_thread] + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + """ + Searches for the provided channel id in the current channel cache. + If the channel cannot be found, returns a `discord.PartialMessageable` instead. + """ + if data is not None: + bot = ctx_bot.get() + channel = bot.get_channel(data) + if channel is None: + channel = bot.get_partial_messageable(data, guild_id=parent_id) + return channel + + +class RoleSetting(InteractiveSetting[ParentID, int, Union[discord.Role, discord.Object]]): + """ + Setting type mixin describing a Guild Role. + + Options + ------- + _selector_placeholder: str + Placeholder to use in the Widget selector. + Default: "Select a Role" + """ + _accepts = _p('settype:role|accepts', "A role name or id") + + _selector_placeholder = "Select a Role" + _allow_object = False + + @classmethod + def _get_guildid(cls, parent_id: int, **kwargs) -> int: + """ + Fetch the current guildid. + Assumes that the guilid is either passed as a kwarg or is the object id. + Should be overridden in other cases. + """ + return kwargs.get('guildid', parent_id) + + @classmethod + def _data_from_value(cls, parent_id, value, **kwargs): + """ + Returns the id of the provided role. + """ + if value is not None: + return value.id + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + """ + Searches for the provided role id in the current channel cache. + If the channel cannot be found, returns a `discord.Object` instead. + """ + if data is not None: + role = None + + guildid = cls._get_guildid(parent_id, **kwargs) + bot = ctx_bot.get() + guild = bot.get_guild(guildid) + if guild is not None: + role = guild.get_role(data) + if role is None and cls._allow_object: + role = discord.Object(id=data) + return role + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + if not string or string.lower() == 'none': + return None + guildid = cls._get_guildid(parent_id, **kwargs) + + t = ctx_translator.get().t + bot = ctx_bot.get() + role = None + guild = bot.get_guild(guildid) + if guild is None: + raise ValueError("Attempting to parse role string with no guild.") + + if string.isdigit(): + maybe_id = int(string) + role = guild.get_role(maybe_id) + else: + role = next((role for role in guild.roles if role.name.lower() == string.lower()), None) + + if role is None: + raise UserInputError(t(_p( + 'settype:role|parse|error:not_found', + "Role `{string}` could not be found in this guild!".format(string=string) + ))) + return role.id + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Returns a manually formatted role mention. + """ + if data: + return "<@&{}>".format(data) + else: + return None + + @property + def input_formatted(self) -> str: + data = self._data + return str(data) if data else '' + + class Widget(SettingWidget['RoleSetting']): + def update_children(self): + self.update_child( + self.role_selector, + {'placeholder': self.setting._selector_placeholder} + ) + + def make_exports(self): + return [self.role_selector] + + @ui.select( + cls=ui.RoleSelect, + placeholder="Select a Role", + max_values=1, + min_values=0 + ) + async def role_selector(self, interaction: discord.Interaction, select: discord.ui.RoleSelect) -> None: + await interaction.response.defer(thinking=True, ephemeral=True) + if select.values: + role = select.values[0] + await self.setting.interactive_set(role.id, interaction) + else: + await self.setting.interactive_set(None, interaction) + + +class BoolSetting(InteractiveSetting[ParentID, bool, bool]): + """ + Setting type mixin describing a boolean. + + Options + ------- + _truthy: Set + Set of strings that are considered "truthy" in the parser. + Not case sensitive. + Default: {"yes", "true", "on", "enable", "enabled"} + + _falsey: Set + Set of strings that are considered "falsey" in the parser. + Not case sensitive. + Default: {"no", "false", "off", "disable", "disabled"} + + _outputs: tuple[str, str, str] + Strings to represent 'True', 'False', and 'None' values respectively. + Default: {True: "On", False: "Off", None: "Not Set"} + """ + + _accepts = _p('settype:bool|accepts', "Enabled/Disabled") + + # Values that are accepted as truthy and falsey by the parser + _truthy = _p( + 'settype:bool|parse:truthy_values', + "enabled|yes|true|on|enable|1" + ) + _falsey = _p( + 'settype:bool|parse:falsey_values', + 'disabled|no|false|off|disable|0' + ) + + # The user-friendly output strings to use for each value + _outputs = { + True: _p('settype:bool|output:true', "On"), + False: _p('settype:bool|output:false', "Off"), + None: _p('settype:bool|output:none', "Not Set"), + } + + # Button labels + _true_button_args: dict[str, Any] = {} + _false_button_args: dict[str, Any] = {} + _reset_button_args: dict[str, Any] = {} + + @classmethod + def truthy_values(cls) -> set: + t = ctx_translator.get().t + return t(cls._truthy).lower().split('|') + + @classmethod + def falsey_values(cls) -> set: + t = ctx_translator.get().t + return t(cls._falsey).lower().split('|') + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + t = ctx_translator.get().t + output = t(self._outputs[self._data]) + input_set = self.truthy_values() if self._data else self.falsey_values() + + if output.lower() in input_set: + return output + else: + return next(iter(input_set)) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Directly return provided value bool as data bool. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Directly return provided data bool as value bool. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Looks up the provided string in the truthy and falsey tables. + """ + _userstr = string.lower() + if not _userstr or _userstr == "none": + return None + if _userstr in cls.truthy_values(): + return True + elif _userstr in cls.falsey_values(): + return False + else: + raise UserInputError("Could not parse `{}` as a boolean.".format(string)) + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Use provided _outputs dictionary to format data. + """ + t = ctx_translator.get().t + return t(cls._outputs[data]) + + class Widget(SettingWidget['BoolSetting']): + def update_children(self): + self.update_child(self.true_button, self.setting._true_button_args) + self.update_child(self.false_button, self.setting._false_button_args) + self.update_child(self.reset_button, self.setting._reset_button_args) + self.order_children(self.true_button, self.false_button, self.reset_button) + + def make_exports(self): + return [self.true_button, self.false_button, self.reset_button] + + @button(style=ButtonStyle.secondary, label="On", row=4) + async def true_button(self, interaction: discord.Interaction, button: Button): + await interaction.response.defer(thinking=True, ephemeral=True) + await self.setting.interactive_set(True, interaction) + + @button(style=ButtonStyle.secondary, label="Off", row=4) + async def false_button(self, interaction: discord.Interaction, button: Button): + await interaction.response.defer(thinking=True, ephemeral=True) + await self.setting.interactive_set(False, interaction) + + +class IntegerSetting(InteractiveSetting[ParentID, int, int]): + """ + Setting type mixin describing a ranged integer. + As usual, override `_parse_string` to customise error messages. + + Options + ------- + _min: int + A minimum integer to accept. + Default: -2147483647 + + _max: int + A maximum integer to accept. + Default: 2147483647 + """ + _min = -2147483647 + _max = 2147483647 + + _accepts = _p('settype:integer|accepts', "An integer") + + @property + def input_formatted(self) -> str: + """ + Return a string representation of the set integer. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Directly return value integer as data integer. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Directly return data integer as value integer. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + if not string: + return None + try: + num = int(string) + except Exception: + raise UserInputError("Couldn't parse provided integer.") from None + + if num > cls._max: + raise UserInputError("Provided integer was too large!") + elif num < cls._min: + raise UserInputError("Provided integer was too small!") + + return num + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Returns the stringified integer in backticks. + """ + if data is not None: + return f"`{data}`" + + +class PartialEmojiSetting(InteractiveSetting[ParentID, str, discord.PartialEmoji]): + """ + Setting type mixin describing an Emoji string. + + Options + ------- + None + """ + + _accepts = _p('settype:emoji|desc', "Unicode or custom emoji") + + @staticmethod + def _parse_emoji(emojistr): + """ + Converts a provided string into a PartialEmoji. + If the string is badly formatted, returns None. + """ + if ":" in emojistr: + emojistr = emojistr.strip('<>') + splits = emojistr.split(":") + if len(splits) == 3: + animated, name, id = splits + animated = bool(animated) + return discord.PartialEmoji(name=name, animated=animated, id=int(id)) + else: + # TODO: Check whether this is a valid emoji + return discord.PartialEmoji(name=emojistr) + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Stringify the value emoji into a consistent data string. + """ + return str(value) if value is not None else None + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Convert the stored string into an emoji, through parse_emoji. + This may return None if the parsing syntax changes. + """ + return cls._parse_emoji(data) if data is not None else None + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the provided string into a PartialEmoji if possible. + """ + if string: + emoji = cls._parse_emoji(string) + if emoji is None: + raise UserInputError("Could not understand provided emoji!") + return str(emoji) + return None + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Emojis are pretty much self-formatting. Just return the data directly. + """ + return data + + +class GuildIDSetting(InteractiveSetting[ParentID, int, int]): + """ + Setting type mixin describing a guildid. + This acts like a pure integer type, apart from the formatting. + + Options + ------- + """ + _accepts = _p('settype:guildid|accepts', "Any Snowflake ID") + # TODO: Consider autocomplete for guilds the user is in + + @property + def input_formatted(self) -> str: + """ + Return a string representation of the stored snowflake. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Directly return value integer as data integer. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Directly return data integer as value integer. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + if not string: + return None + try: + num = int(string) + except Exception: + raise UserInputError("Couldn't parse provided guildid.") from None + return num + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Return the stored snowflake as a string. + If the guild is in cache, attach the name as well. + """ + if data is not None: + bot = ctx_bot.get() + guild = bot.get_guild(data) + if guild is not None: + return f"`{data}` ({guild.name})" + else: + return f"`{data}`" + + +TZT: TypeAlias = pytz.BaseTzInfo + + +class TimezoneSetting(InteractiveSetting[ParentID, str, TZT]): + """ + Typed Setting ABC representing timezone information. + """ + # TODO: Consider configuration UI for timezone by continent and country + # Do any continents have more than 25 countries? + # Maybe list e.g. Europe (Austria - Iceland) and Europe (Ireland - Ukraine) separately + + # TODO Definitely need autocomplete here + _accepts = _p( + 'settype:timezone|accepts', + "A timezone name from the 'tz database' (e.g. 'Europe/London')" + ) + + @property + def input_formatted(self) -> str: + """ + Return a string representation of the stored timezone. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Use str to transform the pytz timezone into a string. + """ + if value: + return str(value) + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Use pytz to convert the stored timezone string to a timezone. + """ + if data: + return pytz.timezone(data) + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + # TODO: Localise + # TODO: Another selection case. + if not string: + return None + try: + timezone = pytz.timezone(string) + except pytz.exceptions.UnknownTimeZoneError: + timezones = [tz for tz in pytz.all_timezones if string.lower() in tz.lower()] + if len(timezones) == 1: + timezone = timezones[0] + elif timezones: + raise UserInputError("Multiple matching timezones found!") + # TODO: Add a selector-message here instead of dying instantly + # Maybe only post a selector if there are less than 25 options! + + # result = await ctx.selector( + # "Multiple matching timezones found, please select one.", + # timezones + # ) + # timezone = timezones[result] + else: + raise UserInputError( + "Unknown timezone `{}`. " + "Please provide a TZ name from " + "[this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones)".format(string) + ) from None + return str(timezone) + + def _desc_table(self) -> list[str]: + translator = ctx_translator.get() + t = translator.t + + lines = super()._desc_table() + lines.append(( + t(_p( + 'settype:timezone|summary_table|field:supported|key', + "Supported" + )), + t(_p( + 'settype:timezone|summary_table|field:supported|value', + "Any timezone from the [tz database]({link})." + )).format(link="https://en.wikipedia.org/wiki/List_of_tz_database_time_zones") + )) + return lines + + @classmethod + async def parse_acmpl(cls, interaction: discord.Interaction, partial: str): + bot = interaction.client + t = bot.translator.t + + timezones = pytz.all_timezones + matching = [tz for tz in timezones if partial.strip().lower() in tz.lower()][:25] + if not matching: + choices = [ + appcmds.Choice( + name=t(_p( + 'set_type:timezone|acmpl|no_matching', + "No timezones matching '{input}'!" + )).format(input=partial)[:100], + value=partial + ) + ] + else: + choices = [] + for tz in matching: + timezone = pytz.timezone(tz) + now = dt.datetime.now(timezone) + nowstr = now.strftime("%H:%M") + name = t(_p( + 'set_type:timezone|acmpl|choice', + "{tz} (Currently {now})" + )).format(tz=tz, now=nowstr) + choice = appcmds.Choice( + name=name[:100], + value=tz + ) + choices.append(choice) + return choices + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Return the stored snowflake as a string. + If the guild is in cache, attach the name as well. + """ + if data is not None: + return f"`{data}`" + + +class TimestampSetting(InteractiveSetting[ParentID, str, dt.datetime]): + """ + Typed Setting ABC representing a fixed point in time. + + Data is assumed to be a timezone aware datetime object. + Value is the same as data. + Parsing accepts YYYY-MM-DD [HH:MM] [+TZ] + Display uses a discord timestamp. + """ + _accepts = _p( + 'settype:timestamp|accepts', + "A timestamp in the form YYYY-MM-DD HH:MM" + ) + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + return value + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data, **kwargs): + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + string = string.strip() + if string.lower() in ('', 'none', '0'): + ts = None + else: + local_tz = await cls._timezone_from_id(parent_id, **kwargs) + now = dt.datetime.now(tz=local_tz) + default = now.replace( + hour=0, minute=0, + second=0, microsecond=0 + ) + try: + ts = parse(string, fuzzy=True, default=default) + except ParserError: + t = ctx_translator.get().t + raise UserInputError(t(_p( + 'settype:timestamp|parse|error:invalid', + "Could not parse `{provided}` as a timestamp. Please use `YYYY-MM-DD HH:MM` format." + )).format(provided=string)) + return ts + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + if data is not None: + return "".format(int(data.timestamp())) + + @classmethod + async def _timezone_from_id(cls, parent_id: ParentID, **kwargs): + """ + Extract the parsing timezone from the given parent id. + + Should generally be overriden for interactive settings. + """ + return pytz.UTC + + @property + def input_formatted(self) -> str: + if self._data: + formatted = self._data.strftime('%Y-%m-%d %H:%M') + else: + formatted = '' + return formatted + + +class RawSetting(InteractiveSetting[ParentID, Any, Any]): + """ + Basic implementation of an interactive setting with identical value and data type. + """ + _accepts = _p('settype:raw|accepts', "Anything") + + @property + def input_formatted(self) -> str: + return str(self._data) if self._data is not None else '' + + @classmethod + def _data_from_value(cls, parent_id, value, **kwargs): + return value + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + return string + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + return str(data) if data is not None else None + + +ET = TypeVar('ET', bound='Enum') + + +class EnumSetting(InteractiveSetting[ParentID, ET, ET]): + """ + Typed InteractiveSetting ABC representing a stored Enum. + The Enum is assumed to be data adapted (e.g. through RegisterEnum). + + The embed of an enum setting should usually be overridden to describe the options. + + The default widget is implemented as a select menu, + although it may also make sense to implement using colour-changing buttons. + + Options + ------- + _enum: Enum + The Enum to act as a setting interface to. + _outputs: dict[Enum, str] + A map of enum items to output strings. + Describes how the enum should be formatted. + _inputs: dict[Enum, str] + A map of accepted input strings (not case sensitive) to enum items. + This should almost always include the strings from `_outputs`. + """ + + _enum: Type[ET] + _outputs: dict[ET, LazyStr] + _input_patterns: dict[ET: LazyStr] + _input_formatted: dict[ET: LazyStr] + + _accepts = _p('settype:enum|accepts', "A valid option.") + + @property + def input_formatted(self) -> str: + """ + Return the output string for the current data. + This assumes the output strings are accepted as inputs! + """ + t = ctx_translator.get().t + if self._data is not None: + return t(self._input_formatted[self._data]) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Return the provided value enum item as the data enum item. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Return the provided data enum item as the value enum item. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an enum item. + """ + if not string: + return None + + string = string.lower() + t = ctx_translator.get().t + + found = None + for enumitem, pattern in cls._input_patterns.items(): + item_keys = set(t(pattern).lower().split('|')) + if string in item_keys: + found = enumitem + break + + if not found: + raise UserInputError( + t(_p( + 'settype:enum|parse|error:not_found', + "`{provided}` is not a valid option!" + )).format(provided=string) + ) + + return found + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Format the enum using the provided output map. + """ + t = ctx_translator.get().t + if data is not None: + if data not in cls._outputs: + raise ValueError(f"Enum item {data} unmapped.") + return t(cls._outputs[data]) + + +class DurationSetting(InteractiveSetting[ParentID, int, int]): + """ + Typed InteractiveSetting ABC representing a stored duration. + Stored and retrieved as an integer number of seconds. + Shown and set as a "duration string", e.g. "24h 10m 20s". + + Options + ------- + _max: int + Upper limit on the stored duration, in seconds. + Default: 60 * 60 * 24 * 365 + _min: Optional[int] + Lower limit on the stored duration, in seconds. + The duration can never be negative. + _default_multiplier: int + Default multiplier to use to convert the number when it is provided alone. + E.g. 1 for seconds, or 60 for minutes. + Default: 1 + allow_zero: bool + Whether to allow a zero duration. + The duration parser typically returns 0 when no duration is found, + so this may be useful for error checking. + Default: False + _show_days: bool + Whether to show days in the formatted output. + Default: False + """ + + _accepts = _p( + 'settype:duration|accepts', + "A number of days, hours, minutes, and seconds, e.g. `2d 4h 10s`." + ) + + # Set an upper limit on the duration + _max = 60 * 60 * 24 * 365 + _min = None + + # Default multiplier when the number is provided alone + # 1 for seconds, 60 from minutes, etc + _default_multiplier = None + + # Whether to allow empty durations + # This is particularly useful since the duration parser will return 0 for most non-duration strings + allow_zero = False + + # Whether to show days on the output + _show_days = False + + @property + def input_formatted(self) -> str: + """ + Return the formatted duration, which is accepted as input. + """ + if self._data is not None: + return strfdur(self._data, short=True, show_days=self._show_days) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Passthrough the provided duration in seconds. + """ + return value + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data, **kwargs): + """ + Passthrough the provided duration in seconds. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into a duration. + """ + if not string: + return None + + if cls._default_multiplier and string.isdigit(): + num = int(string) * cls._default_multiplier + else: + num = parse_duration(string) + + if num is None: + raise UserInputError("Could not parse the provided duration!") + + if num == 0 and not cls.allow_zero: + raise UserInputError( + "The provided duration cannot be `0`! (Please enter in the format `1d 2h 3m 4s`.)" + ) + + if cls._max is not None and num > cls._max: + raise UserInputError( + "Duration cannot be longer than `{}`!".format( + strfdur(cls._max, short=False, show_days=cls._show_days) + ) + ) + if cls._min is not None and num < cls._min: + raise UserInputError( + "Duration cannot be shorter than `{}`!".format( + strfdur(cls._min, short=False, show_days=cls._show_days) + ) + ) + + return num + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Format the enum using the provided output map. + """ + if data is not None: + return "`{}`".format(strfdur(data, short=False, show_days=cls._show_days)) + + +class ListSetting: + """ + Mixin to implement a setting type representing a list of existing settings. + + Does not implement a Widget, + since arbitrary combinations of setting widgets are undefined. + """ + # Base setting type to make the list from + _setting = None # type: Type[InteractiveSetting] + + # Whether 'None' values are filtered out of the data when creating values + _allow_null_values = False # type: bool + + # Whether duplicate data values should be filtered out + _force_unique = False + + @classmethod + def _data_from_value(cls, parent_id: ParentID, values, **kwargs): + """ + Returns the setting type data for each value in the value list + """ + if values is None: + # Special behaviour here, store an empty list instead of None + return [] + else: + return [cls._setting._data_from_value(parent_id, value) for value in values] + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data, **kwargs): + """ + Returns the setting type value for each entry in the data list + """ + if data is None: + return [] + else: + values = [cls._setting._data_to_value(parent_id, entry) for entry in data] + + # Filter out null values if required + if not cls._allow_null_values: + values = [value for value in values if value is not None] + return values + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Splits the user string across `,` to break up the list. + """ + if not string: + return [] + else: + data = [] + items = (item.strip() for item in string.split(',')) + items = (item for item in items if item) + data = [await cls._setting._parse_string(parent_id, item, **kwargs) for item in items] + + if cls._force_unique: + data = list(set(data)) + return data + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Format the list by adding `,` between each formatted item + """ + if data: + formatted_items = [] + for item in data: + formatted_item = cls._setting._format_data(id, item) + if formatted_item is not None: + formatted_items.append(formatted_item) + return ", ".join(formatted_items) + + @property + def input_formatted(self): + """ + Format the list by adding `,` between each input formatted item. + """ + if self._data: + formatted_items = [] + for item in self._data: + formatted_item = self._setting(self.parent_id, item).input_formatted + if formatted_item: + formatted_items.append(formatted_item) + return ", ".join(formatted_items) + else: + return "" + + +class ChannelListSetting(ListSetting, InteractiveSetting): + """ + List of channels + """ + _accepts = _p( + 'settype:channel_list|accepts', + "Comma separated list of channel ids." + ) + _setting = ChannelSetting + + +class RoleListSetting(ListSetting, InteractiveSetting): + """ + List of roles + """ + _accepts = _p( + 'settype:role_list|accepts', + 'Comma separated list of role ids.' + ) + _setting = RoleSetting + + @property + def members(self): + roles = self.value + return list(set(itertools.chain(*(role.members for role in roles)))) + + +class StringListSetting(InteractiveSetting, ListSetting): + """ + List of strings + """ + _accepts = _p( + 'settype:stringlist|accepts', + 'Comma separated strings.' + ) + _setting = StringSetting + + +class GuildIDListSetting(ListSetting, InteractiveSetting): + """ + List of guildids. + """ + _accepts = _p( + 'settype:guildidlist|accepts', + 'Comma separated list of guild ids.' + ) + + _setting = GuildIDSetting diff --git a/src/settings/ui.py b/src/settings/ui.py new file mode 100644 index 0000000..b2a263e --- /dev/null +++ b/src/settings/ui.py @@ -0,0 +1,512 @@ +from typing import Optional, Callable, Any, Dict, Coroutine, Generic, TypeVar, List +import asyncio +from contextvars import copy_context + +import discord +from discord import ui +from discord.ui.button import ButtonStyle, Button, button +from discord.ui.modal import Modal +from discord.ui.text_input import TextInput +from meta.errors import UserInputError + +from utils.lib import tabulate, recover_context +from utils.ui import FastModal +from meta.config import conf +from meta.context import ctx_bot +from babel.translator import ctx_translator, LazyStr + +from .base import BaseSetting, ParentID, SettingData, SettingValue +from . import babel + +_p = babel._p + + +ST = TypeVar('ST', bound='InteractiveSetting') + + +class SettingModal(FastModal): + input_field: TextInput = TextInput(label="Edit Setting") + + def update_field(self, new_field): + self.remove_item(self.input_field) + self.add_item(new_field) + self.input_field = new_field + + +class SettingWidget(Generic[ST], ui.View): + # TODO: Permission restrictions and callback! + # Context variables for permitted user(s)? Subclass ui.View with PermittedView? + # Don't need to descend permissions to Modal + # Maybe combine with timeout manager + + def __init__(self, setting: ST, auto_write=True, **kwargs): + self.setting = setting + self.update_children() + super().__init__(**kwargs) + self.auto_write = auto_write + + self._interaction: Optional[discord.Interaction] = None + self._modal: Optional[SettingModal] = None + self._exports: List[ui.Item] = self.make_exports() + + self._context = copy_context() + + def update_children(self): + """ + Method called before base View initialisation. + Allows updating the children components (usually explicitly defined callbacks), + before Item instantiation. + """ + pass + + def order_children(self, *children): + """ + Helper method to set and order the children using bound methods. + """ + child_map = {child.__name__: child for child in self.__view_children_items__} + self.__view_children_items__ = [child_map[child.__name__] for child in children] + + def update_child(self, child, new_args): + args = getattr(child, '__discord_ui_model_kwargs__') + args |= new_args + + def make_exports(self): + """ + Called post-instantiation to populate self._exports. + """ + return self.children + + def refresh(self): + """ + Update widget components from current setting data, if applicable. + E.g. to update the default entry in a select list after a choice has been made, + or update button colours. + This does not trigger a discord ui update, + that is the responsibility of the interaction handler. + """ + pass + + async def show(self, interaction: discord.Interaction, key: Any = None, override=False, **kwargs): + """ + Complete standard setting widget UI flow for this setting. + The SettingWidget components may be attached to other messages as needed, + and they may be triggered individually, + but this coroutine defines the standard interface. + Intended for use by any interaction which wants to "open the setting". + + Extra keyword arguments are passed directly to the interaction reply (for e.g. ephemeral). + """ + if key is None: + # By default, only have one widget listener per interaction. + key = ('widget', interaction.id) + + # If there is already a widget listening on this key, respect override + if self.setting.get_listener(key) and not override: + # Refuse to spawn another widget + return + + async def update_callback(new_data): + self.setting.data = new_data + await interaction.edit_original_response(embed=self.setting.embed, view=self, **kwargs) + + self.setting.register_callback(key)(update_callback) + await interaction.response.send_message(embed=self.setting.embed, view=self, **kwargs) + await self.wait() + try: + # Try and detach the view, since we aren't handling events anymore. + await interaction.edit_original_response(view=None) + except discord.HTTPException: + pass + self.setting.deregister_callback(key) + + def attach(self, group_view: ui.View): + """ + Attach this setting widget to a view representing several settings. + """ + for item in self._exports: + group_view.add_item(item) + + @button(style=ButtonStyle.secondary, label="Edit", row=4) + async def edit_button(self, interaction: discord.Interaction, button: ui.Button): + """ + Spawn a simple edit modal, + populated with `setting.input_field`. + """ + recover_context(self._context) + # Spawn the setting modal + await interaction.response.send_modal(self.modal) + + @button(style=ButtonStyle.danger, label="Reset", row=4) + async def reset_button(self, interaction: discord.Interaction, button: Button): + recover_context(self._context) + await interaction.response.defer(thinking=True, ephemeral=True) + await self.setting.interactive_set(None, interaction) + + @property + def modal(self) -> Modal: + """ + Build a Modal dialogue for updating the setting. + Refreshes (and re-attaches) the input field each time this is called. + """ + if self._modal is not None: + self._modal.update_field(self.setting.input_field) + return self._modal + + # TODO: Attach shared timeouts to the modal + self._modal = modal = SettingModal( + title=f"Edit {self.setting.display_name}", + ) + modal.update_field(self.setting.input_field) + + @modal.submit_callback() + async def edit_submit(interaction: discord.Interaction): + # TODO: Catch and handle UserInputError + await interaction.response.defer(thinking=True, ephemeral=True) + data = await self.setting._parse_string(self.setting.parent_id, modal.input_field.value) + await self.setting.interactive_set(data, interaction) + + return modal + + +class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]): + __slots__ = ('_widget',) + + # Configuration interface descriptions + _display_name: LazyStr # User readable name of the setting + _desc: LazyStr # User readable brief description of the setting + _long_desc: LazyStr # User readable long description of the setting + _accepts: LazyStr # User readable description of the acceptable values + _set_cmd: str = None + _notset_str: LazyStr = _p('setting|formatted|notset', "Not Set") + _virtual: bool = False # Whether the setting should be hidden from tables and dashboards + _required: bool = False + + Widget = SettingWidget + + # A list of callback coroutines to call when the setting updates + # This can be used globally to refresh state when the setting updates, + # Or locallly to e.g. refresh an active widget. + # The callbacks are called on write, so they may be bypassed by direct use of _writer! + _listeners_: Dict[Any, Callable[[Optional[SettingData]], Coroutine[Any, Any, None]]] = {} + + # Optional client event to dispatch when theis setting has been written + # Event handlers should be of the form Callable[ParentID, SettingData] + _event: Optional[str] = None + + # Interaction ward that should be validated via interaction_check + _write_ward: Optional[Callable[[discord.Interaction], Coroutine[Any, Any, bool]]] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._widget: Optional[SettingWidget] = None + + @property + def long_desc(self): + t = ctx_translator.get().t + bot = ctx_bot.get() + return t(self._long_desc).format( + bot=bot, + cmds=bot.core.mention_cache + ) + + @property + def display_name(self): + t = ctx_translator.get().t + return t(self._display_name) + + @property + def desc(self): + t = ctx_translator.get().t + return t(self._desc) + + @property + def accepts(self): + t = ctx_translator.get().t + return t(self._accepts) + + async def write(self, **kwargs) -> None: + await super().write(**kwargs) + self.dispatch_update() + for listener in self._listeners_.values(): + asyncio.create_task(listener(self.data)) + + def dispatch_update(self): + """ + Dispatch a client event along `self._event`, if set. + + Override to modify the target event handler arguments. + By default, event handlers should be of the form: + Callable[[ParentID, SettingData], Coroutine[Any, Any, None]] + """ + if self._event is not None and (bot := ctx_bot.get()) is not None: + bot.dispatch(self._event, self.parent_id, self) + + def get_listener(self, key): + return self._listeners_.get(key, None) + + @classmethod + def register_callback(cls, name=None): + def wrapped(coro): + cls._listeners_[name or coro.__name__] = coro + return coro + return wrapped + + @classmethod + def deregister_callback(cls, name): + cls._listeners_.pop(name, None) + + @property + def update_message(self): + """ + Response message sent when the setting has successfully been updated. + Should generally be one line. + """ + if self.data is None: + return "Setting reset!" + else: + return f"Setting Updated! New value: {self.formatted}" + + @property + def hover_desc(self): + """ + This no longer works since Discord changed the hover rules. + + return '\n'.join(( + self.display_name, + '=' * len(self.display_name), + self.desc, + f"\nAccepts: {self.accepts}" + )) + """ + return self.desc + + async def update_response(self, interaction: discord.Interaction, message: Optional[str] = None, **kwargs): + """ + Respond to an interaction which triggered a setting update. + Usually just wraps `update_message` in an embed and sends it back. + Passes any extra `kwargs` to the message creation method. + """ + embed = discord.Embed( + description=f"{str(conf.emojis.tick)} {message or self.update_message}", + colour=discord.Color.green() + ) + if interaction.response.is_done(): + await interaction.edit_original_response(embed=embed, **kwargs) + else: + await interaction.response.send_message(embed=embed, **kwargs) + + async def interactive_set(self, new_data: Optional[SettingData], interaction: discord.Interaction, **kwargs): + self.data = new_data + await self.write() + await self.update_response(interaction, **kwargs) + + async def format_in(self, bot, **kwargs): + """ + Formatted version of the setting given an asynchronous context with client. + """ + return self.formatted + + @property + def embed_field(self): + """ + Returns a {name, value} pair for use in an Embed field. + """ + name = self.display_name + value = f"{self.long_desc}\n{self.desc_table}" + if len(value) > 1024: + t = ctx_translator.get().t + desc_table = '\n'.join( + tabulate( + *self._desc_table( + show_value=t(_p( + 'setting|embed_field|too_long', + "Too long to display here!" + )) + ) + ) + ) + value = f"{self.long_desc}\n{desc_table}" + if len(value) > 1024: + # Forcibly trim + value = value[:1020] + '...' + return {'name': name, 'value': value} + + @property + def set_str(self): + if self._set_cmd is not None: + bot = ctx_bot.get() + if bot: + return bot.core.mention_cmd(self._set_cmd) + else: + return f"`/{self._set_cmd}`" + + @property + def notset_str(self): + t = ctx_translator.get().t + return t(self._notset_str) + + @property + def embed(self): + """ + Returns a full embed describing this setting. + """ + t = ctx_translator.get().t + embed = discord.Embed( + title=t(_p( + 'setting|summary_embed|title', + "Configuration options for `{name}`" + )).format(name=self.display_name), + ) + embed.description = "{}\n{}".format(self.long_desc.format(self=self), self.desc_table) + return embed + + def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]: + t = ctx_translator.get().t + lines = [] + + # Currently line + lines.append(( + t(_p('setting|summary_table|field:currently|key', "Currently")), + show_value or (self.formatted or self.notset_str) + )) + + # Default line + if (default := self.default) is not None: + lines.append(( + t(_p('setting|summary_table|field:default|key', "By Default")), + self._format_data(self.parent_id, default) or 'None' + )) + + # Set using line + if (set_str := self.set_str) is not None: + lines.append(( + t(_p('setting|summary_table|field:set|key', "Set Using")), + set_str + )) + return lines + + @property + def desc_table(self) -> str: + return '\n'.join(tabulate(*self._desc_table())) + + @property + def input_field(self) -> TextInput: + """ + TextInput field used for string-based setting modification. + May be added to external modal for grouped setting editing. + This property is not persistent, and creates a new field each time. + """ + return TextInput( + label=self.display_name, + placeholder=self.accepts, + default=self.input_formatted[:4000] if self.input_formatted else None, + required=self._required + ) + + @property + def widget(self): + """ + Returns the Discord UI View associated with the current setting. + """ + if self._widget is None: + self._widget = self.Widget(self) + return self._widget + + @classmethod + def set_widget(cls, WidgetCls): + """ + Convenience decorator to create the widget class for this setting. + """ + cls.Widget = WidgetCls + return WidgetCls + + @property + def formatted(self): + """ + Default user-readable form of the setting. + Should be a short single line. + """ + return self._format_data(self.parent_id, self.data, **self.kwargs) or self.notset_str + + @property + def input_formatted(self) -> str: + """ + Format the current value as a default value for an input field. + Returned string must be acceptable through parse_string. + Does not take into account defaults. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @property + def summary(self): + """ + Formatted summary of the data. + May be implemented in `_format_data(..., summary=True, ...)` or overidden. + """ + return self._format_data(self.parent_id, self.data, summary=True, **self.kwargs) + + @classmethod + async def from_string(cls, parent_id, userstr: str, **kwargs): + """ + Return a setting instance initialised from a parsed user string. + """ + data = await cls._parse_string(parent_id, userstr, **kwargs) + return cls(parent_id, data, **kwargs) + + @classmethod + async def from_value(cls, parent_id, value, **kwargs): + await cls._check_value(parent_id, value, **kwargs) + data = cls._data_from_value(parent_id, value, **kwargs) + return cls(parent_id, data, **kwargs) + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs) -> Optional[SettingData]: + """ + Parse user provided string (usually from a TextInput) into raw setting data. + Must be overriden by the setting if the setting is user-configurable. + Returns None if the setting was unset. + """ + raise NotImplementedError + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Convert raw setting data into a formatted user-readable string, + representing the current value. + """ + raise NotImplementedError + + @classmethod + async def _check_value(cls, parent_id, value, **kwargs): + """ + Check the provided value is valid. + + Many setting update methods now provide Discord objects instead of raw data or user strings. + This method may be used for value-checking such a value. + + Raises UserInputError if the value fails validation. + """ + pass + + @classmethod + async def interaction_check(cls, parent_id, interaction: discord.Interaction, **kwargs): + if cls._write_ward is not None and not await cls._write_ward(interaction): + # TODO: Combine the check system so we can do customised errors here + t = ctx_translator.get().t + raise UserInputError(t(_p( + 'setting|interaction_check|error', + "You do not have sufficient permissions to do this!" + ))) + + +""" +command callback for set command? +autocomplete for set command? + +Might be better in a ConfigSetting subclass. +But also mix into the base setting types. +""" diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/ansi.py b/src/utils/ansi.py new file mode 100644 index 0000000..11f2852 --- /dev/null +++ b/src/utils/ansi.py @@ -0,0 +1,97 @@ +""" +Minimal library for making Discord Ansi colour codes. +""" +from enum import StrEnum + + +PREFIX = u'\u001b' + + +class TextColour(StrEnum): + Gray = '30' + Red = '31' + Green = '32' + Yellow = '33' + Blue = '34' + Pink = '35' + Cyan = '36' + White = '37' + + def __str__(self) -> str: + return AnsiColour(fg=self).as_str() + + def __call__(self): + return AnsiColour(fg=self) + + +class BgColour(StrEnum): + FireflyDarkBlue = '40' + Orange = '41' + MarbleBlue = '42' + GrayTurq = '43' + Gray = '44' + Indigo = '45' + LightGray = '46' + White = '47' + + def __str__(self) -> str: + return AnsiColour(bg=self).as_str() + + def __call__(self): + return AnsiColour(bg=self) + + +class Format(StrEnum): + NORMAL = '0' + BOLD = '1' + UNDERLINE = '4' + NOOP = '9' + + def __str__(self) -> str: + return AnsiColour(self).as_str() + + def __call__(self): + return AnsiColour(self) + + +class AnsiColour: + def __init__(self, *flags, fg=None, bg=None): + self.text_colour = fg + self.background_colour = bg + self.reset = (Format.NORMAL in flags) + self._flags = set(flags) + self._flags.discard(Format.NORMAL) + + @property + def flags(self): + return (*((Format.NORMAL,) if self.reset else ()), *self._flags) + + def as_str(self): + parts = [] + if self.reset: + parts.append(Format.NORMAL) + elif not self.flags: + parts.append(Format.NOOP) + + parts.extend(self._flags) + + for c in (self.text_colour, self.background_colour): + if c is not None: + parts.append(c) + + partstr = ';'.join(part.value for part in parts) + return f"{PREFIX}[{partstr}m" # ] + + def __str__(self): + return self.as_str() + + def __add__(self, obj: 'AnsiColour'): + text_colour = obj.text_colour or self.text_colour + background_colour = obj.background_colour or self.background_colour + flags = (*self.flags, *obj.flags) + return AnsiColour(*flags, fg=text_colour, bg=background_colour) + + +RESET = AnsiColour(Format.NORMAL) +BOLD = AnsiColour(Format.BOLD) +UNDERLINE = AnsiColour(Format.UNDERLINE) diff --git a/src/utils/data.py b/src/utils/data.py new file mode 100644 index 0000000..a590430 --- /dev/null +++ b/src/utils/data.py @@ -0,0 +1,165 @@ +""" +Some useful pre-built Conditions for data queries. +""" +from typing import Optional, Any +from itertools import chain + +from psycopg import sql +from data.conditions import Condition, Joiner +from data.columns import ColumnExpr +from data.base import Expression +from constants import MAX_COINS + + +def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[Any, ...]) -> Condition: + """ + Condition constructor for filtering by multiple column equalities. + + Example Usage + ------------- + Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4))) + """ + if not data: + raise ValueError("Cannot create empty multivalue condition.") + left = sql.SQL("({})").format( + sql.SQL(', ').join( + sql.Identifier(key) + for key in columns + ) + ) + right_item = sql.SQL('({})').format( + sql.SQL(', ').join( + sql.Placeholder() + for _ in columns + ) + ) + right = sql.SQL("({})").format( + sql.SQL(', ').join( + right_item + for _ in data + ) + ) + return Condition( + left, + Joiner.IN, + right, + chain(*data) + ) + + +def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition: + """ + Condition constructor for filtering member tables by guild and user id simultaneously. + + Example Usage + ------------- + Query.where(MEMBERS((1234,12), (5678,34))) + """ + if not memberids: + raise ValueError("Cannot create a condition with no members") + return Condition( + sql.SQL("({guildid}, {userid})").format( + guildid=sql.Identifier(guild_column), + userid=sql.Identifier(user_column) + ), + Joiner.IN, + sql.SQL("({})").format( + sql.SQL(', ').join( + sql.SQL("({}, {})").format( + sql.Placeholder(), + sql.Placeholder() + ) for _ in memberids + ) + ), + chain(*memberids) + ) + + +def as_duration(expr: Expression) -> ColumnExpr: + """ + Convert an integer expression into a duration expression. + """ + expr_expr, expr_values = expr.as_tuple() + return ColumnExpr( + sql.SQL("({} * interval '1 second')").format(expr_expr), + expr_values + ) + + +class TemporaryTable(Expression): + """ + Create a temporary table expression to be used in From or With clauses. + + Example + ------- + ``` + tmp_table = TemporaryTable('_col1', '_col2', name='data') + tmp_table.values((1, 2), (3, 4)) + + real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table) + ``` + """ + + def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str, ...]] = None): + self.name = name + self.columns = columns + self.types = types + if types and len(types) != len(columns): + raise ValueError("Number of types does not much number of columns!") + + self._table_columns = { + col: ColumnExpr(sql.Identifier(name, col)) + for col in columns + } + + self.values = [] + + def __getitem__(self, key) -> sql.Identifier: + return self._table_columns[key] + + def as_tuple(self): + """ + (VALUES {}) + AS + name (col1, col2) + """ + if not self.values: + raise ValueError("Cannot flatten CTE with no values.") + + single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns)) + if self.types: + first_value = sql.SQL("({})").format( + sql.SQL(", ").join( + sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast)) + for cast in self.types + ) + ) + else: + first_value = single_value + + value_placeholder = sql.SQL("(VALUES {})").format( + sql.SQL(", ").join( + (first_value, *(single_value for _ in self.values[1:])) + ) + ) + expr = sql.SQL("{values} AS {name} ({columns})").format( + values=value_placeholder, + name=sql.Identifier(self.name), + columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns) + ) + values = chain(*self.values) + return (expr, values) + + def set_values(self, *data): + self.values = data + + +def SAFECOINS(expr: Expression) -> Expression: + expr_expr, expr_values = expr.as_tuple() + return ColumnExpr( + sql.SQL("LEAST({}, {})").format( + expr_expr, + sql.Literal(MAX_COINS) + ), + expr_values + ) diff --git a/src/utils/lib.py b/src/utils/lib.py new file mode 100644 index 0000000..d7796fc --- /dev/null +++ b/src/utils/lib.py @@ -0,0 +1,879 @@ +from io import StringIO +from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any +import collections +import datetime +import datetime as dt +import iso8601 # type: ignore +import pytz +import re +import json +from contextvars import Context + +import discord +from discord.partial_emoji import _EmojiTag +from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage +from discord.ui import View + +from meta.errors import UserInputError + + +multiselect_regex = re.compile( + r"^([0-9, -]+)$", + re.DOTALL | re.IGNORECASE | re.VERBOSE +) +tick = '✅' +cross = '❌' + +MISSING = object() + + +class MessageArgs: + """ + Utility class for storing message creation and editing arguments. + """ + # TODO: Overrides for mutually exclusive arguments, see Messageable.send + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + @overload + def __init__( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ) -> None: + ... + + def __init__(self, **kwargs): + self.kwargs = kwargs + + @property + def send_args(self) -> dict: + if self.kwargs.get('view', MISSING) is None: + kwargs = self.kwargs.copy() + kwargs.pop('view') + else: + kwargs = self.kwargs + + return kwargs + + @property + def edit_args(self) -> dict: + args = {} + kept = ( + 'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view' + ) + for k in kept: + if k in self.kwargs: + args[k] = self.kwargs[k] + + if 'file' in self.kwargs: + args['attachments'] = [self.kwargs['file']] + + if 'files' in self.kwargs: + args['attachments'] = self.kwargs['files'] + + if 'suppress_embeds' in self.kwargs: + args['suppress'] = self.kwargs['suppress_embeds'] + + return args + + +def tabulate( + *fields: tuple[str, str], + row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}", + sub_format: str = "`{invis:<{pad}}{colon}`\t{value}", + colon: str = ':', + invis: str = "​", + **args +) -> list[str]: + """ + Turns a list of (property, value) pairs into + a pretty string with one `prop: value` pair each line, + padded so that the colons in each line are lined up. + Use `\\r\\n` in a value to break the line with padding. + + Parameters + ---------- + fields: List[tuple[str, str]] + List of (key, value) pairs. + row_format: str + The format string used to format each row. + sub_format: str + The format string used to format each subline in a row. + colon: str + The colon character used. + invis: str + The invisible character used (to avoid Discord stripping the string). + + Returns: List[str] + The list of resulting table rows. + Each row corresponds to one (key, value) pair from fields. + """ + max_len = max(len(field[0]) for field in fields) + + rows = [] + for field in fields: + key = field[0] + value = field[1] + lines = value.split('\r\n') + + row_line = row_format.format( + invis=invis, + key=key, + pad=max_len, + colon=colon, + value=lines[0], + field=field, + **args + ) + if len(lines) > 1: + row_lines = [row_line] + for line in lines[1:]: + sub_line = sub_format.format( + invis=invis, + pad=max_len + len(colon), + colon=colon, + value=line, + **args + ) + row_lines.append(sub_line) + row_line = '\n'.join(row_lines) + rows.append(row_line) + return rows + + +def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]: + """ + Create pretty codeblock pages from a list of strings. + + Parameters + ---------- + item_list: List[str] + List of strings to paginate. + block_length: int + Maximum number of strings per page. + style: str + Codeblock style to use. + Title formatting assumes the `markdown` style, and numbered lists work well with this. + However, `markdown` sometimes messes up formatting in the list. + title: str + Optional title to add to the top of each page. + + Returns: List[str] + List of pages, each formatted into a codeblock, + and containing at most `block_length` of the provided strings. + """ + lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)] + page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)] + pages = [] + for i, block in enumerate(page_blocks): + pagenum = "Page {}/{}".format(i + 1, len(page_blocks)) + if title: + header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title + else: + header = pagenum + header_line = "=" * len(header) + full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else "" + pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block))) + return pages + + +def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]: + """ + Break the text into blocks of maximum length blocksize + If possible, break across nearby newlines. Otherwise just break at blocksize chars + + Parameters + ---------- + text: str + Text to break into blocks. + blocksize: int + Maximum character length for each block. + code: bool + Whether to wrap each block in codeblocks (these are counted in the blocksize). + syntax: str + The markdown formatting language to use for the codeblocks, if applicable. + maxheight: int + The maximum number of lines in each block + + Returns: List[str] + List of blocks, + each containing at most `block_size` characters, + of height at most `maxheight`. + """ + # Adjust blocksize to account for the codeblocks if required + blocksize = blocksize - 8 - len(syntax) if code else blocksize + + # Build the blocks + blocks = [] + while True: + # If the remaining text is already small enough, append it + if len(text) <= blocksize: + blocks.append(text) + break + text = text.strip('\n') + + # Find the last newline in the prototype block + split_on = text[0:blocksize].rfind('\n') + split_on = blocksize if split_on < blocksize // 5 else split_on + + # Add the block and truncate the text + blocks.append(text[0:split_on]) + text = text[split_on:] + + # Add the codeblock ticks and the code syntax header, if required + if code: + blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks] + + return blocks + + +def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str: + """ + Convert a datetime.timedelta object into an easily readable duration string. + + Parameters + ---------- + delta: datetime.timedelta + The timedelta object to convert into a readable string. + sec: bool + Whether to include the seconds from the timedelta object in the string. + minutes: bool + Whether to include the minutes from the timedelta object in the string. + short: bool + Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s"). + + Returns: str + A string containing a time from the datetime.timedelta object, in a readable format. + Time units will be abbreviated if short was set to True. + """ + output = [[delta.days, 'd' if short else ' day'], + [delta.seconds // 3600, 'h' if short else ' hour']] + if minutes: + output.append([delta.seconds // 60 % 60, 'm' if short else ' minute']) + if sec: + output.append([delta.seconds % 60, 's' if short else ' second']) + for i in range(len(output)): + if output[i][0] != 1 and not short: + output[i][1] += 's' # type: ignore + reply_msg = [] + if output[0][0] != 0: + reply_msg.append("{}{} ".format(output[0][0], output[0][1])) + if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2: + reply_msg.append("{}{} ".format(output[1][0], output[1][1])) + for i in range(2, len(output) - 1): + reply_msg.append("{}{} ".format(output[i][0], output[i][1])) + if not short and reply_msg: + reply_msg.append("and ") + reply_msg.append("{}{}".format(output[-1][0], output[-1][1])) + return "".join(reply_msg) + + +def _parse_dur(time_str: str) -> int: + """ + Parses a user provided time duration string into a timedelta object. + + Parameters + ---------- + time_str: str + The time string to parse. String can include days, hours, minutes, and seconds. + + Returns: int + The number of seconds the duration represents. + """ + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + time_str = time_str.strip(" ,") + found = re.findall(r'(\d+)\s?(\w+?)', time_str) + seconds = 0 + for bit in found: + if bit[1] in funcs: + seconds += funcs[bit[1]](int(bit[0])) + return seconds + + +def strfdur(duration: int, short=True, show_days=False) -> str: + """ + Convert a duration given in seconds to a number of hours, minutes, and seconds. + """ + days = duration // (3600 * 24) if show_days else 0 + hours = duration // 3600 + if days: + hours %= 24 + minutes = duration // 60 % 60 + seconds = duration % 60 + + parts = [] + if days: + unit = 'd' if short else (' days' if days != 1 else ' day') + parts.append('{}{}'.format(days, unit)) + if hours: + unit = 'h' if short else (' hours' if hours != 1 else ' hour') + parts.append('{}{}'.format(hours, unit)) + if minutes: + unit = 'm' if short else (' minutes' if minutes != 1 else ' minute') + parts.append('{}{}'.format(minutes, unit)) + if seconds or duration == 0: + unit = 's' if short else (' seconds' if seconds != 1 else ' second') + parts.append('{}{}'.format(seconds, unit)) + + if short: + return ' '.join(parts) + else: + return ', '.join(parts) + + +def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str: + """ + Substitutes a user provided list of numbers and ranges, + and replaces the ranges by the corresponding list of numbers. + + Parameters + ---------- + ranges_str: str + The string to ranges in. + max_match: int + The maximum number of ranges to replace. + Any ranges exceeding this will be ignored. + max_range: int + The maximum length of range to replace. + Attempting to replace a range longer than this will raise a `ValueError`. + """ + def _repl(match): + n1 = int(match.group(1)) + n2 = int(match.group(2)) + if n2 - n1 > max_range: + # TODO: Upgrade to SafeCancellation + raise ValueError("Provided range is too large!") + return separator.join(str(i) for i in range(n1, n2 + 1)) + + return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match) + + +def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]: + """ + Parses a user provided range string into a list of numbers. + Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`. + """ + substituted = substitute_ranges(ranges_str, separator=separator, **kwargs) + _numbers = (item.strip() for item in substituted.split(',')) + numbers = [item for item in _numbers if item] + integers = [int(item) for item in numbers if item.isdigit()] + + if not ignore_errors and len(integers) != len(numbers): + # TODO: Upgrade to SafeCancellation + raise ValueError( + "Couldn't parse the provided selection!\n" + "Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`." + ) + + return integers + + +def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str: + """ + Format a message into a string with various information, such as: + the timestamp of the message, author, message content, and attachments. + + Parameters + ---------- + msg: Message + The message to format. + mask_link: bool + Whether to mask the URLs of any attachments. + line_break: bool + Whether a line break should be used in the string. + tz: Timezone + The timezone to use in the formatted message. + clean: bool + Whether to use the clean content of the original message. + + Returns: str + A formatted string containing various information: + User timezone, message author, message content, attachments + """ + timestr = "%I:%M %p, %d/%m/%Y" + if tz: + time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr) + else: + time = msg.created_at.strftime(timestr) + user = str(msg.author) + attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url] + if mask_link: + attach_list = ["[Link]({})".format(url) for url in attach_list] + attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else "" + return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format( + time=time, + user=user, + line_break="\n" if line_break else "", + message=msg.clean_content if clean else msg.content, + attachments=attachments + ) + + +def convdatestring(datestring: str) -> datetime.timedelta: + """ + Convert a date string into a datetime.timedelta object. + + Parameters + ---------- + datestring: str + The string to convert to a datetime.timedelta object. + + Returns: datetime.timedelta + A datetime.timedelta object formed from the string provided. + """ + datestring = datestring.strip(' ,') + datearray = [] + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + currentnumber = '' + for char in datestring: + if char.isdigit(): + currentnumber += char + else: + if currentnumber == '': + continue + datearray.append((int(currentnumber), char)) + currentnumber = '' + seconds = 0 + if currentnumber: + seconds += int(currentnumber) + for i in datearray: + if i[1] in funcs: + seconds += funcs[i[1]](i[0]) + return datetime.timedelta(seconds=seconds) + + +class _rawChannel(discord.abc.Messageable): + """ + Raw messageable class representing an arbitrary channel, + not necessarially seen by the gateway. + """ + def __init__(self, state, id): + self._state = state + self.id = id + + async def _get_channel(self): + return discord.Object(self.id) + + +async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message: + """ + Mails a message to a channelid which may be invisible to the gateway. + + Parameters: + client: discord.Client + The client to use for mailing. + Must at least have static authentication and have a valid `_connection`. + channelid: int + The channel id to mail to. + msg_args: Any + Message keyword arguments which are passed transparently to `_rawChannel.send(...)`. + """ + # Create the raw channel + channel = _rawChannel(client._connection, channelid) + return await channel.send(**msg_args) + + +class EmbedField(NamedTuple): + name: str + value: str + inline: Optional[bool] = True + + +def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]): + """ + Append embed fields to an embed. + Parameters + ---------- + embed: discord.Embed + The embed to add the field to. + emb_fields: tuple + The values to add to a field. + name: str + The name of the field. + value: str + The value of the field. + inline: bool + Whether the embed field should be inline or not. + """ + for field in emb_fields: + embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2])) + + +def join_list(string: list[str], nfs=False) -> str: + """ + Join a list together, separated with commas, plus add "and" to the beginning of the last value. + Parameters + ---------- + string: list + The list to join together. + nfs: bool + (no fullstops) + Whether to exclude fullstops/periods from the output messages. + If not provided, fullstops will be appended to the output. + """ + # TODO: Probably not useful with localisation + if len(string) > 1: + return "{}{} and {}{}".format((", ").join(string[:-1]), + "," if len(string) > 2 else "", string[-1], "" if nfs else ".") + else: + return "{}{}".format("".join(string), "" if nfs else ".") + + +def shard_of(shard_count: int, guildid: int) -> int: + """ + Calculate the shard number of a given guild. + """ + return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0 + + +def jumpto(guildid: int, channeldid: int, messageid: int) -> str: + """ + Build a jump link for a message given its location. + """ + return 'https://discord.com/channels/{}/{}/{}'.format( + guildid, + channeldid, + messageid + ) + + +def utc_now() -> datetime.datetime: + """ + Return the current timezone-aware utc timestamp. + """ + return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + + +def multiple_replace(string: str, rep_dict: dict[str, str]) -> str: + if rep_dict: + pattern = re.compile( + "|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]), + flags=re.DOTALL + ) + return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string) + else: + return string + + +def recover_context(context: Context): + for var in context: + var.set(context[var]) + + +def parse_ids(idstr: str) -> List[int]: + """ + Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids. + + Object agnostic, so all mention tokens are stripped. + Raises UserInputError if an id is invalid, + setting `orig` and `item` info fields. + """ + from meta.errors import UserInputError + + # Extract ids from string + splititer = (split.strip('<@!#&>, ') for split in idstr.split(',')) + splits = [split for split in splititer if split] + + # Check they are integers + if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None: + raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id}) + + # Cast to integer and return + return list(map(int, splits)) + + +def error_embed(error, **kwargs) -> discord.Embed: + embed = discord.Embed( + colour=discord.Colour.brand_red(), + description=error, + timestamp=utc_now() + ) + return embed + + +class DotDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +class Timezoned: + """ + ABC mixin for objects with a set timezone. + + Provides several useful localised properties. + """ + __slots__ = () + + @property + def timezone(self) -> pytz.timezone: + """ + Must be implemented by the deriving class! + """ + raise NotImplementedError + + @property + def now(self): + """ + Return the current time localised to the object's timezone. + """ + return datetime.datetime.now(tz=self.timezone) + + @property + def today(self): + """ + Return the start of the day localised to the object's timezone. + """ + now = self.now + return now.replace(hour=0, minute=0, second=0, microsecond=0) + + @property + def week_start(self): + """ + Return the start of the week in the object's timezone + """ + today = self.today + return today - datetime.timedelta(days=today.weekday()) + + @property + def month_start(self): + """ + Return the start of the current month in the object's timezone + """ + today = self.today + return today.replace(day=1) + + +def replace_multiple(format_string, mapping): + """ + Subsistutes the keys from the format_dict with their corresponding values. + + Substitution is non-chained, and done in a single pass via regex. + """ + if not mapping: + raise ValueError("Empty mapping passed.") + + keys = list(mapping.keys()) + pattern = '|'.join(f"({key})" for key in keys) + string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string) + return string + + +def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str): + """ + Produces a distinguishing key for an Emoji or PartialEmoji. + + Equality checks using this key should act as expected. + """ + if isinstance(emoji, _EmojiTag): + if emoji.id: + key = str(emoji.id) + else: + key = str(emoji.name) + else: + key = str(emoji) + + return key + +def recurse_map(func, obj, loc=[]): + if isinstance(obj, dict): + for k, v in obj.items(): + loc.append(k) + obj[k] = recurse_map(func, v, loc) + loc.pop() + elif isinstance(obj, list): + for i, item in enumerate(obj): + loc.append(i) + obj[i] = recurse_map(func, item) + loc.pop() + else: + obj = func(loc, obj) + return obj + +async def check_dm(user: discord.User | discord.Member) -> bool: + """ + Check whether we can direct message the given user. + + Assumes the client is initialised. + This uses an always-failing HTTP request, + so we need to be very very very careful that this is not used frequently. + Optimally only at the explicit behest of the user + (i.e. during a user instigated interaction). + """ + try: + await user.send('') + except discord.Forbidden: + return False + except discord.HTTPException: + return True + + +async def command_lengths(tree) -> dict[str, int]: + cmds = tree.get_commands() + payloads = [ + await cmd.get_translated_payload(tree.translator) + for cmd in cmds + ] + lens = {} + for command in payloads: + name = command['name'] + crumbs = {} + cmd_len = lens[name] = _recurse_length(command, crumbs, (name,)) + if name == 'configure' or cmd_len > 4000: + print(f"'{name}' over 4000. Breadcrumb Trail follows:") + lines = [] + for loc, val in crumbs.items(): + locstr = '.'.join(loc) + lines.append(f"{locstr}: {val}") + print('\n'.join(lines)) + print(json.dumps(command, indent=2)) + return lens + +def _recurse_length(payload, breadcrumbs={}, header=()) -> int: + total = 0 + total_header = (*header, '') + breadcrumbs[total_header] = 0 + + if isinstance(payload, dict): + # Read strings that count towards command length + # String length is length of longest localisation, including default. + for key in ('name', 'description', 'value'): + if key in payload: + value = payload[key] + if isinstance(value, str): + values = (value, *payload.get(key + '_localizations', {}).values()) + maxlen = max(map(len, values)) + total += maxlen + breadcrumbs[(*header, key)] = maxlen + + for key, value in payload.items(): + loc = (*header, key) + total += _recurse_length(value, breadcrumbs, loc) + elif isinstance(payload, list): + for i, item in enumerate(payload): + if isinstance(item, dict) and 'name' in item: + loc = (*header, f"{i}<{item['name']}>") + else: + loc = (*header, str(i)) + total += _recurse_length(item, breadcrumbs, loc) + + if total: + breadcrumbs[total_header] = total + else: + breadcrumbs.pop(total_header) + + return total + +def write_records(records: list[dict[str, Any]], stream: StringIO): + if records: + keys = records[0].keys() + stream.write(','.join(keys)) + stream.write('\n') + for record in records: + stream.write(','.join(map(str, record.values()))) + stream.write('\n') + + +parse_dur_exps = [ + ( + r"(?P\d+)\s*(?:(d)|(day))", + 60 * 60 * 24, + ), + ( + r"(?P\d+)\s*(?:(h)|(hour))", + 60 * 60 + ), + ( + r"(?P\d+)\s*(?:(m)|(min))", + 60 + ), + ( + r"(?P\d+)\s*(?:(s)|(sec))", + 1 + ) +] + + +def parse_duration(string: str) -> Optional[int]: + seconds = 0 + found = False + for expr, multiplier in parse_dur_exps: + match = re.search(expr, string, flags=re.IGNORECASE) + if match: + found = True + seconds += int(match.group('value')) * multiplier + + return seconds if found else None diff --git a/src/utils/monitor.py b/src/utils/monitor.py new file mode 100644 index 0000000..96aedeb --- /dev/null +++ b/src/utils/monitor.py @@ -0,0 +1,191 @@ +import asyncio +import bisect +import logging +from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any + +from .lib import utc_now +from .ratelimits import Bucket + + +logger = logging.getLogger(__name__) + +Taskid = TypeVar('Taskid') + + +class TaskMonitor(Generic[Taskid]): + """ + Base class for a task monitor. + + Stores tasks as a time-sorted list of taskids. + Subclasses may override `run_task` to implement an executor. + + Adding or removing a single task has O(n) performance. + To bulk update tasks, instead use `schedule_tasks`. + + Each taskid must be unique and hashable. + """ + + def __init__(self, executor=None, bucket: Optional[Bucket] = None): + # Ratelimit bucket to enforce maximum execution rate + self._bucket = bucket + + self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor + + self._wakeup: asyncio.Event = asyncio.Event() + self._monitor_task: Optional[asyncio.Task] = None + + # Task data + self._tasklist: list[Taskid] = [] + self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp + + # Running map ensures we keep a reference to the running task + # And allows simpler external cancellation if required + self._running: dict[Taskid, asyncio.Future] = {} + + def __repr__(self): + return ( + "<" + f"{self.__class__.__name__}" + f" tasklist={len(self._tasklist)}" + f" taskmap={len(self._taskmap)}" + f" wakeup={self._wakeup.is_set()}" + f" bucket={self._bucket}" + f" running={len(self._running)}" + f" task={self._monitor_task}" + f">" + ) + + def set_tasks(self, *tasks: tuple[Taskid, int]) -> None: + """ + Similar to `schedule_tasks`, but wipe and reset the tasklist. + """ + self._taskmap = {tid: time for tid, time in tasks} + self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])) + self._wakeup.set() + + def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None: + """ + Schedule the given tasks. + + Rather than repeatedly inserting tasks, + where the O(log n) insort is dominated by the O(n) list insertion, + we build an entirely new list, and always wake up the loop. + """ + self._taskmap |= {tid: time for tid, time in tasks} + self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])) + self._wakeup.set() + + def schedule_task(self, taskid: Taskid, timestamp: int) -> None: + """ + Insert the provided task into the tasklist. + If the new task has a lower timestamp than the next task, wakes up the monitor loop. + """ + if self._tasklist: + nextid = self._tasklist[-1] + wake = self._taskmap[nextid] >= timestamp + wake = wake or taskid == nextid + else: + wake = True + if taskid in self._taskmap: + self._tasklist.remove(taskid) + self._taskmap[taskid] = timestamp + bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t]) + if wake: + self._wakeup.set() + + def cancel_tasks(self, *taskids: Taskid) -> None: + """ + Remove all tasks with the given taskids from the tasklist. + If the next task has this taskid, wake up the monitor loop. + """ + taskids = set(taskids) + wake = (self._tasklist and self._tasklist[-1] in taskids) + self._tasklist = [tid for tid in self._tasklist if tid not in taskids] + for tid in taskids: + self._taskmap.pop(tid, None) + if wake: + self._wakeup.set() + + def start(self): + if self._monitor_task and not self._monitor_task.done(): + self._monitor_task.cancel() + # Start the monitor + self._monitor_task = asyncio.create_task(self.monitor()) + return self._monitor_task + + async def monitor(self): + """ + Start the monitor. + Executes the tasks in `self.tasks` at the specified time. + + This will shield task execution from cancellation + to avoid partial states. + """ + try: + while True: + self._wakeup.clear() + if not self._tasklist: + # No tasks left, just sleep until wakeup + await self._wakeup.wait() + else: + # Get the next task, sleep until wakeup or it is ready to run + nextid = self._tasklist[-1] + nexttime = self._taskmap[nextid] + sleep_for = nexttime - utc_now().timestamp() + try: + await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for) + except asyncio.TimeoutError: + # Ready to run the task + self._tasklist.pop() + self._taskmap.pop(nextid, None) + self._running[nextid] = asyncio.ensure_future(self._run(nextid)) + else: + # Wakeup task fired, loop again + continue + except asyncio.CancelledError: + # Log closure and wait for remaining tasks + # A second cancellation will also cancel the tasks + logger.debug( + f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. " + f"Waiting for {len(self._running)} running tasks to complete." + ) + await asyncio.gather(*self._running.values(), return_exceptions=True) + + async def _run(self, taskid: Taskid) -> None: + # Execute the task, respecting the ratelimit bucket + if self._bucket is not None: + # IMPLEMENTATION NOTE: + # Bucket.wait() should guarantee not more than n tasks/second are run + # and that a request directly afterwards will _not_ raise BucketFull + # Make sure that only one waiter is actually waiting on its sleep task + # The other waiters should be waiting on a lock around the sleep task + # Waiters are executed in wait-order, so if we only let a single waiter in + # we shouldn't get collisions. + # Furthermore, make sure we do _not_ pass back to the event loop after waiting + # Or we will lose thread-safety for BucketFull + await self._bucket.wait() + fut = asyncio.create_task(self.run_task(taskid)) + try: + await asyncio.shield(fut) + except asyncio.CancelledError: + raise + except Exception: + # Protect the monitor loop from any other exceptions + logger.exception( + f"Ignoring exception in task monitor {self.__class__.__name__} while " + f"executing " + ) + finally: + self._running.pop(taskid) + + async def run_task(self, taskid: Taskid): + """ + Execute the task with the given taskid. + + Default implementation executes `self.executor` if it exists, + otherwise raises NotImplementedError. + """ + if self.executor is not None: + await self.executor(taskid) + else: + raise NotImplementedError diff --git a/src/utils/ratelimits.py b/src/utils/ratelimits.py new file mode 100644 index 0000000..7322336 --- /dev/null +++ b/src/utils/ratelimits.py @@ -0,0 +1,173 @@ +import asyncio +import time +import logging + +from meta.errors import SafeCancellation + +from cachetools import TTLCache + +logger = logging.getLogger() + + + +class BucketFull(Exception): + """ + Throw when a requested Bucket is already full + """ + pass + + +class BucketOverFull(BucketFull): + """ + Throw when a requested Bucket is overfull + """ + pass + + +class Bucket: + __slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock') + + def __init__(self, max_level, empty_time): + self.max_level = max_level + self.empty_time = empty_time + self.leak_rate = max_level / empty_time + + self._level = 0 + self._last_checked = time.monotonic() + + self._last_full = False + self._wait_lock = asyncio.Lock() + + @property + def full(self) -> bool: + """ + Return whether the bucket is 'full', + that is, whether an immediate request against the bucket will raise `BucketFull`. + """ + self._leak() + return self._level + 1 > self.max_level + + @property + def overfull(self): + self._leak() + return self._level > self.max_level + + @property + def delay(self): + self._leak() + if self._level + 1 > self.max_level: + delay = (self._level + 1 - self.max_level) * self.leak_rate + else: + delay = 0 + return delay + + def _leak(self): + if self._level: + elapsed = time.monotonic() - self._last_checked + self._level = max(0, self._level - (elapsed * self.leak_rate)) + + self._last_checked = time.monotonic() + + def request(self): + self._leak() + if self._level > self.max_level: + raise BucketOverFull + elif self._level == self.max_level: + self._level += 1 + if self._last_full: + raise BucketOverFull + else: + self._last_full = True + raise BucketFull + else: + self._last_full = False + self._level += 1 + + def fill(self): + self._leak() + self._level = max(self._level, self.max_level + 1) + + async def wait(self): + """ + Wait until the bucket has room. + + Guarantees that a `request` directly afterwards will not raise `BucketFull`. + """ + # Wrapped in a lock so that waiters are correctly handled in wait-order + # Otherwise multiple waiters will have the same delay, + # and race for the wakeup after sleep. + # Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order + async with self._wait_lock: + # We do this in a loop in case asyncio.sleep throws us out early, + # or a synchronous request overflows the bucket while we are waiting. + while self.full: + await asyncio.sleep(self.delay) + + async def wrapped(self, coro): + await self.wait() + self.request() + await coro + + +class RateLimit: + def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)): + self.max_level = max_level + self.empty_time = empty_time + + self.error = error or "Too many requests, please slow down!" + self.buckets = cache + + def request_for(self, key): + if not (bucket := self.buckets.get(key, None)): + bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time) + + try: + bucket.request() + except BucketOverFull: + raise SafeCancellation(details="Bucket overflow") + except BucketFull: + raise SafeCancellation(self.error, details="Bucket full") + + def ward(self, member=True, key=None): + """ + Command ratelimit decorator. + """ + key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id)) + + def decorator(func): + async def wrapper(ctx, *args, **kwargs): + self.request_for(key(ctx)) + return await func(ctx, *args, **kwargs) + return wrapper + return decorator + + +async def limit_concurrency(aws, limit): + """ + Run provided awaitables concurrently, + ensuring that no more than `limit` are running at once. + """ + aws = iter(aws) + aws_ended = False + pending = set() + count = 0 + logger.debug("Starting limited concurrency executor") + + while pending or not aws_ended: + while len(pending) < limit and not aws_ended: + aw = next(aws, None) + if aw is None: + aws_ended = True + else: + pending.add(asyncio.create_task(aw)) + count += 1 + + if not pending: + break + + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + while done: + yield done.pop() + logger.debug(f"Completed {count} tasks") diff --git a/src/utils/ui/__init__.py b/src/utils/ui/__init__.py new file mode 100644 index 0000000..fd28a5b --- /dev/null +++ b/src/utils/ui/__init__.py @@ -0,0 +1,12 @@ +import asyncio +import logging +from babel.translator import LocalBabel + +logger = logging.getLogger(__name__) + +util_babel = LocalBabel('utils') + +from .hooked import * +from .leo import * +from .micros import * +from .msgeditor import * diff --git a/src/utils/ui/hooked.py b/src/utils/ui/hooked.py new file mode 100644 index 0000000..075476d --- /dev/null +++ b/src/utils/ui/hooked.py @@ -0,0 +1,59 @@ +import time + +import discord +from discord.ui.item import Item +from discord.ui.button import Button + +from .leo import LeoUI + +__all__ = ( + 'HookedItem', + 'AButton', + 'AsComponents' +) + + +class HookedItem: + """ + Mixin for Item classes allowing an instance to be used as a callback decorator. + """ + def __init__(self, *args, pass_kwargs={}, **kwargs): + super().__init__(*args, **kwargs) + self.pass_kwargs = pass_kwargs + + def __call__(self, coro): + async def wrapped(interaction, **kwargs): + return await coro(interaction, self, **(self.pass_kwargs | kwargs)) + self.callback = wrapped + return self + + +class AButton(HookedItem, Button): + ... + + +class AsComponents(LeoUI): + """ + Simple container class to accept a number of Items and turn them into an attachable View. + """ + def __init__(self, *items, pass_kwargs={}, **kwargs): + super().__init__(**kwargs) + self.pass_kwargs = pass_kwargs + + for item in items: + self.add_item(item) + + async def _scheduled_task(self, item: Item, interaction: discord.Interaction): + try: + item._refresh_state(interaction, interaction.data) # type: ignore + + allow = await self.interaction_check(interaction) + if not allow: + return + + if self.timeout: + self.__timeout_expiry = time.monotonic() + self.timeout + + await item.callback(interaction, **self.pass_kwargs) + except Exception as e: + return await self.on_error(interaction, e, item) diff --git a/src/utils/ui/leo.py b/src/utils/ui/leo.py new file mode 100644 index 0000000..eebaea4 --- /dev/null +++ b/src/utils/ui/leo.py @@ -0,0 +1,485 @@ +from typing import List, Optional, Any, Dict +import asyncio +import logging +import time +from contextvars import copy_context, Context + +import discord +from discord.ui import Modal, View, Item + +from meta.logger import log_action_stack, logging_context +from meta.errors import SafeCancellation + +from . import logger +from ..lib import MessageArgs, error_embed + +__all__ = ( + 'LeoUI', + 'MessageUI', + 'LeoModal', + 'error_handler_for' +) + + +class LeoUI(View): + """ + View subclass for small-scale user interfaces. + + While a 'View' provides an interface for managing a collection of components, + a `LeoUI` may also manage a message, and potentially slave Views or UIs. + The `LeoUI` also exposes more advanced cleanup and timeout methods, + and preserves the context. + """ + + def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + + self._name = ui_name or self.__class__.__name__ + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name]) + + # List of slaved views to stop when this view stops + self._slaves: List[View] = [] + + # TODO: Replace this with a substitutable ViewLayout class + self._layout: Optional[tuple[tuple[Item, ...], ...]] = None + + @property + def _stopped(self) -> asyncio.Future: + """ + Return an future indicating whether the View has finished interacting. + + Currently exposes a hidden attribute of the underlying View. + May be reimplemented in future. + """ + return self._View__stopped + + def to_components(self) -> List[Dict[str, Any]]: + """ + Extending component generator to apply the set _layout, if it exists. + """ + if self._layout is not None: + # Alternative rendering using layout + components = [] + for i, row in enumerate(self._layout): + # Skip empty rows + if not row: + continue + + # Since we aren't relying on ViewWeights, manually check width here + if sum(item.width for item in row) > 5: + raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!") + + # Create the component dict for this row + components.append({ + 'type': 1, + 'components': [item.to_component_dict() for item in row] + }) + else: + components = super().to_components() + + return components + + def set_layout(self, *rows: tuple[Item, ...]) -> None: + """ + Set the layout of the rendered View as a matrix of items, + or more precisely, a list of action rows. + + This acts independently of the existing sorting with `_ViewWeights`, + and overrides the sorting if applied. + """ + self._layout = rows + + async def cleanup(self): + """ + Coroutine to run when timeing out, stopping, or cancelling. + Generally cleans up any open resources, and removes any leftover components. + """ + logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'}) + return None + + def stop(self): + """ + Extends View.stop() to also stop all the slave views. + Note that stopping is idempotent, so it is okay if close() also calls stop(). + """ + for slave in self._slaves: + slave.stop() + super().stop() + + async def close(self, msg=None): + self.stop() + await self.cleanup() + + async def pre_timeout(self): + """ + Task to execute before actually timing out. + This may cancel the timeout by refreshing or rescheduling it. + (E.g. to ask the user whether they want to keep going.) + + Default implementation does nothing. + """ + return None + + async def on_timeout(self): + """ + Task to execute after timeout is complete. + Default implementation calls cleanup. + """ + await self.cleanup() + + async def __dispatch_timeout(self): + """ + This essentially extends View._dispatch_timeout, + to include a pre_timeout task + which may optionally refresh and hence cancel the timeout. + """ + if self._View__stopped.done(): + # We are already stopped, nothing to do + return + + with logging_context(action='Timeout'): + try: + await self.pre_timeout() + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: + pass + except Exception: + logger.exception( + "Unhandled error caught while dispatching timeout for {self!r}.", + extra={'with_ctx': True, 'action': 'Error'} + ) + + # Check if we still need to timeout + if self.timeout is None: + # The timeout was removed entirely, silently walk away + return + + if self._View__stopped.done(): + # We stopped while waiting for the pre timeout. + # Or maybe another thread timed us out + # Either way, we are done here + return + + now = time.monotonic() + if self._View__timeout_expiry is not None and now < self._View__timeout_expiry: + # The timeout was extended, make sure the timeout task is running then fade away + if self._View__timeout_task is None or self._View__timeout_task.done(): + self._View__timeout_task = asyncio.create_task(self._View__timeout_task_impl()) + else: + # Actually timeout, and call the post-timeout task for cleanup. + self._really_timeout() + await self.on_timeout() + + def _dispatch_timeout(self): + """ + Overriding timeout method completely, to support interactive flow during timeout, + and optional refreshing of the timeout. + """ + return self._context.run(asyncio.create_task, self.__dispatch_timeout()) + + def _really_timeout(self): + """ + Actuallly times out the View. + This copies View._dispatch_timeout, apart from the `on_timeout` dispatch, + which is now handled by `__dispatch_timeout`. + """ + if self._View__stopped.done(): + return + + if self._View__cancel_callback: + self._View__cancel_callback(self) + self._View__cancel_callback = None + + self._View__stopped.set_result(True) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item): + """ + Default LeoUI error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except SafeCancellation as e: + if e.msg and not interaction.is_expired(): + try: + if interaction.response.is_done(): + await interaction.followup.send( + embed=error_embed(e.msg), + ephemeral=True + ) + else: + await interaction.response.send_message( + embed=error_embed(e.msg), + ephemeral=True + ) + except discord.HTTPException: + pass + logger.debug( + f"Caught a safe cancellation from LeoUI: {e.details}", + extra={'action': 'Cancel'} + ) + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r} from interaction: " + f"{interaction.data}", + extra={'with_ctx': True, 'action': 'UIError'} + ) + # Explicitly handle the bugsplat ourselves + splat = interaction.client.tree.bugsplat(interaction, error) + await interaction.client.tree.error_reply(interaction, splat) + + +class MessageUI(LeoUI): + """ + Simple single-message LeoUI, intended as a framework for UIs + attached to a single interaction response. + + UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`. + """ + + def __init__(self, *args, callerid: Optional[int] = None, **kwargs): + super().__init__(*args, **kwargs) + + # ----- UI state ----- + # User ID of the original caller (e.g. command author). + # Mainly used for interaction usage checks and logging + self._callerid = callerid + + # Original interaction, if this UI is sent as an interaction response + self._original: discord.Interaction = None + + # Message holding the UI, when the UI is sent attached to a followup + self._message: discord.Message = None + + # Refresh lock, to avoid cache collisions on refresh + self._refresh_lock = asyncio.Lock() + + @property + def channel(self): + if self._original is not None: + return self._original.channel + else: + return self._message.channel + + # ----- UI API ----- + async def run(self, interaction: discord.Interaction, **kwargs): + """ + Run the UI as a response or followup to the given interaction. + + Should be extended if more complex run mechanics are needed + (e.g. registering listeners or setting up caches). + """ + await self.draw(interaction, **kwargs) + + async def refresh(self, *args, thinking: Optional[discord.Interaction] = None, **kwargs): + """ + Reload and redraw this UI. + + Primarily a hook-method for use by parents and other controllers. + Performs a full data and reload and refresh (maintaining UI state, e.g. page n). + """ + async with self._refresh_lock: + # Reload data + await self.reload() + # Redraw UI message + await self.redraw(thinking=thinking) + + async def quit(self): + """ + Quit the UI. + + This usually involves removing the original message, + and stopping or closing the underlying View. + """ + for child in self._slaves: + # TODO: Better to use duck typing or interface typing + if isinstance(child, MessageUI) and not child.is_finished(): + asyncio.create_task(child.quit()) + try: + if self._original is not None and not self._original.is_expired(): + await self._original.delete_original_response() + self._original = None + if self._message is not None: + await self._message.delete() + self._message = None + except discord.HTTPException: + pass + + # Note close() also runs cleanup and stop + await self.close() + + # ----- UI Flow ----- + async def interaction_check(self, interaction: discord.Interaction): + """ + Check the given interaction is authorised to use this UI. + + Default implementation simply checks that the interaction is + from the original caller. + Extend for more complex logic. + """ + return interaction.user.id == self._callerid + + async def make_message(self) -> MessageArgs: + """ + Create the UI message body, depening on the current state. + + Called upon each redraw. + Should handle caching if message construction is for some reason intensive. + + Must be implemented by concrete UI subclasses. + """ + raise NotImplementedError + + async def refresh_layout(self): + """ + Asynchronously refresh the message components, + and explicitly set the message component layout. + + Called just before redrawing, before `make_message`. + + Must be implemented by concrete UI subclasses. + """ + raise NotImplementedError + + async def reload(self): + """ + Reload and recompute the underlying data for this UI. + + Must be implemented by concrete UI subclasses. + """ + raise NotImplementedError + + async def draw(self, interaction, force_followup=False, **kwargs): + """ + Send the UI as a response or followup to the given interaction. + + If the interaction has been responded to, or `force_followup` is set, + creates a followup message instead of a response to the interaction. + """ + # Initial data loading + await self.reload() + # Set the UI layout + await self.refresh_layout() + # Fetch message arguments + args = await self.make_message() + + as_followup = force_followup or interaction.response.is_done() + if as_followup: + self._message = await interaction.followup.send(**args.send_args, **kwargs, view=self) + else: + self._original = interaction + await interaction.response.send_message(**args.send_args, **kwargs, view=self) + + async def send(self, channel: discord.abc.Messageable, **kwargs): + """ + Alternative to draw() which uses a discord.abc.Messageable. + """ + await self.reload() + await self.refresh_layout() + args = await self.make_message() + self._message = await channel.send(**args.send_args, view=self) + + async def _redraw(self, args): + if self._original and not self._original.is_expired(): + await self._original.edit_original_response(**args.edit_args, view=self) + elif self._message: + await self._message.edit(**args.edit_args, view=self) + else: + # Interaction expired or already closed. Quietly cleanup. + await self.close() + + async def redraw(self, thinking: Optional[discord.Interaction] = None): + """ + Update the output message for this UI. + + If a thinking interaction is provided, deletes the response while redrawing. + """ + await self.refresh_layout() + args = await self.make_message() + + if thinking is not None and not thinking.is_expired() and thinking.response.is_done(): + asyncio.create_task(thinking.delete_original_response()) + + try: + await self._redraw(args) + except discord.HTTPException as e: + # Unknown communication error, nothing we can reliably do. Exit quietly. + logger.warning( + f"Unexpected UI redraw failure occurred in {self}: {repr(e)}", + ) + await self.close() + + async def cleanup(self): + """ + Remove message components from interaction response, if possible. + + Extend to remove listeners or clean up caches. + `cleanup` is always called when the UI is exiting, + through timeout or user-driven closure. + """ + try: + if self._original is not None and not self._original.is_expired(): + await self._original.edit_original_response(view=None) + self._original = None + if self._message is not None: + await self._message.edit(view=None) + self._message = None + except discord.HTTPException: + pass + + +class LeoModal(Modal): + """ + Context-aware Modal class. + """ + def __init__(self, *args, context: Optional[Context] = None, **kwargs): + super().__init__(**kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__]) + + def _dispatch_submit(self, *args, **kwargs): + """ + Extending event dispatch to run in the instantiation context. + """ + return self._context.run(super()._dispatch_submit, *args, **kwargs) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + """ + Default LeoModal error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in {self!r}. Interaction: {interaction.data}", + extra={'with_ctx': True, 'action': 'ModalError'} + ) + # Explicitly handle the bugsplat ourselves + splat = interaction.client.tree.bugsplat(interaction, error) + await interaction.client.tree.error_reply(interaction, splat) + + +def error_handler_for(exc): + def wrapper(coro): + coro._ui_error_handler_for_ = exc + return coro + return wrapper diff --git a/src/utils/ui/micros.py b/src/utils/ui/micros.py new file mode 100644 index 0000000..eebf418 --- /dev/null +++ b/src/utils/ui/micros.py @@ -0,0 +1,329 @@ +from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict +import functools +import asyncio + +import discord +from discord.ui import TextInput +from discord.ui.button import button + +from meta.logger import logging_context +from meta.errors import ResponseTimedOut + +from .leo import LeoModal, LeoUI + +__all__ = ( + 'FastModal', + 'ModalRetryUI', + 'Confirm', + 'input', +) + + +class FastModal(LeoModal): + __class_error_handlers__ = [] + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + error_handlers = {} + for base in reversed(cls.__mro__): + for name, member in base.__dict__.items(): + if hasattr(member, '_ui_error_handler_for_'): + error_handlers[name] = member + + cls.__class_error_handlers__ = list(error_handlers.values()) + + def __init__error_handlers__(self): + handlers = {} + for handler in self.__class_error_handlers__: + handlers[handler._ui_error_handler_for_] = functools.partial(handler, self) + return handlers + + def __init__(self, *items: TextInput, **kwargs): + super().__init__(**kwargs) + for item in items: + self.add_item(item) + self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() + self._waiters: List[Callable[[discord.Interaction], Coroutine]] = [] + self._error_handlers = self.__init__error_handlers__() + + def error_handler(self, exception): + def wrapper(coro): + self._error_handlers[exception] = coro + return coro + return wrapper + + async def wait_for(self, check=None, timeout=None): + # Wait for _result or timeout + # If we timeout, or the view times out, raise TimeoutError + # Otherwise, return the Interaction + # This allows multiple listeners and callbacks to wait on + while True: + result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) + if check is not None: + if not check(result): + continue + return result + + async def on_timeout(self): + self._result.set_exception(asyncio.TimeoutError) + + def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): + def wrapper(coro): + async def wrapped_callback(interaction): + with logging_context(action=coro.__name__): + if check is not None: + if not check(interaction): + return + try: + await coro(interaction, *pass_args, **pass_kwargs) + except Exception: + raise + finally: + if once: + self._waiters.remove(wrapped_callback) + self._waiters.append(wrapped_callback) + return wrapper + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + try: + # First let our error handlers have a go + # If there is no handler for this error, or the handlers themselves error, + # drop to the superclass error handler implementation. + try: + raise error + except tuple(self._error_handlers.keys()) as e: + # If an error handler is registered for this exception, run it. + for cls, handler in self._error_handlers.items(): + if isinstance(e, cls): + await handler(interaction, e) + except Exception as error: + await super().on_error(interaction, error) + + async def on_submit(self, interaction): + print("On submit") + old_result = self._result + self._result = asyncio.get_event_loop().create_future() + old_result.set_result(interaction) + + tasks = [] + for waiter in self._waiters: + task = asyncio.create_task( + waiter(interaction), + name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}" + ) + tasks.append(task) + if tasks: + await asyncio.gather(*tasks) + + +async def input( + interaction: discord.Interaction, + title: str, + question: Optional[str] = None, + field: Optional[TextInput] = None, + timeout=180, + **kwargs, +) -> tuple[discord.Interaction, str]: + """ + Spawn a modal to accept input. + Returns an (interaction, value) pair, with interaction not yet responded to. + May raise asyncio.TimeoutError if the view times out. + """ + if field is None: + field = TextInput( + label=kwargs.get('label', question), + **kwargs + ) + modal = FastModal( + field, + title=title, + timeout=timeout + ) + await interaction.response.send_modal(modal) + interaction = await modal.wait_for() + return (interaction, field.value) + + +class ModalRetryUI(LeoUI): + def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.modal = modal + self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)} + + self.message = message + + self._interaction = None + + if label is not None: + self.retry_button.label = label + + @property + def embed(self): + return discord.Embed( + title="Uh-Oh!", + description=self.message, + colour=discord.Colour.red() + ) + + async def respond_to(self, interaction): + self._interaction = interaction + if interaction.response.is_done(): + await interaction.followup.send(embed=self.embed, ephemeral=True, view=self) + else: + await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self) + + @button(label="Retry") + async def retry_button(self, interaction, butt): + # Setting these here so they don't update in the meantime + for item, value in self.item_values.items(): + item.default = value + if self._interaction is not None: + await self._interaction.delete_original_response() + self._interaction = None + await interaction.response.send_modal(self.modal) + await self.close() + + +class Confirm(LeoUI): + """ + Micro UI class implementing a confirmation question. + + Parameters + ---------- + confirm_msg: str + The confirmation question to ask from the user. + This is set as the description of the `embed` property. + The `embed` may be further modified if required. + permitted_id: Optional[int] + The user id allowed to access this interaction. + Other users will recieve an access denied error message. + defer: bool + Whether to defer the interaction response while handling the button. + It may be useful to set this to `False` to obtain manual control + over the interaction response flow (e.g. to send a modal or ephemeral message). + The button press interaction may be accessed through `Confirm.interaction`. + Default: True + + Example + ------- + ``` + confirm = Confirm("Are you sure?", ctx.author.id) + confirm.embed.colour = discord.Colour.red() + confirm.confirm_button.label = "Yes I am sure" + confirm.cancel_button.label = "No I am not sure" + + try: + result = await confirm.ask(ctx.interaction, ephemeral=True) + except ResultTimedOut: + return + ``` + """ + def __init__( + self, + confirm_msg: str, + permitted_id: Optional[int] = None, + defer: bool = True, + **kwargs + ): + super().__init__(**kwargs) + self.confirm_msg = confirm_msg + self.permitted_id = permitted_id + self.defer = defer + + self._embed: Optional[discord.Embed] = None + self._result: asyncio.Future[bool] = asyncio.Future() + + # Indicates whether we should delete the message or the interaction response + self._is_followup: bool = False + self._original: Optional[discord.Interaction] = None + self._message: Optional[discord.Message] = None + + async def interaction_check(self, interaction: discord.Interaction): + return (self.permitted_id is None) or interaction.user.id == self.permitted_id + + async def on_timeout(self): + # Propagate timeout to result Future + self._result.set_exception(ResponseTimedOut) + await self.cleanup() + + async def cleanup(self): + """ + Cleanup the confirmation prompt by deleting it, if possible. + Ignores any Discord errors that occur during the process. + """ + try: + if self._is_followup and self._message: + await self._message.delete() + elif not self._is_followup and self._original and not self._original.is_expired(): + await self._original.delete_original_response() + except discord.HTTPException: + # A user probably already deleted the message + # Anything could have happened, just ignore. + pass + + @button(label="Confirm") + async def confirm_button(self, interaction: discord.Interaction, press): + if self.defer: + await interaction.response.defer() + self._result.set_result(True) + await self.close() + + @button(label="Cancel") + async def cancel_button(self, interaction: discord.Interaction, press): + if self.defer: + await interaction.response.defer() + self._result.set_result(False) + await self.close() + + @property + def embed(self): + """ + Confirmation embed shown to the user. + This is cached, and may be modifed directly through the usual EmbedProxy API, + or explicitly overwritten. + """ + if self._embed is None: + self._embed = discord.Embed( + colour=discord.Colour.orange(), + description=self.confirm_msg + ) + return self._embed + + @embed.setter + def embed(self, value): + self._embed = value + + async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs): + """ + Send this confirmation prompt in response to the provided interaction. + Extra keyword arguments are passed to `Interaction.response.send_message` + or `Interaction.send_followup`, depending on whether + the provided interaction has already been responded to. + + The `epehemeral` argument is handled specially, + since the question message can only be deleted through `Interaction.delete_original_response`. + + Waits on and returns the internal `result` Future. + + Returns: bool + True if the user pressed the confirm button. + False if the user pressed the cancel button. + Raises: + ResponseTimedOut: + If the user does not respond before the UI times out. + """ + self._original = interaction + if interaction.response.is_done(): + # Interaction already responded to, send a follow up + if ephemeral: + raise ValueError("Cannot send an ephemeral response to a used interaction.") + self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self) + self._is_followup = True + else: + await interaction.response.send_message( + embed=self.embed, ephemeral=ephemeral, **kwargs, view=self + ) + self._is_followup = False + return await self._result + +# TODO: Selector MicroUI for displaying options (<= 25) diff --git a/src/utils/ui/msgeditor.py b/src/utils/ui/msgeditor.py new file mode 100644 index 0000000..fb2d907 --- /dev/null +++ b/src/utils/ui/msgeditor.py @@ -0,0 +1,1070 @@ +from typing import Optional +import asyncio +import copy +import json +import datetime as dt +from io import StringIO + +import discord +from discord.ui.button import button, Button, ButtonStyle +from discord.ui.select import select, Select, SelectOption +from discord.ui.text_input import TextInput, TextStyle + +from meta import conf, LionBot +from meta.errors import UserInputError, ResponseTimedOut + +from ..lib import MessageArgs, utc_now + +from . import MessageUI, util_babel, error_handler_for, FastModal, ModalRetryUI, Confirm, AsComponents, AButton + + +_p = util_babel._p + + +class MsgEditorInput(FastModal): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @error_handler_for(UserInputError) + async def rerequest(self, interaction, error): + await ModalRetryUI(self, error.msg).respond_to(interaction) + + +class MsgEditor(MessageUI): + def __init__(self, bot: LionBot, initial_data: dict, formatter=None, callback=None, **kwargs): + self.bot = bot + self.history = [initial_data] # Last item in history is current state + self.future = [] # Last item in future is next state + + self._formatter = formatter + self._callback = callback + + super().__init__(**kwargs) + + @property + def data(self): + return self.history[-1] + + # ----- API ----- + async def format_data(self, data): + """ + Format a MessageData dict for rendering. + + May be extended or overridden for custom formatting. + By default, uses the provided `formatter` callback (if provided). + """ + if self._formatter is not None: + return await self._formatter(data) + else: + return data + + + def copy_data(self): + return copy.deepcopy(self.history[-1]) + + async def save(self): + ... + + async def push_change(self, new_data): + # Cleanup the data + if (embed_data := new_data.get('embed', None)) is not None and not embed_data: + new_data.pop('embed') + + t = self.bot.translator.t + if 'embed' not in new_data and not new_data.get('content', None): + raise UserInputError( + t(_p( + 'ui:msg_editor|error:empty', + "Rendering failed! The message content and embed cannot both be empty." + )) + ) + + if 'embed' in new_data: + try: + formatted_data = copy.deepcopy(new_data) + discord.Embed.from_dict(await self.format_data(formatted_data['embed'])) + except Exception as e: + raise UserInputError( + t(_p( + 'ui:msg_editor|error:embed_failed', + "Rendering failed! Could not parse the embed.\n" + "Error: {error}" + )).format(error=str(e)) + ) + + # Push the state and try displaying it + self.history.append(new_data) + old_future = self.future + self.future = [] + try: + await self.refresh() + except discord.HTTPException as e: + # State failed, rollback and error + self.history.pop() + self.future = old_future + raise UserInputError( + t(_p( + 'ui:msg_editor|error:invalid_change', + "Rendering failed! The message was not modified.\n" + "Error: `{text}`" + )).format(text=e.text) + ) + + # ----- UI Components ----- + + # -- Content Only mode -- + @button(label="EDIT_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def edit_button(self, press: discord.Interaction, pressed: Button): + """ + Open an editor for the message content + """ + data = self.copy_data() + + t = self.bot.translator.t + content_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:content|field:content|label', + "Message Content" + )), + style=TextStyle.long, + required=False, + default=data.get('content', ""), + max_length=2000 + ) + modal = MsgEditorInput( + content_field, + title=t(_p('ui:msg_editor|modal:content|title', "Content Editor")) + ) + + @modal.submit_callback() + async def content_modal_callback(interaction: discord.Interaction): + new_content = content_field.value + data['content'] = new_content + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def edit_button_refresh(self): + t = self.bot.translator.t + button = self.edit_button + button.label = t(_p( + 'ui:msg_editor|button:edit|label', + "Edit Content" + )) + + @button(label="ADD_EMBED_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def add_embed_button(self, press: discord.Interaction, pressed: Button): + """ + Attach an embed with some simple fields filled. + """ + await press.response.defer() + t = self.bot.translator.t + + sample_embed = { + "title": t(_p('ui:msg_editor|button:add_embed|sample_embed|title', "Title Placeholder")), + "description": t(_p('ui:msg_editor|button:add_embed|sample_embed|description', "Description Placeholder")), + } + data = self.copy_data() + data['embed'] = sample_embed + await self.push_change(data) + + async def add_embed_button_refresh(self): + t = self.bot.translator.t + button = self.add_embed_button + button.label = t(_p( + 'ui:msg_editor|button:add_embed|label', + "Add Embed" + )) + + @button(label="RM_EMBED_BUTTON_PLACEHOLDER", style=ButtonStyle.red) + async def rm_embed_button(self, press: discord.Interaction, pressed: Button): + """ + Remove the existing embed from the message. + """ + await press.response.defer() + t = self.bot.translator.t + data = self.copy_data() + data.pop('embed', None) + data.pop('embeds', None) + if not data.get('content', '').strip(): + data['content'] = t(_p( + 'ui:msg_editor|button:rm_embed|sample_content', + "Content Placeholder" + )) + await self.push_change(data) + + async def rm_embed_button_refresh(self): + t = self.bot.translator.t + button = self.rm_embed_button + button.label = t(_p( + 'ui:msg_editor|button:rm_embed|label', + "Remove Embed" + )) + + # -- Embed Mode -- + + @button(label="BODY_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def body_button(self, press: discord.Interaction, pressed: Button): + """ + Edit the Content, Description, Title, and Colour + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + + t = self.bot.translator.t + + content_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:content|label', + "Message Content" + )), + style=TextStyle.long, + required=False, + default=data.get('content', ""), + max_length=2000 + ) + + desc_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:desc|label', + "Embed Description" + )), + style=TextStyle.long, + required=False, + default=embed_data.get('description', ""), + max_length=4000 + ) + + title_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:title|label', + "Embed Title" + )), + style=TextStyle.short, + required=False, + default=embed_data.get('title', ""), + max_length=256 + ) + + colour_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:colour|label', + "Embed Colour" + )), + style=TextStyle.short, + required=False, + default=str(discord.Colour(value=embed_data['color'])) if 'color' in embed_data else '', + placeholder=str(discord.Colour.orange()), + max_length=7, + min_length=7 + ) + + modal = MsgEditorInput( + content_field, + title_field, + desc_field, + colour_field, + title=t(_p('ui:msg_editor|modal:body|title', "Message Body Editor")) + ) + + @modal.submit_callback() + async def body_modal_callback(interaction: discord.Interaction): + data['content'] = content_field.value + + if desc_field.value: + embed_data['description'] = desc_field.value + else: + embed_data.pop('description', None) + + if title_field.value: + embed_data['title'] = title_field.value + else: + embed_data.pop('title', None) + + if colour_field.value: + colourstr = colour_field.value + try: + colour = discord.Colour.from_str(colourstr) + except ValueError: + raise UserInputError( + t(_p( + 'ui:msg_editor|button:body|error:invalid_colour', + "Invalid colour format! Please enter colours as hex codes, e.g. `#E67E22`" + )) + ) + embed_data['color'] = colour.value + else: + embed_data.pop('color', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def body_button_refresh(self): + t = self.bot.translator.t + button = self.body_button + button.label = t(_p( + 'ui:msg_editor|button:body|label', + "Body" + )) + + @button(label="AUTHOR_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def author_button(self, press: discord.Interaction, pressed: Button): + """ + Edit the embed author (author name/link/image url) + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + author_data = embed_data.get('author', {}) + + t = self.bot.translator.t + + name_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:author|field:name|label', + "Author Name" + )), + style=TextStyle.short, + required=False, + default=author_data.get('name', ''), + max_length=256 + ) + + link_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:author|field:link|label', + "Author URL" + )), + style=TextStyle.short, + required=False, + default=author_data.get('url', ''), + ) + + image_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:author|field:image|label', + "Author Image URL" + )), + style=TextStyle.short, + required=False, + default=author_data.get('icon_url', ''), + ) + + modal = MsgEditorInput( + name_field, + link_field, + image_field, + title=t(_p('ui:msg_editor|modal:author|title', "Embed Author Editor")) + ) + + @modal.submit_callback() + async def author_modal_callback(interaction: discord.Interaction): + if (name := name_field.value): + author_data['name'] = name + author_data['icon_url'] = image_field.value + author_data['url'] = link_field.value + embed_data['author'] = author_data + else: + embed_data.pop('author', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def author_button_refresh(self): + t = self.bot.translator.t + button = self.author_button + button.label = t(_p( + 'ui:msg_editor|button:author|label', + "Author" + )) + + @button(label="FOOTER_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def footer_button(self, press: discord.Interaction, pressed: Button): + """ + Open the Footer editor (edit footer icon, text, timestamp). + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + footer_data = embed_data.get('footer', {}) + + t = self.bot.translator.t + + text_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:footer|field:text|label', + "Footer Text" + )), + style=TextStyle.long, + required=False, + default=footer_data.get('text', ''), + max_length=2048 + ) + + image_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:footer|field:image|label', + "Footer Image URL" + )), + style=TextStyle.short, + required=False, + default=footer_data.get('icon_url', ''), + ) + + timestamp_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:footer|field:timestamp|label', + "Embed Timestamp (in ISO format)" + )), + style=TextStyle.short, + required=False, + default=embed_data.get('timestamp', ''), + placeholder=utc_now().replace(microsecond=0).isoformat(sep=' ') + ) + + modal = MsgEditorInput( + text_field, + image_field, + timestamp_field, + title=t(_p('ui:msg_editor|modal:footer|title', "Embed Footer Editor")) + ) + + @modal.submit_callback() + async def footer_modal_callback(interaction: discord.Interaction): + if (text := text_field.value): + footer_data['text'] = text + footer_data['icon_url'] = image_field.value + embed_data['footer'] = footer_data + else: + embed_data.pop('footer', None) + + if (ts := timestamp_field.value): + if ts.isdigit(): + # Treat as UTC timestamp + timestamp = dt.datetime.fromtimestamp(int(ts), dt.timezone.utc) + ts = timestamp.isoformat() + to_validate = ts + elif self._formatter: + to_validate = await self._formatter(ts) + else: + to_validate = ts + try: + dt.datetime.fromisoformat(to_validate) + except ValueError: + raise UserInputError( + t(_p( + 'ui:msg_editor|button:footer|error:invalid_timestamp', + "Invalid timestamp! Please enter the timestamp in ISO format." + )) + ) + embed_data['timestamp'] = ts + else: + embed_data.pop('timestamp', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def footer_button_refresh(self): + t = self.bot.translator.t + button = self.footer_button + button.label = t(_p( + 'ui:msg_editor|button:footer|label', + "Footer" + )) + + @button(label="IMAGES_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def images_button(self, press: discord.Interaction, pressed: Button): + """ + Edit the embed images (thumbnail and main image). + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + thumb_data = embed_data.get('thumbnail', {}) + image_data = embed_data.get('image', {}) + + t = self.bot.translator.t + + thumb_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:images|field:thumb|label', + "Thumbnail Image URL" + )), + style=TextStyle.short, + required=False, + default=thumb_data.get('url', ''), + ) + + image_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:images|field:image|label', + "Embed Image URL" + )), + style=TextStyle.short, + required=False, + default=image_data.get('url', ''), + ) + + modal = MsgEditorInput( + thumb_field, + image_field, + title=t(_p('ui:msg_editor|modal:images|title', "Embed images Editor")) + ) + + @modal.submit_callback() + async def images_modal_callback(interaction: discord.Interaction): + if (thumb_url := thumb_field.value): + thumb_data['url'] = thumb_url + embed_data['thumbnail'] = thumb_data + else: + embed_data.pop('thumbnail', None) + + if (image_url := image_field.value): + image_data['url'] = image_url + embed_data['image'] = image_data + else: + embed_data.pop('image', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def images_button_refresh(self): + t = self.bot.translator.t + button = self.images_button + button.label = t(_p( + 'ui:msg_editor|button:images|label', + "Images" + )) + + @button(label="ADD_FIELD_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def add_field_button(self, press: discord.Interaction, pressed: Button): + """ + Add an embed field (position, name, value, inline) + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + orig_fields = field_data.copy() + + t = self.bot.translator.t + + position_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:position|label', + "Field number to insert at" + )), + style=TextStyle.short, + required=True, + default=str(len(field_data)), + ) + + name_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:name|label', + "Field name" + )), + style=TextStyle.short, + required=False, + max_length=256, + ) + + value_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:value|label', + "Field value" + )), + style=TextStyle.long, + required=True, + max_length=1024, + ) + + inline_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:inline|label', + "Whether the field is inline" + )), + placeholder=t(_p( + 'ui:msg_editor|modal:add_field|field:inline|placeholder', + "True/False" + )), + style=TextStyle.short, + required=True, + max_length=256, + default='True', + ) + + modal = MsgEditorInput( + name_field, + value_field, + position_field, + inline_field, + title=t(_p('ui:msg_editor|modal:add_field|title', "Add Embed Field")) + ) + + @modal.submit_callback() + async def add_field_modal_callback(interaction: discord.Interaction): + if inline_field.value.lower() == 'true': + inline = True + else: + inline = False + field = { + 'name': name_field.value, + 'value': value_field.value, + 'inline': inline + } + try: + position = int(position_field.value) + except ValueError: + raise UserInputError( + t(_p( + 'ui:msg_editor|modal:add_field|error:position_not_int', + "The field position must be an integer!" + )) + ) + field_data = orig_fields.copy() + field_data.insert(position, field) + embed_data['fields'] = field_data + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def add_field_button_refresh(self): + t = self.bot.translator.t + button = self.add_field_button + button.label = t(_p( + 'ui:msg_editor|button:add_field|label', + "Add Field" + )) + data = self.history[-1] + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + button.disabled = (len(field_data) >= 25) + + def _field_option(self, index, field_data): + t = self.bot.translator.t + + name = field_data.get('name', "") + value = field_data['value'] + + if not name: + name = t(_p( + 'ui:msg_editor|format_field|name_placeholder', + "-" + )) + + name = f"{index+1}. {name}" + if len(name) > 100: + name = name[:97] + '...' + + if len(value) > 100: + value = value[:97] + '...' + + return SelectOption(label=name, description=value, value=str(index)) + + @select(cls=Select, placeholder="EDIT_FIELD_MENU_PLACEHOLDER", max_values=1) + async def edit_field_menu(self, selection: discord.Interaction, selected: Select): + if not selected.values: + await selection.response.defer() + return + + index = int(selected.values[0]) + data = self.copy_data() + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + field = field_data[index] + + t = self.bot.translator.t + + name_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:edit_field|field:name|label', + "Field name" + )), + style=TextStyle.short, + default=field.get('name', ''), + required=False, + max_length=256, + ) + + value_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:edit_field|field:value|label', + "Field value" + )), + style=TextStyle.long, + default=field.get('value', ''), + required=True, + max_length=1024, + ) + + inline_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:edit_field|field:inline|label', + "Whether the field is inline" + )), + placeholder=t(_p( + 'ui:msg_editor|modal:edit_field|field:inline|placeholder', + "True/False" + )), + default='True' if field.get('inline', True) else 'False', + style=TextStyle.short, + required=True, + max_length=256, + ) + + modal = MsgEditorInput( + name_field, + value_field, + inline_field, + title=t(_p('ui:msg_editor|modal:edit_field|title', "Edit Embed Field")) + ) + + @modal.submit_callback() + async def edit_field_modal_callback(interaction: discord.Interaction): + if inline_field.value.lower() == 'true': + inline = True + else: + inline = False + field = { + 'name': name_field.value, + 'value': value_field.value, + 'inline': inline + } + field_data[index] = field + embed_data['fields'] = field_data + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await selection.response.send_modal(modal) + + async def edit_field_menu_refresh(self): + t = self.bot.translator.t + menu = self.edit_field_menu + menu.placeholder = t(_p( + 'ui:msg_editor|menu:edit_field|placeholder', + "Edit Embed Field" + )) + data = self.history[-1] + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + + if len(field_data) == 0: + menu.disabled = True + menu.options = [ + SelectOption(label='Dummy') + ] + else: + menu.disabled = False + menu.options = [ + self._field_option(i, field) + for i, field in enumerate(field_data) + ] + + @select(cls=Select, placeholder="DELETE_FIELD_MENU_PLACEHOLDER", max_values=1) + async def delete_field_menu(self, selection: discord.Interaction, selected: Select): + if not selected.values: + await selection.response.defer() + return + + index = int(selected.values[0]) + data = self.copy_data() + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + field_data.pop(index) + if not field_data: + embed_data.pop('fields') + await self.push_change(data) + await selection.response.defer() + + async def delete_field_menu_refresh(self): + t = self.bot.translator.t + menu = self.delete_field_menu + menu.placeholder = t(_p( + 'ui:msg_deleteor|menu:delete_field|placeholder', + "Remove Embed Field" + )) + data = self.history[-1] + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + + if len(field_data) == 0: + menu.disabled = True + menu.options = [ + SelectOption(label='Dummy') + ] + else: + menu.disabled = False + menu.options = [ + self._field_option(i, field) + for i, field in enumerate(field_data) + ] + + # -- Shared -- + @button(label="SAVE_BUTTON_PLACEHOLDER", style=ButtonStyle.green) + async def save_button(self, press: discord.Interaction, pressed: Button): + """ + Saving simply resets the undo stack and calls the callback function. + Presumably the callback is hooked up to data or similar. + """ + await press.response.defer(thinking=True, ephemeral=True) + if self._callback is not None: + await self._callback(self.data) + self.history = self.history[-1:] + await self.refresh(thinking=press) + + async def save_button_refresh(self): + t = self.bot.translator.t + button = self.save_button + button.label = t(_p( + 'ui:msg_editor|button:save|label', + "Save" + )) + if len(self.history) > 1: + original = json.dumps(self.history[0]) + current = json.dumps(self.history[-1]) + button.disabled = (original == current) + else: + button.disabled = True + + @button(label="DOWNLOAD_BUTTON_PLACEHOLDER", style=ButtonStyle.grey) + async def download_button(self, press: discord.Interaction, pressed: Button): + """ + Reply ephemerally with a formatted json version of the message content. + """ + data = json.dumps(self.history[-1], indent=2) + with StringIO(data) as fp: + fp.seek(0) + file = discord.File(fp, filename='message.json') + await press.response.send_message(file=file, ephemeral=True) + + async def download_button_refresh(self): + t = self.bot.translator.t + button = self.download_button + button.label = t(_p( + 'ui:msg_editor|button:download|label', + "Download" + )) + + @button(label="UNDO_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def undo_button(self, press: discord.Interaction, pressed: Button): + """ + Pop the history stack. + """ + if len(self.history) > 1: + state = self.history.pop() + self.future.append(state) + await press.response.defer() + await self.refresh() + + async def undo_button_refresh(self): + t = self.bot.translator.t + button = self.undo_button + button.label = t(_p( + 'ui:msg_editor|button:undo|label', + "Undo" + )) + button.disabled = (len(self.history) <= 1) + + @button(label="REDO_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def redo_button(self, press: discord.Interaction, pressed: Button): + """ + Pop the future stack. + """ + if len(self.future) > 0: + state = self.future.pop() + self.history.append(state) + await press.response.defer() + await self.refresh() + + async def redo_button_refresh(self): + t = self.bot.translator.t + button = self.redo_button + button.label = t(_p( + 'ui:msg_editor|button:redo|label', + "Redo" + )) + button.disabled = (len(self.future) == 0) + + @button(style=ButtonStyle.grey, emoji=conf.emojis.cancel) + async def quit_button(self, press: discord.Interaction, pressed: Button): + # Confirm quit if there are unsaved changes + unsaved = False + if len(self.history) > 1: + original = json.dumps(self.history[0]) + current = json.dumps(self.history[-1]) + if original != current: + unsaved = True + + # Confirmation prompt + if unsaved: + t = self.bot.translator.t + confirm_msg = t(_p( + 'ui:msg_editor|button:quit|confirm', + "You have unsaved changes! Are you sure you want to quit?" + )) + confirm = Confirm(confirm_msg, self._callerid) + confirm.confirm_button.label = t(_p( + 'ui:msg_editor|button:quit|confirm|button:yes', + "Yes, Quit Now" + )) + confirm.confirm_button.style = ButtonStyle.red + confirm.cancel_button.style = ButtonStyle.green + confirm.cancel_button.label = t(_p( + 'ui:msg_editor|button:quit|confirm|button:no', + "No, Go Back" + )) + try: + result = await confirm.ask(press, ephemeral=True) + except ResponseTimedOut: + result = False + + if result: + await self.quit() + else: + await self.quit() + + # ----- UI Flow ----- + async def make_message(self) -> MessageArgs: + data = self.copy_data() + await self.format_data(data) + + args = {} + args['content'] = data.get('content', '') + + if 'embed' in data: + args['embed'] = discord.Embed.from_dict(data['embed']) + else: + args['embed'] = None + + return MessageArgs(**args) + + async def refresh_layout(self): + to_refresh = ( + self.edit_button_refresh(), + self.add_embed_button_refresh(), + self.body_button_refresh(), + self.author_button_refresh(), + self.footer_button_refresh(), + self.images_button_refresh(), + self.add_field_button_refresh(), + self.edit_field_menu_refresh(), + self.delete_field_menu_refresh(), + self.save_button_refresh(), + self.download_button_refresh(), + self.undo_button_refresh(), + self.redo_button_refresh(), + self.rm_embed_button_refresh(), + ) + await asyncio.gather(*to_refresh) + + if self.history[-1].get('embed', None): + self.set_layout( + (self.body_button, self.author_button, self.footer_button, self.images_button, self.add_field_button), + (self.edit_field_menu,), + (self.delete_field_menu,), + (self.rm_embed_button,), + (self.save_button, self.download_button, self.undo_button, self.redo_button, self.quit_button), + ) + else: + self.set_layout( + (self.edit_button, self.add_embed_button), + (self.save_button, self.download_button, self.undo_button, self.redo_button, self.quit_button), + ) + + async def reload(self): + # All data is handled by components, so nothing to do here + pass + + async def redraw(self, thinking: Optional[discord.Interaction] = None): + """ + Overriding MessageUI.redraw to propagate exception. + """ + await self.refresh_layout() + args = await self.make_message() + + if thinking is not None and not thinking.is_expired() and thinking.response.is_done(): + asyncio.create_task(thinking.delete_original_response()) + + if self._original and not self._original.is_expired(): + await self._original.edit_original_response(**args.edit_args, view=self) + elif self._message: + await self._message.edit(**args.edit_args, view=self) + else: + # Interaction expired or already closed. Quietly cleanup. + await self.close() + + async def pre_timeout(self): + unsaved = False + if len(self.history) > 1: + original = json.dumps(self.history[0]) + current = json.dumps(self.history[-1]) + if original != current: + unsaved = True + + # Timeout confirmation + if unsaved: + t = self.bot.translator.t + grace_period = 60 + grace_time = utc_now() + dt.timedelta(seconds=grace_period) + embed = discord.Embed( + title=t(_p( + 'ui:msg_editor|timeout_warning|title', + "Warning!" + )), + description=t(_p( + 'ui:msg_editor|timeout_warning|desc', + "This interface will time out {timestamp}. Press 'Continue' below to keep editing." + )).format( + timestamp=discord.utils.format_dt(grace_time, style='R') + ), + ) + + components = None + stopped = False + + @AButton(label=t(_p('ui:msg_editor|timeout_warning|continue', "Continue")), style=ButtonStyle.green) + async def cont_button(interaction: discord.Interaction, pressed): + await interaction.response.defer() + await interaction.message.delete() + nonlocal stopped + stopped = True + # TODO: Clean up this mess. It works, but needs to be refactored to a timeout confirmation mixin. + # TODO: Consider moving the message to the interaction response + self._refresh_timeout() + components.stop() + + components = AsComponents(cont_button, timeout=grace_period) + message = await self._original.channel.send(content=f"<@{self._callerid}>", embed=embed, view=components) + await components.wait() + + if not stopped: + try: + await message.delete() + except discord.HTTPException: + pass diff --git a/src/wards.py b/src/wards.py new file mode 100644 index 0000000..4f1647f --- /dev/null +++ b/src/wards.py @@ -0,0 +1,9 @@ +from meta import LionBot + +# Raw checks, return True/False depending on whether they pass +async def sys_admin(bot: LionBot, userid: int): + """ + Checks whether the context author is listed in the configuration file as a bot admin. + """ + admins = bot.config.bot.getintlist('admins') + return userid in admins