rewrite: Reminders system.

This commit is contained in:
2022-11-24 23:12:20 +02:00
parent 0d5e801945
commit dd8609fac0
23 changed files with 1268 additions and 545 deletions

View File

@@ -157,6 +157,16 @@ class BabelCog(LionCog):
async def cog_unload(self):
pass
async def get_user_locale(self, userid):
"""
Fetch the best locale we can guess for this userid.
"""
data = await self.bot.core.data.User.fetch(userid)
if data:
return data.locale or data.locale_hint or SOURCE_LOCALE
else:
return SOURCE_LOCALE
async def bot_check_once(self, ctx: LionContext): # type: ignore # Type checker doesn't understand coro checks
"""
Calculate and inject the current locale before the command begins.

View File

@@ -307,3 +307,10 @@ class RowModel:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None

View File

@@ -12,7 +12,7 @@ from meta.context import ctx_bot
from data import Database
from babel.translator import LeoBabel
from babel.translator import LeoBabel, ctx_translator
from constants import DATA_VERSION
@@ -40,6 +40,9 @@ async def main():
logger.critical(error)
raise RuntimeError(error)
translator = LeoBabel()
ctx_translator.set(translator)
async with aiohttp.ClientSession() as session:
async with LionBot(
command_prefix=commands.when_mentioned,
@@ -54,7 +57,7 @@ async def main():
testing_guilds=conf.bot.getintlist('admin_guilds'),
shard_id=sharding.shard_number,
shard_count=sharding.shard_count,
translator=LeoBabel()
translator=translator
) as lionbot:
ctx_bot.set(lionbot)
try:

View File

@@ -124,7 +124,7 @@ class LionContext(Context['LionBot']):
except Exception:
logger.exception(
"Unknown exception in 'error_reply'.",
extra={'action': 'error_reply', 'ctx': self, 'with_ctx': True}
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
)

View File

@@ -2,7 +2,8 @@ this_package = 'modules'
active = [
'.sysadmin',
'.test'
'.test',
'.reminders'
]

View File

@@ -1,5 +0,0 @@
from .module import module
from . import commands
from . import data
from . import reminder

View File

@@ -1,264 +0,0 @@
import re
import asyncio
import datetime
import discord
from meta import sharding
from utils.lib import parse_dur, parse_ranges, multiselect_regex
from .module import module
from .data import reminders
from .reminder import Reminder
reminder_regex = re.compile(
r"""
(^)?(?P<type> (?: \b in) | (?: every))
\s*(?P<duration> (?: day| hour| (?:\d+\s*(?:(?:d|h|m|s)[a-zA-Z]*)?(?:\s|and)*)+))
(?:(?(1) (?:, | ; | : | \. | to)? | $))
""",
re.IGNORECASE | re.VERBOSE | re.DOTALL
)
reminder_limit = 20
@module.cmd(
name="remindme",
desc="Ask me to remind you about important tasks.",
group="Productivity",
aliases=('reminders', 'reminder'),
flags=('remove', 'clear')
)
async def cmd_remindme(ctx, flags):
"""
Usage``:
{prefix}remindme in <duration> to <task>
{prefix}remindme every <duration> to <task>
{prefix}reminders
{prefix}reminders --clear
{prefix}reminders --remove
Description:
Ask {ctx.client.user.name} to remind you about important tasks.
Examples``:
{prefix}remindme in 2h 20m, Revise chapter 1
{prefix}remindme every hour, Drink water!
{prefix}remindme Anatomy class in 8h 20m
"""
# TODO: (FUTURE) every day at 9:00
if flags['remove']:
# Do removal stuff
rows = reminders.fetch_rows_where(
userid=ctx.author.id,
_extra="ORDER BY remind_at ASC"
)
if not rows:
return await ctx.reply("You have no reminders to remove!")
live = [Reminder(row.reminderid) for row in rows]
if not ctx.args:
lines = []
num_field = len(str(len(live) - 1))
for i, reminder in enumerate(live):
lines.append(
"`[{:{}}]` | {}".format(
i,
num_field,
reminder.formatted
)
)
description = '\n'.join(lines)
description += (
"\n\nPlease select the reminders to remove, or type `c` to cancel.\n"
"(For example, respond with `1, 2, 3` or `1-3`.)"
)
embed = discord.Embed(
description=description,
colour=discord.Colour.orange(),
timestamp=datetime.datetime.utcnow()
).set_author(
name="Reminders for {}".format(ctx.author.display_name),
icon_url=ctx.author.avatar_url
)
out_msg = await ctx.reply(embed=embed)
def check(msg):
valid = msg.channel == ctx.ch and msg.author == ctx.author
valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c')
return valid
try:
message = await ctx.client.wait_for('message', check=check, timeout=60)
except asyncio.TimeoutError:
await out_msg.delete()
await ctx.error_reply("Session timed out. No reminders were deleted.")
return
try:
await out_msg.delete()
await message.delete()
except discord.HTTPException:
pass
if message.content.lower() == 'c':
return
to_delete = [
live[index].reminderid
for index in parse_ranges(message.content) if index < len(live)
]
else:
to_delete = [
live[index].reminderid
for index in parse_ranges(ctx.args) if index < len(live)
]
if not to_delete:
return await ctx.error_reply("Nothing to delete!")
# Delete the selected reminders
Reminder.delete(*to_delete)
# Ack
await ctx.embed_reply(
"{tick} Reminder{plural} deleted.".format(
tick='',
plural='s' if len(to_delete) > 1 else ''
)
)
elif flags['clear']:
# Do clear stuff
rows = reminders.fetch_rows_where(
userid=ctx.author.id,
)
if not rows:
return await ctx.reply("You have no reminders to remove!")
Reminder.delete(*(row.reminderid for row in rows))
await ctx.embed_reply(
"{tick} Reminders cleared.".format(
tick='',
)
)
elif ctx.args:
# Add a new reminder
content = None
duration = None
repeating = None
# First parse it
match = re.search(reminder_regex, ctx.args)
if match:
repeating = match.group('type').lower() == 'every'
duration_str = match.group('duration').lower()
if duration_str.isdigit():
duration = int(duration_str)
elif duration_str == 'day':
duration = 24 * 60 * 60
elif duration_str == 'hour':
duration = 60 * 60
else:
duration = parse_dur(duration_str)
content = (ctx.args[:match.start()] + ctx.args[match.end():]).strip()
if content.startswith('to '):
content = content[3:].strip()
else:
# Legacy parsing, without requiring "in" at the front
splits = ctx.args.split(maxsplit=1)
if len(splits) == 2 and splits[0].isdigit():
repeating = False
duration = int(splits[0]) * 60
content = splits[1].strip()
# Sanity checking
if not duration or not content:
return await ctx.error_reply(
"Sorry, I didn't understand your reminder!\n"
"See `{prefix}help remindme` for usage and examples.".format(prefix=ctx.best_prefix)
)
# Don't allow rapid repeating reminders
if repeating and duration < 10 * 60:
return await ctx.error_reply(
"You can't have a repeating reminder shorter than `10` minutes!"
)
# Check the user doesn't have too many reminders already
count = reminders.select_one_where(
userid=ctx.author.id,
select_columns=("COUNT(*)",)
)[0]
if count > reminder_limit:
return await ctx.error_reply(
"Sorry, you have reached your maximum of `{}` reminders!".format(reminder_limit)
)
# Create reminder
reminder = Reminder.create(
userid=ctx.author.id,
content=content,
message_link=ctx.msg.jump_url,
interval=duration if repeating else None,
remind_at=datetime.datetime.utcnow() + datetime.timedelta(seconds=duration)
)
# Schedule reminder
if sharding.shard_number == 0:
reminder.schedule()
# Ack
embed = discord.Embed(
title="Reminder Created!",
colour=discord.Colour.orange(),
description="Got it! I will remind you <t:{}:R>.".format(reminder.timestamp),
timestamp=datetime.datetime.utcnow()
)
await ctx.reply(embed=embed)
elif ctx.alias.lower() == 'remindme':
# Show hints about adding reminders
...
else:
# Show formatted list of reminders
rows = reminders.fetch_rows_where(
userid=ctx.author.id,
_extra="ORDER BY remind_at ASC"
)
if not rows:
return await ctx.reply("You have no reminders!")
live = [Reminder(row.reminderid) for row in rows]
lines = []
num_field = len(str(len(live) - 1))
for i, reminder in enumerate(live):
lines.append(
"`[{:{}}]` | {}".format(
i,
num_field,
reminder.formatted
)
)
description = '\n'.join(lines)
embed = discord.Embed(
description=description,
colour=discord.Colour.orange(),
timestamp=datetime.datetime.utcnow()
).set_author(
name="{}'s reminders".format(ctx.author.display_name),
icon_url=ctx.author.avatar_url
).set_footer(
text=(
"Click a reminder twice to jump to the context!\n"
"For more usage and examples see {}help reminders"
).format(ctx.best_prefix)
)
await ctx.reply(embed=embed)

View File

@@ -1,8 +0,0 @@
from data.interfaces import RowTable
reminders = RowTable(
'reminders',
('reminderid', 'userid', 'remind_at', 'content', 'message_link', 'interval', 'created_at', 'title', 'footer'),
'reminderid'
)

View File

@@ -1,4 +0,0 @@
from LionModule import LionModule
module = LionModule("Reminders")

View File

@@ -1,234 +0,0 @@
import asyncio
import datetime
import logging
import discord
from meta import client, sharding
from utils.lib import strfdur
from .data import reminders
from .module import module
class Reminder:
__slots__ = ('reminderid', '_task')
_live_reminders = {} # map reminderid -> Reminder
def __init__(self, reminderid):
self.reminderid = reminderid
self._task = None
@classmethod
def create(cls, **kwargs):
row = reminders.create_row(**kwargs)
return cls(row.reminderid)
@classmethod
def fetch(cls, *reminderids):
"""
Fetch an live reminders associated to the given reminderids.
"""
return [
cls._live_reminders[reminderid]
for reminderid in reminderids
if reminderid in cls._live_reminders
]
@classmethod
def delete(cls, *reminderids):
"""
Cancel and delete the given reminders in an idempotent fashion.
"""
# Cancel the rmeinders
for reminderid in reminderids:
if reminderid in cls._live_reminders:
cls._live_reminders[reminderid].cancel()
# Remove from data
if reminderids:
return reminders.delete_where(reminderid=reminderids)
else:
return []
@property
def data(self):
return reminders.fetch(self.reminderid)
@property
def timestamp(self):
"""
True unix timestamp for (next) reminder time.
"""
return int(self.data.remind_at.replace(tzinfo=datetime.timezone.utc).timestamp())
@property
def user(self):
"""
The discord.User that owns this reminder, if we can find them.
"""
return client.get_user(self.data.userid)
@property
def formatted(self):
"""
Single-line string format for the reminder, intended for an embed.
"""
content = self.data.content
trunc_content = content[:50] + '...' * (len(content) > 50)
if self.data.interval:
interval = self.data.interval
if interval == 24 * 60 * 60:
interval_str = "day"
elif interval == 60 * 60:
interval_str = "hour"
elif interval % (24 * 60 * 60) == 0:
interval_str = "`{}` days".format(interval // (24 * 60 * 60))
elif interval % (60 * 60) == 0:
interval_str = "`{}` hours".format(interval // (60 * 60))
else:
interval_str = "`{}`".format(strfdur(interval))
repeat = "(Every {})".format(interval_str)
else:
repeat = ""
return "<t:{timestamp}:R>, [{content}]({jump_link}) {repeat}".format(
jump_link=self.data.message_link,
content=trunc_content,
timestamp=self.timestamp,
repeat=repeat
)
def cancel(self):
"""
Cancel the live reminder waiting task, if it exists.
Does not remove the reminder from data. Use `Reminder.delete` for this.
"""
if self._task and not self._task.done():
self._task.cancel()
self._live_reminders.pop(self.reminderid, None)
def schedule(self):
"""
Schedule this reminder to be executed.
"""
asyncio.create_task(self._schedule())
self._live_reminders[self.reminderid] = self
async def _schedule(self):
"""
Execute this reminder after a sleep.
Accepts cancellation by aborting the scheduled execute.
"""
# Calculate time left
remaining = (self.data.remind_at - datetime.datetime.utcnow()).total_seconds()
# Create the waiting task and wait for it, accepting cancellation
self._task = asyncio.create_task(asyncio.sleep(remaining))
try:
await self._task
except asyncio.CancelledError:
return
await self._execute()
async def _execute(self):
"""
Execute the reminder.
"""
if not self.data:
# Reminder deleted elsewhere
return
if self.data.userid in client.user_blacklist():
self.delete(self.reminderid)
return
userid = self.data.userid
# Build the message embed
embed = discord.Embed(
title="You asked me to remind you!" if self.data.title is None else self.data.title,
colour=discord.Colour.orange(),
description=self.data.content,
timestamp=datetime.datetime.utcnow()
)
if self.data.message_link:
embed.add_field(name="Context?", value="[Click here]({})".format(self.data.message_link))
if self.data.interval:
embed.add_field(
name="Next reminder",
value="<t:{}:R>".format(
self.timestamp + self.data.interval
)
)
if self.data.footer:
embed.set_footer(text=self.data.footer)
# Update the reminder data, and reschedule if required
if self.data.interval:
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval)
rows = reminders.update_where(
{'remind_at': next_time},
reminderid=self.reminderid
)
self.schedule()
else:
rows = self.delete(self.reminderid)
if not rows:
# Reminder deleted elsewhere
return
# Send the message, if possible
if not (user := client.get_user(userid)):
try:
user = await client.fetch_user(userid)
except discord.HTTPException:
pass
if user:
try:
await user.send(embed=embed)
except discord.HTTPException:
# Nothing we can really do here. Maybe tell the user about their reminder next time?
pass
async def reminder_poll(client):
"""
One client/shard must continually poll for new or deleted reminders.
"""
# TODO: Clean this up with database signals or IPC
while True:
await asyncio.sleep(60)
client.log(
"Running new reminder poll.",
context="REMINDERS",
level=logging.DEBUG
)
rids = {row.reminderid for row in reminders.fetch_rows_where()}
to_delete = (rid for rid in Reminder._live_reminders if rid not in rids)
Reminder.delete(*to_delete)
[Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders]
@module.launch_task
async def schedule_reminders(client):
if sharding.shard_number == 0:
rows = reminders.fetch_rows_where()
for row in rows:
Reminder(row.reminderid).schedule()
client.log(
"Scheduled {} reminders.".format(len(rows)),
context="LAUNCH_REMINDERS"
)
if sharding.sharded:
asyncio.create_task(reminder_poll(client))

View File

@@ -0,0 +1,13 @@
from babel.translator import LocalBabel
babel = LocalBabel('reminders')
import logging
logger = logging.getLogger(__name__)
logger.debug("Loaded reminders")
from .cog import Reminders
async def setup(bot):
await bot.add_cog(Reminders(bot))

View File

@@ -0,0 +1,839 @@
"""
Max 25 reminders (propagating Discord restriction)
/reminders show
-- Widget which displays and allows removing reminders.
-- Points to /remindme for setting
/reminders clear
/reminders remove <reminder: acmpl>
-- Can we autocomplete an integer field?
/remindme at <time: time> <repeat every: acmpl str> <reminder: str>
/remindme in <days: int> <hours: int> <minutes: int> <repeat every: acmpl str> <reminder: str>
"""
from typing import Optional
import datetime as dt
from cachetools import TTLCache
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from discord.app_commands import Transform
from discord.ui.select import select, SelectOption
from data import RowModel, Registry
from data.queries import ORDER
from data.columns import Integer, String, Timestamp, Bool
from meta import LionBot, LionCog, LionContext
from meta.app import shard_talk, appname_from_shard
from meta.logger import log_wrap, logging_context
from babel import ctx_translator, ctx_locale
from utils.lib import parse_duration, utc_now, strfdur, error_embed
from utils.monitor import TaskMonitor
from utils.transformers import DurationTransformer
from utils.ui import LeoUI, AButton, AsComponents
from . import babel, logger
_, _p, _np = babel._, babel._p, babel._np
class ReminderData(Registry, name='reminders'):
class Reminder(RowModel):
"""
Model representing a single reminder.
Since reminders are likely to change across shards,
does not use an explicit reference cache.
Schema
------
CREATE TABLE reminders(
reminderid SERIAL PRIMARY KEY,
userid BIGINT NOT NULL REFERENCES user_config(userid) ON DELETE CASCADE,
remind_at TIMESTAMP NOT NULL,
content TEXT NOT NULL,
message_link TEXT,
interval INTEGER,
created_at TIMESTAMP DEFAULT (now() at time zone 'utc'),
title TEXT,
footer TEXT
);
CREATE INDEX reminder_users ON reminders (userid);
"""
_tablename_ = 'reminders'
reminderid = Integer(primary=True)
userid = Integer() # User which created the reminder
remind_at = Timestamp() # Time when the reminder should be executed
content = String() # Content the user gave us to remind them
message_link = String() # Link to original confirmation message, for context
interval = Integer() # Repeat interval, if applicable
created_at = Timestamp() # Time when this reminder was originally created
title = String() # Title of the final reminder embed, only set in automated reminders
footer = String() # Footer of the final reminder embed, only set in automated reminders
failed = Bool() # Whether the reminder was already attempted and failed
@property
def timestamp(self) -> int:
"""
Time when this reminder should be executed (next) as an integer timestamp.
"""
return int(self.remind_at.timestamp())
@property
def embed(self) -> discord.Embed:
t = ctx_translator.get().t
embed = discord.Embed(
title=self.title or t(_p('reminder|embed', "You asked me to remind you!")),
colour=discord.Colour.orange(),
description=self.content,
timestamp=self.remind_at
)
if self.message_link:
embed.add_field(
name=t(_p('reminder|embed', "Context?")),
value="[{click}]({link})".format(
click=t(_p('reminder|embed', "Click Here")),
link=self.message_link
)
)
if self.interval:
embed.add_field(
name=t(_p('reminder|embed', "Next reminder")),
value=f"<t:{self.timestamp + self.interval}:R>"
)
if self.footer:
embed.set_footer(text=self.footer)
return embed
@property
def formatted(self):
"""
Single-line string format for the reminder, intended for an embed.
"""
t = ctx_translator.get().t
content = self.content
trunc_content = content[:50] + '...' * (len(content) > 50)
if interval := self.interval:
if not interval % (24 * 60 * 60):
# Exact day case
days = interval // (24 * 60 * 60)
repeat = t(_np(
'reminder|formatted|interval',
"Every day",
"Every `{days}` days",
days
)).format(days=days)
elif not interval % (60 * 60):
# Exact hour case
hours = interval // (60 * 60)
repeat = t(_np(
'reminder|formatted|interval',
"Every hour",
"Every `{hours}` hours",
hours
)).format(hours=hours)
else:
# Inexact interval, e.g 10m or 1h 10m.
# Use short duration format
repeat = t(_p(
'reminder|formatted|interval',
"Every `{duration}`"
)).format(duration=strfdur(interval))
repeat = f"({repeat})"
else:
repeat = ""
return "<t:{timestamp}:R>, [{content}]({jump_link}) {repeat}".format(
jump_link=self.message_link,
content=trunc_content,
timestamp=self.timestamp,
repeat=repeat
)
class ReminderMonitor(TaskMonitor[int]):
...
class Reminders(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(ReminderData())
# Whether this process should handle reminder execution
self.executor = (self.bot.shard_id == 0)
self.executor_name = appname_from_shard(0)
if self.executor:
self.monitor: Optional[ReminderMonitor] = ReminderMonitor(executor=self.execute_reminder)
else:
self.monitor = None
self.talk_reload = shard_talk.register_route('reload_reminders')(self.reload_reminders)
self.talk_schedule = shard_talk.register_route('schedule_reminders')(self.schedule_reminders)
self.talk_cancel = shard_talk.register_route('cancel_reminders')(self.cancel_reminders)
# Short term userid -> list[Reminder] cache, mainly for autocomplete
self._user_reminder_cache: TTLCache[int, list[ReminderData.Reminder]] = TTLCache(1000, ttl=60)
self._active_reminderlists: dict[int, ReminderListUI] = {}
async def cog_load(self):
await self.data.init()
if self.executor:
# Attach and populate the reminder monitor
self.monitor = ReminderMonitor(executor=self.execute_reminder)
await self.reload_reminders()
if self.bot.is_ready:
self.monitor.start()
@LionCog.listener()
async def on_ready(self):
if self.executor and not self.monitor._monitor_task:
# Start firing reminders
self.monitor.start()
async def get_reminders_for(self, userid: int):
"""
Retrieve a list of reminders for the given userid, using the cache.
"""
reminders = self._user_reminder_cache.get(userid, None)
if reminders is None:
reminders = await self.data.Reminder.fetch_where(
userid=userid
).order_by(self.data.Reminder.created_at, ORDER.ASC)
self._user_reminder_cache[userid] = reminders
return reminders
async def dispatch_update_for(self, userid: int):
"""
Announce that the given user's reminders have changed.
This triggers update of the cog reminder cache, and a reload of any active reminder list UIs.
"""
self._user_reminder_cache.pop(userid, None)
if userid in self._active_reminderlists:
await self._active_reminderlists[userid].refresh()
async def reload_reminders(self):
"""
Refresh reminder data and reminder tasks.
"""
if not self.executor:
raise ValueError("Only the executor shard can reload reminders!")
# Load all reminder tasks
reminders = await self.data.Reminder.fetch_where(failed=None)
tasks = [(r.reminderid, r.timestamp) for r in reminders]
self.monitor.set_tasks(*tasks)
logger.info(
f"Reloaded ReminderMonitor with {len(tasks)} active reminders."
)
async def cancel_reminders(self, *reminderids):
"""
ShardTalk Route.
Cancel the given reminderids in the ReminderMonitor.
"""
if not self.executor:
raise ValueError("Only the executor shard can cancel scheduled reminders!")
# If we are the executor shard, we know the monitor is loaded
# If reminders have not yet been loaded, cancelling is a no-op
# Since reminder loading is synchronous, we cannot get in a race state with loading
self.monitor.cancel_tasks(*reminderids)
logger.debug(
f"Cancelled reminders: {reminderids}",
)
async def schedule_reminders(self, *reminderids):
"""
ShardTalk Route.
Schedule the given new reminderids in the ReminderMonitor.
"""
if not self.executor:
raise ValueError("Only the executor shard can schedule reminders!")
# We refetch here to make sure the reminders actually exist
reminders = await self.data.Reminder.fetch_where(reminderid=reminderids)
self.monitor.schedule_tasks(*((reminder.reminderid, reminder.timestamp) for reminder in reminders))
logger.debug(
f"Scheduled new reminders: {tuple(reminder.reminderid for reminder in reminders)}",
)
async def execute_reminder(self, reminderid):
"""
Send the reminder with the given reminderid.
This should in general only be executed from the executor shard,
through a ReminderMonitor instance.
"""
with logging_context(action='Send Reminder', context=f"rid: {reminderid}"):
reminder = await self.data.Reminder.fetch(reminderid)
if reminder is None:
logger.warning(
f"Attempted to execute a reminder <rid: {reminderid}> that no longer exists!"
)
return
try:
# Try and find the user
userid = reminder.userid
if not (user := self.bot.get_user(userid)):
user = await self.bot.fetch_user(userid)
# Set the locale variables
locale = await self.bot.get_cog('BabelCog').get_user_locale(userid)
ctx_locale.set(locale)
ctx_translator.set(self.bot.translator)
# Build the embed
embed = reminder.embed
# Attempt to send to user
# TODO: Consider adding a View to this, for cancelling a repeated reminder or showing reminders
await user.send(embed=embed)
# Update the data as required
if reminder.interval:
now = utc_now()
# Use original reminder time to calculate repeat, avoiding drift
next_time = reminder.remind_at + dt.timedelta(seconds=reminder.interval)
# Skip any expired repeats, to avoid spamming requests after downtime
# TODO: Is this actually dst safe?
while next_time.timestamp() <= now.timestamp():
next_time + dt.timedelta(seconds=reminder.interval)
await reminder.update(remind_at=next_time)
self.monitor.schedule_task(reminder.reminderid, reminder.timestamp)
logger.debug(
f"Executed reminder <rid: {reminder.reminderid}> and scheduled repeat at {next_time}."
)
else:
await reminder.delete()
logger.debug(
f"Executed reminder <rid: {reminder.reminderid}>."
)
except discord.HTTPException:
await reminder.update(failed=True)
logger.debug(
f"Reminder <rid: {reminder.reminderid}> could not be sent.",
exc_info=True
)
except Exception:
await reminder.update(failed=True)
logger.exception(
f"Reminder <rid: {reminder.reminderid}> failed for an unknown reason!"
)
finally:
# Dispatch for analytics
self.bot.dispatch('reminder_sent', reminder)
@cmds.hybrid_group(
name=_p('cmd:reminders', "reminders")
)
async def reminders_group(self, ctx: LionContext):
pass
@reminders_group.command(
# No help string
name=_p('cmd:reminders_show', "show"),
description=_p(
'cmd:reminders_show|desc',
"Display your current reminders."
)
)
async def cmd_reminders_show(self, ctx: LionContext):
# No help string
"""
Display the reminder widget for this user.
"""
t = self.bot.translator.t
if not ctx.interaction:
return
if ctx.author.id in self._active_reminderlists:
await self._active_reminderlists[ctx.author.id].close(
msg=t(_p(
'cmd:reminders_show|close_elsewhere',
"Closing since the list was opened elsewhere."
))
)
ui = ReminderListUI(self.bot, ctx.author)
try:
self._active_reminderlists[ctx.author.id] = ui
await ui.run(ctx.interaction)
await ui.wait()
finally:
self._active_reminderlists.pop(ctx.author.id, None)
@reminders_group.command(
name=_p('cmd:reminders_clear', "clear"),
description=_p(
'cmd:reminders_clear|desc',
"Clear your reminder list."
)
)
async def cmd_reminders_clear(self, ctx: LionContext):
# No help string
"""
Confirm and then clear all the reminders for this user.
"""
if not ctx.interaction:
return
t = self.bot.translator.t
reminders = await self.data.Reminder.fetch_where(userid=ctx.author.id)
if not reminders:
await ctx.reply(
embed=discord.Embed(
description=t(_p(
'cmd:reminders_clear|error:no_reminders',
"You have no reminders to clear!"
)),
colour=discord.Colour.brand_red()
),
ephemeral=True
)
return
embed = discord.Embed(
title=t(_p('cmd:reminders_clear|confirm|title', "Are You Sure?")),
description=t(_np(
'cmd:reminders_clear|confirm|desc',
"Are you sure you want to delete your `{count}` reminder?",
"Are you sure you want to clear your `{count}` reminders?",
len(reminders)
)).format(count=len(reminders))
)
@AButton(label=t(_p('cmd:reminders_clear|confirm|button:yes', "Yes, clear my reminders")))
async def confirm(view, interaction, press):
await interaction.response.defer()
reminders = await self.data.Reminder.table.delete_where(userid=ctx.author.id)
await self.talk_cancel(*(r['reminderid'] for r in reminders)).send(self.executor_name, wait_for_reply=False)
await ctx.interaction.edit_original_response(
embed=discord.Embed(
description=t(_p(
'cmd:reminders_clear|success|desc',
"Your reminders have been cleared!"
)),
colour=discord.Colour.brand_green()
),
view=None
)
await view.close()
await self.dispatch_update_for(ctx.author.id)
@AButton(label=t(_p('cmd:reminders_clear|confirm|button:cancel', "Cancel")))
async def deny(view, interaction, press):
await interaction.response.defer()
await ctx.interaction.delete_original_response()
await view.close()
components = AsComponents(confirm, deny)
await ctx.interaction.response.send_message(embed=embed, view=components, ephemeral=True)
@reminders_group.command(
name=_p('cmd:reminders_cancel', "cancel"),
description=_p(
'cmd:reminders_cancel|desc',
"Cancel a single reminder. Use the menu in \"reminder show\" to cancel multiple reminders."
)
)
@appcmds.rename(
reminder=_p('cmd:reminders_cancel|param:reminder', 'reminder')
)
@appcmds.describe(
reminder=_p(
'cmd:reminders_cancel|param:reminder|desc',
"Start typing, then select a reminder to cancel."
)
)
async def cmd_reminders_cancel(self, ctx: LionContext, reminder: str):
# No help string
"""
Cancel a previously scheduled reminder.
Autocomplete lets the user select their reminder by number or truncated content.
Need to handle the case where reminderid is that truncated content.
"""
t = self.bot.translator.t
reminders = await self.get_reminders_for(ctx.author.id)
# Guard against no reminders
if not reminders:
await ctx.error_reply(
t(_p(
'cmd:reminders_cancel|error:no_reminders',
"There are no reminders to cancel!"
))
)
return
# Now attempt to parse reminder input
if reminder.startswith('rid:') and reminder[4:].isdigit():
# Assume reminderid, probably selected through autocomplete
rid = int(reminder[4:])
rem = next((rem for rem in reminders if rem.reminderid == rid), None)
elif reminder.strip('[] ').isdigit():
# Assume user reminder index
# Not strictly threadsafe, but should be okay 90% of the time
lid = int(reminder)
rem = next((rem for i, rem in enumerate(reminders, start=1) if i == lid), None)
else:
# Assume partial string from a reminder
partial = reminder
rem = next((rem for rem in reminders if partial in rem.content), None)
if rem is None:
await ctx.error_reply(
t(_p(
'cmd:reminders_cancel|error:no_match',
"I am not sure which reminder you want to cancel. "
"Please try again, selecting a reminder from the list of choices."
))
)
return
# At this point we have a valid reminder to cancel
await rem.delete()
await self.talk_cancel(rem.reminderid).send(self.executor_name, wait_for_reply=False)
await ctx.reply(
embed=discord.Embed(
description=t(_p(
'cmd:reminders_cancel|embed:success|desc',
"Reminder successfully cancelled."
)),
colour=discord.Colour.brand_green()
),
ephemeral=True
)
await self.dispatch_update_for(ctx.author.id)
@cmd_reminders_cancel.autocomplete('reminder')
async def cmd_reminders_cancel_acmpl_reminderid(self, interaction: discord.Interaction, partial: str):
t = self.bot.translator.t
reminders = await self.get_reminders_for(interaction.user.id)
if not reminders:
# Nothing to cancel case
name = t(_p(
'cmd:reminders_cancel|acmpl:reminder|error:no_reminders',
"There are no reminders to cancel!"
))
value = 'None'
choices = [
appcmds.Choice(name=name, value=value)
]
else:
# Build list of reminder strings
strings = []
for pos, reminder in enumerate(reminders, start=1):
strings.append(
(f"[{pos}] {reminder.content}", reminder)
)
# Extract matches
matches = [string for string in strings if partial.lower() in string[0].lower()]
if matches:
# Build list of valid choices
choices = [
appcmds.Choice(
name=string[0],
value=f"rid:{string[1].reminderid}"
)
for string in matches
]
else:
choices = [
appcmds.Choice(
name=t(_p(
'cmd:reminders_cancel|acmpl:reminder|error:no_matches',
"You do not have any reminders matching \"{partial}\""
)).format(partial=partial),
value=partial
)
]
return choices
@cmds.hybrid_group(
name=_p('cmd:remindme', "remindme")
)
async def remindme_group(self, ctx: LionContext):
# Base command group for scheduling reminders.
pass
# TODO: Waiting until we have timezone data and user time methods.
# @remindme_group.command(
# name=_p('cmd:remindme_at', "at"),
# description=_p(
# 'cmd:remindme_at|desc',
# "Schedule a reminder for a particular time."
# )
# )
# @appcmds.rename(
# time=_p('cmd:remindme_at|param:time', "time"),
# reminder=_p('cmd:remindme_at|param:reminder', "reminder"),
# every=_p('cmd:remindme_at|param:every', "every"),
# )
# @appcmds.describe(
# time=_p('cmd:remindme_at|param:time|desc', "When you want to be reminded.."),
# reminder=_p('cmd:remindme_at|param:reminder|desc', "What should the reminder be?"),
# every=_p('cmd:remindme_at|param:every|desc', "How often to repeat this reminder.")
# )
# async def cmd_remindme_at(
# self,
# ctx: LionContext,
# time: str,
# reminder: str,
# every: Optional[Transform[int, DurationTransformer(60)]] = None
# ):
# ...
@remindme_group.command(
name=_p('cmd:remindme_in', "in"),
description=_p(
'cmd:remindme_in|desc',
"Schedule a reminder for a given amount of time in the future."
)
)
@appcmds.rename(
time=_p('cmd:remindme_in|param:time', "time"),
reminder=_p('cmd:remindme_in|param:reminder', "reminder"),
every=_p('cmd:remindme_in|param:every', "every"),
)
@appcmds.describe(
time=_p('cmd:remindme_in|param:time|desc', "How far into the future to set the reminder (e.g. 1 day 10h 5m)."),
reminder=_p('cmd:remindme_in|param:reminder|desc', "What should the reminder be?"),
every=_p('cmd:remindme_in|param:every|desc', "How often to repeat this reminder. (e.g. 1 day, or 2h)")
)
async def cmd_remindme_in(
self,
ctx: LionContext,
time: Transform[int, DurationTransformer(60)],
reminder: str, # TODO: Maximum length 1000?
every: Optional[Transform[int, DurationTransformer(60)]] = None
):
t = self.bot.translator.t
reminders = await self.data.Reminder.fetch_where(userid=ctx.author.id)
# Guard against too many reminders
if len(reminders) > 25:
await ctx.error_reply(
embed=error_embed(
t(_p(
'cmd_remindme_in|error:too_many|desc',
"Sorry, you have reached the maximum of `25` reminders!"
)),
title=t(_p(
'cmd_remindme_in|error:too_many|title',
"Could not create reminder!"
))
),
ephemeral=True
)
return
# Guard against too frequent reminders
if every is not None and every < 600:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd_remindme_in|error:too_fast|desc',
"You cannot set a repeating reminder with a period less than 10 minutes."
)),
title=t(_p(
'cmd_remindme_in|error:too_fast|title',
"Could not create reminder!"
))
),
ephemeral=True
)
return
# Everything seems to be in order
# Create the reminder
now = utc_now()
rem = await self.data.Reminder.create(
userid=ctx.author.id,
remind_at=now + dt.timedelta(seconds=time),
content=reminder,
message_link=ctx.message.jump_url,
interval=every,
created_at=now
)
# Reminder created, request scheduling from executor shard
await self.talk_schedule(rem.reminderid).send(self.executor_name, wait_for_reply=False)
# TODO Add repeat to description
embed = discord.Embed(
title=t(_p(
'cmd:remindme_in|success|title',
"Reminder Set {timestamp}"
)).format(timestamp=f"<t:{rem.timestamp}:R>"),
description=f"> {rem.content}"
)
await ctx.reply(
embed=embed,
ephemeral=True
)
await self.dispatch_update_for(ctx.author.id)
class ReminderListUI(LeoUI):
def __init__(self, bot: LionBot, user: discord.User, **kwargs):
super().__init__(**kwargs)
self.bot = bot
self.user = user
cog = bot.get_cog('Reminders')
if cog is None:
raise ValueError("Cannot create a ReminderUI without the Reminder cog!")
self.cog: Reminders = cog
self.userid = user.id
# Original interaction which sent the UI message
# Since this is an ephemeral UI, we need this to update and delete
self._interaction: Optional[discord.Interaction] = None
self._reminders = []
async def cleanup(self):
# Cleanup after an ephemeral UI
# Just close if possible
if self._interaction and not self._interaction.is_expired():
try:
await self._interaction.delete_original_response()
except discord.HTTPException:
pass
@select()
async def select_remove(self, interaction: discord.Interaction, selection):
"""
Select a number of reminders to delete.
"""
await interaction.response.defer()
# Hopefully this is a list of reminderids
values = selection.values
# Delete from data
await self.cog.data.Reminder.table.delete_where(reminderid=values)
# Send cancellation
await self.cog.talk_cancel(*values).send(self.cog.executor_name, wait_for_reply=False)
self.cog._user_reminder_cache.pop(self.userid, None)
await self.refresh()
async def refresh_select_remove(self):
"""
Refresh the select remove component from current state.
"""
t = self.bot.translator.t
self.select_remove.placeholder = t(_p(
'ui:reminderlist|select:remove|placeholder',
"Select to cancel."
))
self.select_remove.options = [
SelectOption(
label=f"[{i}] {reminder.content[:50] + '...' * (len(reminder.content) > 50)}",
value=reminder.reminderid,
emoji=self.bot.config.emojis.getemoji('clock')
)
for i, reminder in enumerate(self._reminders, start=1)
]
self.select_remove.min_values = 1
self.select_remove.max_values = len(self._reminders)
async def refresh_reminders(self):
self._reminders = await self.cog.get_reminders_for(self.userid)
async def refresh(self):
"""
Refresh the UI message and components.
"""
if not self._interaction:
raise ValueError("Cannot refresh ephemeral UI without an origin interaction!")
await self.refresh_reminders()
await self.refresh_select_remove()
embed = await self.build_embed()
if self._reminders:
self.set_layout((self.select_remove,))
else:
self.set_layout()
try:
if not self._interaction.response.is_done():
# Fresh message
await self._interaction.response.send_message(embed=embed, view=self, ephemeral=True)
else:
# Update existing message
await self._interaction.edit_original_response(embed=embed, view=self)
except discord.HTTPException:
await self.close()
async def run(self, interaction: discord.Interaction):
"""
Run the UI responding to the given interaction.
"""
self._interaction = interaction
await self.refresh()
async def build_embed(self):
"""
Build the reminder list embed.
"""
t = self.bot.translator.t
reminders = self._reminders
if reminders:
lines = []
num_len = len(str(len(reminders)))
for i, reminder in enumerate(reminders):
lines.append(
"`[{:<{}}]` | {}".format(
i+1,
num_len,
reminder.formatted
)
)
description = '\n'.join(lines)
embed = discord.Embed(
description=description,
colour=discord.Colour.orange(),
timestamp=utc_now()
).set_author(
name=t(_p(
'ui:reminderlist|embed:list|author',
"{name}'s reminders"
)).format(name=self.user.display_name),
icon_url=self.user.avatar
).set_footer(
text=t(_p(
'ui:reminderlist|embed:list|footer',
"Click a reminder twice to jump to the context!"
))
)
else:
embed = discord.Embed(
description=t(_p(
'ui:reminderlist|embed:no_reminders|desc',
"You have no reminders to display!\n"
"Use {remindme} to create a new reminder."
)).format(
remindme=self.bot.core.cmd_name_cache['remindme'].mention,
)
)
return embed

View File

@@ -29,11 +29,15 @@ from meta.LionBot import LionBot
from utils.ui import FastModal, input
from babel.translator import LocalBabel
from wards import sys_admin
logger = logging.getLogger(__name__)
_, _n, _p, _np = LocalBabel('exec').methods
class ExecModal(FastModal, title="Execute"):
code: TextInput = TextInput(
@@ -65,7 +69,7 @@ class ExecUI(View):
"""Only allow the original author to use this View"""
if interaction.user.id != self.ctx.author.id:
await interaction.response.send_message(
"You cannot use this interface!",
("You cannot use this interface!"),
ephemeral=True
)
return False
@@ -240,6 +244,7 @@ class Exec(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.t = bot.translator.t
self.talk_async = shard_talk.register_route('exec')(_async)
@@ -247,8 +252,8 @@ class Exec(LionCog):
return await sys_admin(ctx)
@commands.hybrid_command(
name='async',
description="Execute arbitrary code with Exec"
name=_('async'),
description=_("Execute arbitrary code with Exec")
)
@appcmd.describe(
string="Code to execute."
@@ -258,23 +263,23 @@ class Exec(LionCog):
await ExecUI(ctx, string, ExecStyle.EXEC).run()
@commands.hybrid_command(
name='eval',
description='Execute arbitrary code with Eval'
name=_p('command', 'eval'),
description=_p('command:eval', 'Execute arbitrary code with Eval')
)
@appcmd.describe(
string="Code to evaluate."
string=_p('command:eval|param:string', "Code to evaluate.")
)
@appcmd.guilds(*guild_ids)
async def eval_cmd(self, ctx: LionContext, *, string: str):
await ExecUI(ctx, string, ExecStyle.EVAL).run()
@commands.hybrid_command(
name='asyncall',
description="Execute arbitrary code on all shards."
name=_p('command', 'asyncall'),
description=_p('command:asyncall|desc', "Execute arbitrary code on all shards.")
)
@appcmd.describe(
string="Cross-shard code to execute. Cannot reference ctx!",
target="Target shard app name, see autocomplete for options."
string=_p("command:asyncall|param:string", "Cross-shard code to execute. Cannot reference ctx!"),
target=_p("command:asyncall|param:target", "Target shard app name, see autocomplete for options.")
)
@appcmd.guilds(*guild_ids)
async def asyncall_cmd(self, ctx: LionContext, string: Optional[str] = None, target: Optional[str] = None):
@@ -329,11 +334,11 @@ class Exec(LionCog):
return results
@commands.hybrid_command(
name='reload',
description="Reload a given LionBot extension. Launches an ExecUI."
name=_('reload'),
description=_("Reload a given LionBot extension. Launches an ExecUI.")
)
@appcmd.describe(
extension="Name of the extesion to reload. See autocomplete for options."
extension=_("Name of the extesion to reload. See autocomplete for options.")
)
@appcmd.guilds(*guild_ids)
async def reload_cmd(self, ctx: LionContext, extension: str):
@@ -365,8 +370,8 @@ class Exec(LionCog):
return results
@commands.hybrid_command(
name='shutdown',
description="Shutdown (or restart) the client."
name=_('shutdown'),
description=_("Shutdown (or restart) the client.")
)
@appcmd.guilds(*guild_ids)
async def shutdown_cmd(self, ctx: LionContext):

View File

@@ -107,6 +107,8 @@ class PresenceSettings(SettingGroup):
_title = "Presence Settings ({bot.core.cmd_name_cache[presence].mention})"
class PresenceStatus(ModelData, EnumSetting[str, AppStatus]):
setting_id = 'presence_status'
display_name = 'online_status'
desc = "Bot status indicator"
long_desc = "Whether the bot account displays as online, idle, dnd, or offline."
@@ -122,6 +124,8 @@ class PresenceSettings(SettingGroup):
_default = AppStatus.online
class PresenceType(ModelData, EnumSetting[str, AppActivityType]):
setting_id = 'presence_type'
display_name = 'activity_type'
desc = "Type of presence activity"
long_desc = "Whether the bot activity is shown as 'Listening', 'Playing', or 'Watching'."
@@ -137,6 +141,8 @@ class PresenceSettings(SettingGroup):
_default = AppActivityType.watching
class PresenceName(ModelData, StringSetting[str]):
setting_id = 'presence_name'
display_name = 'activity_name'
desc = "Name of the presence activity"
long_desc = "Presence activity name."

View File

@@ -9,7 +9,7 @@ from discord.ui.button import button, Button, ButtonStyle
from meta.context import context
from meta.errors import UserInputError
from utils.lib import strfdur, parse_dur
from utils.lib import strfdur, parse_duration
from babel import ctx_translator
from .base import ParentID

View File

@@ -0,0 +1,3 @@
from babel.translator import LocalBabel
util_babel = LocalBabel('utils')

View File

@@ -8,7 +8,11 @@ import discord
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord.ui import View
# from cmdClient.lib import SafeCancellation
from babel.translator import ctx_translator
from . import util_babel
_, _p = util_babel._, util_babel._p
multiselect_regex = re.compile(
@@ -320,7 +324,7 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -
return "".join(reply_msg)
def parse_dur(time_str: str) -> int:
def _parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
@@ -642,7 +646,7 @@ def parse_ids(idstr: str) -> List[int]:
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.red(),
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
)
@@ -653,3 +657,52 @@ class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
parse_dur_exps = [
(
_p(
'util:parse_dur|regex:day',
r"(?P<value>\d+)\s*(?:(d)|(day))"
),
60 * 60 * 24
),
(
_p(
'util:parse_dur|regex:hour',
r"(?P<value>\d+)\s*(?:(h)|(hour))"
),
60 * 60
),
(
_p(
'util:parse_dur|regex:minute',
r"(?P<value>\d+)\s*(?:(m)|(min))"
),
60
),
(
_p(
'util:parse_dur|regex:second',
r"(?P<value>\d+)\s*(?:(s)|(sec))"
),
1
)
]
def parse_duration(string: str) -> Optional[int]:
translator = ctx_translator.get()
if translator is None:
raise ValueError("Cannot parse duration without a translator.")
t = translator.t
seconds = 0
found = False
for expr, multiplier in parse_dur_exps:
match = re.search(t(expr), string, flags=re.IGNORECASE)
if match:
found = True
seconds += int(match.group('value')) * multiplier
return seconds if found else None

178
bot/utils/monitor.py Normal file
View File

@@ -0,0 +1,178 @@
import asyncio
import bisect
import logging
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
from .lib import utc_now
from .ratelimits import Bucket
logger = logging.getLogger(__name__)
Taskid = TypeVar('Taskid')
class TaskMonitor(Generic[Taskid]):
"""
Base class for a task monitor.
Stores tasks as a time-sorted list of taskids.
Subclasses may override `run_task` to implement an executor.
Adding or removing a single task has O(n) performance.
To bulk update tasks, instead use `schedule_tasks`.
Each taskid must be unique and hashable.
"""
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
# Ratelimit bucket to enforce maximum execution rate
self._bucket = bucket
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
self._wakeup: asyncio.Event = asyncio.Event()
self._monitor_task: Optional[self.Task] = None
# Task data
self._tasklist: list[Taskid] = []
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
# Running map ensures we keep a reference to the running task
# And allows simpler external cancellation if required
self._running: dict[Taskid, asyncio.Future] = {}
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Similar to `schedule_tasks`, but wipe and reset the tasklist.
"""
self._taskmap = {tid: time for tid, time in tasks}
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * tid * self._taskmap[tid])
self._wakeup.set()
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Schedule the given tasks.
Rather than repeatedly inserting tasks,
where the O(log n) insort is dominated by the O(n) list insertion,
we build an entirely new list, and always wake up the loop.
"""
self._taskmap |= {tid: time for tid, time in tasks}
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])
self._wakeup.set()
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
"""
Insert the provided task into the tasklist.
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
"""
if self._tasklist:
nextid = self._tasklist[-1]
wake = self._taskmap[nextid] >= timestamp
wake = wake or taskid == nextid
else:
wake = False
if taskid in self._taskmap:
self._tasklist.remove(taskid)
self._taskmap[taskid] = timestamp
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
if wake:
self._wakeup.set()
def cancel_tasks(self, *taskids: Taskid) -> None:
"""
Remove all tasks with the given taskids from the tasklist.
If the next task has this taskid, wake up the monitor loop.
"""
taskids = set(taskids)
wake = (self._tasklist and self._tasklist[-1] in taskids)
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
for tid in taskids:
self._taskmap.pop(tid, None)
if wake:
self._wakeup.set()
def start(self):
if self._monitor_task and not self._monitor_task.done():
self._monitor_task.cancel()
# Start the monitor
self._monitor_task = asyncio.create_task(self.monitor())
return self._monitor_task
async def monitor(self):
"""
Start the monitor.
Executes the tasks in `self.tasks` at the specified time.
This will shield task execution from cancellation
to avoid partial states.
"""
try:
while True:
self._wakeup.clear()
if not self._tasklist:
# No tasks left, just sleep until wakeup
await self._wakeup.wait()
else:
# Get the next task, sleep until wakeup or it is ready to run
nextid = self._tasklist[-1]
nexttime = self._taskmap[nextid]
sleep_for = nexttime - utc_now().timestamp()
try:
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
except asyncio.TimeoutError:
# Ready to run the task
self._tasklist.pop()
self._taskmap.pop(nextid, None)
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
else:
# Wakeup task fired, loop again
continue
except asyncio.CancelledError:
# Log closure and wait for remaining tasks
# A second cancellation will also cancel the tasks
logger.debug(
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
f"Waiting for {len(self._running)} running tasks to complete."
)
await asyncio.gather(*self._running.values(), return_exceptions=True)
async def _run(self, taskid: Taskid) -> None:
# Execute the task, respecting the ratelimit bucket
if self._bucket is not None:
# IMPLEMENTATION NOTE:
# Bucket.wait() should guarantee not more than n tasks/second are run
# and that a request directly afterwards will _not_ raise BucketFull
# Make sure that only one waiter is actually waiting on its sleep task
# The other waiters should be waiting on a lock around the sleep task
# Waiters are executed in wait-order, so if we only let a single waiter in
# we shouldn't get collisions.
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
# Or we will lose thread-safety for BucketFull
await self._bucket.wait()
fut = asyncio.create_task(self.run_task(taskid))
try:
await asyncio.shield(fut)
except asyncio.CancelledError:
raise
except Exception:
# Protect the monitor loop from any other exceptions
logger.exception(
f"Ignoring exception in task monitor {self.__class__.__name__} while "
f"executing <taskid: {taskid}>"
)
finally:
self._running.pop(taskid)
async def run_task(self, taskid: Taskid):
"""
Execute the task with the given taskid.
Default implementation executes `self.executor` if it exists,
otherwise raises NotImplementedError.
"""
if self.executor is not None:
await self.executor(taskid)
else:
raise NotImplementedError

View File

@@ -1,5 +1,7 @@
import asyncio
import time
from cmdClient.lib import SafeCancellation
from meta.errors import SafeCancellation
from cachetools import TTLCache
@@ -19,7 +21,7 @@ class BucketOverFull(BucketFull):
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full')
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
@@ -27,21 +29,37 @@ class Bucket:
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.time()
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
return (self._level + 1 - self.max_level) * self.leak_rate
def _leak(self):
if self._level:
elapsed = time.time() - self._last_checked
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.time()
self._last_checked = time.monotonic()
def request(self):
self._leak()
@@ -58,6 +76,21 @@ class Bucket:
self._last_full = False
self._level += 1
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):

77
bot/utils/transformers.py Normal file
View File

@@ -0,0 +1,77 @@
import discord
from discord import app_commands as appcmds
from discord.app_commands import Transformer
from discord.enums import AppCommandOptionType
from meta.errors import UserInputError
from babel.translator import ctx_translator
from .lib import parse_duration, strfdur
from . import util_babel
_, _p = util_babel._, util_babel._p
class DurationTransformer(Transformer):
"""
Duration parameter, with included autocompletion.
"""
def __init__(self, multiplier=1):
# Multiplier used for a raw integer value
self.multiplier = multiplier
@property
def type(self):
return AppCommandOptionType.string
async def transform(self, interaction: discord.Interaction, value: str) -> int:
"""
Returns the number of seconds in the parsed duration.
Raises UserInputError if the duration cannot be parsed.
"""
translator = ctx_translator.get()
t = translator.t
if value.isdigit():
return int(value) * self.multiplier
duration = parse_duration(value)
if duration is None:
raise UserInputError(
t(_p('utils:parse_dur|error', "Cannot parse `{value}` as a duration.")).format(
value=value
)
)
return duration or 0
async def autocomplete(self, interaction: discord.Interaction, partial: str):
"""
Default autocomplete for Duration parameters.
Attempts to parse the partial value as a duration, and reformat it as an autocomplete choice.
If not possible, displays an error message.
"""
translator = ctx_translator.get()
t = translator.t
if partial.isdigit():
duration = int(partial) * self.multiplier
else:
duration = parse_duration(partial)
if duration is None:
choice = appcmds.Choice(
name=t(_p(
'util:Duration|acmpl|error',
"Cannot extract duration from \"{partial}\""
)).format(partial=partial),
value=partial
)
else:
choice = appcmds.Choice(
name=strfdur(duration, short=False, show_days=True),
value=partial
)
return [choice]

View File

@@ -128,7 +128,7 @@ class LeoUI(View):
slave.stop()
super().stop()
async def close(self):
async def close(self, msg=None):
self.stop()
await self.cleanup()

View File

@@ -44,3 +44,4 @@ cancel = <:xbigger:975880828653568012>
refresh = <:cyclebigger:975880828611600404>
tick = :✅:
clock = :⏱️:

View File

@@ -121,7 +121,10 @@ CREATE TABLE analytics.gui_renders(
--- }}}
-- TODO: Correct foreign keys for member table
ALTER TABLE members
ADD CONSTRAINT fk_members_users FOREIGN KEY (userid) REFERENCES user_config (userid) ON DELETE CASCADE NOT VALID;
ALTER TABLE members
ADD CONSTRAINT fk_members_guilds FOREIGN KEY (guildid) REFERENCES guild_config (guildid) ON DELETE CASCADE NOT VALID;
-- Localisation data {{{
ALTER TABLE user_config ADD COLUMN locale_hint TEXT;
@@ -130,6 +133,12 @@ ALTER TABLE guild_config ADD COLUMN locale TEXT;
ALTER TABLE guild_config ADD COLUMN force_locale BOOLEAN;
--}}}
-- Reminder data {{{
ALTER TABLE reminders ADD COLUMN failed BOOLEAN;
ALTER TABLE reminders
ADD CONSTRAINT fk_reminders_users FOREIGN KEY (userid) REFERENCES user_config (userid) ON DELETE CASCADE NOT VALID;
-- }}}
INSERT INTO VersionHistory (version, author) VALUES (13, 'v12-v13 migration');
-- vim: set fdm=marker: