from io import StringIO from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any import collections import datetime import datetime as dt import iso8601 # type: ignore import pytz import re import json from contextvars import Context import discord from discord.partial_emoji import _EmojiTag from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage from discord.ui import View from meta.errors import UserInputError multiselect_regex = re.compile( r"^([0-9, -]+)$", re.DOTALL | re.IGNORECASE | re.VERBOSE ) tick = '✅' cross = '❌' MISSING = object() class MessageArgs: """ Utility class for storing message creation and editing arguments. """ # TODO: Overrides for mutually exclusive arguments, see Messageable.send @overload def __init__( self, content: Optional[str] = ..., *, tts: bool = ..., embed: Embed = ..., file: File = ..., stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., delete_after: float = ..., nonce: Union[str, int] = ..., allowed_mentions: AllowedMentions = ..., reference: Union[Message, MessageReference, PartialMessage] = ..., mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., ) -> None: ... @overload def __init__( self, content: Optional[str] = ..., *, tts: bool = ..., embed: Embed = ..., files: Sequence[File] = ..., stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., delete_after: float = ..., nonce: Union[str, int] = ..., allowed_mentions: AllowedMentions = ..., reference: Union[Message, MessageReference, PartialMessage] = ..., mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., ) -> None: ... @overload def __init__( self, content: Optional[str] = ..., *, tts: bool = ..., embeds: Sequence[Embed] = ..., file: File = ..., stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., delete_after: float = ..., nonce: Union[str, int] = ..., allowed_mentions: AllowedMentions = ..., reference: Union[Message, MessageReference, PartialMessage] = ..., mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., ) -> None: ... @overload def __init__( self, content: Optional[str] = ..., *, tts: bool = ..., embeds: Sequence[Embed] = ..., files: Sequence[File] = ..., stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., delete_after: float = ..., nonce: Union[str, int] = ..., allowed_mentions: AllowedMentions = ..., reference: Union[Message, MessageReference, PartialMessage] = ..., mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., ) -> None: ... def __init__(self, **kwargs): self.kwargs = kwargs @property def send_args(self) -> dict: if self.kwargs.get('view', MISSING) is None: kwargs = self.kwargs.copy() kwargs.pop('view') else: kwargs = self.kwargs return kwargs @property def edit_args(self) -> dict: args = {} kept = ( 'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view' ) for k in kept: if k in self.kwargs: args[k] = self.kwargs[k] if 'file' in self.kwargs: args['attachments'] = [self.kwargs['file']] if 'files' in self.kwargs: args['attachments'] = self.kwargs['files'] if 'suppress_embeds' in self.kwargs: args['suppress'] = self.kwargs['suppress_embeds'] return args def tabulate( *fields: tuple[str, str], row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}", sub_format: str = "`{invis:<{pad}}{colon}`\t{value}", colon: str = ':', invis: str = "​", **args ) -> list[str]: """ Turns a list of (property, value) pairs into a pretty string with one `prop: value` pair each line, padded so that the colons in each line are lined up. Use `\\r\\n` in a value to break the line with padding. Parameters ---------- fields: List[tuple[str, str]] List of (key, value) pairs. row_format: str The format string used to format each row. sub_format: str The format string used to format each subline in a row. colon: str The colon character used. invis: str The invisible character used (to avoid Discord stripping the string). Returns: List[str] The list of resulting table rows. Each row corresponds to one (key, value) pair from fields. """ max_len = max(len(field[0]) for field in fields) rows = [] for field in fields: key = field[0] value = field[1] lines = value.split('\r\n') row_line = row_format.format( invis=invis, key=key, pad=max_len, colon=colon, value=lines[0], field=field, **args ) if len(lines) > 1: row_lines = [row_line] for line in lines[1:]: sub_line = sub_format.format( invis=invis, pad=max_len + len(colon), colon=colon, value=line, **args ) row_lines.append(sub_line) row_line = '\n'.join(row_lines) rows.append(row_line) return rows def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]: """ Create pretty codeblock pages from a list of strings. Parameters ---------- item_list: List[str] List of strings to paginate. block_length: int Maximum number of strings per page. style: str Codeblock style to use. Title formatting assumes the `markdown` style, and numbered lists work well with this. However, `markdown` sometimes messes up formatting in the list. title: str Optional title to add to the top of each page. Returns: List[str] List of pages, each formatted into a codeblock, and containing at most `block_length` of the provided strings. """ lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)] page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)] pages = [] for i, block in enumerate(page_blocks): pagenum = "Page {}/{}".format(i + 1, len(page_blocks)) if title: header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title else: header = pagenum header_line = "=" * len(header) full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else "" pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block))) return pages def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]: """ Break the text into blocks of maximum length blocksize If possible, break across nearby newlines. Otherwise just break at blocksize chars Parameters ---------- text: str Text to break into blocks. blocksize: int Maximum character length for each block. code: bool Whether to wrap each block in codeblocks (these are counted in the blocksize). syntax: str The markdown formatting language to use for the codeblocks, if applicable. maxheight: int The maximum number of lines in each block Returns: List[str] List of blocks, each containing at most `block_size` characters, of height at most `maxheight`. """ # Adjust blocksize to account for the codeblocks if required blocksize = blocksize - 8 - len(syntax) if code else blocksize # Build the blocks blocks = [] while True: # If the remaining text is already small enough, append it if len(text) <= blocksize: blocks.append(text) break text = text.strip('\n') # Find the last newline in the prototype block split_on = text[0:blocksize].rfind('\n') split_on = blocksize if split_on < blocksize // 5 else split_on # Add the block and truncate the text blocks.append(text[0:split_on]) text = text[split_on:] # Add the codeblock ticks and the code syntax header, if required if code: blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks] return blocks def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str: """ Convert a datetime.timedelta object into an easily readable duration string. Parameters ---------- delta: datetime.timedelta The timedelta object to convert into a readable string. sec: bool Whether to include the seconds from the timedelta object in the string. minutes: bool Whether to include the minutes from the timedelta object in the string. short: bool Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s"). Returns: str A string containing a time from the datetime.timedelta object, in a readable format. Time units will be abbreviated if short was set to True. """ output = [[delta.days, 'd' if short else ' day'], [delta.seconds // 3600, 'h' if short else ' hour']] if minutes: output.append([delta.seconds // 60 % 60, 'm' if short else ' minute']) if sec: output.append([delta.seconds % 60, 's' if short else ' second']) for i in range(len(output)): if output[i][0] != 1 and not short: output[i][1] += 's' # type: ignore reply_msg = [] if output[0][0] != 0: reply_msg.append("{}{} ".format(output[0][0], output[0][1])) if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2: reply_msg.append("{}{} ".format(output[1][0], output[1][1])) for i in range(2, len(output) - 1): reply_msg.append("{}{} ".format(output[i][0], output[i][1])) if not short and reply_msg: reply_msg.append("and ") reply_msg.append("{}{}".format(output[-1][0], output[-1][1])) return "".join(reply_msg) def _parse_dur(time_str: str) -> int: """ Parses a user provided time duration string into a timedelta object. Parameters ---------- time_str: str The time string to parse. String can include days, hours, minutes, and seconds. Returns: int The number of seconds the duration represents. """ funcs = {'d': lambda x: x * 24 * 60 * 60, 'h': lambda x: x * 60 * 60, 'm': lambda x: x * 60, 's': lambda x: x} time_str = time_str.strip(" ,") found = re.findall(r'(\d+)\s?(\w+?)', time_str) seconds = 0 for bit in found: if bit[1] in funcs: seconds += funcs[bit[1]](int(bit[0])) return seconds def strfdur(duration: int, short=True, show_days=False) -> str: """ Convert a duration given in seconds to a number of hours, minutes, and seconds. """ days = duration // (3600 * 24) if show_days else 0 hours = duration // 3600 if days: hours %= 24 minutes = duration // 60 % 60 seconds = duration % 60 parts = [] if days: unit = 'd' if short else (' days' if days != 1 else ' day') parts.append('{}{}'.format(days, unit)) if hours: unit = 'h' if short else (' hours' if hours != 1 else ' hour') parts.append('{}{}'.format(hours, unit)) if minutes: unit = 'm' if short else (' minutes' if minutes != 1 else ' minute') parts.append('{}{}'.format(minutes, unit)) if seconds or duration == 0: unit = 's' if short else (' seconds' if seconds != 1 else ' second') parts.append('{}{}'.format(seconds, unit)) if short: return ' '.join(parts) else: return ', '.join(parts) def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str: """ Substitutes a user provided list of numbers and ranges, and replaces the ranges by the corresponding list of numbers. Parameters ---------- ranges_str: str The string to ranges in. max_match: int The maximum number of ranges to replace. Any ranges exceeding this will be ignored. max_range: int The maximum length of range to replace. Attempting to replace a range longer than this will raise a `ValueError`. """ def _repl(match): n1 = int(match.group(1)) n2 = int(match.group(2)) if n2 - n1 > max_range: # TODO: Upgrade to SafeCancellation raise ValueError("Provided range is too large!") return separator.join(str(i) for i in range(n1, n2 + 1)) return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match) def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]: """ Parses a user provided range string into a list of numbers. Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`. """ substituted = substitute_ranges(ranges_str, separator=separator, **kwargs) _numbers = (item.strip() for item in substituted.split(',')) numbers = [item for item in _numbers if item] integers = [int(item) for item in numbers if item.isdigit()] if not ignore_errors and len(integers) != len(numbers): # TODO: Upgrade to SafeCancellation raise ValueError( "Couldn't parse the provided selection!\n" "Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`." ) return integers def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str: """ Format a message into a string with various information, such as: the timestamp of the message, author, message content, and attachments. Parameters ---------- msg: Message The message to format. mask_link: bool Whether to mask the URLs of any attachments. line_break: bool Whether a line break should be used in the string. tz: Timezone The timezone to use in the formatted message. clean: bool Whether to use the clean content of the original message. Returns: str A formatted string containing various information: User timezone, message author, message content, attachments """ timestr = "%I:%M %p, %d/%m/%Y" if tz: time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr) else: time = msg.created_at.strftime(timestr) user = str(msg.author) attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url] if mask_link: attach_list = ["[Link]({})".format(url) for url in attach_list] attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else "" return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format( time=time, user=user, line_break="\n" if line_break else "", message=msg.clean_content if clean else msg.content, attachments=attachments ) def convdatestring(datestring: str) -> datetime.timedelta: """ Convert a date string into a datetime.timedelta object. Parameters ---------- datestring: str The string to convert to a datetime.timedelta object. Returns: datetime.timedelta A datetime.timedelta object formed from the string provided. """ datestring = datestring.strip(' ,') datearray = [] funcs = {'d': lambda x: x * 24 * 60 * 60, 'h': lambda x: x * 60 * 60, 'm': lambda x: x * 60, 's': lambda x: x} currentnumber = '' for char in datestring: if char.isdigit(): currentnumber += char else: if currentnumber == '': continue datearray.append((int(currentnumber), char)) currentnumber = '' seconds = 0 if currentnumber: seconds += int(currentnumber) for i in datearray: if i[1] in funcs: seconds += funcs[i[1]](i[0]) return datetime.timedelta(seconds=seconds) class _rawChannel(discord.abc.Messageable): """ Raw messageable class representing an arbitrary channel, not necessarially seen by the gateway. """ def __init__(self, state, id): self._state = state self.id = id async def _get_channel(self): return discord.Object(self.id) async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message: """ Mails a message to a channelid which may be invisible to the gateway. Parameters: client: discord.Client The client to use for mailing. Must at least have static authentication and have a valid `_connection`. channelid: int The channel id to mail to. msg_args: Any Message keyword arguments which are passed transparently to `_rawChannel.send(...)`. """ # Create the raw channel channel = _rawChannel(client._connection, channelid) return await channel.send(**msg_args) class EmbedField(NamedTuple): name: str value: str inline: Optional[bool] = True def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]): """ Append embed fields to an embed. Parameters ---------- embed: discord.Embed The embed to add the field to. emb_fields: tuple The values to add to a field. name: str The name of the field. value: str The value of the field. inline: bool Whether the embed field should be inline or not. """ for field in emb_fields: embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2])) def join_list(string: list[str], nfs=False) -> str: """ Join a list together, separated with commas, plus add "and" to the beginning of the last value. Parameters ---------- string: list The list to join together. nfs: bool (no fullstops) Whether to exclude fullstops/periods from the output messages. If not provided, fullstops will be appended to the output. """ # TODO: Probably not useful with localisation if len(string) > 1: return "{}{} and {}{}".format((", ").join(string[:-1]), "," if len(string) > 2 else "", string[-1], "" if nfs else ".") else: return "{}{}".format("".join(string), "" if nfs else ".") def shard_of(shard_count: int, guildid: int) -> int: """ Calculate the shard number of a given guild. """ return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0 def jumpto(guildid: int, channeldid: int, messageid: int) -> str: """ Build a jump link for a message given its location. """ return 'https://discord.com/channels/{}/{}/{}'.format( guildid, channeldid, messageid ) def utc_now() -> datetime.datetime: """ Return the current timezone-aware utc timestamp. """ return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) def multiple_replace(string: str, rep_dict: dict[str, str]) -> str: if rep_dict: pattern = re.compile( "|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]), flags=re.DOTALL ) return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string) else: return string def recover_context(context: Context): for var in context: var.set(context[var]) def parse_ids(idstr: str) -> List[int]: """ Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids. Object agnostic, so all mention tokens are stripped. Raises UserInputError if an id is invalid, setting `orig` and `item` info fields. """ from meta.errors import UserInputError # Extract ids from string splititer = (split.strip('<@!#&>, ') for split in idstr.split(',')) splits = [split for split in splititer if split] # Check they are integers if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None: raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id}) # Cast to integer and return return list(map(int, splits)) def error_embed(error, **kwargs) -> discord.Embed: embed = discord.Embed( colour=discord.Colour.brand_red(), description=error, timestamp=utc_now() ) return embed class DotDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self class Timezoned: """ ABC mixin for objects with a set timezone. Provides several useful localised properties. """ __slots__ = () @property def timezone(self) -> pytz.timezone: """ Must be implemented by the deriving class! """ raise NotImplementedError @property def now(self): """ Return the current time localised to the object's timezone. """ return datetime.datetime.now(tz=self.timezone) @property def today(self): """ Return the start of the day localised to the object's timezone. """ now = self.now return now.replace(hour=0, minute=0, second=0, microsecond=0) @property def week_start(self): """ Return the start of the week in the object's timezone """ today = self.today return today - datetime.timedelta(days=today.weekday()) @property def month_start(self): """ Return the start of the current month in the object's timezone """ today = self.today return today.replace(day=1) def replace_multiple(format_string, mapping): """ Subsistutes the keys from the format_dict with their corresponding values. Substitution is non-chained, and done in a single pass via regex. """ if not mapping: raise ValueError("Empty mapping passed.") keys = list(mapping.keys()) pattern = '|'.join(f"({key})" for key in keys) string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string) return string def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str): """ Produces a distinguishing key for an Emoji or PartialEmoji. Equality checks using this key should act as expected. """ if isinstance(emoji, _EmojiTag): if emoji.id: key = str(emoji.id) else: key = str(emoji.name) else: key = str(emoji) return key def recurse_map(func, obj, loc=[]): if isinstance(obj, dict): for k, v in obj.items(): loc.append(k) obj[k] = recurse_map(func, v, loc) loc.pop() elif isinstance(obj, list): for i, item in enumerate(obj): loc.append(i) obj[i] = recurse_map(func, item) loc.pop() else: obj = func(loc, obj) return obj async def check_dm(user: discord.User | discord.Member) -> bool: """ Check whether we can direct message the given user. Assumes the client is initialised. This uses an always-failing HTTP request, so we need to be very very very careful that this is not used frequently. Optimally only at the explicit behest of the user (i.e. during a user instigated interaction). """ try: await user.send('') except discord.Forbidden: return False except discord.HTTPException: return True async def command_lengths(tree) -> dict[str, int]: cmds = tree.get_commands() payloads = [ await cmd.get_translated_payload(tree.translator) for cmd in cmds ] lens = {} for command in payloads: name = command['name'] crumbs = {} cmd_len = lens[name] = _recurse_length(command, crumbs, (name,)) if name == 'configure' or cmd_len > 4000: print(f"'{name}' over 4000. Breadcrumb Trail follows:") lines = [] for loc, val in crumbs.items(): locstr = '.'.join(loc) lines.append(f"{locstr}: {val}") print('\n'.join(lines)) print(json.dumps(command, indent=2)) return lens def _recurse_length(payload, breadcrumbs={}, header=()) -> int: total = 0 total_header = (*header, '') breadcrumbs[total_header] = 0 if isinstance(payload, dict): # Read strings that count towards command length # String length is length of longest localisation, including default. for key in ('name', 'description', 'value'): if key in payload: value = payload[key] if isinstance(value, str): values = (value, *payload.get(key + '_localizations', {}).values()) maxlen = max(map(len, values)) total += maxlen breadcrumbs[(*header, key)] = maxlen for key, value in payload.items(): loc = (*header, key) total += _recurse_length(value, breadcrumbs, loc) elif isinstance(payload, list): for i, item in enumerate(payload): if isinstance(item, dict) and 'name' in item: loc = (*header, f"{i}<{item['name']}>") else: loc = (*header, str(i)) total += _recurse_length(item, breadcrumbs, loc) if total: breadcrumbs[total_header] = total else: breadcrumbs.pop(total_header) return total def write_records(records: list[dict[str, Any]], stream: StringIO): if records: keys = records[0].keys() stream.write(','.join(keys)) stream.write('\n') for record in records: stream.write(','.join(map(str, record.values()))) stream.write('\n') parse_dur_exps = [ ( r"(?P\d+)\s*(?:(d)|(day))", 60 * 60 * 24, ), ( r"(?P\d+)\s*(?:(h)|(hour))", 60 * 60 ), ( r"(?P\d+)\s*(?:(m)|(min))", 60 ), ( r"(?P\d+)\s*(?:(s)|(sec))", 1 ) ] def parse_duration(string: str) -> Optional[int]: seconds = 0 found = False for expr, multiplier in parse_dur_exps: match = re.search(expr, string, flags=re.IGNORECASE) if match: found = True seconds += int(match.group('value')) * multiplier return seconds if found else None