Initial commit

This commit is contained in:
Foxfire
2025-09-04 03:03:18 +10:00
commit 1657089957
58 changed files with 9082 additions and 0 deletions

879
src/utils/lib.py Normal file
View File

@@ -0,0 +1,879 @@
from io import StringIO
from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any
import collections
import datetime
import datetime as dt
import iso8601 # type: ignore
import pytz
import re
import json
from contextvars import Context
import discord
from discord.partial_emoji import _EmojiTag
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord.ui import View
from meta.errors import UserInputError
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
MISSING = object()
class MessageArgs:
"""
Utility class for storing message creation and editing arguments.
"""
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
@overload
def __init__(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: Sequence[Embed] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
def __init__(self, **kwargs):
self.kwargs = kwargs
@property
def send_args(self) -> dict:
if self.kwargs.get('view', MISSING) is None:
kwargs = self.kwargs.copy()
kwargs.pop('view')
else:
kwargs = self.kwargs
return kwargs
@property
def edit_args(self) -> dict:
args = {}
kept = (
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
)
for k in kept:
if k in self.kwargs:
args[k] = self.kwargs[k]
if 'file' in self.kwargs:
args['attachments'] = [self.kwargs['file']]
if 'files' in self.kwargs:
args['attachments'] = self.kwargs['files']
if 'suppress_embeds' in self.kwargs:
args['suppress'] = self.kwargs['suppress_embeds']
return args
def tabulate(
*fields: tuple[str, str],
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
sub_format: str = "`{invis:<{pad}}{colon}`\t{value}",
colon: str = ':',
invis: str = "",
**args
) -> list[str]:
"""
Turns a list of (property, value) pairs into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Use `\\r\\n` in a value to break the line with padding.
Parameters
----------
fields: List[tuple[str, str]]
List of (key, value) pairs.
row_format: str
The format string used to format each row.
sub_format: str
The format string used to format each subline in a row.
colon: str
The colon character used.
invis: str
The invisible character used (to avoid Discord stripping the string).
Returns: List[str]
The list of resulting table rows.
Each row corresponds to one (key, value) pair from fields.
"""
max_len = max(len(field[0]) for field in fields)
rows = []
for field in fields:
key = field[0]
value = field[1]
lines = value.split('\r\n')
row_line = row_format.format(
invis=invis,
key=key,
pad=max_len,
colon=colon,
value=lines[0],
field=field,
**args
)
if len(lines) > 1:
row_lines = [row_line]
for line in lines[1:]:
sub_line = sub_format.format(
invis=invis,
pad=max_len + len(colon),
colon=colon,
value=line,
**args
)
row_lines.append(sub_line)
row_line = '\n'.join(row_lines)
rows.append(row_line)
return rows
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def _parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def strfdur(duration: int, short=True, show_days=False) -> str:
"""
Convert a duration given in seconds to a number of hours, minutes, and seconds.
"""
days = duration // (3600 * 24) if show_days else 0
hours = duration // 3600
if days:
hours %= 24
minutes = duration // 60 % 60
seconds = duration % 60
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
if short:
return ' '.join(parts)
else:
return ', '.join(parts)
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
# TODO: Upgrade to SafeCancellation
raise ValueError("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
_numbers = (item.strip() for item in substituted.split(','))
numbers = [item for item in _numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
if not ignore_errors and len(integers) != len(numbers):
# TODO: Upgrade to SafeCancellation
raise ValueError(
"Couldn't parse the provided selection!\n"
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
)
return integers
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.created_at.strftime(timestr)
user = str(msg.author)
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring: str) -> datetime.timedelta:
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message:
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
class EmbedField(NamedTuple):
name: str
value: str
inline: Optional[bool] = True
def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string: list[str], nfs=False) -> str:
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
# TODO: Probably not useful with localisation
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def shard_of(shard_count: int, guildid: int) -> int:
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)
def utc_now() -> datetime.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
return string
def recover_context(context: Context):
for var in context:
var.set(context[var])
def parse_ids(idstr: str) -> List[int]:
"""
Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids.
Object agnostic, so all mention tokens are stripped.
Raises UserInputError if an id is invalid,
setting `orig` and `item` info fields.
"""
from meta.errors import UserInputError
# Extract ids from string
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
splits = [split for split in splititer if split]
# Check they are integers
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
# Cast to integer and return
return list(map(int, splits))
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
)
return embed
class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
class Timezoned:
"""
ABC mixin for objects with a set timezone.
Provides several useful localised properties.
"""
__slots__ = ()
@property
def timezone(self) -> pytz.timezone:
"""
Must be implemented by the deriving class!
"""
raise NotImplementedError
@property
def now(self):
"""
Return the current time localised to the object's timezone.
"""
return datetime.datetime.now(tz=self.timezone)
@property
def today(self):
"""
Return the start of the day localised to the object's timezone.
"""
now = self.now
return now.replace(hour=0, minute=0, second=0, microsecond=0)
@property
def week_start(self):
"""
Return the start of the week in the object's timezone
"""
today = self.today
return today - datetime.timedelta(days=today.weekday())
@property
def month_start(self):
"""
Return the start of the current month in the object's timezone
"""
today = self.today
return today.replace(day=1)
def replace_multiple(format_string, mapping):
"""
Subsistutes the keys from the format_dict with their corresponding values.
Substitution is non-chained, and done in a single pass via regex.
"""
if not mapping:
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
return string
def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str):
"""
Produces a distinguishing key for an Emoji or PartialEmoji.
Equality checks using this key should act as expected.
"""
if isinstance(emoji, _EmojiTag):
if emoji.id:
key = str(emoji.id)
else:
key = str(emoji.name)
else:
key = str(emoji)
return key
def recurse_map(func, obj, loc=[]):
if isinstance(obj, dict):
for k, v in obj.items():
loc.append(k)
obj[k] = recurse_map(func, v, loc)
loc.pop()
elif isinstance(obj, list):
for i, item in enumerate(obj):
loc.append(i)
obj[i] = recurse_map(func, item)
loc.pop()
else:
obj = func(loc, obj)
return obj
async def check_dm(user: discord.User | discord.Member) -> bool:
"""
Check whether we can direct message the given user.
Assumes the client is initialised.
This uses an always-failing HTTP request,
so we need to be very very very careful that this is not used frequently.
Optimally only at the explicit behest of the user
(i.e. during a user instigated interaction).
"""
try:
await user.send('')
except discord.Forbidden:
return False
except discord.HTTPException:
return True
async def command_lengths(tree) -> dict[str, int]:
cmds = tree.get_commands()
payloads = [
await cmd.get_translated_payload(tree.translator)
for cmd in cmds
]
lens = {}
for command in payloads:
name = command['name']
crumbs = {}
cmd_len = lens[name] = _recurse_length(command, crumbs, (name,))
if name == 'configure' or cmd_len > 4000:
print(f"'{name}' over 4000. Breadcrumb Trail follows:")
lines = []
for loc, val in crumbs.items():
locstr = '.'.join(loc)
lines.append(f"{locstr}: {val}")
print('\n'.join(lines))
print(json.dumps(command, indent=2))
return lens
def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
total = 0
total_header = (*header, '')
breadcrumbs[total_header] = 0
if isinstance(payload, dict):
# Read strings that count towards command length
# String length is length of longest localisation, including default.
for key in ('name', 'description', 'value'):
if key in payload:
value = payload[key]
if isinstance(value, str):
values = (value, *payload.get(key + '_localizations', {}).values())
maxlen = max(map(len, values))
total += maxlen
breadcrumbs[(*header, key)] = maxlen
for key, value in payload.items():
loc = (*header, key)
total += _recurse_length(value, breadcrumbs, loc)
elif isinstance(payload, list):
for i, item in enumerate(payload):
if isinstance(item, dict) and 'name' in item:
loc = (*header, f"{i}<{item['name']}>")
else:
loc = (*header, str(i))
total += _recurse_length(item, breadcrumbs, loc)
if total:
breadcrumbs[total_header] = total
else:
breadcrumbs.pop(total_header)
return total
def write_records(records: list[dict[str, Any]], stream: StringIO):
if records:
keys = records[0].keys()
stream.write(','.join(keys))
stream.write('\n')
for record in records:
stream.write(','.join(map(str, record.values())))
stream.write('\n')
parse_dur_exps = [
(
r"(?P<value>\d+)\s*(?:(d)|(day))",
60 * 60 * 24,
),
(
r"(?P<value>\d+)\s*(?:(h)|(hour))",
60 * 60
),
(
r"(?P<value>\d+)\s*(?:(m)|(min))",
60
),
(
r"(?P<value>\d+)\s*(?:(s)|(sec))",
1
)
]
def parse_duration(string: str) -> Optional[int]:
seconds = 0
found = False
for expr, multiplier in parse_dur_exps:
match = re.search(expr, string, flags=re.IGNORECASE)
if match:
found = True
seconds += int(match.group('value')) * multiplier
return seconds if found else None