diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 4e75e6d..225ef1a 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -1,6 +1,6 @@ this_package = "modules" -active = [".sysadmin", ".voicefix", ".messagelogger", ".voicelog"] +active = [".sysadmin", ".voicefix", ".messagelogger", ".voicelog", ".yarn"] async def setup(bot): diff --git a/src/modules/yarn/__init__.py b/src/modules/yarn/__init__.py new file mode 100644 index 0000000..89bb98c --- /dev/null +++ b/src/modules/yarn/__init__.py @@ -0,0 +1,9 @@ +import logging + +logger = logging.getLogger(__name__) + + +async def setup(bot): + from .cog import YarnCog + + await bot.add_cog(YarnCog(bot)) diff --git a/src/modules/yarn/cog.py b/src/modules/yarn/cog.py new file mode 100644 index 0000000..4ba16eb --- /dev/null +++ b/src/modules/yarn/cog.py @@ -0,0 +1,129 @@ +from typing import Literal +from collections import defaultdict +import datetime as dt +from datetime import datetime, timedelta, UTC + +from data.queries import ORDER +import discord +from discord.ext import commands as cmds +from discord import app_commands as appcmds + +from meta import LionBot, LionCog, LionContext +from meta.logger import log_wrap +from utils.lib import strfdur, utc_now, strfdur, paginate_list, pager + +from modules.voicelog.plugin.data import VoiceLogSession + + +class YarnCog(LionCog): + """ + Assorted toys for Lilac + """ + + def __init__(self, bot: LionBot): + self.bot = bot + self.desired_voice = defaultdict(dict) + + @LionCog.listener("on_voice_state_update") + async def voicestate_muter(self, member, before, after): + if not after.channel: + return + target_state = self.desired_voice[member.guild.id].pop(member.id, None) + if target_state is None: + return + await member.edit(mute=target_state) + # TODO: Log using voicelog webhook + + @LionCog.listener("on_reaction_add") + async def lilac_confirms(self, reaction: discord.Reaction, user: discord.User): + if not reaction.me: + await reaction.message.add_reaction(reaction.emoji) + + @LionCog.listener("on_reaction_remove") + async def lilac_unconfirms(self, reaction: discord.Reaction, user: discord.User): + if reaction.me and reaction.count == 1: + await reaction.remove(self.bot.user) + + @cmds.hybrid_command(name="voicestate") + @cmds.has_guild_permissions(mute_members=True) + async def voicestate_cmd( + self, ctx, user: discord.Member, state: Literal["muted", "unmuted", "clear"] + ): + self.desired_voice[ctx.guild.id].pop(user.id, None) + + if state == "clear": + # We've already removed the saved state, don't do anything else. + ack = f"{user.mention} target voice state cleared!" + elif user.voice: + # If user is currently in channel, apply the state + if state == "muted": + await user.edit(mute=True) + ack = f"{user.mention} muted!" + elif state == "unmuted": + await user.edit(mute=False) + ack = f"{user.mention} unmuted!" + else: + # If user is not currently in channel, save the state + if state == "muted": + self.desired_voice[ctx.guild.id][user.id] = True + ack = f"{user.mention} will be muted!" + elif state == "unmuted": + self.desired_voice[ctx.guild.id][user.id] = False + ack = f"{user.mention} will be unmuted!" + await ctx.reply( + embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack) + ) + + @cmds.hybrid_command(name="topvoice") + async def topvoice_cmd(self, ctx): + """ + Show top voice members by total time. + """ + target_channelid = 1383707078740279366 + since_stamp = 1769832959 + + voicelogger = ctx.bot.get_cog("VoiceLogCog") + session_data = voicelogger.data.voicelog_sessions + + query = ( + session_data.select_where( + VoiceLogSession.joined_at + >= datetime.fromtimestamp(since_stamp, tz=UTC), + guildid=ctx.guild.id, + channelid=target_channelid, + ) + .select( + userid="userid", + total_time="SUM(COALESCE(duration, EXTRACT(EPOCH FROM (NOW() - joined_at))))", + ) + .order_by("total_time", ORDER.DESC) + .with_no_adapter() + ) + leaderboard = [(row["userid"], row["total_time"]) for row in await query] + + # Format for display and pager + # First collect names + names = {} + for uid, _ in leaderboard: + user = ctx.guild.get_member(uid) + if user is None: + try: + user = await ctx.bot.fetch_member(uid) + except discord.NotFound: + user = None + names[uid] = user.display_name if user else str(uid) + + lb_strings = [] + max_name_len = min((30, max(len(name) for name in names.values()))) + for i, (uid, total) in enumerate(leaderboard): + lb_strings.append( + "{:<{}}\t{:<9}".format( + names[uid], max_name_len, strfdur(total, short=False) + ) + ) + + page_len = 20 + title = "Voice Leaderboard" + pages = paginate_list(lb_strings, block_length=page_len, title=title) + + await ctx.pager(pages) diff --git a/src/utils/lib.py b/src/utils/lib.py index d7796fc..2d3defc 100644 --- a/src/utils/lib.py +++ b/src/utils/lib.py @@ -1,5 +1,6 @@ from io import StringIO from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any +import asyncio import collections import datetime import datetime as dt @@ -11,18 +12,24 @@ from contextvars import Context import discord from discord.partial_emoji import _EmojiTag -from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage +from discord import ( + Embed, + File, + GuildSticker, + StickerItem, + AllowedMentions, + Message, + MessageReference, + PartialMessage, +) from discord.ui import View from meta.errors import UserInputError -multiselect_regex = re.compile( - r"^([0-9, -]+)$", - re.DOTALL | re.IGNORECASE | re.VERBOSE -) -tick = '✅' -cross = '❌' +multiselect_regex = re.compile(r"^([0-9, -]+)$", re.DOTALL | re.IGNORECASE | re.VERBOSE) +tick = "✅" +cross = "❌" MISSING = object() @@ -31,6 +38,7 @@ class MessageArgs: """ Utility class for storing message creation and editing arguments. """ + # TODO: Overrides for mutually exclusive arguments, see Messageable.send @overload @@ -49,8 +57,7 @@ class MessageArgs: mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -68,8 +75,7 @@ class MessageArgs: mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -87,8 +93,7 @@ class MessageArgs: mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -106,17 +111,16 @@ class MessageArgs: mention_author: bool = ..., view: View = ..., suppress_embeds: bool = ..., - ) -> None: - ... + ) -> None: ... def __init__(self, **kwargs): self.kwargs = kwargs @property def send_args(self) -> dict: - if self.kwargs.get('view', MISSING) is None: + if self.kwargs.get("view", MISSING) is None: kwargs = self.kwargs.copy() - kwargs.pop('view') + kwargs.pop("view") else: kwargs = self.kwargs @@ -126,20 +130,25 @@ class MessageArgs: def edit_args(self) -> dict: args = {} kept = ( - 'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view' + "content", + "embed", + "embeds", + "delete_after", + "allowed_mentions", + "view", ) for k in kept: if k in self.kwargs: args[k] = self.kwargs[k] - if 'file' in self.kwargs: - args['attachments'] = [self.kwargs['file']] + if "file" in self.kwargs: + args["attachments"] = [self.kwargs["file"]] - if 'files' in self.kwargs: - args['attachments'] = self.kwargs['files'] + if "files" in self.kwargs: + args["attachments"] = self.kwargs["files"] - if 'suppress_embeds' in self.kwargs: - args['suppress'] = self.kwargs['suppress_embeds'] + if "suppress_embeds" in self.kwargs: + args["suppress"] = self.kwargs["suppress_embeds"] return args @@ -148,9 +157,9 @@ def tabulate( *fields: tuple[str, str], row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}", sub_format: str = "`{invis:<{pad}}{colon}`\t{value}", - colon: str = ':', + colon: str = ":", invis: str = "​", - **args + **args, ) -> list[str]: """ Turns a list of (property, value) pairs into @@ -181,7 +190,7 @@ def tabulate( for field in fields: key = field[0] value = field[1] - lines = value.split('\r\n') + lines = value.split("\r\n") row_line = row_format.format( invis=invis, @@ -190,7 +199,7 @@ def tabulate( colon=colon, value=lines[0], field=field, - **args + **args, ) if len(lines) > 1: row_lines = [row_line] @@ -200,15 +209,17 @@ def tabulate( pad=max_len + len(colon), colon=colon, value=line, - **args + **args, ) row_lines.append(sub_line) - row_line = '\n'.join(row_lines) + row_line = "\n".join(row_lines) rows.append(row_line) return rows -def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]: +def paginate_list( + item_list: list[str], block_length=20, style="markdown", title=None +) -> list[str]: """ Create pretty codeblock pages from a list of strings. @@ -229,8 +240,13 @@ def paginate_list(item_list: list[str], block_length=20, style="markdown", title List of pages, each formatted into a codeblock, and containing at most `block_length` of the provided strings. """ - lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)] - page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)] + lines = [ + "{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) + for i, line in enumerate(item_list) + ] + page_blocks = [ + lines[i : i + block_length] for i in range(0, len(lines), block_length) + ] pages = [] for i, block in enumerate(page_blocks): pagenum = "Page {}/{}".format(i + 1, len(page_blocks)) @@ -239,12 +255,18 @@ def paginate_list(item_list: list[str], block_length=20, style="markdown", title else: header = pagenum header_line = "=" * len(header) - full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else "" + full_header = ( + "{}\n{}\n".format(header, header_line) + if len(page_blocks) > 1 or title + else "" + ) pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block))) return pages -def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]: +def split_text( + text: str, blocksize=2000, code=True, syntax="", maxheight=50 +) -> list[str]: """ Break the text into blocks of maximum length blocksize If possible, break across nearby newlines. Otherwise just break at blocksize chars @@ -277,10 +299,10 @@ def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> if len(text) <= blocksize: blocks.append(text) break - text = text.strip('\n') + text = text.strip("\n") # Find the last newline in the prototype block - split_on = text[0:blocksize].rfind('\n') + split_on = text[0:blocksize].rfind("\n") split_on = blocksize if split_on < blocksize // 5 else split_on # Add the block and truncate the text @@ -313,15 +335,17 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) - A string containing a time from the datetime.timedelta object, in a readable format. Time units will be abbreviated if short was set to True. """ - output = [[delta.days, 'd' if short else ' day'], - [delta.seconds // 3600, 'h' if short else ' hour']] + output = [ + [delta.days, "d" if short else " day"], + [delta.seconds // 3600, "h" if short else " hour"], + ] if minutes: - output.append([delta.seconds // 60 % 60, 'm' if short else ' minute']) + output.append([delta.seconds // 60 % 60, "m" if short else " minute"]) if sec: - output.append([delta.seconds % 60, 's' if short else ' second']) + output.append([delta.seconds % 60, "s" if short else " second"]) for i in range(len(output)): if output[i][0] != 1 and not short: - output[i][1] += 's' # type: ignore + output[i][1] += "s" # type: ignore reply_msg = [] if output[0][0] != 0: reply_msg.append("{}{} ".format(output[0][0], output[0][1])) @@ -347,12 +371,14 @@ def _parse_dur(time_str: str) -> int: Returns: int The number of seconds the duration represents. """ - funcs = {'d': lambda x: x * 24 * 60 * 60, - 'h': lambda x: x * 60 * 60, - 'm': lambda x: x * 60, - 's': lambda x: x} + funcs = { + "d": lambda x: x * 24 * 60 * 60, + "h": lambda x: x * 60 * 60, + "m": lambda x: x * 60, + "s": lambda x: x, + } time_str = time_str.strip(" ,") - found = re.findall(r'(\d+)\s?(\w+?)', time_str) + found = re.findall(r"(\d+)\s?(\w+?)", time_str) seconds = 0 for bit in found: if bit[1] in funcs: @@ -373,25 +399,27 @@ def strfdur(duration: int, short=True, show_days=False) -> str: parts = [] if days: - unit = 'd' if short else (' days' if days != 1 else ' day') - parts.append('{}{}'.format(days, unit)) + unit = "d" if short else (" days" if days != 1 else " day") + parts.append("{}{}".format(days, unit)) if hours: - unit = 'h' if short else (' hours' if hours != 1 else ' hour') - parts.append('{}{}'.format(hours, unit)) + unit = "h" if short else (" hours" if hours != 1 else " hour") + parts.append("{}{}".format(hours, unit)) if minutes: - unit = 'm' if short else (' minutes' if minutes != 1 else ' minute') - parts.append('{}{}'.format(minutes, unit)) + unit = "m" if short else (" minutes" if minutes != 1 else " minute") + parts.append("{}{}".format(minutes, unit)) if seconds or duration == 0: - unit = 's' if short else (' seconds' if seconds != 1 else ' second') - parts.append('{}{}'.format(seconds, unit)) + unit = "s" if short else (" seconds" if seconds != 1 else " second") + parts.append("{}{}".format(seconds, unit)) if short: - return ' '.join(parts) + return " ".join(parts) else: - return ', '.join(parts) + return ", ".join(parts) -def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str: +def substitute_ranges( + ranges_str: str, max_match=20, max_range=1000, separator="," +) -> str: """ Substitutes a user provided list of numbers and ranges, and replaces the ranges by the corresponding list of numbers. @@ -407,6 +435,7 @@ def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=' The maximum length of range to replace. Attempting to replace a range longer than this will raise a `ValueError`. """ + def _repl(match): n1 = int(match.group(1)) n2 = int(match.group(2)) @@ -415,16 +444,18 @@ def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=' raise ValueError("Provided range is too large!") return separator.join(str(i) for i in range(n1, n2 + 1)) - return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match) + return re.sub(r"(\d+)\s*-\s*(\d+)", _repl, ranges_str, max_match) -def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]: +def parse_ranges( + ranges_str: str, ignore_errors=False, separator=",", **kwargs +) -> list[int]: """ Parses a user provided range string into a list of numbers. Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`. """ substituted = substitute_ranges(ranges_str, separator=separator, **kwargs) - _numbers = (item.strip() for item in substituted.split(',')) + _numbers = (item.strip() for item in substituted.split(",")) numbers = [item for item in _numbers if item] integers = [int(item) for item in numbers if item.isdigit()] @@ -438,7 +469,9 @@ def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) return integers -def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str: +def msg_string( + msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True +) -> str: """ Format a message into a string with various information, such as: the timestamp of the message, author, message content, and attachments. @@ -462,20 +495,26 @@ def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, """ timestr = "%I:%M %p, %d/%m/%Y" if tz: - time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr) + time = ( + iso8601.parse_date(msg.created_at.isoformat()) + .astimezone(tz) + .strftime(timestr) + ) else: time = msg.created_at.strftime(timestr) user = str(msg.author) attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url] if mask_link: attach_list = ["[Link]({})".format(url) for url in attach_list] - attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else "" + attachments = ( + "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else "" + ) return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format( time=time, user=user, line_break="\n" if line_break else "", message=msg.clean_content if clean else msg.content, - attachments=attachments + attachments=attachments, ) @@ -491,21 +530,23 @@ def convdatestring(datestring: str) -> datetime.timedelta: Returns: datetime.timedelta A datetime.timedelta object formed from the string provided. """ - datestring = datestring.strip(' ,') + datestring = datestring.strip(" ,") datearray = [] - funcs = {'d': lambda x: x * 24 * 60 * 60, - 'h': lambda x: x * 60 * 60, - 'm': lambda x: x * 60, - 's': lambda x: x} - currentnumber = '' + funcs = { + "d": lambda x: x * 24 * 60 * 60, + "h": lambda x: x * 60 * 60, + "m": lambda x: x * 60, + "s": lambda x: x, + } + currentnumber = "" for char in datestring: if char.isdigit(): currentnumber += char else: - if currentnumber == '': + if currentnumber == "": continue datearray.append((int(currentnumber), char)) - currentnumber = '' + currentnumber = "" seconds = 0 if currentnumber: seconds += int(currentnumber) @@ -520,6 +561,7 @@ class _rawChannel(discord.abc.Messageable): Raw messageable class representing an arbitrary channel, not necessarially seen by the gateway. """ + def __init__(self, state, id): self._state = state self.id = id @@ -586,8 +628,12 @@ def join_list(string: list[str], nfs=False) -> str: """ # TODO: Probably not useful with localisation if len(string) > 1: - return "{}{} and {}{}".format((", ").join(string[:-1]), - "," if len(string) > 2 else "", string[-1], "" if nfs else ".") + return "{}{} and {}{}".format( + (", ").join(string[:-1]), + "," if len(string) > 2 else "", + string[-1], + "" if nfs else ".", + ) else: return "{}{}".format("".join(string), "" if nfs else ".") @@ -603,10 +649,8 @@ def jumpto(guildid: int, channeldid: int, messageid: int) -> str: """ Build a jump link for a message given its location. """ - return 'https://discord.com/channels/{}/{}/{}'.format( - guildid, - channeldid, - messageid + return "https://discord.com/channels/{}/{}/{}".format( + guildid, channeldid, messageid ) @@ -621,7 +665,7 @@ def multiple_replace(string: str, rep_dict: dict[str, str]) -> str: if rep_dict: pattern = re.compile( "|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]), - flags=re.DOTALL + flags=re.DOTALL, ) return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string) else: @@ -644,12 +688,16 @@ def parse_ids(idstr: str) -> List[int]: from meta.errors import UserInputError # Extract ids from string - splititer = (split.strip('<@!#&>, ') for split in idstr.split(',')) + splititer = (split.strip("<@!#&>, ") for split in idstr.split(",")) splits = [split for split in splititer if split] # Check they are integers - if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None: - raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id}) + if ( + not_id := next((split for split in splits if not split.isdigit()), None) + ) is not None: + raise UserInputError( + "Could not extract an id from `$item`!", {"orig": idstr, "item": not_id} + ) # Cast to integer and return return list(map(int, splits)) @@ -657,9 +705,7 @@ def parse_ids(idstr: str) -> List[int]: def error_embed(error, **kwargs) -> discord.Embed: embed = discord.Embed( - colour=discord.Colour.brand_red(), - description=error, - timestamp=utc_now() + colour=discord.Colour.brand_red(), description=error, timestamp=utc_now() ) return embed @@ -676,6 +722,7 @@ class Timezoned: Provides several useful localised properties. """ + __slots__ = () @property @@ -727,8 +774,10 @@ def replace_multiple(format_string, mapping): raise ValueError("Empty mapping passed.") keys = list(mapping.keys()) - pattern = '|'.join(f"({key})" for key in keys) - string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string) + pattern = "|".join(f"({key})" for key in keys) + string = re.sub( + pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string + ) return string @@ -748,6 +797,7 @@ def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str): return key + def recurse_map(func, obj, loc=[]): if isinstance(obj, dict): for k, v in obj.items(): @@ -761,7 +811,8 @@ def recurse_map(func, obj, loc=[]): loc.pop() else: obj = func(loc, obj) - return obj + return obj + async def check_dm(user: discord.User | discord.Member) -> bool: """ @@ -774,7 +825,7 @@ async def check_dm(user: discord.User | discord.Member) -> bool: (i.e. during a user instigated interaction). """ try: - await user.send('') + await user.send("") except discord.Forbidden: return False except discord.HTTPException: @@ -783,38 +834,36 @@ async def check_dm(user: discord.User | discord.Member) -> bool: async def command_lengths(tree) -> dict[str, int]: cmds = tree.get_commands() - payloads = [ - await cmd.get_translated_payload(tree.translator) - for cmd in cmds - ] + payloads = [await cmd.get_translated_payload(tree.translator) for cmd in cmds] lens = {} for command in payloads: - name = command['name'] + name = command["name"] crumbs = {} cmd_len = lens[name] = _recurse_length(command, crumbs, (name,)) - if name == 'configure' or cmd_len > 4000: + if name == "configure" or cmd_len > 4000: print(f"'{name}' over 4000. Breadcrumb Trail follows:") lines = [] for loc, val in crumbs.items(): - locstr = '.'.join(loc) + locstr = ".".join(loc) lines.append(f"{locstr}: {val}") - print('\n'.join(lines)) + print("\n".join(lines)) print(json.dumps(command, indent=2)) return lens + def _recurse_length(payload, breadcrumbs={}, header=()) -> int: total = 0 - total_header = (*header, '') + total_header = (*header, "") breadcrumbs[total_header] = 0 if isinstance(payload, dict): # Read strings that count towards command length # String length is length of longest localisation, including default. - for key in ('name', 'description', 'value'): + for key in ("name", "description", "value"): if key in payload: value = payload[key] if isinstance(value, str): - values = (value, *payload.get(key + '_localizations', {}).values()) + values = (value, *payload.get(key + "_localizations", {}).values()) maxlen = max(map(len, values)) total += maxlen breadcrumbs[(*header, key)] = maxlen @@ -824,7 +873,7 @@ def _recurse_length(payload, breadcrumbs={}, header=()) -> int: total += _recurse_length(value, breadcrumbs, loc) elif isinstance(payload, list): for i, item in enumerate(payload): - if isinstance(item, dict) and 'name' in item: + if isinstance(item, dict) and "name" in item: loc = (*header, f"{i}<{item['name']}>") else: loc = (*header, str(i)) @@ -837,33 +886,25 @@ def _recurse_length(payload, breadcrumbs={}, header=()) -> int: return total + def write_records(records: list[dict[str, Any]], stream: StringIO): if records: keys = records[0].keys() - stream.write(','.join(keys)) - stream.write('\n') + stream.write(",".join(keys)) + stream.write("\n") for record in records: - stream.write(','.join(map(str, record.values()))) - stream.write('\n') + stream.write(",".join(map(str, record.values()))) + stream.write("\n") parse_dur_exps = [ ( - r"(?P\d+)\s*(?:(d)|(day))", - 60 * 60 * 24, + r"(?P\d+)\s*(?:(d)|(day))", + 60 * 60 * 24, ), - ( - r"(?P\d+)\s*(?:(h)|(hour))", - 60 * 60 - ), - ( - r"(?P\d+)\s*(?:(m)|(min))", - 60 - ), - ( - r"(?P\d+)\s*(?:(s)|(sec))", - 1 - ) + (r"(?P\d+)\s*(?:(h)|(hour))", 60 * 60), + (r"(?P\d+)\s*(?:(m)|(min))", 60), + (r"(?P\d+)\s*(?:(s)|(sec))", 1), ] @@ -874,6 +915,131 @@ def parse_duration(string: str) -> Optional[int]: match = re.search(expr, string, flags=re.IGNORECASE) if match: found = True - seconds += int(match.group('value')) * multiplier + seconds += int(match.group("value")) * multiplier return seconds if found else None + + +async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs): + """ + Shows the user each page from the provided list `pages` one at a time, + providing reactions to page back and forth between pages. + This is done asynchronously, and returns after displaying the first page. + + Parameters + ---------- + pages: List(Union(str, discord.Embed)) + A list of either strings or embeds to display as the pages. + locked: bool + Whether only the `ctx.author` should be able to use the paging reactions. + kwargs: ... + Remaining keyword arguments are transparently passed to the reply context method. + + Returns: discord.Message + This is the output message, returned for easy deletion. + """ + cancel_emoji = cross + # Handle broken input + if len(pages) == 0: + raise ValueError("Pager cannot page with no pages!") + + # Post first page. Method depends on whether the page is an embed or not. + if isinstance(pages[start_at], discord.Embed): + out_msg = await ctx.reply(embed=pages[start_at], **kwargs) + else: + out_msg = await ctx.reply(pages[start_at], **kwargs) + + # Run the paging loop if required + if len(pages) > 1: + task = asyncio.create_task( + _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs) + ) + # ctx.tasks.append(task) + elif add_cancel: + await out_msg.add_reaction(cancel_emoji) + + # Return the output message + return out_msg + + +async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs): + """ + Asynchronous initialiser and loop for the `pager` utility above. + """ + # Page number + page = start_at + + # Add reactions to the output message + next_emoji = "▶" + prev_emoji = "◀" + cancel_emoji = cross + + try: + await out_msg.add_reaction(prev_emoji) + if add_cancel: + await out_msg.add_reaction(cancel_emoji) + await out_msg.add_reaction(next_emoji) + except discord.Forbidden: + # We don't have permission to add paging emojis + # Die as gracefully as we can + if ctx.guild: + perms = ctx.channel.permissions_for(ctx.guild.me) + if not perms.add_reactions: + await ctx.error_reply( + "Cannot page results because I do not have the `add_reactions` permission!" + ) + elif not perms.read_message_history: + await ctx.error_reply( + "Cannot page results because I do not have the `read_message_history` permission!" + ) + else: + await ctx.error_reply( + "Cannot page results due to insufficient permissions!" + ) + else: + await ctx.error_reply("Cannot page results!") + return + + # Check function to determine whether a reaction is valid + def check(reaction, user): + result = reaction.message.id == out_msg.id + result = result and str(reaction.emoji) in [next_emoji, prev_emoji] + result = result and not (user.id == ctx.bot.user.id) + result = result and not (locked and user != ctx.author) + return result + + # Begin loop + while True: + # Wait for a valid reaction, break if we time out + try: + reaction, user = await ctx.bot.wait_for( + "reaction_add", check=check, timeout=300 + ) + except asyncio.TimeoutError: + break + + # Attempt to remove the user's reaction, silently ignore errors + asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user)) + + # Change the page number + page += 1 if reaction.emoji == next_emoji else -1 + page %= len(pages) + + # Edit the message with the new page + active_page = pages[page] + if isinstance(active_page, discord.Embed): + await out_msg.edit(embed=active_page, **kwargs) + else: + await out_msg.edit(content=active_page, **kwargs) + + # Clean up by removing the reactions + try: + await out_msg.clear_reactions() + except discord.Forbidden: + try: + await out_msg.remove_reaction(next_emoji, ctx.client.user) + await out_msg.remove_reaction(prev_emoji, ctx.client.user) + except discord.NotFound: + pass + except discord.NotFound: + pass