126 lines
3.7 KiB
Python
126 lines
3.7 KiB
Python
import asyncio
|
|
import time
|
|
|
|
from meta.errors import SafeCancellation
|
|
|
|
from cachetools import TTLCache
|
|
|
|
|
|
class BucketFull(Exception):
|
|
"""
|
|
Throw when a requested Bucket is already full
|
|
"""
|
|
pass
|
|
|
|
|
|
class BucketOverFull(BucketFull):
|
|
"""
|
|
Throw when a requested Bucket is overfull
|
|
"""
|
|
pass
|
|
|
|
|
|
class Bucket:
|
|
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
|
|
|
def __init__(self, max_level, empty_time):
|
|
self.max_level = max_level
|
|
self.empty_time = empty_time
|
|
self.leak_rate = max_level / empty_time
|
|
|
|
self._level = 0
|
|
self._last_checked = time.monotonic()
|
|
|
|
self._last_full = False
|
|
self._wait_lock = asyncio.Lock()
|
|
|
|
@property
|
|
def full(self) -> bool:
|
|
"""
|
|
Return whether the bucket is 'full',
|
|
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
|
"""
|
|
self._leak()
|
|
return self._level + 1 > self.max_level
|
|
|
|
@property
|
|
def overfull(self):
|
|
self._leak()
|
|
return self._level > self.max_level
|
|
|
|
@property
|
|
def delay(self):
|
|
self._leak()
|
|
if self._level + 1 > self.max_level:
|
|
return (self._level + 1 - self.max_level) * self.leak_rate
|
|
|
|
def _leak(self):
|
|
if self._level:
|
|
elapsed = time.monotonic() - self._last_checked
|
|
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
|
|
|
self._last_checked = time.monotonic()
|
|
|
|
def request(self):
|
|
self._leak()
|
|
if self._level + 1 > self.max_level + 1:
|
|
raise BucketOverFull
|
|
elif self._level + 1 > self.max_level:
|
|
self._level += 1
|
|
if self._last_full:
|
|
raise BucketOverFull
|
|
else:
|
|
self._last_full = True
|
|
raise BucketFull
|
|
else:
|
|
self._last_full = False
|
|
self._level += 1
|
|
|
|
async def wait(self):
|
|
"""
|
|
Wait until the bucket has room.
|
|
|
|
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
|
"""
|
|
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
|
# Otherwise multiple waiters will have the same delay,
|
|
# and race for the wakeup after sleep.
|
|
async with self._wait_lock:
|
|
# We do this in a loop in case asyncio.sleep throws us out early,
|
|
# or a synchronous request overflows the bucket while we are waiting.
|
|
while self.full:
|
|
await asyncio.sleep(self.delay)
|
|
|
|
|
|
class RateLimit:
|
|
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
|
self.max_level = max_level
|
|
self.empty_time = empty_time
|
|
|
|
self.error = error or "Too many requests, please slow down!"
|
|
self.buckets = cache
|
|
|
|
def request_for(self, key):
|
|
if not (bucket := self.buckets.get(key, None)):
|
|
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
|
|
|
try:
|
|
bucket.request()
|
|
except BucketOverFull:
|
|
raise SafeCancellation(details="Bucket overflow")
|
|
except BucketFull:
|
|
raise SafeCancellation(self.error, details="Bucket full")
|
|
|
|
def ward(self, member=True, key=None):
|
|
"""
|
|
Command ratelimit decorator.
|
|
"""
|
|
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
|
|
|
def decorator(func):
|
|
async def wrapper(ctx, *args, **kwargs):
|
|
self.request_for(key(ctx))
|
|
return await func(ctx, *args, **kwargs)
|
|
return wrapper
|
|
return decorator
|