rewrite: Restructure to include GUI.

This commit is contained in:
2022-12-23 06:44:32 +02:00
parent 2b93354248
commit f328324747
224 changed files with 8 additions and 0 deletions

8
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
from babel.translator import LocalBabel
util_babel = LocalBabel('utils')
async def setup(bot):
from .cog import MetaUtils
await bot.add_cog(MetaUtils(bot))

102
src/utils/cog.py Normal file
View File

@@ -0,0 +1,102 @@
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionContext, LionCog
from .ui import BasePager
from . import util_babel as babel
_p = babel._p
class MetaUtils(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
@cmds.hybrid_group(
name=_p('cmd:page', 'page'),
description=_p(
'cmd:page|desc',
"Jump to a given page of the ouput of a previous command in this channel."
),
)
async def page_group(self, ctx: LionContext):
"""
No description.
"""
pass
async def page_jump(self, ctx: LionContext, jumper):
pager = BasePager.get_active_pager(ctx.channel.id, ctx.author.id)
if pager is None:
await ctx.error_reply(
_p('cmd:page|error:no_pager', "No pager listening in this channel!")
)
else:
if ctx.interaction:
await ctx.interaction.response.defer()
pager.page_num = jumper(pager)
await pager.redraw()
if ctx.interaction:
await ctx.interaction.delete_original_response()
@page_group.command(
name=_p('cmd:page_next', 'next'),
description=_p('cmd:page_next|desc', "Jump to the next page of output.")
)
async def next_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: pager.page_num + 1)
@page_group.command(
name=_p('cmd:page_prev', 'prev'),
description=_p('cmd:page_prev|desc', "Jump to the previous page of output.")
)
async def prev_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: pager.page_num - 1)
@page_group.command(
name=_p('cmd:page_first', 'first'),
description=_p('cmd:page_first|desc', "Jump to the first page of output.")
)
async def first_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: 0)
@page_group.command(
name=_p('cmd:page_last', 'last'),
description=_p('cmd:page_last|desc', "Jump to the last page of output.")
)
async def last_cmd(self, ctx: LionContext):
await self.page_jump(ctx, lambda pager: -1)
@page_group.command(
name=_p('cmd:page_select', 'select'),
description=_p('cmd:page_select|desc', "Select a page of the output to jump to.")
)
@appcmds.rename(
page=_p('cmd:page_select|param:page', 'page')
)
@appcmds.describe(
page=_p('cmd:page_select|param:page|desc', "The page name or number to jump to.")
)
async def page_cmd(self, ctx: LionContext, page: str):
pager = BasePager.get_active_pager(ctx.channel.id, ctx.author.id)
if pager is None:
await ctx.error_reply(
_p('cmd:page_select|error:no_pager', "No pager listening in this channel!")
)
else:
await pager.page_cmd(ctx.interaction, page)
@page_cmd.autocomplete('page')
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
pager = BasePager.get_active_pager(interaction.channel_id, interaction.user.id)
if pager is None:
return [
appcmds.Choice(
name=_p('cmd:page_select|acmpl|error:no_pager', "No active pagers in this channel!"),
value=partial
)
]
else:
return await pager.page_acmpl(interaction, partial)

708
src/utils/lib.py Normal file
View File

@@ -0,0 +1,708 @@
from typing import NamedTuple, Optional, Sequence, Union, overload, List
import datetime
import iso8601 # type: ignore
import re
from contextvars import Context
import discord
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord.ui import View
from babel.translator import ctx_translator
from . import util_babel
_, _p = util_babel._, util_babel._p
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
class MessageArgs:
"""
Utility class for storing message creation and editing arguments.
"""
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
def __init__(self, **kwargs):
self.kwargs = kwargs
@property
def send_args(self) -> dict:
return self.kwargs
@property
def edit_args(self) -> dict:
args = {}
kept = (
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
)
for k in kept:
if k in self.kwargs:
args[k] = self.kwargs[k]
if 'file' in self.kwargs:
args['attachments'] = [self.kwargs['file']]
if 'files' in self.kwargs:
args['attachments'] = self.kwargs['files']
if 'suppress_embeds' in self.kwargs:
args['suppress'] = self.kwargs['suppress_embeds']
return args
def tabulate(
*fields: tuple[str, str],
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
sub_format: str = "`{invis:<{pad}}{invis}`\t{value}",
colon: str = ':',
invis: str = "",
**args
) -> list[str]:
"""
Turns a list of (property, value) pairs into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Use `\\r\\n` in a value to break the line with padding.
Parameters
----------
fields: List[tuple[str, str]]
List of (key, value) pairs.
row_format: str
The format string used to format each row.
sub_format: str
The format string used to format each subline in a row.
colon: str
The colon character used.
invis: str
The invisible character used (to avoid Discord stripping the string).
Returns: List[str]
The list of resulting table rows.
Each row corresponds to one (key, value) pair from fields.
"""
max_len = max(len(field[0]) for field in fields)
rows = []
for field in fields:
key = field[0]
value = field[1]
lines = value.split('\r\n')
row_line = row_format.format(
invis=invis,
key=key,
pad=max_len,
colon=colon,
value=lines[0],
field=field,
**args
)
if len(lines) > 1:
row_lines = [row_line]
for line in lines[1:]:
sub_line = sub_format.format(
invis=invis,
pad=max_len + len(colon),
value=line,
**args
)
row_lines.append(sub_line)
row_line = '\n'.join(row_lines)
rows.append(row_line)
return rows
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def _parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def strfdur(duration: int, short=True, show_days=False) -> str:
"""
Convert a duration given in seconds to a number of hours, minutes, and seconds.
"""
days = duration // (3600 * 24) if show_days else 0
hours = duration // 3600
if days:
hours %= 24
minutes = duration // 60 % 60
seconds = duration % 60
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
if short:
return ' '.join(parts)
else:
return ', '.join(parts)
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
# TODO: Upgrade to SafeCancellation
raise ValueError("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
_numbers = (item.strip() for item in substituted.split(','))
numbers = [item for item in _numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
if not ignore_errors and len(integers) != len(numbers):
# TODO: Upgrade to SafeCancellation
raise ValueError(
"Couldn't parse the provided selection!\n"
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
)
return integers
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.created_at.strftime(timestr)
user = str(msg.author)
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring: str) -> datetime.timedelta:
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message:
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
class EmbedField(NamedTuple):
name: str
value: str
inline: Optional[bool] = True
def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string: list[str], nfs=False) -> str:
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
# TODO: Probably not useful with localisation
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def shard_of(shard_count: int, guildid: int) -> int:
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)
def utc_now() -> datetime.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
return string
def recover_context(context: Context):
for var in context:
var.set(context[var])
def parse_ids(idstr: str) -> List[int]:
"""
Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids.
Object agnostic, so all mention tokens are stripped.
Raises UserInputError if an id is invalid,
setting `orig` and `item` info fields.
"""
from meta.errors import UserInputError
# Extract ids from string
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
splits = [split for split in splititer if split]
# Check they are integers
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
# Cast to integer and return
return list(map(int, splits))
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
)
return embed
class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
parse_dur_exps = [
(
_p(
'util:parse_dur|regex:day',
r"(?P<value>\d+)\s*(?:(d)|(day))"
),
60 * 60 * 24
),
(
_p(
'util:parse_dur|regex:hour',
r"(?P<value>\d+)\s*(?:(h)|(hour))"
),
60 * 60
),
(
_p(
'util:parse_dur|regex:minute',
r"(?P<value>\d+)\s*(?:(m)|(min))"
),
60
),
(
_p(
'util:parse_dur|regex:second',
r"(?P<value>\d+)\s*(?:(s)|(sec))"
),
1
)
]
def parse_duration(string: str) -> Optional[int]:
translator = ctx_translator.get()
if translator is None:
raise ValueError("Cannot parse duration without a translator.")
t = translator.t
seconds = 0
found = False
for expr, multiplier in parse_dur_exps:
match = re.search(t(expr), string, flags=re.IGNORECASE)
if match:
found = True
seconds += int(match.group('value')) * multiplier
return seconds if found else None

178
src/utils/monitor.py Normal file
View File

@@ -0,0 +1,178 @@
import asyncio
import bisect
import logging
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
from .lib import utc_now
from .ratelimits import Bucket
logger = logging.getLogger(__name__)
Taskid = TypeVar('Taskid')
class TaskMonitor(Generic[Taskid]):
"""
Base class for a task monitor.
Stores tasks as a time-sorted list of taskids.
Subclasses may override `run_task` to implement an executor.
Adding or removing a single task has O(n) performance.
To bulk update tasks, instead use `schedule_tasks`.
Each taskid must be unique and hashable.
"""
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
# Ratelimit bucket to enforce maximum execution rate
self._bucket = bucket
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
self._wakeup: asyncio.Event = asyncio.Event()
self._monitor_task: Optional[self.Task] = None
# Task data
self._tasklist: list[Taskid] = []
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
# Running map ensures we keep a reference to the running task
# And allows simpler external cancellation if required
self._running: dict[Taskid, asyncio.Future] = {}
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Similar to `schedule_tasks`, but wipe and reset the tasklist.
"""
self._taskmap = {tid: time for tid, time in tasks}
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * tid * self._taskmap[tid])
self._wakeup.set()
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Schedule the given tasks.
Rather than repeatedly inserting tasks,
where the O(log n) insort is dominated by the O(n) list insertion,
we build an entirely new list, and always wake up the loop.
"""
self._taskmap |= {tid: time for tid, time in tasks}
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])
self._wakeup.set()
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
"""
Insert the provided task into the tasklist.
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
"""
if self._tasklist:
nextid = self._tasklist[-1]
wake = self._taskmap[nextid] >= timestamp
wake = wake or taskid == nextid
else:
wake = False
if taskid in self._taskmap:
self._tasklist.remove(taskid)
self._taskmap[taskid] = timestamp
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
if wake:
self._wakeup.set()
def cancel_tasks(self, *taskids: Taskid) -> None:
"""
Remove all tasks with the given taskids from the tasklist.
If the next task has this taskid, wake up the monitor loop.
"""
taskids = set(taskids)
wake = (self._tasklist and self._tasklist[-1] in taskids)
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
for tid in taskids:
self._taskmap.pop(tid, None)
if wake:
self._wakeup.set()
def start(self):
if self._monitor_task and not self._monitor_task.done():
self._monitor_task.cancel()
# Start the monitor
self._monitor_task = asyncio.create_task(self.monitor())
return self._monitor_task
async def monitor(self):
"""
Start the monitor.
Executes the tasks in `self.tasks` at the specified time.
This will shield task execution from cancellation
to avoid partial states.
"""
try:
while True:
self._wakeup.clear()
if not self._tasklist:
# No tasks left, just sleep until wakeup
await self._wakeup.wait()
else:
# Get the next task, sleep until wakeup or it is ready to run
nextid = self._tasklist[-1]
nexttime = self._taskmap[nextid]
sleep_for = nexttime - utc_now().timestamp()
try:
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
except asyncio.TimeoutError:
# Ready to run the task
self._tasklist.pop()
self._taskmap.pop(nextid, None)
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
else:
# Wakeup task fired, loop again
continue
except asyncio.CancelledError:
# Log closure and wait for remaining tasks
# A second cancellation will also cancel the tasks
logger.debug(
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
f"Waiting for {len(self._running)} running tasks to complete."
)
await asyncio.gather(*self._running.values(), return_exceptions=True)
async def _run(self, taskid: Taskid) -> None:
# Execute the task, respecting the ratelimit bucket
if self._bucket is not None:
# IMPLEMENTATION NOTE:
# Bucket.wait() should guarantee not more than n tasks/second are run
# and that a request directly afterwards will _not_ raise BucketFull
# Make sure that only one waiter is actually waiting on its sleep task
# The other waiters should be waiting on a lock around the sleep task
# Waiters are executed in wait-order, so if we only let a single waiter in
# we shouldn't get collisions.
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
# Or we will lose thread-safety for BucketFull
await self._bucket.wait()
fut = asyncio.create_task(self.run_task(taskid))
try:
await asyncio.shield(fut)
except asyncio.CancelledError:
raise
except Exception:
# Protect the monitor loop from any other exceptions
logger.exception(
f"Ignoring exception in task monitor {self.__class__.__name__} while "
f"executing <taskid: {taskid}>"
)
finally:
self._running.pop(taskid)
async def run_task(self, taskid: Taskid):
"""
Execute the task with the given taskid.
Default implementation executes `self.executor` if it exists,
otherwise raises NotImplementedError.
"""
if self.executor is not None:
await self.executor(taskid)
else:
raise NotImplementedError

125
src/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,125 @@
import asyncio
import time
from meta.errors import SafeCancellation
from cachetools import TTLCache
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
return (self._level + 1 - self.max_level) * self.leak_rate
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level + 1 > self.max_level + 1:
raise BucketOverFull
elif self._level + 1 > self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
try:
bucket.request()
except BucketOverFull:
raise SafeCancellation(details="Bucket overflow")
except BucketFull:
raise SafeCancellation(self.error, details="Bucket full")
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator

77
src/utils/transformers.py Normal file
View File

@@ -0,0 +1,77 @@
import discord
from discord import app_commands as appcmds
from discord.app_commands import Transformer
from discord.enums import AppCommandOptionType
from meta.errors import UserInputError
from babel.translator import ctx_translator
from .lib import parse_duration, strfdur
from . import util_babel
_, _p = util_babel._, util_babel._p
class DurationTransformer(Transformer):
"""
Duration parameter, with included autocompletion.
"""
def __init__(self, multiplier=1):
# Multiplier used for a raw integer value
self.multiplier = multiplier
@property
def type(self):
return AppCommandOptionType.string
async def transform(self, interaction: discord.Interaction, value: str) -> int:
"""
Returns the number of seconds in the parsed duration.
Raises UserInputError if the duration cannot be parsed.
"""
translator = ctx_translator.get()
t = translator.t
if value.isdigit():
return int(value) * self.multiplier
duration = parse_duration(value)
if duration is None:
raise UserInputError(
t(_p('utils:parse_dur|error', "Cannot parse `{value}` as a duration.")).format(
value=value
)
)
return duration or 0
async def autocomplete(self, interaction: discord.Interaction, partial: str):
"""
Default autocomplete for Duration parameters.
Attempts to parse the partial value as a duration, and reformat it as an autocomplete choice.
If not possible, displays an error message.
"""
translator = ctx_translator.get()
t = translator.t
if partial.isdigit():
duration = int(partial) * self.multiplier
else:
duration = parse_duration(partial)
if duration is None:
choice = appcmds.Choice(
name=t(_p(
'util:Duration|acmpl|error',
"Cannot extract duration from \"{partial}\""
)).format(partial=partial),
value=partial
)
else:
choice = appcmds.Choice(
name=strfdur(duration, short=False, show_days=True),
value=partial
)
return [choice]

20
src/utils/ui/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
import asyncio
import logging
from .. import util_babel
logger = logging.getLogger(__name__)
from .hooked import *
from .leo import *
from .micros import *
from .pagers import *
from .transformed import *
# def create_task_in(coro, context: Context):
# """
# Transitional.
# Since py3.10 asyncio does not support context instantiation,
# this helper method runs `asyncio.create_task(coro)` inside the given context.
# """
# return context.run(asyncio.create_task, coro)

59
src/utils/ui/hooked.py Normal file
View File

@@ -0,0 +1,59 @@
import time
import discord
from discord.ui.item import Item
from discord.ui.button import Button
from .leo import LeoUI
__all__ = (
'HookedItem',
'AButton',
'AsComponents'
)
class HookedItem:
"""
Mixin for Item classes allowing an instance to be used as a callback decorator.
"""
def __init__(self, *args, pass_kwargs={}, **kwargs):
super().__init__(*args, **kwargs)
self.pass_kwargs = pass_kwargs
def __call__(self, coro):
async def wrapped(interaction, **kwargs):
return await coro(interaction, self, **(self.pass_kwargs | kwargs))
self.callback = wrapped
return self
class AButton(HookedItem, Button):
...
class AsComponents(LeoUI):
"""
Simple container class to accept a number of Items and turn them into an attachable View.
"""
def __init__(self, *items, pass_kwargs={}, **kwargs):
super().__init__(**kwargs)
self.pass_kwargs = pass_kwargs
for item in items:
self.add_item(item)
async def _scheduled_task(self, item: Item, interaction: discord.Interaction):
try:
item._refresh_state(interaction, interaction.data) # type: ignore
allow = await self.interaction_check(interaction)
if not allow:
return
if self.timeout:
self.__timeout_expiry = time.monotonic() + self.timeout
await item.callback(interaction, **self.pass_kwargs)
except Exception as e:
return await self.on_error(interaction, e, item)

247
src/utils/ui/leo.py Normal file
View File

@@ -0,0 +1,247 @@
from typing import List, Optional, Any, Dict
import asyncio
import logging
import time
from contextvars import copy_context, Context
import discord
from discord.ui import Modal, View, Item
from meta.logger import log_action_stack, logging_context
from . import logger
__all__ = (
'LeoUI',
'LeoModal',
'error_handler_for'
)
class LeoUI(View):
"""
View subclass for small-scale user interfaces.
While a 'View' provides an interface for managing a collection of components,
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
The `LeoUI` also exposes more advanced cleanup and timeout methods,
and preserves the context.
"""
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._name = ui_name or self.__class__.__name__
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
# List of slaved views to stop when this view stops
self._slaves: List[View] = []
# TODO: Replace this with a substitutable ViewLayout class
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
def to_components(self) -> List[Dict[str, Any]]:
"""
Extending component generator to apply the set _layout, if it exists.
"""
if self._layout is not None:
# Alternative rendering using layout
components = []
for i, row in enumerate(self._layout):
# Skip empty rows
if not row:
continue
# Since we aren't relying on ViewWeights, manually check width here
if sum(item.width for item in row) > 5:
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
# Create the component dict for this row
components.append({
'type': 1,
'components': [item.to_component_dict() for item in row]
})
else:
components = super().to_components()
return components
def set_layout(self, *rows: tuple[Item, ...]) -> None:
"""
Set the layout of the rendered View as a matrix of items,
or more precisely, a list of action rows.
This acts independently of the existing sorting with `_ViewWeights`,
and overrides the sorting if applied.
"""
self._layout = rows
async def cleanup(self):
"""
Coroutine to run when timeing out, stopping, or cancelling.
Generally cleans up any open resources, and removes any leftover components.
"""
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
return None
def stop(self):
"""
Extends View.stop() to also stop all the slave views.
Note that stopping is idempotent, so it is okay if close() also calls stop().
"""
for slave in self._slaves:
slave.stop()
super().stop()
async def close(self, msg=None):
self.stop()
await self.cleanup()
async def pre_timeout(self):
"""
Task to execute before actually timing out.
This may cancel the timeout by refreshing or rescheduling it.
(E.g. to ask the user whether they want to keep going.)
Default implementation does nothing.
"""
return None
async def on_timeout(self):
"""
Task to execute after timeout is complete.
Default implementation calls cleanup.
"""
await self.cleanup()
async def __dispatch_timeout(self):
"""
This essentially extends View._dispatch_timeout,
to include a pre_timeout task
which may optionally refresh and hence cancel the timeout.
"""
if self.__stopped.done():
# We are already stopped, nothing to do
return
with logging_context(action='Timeout'):
try:
await self.pre_timeout()
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
pass
except Exception:
await logger.exception(
"Unhandled error caught while dispatching timeout for {self!r}.",
extra={'with_ctx': True, 'action': 'Error'}
)
# Check if we still need to timeout
if self.timeout is None:
# The timeout was removed entirely, silently walk away
return
if self.__stopped.done():
# We stopped while waiting for the pre timeout.
# Or maybe another thread timed us out
# Either way, we are done here
return
now = time.monotonic()
if self.__timeout_expiry is not None and now < self._timeout_expiry:
# The timeout was extended, make sure the timeout task is running then fade away
if self.__timeout_task is None or self.__timeout_task.done():
self.__timeout_task = asyncio.create_task(self.__timeout_task_impl())
else:
# Actually timeout, and call the post-timeout task for cleanup.
self._really_timeout()
await self.on_timeout()
def _dispatch_timeout(self):
"""
Overriding timeout method completely, to support interactive flow during timeout,
and optional refreshing of the timeout.
"""
return self._context.run(asyncio.create_task, self.dispatch_timeout())
def _really_timeout(self):
"""
Actuallly times out the View.
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
which is now handled by `__dispatch_timeout`.
"""
if self.__stopped.done():
return
if self.__cancel_callback:
self.__cancel_callback(self)
self.__cancel_callback = None
self.__stopped.set_result(True)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
"""
Default LeoUI error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r}",
extra={'with_ctx': True, 'action': 'UIError'}
)
class LeoModal(Modal):
"""
Context-aware Modal class.
"""
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
super().__init__(**kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
def _dispatch_submit(self, *args, **kwargs):
"""
Extending event dispatch to run in the instantiation context.
"""
return self._context.run(super()._dispatch_submit, *args, **kwargs)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
"""
Default LeoModal error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in {self!r}",
extra={'with_ctx': True, 'action': 'ModalError'}
)
def error_handler_for(exc):
def wrapper(coro):
coro._ui_error_handler_for_ = exc
return coro
return wrapper

315
src/utils/ui/micros.py Normal file
View File

@@ -0,0 +1,315 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
import functools
import asyncio
import discord
from discord.ui import TextInput
from discord.ui.button import button
from meta.logger import logging_context
from meta.errors import ResponseTimedOut
from .leo import LeoModal, LeoUI
__all__ = (
'FastModal',
'ModalRetryUI',
'Confirm',
'input',
)
class FastModal(LeoModal):
__class_error_handlers__ = []
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
error_handlers = {}
for base in reversed(cls.__mro__):
for name, member in base.__dict__.items():
if hasattr(member, '_ui_error_handler_for_'):
error_handlers[name] = member
cls.__class_error_handlers__ = list(error_handlers.values())
def __init__error_handlers__(self):
handlers = {}
for handler in self.__class_error_handlers__:
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
return handlers
def __init__(self, *items: TextInput, **kwargs):
super().__init__(**kwargs)
for item in items:
self.add_item(item)
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
self._error_handlers = self.__init__error_handlers__()
def error_handler(self, exception):
def wrapper(coro):
self._error_handlers[exception] = coro
return coro
return wrapper
async def wait_for(self, check=None, timeout=None):
# Wait for _result or timeout
# If we timeout, or the view times out, raise TimeoutError
# Otherwise, return the Interaction
# This allows multiple listeners and callbacks to wait on
while True:
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
if check is not None:
if not check(result):
continue
return result
async def on_timeout(self):
self._result.set_exception(asyncio.TimeoutError)
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
def wrapper(coro):
async def wrapped_callback(interaction):
with logging_context(action=coro.__name__):
if check is not None:
if not check(interaction):
return
try:
await coro(interaction, *pass_args, **pass_kwargs)
except Exception as error:
await self.on_error(interaction, error)
finally:
if once:
self._waiters.remove(wrapped_callback)
self._waiters.append(wrapped_callback)
return wrapper
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
try:
# First let our error handlers have a go
# If there is no handler for this error, or the handlers themselves error,
# drop to the superclass error handler implementation.
try:
raise error
except tuple(self._error_handlers.keys()) as e:
# If an error handler is registered for this exception, run it.
for cls, handler in self._error_handlers.items():
if isinstance(e, cls):
await handler(interaction, e)
except Exception as error:
await super().on_error(interaction, error)
async def on_submit(self, interaction):
old_result = self._result
self._result = asyncio.get_event_loop().create_future()
old_result.set_result(interaction)
for waiter in self._waiters:
asyncio.create_task(waiter(interaction), name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}")
async def input(
interaction: discord.Interaction,
title: str,
question: Optional[str] = None,
field: Optional[TextInput] = None,
timeout=180,
**kwargs,
):
"""
Spawn a modal to accept input.
Returns an (interaction, value) pair, with interaction not yet responded to.
May raise asyncio.TimeoutError if the view times out.
"""
if field is None:
field = TextInput(
label=kwargs.get('label', question),
**kwargs
)
modal = FastModal(
field,
title=title,
timeout=timeout
)
await interaction.response.send_modal(modal)
interaction = await modal.wait_for()
return (interaction, field.value)
class ModalRetryUI(LeoUI):
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.modal = modal
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
self.message = message
self._interaction = None
if label is not None:
self.retry_button.label = label
@property
def embed(self):
return discord.Embed(
description=self.message,
colour=discord.Colour.red()
)
async def respond_to(self, interaction):
self._interaction = interaction
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
@button(label="Retry")
async def retry_button(self, interaction, butt):
# Setting these here so they don't update in the meantime
for item, value in self.item_values.items():
item.default = value
if self._interaction is not None:
await self._interaction.delete_original_response()
self._interaction = None
await interaction.response.send_modal(self.modal)
await self.close()
class Confirm(LeoUI):
"""
Micro UI class implementing a confirmation question.
Parameters
----------
confirm_msg: str
The confirmation question to ask from the user.
This is set as the description of the `embed` property.
The `embed` may be further modified if required.
permitted_id: Optional[int]
The user id allowed to access this interaction.
Other users will recieve an access denied error message.
defer: bool
Whether to defer the interaction response while handling the button.
It may be useful to set this to `False` to obtain manual control
over the interaction response flow (e.g. to send a modal or ephemeral message).
The button press interaction may be accessed through `Confirm.interaction`.
Default: True
Example
-------
```
confirm = Confirm("Are you sure?", ctx.author.id)
confirm.embed.colour = discord.Colour.red()
confirm.confirm_button.label = "Yes I am sure"
confirm.cancel_button.label = "No I am not sure"
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResultTimedOut:
return
```
"""
def __init__(
self,
confirm_msg: str,
permitted_id: Optional[int] = None,
defer: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.confirm_msg = confirm_msg
self.permitted_id = permitted_id
self.defer = defer
self._embed: Optional[discord.Embed] = None
self._result: asyncio.Future[bool] = asyncio.Future()
# Indicates whether we should delete the message or the interaction response
self._is_followup: bool = False
self._original: Optional[discord.Interaction] = None
self._message: Optional[discord.Message] = None
async def interaction_check(self, interaction: discord.Interaction):
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
async def on_timeout(self):
# Propagate timeout to result Future
self._result.set_exception(ResponseTimedOut)
await self.cleanup()
async def cleanup(self):
"""
Cleanup the confirmation prompt by deleting it, if possible.
Ignores any Discord errors that occur during the process.
"""
try:
if self._is_followup and self._message:
await self._message.delete()
elif not self._is_followup and self._original and not self._original.is_expired():
await self._original.delete_original_response()
except discord.HTTPException:
# A user probably already deleted the message
# Anything could have happened, just ignore.
pass
@button(label="Confirm")
async def confirm_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(True)
await self.close()
@button(label="Cancel")
async def cancel_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(False)
await self.close()
@property
def embed(self):
"""
Confirmation embed shown to the user.
This is cached, and may be modifed directly through the usual EmbedProxy API,
or explicitly overwritten.
"""
if self._embed is None:
self._embed = discord.Embed(
colour=discord.Colour.orange(),
description=self.confirm_msg
)
return self._embed
@embed.setter
def embed(self, value):
self._embed = value
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
"""
Send this confirmation prompt in response to the provided interaction.
Extra keyword arguments are passed to `Interaction.response.send_message`
or `Interaction.send_followup`, depending on whether
the provided interaction has already been responded to.
The `epehemeral` argument is handled specially,
since the question message can only be deleted through `Interaction.delete_original_response`.
Waits on and returns the internal `result` Future.
Returns: bool
True if the user pressed the confirm button.
False if the user pressed the cancel button.
Raises:
ResponseTimedOut:
If the user does not respond before the UI times out.
"""
self._original = interaction
if interaction.response.is_done():
# Interaction already responded to, send a follow up
if ephemeral:
raise ValueError("Cannot send an ephemeral response to a used interaction.")
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
self._is_followup = True
else:
await interaction.response.send_message(
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
)
self._is_followup = False
return await self._result

456
src/utils/ui/pagers.py Normal file
View File

@@ -0,0 +1,456 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
from collections import defaultdict
import discord
from discord.ui.button import Button, button
from discord import app_commands as appcmds
from meta.logger import log_action_stack, logging_context
from meta.errors import SafeCancellation
from meta.config import conf
from babel.translator import ctx_translator
from ..lib import MessageArgs, error_embed
from .. import util_babel
from .leo import LeoUI
_p = util_babel._p
__all__ = (
'BasePager',
'Pager',
)
class BasePager(LeoUI):
"""
An ABC describing the common interface for a Paging UI.
A paging UI represents a sequence of pages, accessible by `next` and `previous` buttons,
and possibly by a dropdown (not implemented).
A `Page` is represented as a `MessageArgs` object, which is passable to `send` and `edit` methods as required.
Each page of a paging UI is accessed through the coroutine `get_page`.
This allows for more complex paging schemes where the pages are expensive to compute,
and not generally needed simultaneously.
In general, `get_page` should cache expensive pages,
perhaps simply with a `cached` decorator, but this is not enforced.
The state of the base UI is represented as the current `page_num` and the `current_page`.
This class also maintains an `active_pagers` cache,
representing all `BasePager`s that are currently running.
This allows access from external page controlling utilities, e.g. the `/page` command.
"""
# List of valid keys indicating movement to the next page
next_list = _p('cmd:page|pager:Pager|options:next', "n, nxt, next, forward, +")
# List of valid keys indicating movement to the previous page
prev_list = _p('cmd:page|pager:Pager|options:prev', "p, prev, back, -")
# List of valid keys indicating movement to the first page
first_list = _p('cmd:page|pager:Pager|options:first', "f, first, one, start")
# List of valid keys indicating movement to the last page
last_list = _p('cmd:page|pager:Pager|options:last', "l, last, end")
# channelid -> pager.id -> list of active pagers in this channel
active_pagers: dict[int, dict[int, 'BasePager']] = defaultdict(dict)
page_num: int
current_page: MessageArgs
_channelid: Optional[int]
@classmethod
def get_active_pager(self, channelid, userid):
"""
Get the last active pager in the `destinationid`, which may be accessed by `userid`.
Returns None if there are no matching pagers.
"""
for pager in reversed(self.active_pagers[channelid].values()):
if pager.access_check(userid):
return pager
def set_active(self):
if self._channelid is None:
raise ValueError("Cannot set active without a channelid.")
self.active_pagers[self._channelid][self.id] = self
def set_inactive(self):
self.active_pagers[self._channelid].pop(self.id, None)
def access_check(self, userid):
"""
Check whether the given userid is allowed to use this UI.
Must be overridden by subclasses.
"""
raise NotImplementedError
async def get_page(self, page_id) -> MessageArgs:
"""
`get_page` returns the specified page number, starting from 0.
An implementation of `get_page` must:
- Always return a page (if no data is a valid state, must return a placeholder page).
- Always accept out-of-range `page_id` values.
- There is no behaviour specified for these, although they will usually be modded into the correct
range.
- In some cases (e.g. stream data where we don't have a last page),
they may simply return the last correct page instead.
"""
raise NotImplementedError
async def page_cmd(self, interaction: discord.Interaction, value: str):
"""
Command implementation for the paging command.
Pager subclasses should override this if they use `active_pagers`.
Default implementation is essentially a no-op,
simply replying to the interaction.
"""
await interaction.response.defer()
return
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
"""
Command autocompletion for the paging command.
Pager subclasses should override this if they use `active_pagers`.
"""
return []
@button(emoji=conf.emojis.getemoji('forward'))
async def next_page_button(self, interaction: discord.Interaction, press):
await interaction.response.defer()
self.page_num += 1
await self.redraw()
@button(emoji=conf.emojis.getemoji('backward'))
async def prev_page_button(self, interaction: discord.Interaction, press):
await interaction.response.defer()
self.page_num -= 1
await self.redraw()
async def refresh(self):
"""
Recalculate current computed state.
(E.g. fetch current page, set layout, disable components, etc.)
"""
self.current_page = await self.get_page(self.page_num)
async def redraw(self):
"""
This should refresh the current state and redraw the UI.
Not implemented here, as the implementation depends on whether this is a reaction response ephemeral UI
or a message=based one.
"""
raise NotImplementedError
class Pager(BasePager):
"""
MicroUI to display a sequence of static pages,
supporting paging reaction and paging commands.
Parameters
----------
pages: list[MessageArgs]
A non-empty list of message arguments to page.
start_from: int
The page number to display first.
Default: 0
locked: bool
Whether to only allow the author to use the paging interface.
"""
def __init__(self, pages: list[MessageArgs],
start_from=0,
show_cancel=False, delete_on_cancel=True, delete_after=False, **kwargs):
super().__init__(**kwargs)
self._pages = pages
self.page_num = start_from
self.current_page = pages[self.page_num]
self._locked = True
self._ownerid: Optional[int] = None
self._channelid: Optional[int] = None
if not pages:
raise ValueError("Cannot run Pager with no pages.")
self._original: Optional[discord.Interaction] = None
self._is_followup: bool = False
self._message: Optional[discord.Message] = None
self.show_cancel = show_cancel
self._delete_on_cancel = delete_on_cancel
self._delete_after = delete_after
@property
def ownerid(self):
if self._ownerid is not None:
return self._ownerid
elif self._original:
return self._original.user.id
else:
return None
def access_check(self, userid):
return not self._locked or (userid == self.ownerid)
async def interaction_check(self, interaction: discord.Interaction):
return self.access_check(interaction.user.id)
@button(emoji=conf.emojis.getemoji('cancel'))
async def cancel_button(self, interaction: discord.Interaction, press: Button):
await interaction.response.defer()
if self._delete_on_cancel:
self._delete_after = True
await self.close()
async def cleanup(self):
self.set_inactive()
# If we still have a message, delete it or clear the view
try:
if self._is_followup:
if self._message:
if self._delete_after:
await self._message.delete()
else:
await self._message.edit(view=None)
else:
if self._original and not self._original.is_expired():
if self._delete_after:
await self._original.delete_original_response()
else:
await self._original.edit_original_response(view=None)
except discord.HTTPException:
# Nothing we can do here
pass
async def get_page(self, page_id):
page_id %= len(self._pages)
return self._pages[page_id]
def page_count(self):
return len(self.pages)
async def page_cmd(self, interaction: discord.Interaction, value: str):
"""
`/page` command for the `Pager` MicroUI.
"""
await interaction.response.defer(ephemeral=True)
t = ctx_translator.get().t
nexts = {word.strip() for word in t(self.next_list).split(',')}
prevs = {word.strip() for word in t(self.prev_list).split(',')}
firsts = {word.strip() for word in t(self.first_list).split(',')}
lasts = {word.strip() for word in t(self.last_list).split(',')}
if value:
value = value.lower().strip()
if value.isdigit():
# Assume value is page number
self.page_num = int(value) - 1
if self.page_num == -1:
self.page_num = 0
elif value in firsts:
self.page_num = 0
elif value in nexts:
self.page_num += 1
elif value in prevs:
self.page_num -= 1
elif value in lasts:
self.page_num = -1
elif value.startswith('-') and value[1:].isdigit():
self.page_num = - int(value[1:])
else:
await interaction.edit_original_response(
embed=error_embed(
t(_p(
'cmd:page|pager:Pager|error:parse',
"Could not understand page specification `{value}`."
)).format(value=value)
)
)
return
await interaction.delete_original_response()
await self.redraw()
async def page_acmpl(self, interaction: discord.Interaction, partial: str):
"""
`/page` command autocompletion for the `Pager` MicroUI.
"""
t = ctx_translator.get().t
nexts = {word.strip() for word in t(self.next_list).split(',')}
prevs = {word.strip() for word in t(self.prev_list).split(',')}
firsts = {word.strip() for word in t(self.first_list).split(',')}
lasts = {word.strip() for word in t(self.last_list).split(',')}
total = len(self._pages)
num = self.page_num
page_choices: dict[int, str] = {}
# TODO: Support page names and hints?
if len(self._pages) > 10:
# First add the general choices
if num < total-1:
page_choices[total-1] = t(_p(
'cmd:page|acmpl|pager:Pager|choice:last',
"Last: Page {page}/{total}"
)).format(page=total, total=total)
page_choices[num] = t(_p(
'cmd:page|acmpl|pager:Pager|choice:current',
"Current: Page {page}/{total}"
)).format(page=num+1, total=total)
choices = [
appcmds.Choice(name=string, value=str(num+1))
for num, string in sorted(page_choices.items(), key=lambda t: t[0])
]
else:
# Particularly support page names here
choices = [
appcmds.Choice(
name='> ' * (i == num) + t(_p(
'cmd:page|acmpl|pager:Pager|choice:general',
"Page {page}"
)).format(page=i+1),
value=str(i+1)
)
for i in range(0, total)
]
partial = partial.strip()
if partial:
value = partial.lower().strip()
if value.isdigit():
# Assume value is page number
page_num = int(value) - 1
if page_num == -1:
page_num = 0
elif value in firsts:
page_num = 0
elif value in nexts:
page_num = self.page_num + 1
elif value in prevs:
page_num = self.page_num - 1
elif value in lasts:
page_num = -1
elif value.startswith('-') and value[1:].isdigit():
page_num = - int(value[1:])
else:
page_num = None
if page_num is not None:
page_num %= total
choice = appcmds.Choice(
name=t(_p(
'cmd:page|acmpl|pager:Page|choice:select',
"Selected: Page {page}/{total}"
)).format(page=page_num+1, total=total),
value=str(page_num + 1)
)
return [choice, *choices]
else:
return [
appcmds.Choice(
name=t(_p(
'cmd:page|acmpl|pager:Page|error:parse',
"No matching pages!"
)).format(page=page_num, total=total),
value=partial
)
]
else:
return choices
@property
def page_row(self):
if self.show_cancel:
if len(self._pages) > 1:
return (self.prev_page_button, self.cancel_button, self.next_page_button)
else:
return (self.cancel_button,)
else:
if len(self._pages) > 1:
return (self.prev_page_button, self.next_page_button)
else:
return ()
async def refresh(self):
await super().refresh()
self.set_layout(self.page_row)
async def redraw(self):
await self.refresh()
if not self._original:
raise ValueError("Running run pager manually without interaction.")
try:
if self._message:
await self._message.edit(**self.current_page.edit_args, view=self)
else:
if self._original.is_expired():
raise SafeCancellation("This interface has expired, please try again.")
await self._original.edit_original_response(**self.current_page.edit_args, view=self)
except discord.HTTPException:
raise SafeCancellation("Could not page your results! Please try again.")
async def run(self, interaction: discord.Interaction, ephemeral=False, locked=True, ownerid=None, **kwargs):
"""
Display the UI.
Attempts to reply to the interaction if it has not already been replied to,
otherwise send a follow-up.
An ephemeral response must be sent as an initial interaction response.
On the other hand, a semi-persistent response (expected to last longer than the lifetime of the interaction)
must be sent as a followup.
Extra kwargs are combined with the first page arguments and given to the relevant send method.
Parameters
----------
interaction: discord.Interaction
The interaction to send the pager in response to.
ephemeral: bool
Whether to send the interaction ephemerally.
If this is true, the interaction *must* be fresh (i.e. no response done).
Default: False
locked: bool
Whether this interface is locked to the user `self.ownerid`.
Irrelevant for ephemeral messages.
Use `ownerid` to override the default owner id.
Defaults to true for fail-safety.
Default: True
ownerid: Optional[int]
The userid allowed to use this interaction.
By default, this will be the `interaction.user.id`,
presuming that this is the user which originally triggered this message.
An override may be useful if a user triggers a paging UI for someone else.
"""
if not interaction.channel_id:
raise ValueError("Cannot run pager on a channelless interaction.")
self._original = interaction
self._ownerid = ownerid
self._locked = locked
self._channelid = interaction.channel_id
await self.refresh()
args = self.current_page.send_args | kwargs
if interaction.response.is_done():
if ephemeral:
raise ValueError("Ephemeral response requires fres interaction.")
self._message = await interaction.followup.send(**args, view=self)
self._is_followup = True
else:
self._is_followup = False
await interaction.response.send_message(**args, view=self)
self.set_active()

View File

@@ -0,0 +1,91 @@
from typing import Any, Type, TYPE_CHECKING
from enum import Enum
import discord
import discord.app_commands as appcmd
from discord.app_commands.transformers import AppCommandOptionType
__all__ = (
'ChoicedEnum',
'ChoicedEnumTransformer',
'Transformed',
)
class ChoicedEnum(Enum):
@property
def choice_name(self):
return self.name
@property
def choice_value(self):
return self.value
@property
def choice(self):
return appcmd.Choice(
name=self.choice_name, value=self.choice_value
)
@classmethod
def choices(self):
return [item.choice for item in self]
@classmethod
def make_choice_map(cls):
return {item.choice_value: item for item in cls}
@classmethod
async def transform(cls, transformer: 'ChoicedEnumTransformer', interaction: discord.Interaction, value: Any):
return transformer._choice_map[value]
@classmethod
def option_type(cls) -> AppCommandOptionType:
return AppCommandOptionType.string
@classmethod
def transformer(cls, *args) -> appcmd.Transformer:
return ChoicedEnumTransformer(cls, *args)
class ChoicedEnumTransformer(appcmd.Transformer):
# __discord_app_commands_is_choice__ = True
def __init__(self, enum: Type[ChoicedEnum], opt_type) -> None:
super().__init__()
self._type = opt_type
self._enum = enum
self._choices = enum.choices()
self._choice_map = enum.make_choice_map()
@property
def _error_display_name(self) -> str:
return self._enum.__name__
@property
def type(self) -> AppCommandOptionType:
return self._type
@property
def choices(self):
return self._choices
async def transform(self, interaction: discord.Interaction, value: Any, /) -> Any:
return await self._enum.transform(self, interaction, value)
if TYPE_CHECKING:
from typing_extensions import Annotated as Transformed
else:
class Transformed:
def __class_getitem__(self, items):
cls = items[0]
options = items[1:]
if not hasattr(cls, 'transformer'):
raise ValueError("Tranformed class must have a transformer classmethod.")
transformer = cls.transformer(*options)
return appcmd.Transform[cls, transformer]