rewrite: Reminders system.

This commit is contained in:
2022-11-24 23:12:20 +02:00
parent 0d5e801945
commit dd8609fac0
23 changed files with 1268 additions and 545 deletions

View File

@@ -0,0 +1,3 @@
from babel.translator import LocalBabel
util_babel = LocalBabel('utils')

View File

@@ -8,7 +8,11 @@ import discord
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
from discord.ui import View
# from cmdClient.lib import SafeCancellation
from babel.translator import ctx_translator
from . import util_babel
_, _p = util_babel._, util_babel._p
multiselect_regex = re.compile(
@@ -320,7 +324,7 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -
return "".join(reply_msg)
def parse_dur(time_str: str) -> int:
def _parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
@@ -642,7 +646,7 @@ def parse_ids(idstr: str) -> List[int]:
def error_embed(error, **kwargs) -> discord.Embed:
embed = discord.Embed(
colour=discord.Colour.red(),
colour=discord.Colour.brand_red(),
description=error,
timestamp=utc_now()
)
@@ -653,3 +657,52 @@ class DotDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
parse_dur_exps = [
(
_p(
'util:parse_dur|regex:day',
r"(?P<value>\d+)\s*(?:(d)|(day))"
),
60 * 60 * 24
),
(
_p(
'util:parse_dur|regex:hour',
r"(?P<value>\d+)\s*(?:(h)|(hour))"
),
60 * 60
),
(
_p(
'util:parse_dur|regex:minute',
r"(?P<value>\d+)\s*(?:(m)|(min))"
),
60
),
(
_p(
'util:parse_dur|regex:second',
r"(?P<value>\d+)\s*(?:(s)|(sec))"
),
1
)
]
def parse_duration(string: str) -> Optional[int]:
translator = ctx_translator.get()
if translator is None:
raise ValueError("Cannot parse duration without a translator.")
t = translator.t
seconds = 0
found = False
for expr, multiplier in parse_dur_exps:
match = re.search(t(expr), string, flags=re.IGNORECASE)
if match:
found = True
seconds += int(match.group('value')) * multiplier
return seconds if found else None

178
bot/utils/monitor.py Normal file
View File

@@ -0,0 +1,178 @@
import asyncio
import bisect
import logging
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
from .lib import utc_now
from .ratelimits import Bucket
logger = logging.getLogger(__name__)
Taskid = TypeVar('Taskid')
class TaskMonitor(Generic[Taskid]):
"""
Base class for a task monitor.
Stores tasks as a time-sorted list of taskids.
Subclasses may override `run_task` to implement an executor.
Adding or removing a single task has O(n) performance.
To bulk update tasks, instead use `schedule_tasks`.
Each taskid must be unique and hashable.
"""
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
# Ratelimit bucket to enforce maximum execution rate
self._bucket = bucket
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
self._wakeup: asyncio.Event = asyncio.Event()
self._monitor_task: Optional[self.Task] = None
# Task data
self._tasklist: list[Taskid] = []
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
# Running map ensures we keep a reference to the running task
# And allows simpler external cancellation if required
self._running: dict[Taskid, asyncio.Future] = {}
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Similar to `schedule_tasks`, but wipe and reset the tasklist.
"""
self._taskmap = {tid: time for tid, time in tasks}
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * tid * self._taskmap[tid])
self._wakeup.set()
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
"""
Schedule the given tasks.
Rather than repeatedly inserting tasks,
where the O(log n) insort is dominated by the O(n) list insertion,
we build an entirely new list, and always wake up the loop.
"""
self._taskmap |= {tid: time for tid, time in tasks}
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])
self._wakeup.set()
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
"""
Insert the provided task into the tasklist.
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
"""
if self._tasklist:
nextid = self._tasklist[-1]
wake = self._taskmap[nextid] >= timestamp
wake = wake or taskid == nextid
else:
wake = False
if taskid in self._taskmap:
self._tasklist.remove(taskid)
self._taskmap[taskid] = timestamp
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
if wake:
self._wakeup.set()
def cancel_tasks(self, *taskids: Taskid) -> None:
"""
Remove all tasks with the given taskids from the tasklist.
If the next task has this taskid, wake up the monitor loop.
"""
taskids = set(taskids)
wake = (self._tasklist and self._tasklist[-1] in taskids)
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
for tid in taskids:
self._taskmap.pop(tid, None)
if wake:
self._wakeup.set()
def start(self):
if self._monitor_task and not self._monitor_task.done():
self._monitor_task.cancel()
# Start the monitor
self._monitor_task = asyncio.create_task(self.monitor())
return self._monitor_task
async def monitor(self):
"""
Start the monitor.
Executes the tasks in `self.tasks` at the specified time.
This will shield task execution from cancellation
to avoid partial states.
"""
try:
while True:
self._wakeup.clear()
if not self._tasklist:
# No tasks left, just sleep until wakeup
await self._wakeup.wait()
else:
# Get the next task, sleep until wakeup or it is ready to run
nextid = self._tasklist[-1]
nexttime = self._taskmap[nextid]
sleep_for = nexttime - utc_now().timestamp()
try:
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
except asyncio.TimeoutError:
# Ready to run the task
self._tasklist.pop()
self._taskmap.pop(nextid, None)
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
else:
# Wakeup task fired, loop again
continue
except asyncio.CancelledError:
# Log closure and wait for remaining tasks
# A second cancellation will also cancel the tasks
logger.debug(
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
f"Waiting for {len(self._running)} running tasks to complete."
)
await asyncio.gather(*self._running.values(), return_exceptions=True)
async def _run(self, taskid: Taskid) -> None:
# Execute the task, respecting the ratelimit bucket
if self._bucket is not None:
# IMPLEMENTATION NOTE:
# Bucket.wait() should guarantee not more than n tasks/second are run
# and that a request directly afterwards will _not_ raise BucketFull
# Make sure that only one waiter is actually waiting on its sleep task
# The other waiters should be waiting on a lock around the sleep task
# Waiters are executed in wait-order, so if we only let a single waiter in
# we shouldn't get collisions.
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
# Or we will lose thread-safety for BucketFull
await self._bucket.wait()
fut = asyncio.create_task(self.run_task(taskid))
try:
await asyncio.shield(fut)
except asyncio.CancelledError:
raise
except Exception:
# Protect the monitor loop from any other exceptions
logger.exception(
f"Ignoring exception in task monitor {self.__class__.__name__} while "
f"executing <taskid: {taskid}>"
)
finally:
self._running.pop(taskid)
async def run_task(self, taskid: Taskid):
"""
Execute the task with the given taskid.
Default implementation executes `self.executor` if it exists,
otherwise raises NotImplementedError.
"""
if self.executor is not None:
await self.executor(taskid)
else:
raise NotImplementedError

125
bot/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,125 @@
import asyncio
import time
from meta.errors import SafeCancellation
from cachetools import TTLCache
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
return (self._level + 1 - self.max_level) * self.leak_rate
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level + 1 > self.max_level + 1:
raise BucketOverFull
elif self._level + 1 > self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
try:
bucket.request()
except BucketOverFull:
raise SafeCancellation(details="Bucket overflow")
except BucketFull:
raise SafeCancellation(self.error, details="Bucket full")
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator

77
bot/utils/transformers.py Normal file
View File

@@ -0,0 +1,77 @@
import discord
from discord import app_commands as appcmds
from discord.app_commands import Transformer
from discord.enums import AppCommandOptionType
from meta.errors import UserInputError
from babel.translator import ctx_translator
from .lib import parse_duration, strfdur
from . import util_babel
_, _p = util_babel._, util_babel._p
class DurationTransformer(Transformer):
"""
Duration parameter, with included autocompletion.
"""
def __init__(self, multiplier=1):
# Multiplier used for a raw integer value
self.multiplier = multiplier
@property
def type(self):
return AppCommandOptionType.string
async def transform(self, interaction: discord.Interaction, value: str) -> int:
"""
Returns the number of seconds in the parsed duration.
Raises UserInputError if the duration cannot be parsed.
"""
translator = ctx_translator.get()
t = translator.t
if value.isdigit():
return int(value) * self.multiplier
duration = parse_duration(value)
if duration is None:
raise UserInputError(
t(_p('utils:parse_dur|error', "Cannot parse `{value}` as a duration.")).format(
value=value
)
)
return duration or 0
async def autocomplete(self, interaction: discord.Interaction, partial: str):
"""
Default autocomplete for Duration parameters.
Attempts to parse the partial value as a duration, and reformat it as an autocomplete choice.
If not possible, displays an error message.
"""
translator = ctx_translator.get()
t = translator.t
if partial.isdigit():
duration = int(partial) * self.multiplier
else:
duration = parse_duration(partial)
if duration is None:
choice = appcmds.Choice(
name=t(_p(
'util:Duration|acmpl|error',
"Cannot extract duration from \"{partial}\""
)).format(partial=partial),
value=partial
)
else:
choice = appcmds.Choice(
name=strfdur(duration, short=False, show_days=True),
value=partial
)
return [choice]

View File

@@ -128,7 +128,7 @@ class LeoUI(View):
slave.stop()
super().stop()
async def close(self):
async def close(self, msg=None):
self.stop()
await self.cleanup()