rewrite: Refactor ui utils, add pagers.

This commit is contained in:
2022-11-30 16:57:26 +02:00
parent dd8609fac0
commit 5bd05a84a9
9 changed files with 1282 additions and 526 deletions

View File

@@ -1,3 +1,8 @@
from babel.translator import LocalBabel from babel.translator import LocalBabel
util_babel = LocalBabel('utils') util_babel = LocalBabel('utils')
async def setup(bot):
from .cog import MetaUtils
await bot.add_cog(MetaUtils(bot))

102
bot/utils/cog.py Normal file
View File

@@ -0,0 +1,102 @@
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionContext, LionCog
from .ui import BasePager
from . import util_babel as babel
_p = babel._p
class MetaUtils(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
@cmds.hybrid_group(
name=_p('cmd:page', 'page'),
description=_p(
'cmd:page|desc',
"Jump to a given page of the ouput of a previous command in this channel."
),
)
async def page_group(self, ctx: LionContext):
"""
No description.
"""
pass
async def page_jump(self, ctx: LionContext, jumper):
pager = BasePager.get_active_pager(ctx.channel.id, ctx.author.id)
if pager is None:
await ctx.error_reply(
_p('cmd:page|error:no_pager', "No pager listening in this channel!")
)
else:
if ctx.interaction:
await ctx.interaction.response.defer()
pager.page_num = jumper(pager)
await pager.redraw()
if ctx.interaction:
await ctx.interaction.delete_original_response()
@page_group.command(
name=_p('cmd:page_next', 'next'),
description=_p('cmd:page_next|desc', "Jump to the next page of output.")
)
async def next_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: pager.page_num + 1)
@page_group.command(
name=_p('cmd:page_prev', 'prev'),
description=_p('cmd:page_prev|desc', "Jump to the previous page of output.")
)
async def prev_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: pager.page_num - 1)
@page_group.command(
name=_p('cmd:page_first', 'first'),
description=_p('cmd:page_first|desc', "Jump to the first page of output.")
)
async def first_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: 0)
@page_group.command(
name=_p('cmd:page_last', 'last'),
description=_p('cmd:page_last|desc', "Jump to the last page of output.")
)
async def last_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: -1)
@page_group.command(
name=_p('cmd:page_select', 'select'),
description=_p('cmd:page_select|desc', "Select a page of the output to jump to.")
)
@appcmds.rename(
page=_p('cmd:page_select|param:page', 'page')
)
@appcmds.describe(
page=_p('cmd:page_select|param:page|desc', "The page name or number to jump to.")
)
async def page_cmd(self, ctx: LionContext, page: str):
pager = BasePager.get_active_pager(ctx.channel.id, ctx.author.id)
if pager is None:
await ctx.error_reply(
_p('cmd:page_select|error:no_pager', "No pager listening in this channel!")
)
else:
await pager.page_cmd(ctx.interaction, page)
@page_cmd.autocomplete('page')
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
pager = BasePager.get_active_pager(interaction.channel_id, interaction.user.id)
if pager is None:
return [
appcmds.Choice(
name=_p('cmd:page_select|acmpl|error:no_pager', "No active pagers in this channel!"),
value=partial
)
]
else:
return await pager.page_acmpl(interaction, partial)

View File

@@ -1,526 +0,0 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict, TYPE_CHECKING
from typing_extensions import Annotated
import functools
import asyncio
import logging
import time
from enum import Enum
from contextvars import copy_context, Context
from itertools import groupby
import discord
from discord.ui import Modal, TextInput, View, Item
from discord.ui.button import Button, button
import discord.app_commands as appcmd
from discord.app_commands.transformers import AppCommandOptionType
from meta.logger import log_action_stack, logging_context
logger = logging.getLogger(__name__)
def create_task_in(coro, context: Context):
"""
Transitional.
Since py3.10 asyncio does not support context instantiation,
this helper method runs `asyncio.create_task(coro)` inside the given context.
"""
return context.run(asyncio.create_task, coro)
class HookedItem:
"""
Mixin for Item classes allowing an instance to be used as a callback decorator.
"""
def __init__(self, *args, pass_kwargs={}, **kwargs):
super().__init__(*args, **kwargs)
self.pass_kwargs = pass_kwargs
def __call__(self, coro):
async def wrapped(view, interaction, **kwargs):
return await coro(view, interaction, self, **kwargs, **self.pass_kwargs)
self.callback = wrapped
return self
class AButton(HookedItem, Button):
...
class LeoUI(View):
"""
View subclass for small-scale user interfaces.
While a 'View' provides an interface for managing a collection of components,
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
The `LeoUI` also exposes more advanced cleanup and timeout methods,
and preserves the context.
"""
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._name = ui_name or self.__class__.__name__
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
# List of slaved views to stop when this view stops
self._slaves: List[View] = []
# TODO: Replace this with a substitutable ViewLayout class
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
def to_components(self) -> List[Dict[str, Any]]:
"""
Extending component generator to apply the set _layout, if it exists.
"""
if self._layout is not None:
# Alternative rendering using layout
components = []
for i, row in enumerate(self._layout):
# Skip empty rows
if not row:
continue
# Since we aren't relying on ViewWeights, manually check width here
if sum(item.width for item in row) > 5:
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
# Create the component dict for this row
components.append({
'type': 1,
'components': [item.to_component_dict() for item in row]
})
else:
components = super().to_components()
return components
def set_layout(self, *rows: tuple[Item, ...]) -> None:
"""
Set the layout of the rendered View as a matrix of items,
or more precisely, a list of action rows.
This acts independently of the existing sorting with `_ViewWeights`,
and overrides the sorting if applied.
"""
self._layout = rows
async def cleanup(self):
"""
Coroutine to run when timeing out, stopping, or cancelling.
Generally cleans up any open resources, and removes any leftover components.
"""
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
return None
def stop(self):
"""
Extends View.stop() to also stop all the slave views.
Note that stopping is idempotent, so it is okay if close() also calls stop().
"""
for slave in self._slaves:
slave.stop()
super().stop()
async def close(self, msg=None):
self.stop()
await self.cleanup()
async def pre_timeout(self):
"""
Task to execute before actually timing out.
This may cancel the timeout by refreshing or rescheduling it.
(E.g. to ask the user whether they want to keep going.)
Default implementation does nothing.
"""
return None
async def on_timeout(self):
"""
Task to execute after timeout is complete.
Default implementation calls cleanup.
"""
await self.cleanup()
async def __dispatch_timeout(self):
"""
This essentially extends View._dispatch_timeout,
to include a pre_timeout task
which may optionally refresh and hence cancel the timeout.
"""
if self.__stopped.done():
# We are already stopped, nothing to do
return
with logging_context(action='Timeout'):
try:
await self.pre_timeout()
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
pass
except Exception:
await logger.exception(
"Unhandled error caught while dispatching timeout for {self!r}.",
extra={'with_ctx': True, 'action': 'Error'}
)
# Check if we still need to timeout
if self.timeout is None:
# The timeout was removed entirely, silently walk away
return
if self.__stopped.done():
# We stopped while waiting for the pre timeout.
# Or maybe another thread timed us out
# Either way, we are done here
return
now = time.monotonic()
if self.__timeout_expiry is not None and now < self._timeout_expiry:
# The timeout was extended, make sure the timeout task is running then fade away
if self.__timeout_task is None or self.__timeout_task.done():
self.__timeout_task = asyncio.create_task(self.__timeout_task_impl())
else:
# Actually timeout, and call the post-timeout task for cleanup.
self._really_timeout()
await self.on_timeout()
def _dispatch_timeout(self):
"""
Overriding timeout method completely, to support interactive flow during timeout,
and optional refreshing of the timeout.
"""
return self._context.run(asyncio.create_task, self.dispatch_timeout())
def _really_timeout(self):
"""
Actuallly times out the View.
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
which is now handled by `__dispatch_timeout`.
"""
if self.__stopped.done():
return
if self.__cancel_callback:
self.__cancel_callback(self)
self.__cancel_callback = None
self.__stopped.set_result(True)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
"""
Default LeoUI error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}",
extra={'with_ctx': True, 'action': 'UIError'}
)
class AsComponents(LeoUI):
"""
Simple container class to accept a number of Items and turn them into an attachable View.
"""
def __init__(self, *items, pass_kwargs={}, **kwargs):
super().__init__(**kwargs)
self.pass_kwargs = pass_kwargs
for item in items:
item.callback = self.wrap_callback(item.callback)
self.add_item(item)
def wrap_callback(self, coro):
async def wrapped(*args, **kwargs):
return await coro(self, *args, **kwargs, **self.pass_kwargs)
return wrapped
class LeoModal(Modal):
"""
Context-aware Modal class.
"""
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
super().__init__(**kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
def _dispatch_submit(self, *args, **kwargs):
"""
Extending event dispatch to run in the instantiation context.
"""
return self._context.run(super()._dispatch_submit, *args, **kwargs)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
"""
Default LeoModal error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in {self!r}",
extra={'with_ctx': True, 'action': 'ModalError'}
)
def error_handler_for(exc):
def wrapper(coro):
coro._ui_error_handler_for_ = exc
return coro
return wrapper
class FastModal(LeoModal):
__class_error_handlers__ = []
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
error_handlers = {}
for base in reversed(cls.__mro__):
for name, member in base.__dict__.items():
if hasattr(member, '_ui_error_handler_for_'):
error_handlers[name] = member
cls.__class_error_handlers__ = list(error_handlers.values())
def __init__error_handlers__(self):
handlers = {}
for handler in self.__class_error_handlers__:
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
return handlers
def __init__(self, *items: TextInput, **kwargs):
super().__init__(**kwargs)
for item in items:
self.add_item(item)
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
self._error_handlers = self.__init__error_handlers__()
def error_handler(self, exception):
def wrapper(coro):
self._error_handlers[exception] = coro
return coro
return wrapper
async def wait_for(self, check=None, timeout=None):
# Wait for _result or timeout
# If we timeout, or the view times out, raise TimeoutError
# Otherwise, return the Interaction
# This allows multiple listeners and callbacks to wait on
while True:
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
if check is not None:
if not check(result):
continue
return result
async def on_timeout(self):
self._result.set_exception(asyncio.TimeoutError)
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
def wrapper(coro):
async def wrapped_callback(interaction):
with logging_context(action=coro.__name__):
if check is not None:
if not check(interaction):
return
try:
await coro(interaction, *pass_args, **pass_kwargs)
except Exception as error:
await self.on_error(interaction, error)
finally:
if once:
self._waiters.remove(wrapped_callback)
self._waiters.append(wrapped_callback)
return wrapper
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
try:
# First let our error handlers have a go
# If there is no handler for this error, or the handlers themselves error,
# drop to the superclass error handler implementation.
try:
raise error
except tuple(self._error_handlers.keys()) as e:
# If an error handler is registered for this exception, run it.
for cls, handler in self._error_handlers.items():
if isinstance(e, cls):
await handler(interaction, e)
except Exception as error:
await super().on_error(interaction, error)
async def on_submit(self, interaction):
old_result = self._result
self._result = asyncio.get_event_loop().create_future()
old_result.set_result(interaction)
for waiter in self._waiters:
asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}")
async def input(
interaction: discord.Interaction,
title: str,
question: Optional[str] = None,
field: Optional[TextInput] = None,
timeout=180,
**kwargs,
):
"""
Spawn a modal to accept input.
Returns an (interaction, value) pair, with interaction not yet responded to.
May raise asyncio.TimeoutError if the view times out.
"""
if field is None:
field = TextInput(
label=kwargs.get('label', question),
**kwargs
)
modal = FastModal(
field,
title=title,
timeout=timeout
)
await interaction.response.send_modal(modal)
interaction = await modal.wait_for()
return (interaction, field.value)
class ChoicedEnum(Enum):
@property
def choice_name(self):
return self.name
@property
def choice_value(self):
return self.value
@property
def choice(self):
return appcmd.Choice(
name=self.choice_name, value=self.choice_value
)
@classmethod
def choices(self):
return [item.choice for item in self]
@classmethod
def make_choice_map(cls):
return {item.choice_value: item for item in cls}
@classmethod
async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any):
return transformer._choice_map[value]
@classmethod
def option_type(cls) -> AppCommandOptionType:
return AppCommandOptionType.string
@classmethod
def transformer(cls, *args) -> appcmd.Transformer:
return ChoicedEnumTransformer(cls, *args)
class ChoicedEnumTransformer(appcmd.Transformer):
# __discord_app_commands_is_choice__ = True
def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None:
super().__init__()
self._type = opt_type
self._enum = enum
self._choices = enum.choices()
self._choice_map = enum.make_choice_map()
@property
def _error_display_name(self) -> str:
return self._enum.__name__
@property
def type(self) -> AppCommandOptionType:
return self._type
@property
def choices(self):
return self._choices
async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any:
return await self._enum.transform(self, interaction, value)
if TYPE_CHECKING:
from typing_extensions import Annotated as Transformed
else:
class Transformed:
def __class_getitem__(self, items):
cls = items[0]
options = items[1:]
if not hasattr(cls, 'transformer'):
raise ValueError("Tranformed class must have a transformer classmethod.")
transformer = cls.transformer(*options)
return appcmd.Transform[cls, transformer]
class ModalRetryUI(LeoUI):
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.modal = modal
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
self.message = message
self._interaction = None
if label is not None:
self.retry_button.label = label
@property
def embed(self):
return discord.Embed(
description=self.message,
colour=discord.Colour.red()
)
async def respond_to(self, interaction):
self._interaction = interaction
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
@button(label="Retry")
async def retry_button(self, interaction, butt):
# Setting these here so they don't update in the meantime
for item, value in self.item_values.items():
item.default = value
if self._interaction is not None:
await self._interaction.delete_original_response()
self._interaction = None
await interaction.response.send_modal(self.modal)
await self.close()

20
bot/utils/ui/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
import asyncio
import logging
from .. import util_babel
logger = logging.getLogger(__name__)
from .hooked import *
from .leo import *
from .micros import *
from .pagers import *
from .transformed import *
# def create_task_in(coro, context: Context):
# """
# Transitional.
# Since py3.10 asyncio does not support context instantiation,
# this helper method runs `asyncio.create_task(coro)` inside the given context.
# """
# return context.run(asyncio.create_task, coro)

46
bot/utils/ui/hooked.py Normal file
View File

@@ -0,0 +1,46 @@
from discord.ui.button import Button
from .leo import LeoUI
__all__ = (
'HookedItem',
'AButton',
'AsComponents'
)
class HookedItem:
"""
Mixin for Item classes allowing an instance to be used as a callback decorator.
"""
def __init__(self, *args, pass_kwargs={}, **kwargs):
super().__init__(*args, **kwargs)
self.pass_kwargs = pass_kwargs
def __call__(self, coro):
async def wrapped(view, interaction, **kwargs):
return await coro(view, interaction, self, **kwargs, **self.pass_kwargs)
self.callback = wrapped
return self
class AButton(HookedItem, Button):
...
class AsComponents(LeoUI):
"""
Simple container class to accept a number of Items and turn them into an attachable View.
"""
def __init__(self, *items, pass_kwargs={}, **kwargs):
super().__init__(**kwargs)
self.pass_kwargs = pass_kwargs
for item in items:
item.callback = self.wrap_callback(item.callback)
self.add_item(item)
def wrap_callback(self, coro):
async def wrapped(*args, **kwargs):
return await coro(self, *args, **kwargs, **self.pass_kwargs)
return wrapped

247
bot/utils/ui/leo.py Normal file
View File

@@ -0,0 +1,247 @@
from typing import List, Optional, Any, Dict
import asyncio
import logging
import time
from contextvars import copy_context, Context
import discord
from discord.ui import Modal, View, Item
from meta.logger import log_action_stack, logging_context
from . import logger
__all__ = (
'LeoUI',
'LeoModal',
'error_handler_for'
)
class LeoUI(View):
"""
View subclass for small-scale user interfaces.
While a 'View' provides an interface for managing a collection of components,
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
The `LeoUI` also exposes more advanced cleanup and timeout methods,
and preserves the context.
"""
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._name = ui_name or self.__class__.__name__
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
# List of slaved views to stop when this view stops
self._slaves: List[View] = []
# TODO: Replace this with a substitutable ViewLayout class
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
def to_components(self) -> List[Dict[str, Any]]:
"""
Extending component generator to apply the set _layout, if it exists.
"""
if self._layout is not None:
# Alternative rendering using layout
components = []
for i, row in enumerate(self._layout):
# Skip empty rows
if not row:
continue
# Since we aren't relying on ViewWeights, manually check width here
if sum(item.width for item in row) > 5:
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
# Create the component dict for this row
components.append({
'type': 1,
'components': [item.to_component_dict() for item in row]
})
else:
components = super().to_components()
return components
def set_layout(self, *rows: tuple[Item, ...]) -> None:
"""
Set the layout of the rendered View as a matrix of items,
or more precisely, a list of action rows.
This acts independently of the existing sorting with `_ViewWeights`,
and overrides the sorting if applied.
"""
self._layout = rows
async def cleanup(self):
"""
Coroutine to run when timeing out, stopping, or cancelling.
Generally cleans up any open resources, and removes any leftover components.
"""
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
return None
def stop(self):
"""
Extends View.stop() to also stop all the slave views.
Note that stopping is idempotent, so it is okay if close() also calls stop().
"""
for slave in self._slaves:
slave.stop()
super().stop()
async def close(self, msg=None):
self.stop()
await self.cleanup()
async def pre_timeout(self):
"""
Task to execute before actually timing out.
This may cancel the timeout by refreshing or rescheduling it.
(E.g. to ask the user whether they want to keep going.)
Default implementation does nothing.
"""
return None
async def on_timeout(self):
"""
Task to execute after timeout is complete.
Default implementation calls cleanup.
"""
await self.cleanup()
async def __dispatch_timeout(self):
"""
This essentially extends View._dispatch_timeout,
to include a pre_timeout task
which may optionally refresh and hence cancel the timeout.
"""
if self.__stopped.done():
# We are already stopped, nothing to do
return
with logging_context(action='Timeout'):
try:
await self.pre_timeout()
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
pass
except Exception:
await logger.exception(
"Unhandled error caught while dispatching timeout for {self!r}.",
extra={'with_ctx': True, 'action': 'Error'}
)
# Check if we still need to timeout
if self.timeout is None:
# The timeout was removed entirely, silently walk away
return
if self.__stopped.done():
# We stopped while waiting for the pre timeout.
# Or maybe another thread timed us out
# Either way, we are done here
return
now = time.monotonic()
if self.__timeout_expiry is not None and now < self._timeout_expiry:
# The timeout was extended, make sure the timeout task is running then fade away
if self.__timeout_task is None or self.__timeout_task.done():
self.__timeout_task = asyncio.create_task(self.__timeout_task_impl())
else:
# Actually timeout, and call the post-timeout task for cleanup.
self._really_timeout()
await self.on_timeout()
def _dispatch_timeout(self):
"""
Overriding timeout method completely, to support interactive flow during timeout,
and optional refreshing of the timeout.
"""
return self._context.run(asyncio.create_task, self.dispatch_timeout())
def _really_timeout(self):
"""
Actuallly times out the View.
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
which is now handled by `__dispatch_timeout`.
"""
if self.__stopped.done():
return
if self.__cancel_callback:
self.__cancel_callback(self)
self.__cancel_callback = None
self.__stopped.set_result(True)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
"""
Default LeoUI error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}",
extra={'with_ctx': True, 'action': 'UIError'}
)
class LeoModal(Modal):
"""
Context-aware Modal class.
"""
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
super().__init__(**kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
def _dispatch_submit(self, *args, **kwargs):
"""
Extending event dispatch to run in the instantiation context.
"""
return self._context.run(super()._dispatch_submit, *args, **kwargs)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
"""
Default LeoModal error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in {self!r}",
extra={'with_ctx': True, 'action': 'ModalError'}
)
def error_handler_for(exc):
def wrapper(coro):
coro._ui_error_handler_for_ = exc
return coro
return wrapper

315
bot/utils/ui/micros.py Normal file
View File

@@ -0,0 +1,315 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
import functools
import asyncio
import discord
from discord.ui import TextInput
from discord.ui.button import button
from meta.logger import logging_context
from meta.errors import ResponseTimedOut
from .leo import LeoModal, LeoUI
__all__ = (
'FastModal',
'ModalRetryUI',
'Confirm',
'input',
)
class FastModal(LeoModal):
__class_error_handlers__ = []
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
error_handlers = {}
for base in reversed(cls.__mro__):
for name, member in base.__dict__.items():
if hasattr(member, '_ui_error_handler_for_'):
error_handlers[name] = member
cls.__class_error_handlers__ = list(error_handlers.values())
def __init__error_handlers__(self):
handlers = {}
for handler in self.__class_error_handlers__:
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
return handlers
def __init__(self, *items: TextInput, **kwargs):
super().__init__(**kwargs)
for item in items:
self.add_item(item)
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
self._error_handlers = self.__init__error_handlers__()
def error_handler(self, exception):
def wrapper(coro):
self._error_handlers[exception] = coro
return coro
return wrapper
async def wait_for(self, check=None, timeout=None):
# Wait for _result or timeout
# If we timeout, or the view times out, raise TimeoutError
# Otherwise, return the Interaction
# This allows multiple listeners and callbacks to wait on
while True:
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
if check is not None:
if not check(result):
continue
return result
async def on_timeout(self):
self._result.set_exception(asyncio.TimeoutError)
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
def wrapper(coro):
async def wrapped_callback(interaction):
with logging_context(action=coro.__name__):
if check is not None:
if not check(interaction):
return
try:
await coro(interaction, *pass_args, **pass_kwargs)
except Exception as error:
await self.on_error(interaction, error)
finally:
if once:
self._waiters.remove(wrapped_callback)
self._waiters.append(wrapped_callback)
return wrapper
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
try:
# First let our error handlers have a go
# If there is no handler for this error, or the handlers themselves error,
# drop to the superclass error handler implementation.
try:
raise error
except tuple(self._error_handlers.keys()) as e:
# If an error handler is registered for this exception, run it.
for cls, handler in self._error_handlers.items():
if isinstance(e, cls):
await handler(interaction, e)
except Exception as error:
await super().on_error(interaction, error)
async def on_submit(self, interaction):
old_result = self._result
self._result = asyncio.get_event_loop().create_future()
old_result.set_result(interaction)
for waiter in self._waiters:
asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}")
async def input(
interaction: discord.Interaction,
title: str,
question: Optional[str] = None,
field: Optional[TextInput] = None,
timeout=180,
**kwargs,
):
"""
Spawn a modal to accept input.
Returns an (interaction, value) pair, with interaction not yet responded to.
May raise asyncio.TimeoutError if the view times out.
"""
if field is None:
field = TextInput(
label=kwargs.get('label', question),
**kwargs
)
modal = FastModal(
field,
title=title,
timeout=timeout
)
await interaction.response.send_modal(modal)
interaction = await modal.wait_for()
return (interaction, field.value)
class ModalRetryUI(LeoUI):
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.modal = modal
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
self.message = message
self._interaction = None
if label is not None:
self.retry_button.label = label
@property
def embed(self):
return discord.Embed(
description=self.message,
colour=discord.Colour.red()
)
async def respond_to(self, interaction):
self._interaction = interaction
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
@button(label="Retry")
async def retry_button(self, interaction, butt):
# Setting these here so they don't update in the meantime
for item, value in self.item_values.items():
item.default = value
if self._interaction is not None:
await self._interaction.delete_original_response()
self._interaction = None
await interaction.response.send_modal(self.modal)
await self.close()
class Confirm(LeoUI):
"""
Micro UI class implementing a confirmation question.
Parameters
----------
confirm_msg: str
The confirmation question to ask from the user.
This is set as the description of the `embed` property.
The `embed` may be further modified if required.
permitted_id: Optional[int]
The user id allowed to access this interaction.
Other users will recieve an access denied error message.
defer: bool
Whether to defer the interaction response while handling the button.
It may be useful to set this to `False` to obtain manual control
over the interaction response flow (e.g. to send a modal or ephemeral message).
The button press interaction may be accessed through `Confirm.interaction`.
Default: True
Example
-------
```
confirm = Confirm("Are you sure?", ctx.author.id)
confirm.embed.colour = discord.Colour.red()
confirm.confirm_button.label = "Yes I am sure"
confirm.cancel_button.label = "No I am not sure"
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResultTimedOut:
return
```
"""
def __init__(
self,
confirm_msg: str,
permitted_id: Optional[int] = None,
defer: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.confirm_msg = confirm_msg
self.permitted_id = permitted_id
self.defer = defer
self._embed: Optional[discord.Embed] = None
self._result: asyncio.Future[bool] = asyncio.Future()
# Indicates whether we should delete the message or the interaction response
self._is_followup: bool = False
self._original: Optional[discord.Interaction] = None
self._message: Optional[discord.Message] = None
async def interaction_check(self, interaction: discord.Interaction):
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
async def on_timeout(self):
# Propagate timeout to result Future
self._result.set_exception(ResponseTimedOut)
await self.cleanup()
async def cleanup(self):
"""
Cleanup the confirmation prompt by deleting it, if possible.
Ignores any Discord errors that occur during the process.
"""
try:
if self._is_followup and self._message:
await self._message.delete()
elif not self._is_followup and self._original and not self._original.is_expired():
await self._original.delete_original_response()
except discord.HTTPException:
# A user probably already deleted the message
# Anything could have happened, just ignore.
pass
@button(label="Confirm")
async def confirm_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(True)
await self.close()
@button(label="Cancel")
async def cancel_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(False)
await self.close()
@property
def embed(self):
"""
Confirmation embed shown to the user.
This is cached, and may be modifed directly through the usual EmbedProxy API,
or explicitly overwritten.
"""
if self._embed is None:
self._embed = discord.Embed(
colour=discord.Colour.orange(),
description=self.confirm_msg
)
return self._embed
@embed.setter
def embed(self, value):
self._embed = value
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
"""
Send this confirmation prompt in response to the provided interaction.
Extra keyword arguments are passed to `Interaction.response.send_message`
or `Interaction.send_followup`, depending on whether
the provided interaction has already been responded to.
The `epehemeral` argument is handled specially,
since the question message can only be deleted through `Interaction.delete_original_response`.
Waits on and returns the internal `result` Future.
Returns: bool
True if the user pressed the confirm button.
False if the user pressed the cancel button.
Raises:
ResponseTimedOut:
If the user does not respond before the UI times out.
"""
self._original = interaction
if interaction.response.is_done():
# Interaction already responded to, send a follow up
if ephemeral:
raise ValueError("Cannot send an ephemeral response to a used interaction.")
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
self._is_followup = True
else:
await interaction.response.send_message(
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
)
self._is_followup = False
return await self._result

456
bot/utils/ui/pagers.py Normal file
View File

@@ -0,0 +1,456 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
from collections import defaultdict
import discord
from discord.ui.button import Button, button
from discord import app_commands as appcmds
from meta.logger import log_action_stack, logging_context
from meta.errors import SafeCancellation
from meta.config import conf
from babel.translator import ctx_translator
from ..lib import MessageArgs, error_embed
from .. import util_babel
from .leo import LeoUI
_p = util_babel._p
__all__ = (
'BasePager',
'Pager',
)
class BasePager(LeoUI):
"""
An ABC describing the common interface for a Paging UI.
A paging UI represents a sequence of pages, accessible by `next` and `previous` buttons,
and possibly by a dropdown (not implemented).
A `Page` is represented as a `MessageArgs` object, which is passable to `send` and `edit` methods as required.
Each page of a paging UI is accessed through the coroutine `get_page`.
This allows for more complex paging schemes where the pages are expensive to compute,
and not generally needed simultaneously.
In general, `get_page` should cache expensive pages,
perhaps simply with a `cached` decorator, but this is not enforced.
The state of the base UI is represented as the current `page_num` and the `current_page`.
This class also maintains an `active_pagers` cache,
representing all `BasePager`s that are currently running.
This allows access from external page controlling utilities, e.g. the `/page` command.
"""
# channelid -> pager.id -> list of active pagers in this channel
active_pagers: dict[int, dict[int, 'BasePager']] = defaultdict(dict)
page_num: int
current_page: MessageArgs
_channelid: Optional[int]
@classmethod
def get_active_pager(self, channelid, userid):
"""
Get the last active pager in the `destinationid`, which may be accessed by `userid`.
Returns None if there are no matching pagers.
"""
for pager in reversed(self.active_pagers[channelid].values()):
if pager.access_check(userid):
return pager
def set_active(self):
if self._channelid is None:
raise ValueError("Cannot set active without a channelid.")
self.active_pagers[self._channelid][self.id] = self
def set_inactive(self):
self.active_pagers[self._channelid].pop(self.id, None)
def access_check(self, userid):
"""
Check whether the given userid is allowed to use this UI.
Must be overridden by subclasses.
"""
raise NotImplementedError
async def get_page(self, page_id) -> MessageArgs:
"""
`get_page` returns the specified page number, starting from 0.
An implementation of `get_page` must:
- Always return a page (if no data is a valid state, must return a placeholder page).
- Always accept out-of-range `page_id` values.
- There is no behaviour specified for these, although they will usually be modded into the correct
range.
- In some cases (e.g. stream data where we don't have a last page),
they may simply return the last correct page instead.
"""
raise NotImplementedError
async def page_cmd(self, interaction: discord.Interaction, value: str):
"""
Command implementation for the paging command.
Pager subclasses should override this if they use `active_pagers`.
Default implementation is essentially a no-op,
simply replying to the interaction.
"""
await interaction.response.defer()
return
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
"""
Command autocompletion for the paging command.
Pager subclasses should override this if they use `active_pagers`.
"""
return []
@button(emoji=conf.emojis.getemoji('forward'))
async def next_page_button(self, interaction: discord.Interaction, press):
await interaction.response.defer()
self.page_num += 1
await self.redraw()
@button(emoji=conf.emojis.getemoji('backward'))
async def prev_page_button(self, interaction: discord.Interaction, press):
await interaction.response.defer()
self.page_num -= 1
await self.redraw()
async def refresh(self):
"""
Recalculate current computed state.
(E.g. fetch current page, set layout, disable components, etc.)
"""
self.current_page = await self.get_page(self.page_num)
async def redraw(self):
"""
This should refresh the current state and redraw the UI.
Not implemented here, as the implementation depends on whether this is a reaction response ephemeral UI
or a message=based one.
"""
raise NotImplementedError
class Pager(BasePager):
"""
MicroUI to display a sequence of static pages,
supporting paging reaction and paging commands.
Parameters
----------
pages: list[MessageArgs]
A non-empty list of message arguments to page.
start_from: int
The page number to display first.
Default: 0
locked: bool
Whether to only allow the author to use the paging interface.
"""
# List of valid keys indicating movement to the next page
next_list = _p('cmd:page|pager:Pager|options:next', "n, nxt, next, forward, +")
# List of valid keys indicating movement to the previous page
prev_list = _p('cmd:page|pager:Pager|options:prev', "p, prev, back, -")
# List of valid keys indicating movement to the first page
first_list = _p('cmd:page|pager:Pager|options:first', "f, first, one, start")
# List of valid keys indicating movement to the last page
last_list = _p('cmd:page|pager:Pager|options:last', "l, last, end")
def __init__(self, pages: list[MessageArgs],
start_from=0,
show_cancel=False, delete_on_cancel=True, delete_after=False, **kwargs):
super().__init__(**kwargs)
self._pages = pages
self.page_num = start_from
self.current_page = pages[self.page_num]
self._locked = True
self._ownerid: Optional[int] = None
self._channelid: Optional[int] = None
if not pages:
raise ValueError("Cannot run Pager with no pages.")
self._original: Optional[discord.Interaction] = None
self._is_followup: bool = False
self._message: Optional[discord.Message] = None
self.show_cancel = show_cancel
self._delete_on_cancel = delete_on_cancel
self._delete_after = delete_after
@property
def ownerid(self):
if self._ownerid is not None:
return self._ownerid
elif self._original:
return self._original.user.id
else:
return None
def access_check(self, userid):
return not self._locked or (userid == self.ownerid)
async def interaction_check(self, interaction: discord.Interaction):
return self.access_check(interaction.user.id)
@button(emoji=conf.emojis.getemoji('cancel'))
async def cancel_button(self, interaction: discord.Interaction, press: Button):
await interaction.response.defer()
if self._delete_on_cancel:
self._delete_after = True
await self.close()
async def cleanup(self):
self.set_inactive()
# If we still have a message, delete it or clear the view
try:
if self._is_followup:
if self._message:
if self._delete_after:
await self._message.delete()
else:
await self._message.edit(view=None)
else:
if self._original and not self._original.is_expired():
if self._delete_after:
await self._original.delete_original_response()
else:
await self._original.edit_original_response(view=None)
except discord.HTTPException:
# Nothing we can do here
pass
async def get_page(self, page_id):
page_id %= len(self._pages)
return self._pages[page_id]
def page_count(self):
return len(self.pages)
async def page_cmd(self, interaction: discord.Interaction, value: str):
"""
`/page` command for the `Pager` MicroUI.
"""
await interaction.response.defer(ephemeral=True)
t = ctx_translator.get().t
nexts = {word.strip() for word in t(self.next_list).split(',')}
prevs = {word.strip() for word in t(self.prev_list).split(',')}
firsts = {word.strip() for word in t(self.first_list).split(',')}
lasts = {word.strip() for word in t(self.last_list).split(',')}
if value:
value = value.lower().strip()
if value.isdigit():
# Assume value is page number
self.page_num = int(value) - 1
if self.page_num == -1:
self.page_num = 0
elif value in firsts:
self.page_num = 0
elif value in nexts:
self.page_num += 1
elif value in prevs:
self.page_num -= 1
elif value in lasts:
self.page_num = -1
elif value.startswith('-') and value[1:].isdigit():
self.page_num = - int(value[1:])
else:
await interaction.edit_original_response(
embed=error_embed(
t(_p(
'cmd:page|pager:Pager|error:parse',
"Could not understand page specification `{value}`."
)).format(value=value)
)
)
return
await interaction.delete_original_response()
await self.redraw()
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
"""
`/page` command autocompletion for the `Pager` MicroUI.
"""
t = ctx_translator.get().t
nexts = {word.strip() for word in t(self.next_list).split(',')}
prevs = {word.strip() for word in t(self.prev_list).split(',')}
firsts = {word.strip() for word in t(self.first_list).split(',')}
lasts = {word.strip() for word in t(self.last_list).split(',')}
total = len(self._pages)
num = self.page_num
page_choices: dict[int, str] = {}
# TODO: Support page names and hints?
if len(self._pages) > 10:
# First add the general choices
if num < total-1:
page_choices[total-1] = t(_p(
'cmd:page|acmpl|pager:Pager|choice:last',
"Last: Page {page}/{total}"
)).format(page=total, total=total)
page_choices[num] = t(_p(
'cmd:page|acmpl|pager:Pager|choice:current',
"Current: Page {page}/{total}"
)).format(page=num+1, total=total)
choices = [
appcmds.Choice(name=string, value=str(num+1))
for num, string in sorted(page_choices.items(), key=lambda t: t[0])
]
else:
# Particularly support page names here
choices = [
appcmds.Choice(
name='> ' * (i == num) + t(_p(
'cmd:page|acmpl|pager:Pager|choice:general',
"Page {page}"
)).format(page=i+1),
value=str(i+1)
)
for i in range(0, total)
]
partial = partial.strip()
if partial:
value = partial.lower().strip()
if value.isdigit():
# Assume value is page number
page_num = int(value) - 1
if page_num == -1:
page_num = 0
elif value in firsts:
page_num = 0
elif value in nexts:
page_num = self.page_num + 1
elif value in prevs:
page_num = self.page_num - 1
elif value in lasts:
page_num = -1
elif value.startswith('-') and value[1:].isdigit():
page_num = - int(value[1:])
else:
page_num = None
if page_num is not None:
page_num %= total
choice = appcmds.Choice(
name=t(_p(
'cmd:page|acmpl|pager:Page|choice:select',
"Selected: Page {page}/{total}"
)).format(page=page_num+1, total=total),
value=str(page_num + 1)
)
return [choice, *choices]
else:
return [
appcmds.Choice(
name=t(_p(
'cmd:page|acmpl|pager:Page|error:parse',
"No matching pages!"
)).format(page=page_num, total=total),
value=partial
)
]
else:
return choices
@property
def page_row(self):
if self.show_cancel:
if len(self._pages) > 1:
return (self.prev_page_button, self.cancel_button, self.next_page_button)
else:
return (self.cancel_button,)
else:
if len(self._pages) > 1:
return (self.prev_page_button, self.next_page_button)
else:
return ()
async def refresh(self):
await super().refresh()
self.set_layout(self.page_row)
async def redraw(self):
await self.refresh()
if not self._original:
raise ValueError("Running run pager manually without interaction.")
try:
if self._message:
await self._message.edit(**self.current_page.edit_args, view=self)
else:
if self._original.is_expired():
raise SafeCancellation("This interface has expired, please try again.")
await self._original.edit_original_response(**self.current_page.edit_args, view=self)
except discord.HTTPException:
raise SafeCancellation("Could not page your results! Please try again.")
async def run(self, interaction: discord.Interaction, ephemeral=False, locked=True, ownerid=None, **kwargs):
"""
Display the UI.
Attempts to reply to the interaction if it has not already been replied to,
otherwise send a follow-up.
An ephemeral response must be sent as an initial interaction response.
On the other hand, a semi-persistent response (expected to last longer than the lifetime of the interaction)
must be sent as a followup.
Extra kwargs are combined with the first page arguments and given to the relevant send method.
Parameters
----------
interaction: discord.Interaction
The interaction to send the pager in response to.
ephemeral: bool
Whether to send the interaction ephemerally.
If this is true, the interaction *must* be fresh (i.e. no response done).
Default: False
locked: bool
Whether this interface is locked to the user `self.ownerid`.
Irrelevant for ephemeral messages.
Use `ownerid` to override the default owner id.
Defaults to true for fail-safety.
Default: True
ownerid: Optional[int]
The userid allowed to use this interaction.
By default, this will be the `interaction.user.id`,
presuming that this is the user which originally triggered this message.
An override may be useful if a user triggers a paging UI for someone else.
"""
if not interaction.channel_id:
raise ValueError("Cannot run pager on a channelless interaction.")
self._original = interaction
self._ownerid = ownerid
self._locked = locked
self._channelid = interaction.channel_id
await self.refresh()
args = self.current_page.send_args | kwargs
if interaction.response.is_done():
if ephemeral:
raise ValueError("Ephemeral response requires fres interaction.")
self._message = await interaction.followup.send(**args, view=self)
self._is_followup = True
else:
self._is_followup = False
await interaction.response.send_message(**args, view=self)
self.set_active()

View File

@@ -0,0 +1,91 @@
from typing import Any, Type, TYPE_CHECKING
from enum import Enum
import discord
import discord.app_commands as appcmd
from discord.app_commands.transformers import AppCommandOptionType
__all__ = (
'ChoicedEnum',
'ChoicedEnumTransformer',
'Transformed',
)
class ChoicedEnum(Enum):
@property
def choice_name(self):
return self.name
@property
def choice_value(self):
return self.value
@property
def choice(self):
return appcmd.Choice(
name=self.choice_name, value=self.choice_value
)
@classmethod
def choices(self):
return [item.choice for item in self]
@classmethod
def make_choice_map(cls):
return {item.choice_value: item for item in cls}
@classmethod
async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any):
return transformer._choice_map[value]
@classmethod
def option_type(cls) -> AppCommandOptionType:
return AppCommandOptionType.string
@classmethod
def transformer(cls, *args) -> appcmd.Transformer:
return ChoicedEnumTransformer(cls, *args)
class ChoicedEnumTransformer(appcmd.Transformer):
# __discord_app_commands_is_choice__ = True
def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None:
super().__init__()
self._type = opt_type
self._enum = enum
self._choices = enum.choices()
self._choice_map = enum.make_choice_map()
@property
def _error_display_name(self) -> str:
return self._enum.__name__
@property
def type(self) -> AppCommandOptionType:
return self._type
@property
def choices(self):
return self._choices
async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any:
return await self._enum.transform(self, interaction, value)
if TYPE_CHECKING:
from typing_extensions import Annotated as Transformed
else:
class Transformed:
def __class_getitem__(self, items):
cls = items[0]
options = items[1:]
if not hasattr(cls, 'transformer'):
raise ValueError("Tranformed class must have a transformer classmethod.")
transformer = cls.transformer(*options)
return appcmd.Transform[cls, transformer]