Initial Template.
This commit is contained in:
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
97
src/utils/ansi.py
Normal file
97
src/utils/ansi.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Minimal library for making Discord Ansi colour codes.
|
||||
"""
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
PREFIX = u'\u001b'
|
||||
|
||||
|
||||
class TextColour(StrEnum):
|
||||
Gray = '30'
|
||||
Red = '31'
|
||||
Green = '32'
|
||||
Yellow = '33'
|
||||
Blue = '34'
|
||||
Pink = '35'
|
||||
Cyan = '36'
|
||||
White = '37'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return AnsiColour(fg=self).as_str()
|
||||
|
||||
def __call__(self):
|
||||
return AnsiColour(fg=self)
|
||||
|
||||
|
||||
class BgColour(StrEnum):
|
||||
FireflyDarkBlue = '40'
|
||||
Orange = '41'
|
||||
MarbleBlue = '42'
|
||||
GrayTurq = '43'
|
||||
Gray = '44'
|
||||
Indigo = '45'
|
||||
LightGray = '46'
|
||||
White = '47'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return AnsiColour(bg=self).as_str()
|
||||
|
||||
def __call__(self):
|
||||
return AnsiColour(bg=self)
|
||||
|
||||
|
||||
class Format(StrEnum):
|
||||
NORMAL = '0'
|
||||
BOLD = '1'
|
||||
UNDERLINE = '4'
|
||||
NOOP = '9'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return AnsiColour(self).as_str()
|
||||
|
||||
def __call__(self):
|
||||
return AnsiColour(self)
|
||||
|
||||
|
||||
class AnsiColour:
|
||||
def __init__(self, *flags, fg=None, bg=None):
|
||||
self.text_colour = fg
|
||||
self.background_colour = bg
|
||||
self.reset = (Format.NORMAL in flags)
|
||||
self._flags = set(flags)
|
||||
self._flags.discard(Format.NORMAL)
|
||||
|
||||
@property
|
||||
def flags(self):
|
||||
return (*((Format.NORMAL,) if self.reset else ()), *self._flags)
|
||||
|
||||
def as_str(self):
|
||||
parts = []
|
||||
if self.reset:
|
||||
parts.append(Format.NORMAL)
|
||||
elif not self.flags:
|
||||
parts.append(Format.NOOP)
|
||||
|
||||
parts.extend(self._flags)
|
||||
|
||||
for c in (self.text_colour, self.background_colour):
|
||||
if c is not None:
|
||||
parts.append(c)
|
||||
|
||||
partstr = ';'.join(part.value for part in parts)
|
||||
return f"{PREFIX}[{partstr}m" # ]
|
||||
|
||||
def __str__(self):
|
||||
return self.as_str()
|
||||
|
||||
def __add__(self, obj: 'AnsiColour'):
|
||||
text_colour = obj.text_colour or self.text_colour
|
||||
background_colour = obj.background_colour or self.background_colour
|
||||
flags = (*self.flags, *obj.flags)
|
||||
return AnsiColour(*flags, fg=text_colour, bg=background_colour)
|
||||
|
||||
|
||||
RESET = AnsiColour(Format.NORMAL)
|
||||
BOLD = AnsiColour(Format.BOLD)
|
||||
UNDERLINE = AnsiColour(Format.UNDERLINE)
|
||||
165
src/utils/data.py
Normal file
165
src/utils/data.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Some useful pre-built Conditions for data queries.
|
||||
"""
|
||||
from typing import Optional, Any
|
||||
from itertools import chain
|
||||
|
||||
from psycopg import sql
|
||||
from data.conditions import Condition, Joiner
|
||||
from data.columns import ColumnExpr
|
||||
from data.base import Expression
|
||||
from constants import MAX_COINS
|
||||
|
||||
|
||||
def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[Any, ...]) -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering by multiple column equalities.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4)))
|
||||
"""
|
||||
if not data:
|
||||
raise ValueError("Cannot create empty multivalue condition.")
|
||||
left = sql.SQL("({})").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.Identifier(key)
|
||||
for key in columns
|
||||
)
|
||||
)
|
||||
right_item = sql.SQL('({})').format(
|
||||
sql.SQL(', ').join(
|
||||
sql.Placeholder()
|
||||
for _ in columns
|
||||
)
|
||||
)
|
||||
right = sql.SQL("({})").format(
|
||||
sql.SQL(', ').join(
|
||||
right_item
|
||||
for _ in data
|
||||
)
|
||||
)
|
||||
return Condition(
|
||||
left,
|
||||
Joiner.IN,
|
||||
right,
|
||||
chain(*data)
|
||||
)
|
||||
|
||||
|
||||
def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering member tables by guild and user id simultaneously.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(MEMBERS((1234,12), (5678,34)))
|
||||
"""
|
||||
if not memberids:
|
||||
raise ValueError("Cannot create a condition with no members")
|
||||
return Condition(
|
||||
sql.SQL("({guildid}, {userid})").format(
|
||||
guildid=sql.Identifier(guild_column),
|
||||
userid=sql.Identifier(user_column)
|
||||
),
|
||||
Joiner.IN,
|
||||
sql.SQL("({})").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.SQL("({}, {})").format(
|
||||
sql.Placeholder(),
|
||||
sql.Placeholder()
|
||||
) for _ in memberids
|
||||
)
|
||||
),
|
||||
chain(*memberids)
|
||||
)
|
||||
|
||||
|
||||
def as_duration(expr: Expression) -> ColumnExpr:
|
||||
"""
|
||||
Convert an integer expression into a duration expression.
|
||||
"""
|
||||
expr_expr, expr_values = expr.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} * interval '1 second')").format(expr_expr),
|
||||
expr_values
|
||||
)
|
||||
|
||||
|
||||
class TemporaryTable(Expression):
|
||||
"""
|
||||
Create a temporary table expression to be used in From or With clauses.
|
||||
|
||||
Example
|
||||
-------
|
||||
```
|
||||
tmp_table = TemporaryTable('_col1', '_col2', name='data')
|
||||
tmp_table.values((1, 2), (3, 4))
|
||||
|
||||
real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str, ...]] = None):
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.types = types
|
||||
if types and len(types) != len(columns):
|
||||
raise ValueError("Number of types does not much number of columns!")
|
||||
|
||||
self._table_columns = {
|
||||
col: ColumnExpr(sql.Identifier(name, col))
|
||||
for col in columns
|
||||
}
|
||||
|
||||
self.values = []
|
||||
|
||||
def __getitem__(self, key) -> sql.Identifier:
|
||||
return self._table_columns[key]
|
||||
|
||||
def as_tuple(self):
|
||||
"""
|
||||
(VALUES {})
|
||||
AS
|
||||
name (col1, col2)
|
||||
"""
|
||||
if not self.values:
|
||||
raise ValueError("Cannot flatten CTE with no values.")
|
||||
|
||||
single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns))
|
||||
if self.types:
|
||||
first_value = sql.SQL("({})").format(
|
||||
sql.SQL(", ").join(
|
||||
sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast))
|
||||
for cast in self.types
|
||||
)
|
||||
)
|
||||
else:
|
||||
first_value = single_value
|
||||
|
||||
value_placeholder = sql.SQL("(VALUES {})").format(
|
||||
sql.SQL(", ").join(
|
||||
(first_value, *(single_value for _ in self.values[1:]))
|
||||
)
|
||||
)
|
||||
expr = sql.SQL("{values} AS {name} ({columns})").format(
|
||||
values=value_placeholder,
|
||||
name=sql.Identifier(self.name),
|
||||
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns)
|
||||
)
|
||||
values = chain(*self.values)
|
||||
return (expr, values)
|
||||
|
||||
def set_values(self, *data):
|
||||
self.values = data
|
||||
|
||||
|
||||
def SAFECOINS(expr: Expression) -> Expression:
|
||||
expr_expr, expr_values = expr.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("LEAST({}, {})").format(
|
||||
expr_expr,
|
||||
sql.Literal(MAX_COINS)
|
||||
),
|
||||
expr_values
|
||||
)
|
||||
847
src/utils/lib.py
Normal file
847
src/utils/lib.py
Normal file
@@ -0,0 +1,847 @@
|
||||
from io import StringIO
|
||||
from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any
|
||||
import collections
|
||||
import datetime
|
||||
import datetime as dt
|
||||
import iso8601 # type: ignore
|
||||
import pytz
|
||||
import re
|
||||
import json
|
||||
from contextvars import Context
|
||||
|
||||
import discord
|
||||
from discord.partial_emoji import _EmojiTag
|
||||
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
|
||||
from discord.ui import View
|
||||
|
||||
from meta.errors import UserInputError
|
||||
|
||||
|
||||
multiselect_regex = re.compile(
|
||||
r"^([0-9, -]+)$",
|
||||
re.DOTALL | re.IGNORECASE | re.VERBOSE
|
||||
)
|
||||
tick = '✅'
|
||||
cross = '❌'
|
||||
|
||||
MISSING = object()
|
||||
|
||||
|
||||
class MessageArgs:
|
||||
"""
|
||||
Utility class for storing message creation and editing arguments.
|
||||
"""
|
||||
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[str] = ...,
|
||||
*,
|
||||
tts: bool = ...,
|
||||
embed: Embed = ...,
|
||||
file: File = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||
mention_author: bool = ...,
|
||||
view: View = ...,
|
||||
suppress_embeds: bool = ...,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[str] = ...,
|
||||
*,
|
||||
tts: bool = ...,
|
||||
embed: Embed = ...,
|
||||
files: Sequence[File] = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||
mention_author: bool = ...,
|
||||
view: View = ...,
|
||||
suppress_embeds: bool = ...,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[str] = ...,
|
||||
*,
|
||||
tts: bool = ...,
|
||||
embeds: Sequence[Embed] = ...,
|
||||
file: File = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||
mention_author: bool = ...,
|
||||
view: View = ...,
|
||||
suppress_embeds: bool = ...,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[str] = ...,
|
||||
*,
|
||||
tts: bool = ...,
|
||||
embeds: Sequence[Embed] = ...,
|
||||
files: Sequence[File] = ...,
|
||||
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||
delete_after: float = ...,
|
||||
nonce: Union[str, int] = ...,
|
||||
allowed_mentions: AllowedMentions = ...,
|
||||
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||
mention_author: bool = ...,
|
||||
view: View = ...,
|
||||
suppress_embeds: bool = ...,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def send_args(self) -> dict:
|
||||
if self.kwargs.get('view', MISSING) is None:
|
||||
kwargs = self.kwargs.copy()
|
||||
kwargs.pop('view')
|
||||
else:
|
||||
kwargs = self.kwargs
|
||||
|
||||
return kwargs
|
||||
|
||||
@property
|
||||
def edit_args(self) -> dict:
|
||||
args = {}
|
||||
kept = (
|
||||
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
|
||||
)
|
||||
for k in kept:
|
||||
if k in self.kwargs:
|
||||
args[k] = self.kwargs[k]
|
||||
|
||||
if 'file' in self.kwargs:
|
||||
args['attachments'] = [self.kwargs['file']]
|
||||
|
||||
if 'files' in self.kwargs:
|
||||
args['attachments'] = self.kwargs['files']
|
||||
|
||||
if 'suppress_embeds' in self.kwargs:
|
||||
args['suppress'] = self.kwargs['suppress_embeds']
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def tabulate(
|
||||
*fields: tuple[str, str],
|
||||
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
|
||||
sub_format: str = "`{invis:<{pad}}{colon}`\t{value}",
|
||||
colon: str = ':',
|
||||
invis: str = "",
|
||||
**args
|
||||
) -> list[str]:
|
||||
"""
|
||||
Turns a list of (property, value) pairs into
|
||||
a pretty string with one `prop: value` pair each line,
|
||||
padded so that the colons in each line are lined up.
|
||||
Use `\\r\\n` in a value to break the line with padding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fields: List[tuple[str, str]]
|
||||
List of (key, value) pairs.
|
||||
row_format: str
|
||||
The format string used to format each row.
|
||||
sub_format: str
|
||||
The format string used to format each subline in a row.
|
||||
colon: str
|
||||
The colon character used.
|
||||
invis: str
|
||||
The invisible character used (to avoid Discord stripping the string).
|
||||
|
||||
Returns: List[str]
|
||||
The list of resulting table rows.
|
||||
Each row corresponds to one (key, value) pair from fields.
|
||||
"""
|
||||
max_len = max(len(field[0]) for field in fields)
|
||||
|
||||
rows = []
|
||||
for field in fields:
|
||||
key = field[0]
|
||||
value = field[1]
|
||||
lines = value.split('\r\n')
|
||||
|
||||
row_line = row_format.format(
|
||||
invis=invis,
|
||||
key=key,
|
||||
pad=max_len,
|
||||
colon=colon,
|
||||
value=lines[0],
|
||||
field=field,
|
||||
**args
|
||||
)
|
||||
if len(lines) > 1:
|
||||
row_lines = [row_line]
|
||||
for line in lines[1:]:
|
||||
sub_line = sub_format.format(
|
||||
invis=invis,
|
||||
pad=max_len + len(colon),
|
||||
colon=colon,
|
||||
value=line,
|
||||
**args
|
||||
)
|
||||
row_lines.append(sub_line)
|
||||
row_line = '\n'.join(row_lines)
|
||||
rows.append(row_line)
|
||||
return rows
|
||||
|
||||
|
||||
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
|
||||
"""
|
||||
Create pretty codeblock pages from a list of strings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item_list: List[str]
|
||||
List of strings to paginate.
|
||||
block_length: int
|
||||
Maximum number of strings per page.
|
||||
style: str
|
||||
Codeblock style to use.
|
||||
Title formatting assumes the `markdown` style, and numbered lists work well with this.
|
||||
However, `markdown` sometimes messes up formatting in the list.
|
||||
title: str
|
||||
Optional title to add to the top of each page.
|
||||
|
||||
Returns: List[str]
|
||||
List of pages, each formatted into a codeblock,
|
||||
and containing at most `block_length` of the provided strings.
|
||||
"""
|
||||
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
|
||||
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
|
||||
pages = []
|
||||
for i, block in enumerate(page_blocks):
|
||||
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
|
||||
if title:
|
||||
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
|
||||
else:
|
||||
header = pagenum
|
||||
header_line = "=" * len(header)
|
||||
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
|
||||
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
|
||||
return pages
|
||||
|
||||
|
||||
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
|
||||
"""
|
||||
Break the text into blocks of maximum length blocksize
|
||||
If possible, break across nearby newlines. Otherwise just break at blocksize chars
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
Text to break into blocks.
|
||||
blocksize: int
|
||||
Maximum character length for each block.
|
||||
code: bool
|
||||
Whether to wrap each block in codeblocks (these are counted in the blocksize).
|
||||
syntax: str
|
||||
The markdown formatting language to use for the codeblocks, if applicable.
|
||||
maxheight: int
|
||||
The maximum number of lines in each block
|
||||
|
||||
Returns: List[str]
|
||||
List of blocks,
|
||||
each containing at most `block_size` characters,
|
||||
of height at most `maxheight`.
|
||||
"""
|
||||
# Adjust blocksize to account for the codeblocks if required
|
||||
blocksize = blocksize - 8 - len(syntax) if code else blocksize
|
||||
|
||||
# Build the blocks
|
||||
blocks = []
|
||||
while True:
|
||||
# If the remaining text is already small enough, append it
|
||||
if len(text) <= blocksize:
|
||||
blocks.append(text)
|
||||
break
|
||||
text = text.strip('\n')
|
||||
|
||||
# Find the last newline in the prototype block
|
||||
split_on = text[0:blocksize].rfind('\n')
|
||||
split_on = blocksize if split_on < blocksize // 5 else split_on
|
||||
|
||||
# Add the block and truncate the text
|
||||
blocks.append(text[0:split_on])
|
||||
text = text[split_on:]
|
||||
|
||||
# Add the codeblock ticks and the code syntax header, if required
|
||||
if code:
|
||||
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str:
|
||||
"""
|
||||
Convert a datetime.timedelta object into an easily readable duration string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delta: datetime.timedelta
|
||||
The timedelta object to convert into a readable string.
|
||||
sec: bool
|
||||
Whether to include the seconds from the timedelta object in the string.
|
||||
minutes: bool
|
||||
Whether to include the minutes from the timedelta object in the string.
|
||||
short: bool
|
||||
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||
|
||||
Returns: str
|
||||
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||
Time units will be abbreviated if short was set to True.
|
||||
"""
|
||||
output = [[delta.days, 'd' if short else ' day'],
|
||||
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||
if minutes:
|
||||
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||
if sec:
|
||||
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||
for i in range(len(output)):
|
||||
if output[i][0] != 1 and not short:
|
||||
output[i][1] += 's' # type: ignore
|
||||
reply_msg = []
|
||||
if output[0][0] != 0:
|
||||
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||
for i in range(2, len(output) - 1):
|
||||
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||
if not short and reply_msg:
|
||||
reply_msg.append("and ")
|
||||
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||
return "".join(reply_msg)
|
||||
|
||||
|
||||
def _parse_dur(time_str: str) -> int:
|
||||
"""
|
||||
Parses a user provided time duration string into a timedelta object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_str: str
|
||||
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||
|
||||
Returns: int
|
||||
The number of seconds the duration represents.
|
||||
"""
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
time_str = time_str.strip(" ,")
|
||||
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||
seconds = 0
|
||||
for bit in found:
|
||||
if bit[1] in funcs:
|
||||
seconds += funcs[bit[1]](int(bit[0]))
|
||||
return seconds
|
||||
|
||||
|
||||
def strfdur(duration: int, short=True, show_days=False) -> str:
|
||||
"""
|
||||
Convert a duration given in seconds to a number of hours, minutes, and seconds.
|
||||
"""
|
||||
days = duration // (3600 * 24) if show_days else 0
|
||||
hours = duration // 3600
|
||||
if days:
|
||||
hours %= 24
|
||||
minutes = duration // 60 % 60
|
||||
seconds = duration % 60
|
||||
|
||||
parts = []
|
||||
if days:
|
||||
unit = 'd' if short else (' days' if days != 1 else ' day')
|
||||
parts.append('{}{}'.format(days, unit))
|
||||
if hours:
|
||||
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
|
||||
parts.append('{}{}'.format(hours, unit))
|
||||
if minutes:
|
||||
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
|
||||
parts.append('{}{}'.format(minutes, unit))
|
||||
if seconds or duration == 0:
|
||||
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
|
||||
parts.append('{}{}'.format(seconds, unit))
|
||||
|
||||
if short:
|
||||
return ' '.join(parts)
|
||||
else:
|
||||
return ', '.join(parts)
|
||||
|
||||
|
||||
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
|
||||
"""
|
||||
Substitutes a user provided list of numbers and ranges,
|
||||
and replaces the ranges by the corresponding list of numbers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ranges_str: str
|
||||
The string to ranges in.
|
||||
max_match: int
|
||||
The maximum number of ranges to replace.
|
||||
Any ranges exceeding this will be ignored.
|
||||
max_range: int
|
||||
The maximum length of range to replace.
|
||||
Attempting to replace a range longer than this will raise a `ValueError`.
|
||||
"""
|
||||
def _repl(match):
|
||||
n1 = int(match.group(1))
|
||||
n2 = int(match.group(2))
|
||||
if n2 - n1 > max_range:
|
||||
# TODO: Upgrade to SafeCancellation
|
||||
raise ValueError("Provided range is too large!")
|
||||
return separator.join(str(i) for i in range(n1, n2 + 1))
|
||||
|
||||
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
||||
|
||||
|
||||
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
|
||||
"""
|
||||
Parses a user provided range string into a list of numbers.
|
||||
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
|
||||
"""
|
||||
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
|
||||
_numbers = (item.strip() for item in substituted.split(','))
|
||||
numbers = [item for item in _numbers if item]
|
||||
integers = [int(item) for item in numbers if item.isdigit()]
|
||||
|
||||
if not ignore_errors and len(integers) != len(numbers):
|
||||
# TODO: Upgrade to SafeCancellation
|
||||
raise ValueError(
|
||||
"Couldn't parse the provided selection!\n"
|
||||
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
|
||||
)
|
||||
|
||||
return integers
|
||||
|
||||
|
||||
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
|
||||
"""
|
||||
Format a message into a string with various information, such as:
|
||||
the timestamp of the message, author, message content, and attachments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg: Message
|
||||
The message to format.
|
||||
mask_link: bool
|
||||
Whether to mask the URLs of any attachments.
|
||||
line_break: bool
|
||||
Whether a line break should be used in the string.
|
||||
tz: Timezone
|
||||
The timezone to use in the formatted message.
|
||||
clean: bool
|
||||
Whether to use the clean content of the original message.
|
||||
|
||||
Returns: str
|
||||
A formatted string containing various information:
|
||||
User timezone, message author, message content, attachments
|
||||
"""
|
||||
timestr = "%I:%M %p, %d/%m/%Y"
|
||||
if tz:
|
||||
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
|
||||
else:
|
||||
time = msg.created_at.strftime(timestr)
|
||||
user = str(msg.author)
|
||||
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
|
||||
if mask_link:
|
||||
attach_list = ["[Link]({})".format(url) for url in attach_list]
|
||||
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
|
||||
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
|
||||
time=time,
|
||||
user=user,
|
||||
line_break="\n" if line_break else "",
|
||||
message=msg.clean_content if clean else msg.content,
|
||||
attachments=attachments
|
||||
)
|
||||
|
||||
|
||||
def convdatestring(datestring: str) -> datetime.timedelta:
|
||||
"""
|
||||
Convert a date string into a datetime.timedelta object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datestring: str
|
||||
The string to convert to a datetime.timedelta object.
|
||||
|
||||
Returns: datetime.timedelta
|
||||
A datetime.timedelta object formed from the string provided.
|
||||
"""
|
||||
datestring = datestring.strip(' ,')
|
||||
datearray = []
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
currentnumber = ''
|
||||
for char in datestring:
|
||||
if char.isdigit():
|
||||
currentnumber += char
|
||||
else:
|
||||
if currentnumber == '':
|
||||
continue
|
||||
datearray.append((int(currentnumber), char))
|
||||
currentnumber = ''
|
||||
seconds = 0
|
||||
if currentnumber:
|
||||
seconds += int(currentnumber)
|
||||
for i in datearray:
|
||||
if i[1] in funcs:
|
||||
seconds += funcs[i[1]](i[0])
|
||||
return datetime.timedelta(seconds=seconds)
|
||||
|
||||
|
||||
class _rawChannel(discord.abc.Messageable):
|
||||
"""
|
||||
Raw messageable class representing an arbitrary channel,
|
||||
not necessarially seen by the gateway.
|
||||
"""
|
||||
def __init__(self, state, id):
|
||||
self._state = state
|
||||
self.id = id
|
||||
|
||||
async def _get_channel(self):
|
||||
return discord.Object(self.id)
|
||||
|
||||
|
||||
async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message:
|
||||
"""
|
||||
Mails a message to a channelid which may be invisible to the gateway.
|
||||
|
||||
Parameters:
|
||||
client: discord.Client
|
||||
The client to use for mailing.
|
||||
Must at least have static authentication and have a valid `_connection`.
|
||||
channelid: int
|
||||
The channel id to mail to.
|
||||
msg_args: Any
|
||||
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
|
||||
"""
|
||||
# Create the raw channel
|
||||
channel = _rawChannel(client._connection, channelid)
|
||||
return await channel.send(**msg_args)
|
||||
|
||||
|
||||
class EmbedField(NamedTuple):
|
||||
name: str
|
||||
value: str
|
||||
inline: Optional[bool] = True
|
||||
|
||||
|
||||
def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]):
|
||||
"""
|
||||
Append embed fields to an embed.
|
||||
Parameters
|
||||
----------
|
||||
embed: discord.Embed
|
||||
The embed to add the field to.
|
||||
emb_fields: tuple
|
||||
The values to add to a field.
|
||||
name: str
|
||||
The name of the field.
|
||||
value: str
|
||||
The value of the field.
|
||||
inline: bool
|
||||
Whether the embed field should be inline or not.
|
||||
"""
|
||||
for field in emb_fields:
|
||||
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
|
||||
|
||||
|
||||
def join_list(string: list[str], nfs=False) -> str:
|
||||
"""
|
||||
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
|
||||
Parameters
|
||||
----------
|
||||
string: list
|
||||
The list to join together.
|
||||
nfs: bool
|
||||
(no fullstops)
|
||||
Whether to exclude fullstops/periods from the output messages.
|
||||
If not provided, fullstops will be appended to the output.
|
||||
"""
|
||||
# TODO: Probably not useful with localisation
|
||||
if len(string) > 1:
|
||||
return "{}{} and {}{}".format((", ").join(string[:-1]),
|
||||
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
|
||||
else:
|
||||
return "{}{}".format("".join(string), "" if nfs else ".")
|
||||
|
||||
|
||||
def shard_of(shard_count: int, guildid: int) -> int:
|
||||
"""
|
||||
Calculate the shard number of a given guild.
|
||||
"""
|
||||
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
|
||||
|
||||
|
||||
def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
|
||||
"""
|
||||
Build a jump link for a message given its location.
|
||||
"""
|
||||
return 'https://discord.com/channels/{}/{}/{}'.format(
|
||||
guildid,
|
||||
channeldid,
|
||||
messageid
|
||||
)
|
||||
|
||||
|
||||
def utc_now() -> datetime.datetime:
|
||||
"""
|
||||
Return the current timezone-aware utc timestamp.
|
||||
"""
|
||||
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
|
||||
if rep_dict:
|
||||
pattern = re.compile(
|
||||
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
|
||||
flags=re.DOTALL
|
||||
)
|
||||
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def recover_context(context: Context):
|
||||
for var in context:
|
||||
var.set(context[var])
|
||||
|
||||
|
||||
def parse_ids(idstr: str) -> List[int]:
|
||||
"""
|
||||
Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids.
|
||||
|
||||
Object agnostic, so all mention tokens are stripped.
|
||||
Raises UserInputError if an id is invalid,
|
||||
setting `orig` and `item` info fields.
|
||||
"""
|
||||
from meta.errors import UserInputError
|
||||
|
||||
# Extract ids from string
|
||||
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
|
||||
splits = [split for split in splititer if split]
|
||||
|
||||
# Check they are integers
|
||||
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
|
||||
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
|
||||
|
||||
# Cast to integer and return
|
||||
return list(map(int, splits))
|
||||
|
||||
|
||||
def error_embed(error, **kwargs) -> discord.Embed:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_red(),
|
||||
description=error,
|
||||
timestamp=utc_now()
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
class Timezoned:
|
||||
"""
|
||||
ABC mixin for objects with a set timezone.
|
||||
|
||||
Provides several useful localised properties.
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def timezone(self) -> pytz.timezone:
|
||||
"""
|
||||
Must be implemented by the deriving class!
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def now(self):
|
||||
"""
|
||||
Return the current time localised to the object's timezone.
|
||||
"""
|
||||
return datetime.datetime.now(tz=self.timezone)
|
||||
|
||||
@property
|
||||
def today(self):
|
||||
"""
|
||||
Return the start of the day localised to the object's timezone.
|
||||
"""
|
||||
now = self.now
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
@property
|
||||
def week_start(self):
|
||||
"""
|
||||
Return the start of the week in the object's timezone
|
||||
"""
|
||||
today = self.today
|
||||
return today - datetime.timedelta(days=today.weekday())
|
||||
|
||||
@property
|
||||
def month_start(self):
|
||||
"""
|
||||
Return the start of the current month in the object's timezone
|
||||
"""
|
||||
today = self.today
|
||||
return today.replace(day=1)
|
||||
|
||||
|
||||
def replace_multiple(format_string, mapping):
|
||||
"""
|
||||
Subsistutes the keys from the format_dict with their corresponding values.
|
||||
|
||||
Substitution is non-chained, and done in a single pass via regex.
|
||||
"""
|
||||
if not mapping:
|
||||
raise ValueError("Empty mapping passed.")
|
||||
|
||||
keys = list(mapping.keys())
|
||||
pattern = '|'.join(f"({key})" for key in keys)
|
||||
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
|
||||
return string
|
||||
|
||||
|
||||
def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str):
|
||||
"""
|
||||
Produces a distinguishing key for an Emoji or PartialEmoji.
|
||||
|
||||
Equality checks using this key should act as expected.
|
||||
"""
|
||||
if isinstance(emoji, _EmojiTag):
|
||||
if emoji.id:
|
||||
key = str(emoji.id)
|
||||
else:
|
||||
key = str(emoji.name)
|
||||
else:
|
||||
key = str(emoji)
|
||||
|
||||
return key
|
||||
|
||||
def recurse_map(func, obj, loc=[]):
|
||||
if isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
loc.append(k)
|
||||
obj[k] = recurse_map(func, v, loc)
|
||||
loc.pop()
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
loc.append(i)
|
||||
obj[i] = recurse_map(func, item)
|
||||
loc.pop()
|
||||
else:
|
||||
obj = func(loc, obj)
|
||||
return obj
|
||||
|
||||
async def check_dm(user: discord.User | discord.Member) -> bool:
|
||||
"""
|
||||
Check whether we can direct message the given user.
|
||||
|
||||
Assumes the client is initialised.
|
||||
This uses an always-failing HTTP request,
|
||||
so we need to be very very very careful that this is not used frequently.
|
||||
Optimally only at the explicit behest of the user
|
||||
(i.e. during a user instigated interaction).
|
||||
"""
|
||||
try:
|
||||
await user.send('')
|
||||
except discord.Forbidden:
|
||||
return False
|
||||
except discord.HTTPException:
|
||||
return True
|
||||
|
||||
|
||||
async def command_lengths(tree) -> dict[str, int]:
|
||||
cmds = tree.get_commands()
|
||||
payloads = [
|
||||
await cmd.get_translated_payload(tree.translator)
|
||||
for cmd in cmds
|
||||
]
|
||||
lens = {}
|
||||
for command in payloads:
|
||||
name = command['name']
|
||||
crumbs = {}
|
||||
cmd_len = lens[name] = _recurse_length(command, crumbs, (name,))
|
||||
if name == 'configure' or cmd_len > 4000:
|
||||
print(f"'{name}' over 4000. Breadcrumb Trail follows:")
|
||||
lines = []
|
||||
for loc, val in crumbs.items():
|
||||
locstr = '.'.join(loc)
|
||||
lines.append(f"{locstr}: {val}")
|
||||
print('\n'.join(lines))
|
||||
print(json.dumps(command, indent=2))
|
||||
return lens
|
||||
|
||||
def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
|
||||
total = 0
|
||||
total_header = (*header, '')
|
||||
breadcrumbs[total_header] = 0
|
||||
|
||||
if isinstance(payload, dict):
|
||||
# Read strings that count towards command length
|
||||
# String length is length of longest localisation, including default.
|
||||
for key in ('name', 'description', 'value'):
|
||||
if key in payload:
|
||||
value = payload[key]
|
||||
if isinstance(value, str):
|
||||
values = (value, *payload.get(key + '_localizations', {}).values())
|
||||
maxlen = max(map(len, values))
|
||||
total += maxlen
|
||||
breadcrumbs[(*header, key)] = maxlen
|
||||
|
||||
for key, value in payload.items():
|
||||
loc = (*header, key)
|
||||
total += _recurse_length(value, breadcrumbs, loc)
|
||||
elif isinstance(payload, list):
|
||||
for i, item in enumerate(payload):
|
||||
if isinstance(item, dict) and 'name' in item:
|
||||
loc = (*header, f"{i}<{item['name']}>")
|
||||
else:
|
||||
loc = (*header, str(i))
|
||||
total += _recurse_length(item, breadcrumbs, loc)
|
||||
|
||||
if total:
|
||||
breadcrumbs[total_header] = total
|
||||
else:
|
||||
breadcrumbs.pop(total_header)
|
||||
|
||||
return total
|
||||
|
||||
def write_records(records: list[dict[str, Any]], stream: StringIO):
|
||||
if records:
|
||||
keys = records[0].keys()
|
||||
stream.write(','.join(keys))
|
||||
stream.write('\n')
|
||||
for record in records:
|
||||
stream.write(','.join(map(str, record.values())))
|
||||
stream.write('\n')
|
||||
191
src/utils/monitor.py
Normal file
191
src/utils/monitor.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import logging
|
||||
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
|
||||
|
||||
from .lib import utc_now
|
||||
from .ratelimits import Bucket
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Taskid = TypeVar('Taskid')
|
||||
|
||||
|
||||
class TaskMonitor(Generic[Taskid]):
|
||||
"""
|
||||
Base class for a task monitor.
|
||||
|
||||
Stores tasks as a time-sorted list of taskids.
|
||||
Subclasses may override `run_task` to implement an executor.
|
||||
|
||||
Adding or removing a single task has O(n) performance.
|
||||
To bulk update tasks, instead use `schedule_tasks`.
|
||||
|
||||
Each taskid must be unique and hashable.
|
||||
"""
|
||||
|
||||
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
|
||||
# Ratelimit bucket to enforce maximum execution rate
|
||||
self._bucket = bucket
|
||||
|
||||
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
|
||||
|
||||
self._wakeup: asyncio.Event = asyncio.Event()
|
||||
self._monitor_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Task data
|
||||
self._tasklist: list[Taskid] = []
|
||||
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
|
||||
|
||||
# Running map ensures we keep a reference to the running task
|
||||
# And allows simpler external cancellation if required
|
||||
self._running: dict[Taskid, asyncio.Future] = {}
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"<"
|
||||
f"{self.__class__.__name__}"
|
||||
f" tasklist={len(self._tasklist)}"
|
||||
f" taskmap={len(self._taskmap)}"
|
||||
f" wakeup={self._wakeup.is_set()}"
|
||||
f" bucket={self._bucket}"
|
||||
f" running={len(self._running)}"
|
||||
f" task={self._monitor_task}"
|
||||
f">"
|
||||
)
|
||||
|
||||
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
||||
"""
|
||||
Similar to `schedule_tasks`, but wipe and reset the tasklist.
|
||||
"""
|
||||
self._taskmap = {tid: time for tid, time in tasks}
|
||||
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
|
||||
self._wakeup.set()
|
||||
|
||||
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
||||
"""
|
||||
Schedule the given tasks.
|
||||
|
||||
Rather than repeatedly inserting tasks,
|
||||
where the O(log n) insort is dominated by the O(n) list insertion,
|
||||
we build an entirely new list, and always wake up the loop.
|
||||
"""
|
||||
self._taskmap |= {tid: time for tid, time in tasks}
|
||||
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
|
||||
self._wakeup.set()
|
||||
|
||||
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
|
||||
"""
|
||||
Insert the provided task into the tasklist.
|
||||
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
|
||||
"""
|
||||
if self._tasklist:
|
||||
nextid = self._tasklist[-1]
|
||||
wake = self._taskmap[nextid] >= timestamp
|
||||
wake = wake or taskid == nextid
|
||||
else:
|
||||
wake = True
|
||||
if taskid in self._taskmap:
|
||||
self._tasklist.remove(taskid)
|
||||
self._taskmap[taskid] = timestamp
|
||||
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
|
||||
if wake:
|
||||
self._wakeup.set()
|
||||
|
||||
def cancel_tasks(self, *taskids: Taskid) -> None:
|
||||
"""
|
||||
Remove all tasks with the given taskids from the tasklist.
|
||||
If the next task has this taskid, wake up the monitor loop.
|
||||
"""
|
||||
taskids = set(taskids)
|
||||
wake = (self._tasklist and self._tasklist[-1] in taskids)
|
||||
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
|
||||
for tid in taskids:
|
||||
self._taskmap.pop(tid, None)
|
||||
if wake:
|
||||
self._wakeup.set()
|
||||
|
||||
def start(self):
|
||||
if self._monitor_task and not self._monitor_task.done():
|
||||
self._monitor_task.cancel()
|
||||
# Start the monitor
|
||||
self._monitor_task = asyncio.create_task(self.monitor())
|
||||
return self._monitor_task
|
||||
|
||||
async def monitor(self):
|
||||
"""
|
||||
Start the monitor.
|
||||
Executes the tasks in `self.tasks` at the specified time.
|
||||
|
||||
This will shield task execution from cancellation
|
||||
to avoid partial states.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
self._wakeup.clear()
|
||||
if not self._tasklist:
|
||||
# No tasks left, just sleep until wakeup
|
||||
await self._wakeup.wait()
|
||||
else:
|
||||
# Get the next task, sleep until wakeup or it is ready to run
|
||||
nextid = self._tasklist[-1]
|
||||
nexttime = self._taskmap[nextid]
|
||||
sleep_for = nexttime - utc_now().timestamp()
|
||||
try:
|
||||
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
|
||||
except asyncio.TimeoutError:
|
||||
# Ready to run the task
|
||||
self._tasklist.pop()
|
||||
self._taskmap.pop(nextid, None)
|
||||
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
|
||||
else:
|
||||
# Wakeup task fired, loop again
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# Log closure and wait for remaining tasks
|
||||
# A second cancellation will also cancel the tasks
|
||||
logger.debug(
|
||||
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
|
||||
f"Waiting for {len(self._running)} running tasks to complete."
|
||||
)
|
||||
await asyncio.gather(*self._running.values(), return_exceptions=True)
|
||||
|
||||
async def _run(self, taskid: Taskid) -> None:
|
||||
# Execute the task, respecting the ratelimit bucket
|
||||
if self._bucket is not None:
|
||||
# IMPLEMENTATION NOTE:
|
||||
# Bucket.wait() should guarantee not more than n tasks/second are run
|
||||
# and that a request directly afterwards will _not_ raise BucketFull
|
||||
# Make sure that only one waiter is actually waiting on its sleep task
|
||||
# The other waiters should be waiting on a lock around the sleep task
|
||||
# Waiters are executed in wait-order, so if we only let a single waiter in
|
||||
# we shouldn't get collisions.
|
||||
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
|
||||
# Or we will lose thread-safety for BucketFull
|
||||
await self._bucket.wait()
|
||||
fut = asyncio.create_task(self.run_task(taskid))
|
||||
try:
|
||||
await asyncio.shield(fut)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# Protect the monitor loop from any other exceptions
|
||||
logger.exception(
|
||||
f"Ignoring exception in task monitor {self.__class__.__name__} while "
|
||||
f"executing <taskid: {taskid}>"
|
||||
)
|
||||
finally:
|
||||
self._running.pop(taskid)
|
||||
|
||||
async def run_task(self, taskid: Taskid):
|
||||
"""
|
||||
Execute the task with the given taskid.
|
||||
|
||||
Default implementation executes `self.executor` if it exists,
|
||||
otherwise raises NotImplementedError.
|
||||
"""
|
||||
if self.executor is not None:
|
||||
await self.executor(taskid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
173
src/utils/ratelimits.py
Normal file
173
src/utils/ratelimits.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
|
||||
from meta.errors import SafeCancellation
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
|
||||
class BucketFull(Exception):
|
||||
"""
|
||||
Throw when a requested Bucket is already full
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BucketOverFull(BucketFull):
|
||||
"""
|
||||
Throw when a requested Bucket is overfull
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Bucket:
|
||||
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
||||
|
||||
def __init__(self, max_level, empty_time):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
self.leak_rate = max_level / empty_time
|
||||
|
||||
self._level = 0
|
||||
self._last_checked = time.monotonic()
|
||||
|
||||
self._last_full = False
|
||||
self._wait_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def full(self) -> bool:
|
||||
"""
|
||||
Return whether the bucket is 'full',
|
||||
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
||||
"""
|
||||
self._leak()
|
||||
return self._level + 1 > self.max_level
|
||||
|
||||
@property
|
||||
def overfull(self):
|
||||
self._leak()
|
||||
return self._level > self.max_level
|
||||
|
||||
@property
|
||||
def delay(self):
|
||||
self._leak()
|
||||
if self._level + 1 > self.max_level:
|
||||
delay = (self._level + 1 - self.max_level) * self.leak_rate
|
||||
else:
|
||||
delay = 0
|
||||
return delay
|
||||
|
||||
def _leak(self):
|
||||
if self._level:
|
||||
elapsed = time.monotonic() - self._last_checked
|
||||
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||
|
||||
self._last_checked = time.monotonic()
|
||||
|
||||
def request(self):
|
||||
self._leak()
|
||||
if self._level > self.max_level:
|
||||
raise BucketOverFull
|
||||
elif self._level == self.max_level:
|
||||
self._level += 1
|
||||
if self._last_full:
|
||||
raise BucketOverFull
|
||||
else:
|
||||
self._last_full = True
|
||||
raise BucketFull
|
||||
else:
|
||||
self._last_full = False
|
||||
self._level += 1
|
||||
|
||||
def fill(self):
|
||||
self._leak()
|
||||
self._level = max(self._level, self.max_level + 1)
|
||||
|
||||
async def wait(self):
|
||||
"""
|
||||
Wait until the bucket has room.
|
||||
|
||||
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
||||
"""
|
||||
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||
# Otherwise multiple waiters will have the same delay,
|
||||
# and race for the wakeup after sleep.
|
||||
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
|
||||
async with self._wait_lock:
|
||||
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||
# or a synchronous request overflows the bucket while we are waiting.
|
||||
while self.full:
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
async def wrapped(self, coro):
|
||||
await self.wait()
|
||||
self.request()
|
||||
await coro
|
||||
|
||||
|
||||
class RateLimit:
|
||||
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
|
||||
self.error = error or "Too many requests, please slow down!"
|
||||
self.buckets = cache
|
||||
|
||||
def request_for(self, key):
|
||||
if not (bucket := self.buckets.get(key, None)):
|
||||
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||
|
||||
try:
|
||||
bucket.request()
|
||||
except BucketOverFull:
|
||||
raise SafeCancellation(details="Bucket overflow")
|
||||
except BucketFull:
|
||||
raise SafeCancellation(self.error, details="Bucket full")
|
||||
|
||||
def ward(self, member=True, key=None):
|
||||
"""
|
||||
Command ratelimit decorator.
|
||||
"""
|
||||
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(ctx, *args, **kwargs):
|
||||
self.request_for(key(ctx))
|
||||
return await func(ctx, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def limit_concurrency(aws, limit):
|
||||
"""
|
||||
Run provided awaitables concurrently,
|
||||
ensuring that no more than `limit` are running at once.
|
||||
"""
|
||||
aws = iter(aws)
|
||||
aws_ended = False
|
||||
pending = set()
|
||||
count = 0
|
||||
logger.debug("Starting limited concurrency executor")
|
||||
|
||||
while pending or not aws_ended:
|
||||
while len(pending) < limit and not aws_ended:
|
||||
aw = next(aws, None)
|
||||
if aw is None:
|
||||
aws_ended = True
|
||||
else:
|
||||
pending.add(asyncio.create_task(aw))
|
||||
count += 1
|
||||
|
||||
if not pending:
|
||||
break
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
while done:
|
||||
yield done.pop()
|
||||
logger.debug(f"Completed {count} tasks")
|
||||
8
src/utils/ui/__init__.py
Normal file
8
src/utils/ui/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .hooked import *
|
||||
from .leo import *
|
||||
from .micros import *
|
||||
59
src/utils/ui/hooked.py
Normal file
59
src/utils/ui/hooked.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
import discord
|
||||
from discord.ui.item import Item
|
||||
from discord.ui.button import Button
|
||||
|
||||
from .leo import LeoUI
|
||||
|
||||
__all__ = (
|
||||
'HookedItem',
|
||||
'AButton',
|
||||
'AsComponents'
|
||||
)
|
||||
|
||||
|
||||
class HookedItem:
|
||||
"""
|
||||
Mixin for Item classes allowing an instance to be used as a callback decorator.
|
||||
"""
|
||||
def __init__(self, *args, pass_kwargs={}, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pass_kwargs = pass_kwargs
|
||||
|
||||
def __call__(self, coro):
|
||||
async def wrapped(interaction, **kwargs):
|
||||
return await coro(interaction, self, **(self.pass_kwargs | kwargs))
|
||||
self.callback = wrapped
|
||||
return self
|
||||
|
||||
|
||||
class AButton(HookedItem, Button):
|
||||
...
|
||||
|
||||
|
||||
class AsComponents(LeoUI):
|
||||
"""
|
||||
Simple container class to accept a number of Items and turn them into an attachable View.
|
||||
"""
|
||||
def __init__(self, *items, pass_kwargs={}, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pass_kwargs = pass_kwargs
|
||||
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
|
||||
async def _scheduled_task(self, item: Item, interaction: discord.Interaction):
|
||||
try:
|
||||
item._refresh_state(interaction, interaction.data) # type: ignore
|
||||
|
||||
allow = await self.interaction_check(interaction)
|
||||
if not allow:
|
||||
return
|
||||
|
||||
if self.timeout:
|
||||
self.__timeout_expiry = time.monotonic() + self.timeout
|
||||
|
||||
await item.callback(interaction, **self.pass_kwargs)
|
||||
except Exception as e:
|
||||
return await self.on_error(interaction, e, item)
|
||||
485
src/utils/ui/leo.py
Normal file
485
src/utils/ui/leo.py
Normal file
@@ -0,0 +1,485 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextvars import copy_context, Context
|
||||
|
||||
import discord
|
||||
from discord.ui import Modal, View, Item
|
||||
|
||||
from meta.logger import log_action_stack, logging_context
|
||||
from meta.errors import SafeCancellation
|
||||
|
||||
from . import logger
|
||||
from ..lib import MessageArgs, error_embed
|
||||
|
||||
__all__ = (
|
||||
'LeoUI',
|
||||
'MessageUI',
|
||||
'LeoModal',
|
||||
'error_handler_for'
|
||||
)
|
||||
|
||||
|
||||
class LeoUI(View):
|
||||
"""
|
||||
View subclass for small-scale user interfaces.
|
||||
|
||||
While a 'View' provides an interface for managing a collection of components,
|
||||
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
|
||||
The `LeoUI` also exposes more advanced cleanup and timeout methods,
|
||||
and preserves the context.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if context is None:
|
||||
self._context = copy_context()
|
||||
else:
|
||||
self._context = context
|
||||
|
||||
self._name = ui_name or self.__class__.__name__
|
||||
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
|
||||
|
||||
# List of slaved views to stop when this view stops
|
||||
self._slaves: List[View] = []
|
||||
|
||||
# TODO: Replace this with a substitutable ViewLayout class
|
||||
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
|
||||
|
||||
@property
|
||||
def _stopped(self) -> asyncio.Future:
|
||||
"""
|
||||
Return an future indicating whether the View has finished interacting.
|
||||
|
||||
Currently exposes a hidden attribute of the underlying View.
|
||||
May be reimplemented in future.
|
||||
"""
|
||||
return self._View__stopped
|
||||
|
||||
def to_components(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extending component generator to apply the set _layout, if it exists.
|
||||
"""
|
||||
if self._layout is not None:
|
||||
# Alternative rendering using layout
|
||||
components = []
|
||||
for i, row in enumerate(self._layout):
|
||||
# Skip empty rows
|
||||
if not row:
|
||||
continue
|
||||
|
||||
# Since we aren't relying on ViewWeights, manually check width here
|
||||
if sum(item.width for item in row) > 5:
|
||||
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
|
||||
|
||||
# Create the component dict for this row
|
||||
components.append({
|
||||
'type': 1,
|
||||
'components': [item.to_component_dict() for item in row]
|
||||
})
|
||||
else:
|
||||
components = super().to_components()
|
||||
|
||||
return components
|
||||
|
||||
def set_layout(self, *rows: tuple[Item, ...]) -> None:
|
||||
"""
|
||||
Set the layout of the rendered View as a matrix of items,
|
||||
or more precisely, a list of action rows.
|
||||
|
||||
This acts independently of the existing sorting with `_ViewWeights`,
|
||||
and overrides the sorting if applied.
|
||||
"""
|
||||
self._layout = rows
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
Coroutine to run when timeing out, stopping, or cancelling.
|
||||
Generally cleans up any open resources, and removes any leftover components.
|
||||
"""
|
||||
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
|
||||
return None
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Extends View.stop() to also stop all the slave views.
|
||||
Note that stopping is idempotent, so it is okay if close() also calls stop().
|
||||
"""
|
||||
for slave in self._slaves:
|
||||
slave.stop()
|
||||
super().stop()
|
||||
|
||||
async def close(self, msg=None):
|
||||
self.stop()
|
||||
await self.cleanup()
|
||||
|
||||
async def pre_timeout(self):
|
||||
"""
|
||||
Task to execute before actually timing out.
|
||||
This may cancel the timeout by refreshing or rescheduling it.
|
||||
(E.g. to ask the user whether they want to keep going.)
|
||||
|
||||
Default implementation does nothing.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def on_timeout(self):
|
||||
"""
|
||||
Task to execute after timeout is complete.
|
||||
Default implementation calls cleanup.
|
||||
"""
|
||||
await self.cleanup()
|
||||
|
||||
async def __dispatch_timeout(self):
|
||||
"""
|
||||
This essentially extends View._dispatch_timeout,
|
||||
to include a pre_timeout task
|
||||
which may optionally refresh and hence cancel the timeout.
|
||||
"""
|
||||
if self._View__stopped.done():
|
||||
# We are already stopped, nothing to do
|
||||
return
|
||||
|
||||
with logging_context(action='Timeout'):
|
||||
try:
|
||||
await self.pre_timeout()
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unhandled error caught while dispatching timeout for {self!r}.",
|
||||
extra={'with_ctx': True, 'action': 'Error'}
|
||||
)
|
||||
|
||||
# Check if we still need to timeout
|
||||
if self.timeout is None:
|
||||
# The timeout was removed entirely, silently walk away
|
||||
return
|
||||
|
||||
if self._View__stopped.done():
|
||||
# We stopped while waiting for the pre timeout.
|
||||
# Or maybe another thread timed us out
|
||||
# Either way, we are done here
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if self._View__timeout_expiry is not None and now < self._View__timeout_expiry:
|
||||
# The timeout was extended, make sure the timeout task is running then fade away
|
||||
if self._View__timeout_task is None or self._View__timeout_task.done():
|
||||
self._View__timeout_task = asyncio.create_task(self._View__timeout_task_impl())
|
||||
else:
|
||||
# Actually timeout, and call the post-timeout task for cleanup.
|
||||
self._really_timeout()
|
||||
await self.on_timeout()
|
||||
|
||||
def _dispatch_timeout(self):
|
||||
"""
|
||||
Overriding timeout method completely, to support interactive flow during timeout,
|
||||
and optional refreshing of the timeout.
|
||||
"""
|
||||
return self._context.run(asyncio.create_task, self.__dispatch_timeout())
|
||||
|
||||
def _really_timeout(self):
|
||||
"""
|
||||
Actuallly times out the View.
|
||||
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
|
||||
which is now handled by `__dispatch_timeout`.
|
||||
"""
|
||||
if self._View__stopped.done():
|
||||
return
|
||||
|
||||
if self._View__cancel_callback:
|
||||
self._View__cancel_callback(self)
|
||||
self._View__cancel_callback = None
|
||||
|
||||
self._View__stopped.set_result(True)
|
||||
|
||||
def _dispatch_item(self, *args, **kwargs):
|
||||
"""Extending event dispatch to run in the instantiation context."""
|
||||
return self._context.run(super()._dispatch_item, *args, **kwargs)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
|
||||
"""
|
||||
Default LeoUI error handle.
|
||||
This may be tail extended by subclasses to preserve the exception stack.
|
||||
"""
|
||||
try:
|
||||
raise error
|
||||
except SafeCancellation as e:
|
||||
if e.msg and not interaction.is_expired():
|
||||
try:
|
||||
if interaction.response.is_done():
|
||||
await interaction.followup.send(
|
||||
embed=error_embed(e.msg),
|
||||
ephemeral=True
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
embed=error_embed(e.msg),
|
||||
ephemeral=True
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a safe cancellation from LeoUI: {e.details}",
|
||||
extra={'action': 'Cancel'}
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r} from interaction: "
|
||||
f"{interaction.data}",
|
||||
extra={'with_ctx': True, 'action': 'UIError'}
|
||||
)
|
||||
# Explicitly handle the bugsplat ourselves
|
||||
splat = interaction.client.tree.bugsplat(interaction, error)
|
||||
await interaction.client.tree.error_reply(interaction, splat)
|
||||
|
||||
|
||||
class MessageUI(LeoUI):
|
||||
"""
|
||||
Simple single-message LeoUI, intended as a framework for UIs
|
||||
attached to a single interaction response.
|
||||
|
||||
UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, callerid: Optional[int] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# ----- UI state -----
|
||||
# User ID of the original caller (e.g. command author).
|
||||
# Mainly used for interaction usage checks and logging
|
||||
self._callerid = callerid
|
||||
|
||||
# Original interaction, if this UI is sent as an interaction response
|
||||
self._original: discord.Interaction = None
|
||||
|
||||
# Message holding the UI, when the UI is sent attached to a followup
|
||||
self._message: discord.Message = None
|
||||
|
||||
# Refresh lock, to avoid cache collisions on refresh
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
if self._original is not None:
|
||||
return self._original.channel
|
||||
else:
|
||||
return self._message.channel
|
||||
|
||||
# ----- UI API -----
|
||||
async def run(self, interaction: discord.Interaction, **kwargs):
|
||||
"""
|
||||
Run the UI as a response or followup to the given interaction.
|
||||
|
||||
Should be extended if more complex run mechanics are needed
|
||||
(e.g. registering listeners or setting up caches).
|
||||
"""
|
||||
await self.draw(interaction, **kwargs)
|
||||
|
||||
async def refresh(self, *args, thinking: Optional[discord.Interaction] = None, **kwargs):
|
||||
"""
|
||||
Reload and redraw this UI.
|
||||
|
||||
Primarily a hook-method for use by parents and other controllers.
|
||||
Performs a full data and reload and refresh (maintaining UI state, e.g. page n).
|
||||
"""
|
||||
async with self._refresh_lock:
|
||||
# Reload data
|
||||
await self.reload()
|
||||
# Redraw UI message
|
||||
await self.redraw(thinking=thinking)
|
||||
|
||||
async def quit(self):
|
||||
"""
|
||||
Quit the UI.
|
||||
|
||||
This usually involves removing the original message,
|
||||
and stopping or closing the underlying View.
|
||||
"""
|
||||
for child in self._slaves:
|
||||
# TODO: Better to use duck typing or interface typing
|
||||
if isinstance(child, MessageUI) and not child.is_finished():
|
||||
asyncio.create_task(child.quit())
|
||||
try:
|
||||
if self._original is not None and not self._original.is_expired():
|
||||
await self._original.delete_original_response()
|
||||
self._original = None
|
||||
if self._message is not None:
|
||||
await self._message.delete()
|
||||
self._message = None
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Note close() also runs cleanup and stop
|
||||
await self.close()
|
||||
|
||||
# ----- UI Flow -----
|
||||
async def interaction_check(self, interaction: discord.Interaction):
|
||||
"""
|
||||
Check the given interaction is authorised to use this UI.
|
||||
|
||||
Default implementation simply checks that the interaction is
|
||||
from the original caller.
|
||||
Extend for more complex logic.
|
||||
"""
|
||||
return interaction.user.id == self._callerid
|
||||
|
||||
async def make_message(self) -> MessageArgs:
|
||||
"""
|
||||
Create the UI message body, depening on the current state.
|
||||
|
||||
Called upon each redraw.
|
||||
Should handle caching if message construction is for some reason intensive.
|
||||
|
||||
Must be implemented by concrete UI subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def refresh_layout(self):
|
||||
"""
|
||||
Asynchronously refresh the message components,
|
||||
and explicitly set the message component layout.
|
||||
|
||||
Called just before redrawing, before `make_message`.
|
||||
|
||||
Must be implemented by concrete UI subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reload(self):
|
||||
"""
|
||||
Reload and recompute the underlying data for this UI.
|
||||
|
||||
Must be implemented by concrete UI subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def draw(self, interaction, force_followup=False, **kwargs):
|
||||
"""
|
||||
Send the UI as a response or followup to the given interaction.
|
||||
|
||||
If the interaction has been responded to, or `force_followup` is set,
|
||||
creates a followup message instead of a response to the interaction.
|
||||
"""
|
||||
# Initial data loading
|
||||
await self.reload()
|
||||
# Set the UI layout
|
||||
await self.refresh_layout()
|
||||
# Fetch message arguments
|
||||
args = await self.make_message()
|
||||
|
||||
as_followup = force_followup or interaction.response.is_done()
|
||||
if as_followup:
|
||||
self._message = await interaction.followup.send(**args.send_args, **kwargs, view=self)
|
||||
else:
|
||||
self._original = interaction
|
||||
await interaction.response.send_message(**args.send_args, **kwargs, view=self)
|
||||
|
||||
async def send(self, channel: discord.abc.Messageable, **kwargs):
|
||||
"""
|
||||
Alternative to draw() which uses a discord.abc.Messageable.
|
||||
"""
|
||||
await self.reload()
|
||||
await self.refresh_layout()
|
||||
args = await self.make_message()
|
||||
self._message = await channel.send(**args.send_args, view=self)
|
||||
|
||||
async def _redraw(self, args):
|
||||
if self._original and not self._original.is_expired():
|
||||
await self._original.edit_original_response(**args.edit_args, view=self)
|
||||
elif self._message:
|
||||
await self._message.edit(**args.edit_args, view=self)
|
||||
else:
|
||||
# Interaction expired or already closed. Quietly cleanup.
|
||||
await self.close()
|
||||
|
||||
async def redraw(self, thinking: Optional[discord.Interaction] = None):
|
||||
"""
|
||||
Update the output message for this UI.
|
||||
|
||||
If a thinking interaction is provided, deletes the response while redrawing.
|
||||
"""
|
||||
await self.refresh_layout()
|
||||
args = await self.make_message()
|
||||
|
||||
if thinking is not None and not thinking.is_expired() and thinking.response.is_done():
|
||||
asyncio.create_task(thinking.delete_original_response())
|
||||
|
||||
try:
|
||||
await self._redraw(args)
|
||||
except discord.HTTPException as e:
|
||||
# Unknown communication error, nothing we can reliably do. Exit quietly.
|
||||
logger.warning(
|
||||
f"Unexpected UI redraw failure occurred in {self}: {repr(e)}",
|
||||
)
|
||||
await self.close()
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
Remove message components from interaction response, if possible.
|
||||
|
||||
Extend to remove listeners or clean up caches.
|
||||
`cleanup` is always called when the UI is exiting,
|
||||
through timeout or user-driven closure.
|
||||
"""
|
||||
try:
|
||||
if self._original is not None and not self._original.is_expired():
|
||||
await self._original.edit_original_response(view=None)
|
||||
self._original = None
|
||||
if self._message is not None:
|
||||
await self._message.edit(view=None)
|
||||
self._message = None
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
|
||||
class LeoModal(Modal):
|
||||
"""
|
||||
Context-aware Modal class.
|
||||
"""
|
||||
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if context is None:
|
||||
self._context = copy_context()
|
||||
else:
|
||||
self._context = context
|
||||
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
|
||||
|
||||
def _dispatch_submit(self, *args, **kwargs):
|
||||
"""
|
||||
Extending event dispatch to run in the instantiation context.
|
||||
"""
|
||||
return self._context.run(super()._dispatch_submit, *args, **kwargs)
|
||||
|
||||
def _dispatch_item(self, *args, **kwargs):
|
||||
"""Extending event dispatch to run in the instantiation context."""
|
||||
return self._context.run(super()._dispatch_item, *args, **kwargs)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
|
||||
"""
|
||||
Default LeoModal error handle.
|
||||
This may be tail extended by subclasses to preserve the exception stack.
|
||||
"""
|
||||
try:
|
||||
raise error
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unhandled interaction exception occurred in {self!r}. Interaction: {interaction.data}",
|
||||
extra={'with_ctx': True, 'action': 'ModalError'}
|
||||
)
|
||||
# Explicitly handle the bugsplat ourselves
|
||||
splat = interaction.client.tree.bugsplat(interaction, error)
|
||||
await interaction.client.tree.error_reply(interaction, splat)
|
||||
|
||||
|
||||
def error_handler_for(exc):
|
||||
def wrapper(coro):
|
||||
coro._ui_error_handler_for_ = exc
|
||||
return coro
|
||||
return wrapper
|
||||
329
src/utils/ui/micros.py
Normal file
329
src/utils/ui/micros.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
|
||||
import functools
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.ui import TextInput
|
||||
from discord.ui.button import button
|
||||
|
||||
from meta.logger import logging_context
|
||||
from meta.errors import ResponseTimedOut
|
||||
|
||||
from .leo import LeoModal, LeoUI
|
||||
|
||||
__all__ = (
|
||||
'FastModal',
|
||||
'ModalRetryUI',
|
||||
'Confirm',
|
||||
'input',
|
||||
)
|
||||
|
||||
|
||||
class FastModal(LeoModal):
|
||||
__class_error_handlers__ = []
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
error_handlers = {}
|
||||
for base in reversed(cls.__mro__):
|
||||
for name, member in base.__dict__.items():
|
||||
if hasattr(member, '_ui_error_handler_for_'):
|
||||
error_handlers[name] = member
|
||||
|
||||
cls.__class_error_handlers__ = list(error_handlers.values())
|
||||
|
||||
def __init__error_handlers__(self):
|
||||
handlers = {}
|
||||
for handler in self.__class_error_handlers__:
|
||||
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
|
||||
return handlers
|
||||
|
||||
def __init__(self, *items: TextInput, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
|
||||
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
|
||||
self._error_handlers = self.__init__error_handlers__()
|
||||
|
||||
def error_handler(self, exception):
|
||||
def wrapper(coro):
|
||||
self._error_handlers[exception] = coro
|
||||
return coro
|
||||
return wrapper
|
||||
|
||||
async def wait_for(self, check=None, timeout=None):
|
||||
# Wait for _result or timeout
|
||||
# If we timeout, or the view times out, raise TimeoutError
|
||||
# Otherwise, return the Interaction
|
||||
# This allows multiple listeners and callbacks to wait on
|
||||
while True:
|
||||
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
|
||||
if check is not None:
|
||||
if not check(result):
|
||||
continue
|
||||
return result
|
||||
|
||||
async def on_timeout(self):
|
||||
self._result.set_exception(asyncio.TimeoutError)
|
||||
|
||||
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
|
||||
def wrapper(coro):
|
||||
async def wrapped_callback(interaction):
|
||||
with logging_context(action=coro.__name__):
|
||||
if check is not None:
|
||||
if not check(interaction):
|
||||
return
|
||||
try:
|
||||
await coro(interaction, *pass_args, **pass_kwargs)
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
if once:
|
||||
self._waiters.remove(wrapped_callback)
|
||||
self._waiters.append(wrapped_callback)
|
||||
return wrapper
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
|
||||
try:
|
||||
# First let our error handlers have a go
|
||||
# If there is no handler for this error, or the handlers themselves error,
|
||||
# drop to the superclass error handler implementation.
|
||||
try:
|
||||
raise error
|
||||
except tuple(self._error_handlers.keys()) as e:
|
||||
# If an error handler is registered for this exception, run it.
|
||||
for cls, handler in self._error_handlers.items():
|
||||
if isinstance(e, cls):
|
||||
await handler(interaction, e)
|
||||
except Exception as error:
|
||||
await super().on_error(interaction, error)
|
||||
|
||||
async def on_submit(self, interaction):
|
||||
print("On submit")
|
||||
old_result = self._result
|
||||
self._result = asyncio.get_event_loop().create_future()
|
||||
old_result.set_result(interaction)
|
||||
|
||||
tasks = []
|
||||
for waiter in self._waiters:
|
||||
task = asyncio.create_task(
|
||||
waiter(interaction),
|
||||
name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}"
|
||||
)
|
||||
tasks.append(task)
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def input(
|
||||
interaction: discord.Interaction,
|
||||
title: str,
|
||||
question: Optional[str] = None,
|
||||
field: Optional[TextInput] = None,
|
||||
timeout=180,
|
||||
**kwargs,
|
||||
) -> tuple[discord.Interaction, str]:
|
||||
"""
|
||||
Spawn a modal to accept input.
|
||||
Returns an (interaction, value) pair, with interaction not yet responded to.
|
||||
May raise asyncio.TimeoutError if the view times out.
|
||||
"""
|
||||
if field is None:
|
||||
field = TextInput(
|
||||
label=kwargs.get('label', question),
|
||||
**kwargs
|
||||
)
|
||||
modal = FastModal(
|
||||
field,
|
||||
title=title,
|
||||
timeout=timeout
|
||||
)
|
||||
await interaction.response.send_modal(modal)
|
||||
interaction = await modal.wait_for()
|
||||
return (interaction, field.value)
|
||||
|
||||
|
||||
class ModalRetryUI(LeoUI):
|
||||
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.modal = modal
|
||||
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
|
||||
|
||||
self.message = message
|
||||
|
||||
self._interaction = None
|
||||
|
||||
if label is not None:
|
||||
self.retry_button.label = label
|
||||
|
||||
@property
|
||||
def embed(self):
|
||||
return discord.Embed(
|
||||
title="Uh-Oh!",
|
||||
description=self.message,
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
|
||||
async def respond_to(self, interaction):
|
||||
self._interaction = interaction
|
||||
if interaction.response.is_done():
|
||||
await interaction.followup.send(embed=self.embed, ephemeral=True, view=self)
|
||||
else:
|
||||
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
|
||||
|
||||
@button(label="Retry")
|
||||
async def retry_button(self, interaction, butt):
|
||||
# Setting these here so they don't update in the meantime
|
||||
for item, value in self.item_values.items():
|
||||
item.default = value
|
||||
if self._interaction is not None:
|
||||
await self._interaction.delete_original_response()
|
||||
self._interaction = None
|
||||
await interaction.response.send_modal(self.modal)
|
||||
await self.close()
|
||||
|
||||
|
||||
class Confirm(LeoUI):
|
||||
"""
|
||||
Micro UI class implementing a confirmation question.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
confirm_msg: str
|
||||
The confirmation question to ask from the user.
|
||||
This is set as the description of the `embed` property.
|
||||
The `embed` may be further modified if required.
|
||||
permitted_id: Optional[int]
|
||||
The user id allowed to access this interaction.
|
||||
Other users will recieve an access denied error message.
|
||||
defer: bool
|
||||
Whether to defer the interaction response while handling the button.
|
||||
It may be useful to set this to `False` to obtain manual control
|
||||
over the interaction response flow (e.g. to send a modal or ephemeral message).
|
||||
The button press interaction may be accessed through `Confirm.interaction`.
|
||||
Default: True
|
||||
|
||||
Example
|
||||
-------
|
||||
```
|
||||
confirm = Confirm("Are you sure?", ctx.author.id)
|
||||
confirm.embed.colour = discord.Colour.red()
|
||||
confirm.confirm_button.label = "Yes I am sure"
|
||||
confirm.cancel_button.label = "No I am not sure"
|
||||
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||
except ResultTimedOut:
|
||||
return
|
||||
```
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
confirm_msg: str,
|
||||
permitted_id: Optional[int] = None,
|
||||
defer: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.confirm_msg = confirm_msg
|
||||
self.permitted_id = permitted_id
|
||||
self.defer = defer
|
||||
|
||||
self._embed: Optional[discord.Embed] = None
|
||||
self._result: asyncio.Future[bool] = asyncio.Future()
|
||||
|
||||
# Indicates whether we should delete the message or the interaction response
|
||||
self._is_followup: bool = False
|
||||
self._original: Optional[discord.Interaction] = None
|
||||
self._message: Optional[discord.Message] = None
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction):
|
||||
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
|
||||
|
||||
async def on_timeout(self):
|
||||
# Propagate timeout to result Future
|
||||
self._result.set_exception(ResponseTimedOut)
|
||||
await self.cleanup()
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
Cleanup the confirmation prompt by deleting it, if possible.
|
||||
Ignores any Discord errors that occur during the process.
|
||||
"""
|
||||
try:
|
||||
if self._is_followup and self._message:
|
||||
await self._message.delete()
|
||||
elif not self._is_followup and self._original and not self._original.is_expired():
|
||||
await self._original.delete_original_response()
|
||||
except discord.HTTPException:
|
||||
# A user probably already deleted the message
|
||||
# Anything could have happened, just ignore.
|
||||
pass
|
||||
|
||||
@button(label="Confirm")
|
||||
async def confirm_button(self, interaction: discord.Interaction, press):
|
||||
if self.defer:
|
||||
await interaction.response.defer()
|
||||
self._result.set_result(True)
|
||||
await self.close()
|
||||
|
||||
@button(label="Cancel")
|
||||
async def cancel_button(self, interaction: discord.Interaction, press):
|
||||
if self.defer:
|
||||
await interaction.response.defer()
|
||||
self._result.set_result(False)
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def embed(self):
|
||||
"""
|
||||
Confirmation embed shown to the user.
|
||||
This is cached, and may be modifed directly through the usual EmbedProxy API,
|
||||
or explicitly overwritten.
|
||||
"""
|
||||
if self._embed is None:
|
||||
self._embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=self.confirm_msg
|
||||
)
|
||||
return self._embed
|
||||
|
||||
@embed.setter
|
||||
def embed(self, value):
|
||||
self._embed = value
|
||||
|
||||
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
|
||||
"""
|
||||
Send this confirmation prompt in response to the provided interaction.
|
||||
Extra keyword arguments are passed to `Interaction.response.send_message`
|
||||
or `Interaction.send_followup`, depending on whether
|
||||
the provided interaction has already been responded to.
|
||||
|
||||
The `epehemeral` argument is handled specially,
|
||||
since the question message can only be deleted through `Interaction.delete_original_response`.
|
||||
|
||||
Waits on and returns the internal `result` Future.
|
||||
|
||||
Returns: bool
|
||||
True if the user pressed the confirm button.
|
||||
False if the user pressed the cancel button.
|
||||
Raises:
|
||||
ResponseTimedOut:
|
||||
If the user does not respond before the UI times out.
|
||||
"""
|
||||
self._original = interaction
|
||||
if interaction.response.is_done():
|
||||
# Interaction already responded to, send a follow up
|
||||
if ephemeral:
|
||||
raise ValueError("Cannot send an ephemeral response to a used interaction.")
|
||||
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
|
||||
self._is_followup = True
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
|
||||
)
|
||||
self._is_followup = False
|
||||
return await self._result
|
||||
|
||||
# TODO: Selector MicroUI for displaying options (<= 25)
|
||||
Reference in New Issue
Block a user