feature (utils): Add a ratelimit implementation.
This commit is contained in:
92
bot/utils/ratelimits.py
Normal file
92
bot/utils/ratelimits.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import time
|
||||||
|
from cmdClient.lib import SafeCancellation
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFull(Exception):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is already full
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BucketOverFull(BucketFull):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is overfull
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket:
|
||||||
|
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full')
|
||||||
|
|
||||||
|
def __init__(self, max_level, empty_time):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
self.leak_rate = max_level / empty_time
|
||||||
|
|
||||||
|
self._level = 0
|
||||||
|
self._last_checked = time.time()
|
||||||
|
|
||||||
|
self._last_full = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def overfull(self):
|
||||||
|
self._leak()
|
||||||
|
return self._level > self.max_level
|
||||||
|
|
||||||
|
def _leak(self):
|
||||||
|
if self._level:
|
||||||
|
elapsed = time.time() - self._last_checked
|
||||||
|
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||||
|
|
||||||
|
self._last_checked = time.time()
|
||||||
|
|
||||||
|
def request(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level + 1 > self.max_level + 1:
|
||||||
|
raise BucketOverFull
|
||||||
|
elif self._level + 1 > self.max_level:
|
||||||
|
self._level += 1
|
||||||
|
if self._last_full:
|
||||||
|
raise BucketOverFull
|
||||||
|
else:
|
||||||
|
self._last_full = True
|
||||||
|
raise BucketFull
|
||||||
|
else:
|
||||||
|
self._last_full = False
|
||||||
|
self._level += 1
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimit:
|
||||||
|
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
|
||||||
|
self.error = error or "Too many requests, please slow down!"
|
||||||
|
self.buckets = cache
|
||||||
|
|
||||||
|
def request_for(self, key):
|
||||||
|
if not (bucket := self.buckets.get(key, None)):
|
||||||
|
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bucket.request()
|
||||||
|
except BucketOverFull:
|
||||||
|
raise SafeCancellation(details="Bucket overflow")
|
||||||
|
except BucketFull:
|
||||||
|
raise SafeCancellation(self.error, details="Bucket full")
|
||||||
|
|
||||||
|
def ward(self, member=True, key=None):
|
||||||
|
"""
|
||||||
|
Command ratelimit decorator.
|
||||||
|
"""
|
||||||
|
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(ctx, *args, **kwargs):
|
||||||
|
self.request_for(key(ctx))
|
||||||
|
return await func(ctx, *args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
Reference in New Issue
Block a user