From 9e7fb77530998e5941e7bc43b7662f524f26d5e4 Mon Sep 17 00:00:00 2001 From: Conatum Date: Fri, 18 Nov 2022 08:45:45 +0200 Subject: [PATCH] rewrite: UI util library. --- bot/utils/ui.py | 519 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 489 insertions(+), 30 deletions(-) diff --git a/bot/utils/ui.py b/bot/utils/ui.py index c09fee41..51423738 100644 --- a/bot/utils/ui.py +++ b/bot/utils/ui.py @@ -1,29 +1,338 @@ -from typing import List, Coroutine +from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict, TYPE_CHECKING +from typing_extensions import Annotated +import functools import asyncio import logging -from contextvars import copy_context +import time +from enum import Enum +from contextvars import copy_context, Context +from itertools import groupby import discord -from discord.ui import Modal +from discord.ui import Modal, TextInput, View, Item +from discord.ui.button import Button, button +import discord.app_commands as appcmd +from discord.app_commands.transformers import AppCommandOptionType -from .lib import recover_context +from meta.logger import log_action_stack, logging_context -class FastModal(Modal): - def __init__(self, *items, **kwargs): +logger = logging.getLogger(__name__) + + +def create_task_in(coro, context: Context): + """ + Transitional. + Since py3.10 asyncio does not support context instantiation, + this helper method runs `asyncio.create_task(coro)` inside the given context. + """ + return context.run(asyncio.create_task, coro) + + +class HookedItem: + """ + Mixin for Item classes allowing an instance to be used as a callback decorator. + """ + def __init__(self, *args, pass_kwargs={}, **kwargs): + super().__init__(*args, **kwargs) + self.pass_kwargs = pass_kwargs + + def __call__(self, coro): + async def wrapped(view, interaction, **kwargs): + return await coro(view, interaction, self, **kwargs, **self.pass_kwargs) + self.callback = wrapped + return self + + +class AButton(HookedItem, Button): + ... + + +class LeoUI(View): + """ + View subclass for small-scale user interfaces. + + While a 'View' provides an interface for managing a collection of components, + a `LeoUI` may also manage a message, and potentially slave Views or UIs. + The `LeoUI` also exposes more advanced cleanup and timeout methods, + and preserves the context. + """ + + def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + + self._name = ui_name or self.__class__.__name__ + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name]) + + # List of slaved views to stop when this view stops + self._slaves: List[View] = [] + + # TODO: Replace this with a substitutable ViewLayout class + self._layout: Optional[tuple[tuple[Item, ...], ...]] = None + + def to_components(self) -> List[Dict[str, Any]]: + """ + Extending component generator to apply the set _layout, if it exists. + """ + if self._layout is not None: + # Alternative rendering using layout + components = [] + for i, row in enumerate(self._layout): + # Skip empty rows + if not row: + continue + + # Since we aren't relying on ViewWeights, manually check width here + if sum(item.width for item in row) > 5: + raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!") + + # Create the component dict for this row + components.append({ + 'type': 1, + 'components': [item.to_component_dict() for item in row] + }) + else: + components = super().to_components() + + return components + + def set_layout(self, *rows: tuple[Item, ...]) -> None: + """ + Set the layout of the rendered View as a matrix of items, + or more precisely, a list of action rows. + + This acts independently of the existing sorting with `_ViewWeights`, + and overrides the sorting if applied. + """ + self._layout = rows + + async def cleanup(self): + """ + Coroutine to run when timeing out, stopping, or cancelling. + Generally cleans up any open resources, and removes any leftover components. + """ + logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'}) + return None + + def stop(self): + """ + Extends View.stop() to also stop all the slave views. + Note that stopping is idempotent, so it is okay if close() also calls stop(). + """ + for slave in self._slaves: + slave.stop() + super().stop() + + async def close(self): + self.stop() + await self.cleanup() + + async def pre_timeout(self): + """ + Task to execute before actually timing out. + This may cancel the timeout by refreshing or rescheduling it. + (E.g. to ask the user whether they want to keep going.) + + Default implementation does nothing. + """ + return None + + async def on_timeout(self): + """ + Task to execute after timeout is complete. + Default implementation calls cleanup. + """ + await self.cleanup() + + async def __dispatch_timeout(self): + """ + This essentially extends View._dispatch_timeout, + to include a pre_timeout task + which may optionally refresh and hence cancel the timeout. + """ + if self.__stopped.done(): + # We are already stopped, nothing to do + return + + with logging_context(action='Timeout'): + try: + await self.pre_timeout() + except asyncio.TimeoutError: + pass + except asyncio.CancelledError: + pass + except Exception: + await logger.exception( + "Unhandled error caught while dispatching timeout for {self!r}.", + extra={'with_ctx': True, 'action': 'Error'} + ) + + # Check if we still need to timeout + if self.timeout is None: + # The timeout was removed entirely, silently walk away + return + + if self.__stopped.done(): + # We stopped while waiting for the pre timeout. + # Or maybe another thread timed us out + # Either way, we are done here + return + + now = time.monotonic() + if self.__timeout_expiry is not None and now < self._timeout_expiry: + # The timeout was extended, make sure the timeout task is running then fade away + if self.__timeout_task is None or self.__timeout_task.done(): + self.__timeout_task = asyncio.create_task(self.__timeout_task_impl()) + else: + # Actually timeout, and call the post-timeout task for cleanup. + self._really_timeout() + await self.on_timeout() + + def _dispatch_timeout(self): + """ + Overriding timeout method completely, to support interactive flow during timeout, + and optional refreshing of the timeout. + """ + return self._context.run(asyncio.create_task, self.dispatch_timeout()) + + def _really_timeout(self): + """ + Actuallly times out the View. + This copies View._dispatch_timeout, apart from the `on_timeout` dispatch, + which is now handled by `__dispatch_timeout`. + """ + if self.__stopped.done(): + return + + if self.__cancel_callback: + self.__cancel_callback(self) + self.__cancel_callback = None + + self.__stopped.set_result(True) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item): + """ + Default LeoUI error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}", + extra={'with_ctx': True, 'action': 'UIError'} + ) + + +class AsComponents(LeoUI): + """ + Simple container class to accept a number of Items and turn them into an attachable View. + """ + def __init__(self, *items, pass_kwargs={}, **kwargs): + super().__init__(**kwargs) + self.pass_kwargs = pass_kwargs + + for item in items: + item.callback = self.wrap_callback(item.callback) + self.add_item(item) + + def wrap_callback(self, coro): + async def wrapped(*args, **kwargs): + return await coro(self, *args, **kwargs, **self.pass_kwargs) + return wrapped + + +class LeoModal(Modal): + """ + Context-aware Modal class. + """ + def __init__(self, *args, context: Optional[Context] = None, **kwargs): + super().__init__(**kwargs) + + if context is None: + self._context = copy_context() + else: + self._context = context + self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__]) + + def _dispatch_submit(self, *args, **kwargs): + """ + Extending event dispatch to run in the instantiation context. + """ + return self._context.run(super()._dispatch_submit, *args, **kwargs) + + def _dispatch_item(self, *args, **kwargs): + """Extending event dispatch to run in the instantiation context.""" + return self._context.run(super()._dispatch_item, *args, **kwargs) + + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + """ + Default LeoModal error handle. + This may be tail extended by subclasses to preserve the exception stack. + """ + try: + raise error + except Exception: + logger.exception( + f"Unhandled interaction exception occurred in {self!r}", + extra={'with_ctx': True, 'action': 'ModalError'} + ) + + +def error_handler_for(exc): + def wrapper(coro): + coro._ui_error_handler_for_ = exc + return coro + return wrapper + + +class FastModal(LeoModal): + __class_error_handlers__ = [] + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + error_handlers = {} + for base in reversed(cls.__mro__): + for name, member in base.__dict__.items(): + if hasattr(member, '_ui_error_handler_for_'): + error_handlers[name] = member + + cls.__class_error_handlers__ = list(error_handlers.values()) + + def __init__error_handlers__(self): + handlers = {} + for handler in self.__class_error_handlers__: + handlers[handler._ui_error_handler_for_] = functools.partial(handler, self) + return handlers + + def __init__(self, *items: TextInput, **kwargs): super().__init__(**kwargs) for item in items: self.add_item(item) self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() - self._waiters: List[Coroutine[discord.Interaction]] = [] - self._context = copy_context() + self._waiters: List[Callable[[discord.Interaction], Coroutine]] = [] + self._error_handlers = self.__init__error_handlers__() + + def error_handler(self, exception): + def wrapper(coro): + self._error_handlers[exception] = coro + return coro + return wrapper async def wait_for(self, check=None, timeout=None): # Wait for _result or timeout # If we timeout, or the view times out, raise TimeoutError # Otherwise, return the Interaction # This allows multiple listeners and callbacks to wait on - # TODO: Wait on the timeout as well while True: result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) if check is not None: @@ -31,37 +340,187 @@ class FastModal(Modal): continue return result + async def on_timeout(self): + self._result.set_exception(asyncio.TimeoutError) + def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): def wrapper(coro): async def wrapped_callback(interaction): - if check is not None: - if not check(interaction): - return - try: - await coro(interaction, *pass_args, **pass_kwargs) - except Exception: - # TODO: Log exception - logging.exception( - f"Exception occurred executing FastModal callback '{coro.__name__}'" - ) - if once: - self._waiters.remove(wrapped_callback) + with logging_context(action=coro.__name__): + if check is not None: + if not check(interaction): + return + try: + await coro(interaction, *pass_args, **pass_kwargs) + except Exception as error: + await self.on_error(interaction, error) + finally: + if once: + self._waiters.remove(wrapped_callback) self._waiters.append(wrapped_callback) return wrapper - async def on_submit(self, interaction): - # Transitional patch to re-instantiate the current context - # Not required in py 3.11, instead pass a context parameter to create_task - recover_context(self._context) + async def on_error(self, interaction: discord.Interaction, error: Exception, *args): + try: + # First let our error handlers have a go + # If there is no handler for this error, or the handlers themselves error, + # drop to the superclass error handler implementation. + try: + raise error + except tuple(self._error_handlers.keys()) as e: + # If an error handler is registered for this exception, run it. + for cls, handler in self._error_handlers.items(): + if isinstance(e, cls): + await handler(interaction, e) + except Exception as error: + await super().on_error(interaction, error) + async def on_submit(self, interaction): old_result = self._result self._result = asyncio.get_event_loop().create_future() old_result.set_result(interaction) for waiter in self._waiters: - asyncio.create_task(waiter(interaction)) + asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}") - async def on_error(self, interaction, error): - # This should never happen, since on_submit has its own error handling - # TODO: Logging - logging.error("Submit error occured in FastModal") + +async def input( + interaction: discord.Interaction, + title: str, + question: Optional[str] = None, + field: Optional[TextInput] = None, + timeout=180, + **kwargs, +): + """ + Spawn a modal to accept input. + Returns an (interaction, value) pair, with interaction not yet responded to. + May raise asyncio.TimeoutError if the view times out. + """ + if field is None: + field = TextInput( + label=kwargs.get('label', question), + **kwargs + ) + modal = FastModal( + field, + title=title, + timeout=timeout + ) + await interaction.response.send_modal(modal) + interaction = await modal.wait_for() + return (interaction, field.value) + + +class ChoicedEnum(Enum): + @property + def choice_name(self): + return self.name + + @property + def choice_value(self): + return self.value + + @property + def choice(self): + return appcmd.Choice( + name=self.choice_name, value=self.choice_value + ) + + @classmethod + def choices(self): + return [item.choice for item in self] + + @classmethod + def make_choice_map(cls): + return {item.choice_value: item for item in cls} + + @classmethod + async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any): + return transformer._choice_map[value] + + @classmethod + def option_type(cls) -> AppCommandOptionType: + return AppCommandOptionType.string + + @classmethod + def transformer(cls, *args) -> appcmd.Transformer: + return ChoicedEnumTransformer(cls, *args) + + +class ChoicedEnumTransformer(appcmd.Transformer): + # __discord_app_commands_is_choice__ = True + + def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None: + super().__init__() + + self._type = opt_type + self._enum = enum + self._choices = enum.choices() + self._choice_map = enum.make_choice_map() + + @property + def _error_display_name(self) -> str: + return self._enum.__name__ + + @property + def type(self) -> AppCommandOptionType: + return self._type + + @property + def choices(self): + return self._choices + + async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any: + return await self._enum.transform(self, interaction, value) + + +if TYPE_CHECKING: + from typing_extensions import Annotated as Transformed +else: + + class Transformed: + def __class_getitem__(self, items): + cls = items[0] + options = items[1:] + + if not hasattr(cls, 'transformer'): + raise ValueError("Tranformed class must have a transformer classmethod.") + transformer = cls.transformer(*options) + return appcmd.Transform[cls, transformer] + + +class ModalRetryUI(LeoUI): + def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.modal = modal + self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)} + + self.message = message + + self._interaction = None + + if label is not None: + self.retry_button.label = label + + @property + def embed(self): + return discord.Embed( + description=self.message, + colour=discord.Colour.red() + ) + + async def respond_to(self, interaction): + self._interaction = interaction + await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self) + + @button(label="Retry") + async def retry_button(self, interaction, butt): + # Setting these here so they don't update in the meantime + for item, value in self.item_values.items(): + item.default = value + if self._interaction is not None: + await self._interaction.delete_original_response() + self._interaction = None + await interaction.response.send_modal(self.modal) + await self.close()