rewrite: Restructure to include GUI.
This commit is contained in:
20
src/utils/ui/__init__.py
Normal file
20
src/utils/ui/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from .. import util_babel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .hooked import *
|
||||
from .leo import *
|
||||
from .micros import *
|
||||
from .pagers import *
|
||||
from .transformed import *
|
||||
|
||||
|
||||
# def create_task_in(coro, context: Context):
|
||||
# """
|
||||
# Transitional.
|
||||
# Since py3.10 asyncio does not support context instantiation,
|
||||
# this helper method runs `asyncio.create_task(coro)` inside the given context.
|
||||
# """
|
||||
# return context.run(asyncio.create_task, coro)
|
||||
59
src/utils/ui/hooked.py
Normal file
59
src/utils/ui/hooked.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
import discord
|
||||
from discord.ui.item import Item
|
||||
from discord.ui.button import Button
|
||||
|
||||
from .leo import LeoUI
|
||||
|
||||
__all__ = (
|
||||
'HookedItem',
|
||||
'AButton',
|
||||
'AsComponents'
|
||||
)
|
||||
|
||||
|
||||
class HookedItem:
|
||||
"""
|
||||
Mixin for Item classes allowing an instance to be used as a callback decorator.
|
||||
"""
|
||||
def __init__(self, *args, pass_kwargs={}, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pass_kwargs = pass_kwargs
|
||||
|
||||
def __call__(self, coro):
|
||||
async def wrapped(interaction, **kwargs):
|
||||
return await coro(interaction, self, **(self.pass_kwargs | kwargs))
|
||||
self.callback = wrapped
|
||||
return self
|
||||
|
||||
|
||||
class AButton(HookedItem, Button):
|
||||
...
|
||||
|
||||
|
||||
class AsComponents(LeoUI):
|
||||
"""
|
||||
Simple container class to accept a number of Items and turn them into an attachable View.
|
||||
"""
|
||||
def __init__(self, *items, pass_kwargs={}, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pass_kwargs = pass_kwargs
|
||||
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
|
||||
async def _scheduled_task(self, item: Item, interaction: discord.Interaction):
|
||||
try:
|
||||
item._refresh_state(interaction, interaction.data) # type: ignore
|
||||
|
||||
allow = await self.interaction_check(interaction)
|
||||
if not allow:
|
||||
return
|
||||
|
||||
if self.timeout:
|
||||
self.__timeout_expiry = time.monotonic() + self.timeout
|
||||
|
||||
await item.callback(interaction, **self.pass_kwargs)
|
||||
except Exception as e:
|
||||
return await self.on_error(interaction, e, item)
|
||||
247
src/utils/ui/leo.py
Normal file
247
src/utils/ui/leo.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextvars import copy_context, Context
|
||||
|
||||
import discord
|
||||
from discord.ui import Modal, View, Item
|
||||
|
||||
from meta.logger import log_action_stack, logging_context
|
||||
|
||||
from . import logger
|
||||
|
||||
__all__ = (
|
||||
'LeoUI',
|
||||
'LeoModal',
|
||||
'error_handler_for'
|
||||
)
|
||||
|
||||
|
||||
class LeoUI(View):
|
||||
"""
|
||||
View subclass for small-scale user interfaces.
|
||||
|
||||
While a 'View' provides an interface for managing a collection of components,
|
||||
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
|
||||
The `LeoUI` also exposes more advanced cleanup and timeout methods,
|
||||
and preserves the context.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if context is None:
|
||||
self._context = copy_context()
|
||||
else:
|
||||
self._context = context
|
||||
|
||||
self._name = ui_name or self.__class__.__name__
|
||||
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
|
||||
|
||||
# List of slaved views to stop when this view stops
|
||||
self._slaves: List[View] = []
|
||||
|
||||
# TODO: Replace this with a substitutable ViewLayout class
|
||||
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
|
||||
|
||||
def to_components(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extending component generator to apply the set _layout, if it exists.
|
||||
"""
|
||||
if self._layout is not None:
|
||||
# Alternative rendering using layout
|
||||
components = []
|
||||
for i, row in enumerate(self._layout):
|
||||
# Skip empty rows
|
||||
if not row:
|
||||
continue
|
||||
|
||||
# Since we aren't relying on ViewWeights, manually check width here
|
||||
if sum(item.width for item in row) > 5:
|
||||
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
|
||||
|
||||
# Create the component dict for this row
|
||||
components.append({
|
||||
'type': 1,
|
||||
'components': [item.to_component_dict() for item in row]
|
||||
})
|
||||
else:
|
||||
components = super().to_components()
|
||||
|
||||
return components
|
||||
|
||||
def set_layout(self, *rows: tuple[Item, ...]) -> None:
|
||||
"""
|
||||
Set the layout of the rendered View as a matrix of items,
|
||||
or more precisely, a list of action rows.
|
||||
|
||||
This acts independently of the existing sorting with `_ViewWeights`,
|
||||
and overrides the sorting if applied.
|
||||
"""
|
||||
self._layout = rows
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
Coroutine to run when timeing out, stopping, or cancelling.
|
||||
Generally cleans up any open resources, and removes any leftover components.
|
||||
"""
|
||||
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
|
||||
return None
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Extends View.stop() to also stop all the slave views.
|
||||
Note that stopping is idempotent, so it is okay if close() also calls stop().
|
||||
"""
|
||||
for slave in self._slaves:
|
||||
slave.stop()
|
||||
super().stop()
|
||||
|
||||
async def close(self, msg=None):
|
||||
self.stop()
|
||||
await self.cleanup()
|
||||
|
||||
async def pre_timeout(self):
|
||||
"""
|
||||
Task to execute before actually timing out.
|
||||
This may cancel the timeout by refreshing or rescheduling it.
|
||||
(E.g. to ask the user whether they want to keep going.)
|
||||
|
||||
Default implementation does nothing.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def on_timeout(self):
|
||||
"""
|
||||
Task to execute after timeout is complete.
|
||||
Default implementation calls cleanup.
|
||||
"""
|
||||
await self.cleanup()
|
||||
|
||||
async def __dispatch_timeout(self):
|
||||
"""
|
||||
This essentially extends View._dispatch_timeout,
|
||||
to include a pre_timeout task
|
||||
which may optionally refresh and hence cancel the timeout.
|
||||
"""
|
||||
if self.__stopped.done():
|
||||
# We are already stopped, nothing to do
|
||||
return
|
||||
|
||||
with logging_context(action='Timeout'):
|
||||
try:
|
||||
await self.pre_timeout()
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
await logger.exception(
|
||||
"Unhandled error caught while dispatching timeout for {self!r}.",
|
||||
extra={'with_ctx': True, 'action': 'Error'}
|
||||
)
|
||||
|
||||
# Check if we still need to timeout
|
||||
if self.timeout is None:
|
||||
# The timeout was removed entirely, silently walk away
|
||||
return
|
||||
|
||||
if self.__stopped.done():
|
||||
# We stopped while waiting for the pre timeout.
|
||||
# Or maybe another thread timed us out
|
||||
# Either way, we are done here
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if self.__timeout_expiry is not None and now < self._timeout_expiry:
|
||||
# The timeout was extended, make sure the timeout task is running then fade away
|
||||
if self.__timeout_task is None or self.__timeout_task.done():
|
||||
self.__timeout_task = asyncio.create_task(self.__timeout_task_impl())
|
||||
else:
|
||||
# Actually timeout, and call the post-timeout task for cleanup.
|
||||
self._really_timeout()
|
||||
await self.on_timeout()
|
||||
|
||||
def _dispatch_timeout(self):
|
||||
"""
|
||||
Overriding timeout method completely, to support interactive flow during timeout,
|
||||
and optional refreshing of the timeout.
|
||||
"""
|
||||
return self._context.run(asyncio.create_task, self.dispatch_timeout())
|
||||
|
||||
def _really_timeout(self):
|
||||
"""
|
||||
Actuallly times out the View.
|
||||
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
|
||||
which is now handled by `__dispatch_timeout`.
|
||||
"""
|
||||
if self.__stopped.done():
|
||||
return
|
||||
|
||||
if self.__cancel_callback:
|
||||
self.__cancel_callback(self)
|
||||
self.__cancel_callback = None
|
||||
|
||||
self.__stopped.set_result(True)
|
||||
|
||||
def _dispatch_item(self, *args, **kwargs):
|
||||
"""Extending event dispatch to run in the instantiation context."""
|
||||
return self._context.run(super()._dispatch_item, *args, **kwargs)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
|
||||
"""
|
||||
Default LeoUI error handle.
|
||||
This may be tail extended by subclasses to preserve the exception stack.
|
||||
"""
|
||||
try:
|
||||
raise error
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}",
|
||||
extra={'with_ctx': True, 'action': 'UIError'}
|
||||
)
|
||||
|
||||
|
||||
class LeoModal(Modal):
|
||||
"""
|
||||
Context-aware Modal class.
|
||||
"""
|
||||
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if context is None:
|
||||
self._context = copy_context()
|
||||
else:
|
||||
self._context = context
|
||||
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
|
||||
|
||||
def _dispatch_submit(self, *args, **kwargs):
|
||||
"""
|
||||
Extending event dispatch to run in the instantiation context.
|
||||
"""
|
||||
return self._context.run(super()._dispatch_submit, *args, **kwargs)
|
||||
|
||||
def _dispatch_item(self, *args, **kwargs):
|
||||
"""Extending event dispatch to run in the instantiation context."""
|
||||
return self._context.run(super()._dispatch_item, *args, **kwargs)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
|
||||
"""
|
||||
Default LeoModal error handle.
|
||||
This may be tail extended by subclasses to preserve the exception stack.
|
||||
"""
|
||||
try:
|
||||
raise error
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unhandled interaction exception occurred in {self!r}",
|
||||
extra={'with_ctx': True, 'action': 'ModalError'}
|
||||
)
|
||||
|
||||
|
||||
def error_handler_for(exc):
|
||||
def wrapper(coro):
|
||||
coro._ui_error_handler_for_ = exc
|
||||
return coro
|
||||
return wrapper
|
||||
315
src/utils/ui/micros.py
Normal file
315
src/utils/ui/micros.py
Normal file
@@ -0,0 +1,315 @@
|
||||
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
|
||||
import functools
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.ui import TextInput
|
||||
from discord.ui.button import button
|
||||
|
||||
from meta.logger import logging_context
|
||||
from meta.errors import ResponseTimedOut
|
||||
|
||||
from .leo import LeoModal, LeoUI
|
||||
|
||||
__all__ = (
|
||||
'FastModal',
|
||||
'ModalRetryUI',
|
||||
'Confirm',
|
||||
'input',
|
||||
)
|
||||
|
||||
|
||||
class FastModal(LeoModal):
|
||||
__class_error_handlers__ = []
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
error_handlers = {}
|
||||
for base in reversed(cls.__mro__):
|
||||
for name, member in base.__dict__.items():
|
||||
if hasattr(member, '_ui_error_handler_for_'):
|
||||
error_handlers[name] = member
|
||||
|
||||
cls.__class_error_handlers__ = list(error_handlers.values())
|
||||
|
||||
def __init__error_handlers__(self):
|
||||
handlers = {}
|
||||
for handler in self.__class_error_handlers__:
|
||||
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
|
||||
return handlers
|
||||
|
||||
def __init__(self, *items: TextInput, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
|
||||
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
|
||||
self._error_handlers = self.__init__error_handlers__()
|
||||
|
||||
def error_handler(self, exception):
|
||||
def wrapper(coro):
|
||||
self._error_handlers[exception] = coro
|
||||
return coro
|
||||
return wrapper
|
||||
|
||||
async def wait_for(self, check=None, timeout=None):
|
||||
# Wait for _result or timeout
|
||||
# If we timeout, or the view times out, raise TimeoutError
|
||||
# Otherwise, return the Interaction
|
||||
# This allows multiple listeners and callbacks to wait on
|
||||
while True:
|
||||
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
|
||||
if check is not None:
|
||||
if not check(result):
|
||||
continue
|
||||
return result
|
||||
|
||||
async def on_timeout(self):
|
||||
self._result.set_exception(asyncio.TimeoutError)
|
||||
|
||||
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
|
||||
def wrapper(coro):
|
||||
async def wrapped_callback(interaction):
|
||||
with logging_context(action=coro.__name__):
|
||||
if check is not None:
|
||||
if not check(interaction):
|
||||
return
|
||||
try:
|
||||
await coro(interaction, *pass_args, **pass_kwargs)
|
||||
except Exception as error:
|
||||
await self.on_error(interaction, error)
|
||||
finally:
|
||||
if once:
|
||||
self._waiters.remove(wrapped_callback)
|
||||
self._waiters.append(wrapped_callback)
|
||||
return wrapper
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
|
||||
try:
|
||||
# First let our error handlers have a go
|
||||
# If there is no handler for this error, or the handlers themselves error,
|
||||
# drop to the superclass error handler implementation.
|
||||
try:
|
||||
raise error
|
||||
except tuple(self._error_handlers.keys()) as e:
|
||||
# If an error handler is registered for this exception, run it.
|
||||
for cls, handler in self._error_handlers.items():
|
||||
if isinstance(e, cls):
|
||||
await handler(interaction, e)
|
||||
except Exception as error:
|
||||
await super().on_error(interaction, error)
|
||||
|
||||
async def on_submit(self, interaction):
|
||||
old_result = self._result
|
||||
self._result = asyncio.get_event_loop().create_future()
|
||||
old_result.set_result(interaction)
|
||||
|
||||
for waiter in self._waiters:
|
||||
asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}")
|
||||
|
||||
|
||||
async def input(
|
||||
interaction: discord.Interaction,
|
||||
title: str,
|
||||
question: Optional[str] = None,
|
||||
field: Optional[TextInput] = None,
|
||||
timeout=180,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Spawn a modal to accept input.
|
||||
Returns an (interaction, value) pair, with interaction not yet responded to.
|
||||
May raise asyncio.TimeoutError if the view times out.
|
||||
"""
|
||||
if field is None:
|
||||
field = TextInput(
|
||||
label=kwargs.get('label', question),
|
||||
**kwargs
|
||||
)
|
||||
modal = FastModal(
|
||||
field,
|
||||
title=title,
|
||||
timeout=timeout
|
||||
)
|
||||
await interaction.response.send_modal(modal)
|
||||
interaction = await modal.wait_for()
|
||||
return (interaction, field.value)
|
||||
|
||||
|
||||
class ModalRetryUI(LeoUI):
|
||||
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.modal = modal
|
||||
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
|
||||
|
||||
self.message = message
|
||||
|
||||
self._interaction = None
|
||||
|
||||
if label is not None:
|
||||
self.retry_button.label = label
|
||||
|
||||
@property
|
||||
def embed(self):
|
||||
return discord.Embed(
|
||||
description=self.message,
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
|
||||
async def respond_to(self, interaction):
|
||||
self._interaction = interaction
|
||||
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
|
||||
|
||||
@button(label="Retry")
|
||||
async def retry_button(self, interaction, butt):
|
||||
# Setting these here so they don't update in the meantime
|
||||
for item, value in self.item_values.items():
|
||||
item.default = value
|
||||
if self._interaction is not None:
|
||||
await self._interaction.delete_original_response()
|
||||
self._interaction = None
|
||||
await interaction.response.send_modal(self.modal)
|
||||
await self.close()
|
||||
|
||||
|
||||
class Confirm(LeoUI):
|
||||
"""
|
||||
Micro UI class implementing a confirmation question.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
confirm_msg: str
|
||||
The confirmation question to ask from the user.
|
||||
This is set as the description of the `embed` property.
|
||||
The `embed` may be further modified if required.
|
||||
permitted_id: Optional[int]
|
||||
The user id allowed to access this interaction.
|
||||
Other users will recieve an access denied error message.
|
||||
defer: bool
|
||||
Whether to defer the interaction response while handling the button.
|
||||
It may be useful to set this to `False` to obtain manual control
|
||||
over the interaction response flow (e.g. to send a modal or ephemeral message).
|
||||
The button press interaction may be accessed through `Confirm.interaction`.
|
||||
Default: True
|
||||
|
||||
Example
|
||||
-------
|
||||
```
|
||||
confirm = Confirm("Are you sure?", ctx.author.id)
|
||||
confirm.embed.colour = discord.Colour.red()
|
||||
confirm.confirm_button.label = "Yes I am sure"
|
||||
confirm.cancel_button.label = "No I am not sure"
|
||||
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||
except ResultTimedOut:
|
||||
return
|
||||
```
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
confirm_msg: str,
|
||||
permitted_id: Optional[int] = None,
|
||||
defer: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.confirm_msg = confirm_msg
|
||||
self.permitted_id = permitted_id
|
||||
self.defer = defer
|
||||
|
||||
self._embed: Optional[discord.Embed] = None
|
||||
self._result: asyncio.Future[bool] = asyncio.Future()
|
||||
|
||||
# Indicates whether we should delete the message or the interaction response
|
||||
self._is_followup: bool = False
|
||||
self._original: Optional[discord.Interaction] = None
|
||||
self._message: Optional[discord.Message] = None
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction):
|
||||
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
|
||||
|
||||
async def on_timeout(self):
|
||||
# Propagate timeout to result Future
|
||||
self._result.set_exception(ResponseTimedOut)
|
||||
await self.cleanup()
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
Cleanup the confirmation prompt by deleting it, if possible.
|
||||
Ignores any Discord errors that occur during the process.
|
||||
"""
|
||||
try:
|
||||
if self._is_followup and self._message:
|
||||
await self._message.delete()
|
||||
elif not self._is_followup and self._original and not self._original.is_expired():
|
||||
await self._original.delete_original_response()
|
||||
except discord.HTTPException:
|
||||
# A user probably already deleted the message
|
||||
# Anything could have happened, just ignore.
|
||||
pass
|
||||
|
||||
@button(label="Confirm")
|
||||
async def confirm_button(self, interaction: discord.Interaction, press):
|
||||
if self.defer:
|
||||
await interaction.response.defer()
|
||||
self._result.set_result(True)
|
||||
await self.close()
|
||||
|
||||
@button(label="Cancel")
|
||||
async def cancel_button(self, interaction: discord.Interaction, press):
|
||||
if self.defer:
|
||||
await interaction.response.defer()
|
||||
self._result.set_result(False)
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def embed(self):
|
||||
"""
|
||||
Confirmation embed shown to the user.
|
||||
This is cached, and may be modifed directly through the usual EmbedProxy API,
|
||||
or explicitly overwritten.
|
||||
"""
|
||||
if self._embed is None:
|
||||
self._embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=self.confirm_msg
|
||||
)
|
||||
return self._embed
|
||||
|
||||
@embed.setter
|
||||
def embed(self, value):
|
||||
self._embed = value
|
||||
|
||||
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
|
||||
"""
|
||||
Send this confirmation prompt in response to the provided interaction.
|
||||
Extra keyword arguments are passed to `Interaction.response.send_message`
|
||||
or `Interaction.send_followup`, depending on whether
|
||||
the provided interaction has already been responded to.
|
||||
|
||||
The `epehemeral` argument is handled specially,
|
||||
since the question message can only be deleted through `Interaction.delete_original_response`.
|
||||
|
||||
Waits on and returns the internal `result` Future.
|
||||
|
||||
Returns: bool
|
||||
True if the user pressed the confirm button.
|
||||
False if the user pressed the cancel button.
|
||||
Raises:
|
||||
ResponseTimedOut:
|
||||
If the user does not respond before the UI times out.
|
||||
"""
|
||||
self._original = interaction
|
||||
if interaction.response.is_done():
|
||||
# Interaction already responded to, send a follow up
|
||||
if ephemeral:
|
||||
raise ValueError("Cannot send an ephemeral response to a used interaction.")
|
||||
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
|
||||
self._is_followup = True
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
|
||||
)
|
||||
self._is_followup = False
|
||||
return await self._result
|
||||
456
src/utils/ui/pagers.py
Normal file
456
src/utils/ui/pagers.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
import discord
|
||||
from discord.ui.button import Button, button
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta.logger import log_action_stack, logging_context
|
||||
from meta.errors import SafeCancellation
|
||||
from meta.config import conf
|
||||
|
||||
from babel.translator import ctx_translator
|
||||
|
||||
from ..lib import MessageArgs, error_embed
|
||||
from .. import util_babel
|
||||
|
||||
from .leo import LeoUI
|
||||
|
||||
_p = util_babel._p
|
||||
|
||||
|
||||
__all__ = (
|
||||
'BasePager',
|
||||
'Pager',
|
||||
)
|
||||
|
||||
|
||||
class BasePager(LeoUI):
|
||||
"""
|
||||
An ABC describing the common interface for a Paging UI.
|
||||
|
||||
A paging UI represents a sequence of pages, accessible by `next` and `previous` buttons,
|
||||
and possibly by a dropdown (not implemented).
|
||||
|
||||
A `Page` is represented as a `MessageArgs` object, which is passable to `send` and `edit` methods as required.
|
||||
|
||||
Each page of a paging UI is accessed through the coroutine `get_page`.
|
||||
This allows for more complex paging schemes where the pages are expensive to compute,
|
||||
and not generally needed simultaneously.
|
||||
In general, `get_page` should cache expensive pages,
|
||||
perhaps simply with a `cached` decorator, but this is not enforced.
|
||||
|
||||
The state of the base UI is represented as the current `page_num` and the `current_page`.
|
||||
|
||||
This class also maintains an `active_pagers` cache,
|
||||
representing all `BasePager`s that are currently running.
|
||||
This allows access from external page controlling utilities, e.g. the `/page` command.
|
||||
"""
|
||||
# List of valid keys indicating movement to the next page
|
||||
next_list = _p('cmd:page|pager:Pager|options:next', "n, nxt, next, forward, +")
|
||||
|
||||
# List of valid keys indicating movement to the previous page
|
||||
prev_list = _p('cmd:page|pager:Pager|options:prev', "p, prev, back, -")
|
||||
|
||||
# List of valid keys indicating movement to the first page
|
||||
first_list = _p('cmd:page|pager:Pager|options:first', "f, first, one, start")
|
||||
|
||||
# List of valid keys indicating movement to the last page
|
||||
last_list = _p('cmd:page|pager:Pager|options:last', "l, last, end")
|
||||
# channelid -> pager.id -> list of active pagers in this channel
|
||||
active_pagers: dict[int, dict[int, 'BasePager']] = defaultdict(dict)
|
||||
|
||||
page_num: int
|
||||
current_page: MessageArgs
|
||||
_channelid: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def get_active_pager(self, channelid, userid):
|
||||
"""
|
||||
Get the last active pager in the `destinationid`, which may be accessed by `userid`.
|
||||
Returns None if there are no matching pagers.
|
||||
"""
|
||||
for pager in reversed(self.active_pagers[channelid].values()):
|
||||
if pager.access_check(userid):
|
||||
return pager
|
||||
|
||||
def set_active(self):
|
||||
if self._channelid is None:
|
||||
raise ValueError("Cannot set active without a channelid.")
|
||||
self.active_pagers[self._channelid][self.id] = self
|
||||
|
||||
def set_inactive(self):
|
||||
self.active_pagers[self._channelid].pop(self.id, None)
|
||||
|
||||
def access_check(self, userid):
|
||||
"""
|
||||
Check whether the given userid is allowed to use this UI.
|
||||
Must be overridden by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_page(self, page_id) -> MessageArgs:
|
||||
"""
|
||||
`get_page` returns the specified page number, starting from 0.
|
||||
An implementation of `get_page` must:
|
||||
- Always return a page (if no data is a valid state, must return a placeholder page).
|
||||
- Always accept out-of-range `page_id` values.
|
||||
- There is no behaviour specified for these, although they will usually be modded into the correct
|
||||
range.
|
||||
- In some cases (e.g. stream data where we don't have a last page),
|
||||
they may simply return the last correct page instead.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def page_cmd(self, interaction: discord.Interaction, value: str):
|
||||
"""
|
||||
Command implementation for the paging command.
|
||||
Pager subclasses should override this if they use `active_pagers`.
|
||||
Default implementation is essentially a no-op,
|
||||
simply replying to the interaction.
|
||||
"""
|
||||
await interaction.response.defer()
|
||||
return
|
||||
|
||||
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
Command autocompletion for the paging command.
|
||||
Pager subclasses should override this if they use `active_pagers`.
|
||||
"""
|
||||
return []
|
||||
|
||||
@button(emoji=conf.emojis.getemoji('forward'))
|
||||
async def next_page_button(self, interaction: discord.Interaction, press):
|
||||
await interaction.response.defer()
|
||||
self.page_num += 1
|
||||
await self.redraw()
|
||||
|
||||
@button(emoji=conf.emojis.getemoji('backward'))
|
||||
async def prev_page_button(self, interaction: discord.Interaction, press):
|
||||
await interaction.response.defer()
|
||||
self.page_num -= 1
|
||||
await self.redraw()
|
||||
|
||||
async def refresh(self):
|
||||
"""
|
||||
Recalculate current computed state.
|
||||
(E.g. fetch current page, set layout, disable components, etc.)
|
||||
"""
|
||||
self.current_page = await self.get_page(self.page_num)
|
||||
|
||||
async def redraw(self):
|
||||
"""
|
||||
This should refresh the current state and redraw the UI.
|
||||
Not implemented here, as the implementation depends on whether this is a reaction response ephemeral UI
|
||||
or a message=based one.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Pager(BasePager):
|
||||
"""
|
||||
MicroUI to display a sequence of static pages,
|
||||
supporting paging reaction and paging commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pages: list[MessageArgs]
|
||||
A non-empty list of message arguments to page.
|
||||
start_from: int
|
||||
The page number to display first.
|
||||
Default: 0
|
||||
locked: bool
|
||||
Whether to only allow the author to use the paging interface.
|
||||
"""
|
||||
|
||||
def __init__(self, pages: list[MessageArgs],
|
||||
start_from=0,
|
||||
show_cancel=False, delete_on_cancel=True, delete_after=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._pages = pages
|
||||
self.page_num = start_from
|
||||
self.current_page = pages[self.page_num]
|
||||
|
||||
self._locked = True
|
||||
self._ownerid: Optional[int] = None
|
||||
self._channelid: Optional[int] = None
|
||||
|
||||
if not pages:
|
||||
raise ValueError("Cannot run Pager with no pages.")
|
||||
|
||||
self._original: Optional[discord.Interaction] = None
|
||||
self._is_followup: bool = False
|
||||
self._message: Optional[discord.Message] = None
|
||||
|
||||
self.show_cancel = show_cancel
|
||||
self._delete_on_cancel = delete_on_cancel
|
||||
self._delete_after = delete_after
|
||||
|
||||
@property
|
||||
def ownerid(self):
|
||||
if self._ownerid is not None:
|
||||
return self._ownerid
|
||||
elif self._original:
|
||||
return self._original.user.id
|
||||
else:
|
||||
return None
|
||||
|
||||
def access_check(self, userid):
|
||||
return not self._locked or (userid == self.ownerid)
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction):
|
||||
return self.access_check(interaction.user.id)
|
||||
|
||||
@button(emoji=conf.emojis.getemoji('cancel'))
|
||||
async def cancel_button(self, interaction: discord.Interaction, press: Button):
|
||||
await interaction.response.defer()
|
||||
if self._delete_on_cancel:
|
||||
self._delete_after = True
|
||||
await self.close()
|
||||
|
||||
async def cleanup(self):
|
||||
self.set_inactive()
|
||||
|
||||
# If we still have a message, delete it or clear the view
|
||||
try:
|
||||
if self._is_followup:
|
||||
if self._message:
|
||||
if self._delete_after:
|
||||
await self._message.delete()
|
||||
else:
|
||||
await self._message.edit(view=None)
|
||||
else:
|
||||
if self._original and not self._original.is_expired():
|
||||
if self._delete_after:
|
||||
await self._original.delete_original_response()
|
||||
else:
|
||||
await self._original.edit_original_response(view=None)
|
||||
except discord.HTTPException:
|
||||
# Nothing we can do here
|
||||
pass
|
||||
|
||||
async def get_page(self, page_id):
|
||||
page_id %= len(self._pages)
|
||||
return self._pages[page_id]
|
||||
|
||||
def page_count(self):
|
||||
return len(self.pages)
|
||||
|
||||
async def page_cmd(self, interaction: discord.Interaction, value: str):
|
||||
"""
|
||||
`/page` command for the `Pager` MicroUI.
|
||||
"""
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
t = ctx_translator.get().t
|
||||
nexts = {word.strip() for word in t(self.next_list).split(',')}
|
||||
prevs = {word.strip() for word in t(self.prev_list).split(',')}
|
||||
firsts = {word.strip() for word in t(self.first_list).split(',')}
|
||||
lasts = {word.strip() for word in t(self.last_list).split(',')}
|
||||
|
||||
if value:
|
||||
value = value.lower().strip()
|
||||
if value.isdigit():
|
||||
# Assume value is page number
|
||||
self.page_num = int(value) - 1
|
||||
if self.page_num == -1:
|
||||
self.page_num = 0
|
||||
elif value in firsts:
|
||||
self.page_num = 0
|
||||
elif value in nexts:
|
||||
self.page_num += 1
|
||||
elif value in prevs:
|
||||
self.page_num -= 1
|
||||
elif value in lasts:
|
||||
self.page_num = -1
|
||||
elif value.startswith('-') and value[1:].isdigit():
|
||||
self.page_num = - int(value[1:])
|
||||
else:
|
||||
await interaction.edit_original_response(
|
||||
embed=error_embed(
|
||||
t(_p(
|
||||
'cmd:page|pager:Pager|error:parse',
|
||||
"Could not understand page specification `{value}`."
|
||||
)).format(value=value)
|
||||
)
|
||||
)
|
||||
return
|
||||
await interaction.delete_original_response()
|
||||
await self.redraw()
|
||||
|
||||
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
`/page` command autocompletion for the `Pager` MicroUI.
|
||||
"""
|
||||
t = ctx_translator.get().t
|
||||
nexts = {word.strip() for word in t(self.next_list).split(',')}
|
||||
prevs = {word.strip() for word in t(self.prev_list).split(',')}
|
||||
firsts = {word.strip() for word in t(self.first_list).split(',')}
|
||||
lasts = {word.strip() for word in t(self.last_list).split(',')}
|
||||
|
||||
total = len(self._pages)
|
||||
num = self.page_num
|
||||
page_choices: dict[int, str] = {}
|
||||
|
||||
# TODO: Support page names and hints?
|
||||
|
||||
if len(self._pages) > 10:
|
||||
# First add the general choices
|
||||
if num < total-1:
|
||||
page_choices[total-1] = t(_p(
|
||||
'cmd:page|acmpl|pager:Pager|choice:last',
|
||||
"Last: Page {page}/{total}"
|
||||
)).format(page=total, total=total)
|
||||
|
||||
page_choices[num] = t(_p(
|
||||
'cmd:page|acmpl|pager:Pager|choice:current',
|
||||
"Current: Page {page}/{total}"
|
||||
)).format(page=num+1, total=total)
|
||||
choices = [
|
||||
appcmds.Choice(name=string, value=str(num+1))
|
||||
for num, string in sorted(page_choices.items(), key=lambda t: t[0])
|
||||
]
|
||||
else:
|
||||
# Particularly support page names here
|
||||
choices = [
|
||||
appcmds.Choice(
|
||||
name='> ' * (i == num) + t(_p(
|
||||
'cmd:page|acmpl|pager:Pager|choice:general',
|
||||
"Page {page}"
|
||||
)).format(page=i+1),
|
||||
value=str(i+1)
|
||||
)
|
||||
for i in range(0, total)
|
||||
]
|
||||
|
||||
partial = partial.strip()
|
||||
|
||||
if partial:
|
||||
value = partial.lower().strip()
|
||||
if value.isdigit():
|
||||
# Assume value is page number
|
||||
page_num = int(value) - 1
|
||||
if page_num == -1:
|
||||
page_num = 0
|
||||
elif value in firsts:
|
||||
page_num = 0
|
||||
elif value in nexts:
|
||||
page_num = self.page_num + 1
|
||||
elif value in prevs:
|
||||
page_num = self.page_num - 1
|
||||
elif value in lasts:
|
||||
page_num = -1
|
||||
elif value.startswith('-') and value[1:].isdigit():
|
||||
page_num = - int(value[1:])
|
||||
else:
|
||||
page_num = None
|
||||
|
||||
if page_num is not None:
|
||||
page_num %= total
|
||||
choice = appcmds.Choice(
|
||||
name=t(_p(
|
||||
'cmd:page|acmpl|pager:Page|choice:select',
|
||||
"Selected: Page {page}/{total}"
|
||||
)).format(page=page_num+1, total=total),
|
||||
value=str(page_num + 1)
|
||||
)
|
||||
return [choice, *choices]
|
||||
else:
|
||||
return [
|
||||
appcmds.Choice(
|
||||
name=t(_p(
|
||||
'cmd:page|acmpl|pager:Page|error:parse',
|
||||
"No matching pages!"
|
||||
)).format(page=page_num, total=total),
|
||||
value=partial
|
||||
)
|
||||
]
|
||||
else:
|
||||
return choices
|
||||
|
||||
@property
|
||||
def page_row(self):
|
||||
if self.show_cancel:
|
||||
if len(self._pages) > 1:
|
||||
return (self.prev_page_button, self.cancel_button, self.next_page_button)
|
||||
else:
|
||||
return (self.cancel_button,)
|
||||
else:
|
||||
if len(self._pages) > 1:
|
||||
return (self.prev_page_button, self.next_page_button)
|
||||
else:
|
||||
return ()
|
||||
|
||||
async def refresh(self):
|
||||
await super().refresh()
|
||||
self.set_layout(self.page_row)
|
||||
|
||||
async def redraw(self):
|
||||
await self.refresh()
|
||||
|
||||
if not self._original:
|
||||
raise ValueError("Running run pager manually without interaction.")
|
||||
|
||||
try:
|
||||
if self._message:
|
||||
await self._message.edit(**self.current_page.edit_args, view=self)
|
||||
else:
|
||||
if self._original.is_expired():
|
||||
raise SafeCancellation("This interface has expired, please try again.")
|
||||
await self._original.edit_original_response(**self.current_page.edit_args, view=self)
|
||||
except discord.HTTPException:
|
||||
raise SafeCancellation("Could not page your results! Please try again.")
|
||||
|
||||
async def run(self, interaction: discord.Interaction, ephemeral=False, locked=True, ownerid=None, **kwargs):
|
||||
"""
|
||||
Display the UI.
|
||||
Attempts to reply to the interaction if it has not already been replied to,
|
||||
otherwise send a follow-up.
|
||||
|
||||
An ephemeral response must be sent as an initial interaction response.
|
||||
On the other hand, a semi-persistent response (expected to last longer than the lifetime of the interaction)
|
||||
must be sent as a followup.
|
||||
|
||||
Extra kwargs are combined with the first page arguments and given to the relevant send method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interaction: discord.Interaction
|
||||
The interaction to send the pager in response to.
|
||||
ephemeral: bool
|
||||
Whether to send the interaction ephemerally.
|
||||
If this is true, the interaction *must* be fresh (i.e. no response done).
|
||||
Default: False
|
||||
locked: bool
|
||||
Whether this interface is locked to the user `self.ownerid`.
|
||||
Irrelevant for ephemeral messages.
|
||||
Use `ownerid` to override the default owner id.
|
||||
Defaults to true for fail-safety.
|
||||
Default: True
|
||||
ownerid: Optional[int]
|
||||
The userid allowed to use this interaction.
|
||||
By default, this will be the `interaction.user.id`,
|
||||
presuming that this is the user which originally triggered this message.
|
||||
An override may be useful if a user triggers a paging UI for someone else.
|
||||
"""
|
||||
if not interaction.channel_id:
|
||||
raise ValueError("Cannot run pager on a channelless interaction.")
|
||||
|
||||
self._original = interaction
|
||||
self._ownerid = ownerid
|
||||
self._locked = locked
|
||||
self._channelid = interaction.channel_id
|
||||
|
||||
await self.refresh()
|
||||
args = self.current_page.send_args | kwargs
|
||||
|
||||
if interaction.response.is_done():
|
||||
if ephemeral:
|
||||
raise ValueError("Ephemeral response requires fres interaction.")
|
||||
self._message = await interaction.followup.send(**args, view=self)
|
||||
self._is_followup = True
|
||||
else:
|
||||
self._is_followup = False
|
||||
await interaction.response.send_message(**args, view=self)
|
||||
|
||||
self.set_active()
|
||||
91
src/utils/ui/transformed.py
Normal file
91
src/utils/ui/transformed.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
from enum import Enum
|
||||
|
||||
import discord
|
||||
import discord.app_commands as appcmd
|
||||
from discord.app_commands.transformers import AppCommandOptionType
|
||||
|
||||
|
||||
__all__ = (
|
||||
'ChoicedEnum',
|
||||
'ChoicedEnumTransformer',
|
||||
'Transformed',
|
||||
)
|
||||
|
||||
|
||||
class ChoicedEnum(Enum):
|
||||
@property
|
||||
def choice_name(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def choice_value(self):
|
||||
return self.value
|
||||
|
||||
@property
|
||||
def choice(self):
|
||||
return appcmd.Choice(
|
||||
name=self.choice_name, value=self.choice_value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def choices(self):
|
||||
return [item.choice for item in self]
|
||||
|
||||
@classmethod
|
||||
def make_choice_map(cls):
|
||||
return {item.choice_value: item for item in cls}
|
||||
|
||||
@classmethod
|
||||
async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any):
|
||||
return transformer._choice_map[value]
|
||||
|
||||
@classmethod
|
||||
def option_type(cls) -> AppCommandOptionType:
|
||||
return AppCommandOptionType.string
|
||||
|
||||
@classmethod
|
||||
def transformer(cls, *args) -> appcmd.Transformer:
|
||||
return ChoicedEnumTransformer(cls, *args)
|
||||
|
||||
|
||||
class ChoicedEnumTransformer(appcmd.Transformer):
|
||||
# __discord_app_commands_is_choice__ = True
|
||||
|
||||
def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._type = opt_type
|
||||
self._enum = enum
|
||||
self._choices = enum.choices()
|
||||
self._choice_map = enum.make_choice_map()
|
||||
|
||||
@property
|
||||
def _error_display_name(self) -> str:
|
||||
return self._enum.__name__
|
||||
|
||||
@property
|
||||
def type(self) -> AppCommandOptionType:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def choices(self):
|
||||
return self._choices
|
||||
|
||||
async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any:
|
||||
return await self._enum.transform(self, interaction, value)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Annotated as Transformed
|
||||
else:
|
||||
|
||||
class Transformed:
|
||||
def __class_getitem__(self, items):
|
||||
cls = items[0]
|
||||
options = items[1:]
|
||||
|
||||
if not hasattr(cls, 'transformer'):
|
||||
raise ValueError("Tranformed class must have a transformer classmethod.")
|
||||
transformer = cls.transformer(*options)
|
||||
return appcmd.Transform[cls, transformer]
|
||||
Reference in New Issue
Block a user