Initial Template.

This commit is contained in:
2023-11-02 10:09:06 +02:00
commit 1ae7c60ba8
52 changed files with 6786 additions and 0 deletions

0
src/utils/__init__.py Normal file
View File

97
src/utils/ansi.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Minimal library for making Discord Ansi colour codes.
"""
from enum import StrEnum
PREFIX = u'\u001b'
class TextColour(StrEnum):
Gray = '30'
Red = '31'
Green = '32'
Yellow = '33'
Blue = '34'
Pink = '35'
Cyan = '36'
White = '37'
def __str__(self) -> str:
return AnsiColour(fg=self).as_str()
def __call__(self):
return AnsiColour(fg=self)
class BgColour(StrEnum):
FireflyDarkBlue = '40'
Orange = '41'
MarbleBlue = '42'
GrayTurq = '43'
Gray = '44'
Indigo = '45'
LightGray = '46'
White = '47'
def __str__(self) -> str:
return AnsiColour(bg=self).as_str()
def __call__(self):
return AnsiColour(bg=self)
class Format(StrEnum):
NORMAL = '0'
BOLD = '1'
UNDERLINE = '4'
NOOP = '9'
def __str__(self) -> str:
return AnsiColour(self).as_str()
def __call__(self):
return AnsiColour(self)
class AnsiColour:
def __init__(self, *flags, fg=None, bg=None):
self.text_colour = fg
self.background_colour = bg
self.reset = (Format.NORMAL in flags)
self._flags = set(flags)
self._flags.discard(Format.NORMAL)
@property
def flags(self):
return (*((Format.NORMAL,) if self.reset else ()), *self._flags)
def as_str(self):
parts = []
if self.reset:
parts.append(Format.NORMAL)
elif not self.flags:
parts.append(Format.NOOP)
parts.extend(self._flags)
for c in (self.text_colour, self.background_colour):
if c is not None:
parts.append(c)
partstr = ';'.join(part.value for part in parts)
return f"{PREFIX}[{partstr}m" # ]
def __str__(self):
return self.as_str()
def __add__(self, obj: 'AnsiColour'):
text_colour = obj.text_colour or self.text_colour
background_colour = obj.background_colour or self.background_colour
flags = (*self.flags, *obj.flags)
return AnsiColour(*flags, fg=text_colour, bg=background_colour)
RESET = AnsiColour(Format.NORMAL)
BOLD = AnsiColour(Format.BOLD)
UNDERLINE = AnsiColour(Format.UNDERLINE)

165
src/utils/data.py Normal file
View File

@@ -0,0 +1,165 @@
"""
Some useful pre-built Conditions for data queries.
"""
from typing import Optional, Any
from itertools import chain
from psycopg import sql
from data.conditions import Condition, Joiner
from data.columns import ColumnExpr
from data.base import Expression
from constants import MAX_COINS
def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[Any, ...]) -> Condition:
"""
Condition constructor for filtering by multiple column equalities.
Example Usage
-------------
Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4)))
"""
if not data:
raise ValueError("Cannot create empty multivalue condition.")
left = sql.SQL("({})").format(
sql.SQL(', ').join(
sql.Identifier(key)
for key in columns
)
)
right_item = sql.SQL('({})').format(
sql.SQL(', ').join(
sql.Placeholder()
for _ in columns
)
)
right = sql.SQL("({})").format(
sql.SQL(', ').join(
right_item
for _ in data
)
)
return Condition(
left,
Joiner.IN,
right,
chain(*data)
)
def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition:
"""
Condition constructor for filtering member tables by guild and user id simultaneously.
Example Usage
-------------
Query.where(MEMBERS((1234,12), (5678,34)))
"""
if not memberids:
raise ValueError("Cannot create a condition with no members")
return Condition(
sql.SQL("({guildid}, {userid})").format(
guildid=sql.Identifier(guild_column),
userid=sql.Identifier(user_column)
),
Joiner.IN,
sql.SQL("({})").format(
sql.SQL(', ').join(
sql.SQL("({}, {})").format(
sql.Placeholder(),
sql.Placeholder()
) for _ in memberids
)
),
chain(*memberids)
)
def as_duration(expr: Expression) -> ColumnExpr:
"""
Convert an integer expression into a duration expression.
"""
expr_expr, expr_values = expr.as_tuple()
return ColumnExpr(
sql.SQL("({} * interval '1 second')").format(expr_expr),
expr_values
)
class TemporaryTable(Expression):
"""
Create a temporary table expression to be used in From or With clauses.
Example
-------
```
tmp_table = TemporaryTable('_col1', '_col2', name='data')
tmp_table.values((1, 2), (3, 4))
real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table)
```
"""
def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str, ...]] = None):
self.name = name
self.columns = columns
self.types = types
if types and len(types) != len(columns):
raise ValueError("Number of types does not much number of columns!")
self._table_columns = {
col: ColumnExpr(sql.Identifier(name, col))
for col in columns
}
self.values = []
def __getitem__(self, key) -> sql.Identifier:
return self._table_columns[key]
def as_tuple(self):
"""
(VALUES {})
AS
name (col1, col2)
"""
if not self.values:
raise ValueError("Cannot flatten CTE with no values.")
single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns))
if self.types:
first_value = sql.SQL("({})").format(
sql.SQL(", ").join(
sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast))
for cast in self.types
)
)
else:
first_value = single_value
value_placeholder = sql.SQL("(VALUES {})").format(
sql.SQL(", ").join(
(first_value, *(single_value for _ in self.values[1:]))
)
)
expr = sql.SQL("{values} AS {name} ({columns})").format(
values=value_placeholder,
name=sql.Identifier(self.name),
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns)
)
values = chain(*self.values)
return (expr, values)
def set_values(self, *data):
self.values = data
def SAFECOINS(expr: Expression) -> Expression:
expr_expr, expr_values = expr.as_tuple()
return ColumnExpr(
sql.SQL("LEAST({}, {})").format(
expr_expr,
sql.Literal(MAX_COINS)
),
expr_values
)

847
src/utils/lib.py Normal file
View File

@@ -0,0 +1,847 @@
from io import StringIO
from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any
import collections
import datetime
import datetime as dt
import iso8601 # type: ignore
import pytz
import re
import json
from contextvars import Context
import discord
from discord.partial_emoji import _EmojiTag
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord.ui import View
from meta.errors import UserInputError
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
MISSING = object()
class MessageArgs:
"""
Utility class for storing message creation and editing arguments.
"""
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
def __init__(self, **kwargs):
self.kwargs = kwargs
@property
def send_args(self) -> dict:
if self.kwargs.get('view', MISSING) is None:
kwargs = self.kwargs.copy()
kwargs.pop('view')
else:
kwargs = self.kwargs
return kwargs
@property
def edit_args(self) -> dict:
args = {}
kept = (
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
)
for k in kept:
if k in self.kwargs:
args[k] = self.kwargs[k]
if 'file' in self.kwargs:
args['attachments'] = [self.kwargs['file']]
if 'files' in self.kwargs:
args['attachments'] = self.kwargs['files']
if 'suppress_embeds' in self.kwargs:
args['suppress'] = self.kwargs['suppress_embeds']
return args
def tabulate(
*fields: tuple[str, str],
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
sub_format: str = "`{invis:<{pad}}{colon}`\t{value}",
colon: str = ':',
invis: str = "",
**args
) -> list[str]:
"""
Turns a list of (property, value) pairs into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Use `\\r\\n` in a value to break the line with padding.
Parameters
----------
fields: List[tuple[str, str]]
List of (key, value) pairs.
row_format: str
The format string used to format each row.
sub_format: str
The format string used to format each subline in a row.
colon: str
The colon character used.
invis: str
The invisible character used (to avoid Discord stripping the string).
Returns: List[str]
The list of resulting table rows.
Each row corresponds to one (key, value) pair from fields.
"""
max_len = max(len(field[0]) for field in fields)
rows = []
for field in fields:
key = field[0]
value = field[1]
lines = value.split('\r\n')
row_line = row_format.format(
invis=invis,
key=key,
pad=max_len,
colon=colon,
value=lines[0],
field=field,
**args
)
if len(lines) > 1:
row_lines = [row_line]
for line in lines[1:]:
sub_line = sub_format.format(
invis=invis,
pad=max_len + len(colon),
colon=colon,
value=line,
**args
)
row_lines.append(sub_line)
row_line = '\n'.join(row_lines)
rows.append(row_line)
return rows
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def _parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def strfdur(duration: int, short=True, show_days=False) -> str:
"""
Convert a duration given in seconds to a number of hours, minutes, and seconds.
"""
days = duration // (3600 * 24) if show_days else 0
hours = duration // 3600
if days:
hours %= 24
minutes = duration // 60 % 60
seconds = duration % 60
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
if short:
return ' '.join(parts)
else:
return ', '.join(parts)
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
# TODO: Upgrade to SafeCancellation
raise ValueError("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
_numbers = (item.strip() for item in substituted.split(','))
numbers = [item for item in _numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
if not ignore_errors and len(integers) != len(numbers):
# TODO: Upgrade to SafeCancellation
raise ValueError(
"Couldn't parse the provided selection!\n"
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
)
return integers
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.created_at.strftime(timestr)
user = str(msg.author)
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring: str) -> datetime.timedelta:
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message:
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
class EmbedField(NamedTuple):
name: str
value: str
inline: Optional[bool] = True
def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string: list[str], nfs=False) -> str:
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
# TODO: Probably not useful with localisation
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def shard_of(shard_count: int, guildid: int) -> int:
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)
def utc_now() -> datetime.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
return string
def recover_context(context: Context):
for var in context:
var.set(context[var])
def parse_ids(idstr: str) -> List[int]:
"""
Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids.
Object agnostic, so all mention tokens are stripped.
Raises UserInputError if an id is invalid,
setting `orig` and `item` info fields.
"""
from meta.errors import UserInputError
# Extract ids from string
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
splits = [split for split in splititer if split]
# Check they are integers
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
# Cast to integer and return
return list(map(int, splits))
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
)
return embed
class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
class Timezoned:
"""
ABC mixin for objects with a set timezone.
Provides several useful localised properties.
"""
__slots__ = ()
@property
def timezone(self) -> pytz.timezone:
"""
Must be implemented by the deriving class!
"""
raise NotImplementedError
@property
def now(self):
"""
Return the current time localised to the object's timezone.
"""
return datetime.datetime.now(tz=self.timezone)
@property
def today(self):
"""
Return the start of the day localised to the object's timezone.
"""
now = self.now
return now.replace(hour=0, minute=0, second=0, microsecond=0)
@property
def week_start(self):
"""
Return the start of the week in the object's timezone
"""
today = self.today
return today - datetime.timedelta(days=today.weekday())
@property
def month_start(self):
"""
Return the start of the current month in the object's timezone
"""
today = self.today
return today.replace(day=1)
def replace_multiple(format_string, mapping):
"""
Subsistutes the keys from the format_dict with their corresponding values.
Substitution is non-chained, and done in a single pass via regex.
"""
if not mapping:
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
return string
def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str):
"""
Produces a distinguishing key for an Emoji or PartialEmoji.
Equality checks using this key should act as expected.
"""
if isinstance(emoji, _EmojiTag):
if emoji.id:
key = str(emoji.id)
else:
key = str(emoji.name)
else:
key = str(emoji)
return key
def recurse_map(func, obj, loc=[]):
if isinstance(obj, dict):
for k, v in obj.items():
loc.append(k)
obj[k] = recurse_map(func, v, loc)
loc.pop()
elif isinstance(obj, list):
for i, item in enumerate(obj):
loc.append(i)
obj[i] = recurse_map(func, item)
loc.pop()
else:
obj = func(loc, obj)
return obj
async def check_dm(user: discord.User | discord.Member) -> bool:
"""
Check whether we can direct message the given user.
Assumes the client is initialised.
This uses an always-failing HTTP request,
so we need to be very very very careful that this is not used frequently.
Optimally only at the explicit behest of the user
(i.e. during a user instigated interaction).
"""
try:
await user.send('')
except discord.Forbidden:
return False
except discord.HTTPException:
return True
async def command_lengths(tree) -> dict[str, int]:
cmds = tree.get_commands()
payloads = [
await cmd.get_translated_payload(tree.translator)
for cmd in cmds
]
lens = {}
for command in payloads:
name = command['name']
crumbs = {}
cmd_len = lens[name] = _recurse_length(command, crumbs, (name,))
if name == 'configure' or cmd_len > 4000:
print(f"'{name}' over 4000. Breadcrumb Trail follows:")
lines = []
for loc, val in crumbs.items():
locstr = '.'.join(loc)
lines.append(f"{locstr}: {val}")
print('\n'.join(lines))
print(json.dumps(command, indent=2))
return lens
def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
total = 0
total_header = (*header, '')
breadcrumbs[total_header] = 0
if isinstance(payload, dict):
# Read strings that count towards command length
# String length is length of longest localisation, including default.
for key in ('name', 'description', 'value'):
if key in payload:
value = payload[key]
if isinstance(value, str):
values = (value, *payload.get(key + '_localizations', {}).values())
maxlen = max(map(len, values))
total += maxlen
breadcrumbs[(*header, key)] = maxlen
for key, value in payload.items():
loc = (*header, key)
total += _recurse_length(value, breadcrumbs, loc)
elif isinstance(payload, list):
for i, item in enumerate(payload):
if isinstance(item, dict) and 'name' in item:
loc = (*header, f"{i}<{item['name']}>")
else:
loc = (*header, str(i))
total += _recurse_length(item, breadcrumbs, loc)
if total:
breadcrumbs[total_header] = total
else:
breadcrumbs.pop(total_header)
return total
def write_records(records: list[dict[str, Any]], stream: StringIO):
if records:
keys = records[0].keys()
stream.write(','.join(keys))
stream.write('\n')
for record in records:
stream.write(','.join(map(str, record.values())))
stream.write('\n')

191
src/utils/monitor.py Normal file
View File

@@ -0,0 +1,191 @@
import asyncio
import bisect
import logging
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
from .lib import utc_now
from .ratelimits import Bucket
logger = logging.getLogger(__name__)
Taskid = TypeVar('Taskid')
class TaskMonitor(Generic[Taskid]):
"""
Base class for a task monitor.
Stores tasks as a time-sorted list of taskids.
Subclasses may override `run_task` to implement an executor.
Adding or removing a single task has O(n) performance.
To bulk update tasks, instead use `schedule_tasks`.
Each taskid must be unique and hashable.
"""
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
# Ratelimit bucket to enforce maximum execution rate
self._bucket = bucket
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
self._wakeup: asyncio.Event = asyncio.Event()
self._monitor_task: Optional[asyncio.Task] = None
# Task data
self._tasklist: list[Taskid] = []
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
# Running map ensures we keep a reference to the running task
# And allows simpler external cancellation if required
self._running: dict[Taskid, asyncio.Future] = {}
def __repr__(self):
return (
"<"
f"{self.__class__.__name__}"
f" tasklist={len(self._tasklist)}"
f" taskmap={len(self._taskmap)}"
f" wakeup={self._wakeup.is_set()}"
f" bucket={self._bucket}"
f" running={len(self._running)}"
f" task={self._monitor_task}"
f">"
)
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Similar to `schedule_tasks`, but wipe and reset the tasklist.
"""
self._taskmap = {tid: time for tid, time in tasks}
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
self._wakeup.set()
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Schedule the given tasks.
Rather than repeatedly inserting tasks,
where the O(log n) insort is dominated by the O(n) list insertion,
we build an entirely new list, and always wake up the loop.
"""
self._taskmap |= {tid: time for tid, time in tasks}
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
self._wakeup.set()
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
"""
Insert the provided task into the tasklist.
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
"""
if self._tasklist:
nextid = self._tasklist[-1]
wake = self._taskmap[nextid] >= timestamp
wake = wake or taskid == nextid
else:
wake = True
if taskid in self._taskmap:
self._tasklist.remove(taskid)
self._taskmap[taskid] = timestamp
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
if wake:
self._wakeup.set()
def cancel_tasks(self, *taskids: Taskid) -> None:
"""
Remove all tasks with the given taskids from the tasklist.
If the next task has this taskid, wake up the monitor loop.
"""
taskids = set(taskids)
wake = (self._tasklist and self._tasklist[-1] in taskids)
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
for tid in taskids:
self._taskmap.pop(tid, None)
if wake:
self._wakeup.set()
def start(self):
if self._monitor_task and not self._monitor_task.done():
self._monitor_task.cancel()
# Start the monitor
self._monitor_task = asyncio.create_task(self.monitor())
return self._monitor_task
async def monitor(self):
"""
Start the monitor.
Executes the tasks in `self.tasks` at the specified time.
This will shield task execution from cancellation
to avoid partial states.
"""
try:
while True:
self._wakeup.clear()
if not self._tasklist:
# No tasks left, just sleep until wakeup
await self._wakeup.wait()
else:
# Get the next task, sleep until wakeup or it is ready to run
nextid = self._tasklist[-1]
nexttime = self._taskmap[nextid]
sleep_for = nexttime - utc_now().timestamp()
try:
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
except asyncio.TimeoutError:
# Ready to run the task
self._tasklist.pop()
self._taskmap.pop(nextid, None)
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
else:
# Wakeup task fired, loop again
continue
except asyncio.CancelledError:
# Log closure and wait for remaining tasks
# A second cancellation will also cancel the tasks
logger.debug(
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
f"Waiting for {len(self._running)} running tasks to complete."
)
await asyncio.gather(*self._running.values(), return_exceptions=True)
async def _run(self, taskid: Taskid) -> None:
# Execute the task, respecting the ratelimit bucket
if self._bucket is not None:
# IMPLEMENTATION NOTE:
# Bucket.wait() should guarantee not more than n tasks/second are run
# and that a request directly afterwards will _not_ raise BucketFull
# Make sure that only one waiter is actually waiting on its sleep task
# The other waiters should be waiting on a lock around the sleep task
# Waiters are executed in wait-order, so if we only let a single waiter in
# we shouldn't get collisions.
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
# Or we will lose thread-safety for BucketFull
await self._bucket.wait()
fut = asyncio.create_task(self.run_task(taskid))
try:
await asyncio.shield(fut)
except asyncio.CancelledError:
raise
except Exception:
# Protect the monitor loop from any other exceptions
logger.exception(
f"Ignoring exception in task monitor {self.__class__.__name__} while "
f"executing <taskid: {taskid}>"
)
finally:
self._running.pop(taskid)
async def run_task(self, taskid: Taskid):
"""
Execute the task with the given taskid.
Default implementation executes `self.executor` if it exists,
otherwise raises NotImplementedError.
"""
if self.executor is not None:
await self.executor(taskid)
else:
raise NotImplementedError

173
src/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,173 @@
import asyncio
import time
import logging
from meta.errors import SafeCancellation
from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
delay = (self._level + 1 - self.max_level) * self.leak_rate
else:
delay = 0
return delay
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level > self.max_level:
raise BucketOverFull
elif self._level == self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
def fill(self):
self._leak()
self._level = max(self._level, self.max_level + 1)
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
async def wrapped(self, coro):
await self.wait()
self.request()
await coro
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
try:
bucket.request()
except BucketOverFull:
raise SafeCancellation(details="Bucket overflow")
except BucketFull:
raise SafeCancellation(self.error, details="Bucket full")
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")

8
src/utils/ui/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
import asyncio
import logging
logger = logging.getLogger(__name__)
from .hooked import *
from .leo import *
from .micros import *

59
src/utils/ui/hooked.py Normal file
View File

@@ -0,0 +1,59 @@
import time
import discord
from discord.ui.item import Item
from discord.ui.button import Button
from .leo import LeoUI
__all__ = (
'HookedItem',
'AButton',
'AsComponents'
)
class HookedItem:
"""
Mixin for Item classes allowing an instance to be used as a callback decorator.
"""
def __init__(self, *args, pass_kwargs={}, **kwargs):
super().__init__(*args, **kwargs)
self.pass_kwargs = pass_kwargs
def __call__(self, coro):
async def wrapped(interaction, **kwargs):
return await coro(interaction, self, **(self.pass_kwargs | kwargs))
self.callback = wrapped
return self
class AButton(HookedItem, Button):
...
class AsComponents(LeoUI):
"""
Simple container class to accept a number of Items and turn them into an attachable View.
"""
def __init__(self, *items, pass_kwargs={}, **kwargs):
super().__init__(**kwargs)
self.pass_kwargs = pass_kwargs
for item in items:
self.add_item(item)
async def _scheduled_task(self, item: Item, interaction: discord.Interaction):
try:
item._refresh_state(interaction, interaction.data) # type: ignore
allow = await self.interaction_check(interaction)
if not allow:
return
if self.timeout:
self.__timeout_expiry = time.monotonic() + self.timeout
await item.callback(interaction, **self.pass_kwargs)
except Exception as e:
return await self.on_error(interaction, e, item)

485
src/utils/ui/leo.py Normal file
View File

@@ -0,0 +1,485 @@
from typing import List, Optional, Any, Dict
import asyncio
import logging
import time
from contextvars import copy_context, Context
import discord
from discord.ui import Modal, View, Item
from meta.logger import log_action_stack, logging_context
from meta.errors import SafeCancellation
from . import logger
from ..lib import MessageArgs, error_embed
__all__ = (
'LeoUI',
'MessageUI',
'LeoModal',
'error_handler_for'
)
class LeoUI(View):
"""
View subclass for small-scale user interfaces.
While a 'View' provides an interface for managing a collection of components,
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
The `LeoUI` also exposes more advanced cleanup and timeout methods,
and preserves the context.
"""
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
super().__init__(*args, **kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._name = ui_name or self.__class__.__name__
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
# List of slaved views to stop when this view stops
self._slaves: List[View] = []
# TODO: Replace this with a substitutable ViewLayout class
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
@property
def _stopped(self) -> asyncio.Future:
"""
Return an future indicating whether the View has finished interacting.
Currently exposes a hidden attribute of the underlying View.
May be reimplemented in future.
"""
return self._View__stopped
def to_components(self) -> List[Dict[str, Any]]:
"""
Extending component generator to apply the set _layout, if it exists.
"""
if self._layout is not None:
# Alternative rendering using layout
components = []
for i, row in enumerate(self._layout):
# Skip empty rows
if not row:
continue
# Since we aren't relying on ViewWeights, manually check width here
if sum(item.width for item in row) > 5:
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
# Create the component dict for this row
components.append({
'type': 1,
'components': [item.to_component_dict() for item in row]
})
else:
components = super().to_components()
return components
def set_layout(self, *rows: tuple[Item, ...]) -> None:
"""
Set the layout of the rendered View as a matrix of items,
or more precisely, a list of action rows.
This acts independently of the existing sorting with `_ViewWeights`,
and overrides the sorting if applied.
"""
self._layout = rows
async def cleanup(self):
"""
Coroutine to run when timeing out, stopping, or cancelling.
Generally cleans up any open resources, and removes any leftover components.
"""
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
return None
def stop(self):
"""
Extends View.stop() to also stop all the slave views.
Note that stopping is idempotent, so it is okay if close() also calls stop().
"""
for slave in self._slaves:
slave.stop()
super().stop()
async def close(self, msg=None):
self.stop()
await self.cleanup()
async def pre_timeout(self):
"""
Task to execute before actually timing out.
This may cancel the timeout by refreshing or rescheduling it.
(E.g. to ask the user whether they want to keep going.)
Default implementation does nothing.
"""
return None
async def on_timeout(self):
"""
Task to execute after timeout is complete.
Default implementation calls cleanup.
"""
await self.cleanup()
async def __dispatch_timeout(self):
"""
This essentially extends View._dispatch_timeout,
to include a pre_timeout task
which may optionally refresh and hence cancel the timeout.
"""
if self._View__stopped.done():
# We are already stopped, nothing to do
return
with logging_context(action='Timeout'):
try:
await self.pre_timeout()
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
pass
except Exception:
logger.exception(
"Unhandled error caught while dispatching timeout for {self!r}.",
extra={'with_ctx': True, 'action': 'Error'}
)
# Check if we still need to timeout
if self.timeout is None:
# The timeout was removed entirely, silently walk away
return
if self._View__stopped.done():
# We stopped while waiting for the pre timeout.
# Or maybe another thread timed us out
# Either way, we are done here
return
now = time.monotonic()
if self._View__timeout_expiry is not None and now < self._View__timeout_expiry:
# The timeout was extended, make sure the timeout task is running then fade away
if self._View__timeout_task is None or self._View__timeout_task.done():
self._View__timeout_task = asyncio.create_task(self._View__timeout_task_impl())
else:
# Actually timeout, and call the post-timeout task for cleanup.
self._really_timeout()
await self.on_timeout()
def _dispatch_timeout(self):
"""
Overriding timeout method completely, to support interactive flow during timeout,
and optional refreshing of the timeout.
"""
return self._context.run(asyncio.create_task, self.__dispatch_timeout())
def _really_timeout(self):
"""
Actuallly times out the View.
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
which is now handled by `__dispatch_timeout`.
"""
if self._View__stopped.done():
return
if self._View__cancel_callback:
self._View__cancel_callback(self)
self._View__cancel_callback = None
self._View__stopped.set_result(True)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
"""
Default LeoUI error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except SafeCancellation as e:
if e.msg and not interaction.is_expired():
try:
if interaction.response.is_done():
await interaction.followup.send(
embed=error_embed(e.msg),
ephemeral=True
)
else:
await interaction.response.send_message(
embed=error_embed(e.msg),
ephemeral=True
)
except discord.HTTPException:
pass
logger.debug(
f"Caught a safe cancellation from LeoUI: {e.details}",
extra={'action': 'Cancel'}
)
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r} from interaction: "
f"{interaction.data}",
extra={'with_ctx': True, 'action': 'UIError'}
)
# Explicitly handle the bugsplat ourselves
splat = interaction.client.tree.bugsplat(interaction, error)
await interaction.client.tree.error_reply(interaction, splat)
class MessageUI(LeoUI):
"""
Simple single-message LeoUI, intended as a framework for UIs
attached to a single interaction response.
UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`.
"""
def __init__(self, *args, callerid: Optional[int] = None, **kwargs):
super().__init__(*args, **kwargs)
# ----- UI state -----
# User ID of the original caller (e.g. command author).
# Mainly used for interaction usage checks and logging
self._callerid = callerid
# Original interaction, if this UI is sent as an interaction response
self._original: discord.Interaction = None
# Message holding the UI, when the UI is sent attached to a followup
self._message: discord.Message = None
# Refresh lock, to avoid cache collisions on refresh
self._refresh_lock = asyncio.Lock()
@property
def channel(self):
if self._original is not None:
return self._original.channel
else:
return self._message.channel
# ----- UI API -----
async def run(self, interaction: discord.Interaction, **kwargs):
"""
Run the UI as a response or followup to the given interaction.
Should be extended if more complex run mechanics are needed
(e.g. registering listeners or setting up caches).
"""
await self.draw(interaction, **kwargs)
async def refresh(self, *args, thinking: Optional[discord.Interaction] = None, **kwargs):
"""
Reload and redraw this UI.
Primarily a hook-method for use by parents and other controllers.
Performs a full data and reload and refresh (maintaining UI state, e.g. page n).
"""
async with self._refresh_lock:
# Reload data
await self.reload()
# Redraw UI message
await self.redraw(thinking=thinking)
async def quit(self):
"""
Quit the UI.
This usually involves removing the original message,
and stopping or closing the underlying View.
"""
for child in self._slaves:
# TODO: Better to use duck typing or interface typing
if isinstance(child, MessageUI) and not child.is_finished():
asyncio.create_task(child.quit())
try:
if self._original is not None and not self._original.is_expired():
await self._original.delete_original_response()
self._original = None
if self._message is not None:
await self._message.delete()
self._message = None
except discord.HTTPException:
pass
# Note close() also runs cleanup and stop
await self.close()
# ----- UI Flow -----
async def interaction_check(self, interaction: discord.Interaction):
"""
Check the given interaction is authorised to use this UI.
Default implementation simply checks that the interaction is
from the original caller.
Extend for more complex logic.
"""
return interaction.user.id == self._callerid
async def make_message(self) -> MessageArgs:
"""
Create the UI message body, depening on the current state.
Called upon each redraw.
Should handle caching if message construction is for some reason intensive.
Must be implemented by concrete UI subclasses.
"""
raise NotImplementedError
async def refresh_layout(self):
"""
Asynchronously refresh the message components,
and explicitly set the message component layout.
Called just before redrawing, before `make_message`.
Must be implemented by concrete UI subclasses.
"""
raise NotImplementedError
async def reload(self):
"""
Reload and recompute the underlying data for this UI.
Must be implemented by concrete UI subclasses.
"""
raise NotImplementedError
async def draw(self, interaction, force_followup=False, **kwargs):
"""
Send the UI as a response or followup to the given interaction.
If the interaction has been responded to, or `force_followup` is set,
creates a followup message instead of a response to the interaction.
"""
# Initial data loading
await self.reload()
# Set the UI layout
await self.refresh_layout()
# Fetch message arguments
args = await self.make_message()
as_followup = force_followup or interaction.response.is_done()
if as_followup:
self._message = await interaction.followup.send(**args.send_args, **kwargs, view=self)
else:
self._original = interaction
await interaction.response.send_message(**args.send_args, **kwargs, view=self)
async def send(self, channel: discord.abc.Messageable, **kwargs):
"""
Alternative to draw() which uses a discord.abc.Messageable.
"""
await self.reload()
await self.refresh_layout()
args = await self.make_message()
self._message = await channel.send(**args.send_args, view=self)
async def _redraw(self, args):
if self._original and not self._original.is_expired():
await self._original.edit_original_response(**args.edit_args, view=self)
elif self._message:
await self._message.edit(**args.edit_args, view=self)
else:
# Interaction expired or already closed. Quietly cleanup.
await self.close()
async def redraw(self, thinking: Optional[discord.Interaction] = None):
"""
Update the output message for this UI.
If a thinking interaction is provided, deletes the response while redrawing.
"""
await self.refresh_layout()
args = await self.make_message()
if thinking is not None and not thinking.is_expired() and thinking.response.is_done():
asyncio.create_task(thinking.delete_original_response())
try:
await self._redraw(args)
except discord.HTTPException as e:
# Unknown communication error, nothing we can reliably do. Exit quietly.
logger.warning(
f"Unexpected UI redraw failure occurred in {self}: {repr(e)}",
)
await self.close()
async def cleanup(self):
"""
Remove message components from interaction response, if possible.
Extend to remove listeners or clean up caches.
`cleanup` is always called when the UI is exiting,
through timeout or user-driven closure.
"""
try:
if self._original is not None and not self._original.is_expired():
await self._original.edit_original_response(view=None)
self._original = None
if self._message is not None:
await self._message.edit(view=None)
self._message = None
except discord.HTTPException:
pass
class LeoModal(Modal):
"""
Context-aware Modal class.
"""
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
super().__init__(**kwargs)
if context is None:
self._context = copy_context()
else:
self._context = context
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
def _dispatch_submit(self, *args, **kwargs):
"""
Extending event dispatch to run in the instantiation context.
"""
return self._context.run(super()._dispatch_submit, *args, **kwargs)
def _dispatch_item(self, *args, **kwargs):
"""Extending event dispatch to run in the instantiation context."""
return self._context.run(super()._dispatch_item, *args, **kwargs)
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
"""
Default LeoModal error handle.
This may be tail extended by subclasses to preserve the exception stack.
"""
try:
raise error
except Exception:
logger.exception(
f"Unhandled interaction exception occurred in {self!r}. Interaction: {interaction.data}",
extra={'with_ctx': True, 'action': 'ModalError'}
)
# Explicitly handle the bugsplat ourselves
splat = interaction.client.tree.bugsplat(interaction, error)
await interaction.client.tree.error_reply(interaction, splat)
def error_handler_for(exc):
def wrapper(coro):
coro._ui_error_handler_for_ = exc
return coro
return wrapper

329
src/utils/ui/micros.py Normal file
View File

@@ -0,0 +1,329 @@
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
import functools
import asyncio
import discord
from discord.ui import TextInput
from discord.ui.button import button
from meta.logger import logging_context
from meta.errors import ResponseTimedOut
from .leo import LeoModal, LeoUI
__all__ = (
'FastModal',
'ModalRetryUI',
'Confirm',
'input',
)
class FastModal(LeoModal):
__class_error_handlers__ = []
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
error_handlers = {}
for base in reversed(cls.__mro__):
for name, member in base.__dict__.items():
if hasattr(member, '_ui_error_handler_for_'):
error_handlers[name] = member
cls.__class_error_handlers__ = list(error_handlers.values())
def __init__error_handlers__(self):
handlers = {}
for handler in self.__class_error_handlers__:
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
return handlers
def __init__(self, *items: TextInput, **kwargs):
super().__init__(**kwargs)
for item in items:
self.add_item(item)
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
self._error_handlers = self.__init__error_handlers__()
def error_handler(self, exception):
def wrapper(coro):
self._error_handlers[exception] = coro
return coro
return wrapper
async def wait_for(self, check=None, timeout=None):
# Wait for _result or timeout
# If we timeout, or the view times out, raise TimeoutError
# Otherwise, return the Interaction
# This allows multiple listeners and callbacks to wait on
while True:
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
if check is not None:
if not check(result):
continue
return result
async def on_timeout(self):
self._result.set_exception(asyncio.TimeoutError)
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
def wrapper(coro):
async def wrapped_callback(interaction):
with logging_context(action=coro.__name__):
if check is not None:
if not check(interaction):
return
try:
await coro(interaction, *pass_args, **pass_kwargs)
except Exception:
raise
finally:
if once:
self._waiters.remove(wrapped_callback)
self._waiters.append(wrapped_callback)
return wrapper
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
try:
# First let our error handlers have a go
# If there is no handler for this error, or the handlers themselves error,
# drop to the superclass error handler implementation.
try:
raise error
except tuple(self._error_handlers.keys()) as e:
# If an error handler is registered for this exception, run it.
for cls, handler in self._error_handlers.items():
if isinstance(e, cls):
await handler(interaction, e)
except Exception as error:
await super().on_error(interaction, error)
async def on_submit(self, interaction):
print("On submit")
old_result = self._result
self._result = asyncio.get_event_loop().create_future()
old_result.set_result(interaction)
tasks = []
for waiter in self._waiters:
task = asyncio.create_task(
waiter(interaction),
name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}"
)
tasks.append(task)
if tasks:
await asyncio.gather(*tasks)
async def input(
interaction: discord.Interaction,
title: str,
question: Optional[str] = None,
field: Optional[TextInput] = None,
timeout=180,
**kwargs,
) -> tuple[discord.Interaction, str]:
"""
Spawn a modal to accept input.
Returns an (interaction, value) pair, with interaction not yet responded to.
May raise asyncio.TimeoutError if the view times out.
"""
if field is None:
field = TextInput(
label=kwargs.get('label', question),
**kwargs
)
modal = FastModal(
field,
title=title,
timeout=timeout
)
await interaction.response.send_modal(modal)
interaction = await modal.wait_for()
return (interaction, field.value)
class ModalRetryUI(LeoUI):
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.modal = modal
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
self.message = message
self._interaction = None
if label is not None:
self.retry_button.label = label
@property
def embed(self):
return discord.Embed(
title="Uh-Oh!",
description=self.message,
colour=discord.Colour.red()
)
async def respond_to(self, interaction):
self._interaction = interaction
if interaction.response.is_done():
await interaction.followup.send(embed=self.embed, ephemeral=True, view=self)
else:
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
@button(label="Retry")
async def retry_button(self, interaction, butt):
# Setting these here so they don't update in the meantime
for item, value in self.item_values.items():
item.default = value
if self._interaction is not None:
await self._interaction.delete_original_response()
self._interaction = None
await interaction.response.send_modal(self.modal)
await self.close()
class Confirm(LeoUI):
"""
Micro UI class implementing a confirmation question.
Parameters
----------
confirm_msg: str
The confirmation question to ask from the user.
This is set as the description of the `embed` property.
The `embed` may be further modified if required.
permitted_id: Optional[int]
The user id allowed to access this interaction.
Other users will recieve an access denied error message.
defer: bool
Whether to defer the interaction response while handling the button.
It may be useful to set this to `False` to obtain manual control
over the interaction response flow (e.g. to send a modal or ephemeral message).
The button press interaction may be accessed through `Confirm.interaction`.
Default: True
Example
-------
```
confirm = Confirm("Are you sure?", ctx.author.id)
confirm.embed.colour = discord.Colour.red()
confirm.confirm_button.label = "Yes I am sure"
confirm.cancel_button.label = "No I am not sure"
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResultTimedOut:
return
```
"""
def __init__(
self,
confirm_msg: str,
permitted_id: Optional[int] = None,
defer: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.confirm_msg = confirm_msg
self.permitted_id = permitted_id
self.defer = defer
self._embed: Optional[discord.Embed] = None
self._result: asyncio.Future[bool] = asyncio.Future()
# Indicates whether we should delete the message or the interaction response
self._is_followup: bool = False
self._original: Optional[discord.Interaction] = None
self._message: Optional[discord.Message] = None
async def interaction_check(self, interaction: discord.Interaction):
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
async def on_timeout(self):
# Propagate timeout to result Future
self._result.set_exception(ResponseTimedOut)
await self.cleanup()
async def cleanup(self):
"""
Cleanup the confirmation prompt by deleting it, if possible.
Ignores any Discord errors that occur during the process.
"""
try:
if self._is_followup and self._message:
await self._message.delete()
elif not self._is_followup and self._original and not self._original.is_expired():
await self._original.delete_original_response()
except discord.HTTPException:
# A user probably already deleted the message
# Anything could have happened, just ignore.
pass
@button(label="Confirm")
async def confirm_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(True)
await self.close()
@button(label="Cancel")
async def cancel_button(self, interaction: discord.Interaction, press):
if self.defer:
await interaction.response.defer()
self._result.set_result(False)
await self.close()
@property
def embed(self):
"""
Confirmation embed shown to the user.
This is cached, and may be modifed directly through the usual EmbedProxy API,
or explicitly overwritten.
"""
if self._embed is None:
self._embed = discord.Embed(
colour=discord.Colour.orange(),
description=self.confirm_msg
)
return self._embed
@embed.setter
def embed(self, value):
self._embed = value
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
"""
Send this confirmation prompt in response to the provided interaction.
Extra keyword arguments are passed to `Interaction.response.send_message`
or `Interaction.send_followup`, depending on whether
the provided interaction has already been responded to.
The `epehemeral` argument is handled specially,
since the question message can only be deleted through `Interaction.delete_original_response`.
Waits on and returns the internal `result` Future.
Returns: bool
True if the user pressed the confirm button.
False if the user pressed the cancel button.
Raises:
ResponseTimedOut:
If the user does not respond before the UI times out.
"""
self._original = interaction
if interaction.response.is_done():
# Interaction already responded to, send a follow up
if ephemeral:
raise ValueError("Cannot send an ephemeral response to a used interaction.")
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
self._is_followup = True
else:
await interaction.response.send_message(
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
)
self._is_followup = False
return await self._result
# TODO: Selector MicroUI for displaying options (<= 25)