rewrite: Reminders system.
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from babel.translator import LocalBabel
|
||||
|
||||
util_babel = LocalBabel('utils')
|
||||
|
||||
@@ -8,7 +8,11 @@ import discord
|
||||
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
|
||||
from discord.ui import View
|
||||
|
||||
# from cmdClient.lib import SafeCancellation
|
||||
from babel.translator import ctx_translator
|
||||
|
||||
from . import util_babel
|
||||
|
||||
_, _p = util_babel._, util_babel._p
|
||||
|
||||
|
||||
multiselect_regex = re.compile(
|
||||
@@ -320,7 +324,7 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -
|
||||
return "".join(reply_msg)
|
||||
|
||||
|
||||
def parse_dur(time_str: str) -> int:
|
||||
def _parse_dur(time_str: str) -> int:
|
||||
"""
|
||||
Parses a user provided time duration string into a timedelta object.
|
||||
|
||||
@@ -642,7 +646,7 @@ def parse_ids(idstr: str) -> List[int]:
|
||||
|
||||
def error_embed(error, **kwargs) -> discord.Embed:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
colour=discord.Colour.brand_red(),
|
||||
description=error,
|
||||
timestamp=utc_now()
|
||||
)
|
||||
@@ -653,3 +657,52 @@ class DotDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
parse_dur_exps = [
|
||||
(
|
||||
_p(
|
||||
'util:parse_dur|regex:day',
|
||||
r"(?P<value>\d+)\s*(?:(d)|(day))"
|
||||
),
|
||||
60 * 60 * 24
|
||||
),
|
||||
(
|
||||
_p(
|
||||
'util:parse_dur|regex:hour',
|
||||
r"(?P<value>\d+)\s*(?:(h)|(hour))"
|
||||
),
|
||||
60 * 60
|
||||
),
|
||||
(
|
||||
_p(
|
||||
'util:parse_dur|regex:minute',
|
||||
r"(?P<value>\d+)\s*(?:(m)|(min))"
|
||||
),
|
||||
60
|
||||
),
|
||||
(
|
||||
_p(
|
||||
'util:parse_dur|regex:second',
|
||||
r"(?P<value>\d+)\s*(?:(s)|(sec))"
|
||||
),
|
||||
1
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def parse_duration(string: str) -> Optional[int]:
|
||||
translator = ctx_translator.get()
|
||||
if translator is None:
|
||||
raise ValueError("Cannot parse duration without a translator.")
|
||||
t = translator.t
|
||||
|
||||
seconds = 0
|
||||
found = False
|
||||
for expr, multiplier in parse_dur_exps:
|
||||
match = re.search(t(expr), string, flags=re.IGNORECASE)
|
||||
if match:
|
||||
found = True
|
||||
seconds += int(match.group('value')) * multiplier
|
||||
|
||||
return seconds if found else None
|
||||
|
||||
178
bot/utils/monitor.py
Normal file
178
bot/utils/monitor.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import logging
|
||||
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
|
||||
|
||||
from .lib import utc_now
|
||||
from .ratelimits import Bucket
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Taskid = TypeVar('Taskid')
|
||||
|
||||
|
||||
class TaskMonitor(Generic[Taskid]):
|
||||
"""
|
||||
Base class for a task monitor.
|
||||
|
||||
Stores tasks as a time-sorted list of taskids.
|
||||
Subclasses may override `run_task` to implement an executor.
|
||||
|
||||
Adding or removing a single task has O(n) performance.
|
||||
To bulk update tasks, instead use `schedule_tasks`.
|
||||
|
||||
Each taskid must be unique and hashable.
|
||||
"""
|
||||
|
||||
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
|
||||
# Ratelimit bucket to enforce maximum execution rate
|
||||
self._bucket = bucket
|
||||
|
||||
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
|
||||
|
||||
self._wakeup: asyncio.Event = asyncio.Event()
|
||||
self._monitor_task: Optional[self.Task] = None
|
||||
|
||||
# Task data
|
||||
self._tasklist: list[Taskid] = []
|
||||
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
|
||||
|
||||
# Running map ensures we keep a reference to the running task
|
||||
# And allows simpler external cancellation if required
|
||||
self._running: dict[Taskid, asyncio.Future] = {}
|
||||
|
||||
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
||||
"""
|
||||
Similar to `schedule_tasks`, but wipe and reset the tasklist.
|
||||
"""
|
||||
self._taskmap = {tid: time for tid, time in tasks}
|
||||
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * tid * self._taskmap[tid])
|
||||
self._wakeup.set()
|
||||
|
||||
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
||||
"""
|
||||
Schedule the given tasks.
|
||||
|
||||
Rather than repeatedly inserting tasks,
|
||||
where the O(log n) insort is dominated by the O(n) list insertion,
|
||||
we build an entirely new list, and always wake up the loop.
|
||||
"""
|
||||
self._taskmap |= {tid: time for tid, time in tasks}
|
||||
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])
|
||||
self._wakeup.set()
|
||||
|
||||
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
|
||||
"""
|
||||
Insert the provided task into the tasklist.
|
||||
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
|
||||
"""
|
||||
if self._tasklist:
|
||||
nextid = self._tasklist[-1]
|
||||
wake = self._taskmap[nextid] >= timestamp
|
||||
wake = wake or taskid == nextid
|
||||
else:
|
||||
wake = False
|
||||
if taskid in self._taskmap:
|
||||
self._tasklist.remove(taskid)
|
||||
self._taskmap[taskid] = timestamp
|
||||
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
|
||||
if wake:
|
||||
self._wakeup.set()
|
||||
|
||||
def cancel_tasks(self, *taskids: Taskid) -> None:
|
||||
"""
|
||||
Remove all tasks with the given taskids from the tasklist.
|
||||
If the next task has this taskid, wake up the monitor loop.
|
||||
"""
|
||||
taskids = set(taskids)
|
||||
wake = (self._tasklist and self._tasklist[-1] in taskids)
|
||||
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
|
||||
for tid in taskids:
|
||||
self._taskmap.pop(tid, None)
|
||||
if wake:
|
||||
self._wakeup.set()
|
||||
|
||||
def start(self):
|
||||
if self._monitor_task and not self._monitor_task.done():
|
||||
self._monitor_task.cancel()
|
||||
# Start the monitor
|
||||
self._monitor_task = asyncio.create_task(self.monitor())
|
||||
return self._monitor_task
|
||||
|
||||
async def monitor(self):
|
||||
"""
|
||||
Start the monitor.
|
||||
Executes the tasks in `self.tasks` at the specified time.
|
||||
|
||||
This will shield task execution from cancellation
|
||||
to avoid partial states.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
self._wakeup.clear()
|
||||
if not self._tasklist:
|
||||
# No tasks left, just sleep until wakeup
|
||||
await self._wakeup.wait()
|
||||
else:
|
||||
# Get the next task, sleep until wakeup or it is ready to run
|
||||
nextid = self._tasklist[-1]
|
||||
nexttime = self._taskmap[nextid]
|
||||
sleep_for = nexttime - utc_now().timestamp()
|
||||
try:
|
||||
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
|
||||
except asyncio.TimeoutError:
|
||||
# Ready to run the task
|
||||
self._tasklist.pop()
|
||||
self._taskmap.pop(nextid, None)
|
||||
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
|
||||
else:
|
||||
# Wakeup task fired, loop again
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# Log closure and wait for remaining tasks
|
||||
# A second cancellation will also cancel the tasks
|
||||
logger.debug(
|
||||
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
|
||||
f"Waiting for {len(self._running)} running tasks to complete."
|
||||
)
|
||||
await asyncio.gather(*self._running.values(), return_exceptions=True)
|
||||
|
||||
async def _run(self, taskid: Taskid) -> None:
|
||||
# Execute the task, respecting the ratelimit bucket
|
||||
if self._bucket is not None:
|
||||
# IMPLEMENTATION NOTE:
|
||||
# Bucket.wait() should guarantee not more than n tasks/second are run
|
||||
# and that a request directly afterwards will _not_ raise BucketFull
|
||||
# Make sure that only one waiter is actually waiting on its sleep task
|
||||
# The other waiters should be waiting on a lock around the sleep task
|
||||
# Waiters are executed in wait-order, so if we only let a single waiter in
|
||||
# we shouldn't get collisions.
|
||||
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
|
||||
# Or we will lose thread-safety for BucketFull
|
||||
await self._bucket.wait()
|
||||
fut = asyncio.create_task(self.run_task(taskid))
|
||||
try:
|
||||
await asyncio.shield(fut)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# Protect the monitor loop from any other exceptions
|
||||
logger.exception(
|
||||
f"Ignoring exception in task monitor {self.__class__.__name__} while "
|
||||
f"executing <taskid: {taskid}>"
|
||||
)
|
||||
finally:
|
||||
self._running.pop(taskid)
|
||||
|
||||
async def run_task(self, taskid: Taskid):
|
||||
"""
|
||||
Execute the task with the given taskid.
|
||||
|
||||
Default implementation executes `self.executor` if it exists,
|
||||
otherwise raises NotImplementedError.
|
||||
"""
|
||||
if self.executor is not None:
|
||||
await self.executor(taskid)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
125
bot/utils/ratelimits.py
Normal file
125
bot/utils/ratelimits.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from meta.errors import SafeCancellation
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
|
||||
class BucketFull(Exception):
|
||||
"""
|
||||
Throw when a requested Bucket is already full
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BucketOverFull(BucketFull):
|
||||
"""
|
||||
Throw when a requested Bucket is overfull
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Bucket:
|
||||
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
||||
|
||||
def __init__(self, max_level, empty_time):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
self.leak_rate = max_level / empty_time
|
||||
|
||||
self._level = 0
|
||||
self._last_checked = time.monotonic()
|
||||
|
||||
self._last_full = False
|
||||
self._wait_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def full(self) -> bool:
|
||||
"""
|
||||
Return whether the bucket is 'full',
|
||||
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
||||
"""
|
||||
self._leak()
|
||||
return self._level + 1 > self.max_level
|
||||
|
||||
@property
|
||||
def overfull(self):
|
||||
self._leak()
|
||||
return self._level > self.max_level
|
||||
|
||||
@property
|
||||
def delay(self):
|
||||
self._leak()
|
||||
if self._level + 1 > self.max_level:
|
||||
return (self._level + 1 - self.max_level) * self.leak_rate
|
||||
|
||||
def _leak(self):
|
||||
if self._level:
|
||||
elapsed = time.monotonic() - self._last_checked
|
||||
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||
|
||||
self._last_checked = time.monotonic()
|
||||
|
||||
def request(self):
|
||||
self._leak()
|
||||
if self._level + 1 > self.max_level + 1:
|
||||
raise BucketOverFull
|
||||
elif self._level + 1 > self.max_level:
|
||||
self._level += 1
|
||||
if self._last_full:
|
||||
raise BucketOverFull
|
||||
else:
|
||||
self._last_full = True
|
||||
raise BucketFull
|
||||
else:
|
||||
self._last_full = False
|
||||
self._level += 1
|
||||
|
||||
async def wait(self):
|
||||
"""
|
||||
Wait until the bucket has room.
|
||||
|
||||
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
||||
"""
|
||||
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||
# Otherwise multiple waiters will have the same delay,
|
||||
# and race for the wakeup after sleep.
|
||||
async with self._wait_lock:
|
||||
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||
# or a synchronous request overflows the bucket while we are waiting.
|
||||
while self.full:
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
|
||||
self.error = error or "Too many requests, please slow down!"
|
||||
self.buckets = cache
|
||||
|
||||
def request_for(self, key):
|
||||
if not (bucket := self.buckets.get(key, None)):
|
||||
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||
|
||||
try:
|
||||
bucket.request()
|
||||
except BucketOverFull:
|
||||
raise SafeCancellation(details="Bucket overflow")
|
||||
except BucketFull:
|
||||
raise SafeCancellation(self.error, details="Bucket full")
|
||||
|
||||
def ward(self, member=True, key=None):
|
||||
"""
|
||||
Command ratelimit decorator.
|
||||
"""
|
||||
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(ctx, *args, **kwargs):
|
||||
self.request_for(key(ctx))
|
||||
return await func(ctx, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
77
bot/utils/transformers.py
Normal file
77
bot/utils/transformers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import discord
|
||||
from discord import app_commands as appcmds
|
||||
from discord.app_commands import Transformer
|
||||
from discord.enums import AppCommandOptionType
|
||||
|
||||
from meta.errors import UserInputError
|
||||
|
||||
from babel.translator import ctx_translator
|
||||
|
||||
from .lib import parse_duration, strfdur
|
||||
from . import util_babel
|
||||
|
||||
|
||||
_, _p = util_babel._, util_babel._p
|
||||
|
||||
|
||||
class DurationTransformer(Transformer):
|
||||
"""
|
||||
Duration parameter, with included autocompletion.
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier=1):
|
||||
# Multiplier used for a raw integer value
|
||||
self.multiplier = multiplier
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return AppCommandOptionType.string
|
||||
|
||||
async def transform(self, interaction: discord.Interaction, value: str) -> int:
|
||||
"""
|
||||
Returns the number of seconds in the parsed duration.
|
||||
Raises UserInputError if the duration cannot be parsed.
|
||||
"""
|
||||
translator = ctx_translator.get()
|
||||
t = translator.t
|
||||
|
||||
if value.isdigit():
|
||||
return int(value) * self.multiplier
|
||||
duration = parse_duration(value)
|
||||
if duration is None:
|
||||
raise UserInputError(
|
||||
t(_p('utils:parse_dur|error', "Cannot parse `{value}` as a duration.")).format(
|
||||
value=value
|
||||
)
|
||||
)
|
||||
return duration or 0
|
||||
|
||||
async def autocomplete(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
Default autocomplete for Duration parameters.
|
||||
|
||||
Attempts to parse the partial value as a duration, and reformat it as an autocomplete choice.
|
||||
If not possible, displays an error message.
|
||||
"""
|
||||
translator = ctx_translator.get()
|
||||
t = translator.t
|
||||
|
||||
if partial.isdigit():
|
||||
duration = int(partial) * self.multiplier
|
||||
else:
|
||||
duration = parse_duration(partial)
|
||||
|
||||
if duration is None:
|
||||
choice = appcmds.Choice(
|
||||
name=t(_p(
|
||||
'util:Duration|acmpl|error',
|
||||
"Cannot extract duration from \"{partial}\""
|
||||
)).format(partial=partial),
|
||||
value=partial
|
||||
)
|
||||
else:
|
||||
choice = appcmds.Choice(
|
||||
name=strfdur(duration, short=False, show_days=True),
|
||||
value=partial
|
||||
)
|
||||
return [choice]
|
||||
@@ -128,7 +128,7 @@ class LeoUI(View):
|
||||
slave.stop()
|
||||
super().stop()
|
||||
|
||||
async def close(self):
|
||||
async def close(self, msg=None):
|
||||
self.stop()
|
||||
await self.cleanup()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user