From 5bd05a84a99ecd271e192acd4a1db4c340ac70a5 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 30 Nov 2022 16:57:26 +0200 Subject: [PATCH] rewrite: Refactor ui utils, add pagers. --- bot/utils/__init__.py | 5 + bot/utils/cog.py | 102 +++++++ bot/utils/ui.py | 526 ------------------------------------ bot/utils/ui/__init__.py | 20 ++ bot/utils/ui/hooked.py | 46 ++++ bot/utils/ui/leo.py | 247 +++++++++++++++++ bot/utils/ui/micros.py | 315 +++++++++++++++++++++ bot/utils/ui/pagers.py | 456 +++++++++++++++++++++++++++++++ bot/utils/ui/transformed.py | 91 +++++++ 9 files changed, 1282 insertions(+), 526 deletions(-) create mode 100644 bot/utils/cog.py delete mode 100644 bot/utils/ui.py create mode 100644 bot/utils/ui/__init__.py create mode 100644 bot/utils/ui/hooked.py create mode 100644 bot/utils/ui/leo.py create mode 100644 bot/utils/ui/micros.py create mode 100644 bot/utils/ui/pagers.py create mode 100644 bot/utils/ui/transformed.py diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py index ccf9ca7a..513c5034 100644 --- a/bot/utils/__init__.py +++ b/bot/utils/__init__.py @@ -1,3 +1,8 @@ from babel.translator import LocalBabel util_babel = LocalBabel('utils') + + +async def setup(bot): + from .cog import MetaUtils + await bot.add_cog(MetaUtils(bot)) diff --git a/bot/utils/cog.py b/bot/utils/cog.py new file mode 100644 index 00000000..8a1403cf --- /dev/null +++ b/bot/utils/cog.py @@ -0,0 +1,102 @@ +import discord +from discord.ext import commands as cmds +from discord import app_commands as appcmds + +from meta import LionBot, LionContext, LionCog +from .ui import BasePager + +from . import util_babel as babel + +_p = babel._p + + +class MetaUtils(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + + @cmds.hybrid_group( + name=_p('cmd:page', 'page'), + description=_p( + 'cmd:page|desc', + "Jump to a given page of the ouput of a previous command in this channel." + ), + ) + async def page_group(self, ctx: LionContext): + """ + No description. + """ + pass + + async def page_jump(self, ctx: LionContext, jumper): + pager = BasePager.get_active_pager(ctx.channel.id, ctx.author.id) + if pager is None: + await ctx.error_reply( + _p('cmd:page|error:no_pager', "No pager listening in this channel!") + ) + else: + if ctx.interaction: + await ctx.interaction.response.defer() + pager.page_num = jumper(pager) + await pager.redraw() + if ctx.interaction: + await ctx.interaction.delete_original_response() + + @page_group.command( + name=_p('cmd:page_next', 'next'), + description=_p('cmd:page_next|desc', "Jump to the next page of output.") + ) + async def next_cmd(self, ctx: LionContext): + await self.page_jump(ctx, lambda pager: pager.page_num + 1) + + @page_group.command( + name=_p('cmd:page_prev', 'prev'), + description=_p('cmd:page_prev|desc', "Jump to the previous page of output.") + ) + async def prev_cmd(self, ctx: LionContext): + await self.page_jump(ctx, lambda pager: pager.page_num - 1) + + @page_group.command( + name=_p('cmd:page_first', 'first'), + description=_p('cmd:page_first|desc', "Jump to the first page of output.") + ) + async def first_cmd(self, ctx: LionContext): + await self.page_jump(ctx, lambda pager: 0) + + @page_group.command( + name=_p('cmd:page_last', 'last'), + description=_p('cmd:page_last|desc', "Jump to the last page of output.") + ) + async def last_cmd(self, ctx: LionContext): + await self.page_jump(ctx, lambda pager: -1) + + @page_group.command( + name=_p('cmd:page_select', 'select'), + description=_p('cmd:page_select|desc', "Select a page of the output to jump to.") + ) + @appcmds.rename( + page=_p('cmd:page_select|param:page', 'page') + ) + @appcmds.describe( + page=_p('cmd:page_select|param:page|desc', "The page name or number to jump to.") + ) + async def page_cmd(self, ctx: LionContext, page: str): + pager = BasePager.get_active_pager(ctx.channel.id, ctx.author.id) + if pager is None: + await ctx.error_reply( + _p('cmd:page_select|error:no_pager', "No pager listening in this channel!") + ) + else: + await pager.page_cmd(ctx.interaction, page) + + @page_cmd.autocomplete('page') + async def page_acmpl(self, interaction: discord.Interaction, partial: str): + pager = BasePager.get_active_pager(interaction.channel_id, interaction.user.id) + if pager is None: + return [ + appcmds.Choice( + name=_p('cmd:page_select|acmpl|error:no_pager', "No active pagers in this channel!"), + value=partial + ) + ] + else: + return await pager.page_acmpl(interaction, partial) diff --git a/bot/utils/ui.py b/bot/utils/ui.py deleted file mode 100644 index 028eddb6..00000000 --- a/bot/utils/ui.py +++ /dev/null @@ -1,526 +0,0 @@ -from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict, TYPE_CHECKING -from typing_extensions import Annotated -import functools -import asyncio -import logging -import time -from enum import Enum -from contextvars import copy_context, Context -from itertools import groupby - -import discord -from discord.ui import Modal, TextInput, View, Item -from discord.ui.button import Button, button -import discord.app_commands as appcmd -from discord.app_commands.transformers import AppCommandOptionType - -from meta.logger import log_action_stack, logging_context - - -logger = logging.getLogger(__name__) - - -def create_task_in(coro, context: Context): - """ - Transitional. - Since py3.10 asyncio does not support context instantiation, - this helper method runs `asyncio.create_task(coro)` inside the given context. - """ - return context.run(asyncio.create_task, coro) - - -class HookedItem: - """ - Mixin for Item classes allowing an instance to be used as a callback decorator. - """ - def __init__(self, *args, pass_kwargs={}, **kwargs): - super().__init__(*args, **kwargs) - self.pass_kwargs = pass_kwargs - - def __call__(self, coro): - async def wrapped(view, interaction, **kwargs): - return await coro(view, interaction, self, **kwargs, **self.pass_kwargs) - self.callback = wrapped - return self - - -class AButton(HookedItem, Button): - ... - - -class LeoUI(View): - """ - View subclass for small-scale user interfaces. - - While a 'View' provides an interface for managing a collection of components, - a `LeoUI` may also manage a message, and potentially slave Views or UIs. - The `LeoUI` also exposes more advanced cleanup and timeout methods, - and preserves the context. - """ - - def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None: - super().__init__(*args, **kwargs) - - if context is None: - self._context = copy_context() - else: - self._context = context - - self._name = ui_name or self.__class__.__name__ - self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name]) - - # List of slaved views to stop when this view stops - self._slaves: List[View] = [] - - # TODO: Replace this with a substitutable ViewLayout class - self._layout: Optional[tuple[tuple[Item, ...], ...]] = None - - def to_components(self) -> List[Dict[str, Any]]: - """ - Extending component generator to apply the set _layout, if it exists. - """ - if self._layout is not None: - # Alternative rendering using layout - components = [] - for i, row in enumerate(self._layout): - # Skip empty rows - if not row: - continue - - # Since we aren't relying on ViewWeights, manually check width here - if sum(item.width for item in row) > 5: - raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!") - - # Create the component dict for this row - components.append({ - 'type': 1, - 'components': [item.to_component_dict() for item in row] - }) - else: - components = super().to_components() - - return components - - def set_layout(self, *rows: tuple[Item, ...]) -> None: - """ - Set the layout of the rendered View as a matrix of items, - or more precisely, a list of action rows. - - This acts independently of the existing sorting with `_ViewWeights`, - and overrides the sorting if applied. - """ - self._layout = rows - - async def cleanup(self): - """ - Coroutine to run when timeing out, stopping, or cancelling. - Generally cleans up any open resources, and removes any leftover components. - """ - logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'}) - return None - - def stop(self): - """ - Extends View.stop() to also stop all the slave views. - Note that stopping is idempotent, so it is okay if close() also calls stop(). - """ - for slave in self._slaves: - slave.stop() - super().stop() - - async def close(self, msg=None): - self.stop() - await self.cleanup() - - async def pre_timeout(self): - """ - Task to execute before actually timing out. - This may cancel the timeout by refreshing or rescheduling it. - (E.g. to ask the user whether they want to keep going.) - - Default implementation does nothing. - """ - return None - - async def on_timeout(self): - """ - Task to execute after timeout is complete. - Default implementation calls cleanup. - """ - await self.cleanup() - - async def __dispatch_timeout(self): - """ - This essentially extends View._dispatch_timeout, - to include a pre_timeout task - which may optionally refresh and hence cancel the timeout. - """ - if self.__stopped.done(): - # We are already stopped, nothing to do - return - - with logging_context(action='Timeout'): - try: - await self.pre_timeout() - except asyncio.TimeoutError: - pass - except asyncio.CancelledError: - pass - except Exception: - await logger.exception( - "Unhandled error caught while dispatching timeout for {self!r}.", - extra={'with_ctx': True, 'action': 'Error'} - ) - - # Check if we still need to timeout - if self.timeout is None: - # The timeout was removed entirely, silently walk away - return - - if self.__stopped.done(): - # We stopped while waiting for the pre timeout. - # Or maybe another thread timed us out - # Either way, we are done here - return - - now = time.monotonic() - if self.__timeout_expiry is not None and now < self._timeout_expiry: - # The timeout was extended, make sure the timeout task is running then fade away - if self.__timeout_task is None or self.__timeout_task.done(): - self.__timeout_task = asyncio.create_task(self.__timeout_task_impl()) - else: - # Actually timeout, and call the post-timeout task for cleanup. - self._really_timeout() - await self.on_timeout() - - def _dispatch_timeout(self): - """ - Overriding timeout method completely, to support interactive flow during timeout, - and optional refreshing of the timeout. - """ - return self._context.run(asyncio.create_task, self.dispatch_timeout()) - - def _really_timeout(self): - """ - Actuallly times out the View. - This copies View._dispatch_timeout, apart from the `on_timeout` dispatch, - which is now handled by `__dispatch_timeout`. - """ - if self.__stopped.done(): - return - - if self.__cancel_callback: - self.__cancel_callback(self) - self.__cancel_callback = None - - self.__stopped.set_result(True) - - def _dispatch_item(self, *args, **kwargs): - """Extending event dispatch to run in the instantiation context.""" - return self._context.run(super()._dispatch_item, *args, **kwargs) - - async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item): - """ - Default LeoUI error handle. - This may be tail extended by subclasses to preserve the exception stack. - """ - try: - raise error - except Exception: - logger.exception( - f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}", - extra={'with_ctx': True, 'action': 'UIError'} - ) - - -class AsComponents(LeoUI): - """ - Simple container class to accept a number of Items and turn them into an attachable View. - """ - def __init__(self, *items, pass_kwargs={}, **kwargs): - super().__init__(**kwargs) - self.pass_kwargs = pass_kwargs - - for item in items: - item.callback = self.wrap_callback(item.callback) - self.add_item(item) - - def wrap_callback(self, coro): - async def wrapped(*args, **kwargs): - return await coro(self, *args, **kwargs, **self.pass_kwargs) - return wrapped - - -class LeoModal(Modal): - """ - Context-aware Modal class. - """ - def __init__(self, *args, context: Optional[Context] = None, **kwargs): - super().__init__(**kwargs) - - if context is None: - self._context = copy_context() - else: - self._context = context - self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__]) - - def _dispatch_submit(self, *args, **kwargs): - """ - Extending event dispatch to run in the instantiation context. - """ - return self._context.run(super()._dispatch_submit, *args, **kwargs) - - def _dispatch_item(self, *args, **kwargs): - """Extending event dispatch to run in the instantiation context.""" - return self._context.run(super()._dispatch_item, *args, **kwargs) - - async def on_error(self, interaction: discord.Interaction, error: Exception, *args): - """ - Default LeoModal error handle. - This may be tail extended by subclasses to preserve the exception stack. - """ - try: - raise error - except Exception: - logger.exception( - f"Unhandled interaction exception occurred in {self!r}", - extra={'with_ctx': True, 'action': 'ModalError'} - ) - - -def error_handler_for(exc): - def wrapper(coro): - coro._ui_error_handler_for_ = exc - return coro - return wrapper - - -class FastModal(LeoModal): - __class_error_handlers__ = [] - - def __init_subclass__(cls, **kwargs) -> None: - super().__init_subclass__(**kwargs) - error_handlers = {} - for base in reversed(cls.__mro__): - for name, member in base.__dict__.items(): - if hasattr(member, '_ui_error_handler_for_'): - error_handlers[name] = member - - cls.__class_error_handlers__ = list(error_handlers.values()) - - def __init__error_handlers__(self): - handlers = {} - for handler in self.__class_error_handlers__: - handlers[handler._ui_error_handler_for_] = functools.partial(handler, self) - return handlers - - def __init__(self, *items: TextInput, **kwargs): - super().__init__(**kwargs) - for item in items: - self.add_item(item) - self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() - self._waiters: List[Callable[[discord.Interaction], Coroutine]] = [] - self._error_handlers = self.__init__error_handlers__() - - def error_handler(self, exception): - def wrapper(coro): - self._error_handlers[exception] = coro - return coro - return wrapper - - async def wait_for(self, check=None, timeout=None): - # Wait for _result or timeout - # If we timeout, or the view times out, raise TimeoutError - # Otherwise, return the Interaction - # This allows multiple listeners and callbacks to wait on - while True: - result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) - if check is not None: - if not check(result): - continue - return result - - async def on_timeout(self): - self._result.set_exception(asyncio.TimeoutError) - - def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): - def wrapper(coro): - async def wrapped_callback(interaction): - with logging_context(action=coro.__name__): - if check is not None: - if not check(interaction): - return - try: - await coro(interaction, *pass_args, **pass_kwargs) - except Exception as error: - await self.on_error(interaction, error) - finally: - if once: - self._waiters.remove(wrapped_callback) - self._waiters.append(wrapped_callback) - return wrapper - - async def on_error(self, interaction: discord.Interaction, error: Exception, *args): - try: - # First let our error handlers have a go - # If there is no handler for this error, or the handlers themselves error, - # drop to the superclass error handler implementation. - try: - raise error - except tuple(self._error_handlers.keys()) as e: - # If an error handler is registered for this exception, run it. - for cls, handler in self._error_handlers.items(): - if isinstance(e, cls): - await handler(interaction, e) - except Exception as error: - await super().on_error(interaction, error) - - async def on_submit(self, interaction): - old_result = self._result - self._result = asyncio.get_event_loop().create_future() - old_result.set_result(interaction) - - for waiter in self._waiters: - asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}") - - -async def input( - interaction: discord.Interaction, - title: str, - question: Optional[str] = None, - field: Optional[TextInput] = None, - timeout=180, - **kwargs, -): - """ - Spawn a modal to accept input. - Returns an (interaction, value) pair, with interaction not yet responded to. - May raise asyncio.TimeoutError if the view times out. - """ - if field is None: - field = TextInput( - label=kwargs.get('label', question), - **kwargs - ) - modal = FastModal( - field, - title=title, - timeout=timeout - ) - await interaction.response.send_modal(modal) - interaction = await modal.wait_for() - return (interaction, field.value) - - -class ChoicedEnum(Enum): - @property - def choice_name(self): - return self.name - - @property - def choice_value(self): - return self.value - - @property - def choice(self): - return appcmd.Choice( - name=self.choice_name, value=self.choice_value - ) - - @classmethod - def choices(self): - return [item.choice for item in self] - - @classmethod - def make_choice_map(cls): - return {item.choice_value: item for item in cls} - - @classmethod - async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any): - return transformer._choice_map[value] - - @classmethod - def option_type(cls) -> AppCommandOptionType: - return AppCommandOptionType.string - - @classmethod - def transformer(cls, *args) -> appcmd.Transformer: - return ChoicedEnumTransformer(cls, *args) - - -class ChoicedEnumTransformer(appcmd.Transformer): - # __discord_app_commands_is_choice__ = True - - def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None: - super().__init__() - - self._type = opt_type - self._enum = enum - self._choices = enum.choices() - self._choice_map = enum.make_choice_map() - - @property - def _error_display_name(self) -> str: - return self._enum.__name__ - - @property - def type(self) -> AppCommandOptionType: - return self._type - - @property - def choices(self): - return self._choices - - async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any: - return await self._enum.transform(self, interaction, value) - - -if TYPE_CHECKING: - from typing_extensions import Annotated as Transformed -else: - - class Transformed: - def __class_getitem__(self, items): - cls = items[0] - options = items[1:] - - if not hasattr(cls, 'transformer'): - raise ValueError("Tranformed class must have a transformer classmethod.") - transformer = cls.transformer(*options) - return appcmd.Transform[cls, transformer] - - -class ModalRetryUI(LeoUI): - def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - self.modal = modal - self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)} - - self.message = message - - self._interaction = None - - if label is not None: - self.retry_button.label = label - - @property - def embed(self): - return discord.Embed( - description=self.message, - colour=discord.Colour.red() - ) - - async def respond_to(self, interaction): - self._interaction = interaction - await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self) - - @button(label="Retry") - async def retry_button(self, interaction, butt): - # Setting these here so they don't update in the meantime - for item, value in self.item_values.items(): - item.default = value - if self._interaction is not None: - await self._interaction.delete_original_response() - self._interaction = None - await interaction.response.send_modal(self.modal) - await self.close() diff --git a/bot/utils/ui/__init__.py b/bot/utils/ui/__init__.py new file mode 100644 index 00000000..14fa3150 --- /dev/null +++ b/bot/utils/ui/__init__.py @@ -0,0 +1,20 @@ +import asyncio +import logging +from .. import util_babel + +logger = logging.getLogger(__name__) + +from .hooked import * +from .leo import * +from .micros import * +from .pagers import * +from .transformed import * + + +# def create_task_in(coro, context: Context): +# """ +# Transitional. +# Since py3.10 asyncio does not support context instantiation, +# this helper method runs `asyncio.create_task(coro)` inside the given context. +# """ +# return context.run(asyncio.create_task, coro) diff --git a/bot/utils/ui/hooked.py b/bot/utils/ui/hooked.py new file mode 100644 index 00000000..2784e2d4 --- /dev/null +++ b/bot/utils/ui/hooked.py @@ -0,0 +1,46 @@ +from discord.ui.button import Button + +from .leo import LeoUI + +__all__ = ( + 'HookedItem', + 'AButton', + 'AsComponents' +) + + +class HookedItem: + """ + Mixin for Item classes allowing an instance to be used as a callback decorator. + """ + def __init__(self, *args, pass_kwargs={}, **kwargs): + super().__init__(*args, **kwargs) + self.pass_kwargs = pass_kwargs + + def __call__(self, coro): + async def wrapped(view, interaction, **kwargs): + return await coro(view, interaction, self, **kwargs, **self.pass_kwargs) + self.callback = wrapped + return self + + +class AButton(HookedItem, Button): + ... + + +class AsComponents(LeoUI): + """ + Simple container class to accept a number of Items and turn them into an attachable View. + """ + def __init__(self, *items, pass_kwargs={}, **kwargs): + super().__init__(**kwargs) + self.pass_kwargs = pass_kwargs + + for item in items: + item.callback = self.wrap_callback(item.callback) + self.add_item(item) + + def wrap_callback(self, coro): + async def wrapped(*args, **kwargs): + return await coro(self, *args, **kwargs, **self.pass_kwargs) + return wrapped diff --git a/bot/utils/ui/leo.py b/bot/utils/ui/leo.py new file mode 100644 index 00000000..e8d860da --- /dev/null +++ b/bot/utils/ui/leo.py @@ -0,0 +1,247 @@ +from typing import List, Optional, Any, Dict +import asyncio +import logging +import time +from contextvars import copy_context, Context + +import discord +from discord.ui import Modal, View, Item + +from meta.logger import log_action_stack, logging_context + +from . import logger + +__all__ = ( + 'LeoUI', + 'LeoModal', + 'error_handler_for' +) + + +class LeoUI(View): + """ + View subclass for small-scale user interfaces. + + While a 'View' provides an interface for managing a collection of components, + a `LeoUI` may also manage a message, and potentially slave Views or UIs. + The `LeoUI` also exposes more advanced cleanup and timeout methods, + and preserves the context. + """ + + def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + + self._name = ui_name or self.__class__.__name__ + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name]) + + # List of slaved views to stop when this view stops + self._slaves: List[View] = [] + + # TODO: Replace this with a substitutable ViewLayout class + self._layout: Optional[tuple[tuple[Item, ...], ...]] = None + + def to_components(self) -> List[Dict[str, Any]]: + """ + Extending component generator to apply the set _layout, if it exists. + """ + if self._layout is not None: + # Alternative rendering using layout + components = [] + for i, row in enumerate(self._layout): + # Skip empty rows + if not row: + continue + + # Since we aren't relying on ViewWeights, manually check width here + if sum(item.width for item in row) > 5: + raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!") + + # Create the component dict for this row + components.append({ + 'type': 1, + 'components': [item.to_component_dict() for item in row] + }) + else: + components = super().to_components() + + return components + + def set_layout(self, *rows: tuple[Item, ...]) -> None: + """ + Set the layout of the rendered View as a matrix of items, + or more precisely, a list of action rows. + + This acts independently of the existing sorting with `_ViewWeights`, + and overrides the sorting if applied. + """ + self._layout = rows + + async def cleanup(self): + """ + Coroutine to run when timeing out, stopping, or cancelling. + Generally cleans up any open resources, and removes any leftover components. + """ + logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'}) + return None + + def stop(self): + """ + Extends View.stop() to also stop all the slave views. + Note that stopping is idempotent, so it is okay if close() also calls stop(). + """ + for slave in self._slaves: + slave.stop() + super().stop() + + async def close(self, msg=None): + self.stop() + await self.cleanup() + + async def pre_timeout(self): + """ + Task to execute before actually timing out. + This may cancel the timeout by refreshing or rescheduling it. + (E.g. to ask the user whether they want to keep going.) + + Default implementation does nothing. + """ + return None + + async def on_timeout(self): + """ + Task to execute after timeout is complete. + Default implementation calls cleanup. + """ + await self.cleanup() + + async def __dispatch_timeout(self): + """ + This essentially extends View._dispatch_timeout, + to include a pre_timeout task + which may optionally refresh and hence cancel the timeout. + """ + if self.__stopped.done(): + # We are already stopped, nothing to do + return + + with logging_context(action='Timeout'): + try: + await self.pre_timeout() + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: + pass + except Exception: + await logger.exception( + "Unhandled error caught while dispatching timeout for {self!r}.", + extra={'with_ctx': True, 'action': 'Error'} + ) + + # Check if we still need to timeout + if self.timeout is None: + # The timeout was removed entirely, silently walk away + return + + if self.__stopped.done(): + # We stopped while waiting for the pre timeout. + # Or maybe another thread timed us out + # Either way, we are done here + return + + now = time.monotonic() + if self.__timeout_expiry is not None and now < self._timeout_expiry: + # The timeout was extended, make sure the timeout task is running then fade away + if self.__timeout_task is None or self.__timeout_task.done(): + self.__timeout_task = asyncio.create_task(self.__timeout_task_impl()) + else: + # Actually timeout, and call the post-timeout task for cleanup. + self._really_timeout() + await self.on_timeout() + + def _dispatch_timeout(self): + """ + Overriding timeout method completely, to support interactive flow during timeout, + and optional refreshing of the timeout. + """ + return self._context.run(asyncio.create_task, self.dispatch_timeout()) + + def _really_timeout(self): + """ + Actuallly times out the View. + This copies View._dispatch_timeout, apart from the `on_timeout` dispatch, + which is now handled by `__dispatch_timeout`. + """ + if self.__stopped.done(): + return + + if self.__cancel_callback: + self.__cancel_callback(self) + self.__cancel_callback = None + + self.__stopped.set_result(True) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item): + """ + Default LeoUI error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}", + extra={'with_ctx': True, 'action': 'UIError'} + ) + + +class LeoModal(Modal): + """ + Context-aware Modal class. + """ + def __init__(self, *args, context: Optional[Context] = None, **kwargs): + super().__init__(**kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__]) + + def _dispatch_submit(self, *args, **kwargs): + """ + Extending event dispatch to run in the instantiation context. + """ + return self._context.run(super()._dispatch_submit, *args, **kwargs) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + """ + Default LeoModal error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in {self!r}", + extra={'with_ctx': True, 'action': 'ModalError'} + ) + + +def error_handler_for(exc): + def wrapper(coro): + coro._ui_error_handler_for_ = exc + return coro + return wrapper diff --git a/bot/utils/ui/micros.py b/bot/utils/ui/micros.py new file mode 100644 index 00000000..7246bffa --- /dev/null +++ b/bot/utils/ui/micros.py @@ -0,0 +1,315 @@ +from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict +import functools +import asyncio + +import discord +from discord.ui import TextInput +from discord.ui.button import button + +from meta.logger import logging_context +from meta.errors import ResponseTimedOut + +from .leo import LeoModal, LeoUI + +__all__ = ( + 'FastModal', + 'ModalRetryUI', + 'Confirm', + 'input', +) + + +class FastModal(LeoModal): + __class_error_handlers__ = [] + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + error_handlers = {} + for base in reversed(cls.__mro__): + for name, member in base.__dict__.items(): + if hasattr(member, '_ui_error_handler_for_'): + error_handlers[name] = member + + cls.__class_error_handlers__ = list(error_handlers.values()) + + def __init__error_handlers__(self): + handlers = {} + for handler in self.__class_error_handlers__: + handlers[handler._ui_error_handler_for_] = functools.partial(handler, self) + return handlers + + def __init__(self, *items: TextInput, **kwargs): + super().__init__(**kwargs) + for item in items: + self.add_item(item) + self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() + self._waiters: List[Callable[[discord.Interaction], Coroutine]] = [] + self._error_handlers = self.__init__error_handlers__() + + def error_handler(self, exception): + def wrapper(coro): + self._error_handlers[exception] = coro + return coro + return wrapper + + async def wait_for(self, check=None, timeout=None): + # Wait for _result or timeout + # If we timeout, or the view times out, raise TimeoutError + # Otherwise, return the Interaction + # This allows multiple listeners and callbacks to wait on + while True: + result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) + if check is not None: + if not check(result): + continue + return result + + async def on_timeout(self): + self._result.set_exception(asyncio.TimeoutError) + + def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): + def wrapper(coro): + async def wrapped_callback(interaction): + with logging_context(action=coro.__name__): + if check is not None: + if not check(interaction): + return + try: + await coro(interaction, *pass_args, **pass_kwargs) + except Exception as error: + await self.on_error(interaction, error) + finally: + if once: + self._waiters.remove(wrapped_callback) + self._waiters.append(wrapped_callback) + return wrapper + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + try: + # First let our error handlers have a go + # If there is no handler for this error, or the handlers themselves error, + # drop to the superclass error handler implementation. + try: + raise error + except tuple(self._error_handlers.keys()) as e: + # If an error handler is registered for this exception, run it. + for cls, handler in self._error_handlers.items(): + if isinstance(e, cls): + await handler(interaction, e) + except Exception as error: + await super().on_error(interaction, error) + + async def on_submit(self, interaction): + old_result = self._result + self._result = asyncio.get_event_loop().create_future() + old_result.set_result(interaction) + + for waiter in self._waiters: + asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}") + + +async def input( + interaction: discord.Interaction, + title: str, + question: Optional[str] = None, + field: Optional[TextInput] = None, + timeout=180, + **kwargs, +): + """ + Spawn a modal to accept input. + Returns an (interaction, value) pair, with interaction not yet responded to. + May raise asyncio.TimeoutError if the view times out. + """ + if field is None: + field = TextInput( + label=kwargs.get('label', question), + **kwargs + ) + modal = FastModal( + field, + title=title, + timeout=timeout + ) + await interaction.response.send_modal(modal) + interaction = await modal.wait_for() + return (interaction, field.value) + + +class ModalRetryUI(LeoUI): + def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.modal = modal + self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)} + + self.message = message + + self._interaction = None + + if label is not None: + self.retry_button.label = label + + @property + def embed(self): + return discord.Embed( + description=self.message, + colour=discord.Colour.red() + ) + + async def respond_to(self, interaction): + self._interaction = interaction + await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self) + + @button(label="Retry") + async def retry_button(self, interaction, butt): + # Setting these here so they don't update in the meantime + for item, value in self.item_values.items(): + item.default = value + if self._interaction is not None: + await self._interaction.delete_original_response() + self._interaction = None + await interaction.response.send_modal(self.modal) + await self.close() + + +class Confirm(LeoUI): + """ + Micro UI class implementing a confirmation question. + + Parameters + ---------- + confirm_msg: str + The confirmation question to ask from the user. + This is set as the description of the `embed` property. + The `embed` may be further modified if required. + permitted_id: Optional[int] + The user id allowed to access this interaction. + Other users will recieve an access denied error message. + defer: bool + Whether to defer the interaction response while handling the button. + It may be useful to set this to `False` to obtain manual control + over the interaction response flow (e.g. to send a modal or ephemeral message). + The button press interaction may be accessed through `Confirm.interaction`. + Default: True + + Example + ------- + ``` + confirm = Confirm("Are you sure?", ctx.author.id) + confirm.embed.colour = discord.Colour.red() + confirm.confirm_button.label = "Yes I am sure" + confirm.cancel_button.label = "No I am not sure" + + try: + result = await confirm.ask(ctx.interaction, ephemeral=True) + except ResultTimedOut: + return + ``` + """ + def __init__( + self, + confirm_msg: str, + permitted_id: Optional[int] = None, + defer: bool = True, + **kwargs + ): + super().__init__(**kwargs) + self.confirm_msg = confirm_msg + self.permitted_id = permitted_id + self.defer = defer + + self._embed: Optional[discord.Embed] = None + self._result: asyncio.Future[bool] = asyncio.Future() + + # Indicates whether we should delete the message or the interaction response + self._is_followup: bool = False + self._original: Optional[discord.Interaction] = None + self._message: Optional[discord.Message] = None + + async def interaction_check(self, interaction: discord.Interaction): + return (self.permitted_id is None) or interaction.user.id == self.permitted_id + + async def on_timeout(self): + # Propagate timeout to result Future + self._result.set_exception(ResponseTimedOut) + await self.cleanup() + + async def cleanup(self): + """ + Cleanup the confirmation prompt by deleting it, if possible. + Ignores any Discord errors that occur during the process. + """ + try: + if self._is_followup and self._message: + await self._message.delete() + elif not self._is_followup and self._original and not self._original.is_expired(): + await self._original.delete_original_response() + except discord.HTTPException: + # A user probably already deleted the message + # Anything could have happened, just ignore. + pass + + @button(label="Confirm") + async def confirm_button(self, interaction: discord.Interaction, press): + if self.defer: + await interaction.response.defer() + self._result.set_result(True) + await self.close() + + @button(label="Cancel") + async def cancel_button(self, interaction: discord.Interaction, press): + if self.defer: + await interaction.response.defer() + self._result.set_result(False) + await self.close() + + @property + def embed(self): + """ + Confirmation embed shown to the user. + This is cached, and may be modifed directly through the usual EmbedProxy API, + or explicitly overwritten. + """ + if self._embed is None: + self._embed = discord.Embed( + colour=discord.Colour.orange(), + description=self.confirm_msg + ) + return self._embed + + @embed.setter + def embed(self, value): + self._embed = value + + async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs): + """ + Send this confirmation prompt in response to the provided interaction. + Extra keyword arguments are passed to `Interaction.response.send_message` + or `Interaction.send_followup`, depending on whether + the provided interaction has already been responded to. + + The `epehemeral` argument is handled specially, + since the question message can only be deleted through `Interaction.delete_original_response`. + + Waits on and returns the internal `result` Future. + + Returns: bool + True if the user pressed the confirm button. + False if the user pressed the cancel button. + Raises: + ResponseTimedOut: + If the user does not respond before the UI times out. + """ + self._original = interaction + if interaction.response.is_done(): + # Interaction already responded to, send a follow up + if ephemeral: + raise ValueError("Cannot send an ephemeral response to a used interaction.") + self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self) + self._is_followup = True + else: + await interaction.response.send_message( + embed=self.embed, ephemeral=ephemeral, **kwargs, view=self + ) + self._is_followup = False + return await self._result diff --git a/bot/utils/ui/pagers.py b/bot/utils/ui/pagers.py new file mode 100644 index 00000000..284cbd03 --- /dev/null +++ b/bot/utils/ui/pagers.py @@ -0,0 +1,456 @@ +from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict +from collections import defaultdict + +import discord +from discord.ui.button import Button, button +from discord import app_commands as appcmds + +from meta.logger import log_action_stack, logging_context +from meta.errors import SafeCancellation +from meta.config import conf + +from babel.translator import ctx_translator + +from ..lib import MessageArgs, error_embed +from .. import util_babel + +from .leo import LeoUI + +_p = util_babel._p + + +__all__ = ( + 'BasePager', + 'Pager', +) + + +class BasePager(LeoUI): + """ + An ABC describing the common interface for a Paging UI. + + A paging UI represents a sequence of pages, accessible by `next` and `previous` buttons, + and possibly by a dropdown (not implemented). + + A `Page` is represented as a `MessageArgs` object, which is passable to `send` and `edit` methods as required. + + Each page of a paging UI is accessed through the coroutine `get_page`. + This allows for more complex paging schemes where the pages are expensive to compute, + and not generally needed simultaneously. + In general, `get_page` should cache expensive pages, + perhaps simply with a `cached` decorator, but this is not enforced. + + The state of the base UI is represented as the current `page_num` and the `current_page`. + + This class also maintains an `active_pagers` cache, + representing all `BasePager`s that are currently running. + This allows access from external page controlling utilities, e.g. the `/page` command. + """ + # channelid -> pager.id -> list of active pagers in this channel + active_pagers: dict[int, dict[int, 'BasePager']] = defaultdict(dict) + + page_num: int + current_page: MessageArgs + _channelid: Optional[int] + + @classmethod + def get_active_pager(self, channelid, userid): + """ + Get the last active pager in the `destinationid`, which may be accessed by `userid`. + Returns None if there are no matching pagers. + """ + for pager in reversed(self.active_pagers[channelid].values()): + if pager.access_check(userid): + return pager + + def set_active(self): + if self._channelid is None: + raise ValueError("Cannot set active without a channelid.") + self.active_pagers[self._channelid][self.id] = self + + def set_inactive(self): + self.active_pagers[self._channelid].pop(self.id, None) + + def access_check(self, userid): + """ + Check whether the given userid is allowed to use this UI. + Must be overridden by subclasses. + """ + raise NotImplementedError + + async def get_page(self, page_id) -> MessageArgs: + """ + `get_page` returns the specified page number, starting from 0. + An implementation of `get_page` must: + - Always return a page (if no data is a valid state, must return a placeholder page). + - Always accept out-of-range `page_id` values. + - There is no behaviour specified for these, although they will usually be modded into the correct + range. + - In some cases (e.g. stream data where we don't have a last page), + they may simply return the last correct page instead. + + """ + raise NotImplementedError + + async def page_cmd(self, interaction: discord.Interaction, value: str): + """ + Command implementation for the paging command. + Pager subclasses should override this if they use `active_pagers`. + Default implementation is essentially a no-op, + simply replying to the interaction. + """ + await interaction.response.defer() + return + + async def page_acmpl(self, interaction: discord.Interaction, partial: str): + """ + Command autocompletion for the paging command. + Pager subclasses should override this if they use `active_pagers`. + """ + return [] + + @button(emoji=conf.emojis.getemoji('forward')) + async def next_page_button(self, interaction: discord.Interaction, press): + await interaction.response.defer() + self.page_num += 1 + await self.redraw() + + @button(emoji=conf.emojis.getemoji('backward')) + async def prev_page_button(self, interaction: discord.Interaction, press): + await interaction.response.defer() + self.page_num -= 1 + await self.redraw() + + async def refresh(self): + """ + Recalculate current computed state. + (E.g. fetch current page, set layout, disable components, etc.) + """ + self.current_page = await self.get_page(self.page_num) + + async def redraw(self): + """ + This should refresh the current state and redraw the UI. + Not implemented here, as the implementation depends on whether this is a reaction response ephemeral UI + or a message=based one. + """ + raise NotImplementedError + + +class Pager(BasePager): + """ + MicroUI to display a sequence of static pages, + supporting paging reaction and paging commands. + + Parameters + ---------- + pages: list[MessageArgs] + A non-empty list of message arguments to page. + start_from: int + The page number to display first. + Default: 0 + locked: bool + Whether to only allow the author to use the paging interface. + """ + # List of valid keys indicating movement to the next page + next_list = _p('cmd:page|pager:Pager|options:next', "n, nxt, next, forward, +") + + # List of valid keys indicating movement to the previous page + prev_list = _p('cmd:page|pager:Pager|options:prev', "p, prev, back, -") + + # List of valid keys indicating movement to the first page + first_list = _p('cmd:page|pager:Pager|options:first', "f, first, one, start") + + # List of valid keys indicating movement to the last page + last_list = _p('cmd:page|pager:Pager|options:last', "l, last, end") + + def __init__(self, pages: list[MessageArgs], + start_from=0, + show_cancel=False, delete_on_cancel=True, delete_after=False, **kwargs): + super().__init__(**kwargs) + self._pages = pages + self.page_num = start_from + self.current_page = pages[self.page_num] + + self._locked = True + self._ownerid: Optional[int] = None + self._channelid: Optional[int] = None + + if not pages: + raise ValueError("Cannot run Pager with no pages.") + + self._original: Optional[discord.Interaction] = None + self._is_followup: bool = False + self._message: Optional[discord.Message] = None + + self.show_cancel = show_cancel + self._delete_on_cancel = delete_on_cancel + self._delete_after = delete_after + + @property + def ownerid(self): + if self._ownerid is not None: + return self._ownerid + elif self._original: + return self._original.user.id + else: + return None + + def access_check(self, userid): + return not self._locked or (userid == self.ownerid) + + async def interaction_check(self, interaction: discord.Interaction): + return self.access_check(interaction.user.id) + + @button(emoji=conf.emojis.getemoji('cancel')) + async def cancel_button(self, interaction: discord.Interaction, press: Button): + await interaction.response.defer() + if self._delete_on_cancel: + self._delete_after = True + await self.close() + + async def cleanup(self): + self.set_inactive() + + # If we still have a message, delete it or clear the view + try: + if self._is_followup: + if self._message: + if self._delete_after: + await self._message.delete() + else: + await self._message.edit(view=None) + else: + if self._original and not self._original.is_expired(): + if self._delete_after: + await self._original.delete_original_response() + else: + await self._original.edit_original_response(view=None) + except discord.HTTPException: + # Nothing we can do here + pass + + async def get_page(self, page_id): + page_id %= len(self._pages) + return self._pages[page_id] + + def page_count(self): + return len(self.pages) + + async def page_cmd(self, interaction: discord.Interaction, value: str): + """ + `/page` command for the `Pager` MicroUI. + """ + await interaction.response.defer(ephemeral=True) + t = ctx_translator.get().t + nexts = {word.strip() for word in t(self.next_list).split(',')} + prevs = {word.strip() for word in t(self.prev_list).split(',')} + firsts = {word.strip() for word in t(self.first_list).split(',')} + lasts = {word.strip() for word in t(self.last_list).split(',')} + + if value: + value = value.lower().strip() + if value.isdigit(): + # Assume value is page number + self.page_num = int(value) - 1 + if self.page_num == -1: + self.page_num = 0 + elif value in firsts: + self.page_num = 0 + elif value in nexts: + self.page_num += 1 + elif value in prevs: + self.page_num -= 1 + elif value in lasts: + self.page_num = -1 + elif value.startswith('-') and value[1:].isdigit(): + self.page_num = - int(value[1:]) + else: + await interaction.edit_original_response( + embed=error_embed( + t(_p( + 'cmd:page|pager:Pager|error:parse', + "Could not understand page specification `{value}`." + )).format(value=value) + ) + ) + return + await interaction.delete_original_response() + await self.redraw() + + async def page_acmpl(self, interaction: discord.Interaction, partial: str): + """ + `/page` command autocompletion for the `Pager` MicroUI. + """ + t = ctx_translator.get().t + nexts = {word.strip() for word in t(self.next_list).split(',')} + prevs = {word.strip() for word in t(self.prev_list).split(',')} + firsts = {word.strip() for word in t(self.first_list).split(',')} + lasts = {word.strip() for word in t(self.last_list).split(',')} + + total = len(self._pages) + num = self.page_num + page_choices: dict[int, str] = {} + + # TODO: Support page names and hints? + + if len(self._pages) > 10: + # First add the general choices + if num < total-1: + page_choices[total-1] = t(_p( + 'cmd:page|acmpl|pager:Pager|choice:last', + "Last: Page {page}/{total}" + )).format(page=total, total=total) + + page_choices[num] = t(_p( + 'cmd:page|acmpl|pager:Pager|choice:current', + "Current: Page {page}/{total}" + )).format(page=num+1, total=total) + choices = [ + appcmds.Choice(name=string, value=str(num+1)) + for num, string in sorted(page_choices.items(), key=lambda t: t[0]) + ] + else: + # Particularly support page names here + choices = [ + appcmds.Choice( + name='> ' * (i == num) + t(_p( + 'cmd:page|acmpl|pager:Pager|choice:general', + "Page {page}" + )).format(page=i+1), + value=str(i+1) + ) + for i in range(0, total) + ] + + partial = partial.strip() + + if partial: + value = partial.lower().strip() + if value.isdigit(): + # Assume value is page number + page_num = int(value) - 1 + if page_num == -1: + page_num = 0 + elif value in firsts: + page_num = 0 + elif value in nexts: + page_num = self.page_num + 1 + elif value in prevs: + page_num = self.page_num - 1 + elif value in lasts: + page_num = -1 + elif value.startswith('-') and value[1:].isdigit(): + page_num = - int(value[1:]) + else: + page_num = None + + if page_num is not None: + page_num %= total + choice = appcmds.Choice( + name=t(_p( + 'cmd:page|acmpl|pager:Page|choice:select', + "Selected: Page {page}/{total}" + )).format(page=page_num+1, total=total), + value=str(page_num + 1) + ) + return [choice, *choices] + else: + return [ + appcmds.Choice( + name=t(_p( + 'cmd:page|acmpl|pager:Page|error:parse', + "No matching pages!" + )).format(page=page_num, total=total), + value=partial + ) + ] + else: + return choices + + @property + def page_row(self): + if self.show_cancel: + if len(self._pages) > 1: + return (self.prev_page_button, self.cancel_button, self.next_page_button) + else: + return (self.cancel_button,) + else: + if len(self._pages) > 1: + return (self.prev_page_button, self.next_page_button) + else: + return () + + async def refresh(self): + await super().refresh() + self.set_layout(self.page_row) + + async def redraw(self): + await self.refresh() + + if not self._original: + raise ValueError("Running run pager manually without interaction.") + + try: + if self._message: + await self._message.edit(**self.current_page.edit_args, view=self) + else: + if self._original.is_expired(): + raise SafeCancellation("This interface has expired, please try again.") + await self._original.edit_original_response(**self.current_page.edit_args, view=self) + except discord.HTTPException: + raise SafeCancellation("Could not page your results! Please try again.") + + async def run(self, interaction: discord.Interaction, ephemeral=False, locked=True, ownerid=None, **kwargs): + """ + Display the UI. + Attempts to reply to the interaction if it has not already been replied to, + otherwise send a follow-up. + + An ephemeral response must be sent as an initial interaction response. + On the other hand, a semi-persistent response (expected to last longer than the lifetime of the interaction) + must be sent as a followup. + + Extra kwargs are combined with the first page arguments and given to the relevant send method. + + Parameters + ---------- + interaction: discord.Interaction + The interaction to send the pager in response to. + ephemeral: bool + Whether to send the interaction ephemerally. + If this is true, the interaction *must* be fresh (i.e. no response done). + Default: False + locked: bool + Whether this interface is locked to the user `self.ownerid`. + Irrelevant for ephemeral messages. + Use `ownerid` to override the default owner id. + Defaults to true for fail-safety. + Default: True + ownerid: Optional[int] + The userid allowed to use this interaction. + By default, this will be the `interaction.user.id`, + presuming that this is the user which originally triggered this message. + An override may be useful if a user triggers a paging UI for someone else. + """ + if not interaction.channel_id: + raise ValueError("Cannot run pager on a channelless interaction.") + + self._original = interaction + self._ownerid = ownerid + self._locked = locked + self._channelid = interaction.channel_id + + await self.refresh() + args = self.current_page.send_args | kwargs + + if interaction.response.is_done(): + if ephemeral: + raise ValueError("Ephemeral response requires fres interaction.") + self._message = await interaction.followup.send(**args, view=self) + self._is_followup = True + else: + self._is_followup = False + await interaction.response.send_message(**args, view=self) + + self.set_active() diff --git a/bot/utils/ui/transformed.py b/bot/utils/ui/transformed.py new file mode 100644 index 00000000..3e344659 --- /dev/null +++ b/bot/utils/ui/transformed.py @@ -0,0 +1,91 @@ +from typing import Any, Type, TYPE_CHECKING +from enum import Enum + +import discord +import discord.app_commands as appcmd +from discord.app_commands.transformers import AppCommandOptionType + + +__all__ = ( + 'ChoicedEnum', + 'ChoicedEnumTransformer', + 'Transformed', +) + + +class ChoicedEnum(Enum): + @property + def choice_name(self): + return self.name + + @property + def choice_value(self): + return self.value + + @property + def choice(self): + return appcmd.Choice( + name=self.choice_name, value=self.choice_value + ) + + @classmethod + def choices(self): + return [item.choice for item in self] + + @classmethod + def make_choice_map(cls): + return {item.choice_value: item for item in cls} + + @classmethod + async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any): + return transformer._choice_map[value] + + @classmethod + def option_type(cls) -> AppCommandOptionType: + return AppCommandOptionType.string + + @classmethod + def transformer(cls, *args) -> appcmd.Transformer: + return ChoicedEnumTransformer(cls, *args) + + +class ChoicedEnumTransformer(appcmd.Transformer): + # __discord_app_commands_is_choice__ = True + + def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None: + super().__init__() + + self._type = opt_type + self._enum = enum + self._choices = enum.choices() + self._choice_map = enum.make_choice_map() + + @property + def _error_display_name(self) -> str: + return self._enum.__name__ + + @property + def type(self) -> AppCommandOptionType: + return self._type + + @property + def choices(self): + return self._choices + + async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any: + return await self._enum.transform(self, interaction, value) + + +if TYPE_CHECKING: + from typing_extensions import Annotated as Transformed +else: + + class Transformed: + def __class_getitem__(self, items): + cls = items[0] + options = items[1:] + + if not hasattr(cls, 'transformer'): + raise ValueError("Tranformed class must have a transformer classmethod.") + transformer = cls.transformer(*options) + return appcmd.Transform[cls, transformer]