feat: Add yarn module for Lilac Misc.

Adds auto-reaction to emotes.
Adds 'voicestate' command for offline (un)muting.
Adds 'topvoice' command for current voice channel challenge.
This commit is contained in:
2026-02-25 18:45:35 +10:00
parent f432aa2b9d
commit 8339921f93
4 changed files with 426 additions and 122 deletions

View File

@@ -1,6 +1,6 @@
this_package = "modules"
active = [".sysadmin", ".voicefix", ".messagelogger", ".voicelog"]
active = [".sysadmin", ".voicefix", ".messagelogger", ".voicelog", ".yarn"]
async def setup(bot):

View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .cog import YarnCog
await bot.add_cog(YarnCog(bot))

129
src/modules/yarn/cog.py Normal file
View File

@@ -0,0 +1,129 @@
from typing import Literal
from collections import defaultdict
import datetime as dt
from datetime import datetime, timedelta, UTC
from data.queries import ORDER
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from utils.lib import strfdur, utc_now, strfdur, paginate_list, pager
from modules.voicelog.plugin.data import VoiceLogSession
class YarnCog(LionCog):
"""
Assorted toys for Lilac
"""
def __init__(self, bot: LionBot):
self.bot = bot
self.desired_voice = defaultdict(dict)
@LionCog.listener("on_voice_state_update")
async def voicestate_muter(self, member, before, after):
if not after.channel:
return
target_state = self.desired_voice[member.guild.id].pop(member.id, None)
if target_state is None:
return
await member.edit(mute=target_state)
# TODO: Log using voicelog webhook
@LionCog.listener("on_reaction_add")
async def lilac_confirms(self, reaction: discord.Reaction, user: discord.User):
if not reaction.me:
await reaction.message.add_reaction(reaction.emoji)
@LionCog.listener("on_reaction_remove")
async def lilac_unconfirms(self, reaction: discord.Reaction, user: discord.User):
if reaction.me and reaction.count == 1:
await reaction.remove(self.bot.user)
@cmds.hybrid_command(name="voicestate")
@cmds.has_guild_permissions(mute_members=True)
async def voicestate_cmd(
self, ctx, user: discord.Member, state: Literal["muted", "unmuted", "clear"]
):
self.desired_voice[ctx.guild.id].pop(user.id, None)
if state == "clear":
# We've already removed the saved state, don't do anything else.
ack = f"{user.mention} target voice state cleared!"
elif user.voice:
# If user is currently in channel, apply the state
if state == "muted":
await user.edit(mute=True)
ack = f"{user.mention} muted!"
elif state == "unmuted":
await user.edit(mute=False)
ack = f"{user.mention} unmuted!"
else:
# If user is not currently in channel, save the state
if state == "muted":
self.desired_voice[ctx.guild.id][user.id] = True
ack = f"{user.mention} will be muted!"
elif state == "unmuted":
self.desired_voice[ctx.guild.id][user.id] = False
ack = f"{user.mention} will be unmuted!"
await ctx.reply(
embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack)
)
@cmds.hybrid_command(name="topvoice")
async def topvoice_cmd(self, ctx):
"""
Show top voice members by total time.
"""
target_channelid = 1383707078740279366
since_stamp = 1769832959
voicelogger = ctx.bot.get_cog("VoiceLogCog")
session_data = voicelogger.data.voicelog_sessions
query = (
session_data.select_where(
VoiceLogSession.joined_at
>= datetime.fromtimestamp(since_stamp, tz=UTC),
guildid=ctx.guild.id,
channelid=target_channelid,
)
.select(
userid="userid",
total_time="SUM(COALESCE(duration, EXTRACT(EPOCH FROM (NOW() - joined_at))))",
)
.order_by("total_time", ORDER.DESC)
.with_no_adapter()
)
leaderboard = [(row["userid"], row["total_time"]) for row in await query]
# Format for display and pager
# First collect names
names = {}
for uid, _ in leaderboard:
user = ctx.guild.get_member(uid)
if user is None:
try:
user = await ctx.bot.fetch_member(uid)
except discord.NotFound:
user = None
names[uid] = user.display_name if user else str(uid)
lb_strings = []
max_name_len = min((30, max(len(name) for name in names.values())))
for i, (uid, total) in enumerate(leaderboard):
lb_strings.append(
"{:<{}}\t{:<9}".format(
names[uid], max_name_len, strfdur(total, short=False)
)
)
page_len = 20
title = "Voice Leaderboard"
pages = paginate_list(lb_strings, block_length=page_len, title=title)
await ctx.pager(pages)

View File

@@ -1,5 +1,6 @@
from io import StringIO
from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any
import asyncio
import collections
import datetime
import datetime as dt
@@ -11,18 +12,24 @@ from contextvars import Context
import discord
from discord.partial_emoji import _EmojiTag
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord import (
Embed,
File,
GuildSticker,
StickerItem,
AllowedMentions,
Message,
MessageReference,
PartialMessage,
)
from discord.ui import View
from meta.errors import UserInputError
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
multiselect_regex = re.compile(r"^([0-9, -]+)$", re.DOTALL | re.IGNORECASE | re.VERBOSE)
tick = ""
cross = ""
MISSING = object()
@@ -31,6 +38,7 @@ class MessageArgs:
"""
Utility class for storing message creation and editing arguments.
"""
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
@overload
@@ -49,8 +57,7 @@ class MessageArgs:
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
) -> None: ...
@overload
def __init__(
@@ -68,8 +75,7 @@ class MessageArgs:
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
) -> None: ...
@overload
def __init__(
@@ -87,8 +93,7 @@ class MessageArgs:
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
) -> None: ...
@overload
def __init__(
@@ -106,17 +111,16 @@ class MessageArgs:
mention_author: bool = ...,
view: View = ...,
suppress_embeds: bool = ...,
) -> None:
...
) -> None: ...
def __init__(self, **kwargs):
self.kwargs = kwargs
@property
def send_args(self) -> dict:
if self.kwargs.get('view', MISSING) is None:
if self.kwargs.get("view", MISSING) is None:
kwargs = self.kwargs.copy()
kwargs.pop('view')
kwargs.pop("view")
else:
kwargs = self.kwargs
@@ -126,20 +130,25 @@ class MessageArgs:
def edit_args(self) -> dict:
args = {}
kept = (
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
"content",
"embed",
"embeds",
"delete_after",
"allowed_mentions",
"view",
)
for k in kept:
if k in self.kwargs:
args[k] = self.kwargs[k]
if 'file' in self.kwargs:
args['attachments'] = [self.kwargs['file']]
if "file" in self.kwargs:
args["attachments"] = [self.kwargs["file"]]
if 'files' in self.kwargs:
args['attachments'] = self.kwargs['files']
if "files" in self.kwargs:
args["attachments"] = self.kwargs["files"]
if 'suppress_embeds' in self.kwargs:
args['suppress'] = self.kwargs['suppress_embeds']
if "suppress_embeds" in self.kwargs:
args["suppress"] = self.kwargs["suppress_embeds"]
return args
@@ -148,9 +157,9 @@ def tabulate(
*fields: tuple[str, str],
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
sub_format: str = "`{invis:<{pad}}{colon}`\t{value}",
colon: str = ':',
colon: str = ":",
invis: str = "",
**args
**args,
) -> list[str]:
"""
Turns a list of (property, value) pairs into
@@ -181,7 +190,7 @@ def tabulate(
for field in fields:
key = field[0]
value = field[1]
lines = value.split('\r\n')
lines = value.split("\r\n")
row_line = row_format.format(
invis=invis,
@@ -190,7 +199,7 @@ def tabulate(
colon=colon,
value=lines[0],
field=field,
**args
**args,
)
if len(lines) > 1:
row_lines = [row_line]
@@ -200,15 +209,17 @@ def tabulate(
pad=max_len + len(colon),
colon=colon,
value=line,
**args
**args,
)
row_lines.append(sub_line)
row_line = '\n'.join(row_lines)
row_line = "\n".join(row_lines)
rows.append(row_line)
return rows
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
def paginate_list(
item_list: list[str], block_length=20, style="markdown", title=None
) -> list[str]:
"""
Create pretty codeblock pages from a list of strings.
@@ -229,8 +240,13 @@ def paginate_list(item_list: list[str], block_length=20, style="markdown", title
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
lines = [
"{0:<5}{1:<5}".format("{}.".format(i + 1), str(line))
for i, line in enumerate(item_list)
]
page_blocks = [
lines[i : i + block_length] for i in range(0, len(lines), block_length)
]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
@@ -239,12 +255,18 @@ def paginate_list(item_list: list[str], block_length=20, style="markdown", title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
full_header = (
"{}\n{}\n".format(header, header_line)
if len(page_blocks) > 1 or title
else ""
)
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
def split_text(
text: str, blocksize=2000, code=True, syntax="", maxheight=50
) -> list[str]:
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
@@ -277,10 +299,10 @@ def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) ->
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
text = text.strip("\n")
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = text[0:blocksize].rfind("\n")
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
@@ -313,15 +335,17 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
output = [
[delta.days, "d" if short else " day"],
[delta.seconds // 3600, "h" if short else " hour"],
]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
output.append([delta.seconds // 60 % 60, "m" if short else " minute"])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
output.append([delta.seconds % 60, "s" if short else " second"])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
output[i][1] += "s" # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
@@ -347,12 +371,14 @@ def _parse_dur(time_str: str) -> int:
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
funcs = {
"d": lambda x: x * 24 * 60 * 60,
"h": lambda x: x * 60 * 60,
"m": lambda x: x * 60,
"s": lambda x: x,
}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
found = re.findall(r"(\d+)\s?(\w+?)", time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
@@ -373,25 +399,27 @@ def strfdur(duration: int, short=True, show_days=False) -> str:
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
unit = "d" if short else (" days" if days != 1 else " day")
parts.append("{}{}".format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
unit = "h" if short else (" hours" if hours != 1 else " hour")
parts.append("{}{}".format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
unit = "m" if short else (" minutes" if minutes != 1 else " minute")
parts.append("{}{}".format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
unit = "s" if short else (" seconds" if seconds != 1 else " second")
parts.append("{}{}".format(seconds, unit))
if short:
return ' '.join(parts)
return " ".join(parts)
else:
return ', '.join(parts)
return ", ".join(parts)
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
def substitute_ranges(
ranges_str: str, max_match=20, max_range=1000, separator=","
) -> str:
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
@@ -407,6 +435,7 @@ def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator='
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
@@ -415,16 +444,18 @@ def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator='
raise ValueError("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
return re.sub(r"(\d+)\s*-\s*(\d+)", _repl, ranges_str, max_match)
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
def parse_ranges(
ranges_str: str, ignore_errors=False, separator=",", **kwargs
) -> list[int]:
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
_numbers = (item.strip() for item in substituted.split(','))
_numbers = (item.strip() for item in substituted.split(","))
numbers = [item for item in _numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
@@ -438,7 +469,9 @@ def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs)
return integers
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
def msg_string(
msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True
) -> str:
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
@@ -462,20 +495,26 @@ def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None,
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
time = (
iso8601.parse_date(msg.created_at.isoformat())
.astimezone(tz)
.strftime(timestr)
)
else:
time = msg.created_at.strftime(timestr)
user = str(msg.author)
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
attachments = (
"\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
)
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
attachments=attachments,
)
@@ -491,21 +530,23 @@ def convdatestring(datestring: str) -> datetime.timedelta:
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datestring = datestring.strip(" ,")
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
funcs = {
"d": lambda x: x * 24 * 60 * 60,
"h": lambda x: x * 60 * 60,
"m": lambda x: x * 60,
"s": lambda x: x,
}
currentnumber = ""
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
if currentnumber == "":
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
currentnumber = ""
seconds = 0
if currentnumber:
seconds += int(currentnumber)
@@ -520,6 +561,7 @@ class _rawChannel(discord.abc.Messageable):
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
@@ -586,8 +628,12 @@ def join_list(string: list[str], nfs=False) -> str:
"""
# TODO: Probably not useful with localisation
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
return "{}{} and {}{}".format(
(", ").join(string[:-1]),
"," if len(string) > 2 else "",
string[-1],
"" if nfs else ".",
)
else:
return "{}{}".format("".join(string), "" if nfs else ".")
@@ -603,10 +649,8 @@ def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
return "https://discord.com/channels/{}/{}/{}".format(
guildid, channeldid, messageid
)
@@ -621,7 +665,7 @@ def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
flags=re.DOTALL,
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
@@ -644,12 +688,16 @@ def parse_ids(idstr: str) -> List[int]:
from meta.errors import UserInputError
# Extract ids from string
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
splititer = (split.strip("<@!#&>, ") for split in idstr.split(","))
splits = [split for split in splititer if split]
# Check they are integers
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
if (
not_id := next((split for split in splits if not split.isdigit()), None)
) is not None:
raise UserInputError(
"Could not extract an id from `$item`!", {"orig": idstr, "item": not_id}
)
# Cast to integer and return
return list(map(int, splits))
@@ -657,9 +705,7 @@ def parse_ids(idstr: str) -> List[int]:
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
colour=discord.Colour.brand_red(), description=error, timestamp=utc_now()
)
return embed
@@ -676,6 +722,7 @@ class Timezoned:
Provides several useful localised properties.
"""
__slots__ = ()
@property
@@ -727,8 +774,10 @@ def replace_multiple(format_string, mapping):
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
pattern = "|".join(f"({key})" for key in keys)
string = re.sub(
pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string
)
return string
@@ -748,6 +797,7 @@ def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str):
return key
def recurse_map(func, obj, loc=[]):
if isinstance(obj, dict):
for k, v in obj.items():
@@ -763,6 +813,7 @@ def recurse_map(func, obj, loc=[]):
obj = func(loc, obj)
return obj
async def check_dm(user: discord.User | discord.Member) -> bool:
"""
Check whether we can direct message the given user.
@@ -774,7 +825,7 @@ async def check_dm(user: discord.User | discord.Member) -> bool:
(i.e. during a user instigated interaction).
"""
try:
await user.send('')
await user.send("")
except discord.Forbidden:
return False
except discord.HTTPException:
@@ -783,38 +834,36 @@ async def check_dm(user: discord.User | discord.Member) -> bool:
async def command_lengths(tree) -> dict[str, int]:
cmds = tree.get_commands()
payloads = [
await cmd.get_translated_payload(tree.translator)
for cmd in cmds
]
payloads = [await cmd.get_translated_payload(tree.translator) for cmd in cmds]
lens = {}
for command in payloads:
name = command['name']
name = command["name"]
crumbs = {}
cmd_len = lens[name] = _recurse_length(command, crumbs, (name,))
if name == 'configure' or cmd_len > 4000:
if name == "configure" or cmd_len > 4000:
print(f"'{name}' over 4000. Breadcrumb Trail follows:")
lines = []
for loc, val in crumbs.items():
locstr = '.'.join(loc)
locstr = ".".join(loc)
lines.append(f"{locstr}: {val}")
print('\n'.join(lines))
print("\n".join(lines))
print(json.dumps(command, indent=2))
return lens
def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
total = 0
total_header = (*header, '')
total_header = (*header, "")
breadcrumbs[total_header] = 0
if isinstance(payload, dict):
# Read strings that count towards command length
# String length is length of longest localisation, including default.
for key in ('name', 'description', 'value'):
for key in ("name", "description", "value"):
if key in payload:
value = payload[key]
if isinstance(value, str):
values = (value, *payload.get(key + '_localizations', {}).values())
values = (value, *payload.get(key + "_localizations", {}).values())
maxlen = max(map(len, values))
total += maxlen
breadcrumbs[(*header, key)] = maxlen
@@ -824,7 +873,7 @@ def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
total += _recurse_length(value, breadcrumbs, loc)
elif isinstance(payload, list):
for i, item in enumerate(payload):
if isinstance(item, dict) and 'name' in item:
if isinstance(item, dict) and "name" in item:
loc = (*header, f"{i}<{item['name']}>")
else:
loc = (*header, str(i))
@@ -837,14 +886,15 @@ def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
return total
def write_records(records: list[dict[str, Any]], stream: StringIO):
if records:
keys = records[0].keys()
stream.write(','.join(keys))
stream.write('\n')
stream.write(",".join(keys))
stream.write("\n")
for record in records:
stream.write(','.join(map(str, record.values())))
stream.write('\n')
stream.write(",".join(map(str, record.values())))
stream.write("\n")
parse_dur_exps = [
@@ -852,18 +902,9 @@ parse_dur_exps = [
r"(?P<value>\d+)\s*(?:(d)|(day))",
60 * 60 * 24,
),
(
r"(?P<value>\d+)\s*(?:(h)|(hour))",
60 * 60
),
(
r"(?P<value>\d+)\s*(?:(m)|(min))",
60
),
(
r"(?P<value>\d+)\s*(?:(s)|(sec))",
1
)
(r"(?P<value>\d+)\s*(?:(h)|(hour))", 60 * 60),
(r"(?P<value>\d+)\s*(?:(m)|(min))", 60),
(r"(?P<value>\d+)\s*(?:(s)|(sec))", 1),
]
@@ -874,6 +915,131 @@ def parse_duration(string: str) -> Optional[int]:
match = re.search(expr, string, flags=re.IGNORECASE)
if match:
found = True
seconds += int(match.group('value')) * multiplier
seconds += int(match.group("value")) * multiplier
return seconds if found else None
async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs):
"""
Shows the user each page from the provided list `pages` one at a time,
providing reactions to page back and forth between pages.
This is done asynchronously, and returns after displaying the first page.
Parameters
----------
pages: List(Union(str, discord.Embed))
A list of either strings or embeds to display as the pages.
locked: bool
Whether only the `ctx.author` should be able to use the paging reactions.
kwargs: ...
Remaining keyword arguments are transparently passed to the reply context method.
Returns: discord.Message
This is the output message, returned for easy deletion.
"""
cancel_emoji = cross
# Handle broken input
if len(pages) == 0:
raise ValueError("Pager cannot page with no pages!")
# Post first page. Method depends on whether the page is an embed or not.
if isinstance(pages[start_at], discord.Embed):
out_msg = await ctx.reply(embed=pages[start_at], **kwargs)
else:
out_msg = await ctx.reply(pages[start_at], **kwargs)
# Run the paging loop if required
if len(pages) > 1:
task = asyncio.create_task(
_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs)
)
# ctx.tasks.append(task)
elif add_cancel:
await out_msg.add_reaction(cancel_emoji)
# Return the output message
return out_msg
async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs):
"""
Asynchronous initialiser and loop for the `pager` utility above.
"""
# Page number
page = start_at
# Add reactions to the output message
next_emoji = ""
prev_emoji = ""
cancel_emoji = cross
try:
await out_msg.add_reaction(prev_emoji)
if add_cancel:
await out_msg.add_reaction(cancel_emoji)
await out_msg.add_reaction(next_emoji)
except discord.Forbidden:
# We don't have permission to add paging emojis
# Die as gracefully as we can
if ctx.guild:
perms = ctx.channel.permissions_for(ctx.guild.me)
if not perms.add_reactions:
await ctx.error_reply(
"Cannot page results because I do not have the `add_reactions` permission!"
)
elif not perms.read_message_history:
await ctx.error_reply(
"Cannot page results because I do not have the `read_message_history` permission!"
)
else:
await ctx.error_reply(
"Cannot page results due to insufficient permissions!"
)
else:
await ctx.error_reply("Cannot page results!")
return
# Check function to determine whether a reaction is valid
def check(reaction, user):
result = reaction.message.id == out_msg.id
result = result and str(reaction.emoji) in [next_emoji, prev_emoji]
result = result and not (user.id == ctx.bot.user.id)
result = result and not (locked and user != ctx.author)
return result
# Begin loop
while True:
# Wait for a valid reaction, break if we time out
try:
reaction, user = await ctx.bot.wait_for(
"reaction_add", check=check, timeout=300
)
except asyncio.TimeoutError:
break
# Attempt to remove the user's reaction, silently ignore errors
asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user))
# Change the page number
page += 1 if reaction.emoji == next_emoji else -1
page %= len(pages)
# Edit the message with the new page
active_page = pages[page]
if isinstance(active_page, discord.Embed):
await out_msg.edit(embed=active_page, **kwargs)
else:
await out_msg.edit(content=active_page, **kwargs)
# Clean up by removing the reactions
try:
await out_msg.clear_reactions()
except discord.Forbidden:
try:
await out_msg.remove_reaction(next_emoji, ctx.client.user)
await out_msg.remove_reaction(prev_emoji, ctx.client.user)
except discord.NotFound:
pass
except discord.NotFound:
pass