diff --git a/data/schema.sql b/data/schema.sql index 3d508a2..74e0b84 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -51,6 +51,56 @@ CREATE TABLE channel_links( ); +-- }}} + +-- Stream Alerts {{{ + +-- DROP TABLE IF EXISTS stream_alerts; +-- DROP TABLE IF EXISTS streams; +-- DROP TABLE IF EXISTS alert_channels; +-- DROP TABLE IF EXISTS streamers; + +CREATE TABLE streamers( + userid BIGINT PRIMARY KEY, + login_name TEXT NOT NULL, + display_name TEXT NOT NULL +); + +CREATE TABLE alert_channels( + subscriptionid SERIAL PRIMARY KEY, + guildid BIGINT NOT NULL, + channelid BIGINT NOT NULL, + streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE, + created_by BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + paused BOOLEAN NOT NULL DEFAULT FALSE, + end_delete BOOLEAN NOT NULL DEFAULT FALSE, + live_message TEXT, + end_message TEXT +); +CREATE INDEX alert_channels_guilds ON alert_channels (guildid); +CREATE UNIQUE INDEX alert_channels_channelid_streamerid ON alert_channels (channelid, streamerid); + +CREATE TABLE streams( + streamid SERIAL PRIMARY KEY, + streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE, + start_at TIMESTAMPTZ NOT NULL, + twitch_stream_id BIGINT, + game_name TEXT, + title TEXT, + end_at TIMESTAMPTZ +); + +CREATE TABLE stream_alerts( + alertid SERIAL PRIMARY KEY, + streamid INTEGER NOT NULL REFERENCES streams (streamid) ON DELETE CASCADE, + subscriptionid INTEGER NOT NULL REFERENCES alert_channels (subscriptionid) ON DELETE CASCADE, + sent_at TIMESTAMPTZ NOT NULL, + messageid BIGINT NOT NULL, + resolved_at TIMESTAMPTZ +); + + -- }}} -- vim: set fdm=marker: diff --git a/requirements.txt b/requirements.txt index 48ed17b..402aaf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ discord.py [voice] iso8601==0.1.16 psycopg[pool] pytz==2021.1 +twitchAPI diff --git a/src/babel/__init__.py b/src/babel/__init__.py new file mode 100644 index 0000000..48da68e --- /dev/null +++ b/src/babel/__init__.py @@ -0,0 +1,3 @@ +from .translator import SOURCE_LOCALE, LeoBabel, LocalBabel, LazyStr, ctx_locale, ctx_translator + +babel = LocalBabel('babel') diff --git a/src/babel/enums.py b/src/babel/enums.py new file mode 100644 index 0000000..9e73da3 --- /dev/null +++ b/src/babel/enums.py @@ -0,0 +1,81 @@ +from enum import Enum +from . import babel + +_p = babel._p + + +class LocaleMap(Enum): + american_english = 'en-US' + british_english = 'en-GB' + bulgarian = 'bg' + chinese = 'zh-CN' + taiwan_chinese = 'zh-TW' + croatian = 'hr' + czech = 'cs' + danish = 'da' + dutch = 'nl' + finnish = 'fi' + french = 'fr' + german = 'de' + greek = 'el' + hindi = 'hi' + hungarian = 'hu' + italian = 'it' + japanese = 'ja' + korean = 'ko' + lithuanian = 'lt' + norwegian = 'no' + polish = 'pl' + brazil_portuguese = 'pt-BR' + romanian = 'ro' + russian = 'ru' + spain_spanish = 'es-ES' + swedish = 'sv-SE' + thai = 'th' + turkish = 'tr' + ukrainian = 'uk' + vietnamese = 'vi' + hebrew = 'he-IL' + + +# Original Discord names +locale_names = { + 'id': (_p('localenames|locale:id', "Indonesian"), "Bahasa Indonesia"), + 'da': (_p('localenames|locale:da', "Danish"), "Dansk"), + 'de': (_p('localenames|locale:de', "German"), "Deutsch"), + 'en-GB': (_p('localenames|locale:en-GB', "English, UK"), "English, UK"), + 'en-US': (_p('localenames|locale:en-US', "English, US"), "English, US"), + 'es-ES': (_p('localenames|locale:es-ES', "Spanish"), "Español"), + 'fr': (_p('localenames|locale:fr', "French"), "Français"), + 'hr': (_p('localenames|locale:hr', "Croatian"), "Hrvatski"), + 'it': (_p('localenames|locale:it', "Italian"), "Italiano"), + 'lt': (_p('localenames|locale:lt', "Lithuanian"), "Lietuviškai"), + 'hu': (_p('localenames|locale:hu', "Hungarian"), "Magyar"), + 'nl': (_p('localenames|locale:nl', "Dutch"), "Nederlands"), + 'no': (_p('localenames|locale:no', "Norwegian"), "Norsk"), + 'pl': (_p('localenames|locale:pl', "Polish"), "Polski"), + 'pt-BR': (_p('localenames|locale:pt-BR', "Portuguese, Brazilian"), "Português do Brasil"), + 'ro': (_p('localenames|locale:ro', "Romanian, Romania"), "Română"), + 'fi': (_p('localenames|locale:fi', "Finnish"), "Suomi"), + 'sv-SE': (_p('localenames|locale:sv-SE', "Swedish"), "Svenska"), + 'vi': (_p('localenames|locale:vi', "Vietnamese"), "Tiếng Việt"), + 'tr': (_p('localenames|locale:tr', "Turkish"), "Türkçe"), + 'cs': (_p('localenames|locale:cs', "Czech"), "Čeština"), + 'el': (_p('localenames|locale:el', "Greek"), "Ελληνικά"), + 'bg': (_p('localenames|locale:bg', "Bulgarian"), "български"), + 'ru': (_p('localenames|locale:ru', "Russian"), "Pусский"), + 'uk': (_p('localenames|locale:uk', "Ukrainian"), "Українська"), + 'hi': (_p('localenames|locale:hi', "Hindi"), "हिन्दी"), + 'th': (_p('localenames|locale:th', "Thai"), "ไทย"), + 'zh-CN': (_p('localenames|locale:zh-CN', "Chinese, China"), "中文"), + 'ja': (_p('localenames|locale:ja', "Japanese"), "日本語"), + 'zh-TW': (_p('localenames|locale:zh-TW', "Chinese, Taiwan"), "繁體中文"), + 'ko': (_p('localenames|locale:ko', "Korean"), "한국어"), +} + +# More names for languages not supported by Discord +locale_names |= { + 'he': (_p('localenames|locale:he', "Hebrew"), "Hebrew"), + 'he-IL': (_p('localenames|locale:he-IL', "Hebrew"), "Hebrew"), + 'ceaser': (_p('localenames|locale:test', "Test Language"), "dfbtfs"), +} diff --git a/src/babel/translator.py b/src/babel/translator.py new file mode 100644 index 0000000..f3ec1d0 --- /dev/null +++ b/src/babel/translator.py @@ -0,0 +1,108 @@ +from typing import Optional +import logging +from contextvars import ContextVar +from collections import defaultdict +from enum import Enum + +import gettext + +from discord.app_commands import Translator, locale_str +from discord.enums import Locale + + +logger = logging.getLogger(__name__) + + +SOURCE_LOCALE = 'en_GB' +ctx_locale: ContextVar[str] = ContextVar('locale', default=SOURCE_LOCALE) +ctx_translator: ContextVar['LeoBabel'] = ContextVar('translator', default=None) # type: ignore + +null = gettext.NullTranslations() + + +class LeoBabel(Translator): + def __init__(self): + self.supported_locales = {loc.name for loc in Locale} + self.supported_domains = {} + self.translators = defaultdict(dict) # locale -> domain -> GNUTranslator + + async def load(self): + pass + + async def unload(self): + self.translators.clear() + + def get_translator(self, locale: Optional[str], domain): + return null + + def t(self, lazystr, locale=None): + return lazystr._translate_with(null) + + async def translate(self, string: locale_str, locale: Locale, context): + if not isinstance(string, LazyStr): + return string + else: + return string.message + +ctx_translator.set(LeoBabel()) + +class Method(Enum): + GETTEXT = 'gettext' + NGETTEXT = 'ngettext' + PGETTEXT = 'pgettext' + NPGETTEXT = 'npgettext' + + +class LocalBabel: + def __init__(self, domain): + self.domain = domain + + @property + def methods(self): + return (self._, self._n, self._p, self._np) + + def _(self, message): + return LazyStr(Method.GETTEXT, message, domain=self.domain) + + def _n(self, singular, plural, n): + return LazyStr(Method.NGETTEXT, singular, plural, n, domain=self.domain) + + def _p(self, context, message): + return LazyStr(Method.PGETTEXT, context, message, domain=self.domain) + + def _np(self, context, singular, plural, n): + return LazyStr(Method.NPGETTEXT, context, singular, plural, n, domain=self.domain) + + +class LazyStr(locale_str): + __slots__ = ('method', 'args', 'domain', 'locale') + + def __init__(self, method, *args, locale=None, domain=None): + self.method = method + self.args = args + self.domain = domain + self.locale = locale + + @property + def message(self): + return self._translate_with(null) + + @property + def extras(self): + return {'locale': self.locale, 'domain': self.domain} + + def __str__(self): + return self.message + + def _translate_with(self, translator: gettext.GNUTranslations): + method = getattr(translator, self.method.value) + return method(*self.args) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.method}, {self.args!r}, locale={self.locale}, domain={self.domain})' + + def __eq__(self, obj: object) -> bool: + return isinstance(obj, locale_str) and self.message == obj.message + + def __hash__(self) -> int: + return hash(self.args) diff --git a/src/babel/utils.py b/src/babel/utils.py new file mode 100644 index 0000000..e3faa1c --- /dev/null +++ b/src/babel/utils.py @@ -0,0 +1,20 @@ +from .translator import ctx_translator +from . import babel + +_, _p, _np = babel._, babel._p, babel._np + + +MONTHS = _p( + 'utils|months', + "January,February,March,April,May,June,July,August,September,October,November,December" +) + +SHORT_MONTHS = _p( + 'utils|short_months', + "Jan,Feb,Mar,Apr,May,Jun,Jul,Aug,Sep,Oct,Nov,Dec" +) + + +def local_month(month, short=False): + string = MONTHS if not short else SHORT_MONTHS + return ctx_translator.get().t(string).split(',')[month-1] diff --git a/src/core/__init__.py b/src/core/__init__.py index 2cbb58f..6be5d80 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -1,4 +1,6 @@ +from babel import LocalBabel +babel = LocalBabel('core') async def setup(bot): from .cog import CoreCog diff --git a/src/core/setting_types.py b/src/core/setting_types.py new file mode 100644 index 0000000..777b782 --- /dev/null +++ b/src/core/setting_types.py @@ -0,0 +1,227 @@ +""" +Additional abstract setting types useful for StudyLion settings. +""" +from typing import Optional +import json +import traceback + +import discord +from discord.enums import TextStyle + +from settings.base import ParentID +from settings.setting_types import IntegerSetting, StringSetting +from meta import conf +from meta.errors import UserInputError +from babel.translator import ctx_translator +from utils.lib import MessageArgs + +from . import babel + +_p = babel._p + + +class MessageSetting(StringSetting): + """ + Typed Setting ABC representing a message sent to Discord. + + Data is a json-formatted string dict with at least one of the fields 'content', 'embed', 'embeds' + Value is the corresponding dictionary + """ + # TODO: Extend to support format keys + + _accepts = _p( + 'settype:message|accepts', + "JSON formatted raw message data" + ) + + @staticmethod + async def download_attachment(attached: discord.Attachment): + """ + Download a discord.Attachment with some basic filetype and file size validation. + """ + t = ctx_translator.get().t + + error = None + decoded = None + if attached.content_type and not ('json' in attached.content_type): + error = t(_p( + 'settype:message|download|error:not_json', + "The attached message data is not a JSON file!" + )) + elif attached.size > 10000: + error = t(_p( + 'settype:message|download|error:size', + "The attached message data is too large!" + )) + else: + content = await attached.read() + try: + decoded = content.decode('UTF-8') + except UnicodeDecodeError: + error = t(_p( + 'settype:message|download|error:decoding', + "Could not decode the message data. Please ensure it is saved with the `UTF-8` encoding." + )) + + if error is not None: + raise UserInputError(error) + else: + return decoded + + @classmethod + def value_to_args(cls, parent_id: ParentID, value: dict, **kwargs) -> MessageArgs: + if not value: + return None + + args = {} + args['content'] = value.get('content', "") + if 'embed' in value: + embed = discord.Embed.from_dict(value['embed']) + args['embed'] = embed + if 'embeds' in value: + embeds = [] + for embed_data in value['embeds']: + embeds.append(discord.Embed.from_dict(embed_data)) + args['embeds'] = embeds + return MessageArgs(**args) + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value: Optional[dict], **kwargs): + if value and any(value.get(key, None) for key in ('content', 'embed', 'embeds')): + data = json.dumps(value) + else: + data = None + return data + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data: Optional[str], **kwargs): + if data: + value = json.loads(data) + else: + value = None + return value + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Provided user string can be downright random. + + If it isn't json-formatted, treat it as the content of the message. + If it is, do basic checking on the length and embeds. + """ + string = string.strip() + if not string or string.lower() == 'none': + return None + + t = ctx_translator.get().t + + error_tip = t(_p( + 'settype:message|error_suffix', + "You can view, test, and fix your embed using the online [embed builder]({link})." + )).format( + link="https://glitchii.github.io/embedbuilder/?editor=json" + ) + + if string.startswith('{') and string.endswith('}'): + # Assume the string is a json-formatted message dict + try: + value = json.loads(string) + except json.JSONDecodeError as err: + error = t(_p( + 'settype:message|error:invalid_json', + "The provided message data was not a valid JSON document!\n" + "`{error}`" + )).format(error=str(err)) + raise UserInputError(error + '\n' + error_tip) + + if not isinstance(value, dict) or not any(value.get(key, None) for key in ('content', 'embed', 'embeds')): + error = t(_p( + 'settype:message|error:json_missing_keys', + "Message data must be a JSON object with at least one of the following fields: " + "`content`, `embed`, `embeds`" + )) + raise UserInputError(error + '\n' + error_tip) + + embed_data = value.get('embed', None) + if not isinstance(embed_data, dict): + error = t(_p( + 'settype:message|error:json_embed_type', + "`embed` field must be a valid JSON object." + )) + raise UserInputError(error + '\n' + error_tip) + + embeds_data = value.get('embeds', []) + if not isinstance(embeds_data, list): + error = t(_p( + 'settype:message|error:json_embeds_type', + "`embeds` field must be a list." + )) + raise UserInputError(error + '\n' + error_tip) + + if embed_data and embeds_data: + error = t(_p( + 'settype:message|error:json_embed_embeds', + "Message data cannot include both `embed` and `embeds`." + )) + raise UserInputError(error + '\n' + error_tip) + + content_data = value.get('content', "") + if not isinstance(content_data, str): + error = t(_p( + 'settype:message|error:json_content_type', + "`content` field must be a string." + )) + raise UserInputError(error + '\n' + error_tip) + + # Validate embeds, which is the most likely place for something to go wrong + embeds = [embed_data] if embed_data else embeds_data + try: + for embed in embeds: + discord.Embed.from_dict(embed) + except Exception as e: + # from_dict may raise a range of possible exceptions. + raw_error = ''.join( + traceback.TracebackException.from_exception(e).format_exception_only() + ) + error = t(_p( + 'ui:settype:message|error:embed_conversion', + "Could not parse the message embed data.\n" + "**Error:** `{exception}`" + )).format(exception=raw_error) + raise UserInputError(error + '\n' + error_tip) + + # At this point, the message will at least successfully convert into MessageArgs + # There are numerous ways it could still be invalid, e.g. invalid urls, or too-long fields + # or the total message content being too long, or too many fields, etc + # This will need to be caught in anything which displays a message parsed from user data. + else: + # Either the string is not json formatted, or the formatting is broken + # Assume the string is a content message + value = { + 'content': string + } + return json.dumps(value) + + @classmethod + def _format_data(cls, parent_id: ParentID, data: Optional[str], **kwargs): + if not data: + return None + + value = cls._data_to_value(parent_id, data, **kwargs) + content = value.get('content', "") + if 'embed' in value or 'embeds' in value or len(content) > 100: + t = ctx_translator.get().t + formatted = t(_p( + 'settype:message|format:too_long', + "Too long to display! See Preview." + )) + else: + formatted = content + + return formatted + + @property + def input_field(self): + field = super().input_field + field.style = TextStyle.long + return field diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index b7e3e65..88fa0b2 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -12,6 +12,7 @@ from aiohttp import ClientSession from data import Database from utils.lib import tabulate +from babel.translator import LeoBabel from .config import Conf from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context @@ -43,6 +44,7 @@ class LionBot(Bot): self.shardname = shardname # self.appdata = appdata self.config = config + self.translator = LeoBabel() self.system_monitor = SystemMonitor() self.monitor = ComponentMonitor('LionBot', self._monitor_status) diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 099c81b..a137bc9 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -3,6 +3,7 @@ this_package = 'modules' active = [ '.sysadmin', '.voicefix', + '.streamalerts', ] diff --git a/src/modules/streamalerts/__init__.py b/src/modules/streamalerts/__init__.py new file mode 100644 index 0000000..4fd5c1b --- /dev/null +++ b/src/modules/streamalerts/__init__.py @@ -0,0 +1,8 @@ +import logging +from meta import LionBot + +logger = logging.getLogger(__name__) + +async def setup(bot: LionBot): + from .cog import AlertCog + await bot.add_cog(AlertCog(bot)) diff --git a/src/modules/streamalerts/cog.py b/src/modules/streamalerts/cog.py new file mode 100644 index 0000000..8762cd7 --- /dev/null +++ b/src/modules/streamalerts/cog.py @@ -0,0 +1,609 @@ +import asyncio +from typing import Optional + +import discord +from discord.ext import commands as cmds +from discord import app_commands as appcmds + +from twitchAPI.twitch import Twitch +from twitchAPI.helper import first + +from meta import LionBot, LionCog, LionContext +from meta.errors import UserInputError +from meta.logger import log_wrap +from utils.lib import utc_now +from data.conditions import NULL + +from . import logger +from .data import AlertsData +from .settings import AlertConfig, AlertSettings +from .editor import AlertEditorUI + + +class AlertCog(LionCog): + POLL_PERIOD = 60 + + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(AlertsData()) + self.twitch = None + self.alert_settings = AlertSettings() + + self.poll_task = None + self.event_tasks = set() + + # Cache of currently live streams, maps streamerid -> stream + self.live_streams = {} + + # Cache of streamers we are watching state changes for + # Map of streamerid -> streamer + self.watching = {} + + async def cog_load(self): + await self.data.init() + + await self.twitch_login() + await self.load_subs() + self.poll_task = asyncio.create_task(self.poll_live()) + + async def twitch_login(self): + # TODO: Probably abstract this out to core or a dedicated core cog + # Also handle refresh tokens + if self.twitch is not None: + await self.twitch.close() + self.twitch = None + + self.twitch = await Twitch( + self.bot.config.twitch['app_id'].strip(), + self.bot.config.twitch['app_secret'].strip() + ) + + async def load_subs(self): + # Load active subscriptions + active_subs = await self.data.AlertChannel.fetch_where() + to_watch = {sub.streamerid for sub in active_subs} + live_streams = await self.data.Stream.fetch_where( + self.data.Stream.end_at != NULL + ) + to_watch.union(stream.streamerid for stream in live_streams) + + # Load associated streamers + watching = {} + if to_watch: + streamers = await self.data.Streamer.fetch_where( + userid=list(to_watch) + ) + for streamer in streamers: + watching[streamer.userid] = streamer + + self.watching = watching + self.live_streams = {stream.streamerid: stream for stream in live_streams} + + logger.info( + f"Watching {len(watching)} streamers for state changes. " + f"Loaded {len(live_streams)} (previously) live streams into cache." + ) + + async def poll_live(self): + # Every PERIOD seconds, + # request get_streams for the streamers we are currently watching. + # Check if they are in the live_stream cache, + # and update cache and data and fire-and-forget start/stop events as required. + # TODO: Logging + # TODO: Error handling so the poll loop doesn't die from temporary errors + # And when it does die it gets logged properly. + if not self.twitch: + raise ValueError("Attempting to start alert poll-loop before twitch set.") + + block_i = 0 + + self.polling = True + while self.polling: + await asyncio.sleep(self.POLL_PERIOD) + + to_request = list(self.watching.keys()) + if not to_request: + continue + # Each loop we request the 'next' slice of 100 userids + blocks = [to_request[i:i+100] for i in range(0, len(to_request), 100)] + block_i += 1 + block_i %= len(blocks) + block = blocks[block_i] + + streaming = {} + async for stream in self.twitch.get_streams(user_id=block, first=100): + # Note we set page size to 100 + # So we should never get repeat or missed streams + # Since we can request a max of 100 userids anyway. + streaming[stream.user_id] = stream + + started = set(streaming.keys()).difference(self.live_streams.keys()) + ended = set(self.live_streams.keys()).difference(streaming.keys()) + + for streamerid in started: + stream = streaming[streamerid] + stream_data = await self.data.Stream.create( + streamerid=stream.user_id, + start_at=stream.started_at, + twitch_stream_id=stream.id, + game_name=stream.game_name, + title=stream.title, + ) + self.live_streams[streamerid] = stream_data + task = asyncio.create_task(self.on_stream_start(stream_data)) + self.event_tasks.add(task) + task.add_done_callback(self.event_tasks.discard) + + for streamerid in ended: + stream_data = self.live_streams.pop(streamerid) + await stream_data.update(end_at=utc_now()) + task = asyncio.create_task(self.on_stream_end(stream_data)) + self.event_tasks.add(task) + task.add_done_callback(self.event_tasks.discard) + + async def on_stream_start(self, stream_data): + # Get channel subscriptions listening for this streamer + uid = stream_data.streamerid + logger.info(f"Streamer started streaming! {stream_data=}") + subbed = await self.data.AlertChannel.fetch_where(streamerid=uid) + + # Fulfill those alerts + for sub in subbed: + try: + # If the sub is paused, don't create the alert + await self.sub_alert(sub, stream_data) + except discord.HTTPException: + # TODO: Needs to be handled more gracefully at user level + # Retry logic? + logger.warning( + f"Could not complete subscription {sub=} for {stream_data=}", exc_info=True + ) + except Exception: + logger.exception( + f"Unexpected exception completing {sub=} for {stream_data=}" + ) + raise + + async def subscription_error(self, subscription, stream_data, err_msg): + """ + Handle a subscription fulfill failure. + Stores the error message for user display, + and deletes the subscription after some number of errors. + # TODO + """ + logger.warning( + f"Subscription error {subscription=} {stream_data=} {err_msg=}" + ) + + async def sub_alert(self, subscription, stream_data): + # Base alert behaviour is just to send a message + # and create an alert row + + channel = self.bot.get_channel(subscription.channelid) + if channel is None or not isinstance(channel, discord.abc.Messageable): + # Subscription channel is gone! + # Or the Discord channel cache died + await self.subscription_error( + subscription, stream_data, + "Subscription channel no longer exists." + ) + return + permissions = channel.permissions_for(channel.guild.me) + if not (permissions.send_messages and permissions.embed_links): + await self.subscription_error( + subscription, stream_data, + "Insufficient permissions to post alert message." + ) + return + + # Build message + streamer = await self.data.Streamer.fetch(stream_data.streamerid) + if not streamer: + # Streamer was deleted while handling the alert + # Just quietly ignore + # Don't error out because the stream data row won't exist anymore + logger.warning( + f"Cancelling alert for subscription {subscription.subscriptionid}" + " because the streamer no longer exists." + ) + return + + alert_config = AlertConfig(subscription.subscriptionid, subscription) + paused = alert_config.get(self.alert_settings.AlertPaused.setting_id) + if paused.value: + logger.info(f"Skipping alert for subscription {subscription=} because it is paused.") + return + + live_message = alert_config.get(self.alert_settings.AlertMessage.setting_id) + + formatter = await live_message.generate_formatter(self.bot, stream_data, streamer) + formatted = await formatter(live_message.value) + args = live_message.value_to_args(subscription.subscriptionid, formatted) + + try: + message = await channel.send(**args.send_args) + except discord.HTTPException as e: + logger.warning( + f"Message send failure while sending streamalert {subscription.subscriptionid}", + exc_info=True + ) + await self.subscription_error( + subscription, stream_data, + "Failed to post live alert." + ) + return + + # Store sent alert + alert = await self.data.StreamAlert.create( + streamid=stream_data.streamid, + subscriptionid=subscription.subscriptionid, + sent_at=utc_now(), + messageid=message.id + ) + logger.debug( + f"Fulfilled subscription {subscription.subscriptionid} with alert {alert.alertid}" + ) + + async def on_stream_end(self, stream_data): + # Get channel subscriptions listening for this streamer + uid = stream_data.streamerid + logger.info(f"Streamer stopped streaming! {stream_data=}") + subbed = await self.data.AlertChannel.fetch_where(streamerid=uid) + + # Resolve subscriptions + for sub in subbed: + try: + await self.sub_resolve(sub, stream_data) + except discord.HTTPException: + # TODO: Needs to be handled more gracefully at user level + # Retry logic? + logger.warning( + f"Could not resolve subscription {sub=} for {stream_data=}", exc_info=True + ) + except Exception: + logger.exception( + f"Unexpected exception resolving {sub=} for {stream_data=}" + ) + raise + + async def sub_resolve(self, subscription, stream_data): + # Check if there is a current active alert to resolve + alerts = await self.data.StreamAlert.fetch_where( + streamid=stream_data.streamid, + subscriptionid=subscription.subscriptionid, + ) + if not alerts: + logger.info( + f"Resolution requested for subscription {subscription.subscriptionid} with stream {stream_data.streamid} " + "but no active alerts were found." + ) + return + alert = alerts[0] + if alert.resolved_at is not None: + # Alert was already resolved + # This is okay, Twitch might have just sent the stream ending twice + logger.info( + f"Resolution requested for subscription {subscription.subscriptionid} with stream {stream_data.streamid} " + "but alert was already resolved." + ) + return + + # Check if message is to be deleted or edited (or nothing) + alert_config = AlertConfig(subscription.subscriptionid, subscription) + del_setting = alert_config.get(self.alert_settings.AlertEndDelete.setting_id) + edit_setting = alert_config.get(self.alert_settings.AlertEndMessage.setting_id) + + if (delmsg := del_setting.value) or (edit_setting.value): + # Find the message + message = None + channel = self.bot.get_channel(subscription.channelid) + if channel: + try: + message = await channel.fetch_message(alert.messageid) + except discord.HTTPException: + # Message was probably deleted already + # Or permissions were changed + # Or Discord connection broke + pass + else: + # Channel went after posting the alert + # Or Discord cache sucks + # Nothing we can do, just mark it handled + pass + if message: + if delmsg: + # Delete the message + try: + await message.delete() + except discord.HTTPException: + logger.warning( + f"Discord exception while del-resolve live alert {alert=}", + exc_info=True + ) + else: + # Edit message with custom arguments + streamer = await self.data.Streamer.fetch(stream_data.streamerid) + formatter = await edit_setting.generate_formatter(self.bot, stream_data, streamer) + formatted = await formatter(edit_setting.value) + args = edit_setting.value_to_args(subscription.subscriptionid, formatted) + try: + await message.edit(**args.edit_args) + except discord.HTTPException: + logger.warning( + f"Discord exception while edit-resolve live alert {alert=}", + exc_info=True + ) + else: + # Explicitly don't need to do anything to the alert + pass + + # Save alert as resolved + await alert.update(resolved_at=utc_now()) + + async def cog_unload(self): + if self.poll_task is not None and not self.poll_task.cancelled(): + self.poll_task.cancel() + + if self.twitch is not None: + await self.twitch.close() + self.twitch = None + + # ----- Commands ----- + @cmds.hybrid_group( + name='streamalert', + description=( + "Create and configure stream live-alerts." + ) + ) + @cmds.guild_only() + @appcmds.default_permissions(manage_channels=True) + async def streamalert_group(self, ctx: LionContext): + # Placeholder group, method not used + raise NotImplementedError + + @streamalert_group.command( + name='create', + description=( + "Subscribe a Discord channel to notifications when a Twitch stream goes live." + ) + ) + @appcmds.describe( + streamer="Name of the twitch channel to watch.", + channel="Which Discord channel to send live alerts in.", + message="Custom message to send when the channel goes live (may be edited later)." + ) + @appcmds.default_permissions(manage_channels=True) + async def streamalert_create_cmd(self, ctx: LionContext, + streamer: str, + channel: discord.TextChannel, + message: Optional[str]): + # Type guards + assert ctx.guild is not None, "Guild-only command has no guild ctx." + assert self.twitch is not None, "Twitch command run with no twitch obj." + + # Wards + if not channel.permissions_for(ctx.author).manage_channels: + await ctx.error_reply( + "Sorry, you need the `MANAGE_CHANNELS` permission " + "to add a stream alert to a channel." + ) + return + + # Look up the specified streamer + tw_user = await first(self.twitch.get_users(logins=[streamer])) + if not tw_user: + await ctx.error_reply( + f"Sorry, could not find `{streamer}` on Twitch! " + "Make sure you use the name in their channel url." + ) + return + + # Create streamer data if it doesn't already exist + streamer_data = await self.data.Streamer.fetch_or_create( + tw_user.id, + login_name=tw_user.login, + display_name=tw_user.display_name, + ) + + # Add subscription to alerts list + sub_data = await self.data.AlertChannel.create( + streamerid=streamer_data.userid, + guildid=channel.guild.id, + channelid=channel.id, + created_by=ctx.author.id, + paused=False + ) + + # Add to watchlist + self.watching[streamer_data.userid] = streamer_data + + # Open AlertEditorUI for the new subscription + # TODO + await ctx.reply("StreamAlert Created.") + + async def alert_acmpl(self, interaction: discord.Interaction, partial: str): + if not interaction.guild: + raise ValueError("Cannot acmpl alert in guildless interaction.") + + # Get all alerts in the server + alerts = await self.data.AlertChannel.fetch_where(guildid=interaction.guild_id) + + if not alerts: + # No alerts available + options = [ + appcmds.Choice( + name="No stream alerts are set up in this server!", + value=partial + ) + ] + else: + options = [] + for alert in alerts: + streamer = await self.data.Streamer.fetch(alert.streamerid) + if streamer is None: + # Should be impossible by foreign key condition + # Might be a stale cache + continue + channel = interaction.guild.get_channel(alert.channelid) + display = f"{streamer.display_name} in #{channel.name if channel else 'unknown'}" + if partial.lower() in display.lower(): + # Matching option + options.append(appcmds.Choice(name=display, value=str(alert.subscriptionid))) + if not options: + options.append( + appcmds.Choice( + name=f"No stream alerts matching {partial}"[:25], + value=partial + ) + ) + return options + + async def resolve_alert(self, interaction: discord.Interaction, alert_str: str): + if not interaction.guild: + raise ValueError("Resolving alert outside of a guild.") + # Expect alert_str to be the integer subscriptionid + if not alert_str.isdigit(): + raise UserInputError( + f"No stream alerts in this server matching `{alert_str}`!" + ) + alert = await self.data.AlertChannel.fetch(int(alert_str)) + if not alert or not alert.guildid == interaction.guild_id: + raise UserInputError( + "Could not find the selected alert! Please try again." + ) + return alert + + @streamalert_group.command( + name='edit', + description=( + "Update settings for an existing Twitch stream alert." + ) + ) + @appcmds.describe( + alert="Which alert do you want to edit?", + # TODO: Other settings here + ) + @appcmds.default_permissions(manage_channels=True) + async def streamalert_edit_cmd(self, ctx: LionContext, alert: str): + # Type guards + assert ctx.guild is not None, "Guild-only command has no guild ctx." + assert self.twitch is not None, "Twitch command run with no twitch obj." + assert ctx.interaction is not None, "Twitch command needs interaction ctx." + + # Look up provided alert + sub_data = await self.resolve_alert(ctx.interaction, alert) + + # Check user permissions for editing this alert + channel = ctx.guild.get_channel(sub_data.channelid) + permlevel = channel if channel else ctx.guild + if not permlevel.permissions_for(ctx.author).manage_channels: + await ctx.error_reply( + "Sorry, you need the `MANAGE_CHANNELS` permission " + "in this channel to edit the stream alert." + ) + return + # If edit options have been given, save edits and retouch cache if needed + # If not, open AlertEditorUI + ui = AlertEditorUI(bot=self.bot, sub_data=sub_data, callerid=ctx.author.id) + await ui.run(ctx.interaction) + await ui.wait() + + @streamalert_edit_cmd.autocomplete('alert') + async def streamalert_edit_cmd_alert_acmpl(self, interaction, partial): + return await self.alert_acmpl(interaction, partial) + + @streamalert_group.command( + name='pause', + description=( + "Pause a streamalert." + ) + ) + @appcmds.describe( + alert="Which alert do you want to pause?", + ) + @appcmds.default_permissions(manage_channels=True) + async def streamalert_pause_cmd(self, ctx: LionContext, alert: str): + # Type guards + assert ctx.guild is not None, "Guild-only command has no guild ctx." + assert self.twitch is not None, "Twitch command run with no twitch obj." + assert ctx.interaction is not None, "Twitch command needs interaction ctx." + + # Look up provided alert + sub_data = await self.resolve_alert(ctx.interaction, alert) + + # Check user permissions for editing this alert + channel = ctx.guild.get_channel(sub_data.channelid) + permlevel = channel if channel else ctx.guild + if not permlevel.permissions_for(ctx.author).manage_channels: + await ctx.error_reply( + "Sorry, you need the `MANAGE_CHANNELS` permission " + "in this channel to edit the stream alert." + ) + return + + await sub_data.update(paused=True) + await ctx.reply("This alert is now paused!") + + @streamalert_group.command( + name='unpause', + description=( + "Resume a streamalert." + ) + ) + @appcmds.describe( + alert="Which alert do you want to unpause?", + ) + @appcmds.default_permissions(manage_channels=True) + async def streamalert_unpause_cmd(self, ctx: LionContext, alert: str): + # Type guards + assert ctx.guild is not None, "Guild-only command has no guild ctx." + assert self.twitch is not None, "Twitch command run with no twitch obj." + assert ctx.interaction is not None, "Twitch command needs interaction ctx." + + # Look up provided alert + sub_data = await self.resolve_alert(ctx.interaction, alert) + + # Check user permissions for editing this alert + channel = ctx.guild.get_channel(sub_data.channelid) + permlevel = channel if channel else ctx.guild + if not permlevel.permissions_for(ctx.author).manage_channels: + await ctx.error_reply( + "Sorry, you need the `MANAGE_CHANNELS` permission " + "in this channel to edit the stream alert." + ) + return + + await sub_data.update(paused=False) + await ctx.reply("This alert has been unpaused!") + + @streamalert_group.command( + name='remove', + description=( + "Deactivate a streamalert entirely (see /streamalert pause to temporarily pause it)." + ) + ) + @appcmds.describe( + alert="Which alert do you want to remove?", + ) + @appcmds.default_permissions(manage_channels=True) + async def streamalert_remove_cmd(self, ctx: LionContext, alert: str): + # Type guards + assert ctx.guild is not None, "Guild-only command has no guild ctx." + assert self.twitch is not None, "Twitch command run with no twitch obj." + assert ctx.interaction is not None, "Twitch command needs interaction ctx." + + # Look up provided alert + sub_data = await self.resolve_alert(ctx.interaction, alert) + + # Check user permissions for editing this alert + channel = ctx.guild.get_channel(sub_data.channelid) + permlevel = channel if channel else ctx.guild + if not permlevel.permissions_for(ctx.author).manage_channels: + await ctx.error_reply( + "Sorry, you need the `MANAGE_CHANNELS` permission " + "in this channel to edit the stream alert." + ) + return + + await sub_data.delete() + await ctx.reply("This alert has been deleted.") diff --git a/src/modules/streamalerts/data.py b/src/modules/streamalerts/data.py new file mode 100644 index 0000000..645907f --- /dev/null +++ b/src/modules/streamalerts/data.py @@ -0,0 +1,105 @@ +from data import Registry, RowModel +from data.columns import Integer, Bool, Timestamp, String +from data.models import WeakCache +from cachetools import TTLCache + + +class AlertsData(Registry): + class Streamer(RowModel): + """ + Schema + ------ + CREATE TABLE streamers( + userid BIGINT PRIMARY KEY, + login_name TEXT NOT NULL, + display_name TEXT NOT NULL + ); + """ + _tablename_ = 'streamers' + _cache_ = {} + + userid = Integer(primary=True) + login_name = String() + display_name = String() + + class AlertChannel(RowModel): + """ + Schema + ------ + CREATE TABLE alert_channels( + subscriptionid SERIAL PRIMARY KEY, + guildid BIGINT NOT NULL, + channelid BIGINT NOT NULL, + streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE, + created_by BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + paused BOOLEAN NOT NULL DEFAULT FALSE, + end_delete BOOLEAN NOT NULL DEFAULT FALSE, + live_message TEXT, + end_message TEXT + ); + CREATE INDEX alert_channels_guilds ON alert_channels (guildid); + CREATE UNIQUE INDEX alert_channels_channelid_streamerid ON alert_channels (channelid, streamerid); + """ + _tablename_ = 'alert_channels' + _cache_ = {} + + subscriptionid = Integer(primary=True) + guildid = Integer() + channelid = Integer() + streamerid = Integer() + display_name = Integer() + created_by = Integer() + created_at = Timestamp() + paused = Bool() + end_delete = Bool() + live_message = String() + end_message = String() + + class Stream(RowModel): + """ + Schema + ------ + CREATE TABLE streams( + streamid SERIAL PRIMARY KEY, + streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE, + start_at TIMESTAMPTZ NOT NULL, + twitch_stream_id BIGINT, + game_name TEXT, + title TEXT, + end_at TIMESTAMPTZ + ); + """ + _tablename_ = 'streams' + _cache_ = WeakCache(TTLCache(maxsize=100, ttl=24*60*60)) + + streamid = Integer(primary=True) + streamerid = Integer() + start_at = Timestamp() + twitch_stream_id = Integer() + game_name = String() + title = String() + end_at = Timestamp() + + class StreamAlert(RowModel): + """ + Schema + ------ + CREATE TABLE stream_alerts( + alertid SERIAL PRIMARY KEY, + streamid INTEGER NOT NULL REFERENCES streams (streamid) ON DELETE CASCADE, + subscriptionid INTEGER NOT NULL REFERENCES alert_channels (subscriptionid) ON DELETE CASCADE, + sent_at TIMESTAMPTZ NOT NULL, + messageid BIGINT NOT NULL, + resolved_at TIMESTAMPTZ + ); + """ + _tablename_ = 'stream_alerts' + _cache_ = WeakCache(TTLCache(maxsize=1000, ttl=24*60*60)) + + alertid = Integer(primary=True) + streamid = Integer() + subscriptionid = Integer() + sent_at = Timestamp() + messageid = Integer() + resolved_at = Timestamp() diff --git a/src/modules/streamalerts/editor.py b/src/modules/streamalerts/editor.py new file mode 100644 index 0000000..e87f175 --- /dev/null +++ b/src/modules/streamalerts/editor.py @@ -0,0 +1,369 @@ +import asyncio +import datetime as dt +from collections import namedtuple +from functools import wraps +from typing import TYPE_CHECKING + +import discord +from discord.ui.button import button, Button, ButtonStyle +from discord.ui.select import select, Select, SelectOption, ChannelSelect + +from meta import LionBot, conf + +from utils.lib import MessageArgs, tabulate, utc_now +from utils.ui import MessageUI +from utils.ui.msgeditor import MsgEditor + +from .settings import AlertSettings as Settings +from .settings import AlertConfig as Config +from .data import AlertsData + +if TYPE_CHECKING: + from .cog import AlertCog + + +FakeStream = namedtuple( + 'FakeStream', + ["streamid", "streamerid", "start_at", "twitch_stream_id", "game_name", "title", "end_at"] +) + + +class AlertEditorUI(MessageUI): + setting_classes = ( + Settings.AlertPaused, + Settings.AlertEndDelete, + Settings.AlertEndMessage, + Settings.AlertMessage, + Settings.AlertChannel, + ) + + def __init__(self, bot: LionBot, sub_data: AlertsData.AlertChannel, **kwargs): + super().__init__(**kwargs) + + self.bot = bot + self.sub_data = sub_data + self.subid = sub_data.subscriptionid + self.cog: 'AlertCog' = bot.get_cog('AlertCog') + self.config = Config(self.subid, sub_data) + + # ----- UI API ----- + def preview_stream_data(self): + # TODO: Probably makes sense to factor this out to the cog + # Or even generate it in the formatters themselves + data = self.sub_data + return FakeStream( + -1, + data.streamerid, + utc_now() - dt.timedelta(hours=1), + -1, + "Discord Admin", + "Testing Go Live Message", + utc_now() + ) + + def call_and_refresh(self, func): + """ + Generate a wrapper which runs coroutine 'func' and then refreshes the UI. + """ + # TODO: Check whether the UI has finished interaction + @wraps(func) + async def wrapped(*args, **kwargs): + await func(*args, **kwargs) + await self.refresh() + return wrapped + + # ----- UI Components ----- + + # Pause button + @button(label="PAUSE_PLACEHOLDER", style=ButtonStyle.blurple) + async def pause_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + setting = self.config.get(Settings.AlertPaused.setting_id) + setting.value = not setting.value + await setting.write() + await self.refresh(thinking=press) + + async def pause_button_refresh(self): + button = self.pause_button + if self.config.get(Settings.AlertPaused.setting_id).value: + button.label = "UnPause" + button.style = ButtonStyle.grey + else: + button.label = "Pause" + button.style = ButtonStyle.green + + # Delete button + @button(label="Delete Alert", style=ButtonStyle.red) + async def delete_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + await self.sub_data.delete() + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description="Stream alert removed." + ) + await press.edit_original_response(embed=embed) + await self.close() + + # Close button + @button(emoji=conf.emojis.cancel, style=ButtonStyle.red) + async def close_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=False) + await self.close() + + # Edit Alert button + @button(label="Edit Alert", style=ButtonStyle.blurple) + async def edit_alert_button(self, press: discord.Interaction, pressed: Button): + # Spawn MsgEditor for the live alert + await press.response.defer(thinking=True, ephemeral=True) + + setting = self.config.get(Settings.AlertMessage.setting_id) + + stream = self.preview_stream_data() + streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid) + + editor = MsgEditor( + self.bot, + setting.value, + callback=self.call_and_refresh(setting.editor_callback), + formatter=await setting.generate_formatter(self.bot, stream, streamer), + callerid=press.user.id + ) + self._slaves.append(editor) + await editor.run(press) + + # Edit End message + @button(label="Edit Ending Alert", style=ButtonStyle.blurple) + async def edit_end_button(self, press: discord.Interaction, pressed: Button): + # Spawn MsgEditor for the ending alert + await press.response.defer(thinking=True, ephemeral=True) + await self.open_end_editor(press) + + async def open_end_editor(self, respond_to: discord.Interaction): + setting = self.config.get(Settings.AlertEndMessage.setting_id) + # Start from current live alert data if not set + if not setting.value: + alert_setting = self.config.get(Settings.AlertMessage.setting_id) + setting.value = alert_setting.value + + stream = self.preview_stream_data() + streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid) + + editor = MsgEditor( + self.bot, + setting.value, + callback=self.call_and_refresh(setting.editor_callback), + formatter=await setting.generate_formatter(self.bot, stream, streamer), + callerid=respond_to.user.id + ) + self._slaves.append(editor) + await editor.run(respond_to) + return editor + + # Ending Mode Menu + @select( + cls=Select, + placeholder="Select action to take when the stream ends", + options=[SelectOption(label="DUMMY")], + min_values=0, max_values=1 + ) + async def ending_mode_menu(self, selection: discord.Interaction, selected: Select): + if not selected.values: + await selection.response.defer() + return + + await selection.response.defer(thinking=True, ephemeral=True) + value = selected.values[0] + + if value == '0': + # In Do Nothing case, + # Ensure Delete is off and custom edit message is unset + setting = self.config.get(Settings.AlertEndDelete.setting_id) + if setting.value: + setting.value = False + await setting.write() + setting = self.config.get(Settings.AlertEndMessage.setting_id) + if setting.value: + setting.value = None + await setting.write() + + await self.refresh(thinking=selection) + elif value == '1': + # In Delete Alert case, + # Set the delete setting to True + setting = self.config.get(Settings.AlertEndDelete.setting_id) + if not setting.value: + setting.value = True + await setting.write() + + await self.refresh(thinking=selection) + elif value == '2': + # In Edit Message case, + # Set the delete setting to False, + setting = self.config.get(Settings.AlertEndDelete.setting_id) + if setting.value: + setting.value = False + await setting.write() + + # And open the edit message editor + await self.open_end_editor(selection) + await self.refresh() + + async def ending_mode_menu_refresh(self): + # Build menu options + options = [ + SelectOption( + label="Do Nothing", + description="Don't modify the live alert message.", + value="0", + ), + SelectOption( + label="Delete Alert After Stream", + description="Delete the live alert message.", + value="1", + ), + SelectOption( + label="Edit Alert After Stream", + description="Edit the live alert message to a custom message. Opens editor.", + value="2", + ), + ] + + # Calculate the correct default + if self.config.get(Settings.AlertEndDelete.setting_id).value: + options[1].default = True + elif self.config.get(Settings.AlertEndMessage.setting_id).value: + options[2].default = True + + self.ending_mode_menu.options = options + + # Edit channel menu + @select(cls=ChannelSelect, + placeholder="Select Alert Channel", + channel_types=[discord.ChannelType.text, discord.ChannelType.voice], + min_values=0, max_values=1) + async def channel_menu(self, selection: discord.Interaction, selected): + if selected.values: + await selection.response.defer(thinking=True, ephemeral=True) + setting = self.config.get(Settings.AlertChannel.setting_id) + setting.value = selected.values[0] + await setting.write() + await self.refresh(thinking=selection) + else: + await selection.response.defer(thinking=False) + + async def channel_menu_refresh(self): + # current = self.config.get(Settings.AlertChannel.setting_id).value + # TODO: Check if discord-typed menus can have defaults yet + # Impl in stable dpy, but not released to pip yet + ... + + # ----- UI Flow ----- + async def make_message(self) -> MessageArgs: + streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid) + if streamer is None: + raise ValueError("Streamer row does not exist in AlertEditor") + name = streamer.display_name + + # Build relevant setting table + table_map = {} + table_map['Channel'] = self.config.get(Settings.AlertChannel.setting_id).formatted + table_map['Streamer'] = f"https://www.twitch.tv/{streamer.login_name}" + table_map['Paused'] = self.config.get(Settings.AlertPaused.setting_id).formatted + + prop_table = '\n'.join(tabulate(*table_map.items())) + + embed = discord.Embed( + colour=discord.Colour.dark_green(), + title=f"Stream Alert for {name}", + description=prop_table, + timestamp=utc_now() + ) + + message_setting = self.config.get(Settings.AlertMessage.setting_id) + message_desc_lines = [ + f"An alert message will be posted to {table_map['Channel']}.", + f"Press `{self.edit_alert_button.label}`" + " to preview or edit the alert.", + "The following keys will be substituted in the alert message." + ] + keytable = tabulate(*message_setting._subkey_desc.items()) + for line in keytable: + message_desc_lines.append(f"> {line}") + + embed.add_field( + name=f"When {name} goes live", + value='\n'.join(message_desc_lines), + inline=False + ) + + # Determine the ending behaviour + del_setting = self.config.get(Settings.AlertEndDelete.setting_id) + end_msg_setting = self.config.get(Settings.AlertEndMessage.setting_id) + + if del_setting.value: + # Deleting + end_msg_desc = "The live alert message will be deleted." + ... + elif end_msg_setting.value: + # Editing + lines = [ + "The live alert message will edited to the configured message.", + f"Press `{self.edit_end_button.label}` to preview or edit the message.", + "The following substitution keys are supported " + "*in addition* to the live alert keys." + ] + keytable = tabulate( + *[(k, v) for k, v in end_msg_setting._subkey_desc.items() if k not in message_setting._subkey_desc] + ) + for line in keytable: + lines.append(f"> {line}") + end_msg_desc = '\n'.join(lines) + else: + # Doing nothing + end_msg_desc = "The live alert message will not be changed." + + embed.add_field( + name=f"When {name} ends their stream", + value=end_msg_desc, + inline=False + ) + + return MessageArgs(embed=embed) + + async def reload(self): + await self.sub_data.refresh() + # Note self.config references the sub_data, and doesn't need reloading. + + async def refresh_layout(self): + to_refresh = ( + self.pause_button_refresh(), + self.channel_menu_refresh(), + self.ending_mode_menu_refresh(), + ) + await asyncio.gather(*to_refresh) + + show_end_edit = ( + not self.config.get(Settings.AlertEndDelete.setting_id).value + and + self.config.get(Settings.AlertEndMessage.setting_id).value + ) + + + if not show_end_edit: + # Don't show edit end button + buttons = ( + self.edit_alert_button, + self.pause_button, self.delete_button, self.close_button + ) + else: + buttons = ( + self.edit_alert_button, self.edit_end_button, + self.pause_button, self.delete_button, self.close_button + ) + + self.set_layout( + buttons, + (self.ending_mode_menu,), + (self.channel_menu,), + ) + diff --git a/src/modules/streamalerts/settings.py b/src/modules/streamalerts/settings.py new file mode 100644 index 0000000..42e2e94 --- /dev/null +++ b/src/modules/streamalerts/settings.py @@ -0,0 +1,264 @@ +from typing import Optional, Any +import json + +from meta.LionBot import LionBot +from settings import ModelData +from settings.groups import SettingGroup, ModelConfig, SettingDotDict +from settings.setting_types import BoolSetting, ChannelSetting +from core.setting_types import MessageSetting +from babel.translator import LocalBabel +from utils.lib import recurse_map, replace_multiple, tabulate + +from .data import AlertsData + + +babel = LocalBabel('streamalerts') +_p = babel._p + + +class AlertConfig(ModelConfig): + settings = SettingDotDict() + _model_settings = set() + model = AlertsData.AlertChannel + + +class AlertSettings(SettingGroup): + @AlertConfig.register_model_setting + class AlertMessage(ModelData, MessageSetting): + setting_id = 'alert_live_message' + _display_name = _p('', 'live_message') + + _desc = _p( + '', + 'Message sent to the channel when the streamer goes live.' + ) + _long_desc = _p( + '', + 'Message sent to the attached channel when the Twitch streamer goes live.' + ) + _accepts = _p('', 'JSON formatted greeting message data') + _default = json.dumps({'content': "**{display_name}** just went live at {channel_link}"}) + + _model = AlertsData.AlertChannel + _column = AlertsData.AlertChannel.live_message.name + + _subkey_desc = { + '{display_name}': "Twitch channel name (with capitalisation)", + '{login_name}': "Twitch channel login name (as in url)", + '{channel_link}': "Link to the live twitch channel", + '{stream_start}': "Numeric timestamp when stream went live", + } + # TODO: More stuff + + @property + def update_message(self) -> str: + return "The go-live notification message has been updated!" + + @classmethod + async def generate_formatter(cls, bot: LionBot, stream: AlertsData.Stream, streamer: AlertsData.Streamer, **kwargs): + """ + Generate a formatter function for this message + from the provided stream and streamer data. + + The formatter function accepts and returns a message data dict. + """ + async def formatter(data_dict: Optional[dict[str, Any]]): + if not data_dict: + return None + + mapping = { + '{display_name}': streamer.display_name, + '{login_name}': streamer.login_name, + '{channel_link}': f"https://www.twitch.tv/{streamer.login_name}", + '{stream_start}': int(stream.start_at.timestamp()), + } + + recurse_map( + lambda loc, value: replace_multiple(value, mapping) if isinstance(value, str) else value, + data_dict, + ) + return data_dict + return formatter + + async def editor_callback(self, editor_data): + self.value = editor_data + await self.write() + + def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]: + lines = super()._desc_table(show_value=show_value) + keytable = tabulate(*self._subkey_desc.items(), colon='') + expline = ( + "The following placeholders will be substituted with their values." + ) + keyfield = ( + "Placeholders", + expline + '\n' + '\n'.join(f"> {line}" for line in keytable) + ) + lines.append(keyfield) + return lines + + @AlertConfig.register_model_setting + class AlertEndMessage(ModelData, MessageSetting): + """ + Custom ending message to edit the live alert to. + If not set, doesn't edit the alert. + """ + setting_id = 'alert_end_message' + _display_name = _p('', 'end_message') + + _desc = _p( + '', + 'Optional message to edit the live alert with when the stream ends.' + ) + _long_desc = _p( + '', + "If set, and `end_delete` is not on, " + "the live alert will be edited with this custom message " + "when the stream ends." + ) + _accepts = _p('', 'JSON formatted greeting message data') + _default = None + + _model = AlertsData.AlertChannel + _column = AlertsData.AlertChannel.end_message.name + + _subkey_desc = { + '{display_name}': "Twitch channel name (with capitalisation)", + '{login_name}': "Twitch channel login name (as in url)", + '{channel_link}': "Link to the live twitch channel", + '{stream_start}': "Numeric timestamp when stream went live", + '{stream_end}': "Numeric timestamp when stream ended", + } + + @property + def update_message(self) -> str: + if self.value: + return "The stream ending message has been updated." + else: + return "The stream ending message has been unset." + + @classmethod + async def generate_formatter(cls, bot: LionBot, stream: AlertsData.Stream, streamer: AlertsData.Streamer, **kwargs): + """ + Generate a formatter function for this message + from the provided stream and streamer data. + + The formatter function accepts and returns a message data dict. + """ + # TODO: Fake stream data maker (namedtuple?) for previewing + async def formatter(data_dict: Optional[dict[str, Any]]): + if not data_dict: + return None + + mapping = { + '{display_name}': streamer.display_name, + '{login_name}': streamer.login_name, + '{channel_link}': f"https://www.twitch.tv/{streamer.login_name}", + '{stream_start}': int(stream.start_at.timestamp()), + '{stream_end}': int(stream.end_at.timestamp()), + } + + recurse_map( + lambda loc, value: replace_multiple(value, mapping) if isinstance(value, str) else value, + data_dict, + ) + return data_dict + return formatter + + async def editor_callback(self, editor_data): + self.value = editor_data + await self.write() + + def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]: + lines = super()._desc_table(show_value=show_value) + keytable = tabulate(*self._subkey_desc.items(), colon='') + expline = ( + "The following placeholders will be substituted with their values." + ) + keyfield = ( + "Placeholders", + expline + '\n' + '\n'.join(f"> {line}" for line in keytable) + ) + lines.append(keyfield) + return lines + ... + + @AlertConfig.register_model_setting + class AlertEndDelete(ModelData, BoolSetting): + """ + Whether to delete the live alert after the stream ends. + """ + setting_id = 'alert_end_delete' + _display_name = _p('', 'end_delete') + _desc = _p( + '', + 'Whether to delete the live alert after the stream ends.' + ) + _long_desc = _p( + '', + "If enabled, the live alert message will be deleted when the stream ends. " + "This overrides the `end_message` setting." + ) + _default = False + + _model = AlertsData.AlertChannel + _column = AlertsData.AlertChannel.end_delete.name + + @property + def update_message(self) -> str: + if self.value: + return "The live alert will be deleted at the end of the stream." + else: + return "The live alert will not be deleted when the stream ends." + + @AlertConfig.register_model_setting + class AlertPaused(ModelData, BoolSetting): + """ + Whether this live alert is currently paused. + """ + setting_id = 'alert_paused' + _display_name = _p('', 'paused') + _desc = _p( + '', + "Whether the alert is currently paused." + ) + _long_desc = _p( + '', + "Paused alerts will not trigger live notifications, " + "although the streams will still be tracked internally." + ) + _default = False + + _model = AlertsData.AlertChannel + _column = AlertsData.AlertChannel.paused.name + + @property + def update_message(self): + if self.value: + return "This alert is now paused" + else: + return "This alert has been unpaused" + + @AlertConfig.register_model_setting + class AlertChannel(ModelData, ChannelSetting): + """ + The channel associated to this alert. + """ + setting_id = 'alert_channel' + _display_name = _p('', 'channel') + _desc = _p( + '', + "The Discord channel this live alert will be sent in." + ) + _long_desc = _desc + + # Note that this cannot actually be None, + # as there is no UI pathway to unset the setting. + _default = None + + _model = AlertsData.AlertChannel + _column = AlertsData.AlertChannel.channelid.name + + @property + def update_message(self): + return f"This alert will now be posted to {self.value.channel.mention}" diff --git a/src/settings/__init__.py b/src/settings/__init__.py new file mode 100644 index 0000000..9c38704 --- /dev/null +++ b/src/settings/__init__.py @@ -0,0 +1,7 @@ +from babel.translator import LocalBabel +babel = LocalBabel('settings_base') + +from .data import ModelData, ListData +from .base import BaseSetting +from .ui import SettingWidget, InteractiveSetting +from .groups import SettingDotDict, SettingGroup, ModelSettings, ModelSetting diff --git a/src/settings/base.py b/src/settings/base.py new file mode 100644 index 0000000..0cbfd0d --- /dev/null +++ b/src/settings/base.py @@ -0,0 +1,166 @@ +from typing import Generic, TypeVar, Type, Optional, overload + + +""" +Setting metclass? +Parse setting docstring to generate default info? +Or just put it in the decorator we are already using +""" + + +# Typing using Generic[parent_id_type, data_type, value_type] +# value generic, could be Union[?, UNSET] +ParentID = TypeVar('ParentID') +SettingData = TypeVar('SettingData') +SettingValue = TypeVar('SettingValue') + +T = TypeVar('T', bound='BaseSetting') + + +class BaseSetting(Generic[ParentID, SettingData, SettingValue]): + """ + Abstract base class describing a stored configuration setting. + A setting consists of logic to load the setting from storage, + present it in a readable form, understand user entered values, + and write it again in storage. + Additionally, the setting has attributes attached describing + the setting in a user-friendly manner for display purposes. + """ + setting_id: str # Unique source identifier for the setting + + _default: Optional[SettingData] = None # Default data value for the setting + + def __init__(self, parent_id: ParentID, data: Optional[SettingData], **kwargs): + self.parent_id = parent_id + self._data = data + self.kwargs = kwargs + + # Instance generation + @classmethod + async def get(cls: Type[T], parent_id: ParentID, **kwargs) -> T: + """ + Return a setting instance initialised from the stored value, associated with the given parent id. + """ + data = await cls._reader(parent_id, **kwargs) + return cls(parent_id, data, **kwargs) + + # Main interface + @property + def data(self) -> Optional[SettingData]: + """ + Retrieves the current internal setting data if it is set, otherwise the default data + """ + return self._data if self._data is not None else self.default + + @data.setter + def data(self, new_data: Optional[SettingData]): + """ + Sets the internal raw data. + Does not write the changes. + """ + self._data = new_data + + @property + def default(self) -> Optional[SettingData]: + """ + Retrieves the default value for this setting. + Settings should override this if the default depends on the object id. + """ + return self._default + + @property + def value(self) -> SettingValue: # Actually optional *if* _default is None + """ + Context-aware object or objects associated with the setting. + """ + return self._data_to_value(self.parent_id, self.data) # type: ignore + + @value.setter + def value(self, new_value: Optional[SettingValue]): + """ + Setter which reads the discord-aware object and converts it to data. + Does not write the new value. + """ + self._data = self._data_from_value(self.parent_id, new_value) + + async def write(self, **kwargs) -> None: + """ + Write current data to the database. + For settings which override this, + ensure you handle deletion of values when internal data is None. + """ + await self._writer(self.parent_id, self._data, **kwargs) + + # Raw converters + @overload + @classmethod + def _data_from_value(cls: Type[T], parent_id: ParentID, value: SettingValue, **kwargs) -> SettingData: + ... + + @overload + @classmethod + def _data_from_value(cls: Type[T], parent_id: ParentID, value: None, **kwargs) -> None: + ... + + @classmethod + def _data_from_value( + cls: Type[T], parent_id: ParentID, value: Optional[SettingValue], **kwargs + ) -> Optional[SettingData]: + """ + Convert a high-level setting value to internal data. + Must be overridden by the setting. + Be aware of UNSET values, these should always pass through as None + to provide an unsetting interface. + """ + raise NotImplementedError + + @overload + @classmethod + def _data_to_value(cls: Type[T], parent_id: ParentID, data: SettingData, **kwargs) -> SettingValue: + ... + + @overload + @classmethod + def _data_to_value(cls: Type[T], parent_id: ParentID, data: None, **kwargs) -> None: + ... + + @classmethod + def _data_to_value( + cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs + ) -> Optional[SettingValue]: + """ + Convert internal data to high-level setting value. + Must be overriden by the setting. + """ + raise NotImplementedError + + # Database access + @classmethod + async def _reader(cls: Type[T], parent_id: ParentID, **kwargs) -> Optional[SettingData]: + """ + Retrieve the setting data associated with the given parent_id. + May be None if the setting is not set. + Must be overridden by the setting. + """ + raise NotImplementedError + + @classmethod + async def _writer(cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs) -> None: + """ + Write provided setting data to storage. + Must be overridden by the setting unless the `write` method is overridden. + If the data is None, the setting is UNSET and should be deleted. + """ + raise NotImplementedError + + @classmethod + async def setup(cls, bot): + """ + Initialisation task to be executed during client initialisation. + May be used for e.g. populating a cache or required client setup. + + Main application must execute the initialisation task before the setting is used. + Further, the task must always be executable, if the setting is loaded. + Conditional initialisation should go in the relevant module's init tasks. + """ + return None diff --git a/src/settings/data.py b/src/settings/data.py new file mode 100644 index 0000000..9f627ad --- /dev/null +++ b/src/settings/data.py @@ -0,0 +1,233 @@ +from typing import Type +import json + +from data import RowModel, Table, ORDER +from meta.logger import log_wrap, set_logging_context + + +class ModelData: + """ + Mixin for settings stored in a single row and column of a Model. + Assumes that the parent_id is the identity key of the Model. + + This does not create a reference to the Row. + """ + # Table storing the desired data + _model: Type[RowModel] + + # Column with the desired data + _column: str + + # Whether to create a row if not found + _create_row = False + + # High level data cache to use, leave as None to disable cache. + _cache = None # Map[id -> value] + + @classmethod + def _read_from_row(cls, parent_id, row, **kwargs): + data = row[cls._column] + + if cls._cache is not None: + cls._cache[parent_id] = data + + return data + + @classmethod + async def _reader(cls, parent_id, use_cache=True, **kwargs): + """ + Read in the requested column associated to the parent id. + """ + if cls._cache is not None and parent_id in cls._cache and use_cache: + return cls._cache[parent_id] + + model = cls._model + if cls._create_row: + row = await model.fetch_or_create(parent_id) + else: + row = await model.fetch(parent_id) + data = row[cls._column] if row else None + + if cls._cache is not None: + cls._cache[parent_id] = data + + return data + + @classmethod + async def _writer(cls, parent_id, data, **kwargs): + """ + Write the provided entry to the table. + This does *not* create the row if it does not exist. + It only updates. + """ + # TODO: Better way of getting the key? + # TODO: Transaction + if not isinstance(parent_id, tuple): + parent_id = (parent_id, ) + model = cls._model + rows = await model.table.update_where( + **model._dict_from_id(parent_id) + ).set( + **{cls._column: data} + ) + # If we didn't update any rows, create a new row + if not rows: + await model.fetch_or_create(**model._dict_from_id(parent_id), **{cls._column: data}) + + if cls._cache is not None: + cls._cache[parent_id] = data + + +class ListData: + """ + Mixin for list types implemented on a Table. + Implements a reader and writer. + This assumes the list is the only data stored in the table, + and removes list entries by deleting rows. + """ + setting_id: str + + # Table storing the setting data + _table_interface: Table + + # Name of the column storing the id + _id_column: str + + # Name of the column storing the data to read + _data_column: str + + # Name of column storing the order index to use, if any. Assumed to be Serial on writing. + _order_column: str + _order_type: ORDER = ORDER.ASC + + # High level data cache to use, set to None to disable cache. + _cache = None # Map[id -> value] + + @classmethod + @log_wrap(isolate=True) + async def _reader(cls, parent_id, use_cache=True, **kwargs): + """ + Read in all entries associated to the given id. + """ + set_logging_context(action="Read cls.setting_id") + if cls._cache is not None and parent_id in cls._cache and use_cache: + return cls._cache[parent_id] + + table = cls._table_interface # type: Table + query = table.select_where(**{cls._id_column: parent_id}).select(cls._data_column) + if cls._order_column: + query.order_by(cls._order_column, direction=cls._order_type) + + rows = await query + data = [row[cls._data_column] for row in rows] + + if cls._cache is not None: + cls._cache[parent_id] = data + + return data + + @classmethod + @log_wrap(isolate=True) + async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs): + """ + Write the provided list to storage. + """ + set_logging_context(action="Write cls.setting_id") + table = cls._table_interface + async with table.connector.connection() as conn: + table.connector.conn = conn + async with conn.transaction(): + # Handle None input as an empty list + if data is None: + data = [] + + current = await cls._reader(id, use_cache=False, **kwargs) + if not cls._order_column and (add_only or remove_only): + to_insert = [item for item in data if item not in current] if not remove_only else [] + to_remove = data if remove_only else ( + [item for item in current if item not in data] if not add_only else [] + ) + + # Handle required deletions + if to_remove: + params = { + cls._id_column: id, + cls._data_column: to_remove + } + await table.delete_where(**params) + + # Handle required insertions + if to_insert: + columns = (cls._id_column, cls._data_column) + values = [(id, value) for value in to_insert] + await table.insert_many(columns, *values) + + if cls._cache is not None: + new_current = [item for item in current + to_insert if item not in to_remove] + cls._cache[id] = new_current + else: + # Remove all and add all to preserve order + delete_params = {cls._id_column: id} + await table.delete_where(**delete_params) + + if data: + columns = (cls._id_column, cls._data_column) + values = [(id, value) for value in data] + await table.insert_many(columns, *values) + + if cls._cache is not None: + cls._cache[id] = data + + +class KeyValueData: + """ + Mixin for settings implemented in a Key-Value table. + The underlying table should have a Unique constraint on the `(_id_column, _key_column)` pair. + """ + _table_interface: Table + + _id_column: str + + _key_column: str + + _value_column: str + + _key: str + + @classmethod + async def _reader(cls, id, **kwargs): + params = { + cls._id_column: id, + cls._key_column: cls._key + } + + row = await cls._table_interface.select_one_where(**params).select(cls._value_column) + data = row[cls._value_column] if row else None + + if data is not None: + data = json.loads(data) + + return data + + @classmethod + async def _writer(cls, id, data, **kwargs): + params = { + cls._id_column: id, + cls._key_column: cls._key + } + if data is not None: + values = { + cls._value_column: json.dumps(data) + } + rows = await cls._table_interface.update_where(**params).set(**values) + if not rows: + await cls._table_interface.insert_many( + (cls._id_column, cls._key_column, cls._value_column), + (id, cls._key, json.dumps(data)) + ) + else: + await cls._table_interface.delete_where(**params) + + +# class UserInputError(SafeCancellation): +# pass diff --git a/src/settings/groups.py b/src/settings/groups.py new file mode 100644 index 0000000..3719a7b --- /dev/null +++ b/src/settings/groups.py @@ -0,0 +1,204 @@ +from typing import Generic, Type, TypeVar, Optional, overload + +from data import RowModel + +from .data import ModelData +from .ui import InteractiveSetting +from .base import BaseSetting + +from utils.lib import tabulate + + +T = TypeVar('T', bound=InteractiveSetting) + + +class SettingDotDict(Generic[T], dict[str, Type[T]]): + """ + Dictionary structure allowing simple dot access to items. + """ + __getattr__ = dict.__getitem__ # type: ignore + __setattr__ = dict.__setitem__ # type: ignore + __delattr__ = dict.__delitem__ # type: ignore + + +class SettingGroup: + """ + A SettingGroup is a collection of settings under one name. + """ + __initial_settings__: list[Type[InteractiveSetting]] = [] + + _title: Optional[str] = None + _description: Optional[str] = None + + def __init_subclass__(cls, title: Optional[str] = None): + cls._title = title or cls._title + cls._description = cls._description or cls.__doc__ + + settings: list[Type[InteractiveSetting]] = [] + for item in cls.__dict__.values(): + if isinstance(item, type) and issubclass(item, InteractiveSetting): + settings.append(item) + cls.__initial_settings__ = settings + + def __init_settings__(self): + settings = SettingDotDict() + for setting in self.__initial_settings__: + settings[setting.__name__] = setting + return settings + + def __init__(self, title=None, description=None) -> None: + self.title: str = title or self._title or self.__class__.__name__ + self.description: str = description or self._description or "" + self.settings: SettingDotDict[InteractiveSetting] = self.__init_settings__() + + def attach(self, cls: Type[T], name: Optional[str] = None): + name = name or cls.setting_id + self.settings[name] = cls + return cls + + def detach(self, cls): + return self.settings.pop(cls.__name__, None) + + def update(self, smap): + self.settings.update(smap.settings) + + def reduce(self, *keys): + for key in keys: + self.settings.pop(key, None) + return + + async def make_setting_table(self, parent_id, **kwargs): + """ + Convenience method for generating a rendered setting table. + """ + rows = [] + for setting in self.settings.values(): + if not setting._virtual: + set = await setting.get(parent_id, **kwargs) + name = set.display_name + value = str(set.formatted) + rows.append((name, value, set.hover_desc)) + table_rows = tabulate( + *rows, + row_format="[`{invis}{key:<{pad}}{colon}`](https://lionbot.org \"{field[2]}\")\t{value}" + ) + return '\n'.join(table_rows) + + +class ModelSetting(ModelData, BaseSetting): + ... + + +class ModelConfig: + """ + A ModelConfig provides a central point of configuration for any object described by a single Model. + + An instance of a ModelConfig represents configuration for a single object + (given by a single row of the corresponding Model). + + The ModelConfig also supports registration of non-model configuration, + to support associated settings (e.g. list-settings) for the object. + + This is an ABC, and must be subclassed for each object-type. + """ + settings: SettingDotDict + _model_settings: set + model: Type[RowModel] + + def __init__(self, parent_id, row, **kwargs): + self.parent_id = parent_id + self.row = row + self.kwargs = kwargs + + @classmethod + def register_setting(cls, setting_cls): + """ + Decorator to register a non-model setting as part of the object configuration. + + The setting class may be re-accessed through the `settings` class attr. + + Subclasses may provide alternative access pathways to key non-model settings. + """ + cls.settings[setting_cls.setting_id] = setting_cls + return setting_cls + + @classmethod + def register_model_setting(cls, model_setting_cls): + """ + Decorator to register a model setting as part of the object configuration. + + The setting class may be accessed through the `settings` class attr. + + A fresh setting instance may also be retrieved (using cached data) + through the `get` instance method. + + Subclasses are recommended to provide model settings as properties + for simplified access and type checking. + """ + cls._model_settings.add(model_setting_cls.setting_id) + return cls.register_setting(model_setting_cls) + + def get(self, setting_id): + """ + Retrieve a freshly initialised copy of the given model-setting. + + The given `setting_id` must have been previously registered through `register_model_setting`. + This uses cached data, and so is not guaranteed to be up-to-date. + """ + if setting_id not in self._model_settings: + # TODO: Log + raise ValueError + setting_cls = self.settings[setting_id] + data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs) + return setting_cls(self.parent_id, data, **self.kwargs) + + +class ModelSettings: + """ + A ModelSettings instance aggregates multiple `ModelSetting` instances + bound to the same parent id on a single Model. + + This enables a single point of access + for settings of a given Model, + with support for caching or deriving as needed. + + This is an abstract base class, + and should be subclassed to define the contained settings. + """ + _settings: SettingDotDict = SettingDotDict() + model: Type[RowModel] + + def __init__(self, parent_id, row, **kwargs): + self.parent_id = parent_id + self.row = row + self.kwargs = kwargs + + @classmethod + async def fetch(cls, *parent_id, **kwargs): + """ + Load an instance of this ModelSetting with the given parent_id + and setting keyword arguments. + """ + row = await cls.model.fetch_or_create(*parent_id) + return cls(parent_id, row, **kwargs) + + @classmethod + def attach(self, setting_cls): + """ + Decorator to attach the given setting class to this modelsetting. + """ + # This violates the interface principle, use structured typing instead? + if not (issubclass(setting_cls, BaseSetting) and issubclass(setting_cls, ModelData)): + raise ValueError( + f"The provided setting class must be `ModelSetting`, not {setting_cls.__class__.__name__}." + ) + self._settings[setting_cls.setting_id] = setting_cls + return setting_cls + + def get(self, setting_id): + setting_cls = self._settings.get(setting_id) + data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs) + return setting_cls(self.parent_id, data, **self.kwargs) + + def __getitem__(self, setting_id): + return self.get(setting_id) diff --git a/src/settings/mock.py b/src/settings/mock.py new file mode 100644 index 0000000..e80a4c9 --- /dev/null +++ b/src/settings/mock.py @@ -0,0 +1,13 @@ +import discord +from discord import app_commands + + +class LocalString: + def __init__(self, string): + self.string = string + + def as_string(self): + return self.string + + +_ = LocalString diff --git a/src/settings/setting_types.py b/src/settings/setting_types.py new file mode 100644 index 0000000..4cf4177 --- /dev/null +++ b/src/settings/setting_types.py @@ -0,0 +1,1393 @@ +from typing import Optional, Union, TYPE_CHECKING, TypeVar, Generic, Any, TypeAlias, Type +from enum import Enum + +import pytz +import discord +import discord.app_commands as appcmds + +import itertools +import datetime as dt +from discord import ui +from discord.ui.button import button, Button, ButtonStyle +from dateutil.parser import parse, ParserError + +from meta.context import ctx_bot +from meta.errors import UserInputError +from utils.lib import strfdur, parse_duration +from babel.translator import ctx_translator, LazyStr + +from .base import ParentID +from .ui import InteractiveSetting, SettingWidget +from . import babel + +_, _p = babel._, babel._p + + +if TYPE_CHECKING: + from discord.guild import GuildChannel + + +# TODO: Localise this file + + +class StringSetting(InteractiveSetting[ParentID, str, str]): + """ + Setting type mixin describing an arbitrary string type. + + Options + ------- + _maxlen: int + Maximum length of string to accept in `_parse_string`. + Default: 4000 + + _quote: bool + Whether to display the string with backticks. + Default: True + """ + + _accepts = _p('settype:string|accepts', "Any Text") + + _maxlen: int = 4000 + _quote: bool = True + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Return the provided value string as the data string. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Return the provided data string as the value string. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the user input `string` into StringSetting data. + Provides some minor input validation. + Treats an empty string as a `None` value. + """ + t = ctx_translator.get().t + if len(string) > cls._maxlen: + raise UserInputError( + t(_p( + 'settype:string|error', + "Provided string is too long! Maximum length: {maxlen} characters." + )).format(maxlen=cls._maxlen) + ) + elif len(string) == 0: + return None + else: + return string + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Optionally (see `_quote`) wrap the data string in backticks. + """ + if data: + return "`{}`".format(data) if cls._quote else str(data) + else: + return None + + +class EmojiSetting(InteractiveSetting[ParentID, str, str]): + """ + Setting type representing a stored emoji. + + The emoji is stored in a single string field, and at no time is guaranteed to be a valid emoji. + """ + _accepts = _p('settype:emoji|accepts', "Paste a builtin emoji, custom emoji, or emoji id.") + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Return the provided value string as the data string. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Return the provided data string as the value string. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the given user entered emoji string. + + Accepts unicode (builtin) emojis, custom emojis, and custom emoji ids. + """ + t = ctx_translator.get().t + + provided = string + string = string.strip(' :<>') + if string.startswith('a:'): + string = string[2:] + + if not string or string.lower() == 'none': + emojistr = None + elif string.isdigit(): + # Assume emoji id + emojistr = f"" + elif ':' in string: + # Assume custom emoji + emojistr = provided.strip() + elif string.isascii(): + # Probably not an emoji + raise UserInputError( + t(_p( + 'settype:emoji|error:parse', + "Could not parse `{provided}` as a Discord emoji. " + "Supported formats are builtin emojis (e.g. `{builtin}`), " + "custom emojis (e.g. {custom}), " + "or custom emoji ids (e.g. `{custom_id}`)." + )).format( + provided=provided, + builtin="🤔", + custom="*`<`*`CuteLeo:942499177135480942`*`>`*", + custom_id="942499177135480942", + ) + ) + else: + # We don't have a good way of testing for emoji unicode + # So just assume anything with unicode is an emoji. + emojistr = string + + return emojistr + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Optionally (see `_quote`) wrap the data string in backticks. + """ + if data: + return data + else: + return None + + @property + def as_partial(self) -> Optional[discord.PartialEmoji]: + return self._parse_emoji(self.data) + + @staticmethod + def _parse_emoji(emojistr: str): + """ + Converts a provided string into a PartialEmoji. + Deos not validate the emoji string. + """ + if not emojistr: + return None + elif ":" in emojistr: + emojistr = emojistr.strip('<>') + splits = emojistr.split(":") + if len(splits) == 3: + animated, name, id = splits + animated = bool(animated) + return discord.PartialEmoji(name=name, animated=animated, id=int(id)) + else: + return discord.PartialEmoji(name=emojistr) + + +CT = TypeVar('CT', 'GuildChannel', 'discord.Object', 'discord.Thread') +MCT = TypeVar('MCT', discord.TextChannel, discord.Thread, discord.VoiceChannel, discord.Object) + + +class ChannelSetting(Generic[ParentID, CT], InteractiveSetting[ParentID, int, CT]): + """ + Setting type mixin describing a Guild Channel. + + Options + ------- + _selector_placeholder: str + Placeholder to use in the Widget selector. + Default: "Select a channel" + + channel_types: list[discord.ChannelType] + List of guild channel types to accept. + Default: [] + """ + _accepts = _p('settype:channel|accepts', "A channel name or id") + + _selector_placeholder = "Select a Channel" + channel_types: list[discord.ChannelType] = [] + _allow_object = False + + @classmethod + def _data_from_value(cls, parent_id, value, **kwargs): + """ + Returns the id of the provided channel. + """ + if value is not None: + return value.id + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + """ + Searches for the provided channel id in the current channel cache. + If the channel cannot be found, returns a `discord.Object` instead. + """ + if data is not None: + bot = ctx_bot.get() + channel = bot.get_channel(data) + if channel is None and cls._allow_object: + channel = discord.Object(id=data) + return channel + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + if not string or string.lower() == 'none': + return None + + t = ctx_translator.get().t + bot = ctx_bot.get() + channel = None + guild = bot.get_guild(parent_id) + + if string.isdigit(): + maybe_id = int(string) + channel = guild.get_channel(maybe_id) + else: + channel = next((channel for channel in guild.channels if channel.name.lower() == string.lower()), None) + + if channel is None: + raise UserInputError(t(_p( + 'settype:channel|parse|error:not_found', + "Channel `{string}` could not be found in this guild!".format(string=string) + ))) + return channel.id + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Returns a manually formatted channel mention. + """ + if data: + return "<#{}>".format(data) + + @property + def input_formatted(self) -> str: + data = self._data + return str(data) if data else '' + + class Widget(SettingWidget['ChannelSetting']): + def update_children(self): + self.update_child( + self.channel_selector, { + 'channel_types': self.setting.channel_types, + 'placeholder': self.setting._selector_placeholder + } + ) + + def make_exports(self): + return [self.channel_selector] + + @ui.select( + cls=ui.ChannelSelect, + channel_types=[discord.ChannelType.text], + placeholder="Select a Channel", + max_values=1, + min_values=0 + ) + async def channel_selector(self, interaction: discord.Interaction, select: discord.ui.ChannelSelect) -> None: + await interaction.response.defer(thinking=True, ephemeral=True) + if select.values: + channel = select.values[0] + await self.setting.interactive_set(channel.id, interaction) + else: + await self.setting.interactive_set(None, interaction) + + +class VoiceChannelSetting(ChannelSetting): + """ + Setting type mixin representing a discord VoiceChannel. + Implemented as a narrowed `ChannelSetting`. + See `ChannelSetting` for options. + """ + channel_types = [discord.ChannelType.voice] + + +class MessageablelSetting(ChannelSetting): + """ + Setting type mixin representing a discord Messageable guild channel. + Implemented as a narrowed `ChannelSetting`. + See `ChannelSetting` for options. + """ + channel_types = [discord.ChannelType.text, discord.ChannelType.voice, discord.ChannelType.public_thread] + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + """ + Searches for the provided channel id in the current channel cache. + If the channel cannot be found, returns a `discord.PartialMessageable` instead. + """ + if data is not None: + bot = ctx_bot.get() + channel = bot.get_channel(data) + if channel is None: + channel = bot.get_partial_messageable(data, guild_id=parent_id) + return channel + + +class RoleSetting(InteractiveSetting[ParentID, int, Union[discord.Role, discord.Object]]): + """ + Setting type mixin describing a Guild Role. + + Options + ------- + _selector_placeholder: str + Placeholder to use in the Widget selector. + Default: "Select a Role" + """ + _accepts = _p('settype:role|accepts', "A role name or id") + + _selector_placeholder = "Select a Role" + _allow_object = False + + @classmethod + def _get_guildid(cls, parent_id: int, **kwargs) -> int: + """ + Fetch the current guildid. + Assumes that the guilid is either passed as a kwarg or is the object id. + Should be overridden in other cases. + """ + return kwargs.get('guildid', parent_id) + + @classmethod + def _data_from_value(cls, parent_id, value, **kwargs): + """ + Returns the id of the provided role. + """ + if value is not None: + return value.id + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + """ + Searches for the provided role id in the current channel cache. + If the channel cannot be found, returns a `discord.Object` instead. + """ + if data is not None: + role = None + + guildid = cls._get_guildid(parent_id, **kwargs) + bot = ctx_bot.get() + guild = bot.get_guild(guildid) + if guild is not None: + role = guild.get_role(data) + if role is None and cls._allow_object: + role = discord.Object(id=data) + return role + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + if not string or string.lower() == 'none': + return None + guildid = cls._get_guildid(parent_id, **kwargs) + + t = ctx_translator.get().t + bot = ctx_bot.get() + role = None + guild = bot.get_guild(guildid) + if guild is None: + raise ValueError("Attempting to parse role string with no guild.") + + if string.isdigit(): + maybe_id = int(string) + role = guild.get_role(maybe_id) + else: + role = next((role for role in guild.roles if role.name.lower() == string.lower()), None) + + if role is None: + raise UserInputError(t(_p( + 'settype:role|parse|error:not_found', + "Role `{string}` could not be found in this guild!".format(string=string) + ))) + return role.id + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Returns a manually formatted role mention. + """ + if data: + return "<@&{}>".format(data) + else: + return None + + @property + def input_formatted(self) -> str: + data = self._data + return str(data) if data else '' + + class Widget(SettingWidget['RoleSetting']): + def update_children(self): + self.update_child( + self.role_selector, + {'placeholder': self.setting._selector_placeholder} + ) + + def make_exports(self): + return [self.role_selector] + + @ui.select( + cls=ui.RoleSelect, + placeholder="Select a Role", + max_values=1, + min_values=0 + ) + async def role_selector(self, interaction: discord.Interaction, select: discord.ui.RoleSelect) -> None: + await interaction.response.defer(thinking=True, ephemeral=True) + if select.values: + role = select.values[0] + await self.setting.interactive_set(role.id, interaction) + else: + await self.setting.interactive_set(None, interaction) + + +class BoolSetting(InteractiveSetting[ParentID, bool, bool]): + """ + Setting type mixin describing a boolean. + + Options + ------- + _truthy: Set + Set of strings that are considered "truthy" in the parser. + Not case sensitive. + Default: {"yes", "true", "on", "enable", "enabled"} + + _falsey: Set + Set of strings that are considered "falsey" in the parser. + Not case sensitive. + Default: {"no", "false", "off", "disable", "disabled"} + + _outputs: tuple[str, str, str] + Strings to represent 'True', 'False', and 'None' values respectively. + Default: {True: "On", False: "Off", None: "Not Set"} + """ + + _accepts = _p('settype:bool|accepts', "Enabled/Disabled") + + # Values that are accepted as truthy and falsey by the parser + _truthy = _p( + 'settype:bool|parse:truthy_values', + "enabled|yes|true|on|enable|1" + ) + _falsey = _p( + 'settype:bool|parse:falsey_values', + 'disabled|no|false|off|disable|0' + ) + + # The user-friendly output strings to use for each value + _outputs = { + True: _p('settype:bool|output:true', "On"), + False: _p('settype:bool|output:false', "Off"), + None: _p('settype:bool|output:none', "Not Set"), + } + + # Button labels + _true_button_args: dict[str, Any] = {} + _false_button_args: dict[str, Any] = {} + _reset_button_args: dict[str, Any] = {} + + @classmethod + def truthy_values(cls) -> set: + t = ctx_translator.get().t + return t(cls._truthy).lower().split('|') + + @classmethod + def falsey_values(cls) -> set: + t = ctx_translator.get().t + return t(cls._falsey).lower().split('|') + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + t = ctx_translator.get().t + output = t(self._outputs[self._data]) + input_set = self.truthy_values() if self._data else self.falsey_values() + + if output.lower() in input_set: + return output + else: + return next(iter(input_set)) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Directly return provided value bool as data bool. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Directly return provided data bool as value bool. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Looks up the provided string in the truthy and falsey tables. + """ + _userstr = string.lower() + if not _userstr or _userstr == "none": + return None + if _userstr in cls.truthy_values(): + return True + elif _userstr in cls.falsey_values(): + return False + else: + raise UserInputError("Could not parse `{}` as a boolean.".format(string)) + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Use provided _outputs dictionary to format data. + """ + t = ctx_translator.get().t + return t(cls._outputs[data]) + + class Widget(SettingWidget['BoolSetting']): + def update_children(self): + self.update_child(self.true_button, self.setting._true_button_args) + self.update_child(self.false_button, self.setting._false_button_args) + self.update_child(self.reset_button, self.setting._reset_button_args) + self.order_children(self.true_button, self.false_button, self.reset_button) + + def make_exports(self): + return [self.true_button, self.false_button, self.reset_button] + + @button(style=ButtonStyle.secondary, label="On", row=4) + async def true_button(self, interaction: discord.Interaction, button: Button): + await interaction.response.defer(thinking=True, ephemeral=True) + await self.setting.interactive_set(True, interaction) + + @button(style=ButtonStyle.secondary, label="Off", row=4) + async def false_button(self, interaction: discord.Interaction, button: Button): + await interaction.response.defer(thinking=True, ephemeral=True) + await self.setting.interactive_set(False, interaction) + + +class IntegerSetting(InteractiveSetting[ParentID, int, int]): + """ + Setting type mixin describing a ranged integer. + As usual, override `_parse_string` to customise error messages. + + Options + ------- + _min: int + A minimum integer to accept. + Default: -2147483647 + + _max: int + A maximum integer to accept. + Default: 2147483647 + """ + _min = -2147483647 + _max = 2147483647 + + _accepts = _p('settype:integer|accepts', "An integer") + + @property + def input_formatted(self) -> str: + """ + Return a string representation of the set integer. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Directly return value integer as data integer. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Directly return data integer as value integer. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + if not string: + return None + try: + num = int(string) + except Exception: + raise UserInputError("Couldn't parse provided integer.") from None + + if num > cls._max: + raise UserInputError("Provided integer was too large!") + elif num < cls._min: + raise UserInputError("Provided integer was too small!") + + return num + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Returns the stringified integer in backticks. + """ + if data is not None: + return f"`{data}`" + + +class PartialEmojiSetting(InteractiveSetting[ParentID, str, discord.PartialEmoji]): + """ + Setting type mixin describing an Emoji string. + + Options + ------- + None + """ + + _accepts = _p('settype:emoji|desc', "Unicode or custom emoji") + + @staticmethod + def _parse_emoji(emojistr): + """ + Converts a provided string into a PartialEmoji. + If the string is badly formatted, returns None. + """ + if ":" in emojistr: + emojistr = emojistr.strip('<>') + splits = emojistr.split(":") + if len(splits) == 3: + animated, name, id = splits + animated = bool(animated) + return discord.PartialEmoji(name=name, animated=animated, id=int(id)) + else: + # TODO: Check whether this is a valid emoji + return discord.PartialEmoji(name=emojistr) + + @property + def input_formatted(self) -> str: + """ + Return the current data string. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Stringify the value emoji into a consistent data string. + """ + return str(value) if value is not None else None + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Convert the stored string into an emoji, through parse_emoji. + This may return None if the parsing syntax changes. + """ + return cls._parse_emoji(data) if data is not None else None + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the provided string into a PartialEmoji if possible. + """ + if string: + emoji = cls._parse_emoji(string) + if emoji is None: + raise UserInputError("Could not understand provided emoji!") + return str(emoji) + return None + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Emojis are pretty much self-formatting. Just return the data directly. + """ + return data + + +class GuildIDSetting(InteractiveSetting[ParentID, int, int]): + """ + Setting type mixin describing a guildid. + This acts like a pure integer type, apart from the formatting. + + Options + ------- + """ + _accepts = _p('settype:guildid|accepts', "Any Snowflake ID") + # TODO: Consider autocomplete for guilds the user is in + + @property + def input_formatted(self) -> str: + """ + Return a string representation of the stored snowflake. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Directly return value integer as data integer. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Directly return data integer as value integer. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + if not string: + return None + try: + num = int(string) + except Exception: + raise UserInputError("Couldn't parse provided guildid.") from None + return num + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Return the stored snowflake as a string. + If the guild is in cache, attach the name as well. + """ + if data is not None: + bot = ctx_bot.get() + guild = bot.get_guild(data) + if guild is not None: + return f"`{data}` ({guild.name})" + else: + return f"`{data}`" + + +TZT: TypeAlias = pytz.BaseTzInfo + + +class TimezoneSetting(InteractiveSetting[ParentID, str, TZT]): + """ + Typed Setting ABC representing timezone information. + """ + # TODO: Consider configuration UI for timezone by continent and country + # Do any continents have more than 25 countries? + # Maybe list e.g. Europe (Austria - Iceland) and Europe (Ireland - Ukraine) separately + + # TODO Definitely need autocomplete here + _accepts = _p( + 'settype:timezone|accepts', + "A timezone name from the 'tz database' (e.g. 'Europe/London')" + ) + + @property + def input_formatted(self) -> str: + """ + Return a string representation of the stored timezone. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Use str to transform the pytz timezone into a string. + """ + if value: + return str(value) + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Use pytz to convert the stored timezone string to a timezone. + """ + if data: + return pytz.timezone(data) + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + # TODO: Localise + # TODO: Another selection case. + if not string: + return None + try: + timezone = pytz.timezone(string) + except pytz.exceptions.UnknownTimeZoneError: + timezones = [tz for tz in pytz.all_timezones if string.lower() in tz.lower()] + if len(timezones) == 1: + timezone = timezones[0] + elif timezones: + raise UserInputError("Multiple matching timezones found!") + # TODO: Add a selector-message here instead of dying instantly + # Maybe only post a selector if there are less than 25 options! + + # result = await ctx.selector( + # "Multiple matching timezones found, please select one.", + # timezones + # ) + # timezone = timezones[result] + else: + raise UserInputError( + "Unknown timezone `{}`. " + "Please provide a TZ name from " + "[this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones)".format(string) + ) from None + return str(timezone) + + def _desc_table(self) -> list[str]: + translator = ctx_translator.get() + t = translator.t + + lines = super()._desc_table() + lines.append(( + t(_p( + 'settype:timezone|summary_table|field:supported|key', + "Supported" + )), + t(_p( + 'settype:timezone|summary_table|field:supported|value', + "Any timezone from the [tz database]({link})." + )).format(link="https://en.wikipedia.org/wiki/List_of_tz_database_time_zones") + )) + return lines + + @classmethod + async def parse_acmpl(cls, interaction: discord.Interaction, partial: str): + bot = interaction.client + t = bot.translator.t + + timezones = pytz.all_timezones + matching = [tz for tz in timezones if partial.strip().lower() in tz.lower()][:25] + if not matching: + choices = [ + appcmds.Choice( + name=t(_p( + 'set_type:timezone|acmpl|no_matching', + "No timezones matching '{input}'!" + )).format(input=partial)[:100], + value=partial + ) + ] + else: + choices = [] + for tz in matching: + timezone = pytz.timezone(tz) + now = dt.datetime.now(timezone) + nowstr = now.strftime("%H:%M") + name = t(_p( + 'set_type:timezone|acmpl|choice', + "{tz} (Currently {now})" + )).format(tz=tz, now=nowstr) + choice = appcmds.Choice( + name=name[:100], + value=tz + ) + choices.append(choice) + return choices + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Return the stored snowflake as a string. + If the guild is in cache, attach the name as well. + """ + if data is not None: + return f"`{data}`" + + +class TimestampSetting(InteractiveSetting[ParentID, str, dt.datetime]): + """ + Typed Setting ABC representing a fixed point in time. + + Data is assumed to be a timezone aware datetime object. + Value is the same as data. + Parsing accepts YYYY-MM-DD [HH:MM] [+TZ] + Display uses a discord timestamp. + """ + _accepts = _p( + 'settype:timestamp|accepts', + "A timestamp in the form YYYY-MM-DD HH:MM" + ) + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + return value + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data, **kwargs): + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + string = string.strip() + if string.lower() in ('', 'none', '0'): + ts = None + else: + local_tz = await cls._timezone_from_id(parent_id, **kwargs) + now = dt.datetime.now(tz=local_tz) + default = now.replace( + hour=0, minute=0, + second=0, microsecond=0 + ) + try: + ts = parse(string, fuzzy=True, default=default) + except ParserError: + t = ctx_translator.get().t + raise UserInputError(t(_p( + 'settype:timestamp|parse|error:invalid', + "Could not parse `{provided}` as a timestamp. Please use `YYYY-MM-DD HH:MM` format." + )).format(provided=string)) + return ts + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + if data is not None: + return "".format(int(data.timestamp())) + + @classmethod + async def _timezone_from_id(cls, parent_id: ParentID, **kwargs): + """ + Extract the parsing timezone from the given parent id. + + Should generally be overriden for interactive settings. + """ + return pytz.UTC + + @property + def input_formatted(self) -> str: + if self._data: + formatted = self._data.strftime('%Y-%m-%d %H:%M') + else: + formatted = '' + return formatted + + +class RawSetting(InteractiveSetting[ParentID, Any, Any]): + """ + Basic implementation of an interactive setting with identical value and data type. + """ + _accepts = _p('settype:raw|accepts', "Anything") + + @property + def input_formatted(self) -> str: + return str(self._data) if self._data is not None else '' + + @classmethod + def _data_from_value(cls, parent_id, value, **kwargs): + return value + + @classmethod + def _data_to_value(cls, parent_id, data, **kwargs): + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + return string + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + return str(data) if data is not None else None + + +ET = TypeVar('ET', bound='Enum') + + +class EnumSetting(InteractiveSetting[ParentID, ET, ET]): + """ + Typed InteractiveSetting ABC representing a stored Enum. + The Enum is assumed to be data adapted (e.g. through RegisterEnum). + + The embed of an enum setting should usually be overridden to describe the options. + + The default widget is implemented as a select menu, + although it may also make sense to implement using colour-changing buttons. + + Options + ------- + _enum: Enum + The Enum to act as a setting interface to. + _outputs: dict[Enum, str] + A map of enum items to output strings. + Describes how the enum should be formatted. + _inputs: dict[Enum, str] + A map of accepted input strings (not case sensitive) to enum items. + This should almost always include the strings from `_outputs`. + """ + + _enum: Type[ET] + _outputs: dict[ET, LazyStr] + _input_patterns: dict[ET: LazyStr] + _input_formatted: dict[ET: LazyStr] + + _accepts = _p('settype:enum|accepts', "A valid option.") + + @property + def input_formatted(self) -> str: + """ + Return the output string for the current data. + This assumes the output strings are accepted as inputs! + """ + t = ctx_translator.get().t + if self._data is not None: + return t(self._input_formatted[self._data]) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Return the provided value enum item as the data enum item. + """ + return value + + @classmethod + def _data_to_value(cls, id, data, **kwargs): + """ + Return the provided data enum item as the value enum item. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into an enum item. + """ + if not string: + return None + + string = string.lower() + t = ctx_translator.get().t + + found = None + for enumitem, pattern in cls._input_patterns.items(): + item_keys = set(t(pattern).lower().split('|')) + if string in item_keys: + found = enumitem + break + + if not found: + raise UserInputError( + t(_p( + 'settype:enum|parse|error:not_found', + "`{provided}` is not a valid option!" + )).format(provided=string) + ) + + return found + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Format the enum using the provided output map. + """ + t = ctx_translator.get().t + if data is not None: + if data not in cls._outputs: + raise ValueError(f"Enum item {data} unmapped.") + return t(cls._outputs[data]) + + +class DurationSetting(InteractiveSetting[ParentID, int, int]): + """ + Typed InteractiveSetting ABC representing a stored duration. + Stored and retrieved as an integer number of seconds. + Shown and set as a "duration string", e.g. "24h 10m 20s". + + Options + ------- + _max: int + Upper limit on the stored duration, in seconds. + Default: 60 * 60 * 24 * 365 + _min: Optional[int] + Lower limit on the stored duration, in seconds. + The duration can never be negative. + _default_multiplier: int + Default multiplier to use to convert the number when it is provided alone. + E.g. 1 for seconds, or 60 for minutes. + Default: 1 + allow_zero: bool + Whether to allow a zero duration. + The duration parser typically returns 0 when no duration is found, + so this may be useful for error checking. + Default: False + _show_days: bool + Whether to show days in the formatted output. + Default: False + """ + + _accepts = _p( + 'settype:duration|accepts', + "A number of days, hours, minutes, and seconds, e.g. `2d 4h 10s`." + ) + + # Set an upper limit on the duration + _max = 60 * 60 * 24 * 365 + _min = None + + # Default multiplier when the number is provided alone + # 1 for seconds, 60 from minutes, etc + _default_multiplier = None + + # Whether to allow empty durations + # This is particularly useful since the duration parser will return 0 for most non-duration strings + allow_zero = False + + # Whether to show days on the output + _show_days = False + + @property + def input_formatted(self) -> str: + """ + Return the formatted duration, which is accepted as input. + """ + if self._data is not None: + return strfdur(self._data, short=True, show_days=self._show_days) + else: + return "" + + @classmethod + def _data_from_value(cls, parent_id: ParentID, value, **kwargs): + """ + Passthrough the provided duration in seconds. + """ + return value + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data, **kwargs): + """ + Passthrough the provided duration in seconds. + """ + return data + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Parse the user input into a duration. + """ + if not string: + return None + + if cls._default_multiplier and string.isdigit(): + num = int(string) * cls._default_multiplier + else: + num = parse_duration(string) + + if num is None: + raise UserInputError("Could not parse the provided duration!") + + if num == 0 and not cls.allow_zero: + raise UserInputError( + "The provided duration cannot be `0`! (Please enter in the format `1d 2h 3m 4s`.)" + ) + + if cls._max is not None and num > cls._max: + raise UserInputError( + "Duration cannot be longer than `{}`!".format( + strfdur(cls._max, short=False, show_days=cls._show_days) + ) + ) + if cls._min is not None and num < cls._min: + raise UserInputError( + "Duration cannot be shorter than `{}`!".format( + strfdur(cls._min, short=False, show_days=cls._show_days) + ) + ) + + return num + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Format the enum using the provided output map. + """ + if data is not None: + return "`{}`".format(strfdur(data, short=False, show_days=cls._show_days)) + + +class ListSetting: + """ + Mixin to implement a setting type representing a list of existing settings. + + Does not implement a Widget, + since arbitrary combinations of setting widgets are undefined. + """ + # Base setting type to make the list from + _setting = None # type: Type[InteractiveSetting] + + # Whether 'None' values are filtered out of the data when creating values + _allow_null_values = False # type: bool + + # Whether duplicate data values should be filtered out + _force_unique = False + + @classmethod + def _data_from_value(cls, parent_id: ParentID, values, **kwargs): + """ + Returns the setting type data for each value in the value list + """ + if values is None: + # Special behaviour here, store an empty list instead of None + return [] + else: + return [cls._setting._data_from_value(parent_id, value) for value in values] + + @classmethod + def _data_to_value(cls, parent_id: ParentID, data, **kwargs): + """ + Returns the setting type value for each entry in the data list + """ + if data is None: + return [] + else: + values = [cls._setting._data_to_value(parent_id, entry) for entry in data] + + # Filter out null values if required + if not cls._allow_null_values: + values = [value for value in values if value is not None] + return values + + @classmethod + async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs): + """ + Splits the user string across `,` to break up the list. + """ + if not string: + return [] + else: + data = [] + items = (item.strip() for item in string.split(',')) + items = (item for item in items if item) + data = [await cls._setting._parse_string(parent_id, item, **kwargs) for item in items] + + if cls._force_unique: + data = list(set(data)) + return data + + @classmethod + def _format_data(cls, parent_id: ParentID, data, **kwargs): + """ + Format the list by adding `,` between each formatted item + """ + if data: + formatted_items = [] + for item in data: + formatted_item = cls._setting._format_data(id, item) + if formatted_item is not None: + formatted_items.append(formatted_item) + return ", ".join(formatted_items) + + @property + def input_formatted(self): + """ + Format the list by adding `,` between each input formatted item. + """ + if self._data: + formatted_items = [] + for item in self._data: + formatted_item = self._setting(self.parent_id, item).input_formatted + if formatted_item: + formatted_items.append(formatted_item) + return ", ".join(formatted_items) + else: + return "" + + +class ChannelListSetting(ListSetting, InteractiveSetting): + """ + List of channels + """ + _accepts = _p( + 'settype:channel_list|accepts', + "Comma separated list of channel ids." + ) + _setting = ChannelSetting + + +class RoleListSetting(ListSetting, InteractiveSetting): + """ + List of roles + """ + _accepts = _p( + 'settype:role_list|accepts', + 'Comma separated list of role ids.' + ) + _setting = RoleSetting + + @property + def members(self): + roles = self.value + return list(set(itertools.chain(*(role.members for role in roles)))) + + +class StringListSetting(InteractiveSetting, ListSetting): + """ + List of strings + """ + _accepts = _p( + 'settype:stringlist|accepts', + 'Comma separated strings.' + ) + _setting = StringSetting + + +class GuildIDListSetting(ListSetting, InteractiveSetting): + """ + List of guildids. + """ + _accepts = _p( + 'settype:guildidlist|accepts', + 'Comma separated list of guild ids.' + ) + + _setting = GuildIDSetting diff --git a/src/settings/ui.py b/src/settings/ui.py new file mode 100644 index 0000000..b2a263e --- /dev/null +++ b/src/settings/ui.py @@ -0,0 +1,512 @@ +from typing import Optional, Callable, Any, Dict, Coroutine, Generic, TypeVar, List +import asyncio +from contextvars import copy_context + +import discord +from discord import ui +from discord.ui.button import ButtonStyle, Button, button +from discord.ui.modal import Modal +from discord.ui.text_input import TextInput +from meta.errors import UserInputError + +from utils.lib import tabulate, recover_context +from utils.ui import FastModal +from meta.config import conf +from meta.context import ctx_bot +from babel.translator import ctx_translator, LazyStr + +from .base import BaseSetting, ParentID, SettingData, SettingValue +from . import babel + +_p = babel._p + + +ST = TypeVar('ST', bound='InteractiveSetting') + + +class SettingModal(FastModal): + input_field: TextInput = TextInput(label="Edit Setting") + + def update_field(self, new_field): + self.remove_item(self.input_field) + self.add_item(new_field) + self.input_field = new_field + + +class SettingWidget(Generic[ST], ui.View): + # TODO: Permission restrictions and callback! + # Context variables for permitted user(s)? Subclass ui.View with PermittedView? + # Don't need to descend permissions to Modal + # Maybe combine with timeout manager + + def __init__(self, setting: ST, auto_write=True, **kwargs): + self.setting = setting + self.update_children() + super().__init__(**kwargs) + self.auto_write = auto_write + + self._interaction: Optional[discord.Interaction] = None + self._modal: Optional[SettingModal] = None + self._exports: List[ui.Item] = self.make_exports() + + self._context = copy_context() + + def update_children(self): + """ + Method called before base View initialisation. + Allows updating the children components (usually explicitly defined callbacks), + before Item instantiation. + """ + pass + + def order_children(self, *children): + """ + Helper method to set and order the children using bound methods. + """ + child_map = {child.__name__: child for child in self.__view_children_items__} + self.__view_children_items__ = [child_map[child.__name__] for child in children] + + def update_child(self, child, new_args): + args = getattr(child, '__discord_ui_model_kwargs__') + args |= new_args + + def make_exports(self): + """ + Called post-instantiation to populate self._exports. + """ + return self.children + + def refresh(self): + """ + Update widget components from current setting data, if applicable. + E.g. to update the default entry in a select list after a choice has been made, + or update button colours. + This does not trigger a discord ui update, + that is the responsibility of the interaction handler. + """ + pass + + async def show(self, interaction: discord.Interaction, key: Any = None, override=False, **kwargs): + """ + Complete standard setting widget UI flow for this setting. + The SettingWidget components may be attached to other messages as needed, + and they may be triggered individually, + but this coroutine defines the standard interface. + Intended for use by any interaction which wants to "open the setting". + + Extra keyword arguments are passed directly to the interaction reply (for e.g. ephemeral). + """ + if key is None: + # By default, only have one widget listener per interaction. + key = ('widget', interaction.id) + + # If there is already a widget listening on this key, respect override + if self.setting.get_listener(key) and not override: + # Refuse to spawn another widget + return + + async def update_callback(new_data): + self.setting.data = new_data + await interaction.edit_original_response(embed=self.setting.embed, view=self, **kwargs) + + self.setting.register_callback(key)(update_callback) + await interaction.response.send_message(embed=self.setting.embed, view=self, **kwargs) + await self.wait() + try: + # Try and detach the view, since we aren't handling events anymore. + await interaction.edit_original_response(view=None) + except discord.HTTPException: + pass + self.setting.deregister_callback(key) + + def attach(self, group_view: ui.View): + """ + Attach this setting widget to a view representing several settings. + """ + for item in self._exports: + group_view.add_item(item) + + @button(style=ButtonStyle.secondary, label="Edit", row=4) + async def edit_button(self, interaction: discord.Interaction, button: ui.Button): + """ + Spawn a simple edit modal, + populated with `setting.input_field`. + """ + recover_context(self._context) + # Spawn the setting modal + await interaction.response.send_modal(self.modal) + + @button(style=ButtonStyle.danger, label="Reset", row=4) + async def reset_button(self, interaction: discord.Interaction, button: Button): + recover_context(self._context) + await interaction.response.defer(thinking=True, ephemeral=True) + await self.setting.interactive_set(None, interaction) + + @property + def modal(self) -> Modal: + """ + Build a Modal dialogue for updating the setting. + Refreshes (and re-attaches) the input field each time this is called. + """ + if self._modal is not None: + self._modal.update_field(self.setting.input_field) + return self._modal + + # TODO: Attach shared timeouts to the modal + self._modal = modal = SettingModal( + title=f"Edit {self.setting.display_name}", + ) + modal.update_field(self.setting.input_field) + + @modal.submit_callback() + async def edit_submit(interaction: discord.Interaction): + # TODO: Catch and handle UserInputError + await interaction.response.defer(thinking=True, ephemeral=True) + data = await self.setting._parse_string(self.setting.parent_id, modal.input_field.value) + await self.setting.interactive_set(data, interaction) + + return modal + + +class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]): + __slots__ = ('_widget',) + + # Configuration interface descriptions + _display_name: LazyStr # User readable name of the setting + _desc: LazyStr # User readable brief description of the setting + _long_desc: LazyStr # User readable long description of the setting + _accepts: LazyStr # User readable description of the acceptable values + _set_cmd: str = None + _notset_str: LazyStr = _p('setting|formatted|notset', "Not Set") + _virtual: bool = False # Whether the setting should be hidden from tables and dashboards + _required: bool = False + + Widget = SettingWidget + + # A list of callback coroutines to call when the setting updates + # This can be used globally to refresh state when the setting updates, + # Or locallly to e.g. refresh an active widget. + # The callbacks are called on write, so they may be bypassed by direct use of _writer! + _listeners_: Dict[Any, Callable[[Optional[SettingData]], Coroutine[Any, Any, None]]] = {} + + # Optional client event to dispatch when theis setting has been written + # Event handlers should be of the form Callable[ParentID, SettingData] + _event: Optional[str] = None + + # Interaction ward that should be validated via interaction_check + _write_ward: Optional[Callable[[discord.Interaction], Coroutine[Any, Any, bool]]] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._widget: Optional[SettingWidget] = None + + @property + def long_desc(self): + t = ctx_translator.get().t + bot = ctx_bot.get() + return t(self._long_desc).format( + bot=bot, + cmds=bot.core.mention_cache + ) + + @property + def display_name(self): + t = ctx_translator.get().t + return t(self._display_name) + + @property + def desc(self): + t = ctx_translator.get().t + return t(self._desc) + + @property + def accepts(self): + t = ctx_translator.get().t + return t(self._accepts) + + async def write(self, **kwargs) -> None: + await super().write(**kwargs) + self.dispatch_update() + for listener in self._listeners_.values(): + asyncio.create_task(listener(self.data)) + + def dispatch_update(self): + """ + Dispatch a client event along `self._event`, if set. + + Override to modify the target event handler arguments. + By default, event handlers should be of the form: + Callable[[ParentID, SettingData], Coroutine[Any, Any, None]] + """ + if self._event is not None and (bot := ctx_bot.get()) is not None: + bot.dispatch(self._event, self.parent_id, self) + + def get_listener(self, key): + return self._listeners_.get(key, None) + + @classmethod + def register_callback(cls, name=None): + def wrapped(coro): + cls._listeners_[name or coro.__name__] = coro + return coro + return wrapped + + @classmethod + def deregister_callback(cls, name): + cls._listeners_.pop(name, None) + + @property + def update_message(self): + """ + Response message sent when the setting has successfully been updated. + Should generally be one line. + """ + if self.data is None: + return "Setting reset!" + else: + return f"Setting Updated! New value: {self.formatted}" + + @property + def hover_desc(self): + """ + This no longer works since Discord changed the hover rules. + + return '\n'.join(( + self.display_name, + '=' * len(self.display_name), + self.desc, + f"\nAccepts: {self.accepts}" + )) + """ + return self.desc + + async def update_response(self, interaction: discord.Interaction, message: Optional[str] = None, **kwargs): + """ + Respond to an interaction which triggered a setting update. + Usually just wraps `update_message` in an embed and sends it back. + Passes any extra `kwargs` to the message creation method. + """ + embed = discord.Embed( + description=f"{str(conf.emojis.tick)} {message or self.update_message}", + colour=discord.Color.green() + ) + if interaction.response.is_done(): + await interaction.edit_original_response(embed=embed, **kwargs) + else: + await interaction.response.send_message(embed=embed, **kwargs) + + async def interactive_set(self, new_data: Optional[SettingData], interaction: discord.Interaction, **kwargs): + self.data = new_data + await self.write() + await self.update_response(interaction, **kwargs) + + async def format_in(self, bot, **kwargs): + """ + Formatted version of the setting given an asynchronous context with client. + """ + return self.formatted + + @property + def embed_field(self): + """ + Returns a {name, value} pair for use in an Embed field. + """ + name = self.display_name + value = f"{self.long_desc}\n{self.desc_table}" + if len(value) > 1024: + t = ctx_translator.get().t + desc_table = '\n'.join( + tabulate( + *self._desc_table( + show_value=t(_p( + 'setting|embed_field|too_long', + "Too long to display here!" + )) + ) + ) + ) + value = f"{self.long_desc}\n{desc_table}" + if len(value) > 1024: + # Forcibly trim + value = value[:1020] + '...' + return {'name': name, 'value': value} + + @property + def set_str(self): + if self._set_cmd is not None: + bot = ctx_bot.get() + if bot: + return bot.core.mention_cmd(self._set_cmd) + else: + return f"`/{self._set_cmd}`" + + @property + def notset_str(self): + t = ctx_translator.get().t + return t(self._notset_str) + + @property + def embed(self): + """ + Returns a full embed describing this setting. + """ + t = ctx_translator.get().t + embed = discord.Embed( + title=t(_p( + 'setting|summary_embed|title', + "Configuration options for `{name}`" + )).format(name=self.display_name), + ) + embed.description = "{}\n{}".format(self.long_desc.format(self=self), self.desc_table) + return embed + + def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]: + t = ctx_translator.get().t + lines = [] + + # Currently line + lines.append(( + t(_p('setting|summary_table|field:currently|key', "Currently")), + show_value or (self.formatted or self.notset_str) + )) + + # Default line + if (default := self.default) is not None: + lines.append(( + t(_p('setting|summary_table|field:default|key', "By Default")), + self._format_data(self.parent_id, default) or 'None' + )) + + # Set using line + if (set_str := self.set_str) is not None: + lines.append(( + t(_p('setting|summary_table|field:set|key', "Set Using")), + set_str + )) + return lines + + @property + def desc_table(self) -> str: + return '\n'.join(tabulate(*self._desc_table())) + + @property + def input_field(self) -> TextInput: + """ + TextInput field used for string-based setting modification. + May be added to external modal for grouped setting editing. + This property is not persistent, and creates a new field each time. + """ + return TextInput( + label=self.display_name, + placeholder=self.accepts, + default=self.input_formatted[:4000] if self.input_formatted else None, + required=self._required + ) + + @property + def widget(self): + """ + Returns the Discord UI View associated with the current setting. + """ + if self._widget is None: + self._widget = self.Widget(self) + return self._widget + + @classmethod + def set_widget(cls, WidgetCls): + """ + Convenience decorator to create the widget class for this setting. + """ + cls.Widget = WidgetCls + return WidgetCls + + @property + def formatted(self): + """ + Default user-readable form of the setting. + Should be a short single line. + """ + return self._format_data(self.parent_id, self.data, **self.kwargs) or self.notset_str + + @property + def input_formatted(self) -> str: + """ + Format the current value as a default value for an input field. + Returned string must be acceptable through parse_string. + Does not take into account defaults. + """ + if self._data is not None: + return str(self._data) + else: + return "" + + @property + def summary(self): + """ + Formatted summary of the data. + May be implemented in `_format_data(..., summary=True, ...)` or overidden. + """ + return self._format_data(self.parent_id, self.data, summary=True, **self.kwargs) + + @classmethod + async def from_string(cls, parent_id, userstr: str, **kwargs): + """ + Return a setting instance initialised from a parsed user string. + """ + data = await cls._parse_string(parent_id, userstr, **kwargs) + return cls(parent_id, data, **kwargs) + + @classmethod + async def from_value(cls, parent_id, value, **kwargs): + await cls._check_value(parent_id, value, **kwargs) + data = cls._data_from_value(parent_id, value, **kwargs) + return cls(parent_id, data, **kwargs) + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs) -> Optional[SettingData]: + """ + Parse user provided string (usually from a TextInput) into raw setting data. + Must be overriden by the setting if the setting is user-configurable. + Returns None if the setting was unset. + """ + raise NotImplementedError + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + """ + Convert raw setting data into a formatted user-readable string, + representing the current value. + """ + raise NotImplementedError + + @classmethod + async def _check_value(cls, parent_id, value, **kwargs): + """ + Check the provided value is valid. + + Many setting update methods now provide Discord objects instead of raw data or user strings. + This method may be used for value-checking such a value. + + Raises UserInputError if the value fails validation. + """ + pass + + @classmethod + async def interaction_check(cls, parent_id, interaction: discord.Interaction, **kwargs): + if cls._write_ward is not None and not await cls._write_ward(interaction): + # TODO: Combine the check system so we can do customised errors here + t = ctx_translator.get().t + raise UserInputError(t(_p( + 'setting|interaction_check|error', + "You do not have sufficient permissions to do this!" + ))) + + +""" +command callback for set command? +autocomplete for set command? + +Might be better in a ConfigSetting subclass. +But also mix into the base setting types. +""" diff --git a/src/utils/lib.py b/src/utils/lib.py index 30657ed..d7796fc 100644 --- a/src/utils/lib.py +++ b/src/utils/lib.py @@ -845,3 +845,35 @@ def write_records(records: list[dict[str, Any]], stream: StringIO): for record in records: stream.write(','.join(map(str, record.values()))) stream.write('\n') + + +parse_dur_exps = [ + ( + r"(?P\d+)\s*(?:(d)|(day))", + 60 * 60 * 24, + ), + ( + r"(?P\d+)\s*(?:(h)|(hour))", + 60 * 60 + ), + ( + r"(?P\d+)\s*(?:(m)|(min))", + 60 + ), + ( + r"(?P\d+)\s*(?:(s)|(sec))", + 1 + ) +] + + +def parse_duration(string: str) -> Optional[int]: + seconds = 0 + found = False + for expr, multiplier in parse_dur_exps: + match = re.search(expr, string, flags=re.IGNORECASE) + if match: + found = True + seconds += int(match.group('value')) * multiplier + + return seconds if found else None diff --git a/src/utils/ui/__init__.py b/src/utils/ui/__init__.py index 633b3a9..fd28a5b 100644 --- a/src/utils/ui/__init__.py +++ b/src/utils/ui/__init__.py @@ -1,8 +1,12 @@ import asyncio import logging +from babel.translator import LocalBabel logger = logging.getLogger(__name__) +util_babel = LocalBabel('utils') + from .hooked import * from .leo import * from .micros import * +from .msgeditor import * diff --git a/src/utils/ui/msgeditor.py b/src/utils/ui/msgeditor.py new file mode 100644 index 0000000..5e1cca2 --- /dev/null +++ b/src/utils/ui/msgeditor.py @@ -0,0 +1,1057 @@ +from typing import Optional +import asyncio +import copy +import json +import datetime as dt +from io import StringIO + +import discord +from discord.ui.button import button, Button, ButtonStyle +from discord.ui.select import select, Select, SelectOption +from discord.ui.text_input import TextInput, TextStyle + +from meta import conf, LionBot +from meta.errors import UserInputError, ResponseTimedOut + +from ..lib import MessageArgs, utc_now + +from . import MessageUI, util_babel, error_handler_for, FastModal, ModalRetryUI, Confirm, AsComponents, AButton + + +_p = util_babel._p + + +class MsgEditorInput(FastModal): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @error_handler_for(UserInputError) + async def rerequest(self, interaction, error): + await ModalRetryUI(self, error.msg).respond_to(interaction) + + +class MsgEditor(MessageUI): + def __init__(self, bot: LionBot, initial_data: dict, formatter=None, callback=None, **kwargs): + self.bot = bot + self.history = [initial_data] # Last item in history is current state + self.future = [] # Last item in future is next state + + self._formatter = formatter + self._callback = callback + + super().__init__(**kwargs) + + @property + def data(self): + return self.history[-1] + + # ----- API ----- + async def format_data(self, data): + """ + Format a MessageData dict for rendering. + + May be extended or overridden for custom formatting. + By default, uses the provided `formatter` callback (if provided). + """ + if self._formatter is not None: + await self._formatter(data) + + def copy_data(self): + return copy.deepcopy(self.history[-1]) + + async def save(self): + ... + + async def push_change(self, new_data): + # Cleanup the data + if (embed_data := new_data.get('embed', None)) is not None and not embed_data: + new_data.pop('embed') + + t = self.bot.translator.t + if 'embed' not in new_data and not new_data.get('content', None): + raise UserInputError( + t(_p( + 'ui:msg_editor|error:empty', + "Rendering failed! The message content and embed cannot both be empty." + )) + ) + + if 'embed' in new_data: + try: + discord.Embed.from_dict(new_data['embed']) + except Exception as e: + raise UserInputError( + t(_p( + 'ui:msg_editor|error:embed_failed', + "Rendering failed! Could not parse the embed.\n" + "Error: {error}" + )).format(error=str(e)) + ) + + # Push the state and try displaying it + self.history.append(new_data) + old_future = self.future + self.future = [] + try: + await self.refresh() + except discord.HTTPException as e: + # State failed, rollback and error + self.history.pop() + self.future = old_future + raise UserInputError( + t(_p( + 'ui:msg_editor|error:invalid_change', + "Rendering failed! The message was not modified.\n" + "Error: `{text}`" + )).format(text=e.text) + ) + + # ----- UI Components ----- + + # -- Content Only mode -- + @button(label="EDIT_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def edit_button(self, press: discord.Interaction, pressed: Button): + """ + Open an editor for the message content + """ + data = self.copy_data() + + t = self.bot.translator.t + content_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:content|field:content|label', + "Message Content" + )), + style=TextStyle.long, + required=False, + default=data.get('content', ""), + max_length=2000 + ) + modal = MsgEditorInput( + content_field, + title=t(_p('ui:msg_editor|modal:content|title', "Content Editor")) + ) + + @modal.submit_callback() + async def content_modal_callback(interaction: discord.Interaction): + new_content = content_field.value + data['content'] = new_content + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def edit_button_refresh(self): + t = self.bot.translator.t + button = self.edit_button + button.label = t(_p( + 'ui:msg_editor|button:edit|label', + "Edit Content" + )) + + @button(label="ADD_EMBED_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def add_embed_button(self, press: discord.Interaction, pressed: Button): + """ + Attach an embed with some simple fields filled. + """ + await press.response.defer() + t = self.bot.translator.t + + sample_embed = { + "title": t(_p('ui:msg_editor|button:add_embed|sample_embed|title', "Title Placeholder")), + "description": t(_p('ui:msg_editor|button:add_embed|sample_embed|description', "Description Placeholder")), + } + data = self.copy_data() + data['embed'] = sample_embed + await self.push_change(data) + + async def add_embed_button_refresh(self): + t = self.bot.translator.t + button = self.add_embed_button + button.label = t(_p( + 'ui:msg_editor|button:add_embed|label', + "Add Embed" + )) + + @button(label="RM_EMBED_BUTTON_PLACEHOLDER", style=ButtonStyle.red) + async def rm_embed_button(self, press: discord.Interaction, pressed: Button): + """ + Remove the existing embed from the message. + """ + await press.response.defer() + t = self.bot.translator.t + data = self.copy_data() + data.pop('embed', None) + data.pop('embeds', None) + if not data.get('content', '').strip(): + data['content'] = t(_p( + 'ui:msg_editor|button:rm_embed|sample_content', + "Content Placeholder" + )) + await self.push_change(data) + + async def rm_embed_button_refresh(self): + t = self.bot.translator.t + button = self.rm_embed_button + button.label = t(_p( + 'ui:msg_editor|button:rm_embed|label', + "Remove Embed" + )) + + # -- Embed Mode -- + + @button(label="BODY_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def body_button(self, press: discord.Interaction, pressed: Button): + """ + Edit the Content, Description, Title, and Colour + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + + t = self.bot.translator.t + + content_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:content|label', + "Message Content" + )), + style=TextStyle.long, + required=False, + default=data.get('content', ""), + max_length=2000 + ) + + desc_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:desc|label', + "Embed Description" + )), + style=TextStyle.long, + required=False, + default=embed_data.get('description', ""), + max_length=4000 + ) + + title_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:title|label', + "Embed Title" + )), + style=TextStyle.short, + required=False, + default=embed_data.get('title', ""), + max_length=256 + ) + + colour_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:body|field:colour|label', + "Embed Colour" + )), + style=TextStyle.short, + required=False, + default=str(discord.Colour(value=embed_data['color'])) if 'color' in embed_data else '', + placeholder=str(discord.Colour.orange()), + max_length=7, + min_length=7 + ) + + modal = MsgEditorInput( + content_field, + title_field, + desc_field, + colour_field, + title=t(_p('ui:msg_editor|modal:body|title', "Message Body Editor")) + ) + + @modal.submit_callback() + async def body_modal_callback(interaction: discord.Interaction): + data['content'] = content_field.value + + if desc_field.value: + embed_data['description'] = desc_field.value + else: + embed_data.pop('description', None) + + if title_field.value: + embed_data['title'] = title_field.value + else: + embed_data.pop('title', None) + + if colour_field.value: + colourstr = colour_field.value + try: + colour = discord.Colour.from_str(colourstr) + except ValueError: + raise UserInputError( + t(_p( + 'ui:msg_editor|button:body|error:invalid_colour', + "Invalid colour format! Please enter colours as hex codes, e.g. `#E67E22`" + )) + ) + embed_data['color'] = colour.value + else: + embed_data.pop('color', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def body_button_refresh(self): + t = self.bot.translator.t + button = self.body_button + button.label = t(_p( + 'ui:msg_editor|button:body|label', + "Body" + )) + + @button(label="AUTHOR_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def author_button(self, press: discord.Interaction, pressed: Button): + """ + Edit the embed author (author name/link/image url) + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + author_data = embed_data.get('author', {}) + + t = self.bot.translator.t + + name_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:author|field:name|label', + "Author Name" + )), + style=TextStyle.short, + required=False, + default=author_data.get('name', ''), + max_length=256 + ) + + link_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:author|field:link|label', + "Author URL" + )), + style=TextStyle.short, + required=False, + default=author_data.get('url', ''), + ) + + image_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:author|field:image|label', + "Author Image URL" + )), + style=TextStyle.short, + required=False, + default=author_data.get('icon_url', ''), + ) + + modal = MsgEditorInput( + name_field, + link_field, + image_field, + title=t(_p('ui:msg_editor|modal:author|title', "Embed Author Editor")) + ) + + @modal.submit_callback() + async def author_modal_callback(interaction: discord.Interaction): + if (name := name_field.value): + author_data['name'] = name + author_data['icon_url'] = image_field.value + author_data['url'] = link_field.value + embed_data['author'] = author_data + else: + embed_data.pop('author', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def author_button_refresh(self): + t = self.bot.translator.t + button = self.author_button + button.label = t(_p( + 'ui:msg_editor|button:author|label', + "Author" + )) + + @button(label="FOOTER_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def footer_button(self, press: discord.Interaction, pressed: Button): + """ + Open the Footer editor (edit footer icon, text, timestamp). + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + footer_data = embed_data.get('footer', {}) + + t = self.bot.translator.t + + text_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:footer|field:text|label', + "Footer Text" + )), + style=TextStyle.long, + required=False, + default=footer_data.get('text', ''), + max_length=2048 + ) + + image_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:footer|field:image|label', + "Footer Image URL" + )), + style=TextStyle.short, + required=False, + default=footer_data.get('icon_url', ''), + ) + + timestamp_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:footer|field:timestamp|label', + "Embed Timestamp (in ISO format)" + )), + style=TextStyle.short, + required=False, + default=embed_data.get('timestamp', ''), + placeholder=utc_now().replace(microsecond=0).isoformat(sep=' ') + ) + + modal = MsgEditorInput( + text_field, + image_field, + timestamp_field, + title=t(_p('ui:msg_editor|modal:footer|title', "Embed Footer Editor")) + ) + + @modal.submit_callback() + async def footer_modal_callback(interaction: discord.Interaction): + if (text := text_field.value): + footer_data['text'] = text + footer_data['icon_url'] = image_field.value + embed_data['footer'] = footer_data + else: + embed_data.pop('footer', None) + + if (ts := timestamp_field.value): + try: + dt.datetime.fromisoformat(ts) + except ValueError: + raise UserInputError( + t(_p( + 'ui:msg_editor|button:footer|error:invalid_timestamp', + "Invalid timestamp! Please enter the timestamp in ISO format." + )) + ) + embed_data['timestamp'] = ts + else: + embed_data.pop('timestamp', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def footer_button_refresh(self): + t = self.bot.translator.t + button = self.footer_button + button.label = t(_p( + 'ui:msg_editor|button:footer|label', + "Footer" + )) + + @button(label="IMAGES_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def images_button(self, press: discord.Interaction, pressed: Button): + """ + Edit the embed images (thumbnail and main image). + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + thumb_data = embed_data.get('thumbnail', {}) + image_data = embed_data.get('image', {}) + + t = self.bot.translator.t + + thumb_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:images|field:thumb|label', + "Thumbnail Image URL" + )), + style=TextStyle.short, + required=False, + default=thumb_data.get('url', ''), + ) + + image_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:images|field:image|label', + "Embed Image URL" + )), + style=TextStyle.short, + required=False, + default=image_data.get('url', ''), + ) + + modal = MsgEditorInput( + thumb_field, + image_field, + title=t(_p('ui:msg_editor|modal:images|title', "Embed images Editor")) + ) + + @modal.submit_callback() + async def images_modal_callback(interaction: discord.Interaction): + if (thumb_url := thumb_field.value): + thumb_data['url'] = thumb_url + embed_data['thumbnail'] = thumb_data + else: + embed_data.pop('thumbnail', None) + + if (image_url := image_field.value): + image_data['url'] = image_url + embed_data['image'] = image_data + else: + embed_data.pop('image', None) + + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def images_button_refresh(self): + t = self.bot.translator.t + button = self.images_button + button.label = t(_p( + 'ui:msg_editor|button:images|label', + "Images" + )) + + @button(label="ADD_FIELD_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def add_field_button(self, press: discord.Interaction, pressed: Button): + """ + Add an embed field (position, name, value, inline) + """ + data = self.copy_data() + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + orig_fields = field_data.copy() + + t = self.bot.translator.t + + position_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:position|label', + "Field number to insert at" + )), + style=TextStyle.short, + required=True, + default=str(len(field_data)), + ) + + name_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:name|label', + "Field name" + )), + style=TextStyle.short, + required=False, + max_length=256, + ) + + value_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:value|label', + "Field value" + )), + style=TextStyle.long, + required=True, + max_length=1024, + ) + + inline_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:add_field|field:inline|label', + "Whether the field is inline" + )), + placeholder=t(_p( + 'ui:msg_editor|modal:add_field|field:inline|placeholder', + "True/False" + )), + style=TextStyle.short, + required=True, + max_length=256, + default='True', + ) + + modal = MsgEditorInput( + name_field, + value_field, + position_field, + inline_field, + title=t(_p('ui:msg_editor|modal:add_field|title', "Add Embed Field")) + ) + + @modal.submit_callback() + async def add_field_modal_callback(interaction: discord.Interaction): + if inline_field.value.lower() == 'true': + inline = True + else: + inline = False + field = { + 'name': name_field.value, + 'value': value_field.value, + 'inline': inline + } + try: + position = int(position_field.value) + except ValueError: + raise UserInputError( + t(_p( + 'ui:msg_editor|modal:add_field|error:position_not_int', + "The field position must be an integer!" + )) + ) + field_data = orig_fields.copy() + field_data.insert(position, field) + embed_data['fields'] = field_data + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await press.response.send_modal(modal) + + async def add_field_button_refresh(self): + t = self.bot.translator.t + button = self.add_field_button + button.label = t(_p( + 'ui:msg_editor|button:add_field|label', + "Add Field" + )) + data = self.history[-1] + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + button.disabled = (len(field_data) >= 25) + + def _field_option(self, index, field_data): + t = self.bot.translator.t + + name = field_data.get('name', "") + value = field_data['value'] + + if not name: + name = t(_p( + 'ui:msg_editor|format_field|name_placeholder', + "-" + )) + + name = f"{index+1}. {name}" + if len(name) > 100: + name = name[:97] + '...' + + if len(value) > 100: + value = value[:97] + '...' + + return SelectOption(label=name, description=value, value=str(index)) + + @select(cls=Select, placeholder="EDIT_FIELD_MENU_PLACEHOLDER", max_values=1) + async def edit_field_menu(self, selection: discord.Interaction, selected: Select): + if not selected.values: + await selection.response.defer() + return + + index = int(selected.values[0]) + data = self.copy_data() + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + field = field_data[index] + + t = self.bot.translator.t + + name_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:edit_field|field:name|label', + "Field name" + )), + style=TextStyle.short, + default=field.get('name', ''), + required=False, + max_length=256, + ) + + value_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:edit_field|field:value|label', + "Field value" + )), + style=TextStyle.long, + default=field.get('value', ''), + required=True, + max_length=1024, + ) + + inline_field = TextInput( + label=t(_p( + 'ui:msg_editor|modal:edit_field|field:inline|label', + "Whether the field is inline" + )), + placeholder=t(_p( + 'ui:msg_editor|modal:edit_field|field:inline|placeholder', + "True/False" + )), + default='True' if field.get('inline', True) else 'False', + style=TextStyle.short, + required=True, + max_length=256, + ) + + modal = MsgEditorInput( + name_field, + value_field, + inline_field, + title=t(_p('ui:msg_editor|modal:edit_field|title', "Edit Embed Field")) + ) + + @modal.submit_callback() + async def edit_field_modal_callback(interaction: discord.Interaction): + if inline_field.value.lower() == 'true': + inline = True + else: + inline = False + field = { + 'name': name_field.value, + 'value': value_field.value, + 'inline': inline + } + field_data[index] = field + embed_data['fields'] = field_data + data['embed'] = embed_data + + await self.push_change(data) + + await interaction.response.defer() + + await selection.response.send_modal(modal) + + async def edit_field_menu_refresh(self): + t = self.bot.translator.t + menu = self.edit_field_menu + menu.placeholder = t(_p( + 'ui:msg_editor|menu:edit_field|placeholder', + "Edit Embed Field" + )) + data = self.history[-1] + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + + if len(field_data) == 0: + menu.disabled = True + menu.options = [ + SelectOption(label='Dummy') + ] + else: + menu.disabled = False + menu.options = [ + self._field_option(i, field) + for i, field in enumerate(field_data) + ] + + @select(cls=Select, placeholder="DELETE_FIELD_MENU_PLACEHOLDER", max_values=1) + async def delete_field_menu(self, selection: discord.Interaction, selected: Select): + if not selected.values: + await selection.response.defer() + return + + index = int(selected.values[0]) + data = self.copy_data() + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + field_data.pop(index) + if not field_data: + embed_data.pop('fields') + await self.push_change(data) + await selection.response.defer() + + async def delete_field_menu_refresh(self): + t = self.bot.translator.t + menu = self.delete_field_menu + menu.placeholder = t(_p( + 'ui:msg_deleteor|menu:delete_field|placeholder', + "Remove Embed Field" + )) + data = self.history[-1] + embed_data = data.get('embed', {}) + field_data = embed_data.get('fields', []) + + if len(field_data) == 0: + menu.disabled = True + menu.options = [ + SelectOption(label='Dummy') + ] + else: + menu.disabled = False + menu.options = [ + self._field_option(i, field) + for i, field in enumerate(field_data) + ] + + # -- Shared -- + @button(label="SAVE_BUTTON_PLACEHOLDER", style=ButtonStyle.green) + async def save_button(self, press: discord.Interaction, pressed: Button): + """ + Saving simply resets the undo stack and calls the callback function. + Presumably the callback is hooked up to data or similar. + """ + await press.response.defer(thinking=True, ephemeral=True) + if self._callback is not None: + await self._callback(self.data) + self.history = self.history[-1:] + await self.refresh(thinking=press) + + async def save_button_refresh(self): + t = self.bot.translator.t + button = self.save_button + button.label = t(_p( + 'ui:msg_editor|button:save|label', + "Save" + )) + if len(self.history) > 1: + original = json.dumps(self.history[0]) + current = json.dumps(self.history[-1]) + button.disabled = (original == current) + else: + button.disabled = True + + @button(label="DOWNLOAD_BUTTON_PLACEHOLDER", style=ButtonStyle.grey) + async def download_button(self, press: discord.Interaction, pressed: Button): + """ + Reply ephemerally with a formatted json version of the message content. + """ + data = json.dumps(self.history[-1], indent=2) + with StringIO(data) as fp: + fp.seek(0) + file = discord.File(fp, filename='message.json') + await press.response.send_message(file=file, ephemeral=True) + + async def download_button_refresh(self): + t = self.bot.translator.t + button = self.download_button + button.label = t(_p( + 'ui:msg_editor|button:download|label', + "Download" + )) + + @button(label="UNDO_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def undo_button(self, press: discord.Interaction, pressed: Button): + """ + Pop the history stack. + """ + if len(self.history) > 1: + state = self.history.pop() + self.future.append(state) + await press.response.defer() + await self.refresh() + + async def undo_button_refresh(self): + t = self.bot.translator.t + button = self.undo_button + button.label = t(_p( + 'ui:msg_editor|button:undo|label', + "Undo" + )) + button.disabled = (len(self.history) <= 1) + + @button(label="REDO_BUTTON_PLACEHOLDER", style=ButtonStyle.blurple) + async def redo_button(self, press: discord.Interaction, pressed: Button): + """ + Pop the future stack. + """ + if len(self.future) > 0: + state = self.future.pop() + self.history.append(state) + await press.response.defer() + await self.refresh() + + async def redo_button_refresh(self): + t = self.bot.translator.t + button = self.redo_button + button.label = t(_p( + 'ui:msg_editor|button:redo|label', + "Redo" + )) + button.disabled = (len(self.future) == 0) + + @button(style=ButtonStyle.grey, emoji=conf.emojis.cancel) + async def quit_button(self, press: discord.Interaction, pressed: Button): + # Confirm quit if there are unsaved changes + unsaved = False + if len(self.history) > 1: + original = json.dumps(self.history[0]) + current = json.dumps(self.history[-1]) + if original != current: + unsaved = True + + # Confirmation prompt + if unsaved: + t = self.bot.translator.t + confirm_msg = t(_p( + 'ui:msg_editor|button:quit|confirm', + "You have unsaved changes! Are you sure you want to quit?" + )) + confirm = Confirm(confirm_msg, self._callerid) + confirm.confirm_button.label = t(_p( + 'ui:msg_editor|button:quit|confirm|button:yes', + "Yes, Quit Now" + )) + confirm.confirm_button.style = ButtonStyle.red + confirm.cancel_button.style = ButtonStyle.green + confirm.cancel_button.label = t(_p( + 'ui:msg_editor|button:quit|confirm|button:no', + "No, Go Back" + )) + try: + result = await confirm.ask(press, ephemeral=True) + except ResponseTimedOut: + result = False + + if result: + await self.quit() + else: + await self.quit() + + # ----- UI Flow ----- + async def make_message(self) -> MessageArgs: + data = self.copy_data() + await self.format_data(data) + + args = {} + args['content'] = data.get('content', '') + + if 'embed' in data: + args['embed'] = discord.Embed.from_dict(data['embed']) + else: + args['embed'] = None + + return MessageArgs(**args) + + async def refresh_layout(self): + to_refresh = ( + self.edit_button_refresh(), + self.add_embed_button_refresh(), + self.body_button_refresh(), + self.author_button_refresh(), + self.footer_button_refresh(), + self.images_button_refresh(), + self.add_field_button_refresh(), + self.edit_field_menu_refresh(), + self.delete_field_menu_refresh(), + self.save_button_refresh(), + self.download_button_refresh(), + self.undo_button_refresh(), + self.redo_button_refresh(), + self.rm_embed_button_refresh(), + ) + await asyncio.gather(*to_refresh) + + if self.history[-1].get('embed', None): + self.set_layout( + (self.body_button, self.author_button, self.footer_button, self.images_button, self.add_field_button), + (self.edit_field_menu,), + (self.delete_field_menu,), + (self.rm_embed_button,), + (self.save_button, self.download_button, self.undo_button, self.redo_button, self.quit_button), + ) + else: + self.set_layout( + (self.edit_button, self.add_embed_button), + (self.save_button, self.download_button, self.undo_button, self.redo_button, self.quit_button), + ) + + async def reload(self): + # All data is handled by components, so nothing to do here + pass + + async def redraw(self, thinking: Optional[discord.Interaction] = None): + """ + Overriding MessageUI.redraw to propagate exception. + """ + await self.refresh_layout() + args = await self.make_message() + + if thinking is not None and not thinking.is_expired() and thinking.response.is_done(): + asyncio.create_task(thinking.delete_original_response()) + + if self._original and not self._original.is_expired(): + await self._original.edit_original_response(**args.edit_args, view=self) + elif self._message: + await self._message.edit(**args.edit_args, view=self) + else: + # Interaction expired or already closed. Quietly cleanup. + await self.close() + + async def pre_timeout(self): + unsaved = False + if len(self.history) > 1: + original = json.dumps(self.history[0]) + current = json.dumps(self.history[-1]) + if original != current: + unsaved = True + + # Timeout confirmation + if unsaved: + t = self.bot.translator.t + grace_period = 60 + grace_time = utc_now() + dt.timedelta(seconds=grace_period) + embed = discord.Embed( + title=t(_p( + 'ui:msg_editor|timeout_warning|title', + "Warning!" + )), + description=t(_p( + 'ui:msg_editor|timeout_warning|desc', + "This interface will time out {timestamp}. Press 'Continue' below to keep editing." + )).format( + timestamp=discord.utils.format_dt(grace_time, style='R') + ), + ) + + components = None + stopped = False + + @AButton(label=t(_p('ui:msg_editor|timeout_warning|continue', "Continue")), style=ButtonStyle.green) + async def cont_button(interaction: discord.Interaction, pressed): + await interaction.response.defer() + await interaction.message.delete() + nonlocal stopped + stopped = True + # TODO: Clean up this mess. It works, but needs to be refactored to a timeout confirmation mixin. + # TODO: Consider moving the message to the interaction response + self._refresh_timeout() + components.stop() + + components = AsComponents(cont_button, timeout=grace_period) + message = await self._original.channel.send(content=f"<@{self._callerid}>", embed=embed, view=components) + await components.wait() + + if not stopped: + try: + await message.delete() + except discord.HTTPException: + pass