Initial framework

This commit is contained in:
2025-07-24 03:33:42 +10:00
commit 9c0ae404c8
31 changed files with 3435 additions and 0 deletions

88
src/utils/lib.py Normal file
View File

@@ -0,0 +1,88 @@
import re
import datetime as dt
def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def utc_now() -> dt.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
def replace_multiple(format_string, mapping):
"""
Subsistutes the keys from the format_dict with their corresponding values.
Substitution is non-chained, and done in a single pass via regex.
"""
if not mapping:
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
return string
def parse_dur(time_str):
"""
Parses a user provided time duration string into a number of seconds.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds

166
src/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,166 @@
import asyncio
import time
import logging
from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
delay = (self._level + 1 - self.max_level) * self.leak_rate
else:
delay = 0
return delay
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level > self.max_level:
raise BucketOverFull
elif self._level == self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
def fill(self):
self._leak()
self._level = max(self._level, self.max_level + 1)
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
async def wrapped(self, coro):
await self.wait()
self.request()
await coro
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
bucket.request()
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")