Initial Template
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "src/data"]
|
||||
path = src/data
|
||||
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git
|
||||
15
config/example-bot.conf
Normal file
15
config/example-bot.conf
Normal file
@@ -0,0 +1,15 @@
|
||||
[BOT]
|
||||
prefix = ?
|
||||
bot_id =
|
||||
|
||||
ALSO_READ = config/secrets.conf
|
||||
|
||||
wshost = localhost
|
||||
wsport = 4343
|
||||
wsdomain = localhost:4343
|
||||
|
||||
[LOGGING]
|
||||
general_log =
|
||||
warning_log =
|
||||
error_log =
|
||||
critical_log =
|
||||
7
config/example-secrets.conf
Normal file
7
config/example-secrets.conf
Normal file
@@ -0,0 +1,7 @@
|
||||
[CROCBOT]
|
||||
client_id =
|
||||
client_secret =
|
||||
|
||||
[DATA]
|
||||
args =
|
||||
appid =
|
||||
0
data/.gitignore
vendored
Normal file
0
data/.gitignore
vendored
Normal file
63
data/schema.sql
Normal file
63
data/schema.sql
Normal file
@@ -0,0 +1,63 @@
|
||||
-- Metadata {{{
|
||||
CREATE TABLE version_history(
|
||||
component TEXT NOT NULL,
|
||||
from_version INTEGER NOT NULL,
|
||||
to_version INTEGER NOT NULL,
|
||||
author TEXT NOT NULL,
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
);
|
||||
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW._timestamp = now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ language 'plpgsql';
|
||||
|
||||
CREATE TABLE app_config(
|
||||
appname TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Twitch Auth {{{
|
||||
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('TWITCH_AUTH', 0, 1, 'Initial Creation');
|
||||
|
||||
-- Authorisation tokens allowing us to take actions on behalf of certain users or channels.
|
||||
-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope.
|
||||
CREATE TABLE user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE TRIGGER user_auth_timestamp BEFORE UPDATE ON user_auth
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
CREATE TABLE user_auth_scopes(
|
||||
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
scope TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Which joins will be joined at startup,
|
||||
-- and any configurable choices needed when joining the channel
|
||||
CREATE TABLE bot_channels(
|
||||
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
autojoin BOOLEAN DEFAULT true,
|
||||
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
|
||||
-- }}}
|
||||
|
||||
|
||||
|
||||
-- vim: set fdm=marker:
|
||||
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
twitchio
|
||||
psycopg[pool]
|
||||
cachetools
|
||||
discord.py
|
||||
12
scripts/start_bot.py
Normal file
12
scripts/start_bot.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# !/bin/python3
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from bot import _main
|
||||
_main()
|
||||
42
src/bot.py
Normal file
42
src/bot.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import websockets
|
||||
|
||||
from twitchio.web import AiohttpAdapter
|
||||
|
||||
from meta import CrocBot, conf, setup_main_logger, args
|
||||
from data import Database
|
||||
|
||||
from modules import twitch_setup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
db = Database(conf.data['args'])
|
||||
|
||||
async with db.open():
|
||||
adapter = AiohttpAdapter(
|
||||
host=conf.bot.get('wshost', None),
|
||||
port=conf.bot.getint('wsport', None),
|
||||
domain=conf.bot.get('wsdomain', None),
|
||||
eventsub_secret=conf.bot.get('eventsub_secret', None)
|
||||
)
|
||||
|
||||
bot = CrocBot(
|
||||
config=conf,
|
||||
dbconn=db,
|
||||
adapter=adapter,
|
||||
setup=twitch_setup,
|
||||
)
|
||||
|
||||
try:
|
||||
await bot.start()
|
||||
finally:
|
||||
await bot.close()
|
||||
|
||||
|
||||
def _main():
|
||||
setup_main_logger()
|
||||
|
||||
asyncio.run(main())
|
||||
82
src/botdata.py
Normal file
82
src/botdata.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from data import Registry, RowModel, Table
|
||||
from data.columns import String, Timestamp, Integer, Bool
|
||||
|
||||
|
||||
class UserAuth(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_auth'
|
||||
_cache_ = {}
|
||||
|
||||
userid = String(primary=True)
|
||||
token = String()
|
||||
refresh_token = String()
|
||||
created_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class BotChannel(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE bot_channels(
|
||||
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
autojoin BOOLEAN DEFAULT true,
|
||||
listen_redeems BOOLEAN,
|
||||
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'bot_channels'
|
||||
_cache_ = {}
|
||||
|
||||
userid = String(primary=True)
|
||||
autojoin = Bool()
|
||||
listen_redeems = Bool()
|
||||
joined_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class VersionHistory(RowModel):
|
||||
"""
|
||||
CREATE TABLE version_history(
|
||||
component TEXT NOT NULL,
|
||||
from_version INTEGER NOT NULL,
|
||||
to_version INTEGER NOT NULL,
|
||||
author TEXT NOT NULL,
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'version_history'
|
||||
_cache_ = {}
|
||||
|
||||
component = String()
|
||||
from_version = Integer()
|
||||
to_version = Integer()
|
||||
author = String()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class BotData(Registry):
|
||||
version_history = VersionHistory.table
|
||||
|
||||
user_auth = UserAuth.table
|
||||
bot_channels = BotChannel.table
|
||||
|
||||
"""
|
||||
CREATE TABLE user_auth_scopes(
|
||||
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
scope TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
user_auth_scopes = Table('user_auth_scopes')
|
||||
|
||||
22
src/constants.py
Normal file
22
src/constants.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from twitchio import Scopes
|
||||
|
||||
|
||||
CONFIG_FILE = 'config/bot.conf'
|
||||
|
||||
SCHEMA_VERSIONS = {
|
||||
'ROOT': 1,
|
||||
'TWITCH_AUTH': 1
|
||||
}
|
||||
|
||||
# Requested scopes for the bots own twitch user
|
||||
BOTUSER_SCOPES = Scopes((
|
||||
Scopes.user_read_chat,
|
||||
Scopes.user_write_chat,
|
||||
Scopes.user_bot,
|
||||
Scopes.channel_bot,
|
||||
))
|
||||
|
||||
# Default requested scopes for joining a channel
|
||||
CHANNEL_SCOPES = Scopes((
|
||||
Scopes.channel_bot,
|
||||
))
|
||||
1
src/data
Submodule
1
src/data
Submodule
Submodule src/data added at cfdfe0eb50
4
src/meta/__init__.py
Normal file
4
src/meta/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .args import args
|
||||
from .bot import Bot
|
||||
from .config import Conf, conf
|
||||
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task
|
||||
28
src/meta/args.py
Normal file
28
src/meta/args.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import argparse
|
||||
|
||||
from constants import CONFIG_FILE
|
||||
|
||||
# ------------------------------
|
||||
# Parsed commandline arguments
|
||||
# ------------------------------
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--conf',
|
||||
dest='config',
|
||||
default=CONFIG_FILE,
|
||||
help="Path to configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
dest='host',
|
||||
default='127.0.0.1',
|
||||
help="IP address to run the websocket server on."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
dest='port',
|
||||
default='5001',
|
||||
help="Port to run the websocket server on."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
174
src/meta/bot.py
Normal file
174
src/meta/bot.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from twitchio.authentication import UserTokenPayload
|
||||
from twitchio.ext import commands
|
||||
from twitchio import Scopes, eventsub
|
||||
|
||||
from data import Database, ORDER
|
||||
from botdata import BotData, UserAuth, BotChannel, VersionHistory
|
||||
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
|
||||
|
||||
from .config import Conf
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Bot(commands.Bot):
|
||||
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
|
||||
kwargs.setdefault('client_id', config.bot['client_id'])
|
||||
kwargs.setdefault('client_secret', config.bot['client_secret'])
|
||||
kwargs.setdefault('bot_id', config.bot['bot_id'])
|
||||
kwargs.setdefault('prefix', config.bot['prefix'])
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if config.bot.get('eventsub_secret', None):
|
||||
self.using_webhooks = True
|
||||
else:
|
||||
self.using_webhooks = False
|
||||
|
||||
self.config = config
|
||||
self.dbconn = dbconn
|
||||
self.data: BotData = dbconn.load_registry(BotData())
|
||||
self._setup_hook = setup
|
||||
|
||||
self.joined: dict[str, BotChannel] = {}
|
||||
|
||||
async def event_ready(self):
|
||||
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||
logger.info("Logged in as %s", self.bot_id)
|
||||
|
||||
async def version_check(self, component: str, req_version: int):
|
||||
# Query the database to confirm that the given component is listed with the given version.
|
||||
# Typically done upon loading a component
|
||||
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
||||
|
||||
version = rows[0].to_version if rows else 0
|
||||
|
||||
if version != req_version:
|
||||
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
||||
else:
|
||||
logger.debug(
|
||||
"Component %s passed version check with version %s",
|
||||
component,
|
||||
version
|
||||
)
|
||||
return True
|
||||
|
||||
async def setup_hook(self):
|
||||
await self.data.init()
|
||||
for component, req in SCHEMA_VERSIONS.items():
|
||||
await self.version_check(component, req)
|
||||
|
||||
if self._setup_hook is not None:
|
||||
await self._setup_hook(self)
|
||||
|
||||
# Get all current bot channels
|
||||
channels = await BotChannel.fetch_where(autojoin=True)
|
||||
|
||||
# Join the channels
|
||||
await self.join_channels(*channels)
|
||||
|
||||
# Build bot account's own url
|
||||
scopes = BOTUSER_SCOPES
|
||||
url = self.get_auth_url(scopes)
|
||||
logger.info("Bot account authorisation url: %s", url)
|
||||
|
||||
# Build everyone else's url
|
||||
scopes = CHANNEL_SCOPES
|
||||
url = self.get_auth_url(scopes)
|
||||
logger.info("User account authorisation url: %s", url)
|
||||
|
||||
logger.info("Finished setup")
|
||||
|
||||
def get_auth_url(self, scopes: Optional[Scopes] = None):
|
||||
if scopes is None:
|
||||
scopes = Scopes((Scopes.channel_bot,))
|
||||
|
||||
url = self._adapter.get_authorization_url(scopes=scopes)
|
||||
return url
|
||||
|
||||
async def join_channels(self, *channels: BotChannel):
|
||||
"""
|
||||
Register webhook subscriptions to the given channel(s).
|
||||
"""
|
||||
# TODO: If channels are already joined, unsubscribe
|
||||
for channel in channels:
|
||||
sub = None
|
||||
try:
|
||||
sub = eventsub.ChatMessageSubscription(
|
||||
broadcaster_user_id=channel.userid,
|
||||
user_id=self.bot_id,
|
||||
)
|
||||
if self.using_webhooks:
|
||||
resp = await self.subscribe_webhook(sub)
|
||||
else:
|
||||
resp = await self.subscribe_websocket(sub)
|
||||
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
||||
self.joined[channel.userid] = channel
|
||||
self.safe_dispatch('channel_joined', payload=channel)
|
||||
except Exception:
|
||||
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
|
||||
|
||||
async def event_oauth_authorized(self, payload: UserTokenPayload):
|
||||
logger.debug("Oauth flow authorization with payload %s", repr(payload))
|
||||
# Save the token and scopes and update internal authorisations
|
||||
resp = await self.add_token(payload.access_token, payload.refresh_token)
|
||||
if resp.user_id is None:
|
||||
logger.warning(
|
||||
"Oauth flow recieved with no user_id. Payload was: %s",
|
||||
repr(payload)
|
||||
)
|
||||
return
|
||||
|
||||
# If the scopes authorised included channel:bot, ensure a BotChannel exists
|
||||
# And join it if needed
|
||||
if Scopes.channel_bot.value in resp.scopes:
|
||||
bot_channel = await BotChannel.fetch_or_create(
|
||||
resp.user_id,
|
||||
autojoin=True,
|
||||
)
|
||||
if bot_channel.autojoin:
|
||||
await self.join_channels(bot_channel)
|
||||
|
||||
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
|
||||
|
||||
async def add_token(self, token: str, refresh: str):
|
||||
# Update the tokens in internal cache
|
||||
# This also validates the token
|
||||
# And hopefully gets the userid and scopes
|
||||
resp = await super().add_token(token, refresh)
|
||||
if resp.user_id is None:
|
||||
logger.warning(
|
||||
"Added a token with no user_id. Response was: %s",
|
||||
repr(resp)
|
||||
)
|
||||
return resp
|
||||
userid = resp.user_id
|
||||
new_scopes = resp.scopes
|
||||
|
||||
# Save the token and scopes to data
|
||||
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
||||
async with self.dbconn.connection() as conn:
|
||||
self.dbconn.conn = conn
|
||||
async with conn.transaction():
|
||||
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
||||
if row.token != token or row.refresh_token != refresh:
|
||||
await row.update(token=token, refresh_token=refresh)
|
||||
await self.data.user_auth_scopes.delete_where(userid=userid)
|
||||
await self.data.user_auth_scopes.insert_many(
|
||||
('userid', 'scope'),
|
||||
*((userid, scope) for scope in new_scopes)
|
||||
)
|
||||
|
||||
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
||||
return resp
|
||||
|
||||
async def load_tokens(self, path: str | None = None):
|
||||
for row in await UserAuth.fetch_where():
|
||||
try:
|
||||
await self.add_token(row.token, row.refresh_token)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add token for {row}")
|
||||
105
src/meta/config.py
Normal file
105
src/meta/config.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import configparser as cfgp
|
||||
|
||||
from .args import args
|
||||
|
||||
|
||||
class MapDotProxy:
|
||||
"""
|
||||
Allows dot access to an underlying Mappable object.
|
||||
"""
|
||||
__slots__ = ("_map", "_converter")
|
||||
|
||||
def __init__(self, mappable, converter=None):
|
||||
self._map = mappable
|
||||
self._converter = converter
|
||||
|
||||
def __getattribute__(self, key):
|
||||
_map = object.__getattribute__(self, '_map')
|
||||
if key == '_map':
|
||||
return _map
|
||||
if key in _map:
|
||||
_converter = object.__getattribute__(self, '_converter')
|
||||
if _converter:
|
||||
return _converter(_map[key])
|
||||
else:
|
||||
return _map[key]
|
||||
else:
|
||||
return object.__getattribute__(_map, key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._map.__getitem__(key)
|
||||
|
||||
|
||||
class ConfigParser(cfgp.ConfigParser):
|
||||
"""
|
||||
Extension of base ConfigParser allowing optional
|
||||
section option retrieval without defaults.
|
||||
"""
|
||||
def options(self, section, no_defaults=False, **kwargs):
|
||||
if no_defaults:
|
||||
try:
|
||||
return list(self._sections[section].keys())
|
||||
except KeyError:
|
||||
raise cfgp.NoSectionError(section)
|
||||
else:
|
||||
return super().options(section, **kwargs)
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self, configfile, section_name="DEFAULT"):
|
||||
self.configfile = configfile
|
||||
|
||||
self.config = ConfigParser(
|
||||
converters={
|
||||
"intlist": self._getintlist,
|
||||
"list": self._getlist,
|
||||
}
|
||||
)
|
||||
|
||||
with open(configfile) as conff:
|
||||
# Opening with read_file mainly to ensure the file exists
|
||||
self.config.read_file(conff)
|
||||
|
||||
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||
|
||||
self.default = self.config["DEFAULT"]
|
||||
self.section = MapDotProxy(self.config[self.section_name])
|
||||
self.bot = self.section
|
||||
|
||||
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||
read = set()
|
||||
while more_to_read:
|
||||
to_read = more_to_read.pop(0)
|
||||
read.add(to_read)
|
||||
self.config.read(to_read)
|
||||
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||
if path not in read and path not in more_to_read]
|
||||
more_to_read.extend(new_paths)
|
||||
|
||||
global conf
|
||||
conf = self
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.section[key].strip()
|
||||
|
||||
def __getattr__(self, section):
|
||||
name = section.upper()
|
||||
return self.config[name]
|
||||
|
||||
def get(self, name, fallback=None):
|
||||
result = self.section.get(name, fallback)
|
||||
return result.strip() if result else result
|
||||
|
||||
def _getintlist(self, value):
|
||||
return [int(item.strip()) for item in value.split(',')]
|
||||
|
||||
def _getlist(self, value):
|
||||
return [item.strip() for item in value.split(',')]
|
||||
|
||||
def write(self):
|
||||
with open(self.configfile, 'w') as conffile:
|
||||
self.config.write(conffile)
|
||||
|
||||
|
||||
conf = Conf(args.config, 'BOT')
|
||||
466
src/meta/logger.py
Normal file
466
src/meta/logger.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from logging.handlers import QueueListener, QueueHandler
|
||||
import queue
|
||||
import multiprocessing
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .config import conf
|
||||
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||
|
||||
|
||||
log_logger = logging.getLogger(__name__)
|
||||
log_logger.propagate = False
|
||||
|
||||
|
||||
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
|
||||
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
|
||||
|
||||
context: ContextVar[Optional[str]] = ContextVar('context', default=None)
|
||||
|
||||
def set_logging_context(
|
||||
context: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
stack: Optional[tuple[str, ...]] = None
|
||||
):
|
||||
"""
|
||||
Statically set the logging context variables to the given values.
|
||||
|
||||
If `action` is given, pushes it onto the `log_action_stack`.
|
||||
"""
|
||||
if context is not None:
|
||||
log_context.set(context)
|
||||
if action is not None or stack is not None:
|
||||
astack = log_action_stack.get()
|
||||
newstack = stack if stack is not None else astack
|
||||
if action is not None:
|
||||
newstack = (*newstack, action)
|
||||
log_action_stack.set(newstack)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def logging_context(context=None, action=None, stack=None):
|
||||
"""
|
||||
Context manager for executing a block of code in a given logging context.
|
||||
|
||||
This context manager should only be used around synchronous code.
|
||||
This is because async code *may* get cancelled or externally garbage collected,
|
||||
in which case the finally block will be executed in the wrong context.
|
||||
See https://github.com/python/cpython/issues/93740
|
||||
This can be refactored nicely if this gets merged:
|
||||
https://github.com/python/cpython/pull/99634
|
||||
|
||||
(It will not necessarily break on async code,
|
||||
if the async code can be guaranteed to clean up in its own context.)
|
||||
"""
|
||||
if context is not None:
|
||||
oldcontext = log_context.get()
|
||||
log_context.set(context)
|
||||
if action is not None or stack is not None:
|
||||
astack = log_action_stack.get()
|
||||
newstack = stack if stack is not None else astack
|
||||
if action is not None:
|
||||
newstack = (*newstack, action)
|
||||
log_action_stack.set(newstack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if context is not None:
|
||||
log_context.set(oldcontext)
|
||||
if stack is not None or action is not None:
|
||||
log_action_stack.set(astack)
|
||||
|
||||
|
||||
def with_log_ctx(isolate=True, **kwargs):
|
||||
"""
|
||||
Execute a coroutine inside a given logging context.
|
||||
|
||||
If `isolate` is true, ensures that context does not leak
|
||||
outside the coroutine.
|
||||
|
||||
If `isolate` is false, just statically set the context,
|
||||
which will leak unless the coroutine is
|
||||
called in an externally copied context.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
if isolate:
|
||||
with logging_context(**kwargs):
|
||||
# Task creation will synchronously copy the context
|
||||
# This is gc safe
|
||||
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
|
||||
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
|
||||
return await task
|
||||
else:
|
||||
# This will leak context changes
|
||||
set_logging_context(**kwargs)
|
||||
return await func(*w_args, **w_kwargs)
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
log_wrap = with_log_ctx
|
||||
|
||||
|
||||
def persist_task(task_collection: set):
|
||||
"""
|
||||
Coroutine decorator that ensures the coroutine is scheduled as a task
|
||||
and added to the given task_collection for strong reference
|
||||
when it is called.
|
||||
|
||||
This is just a hack to handle discord.py events potentially
|
||||
being unexpectedly garbage collected.
|
||||
|
||||
Since this also implicitly schedules the coroutine as a task when it is called,
|
||||
the coroutine will also be run inside an isolated context.
|
||||
"""
|
||||
def decorator(coro):
|
||||
@wraps(coro)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
name = f"persisted-{coro.__name__}"
|
||||
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
|
||||
task_collection.add(task)
|
||||
task.add_done_callback(lambda f: task_collection.discard(f))
|
||||
await task
|
||||
|
||||
|
||||
RESET_SEQ = "\033[0m"
|
||||
COLOR_SEQ = "\033[3%dm"
|
||||
BOLD_SEQ = "\033[1m"
|
||||
"]]]"
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||
|
||||
|
||||
def colour_escape(fmt: str) -> str:
|
||||
cmap = {
|
||||
'%(black)': COLOR_SEQ % BLACK,
|
||||
'%(red)': COLOR_SEQ % RED,
|
||||
'%(green)': COLOR_SEQ % GREEN,
|
||||
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||
'%(blue)': COLOR_SEQ % BLUE,
|
||||
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||
'%(cyan)': COLOR_SEQ % CYAN,
|
||||
'%(white)': COLOR_SEQ % WHITE,
|
||||
'%(reset)': RESET_SEQ,
|
||||
'%(bold)': BOLD_SEQ,
|
||||
}
|
||||
for key, value in cmap.items():
|
||||
fmt = fmt.replace(key, value)
|
||||
return fmt
|
||||
|
||||
|
||||
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
|
||||
'%(cyan)%(app)-15s%(reset)|' +
|
||||
'%(cyan)%(context)-24s%(reset)|' +
|
||||
'%(cyan)%(actionstr)-22s%(reset)|' +
|
||||
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||
log_format = colour_escape(log_format)
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=log_format,
|
||||
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
class ExactLevelFilter(logging.Filter):
|
||||
def __init__(self, target_level, name=""):
|
||||
super().__init__(name)
|
||||
self.target_level = target_level
|
||||
|
||||
def filter(self, record):
|
||||
return (record.levelno == self.target_level)
|
||||
|
||||
|
||||
class ThreadFilter(logging.Filter):
|
||||
def __init__(self, thread_name):
|
||||
super().__init__("")
|
||||
self.thread = thread_name
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.threadName == self.thread else 0
|
||||
|
||||
|
||||
class ContextInjection(logging.Filter):
|
||||
def filter(self, record):
|
||||
# These guards are to allow override through _extra
|
||||
# And to ensure the injection is idempotent
|
||||
if not hasattr(record, 'context'):
|
||||
record.context = log_context.get()
|
||||
|
||||
if not hasattr(record, 'actionstr'):
|
||||
action_stack = log_action_stack.get()
|
||||
if hasattr(record, 'action'):
|
||||
action_stack = (*action_stack, record.action)
|
||||
if action_stack:
|
||||
record.actionstr = ' ➔ '.join(action_stack)
|
||||
else:
|
||||
record.actionstr = "Unknown Action"
|
||||
|
||||
if not hasattr(record, 'app'):
|
||||
record.app = log_app.get()
|
||||
|
||||
if not hasattr(record, 'ctx'):
|
||||
if ctx := context.get():
|
||||
record.ctx = repr(ctx)
|
||||
else:
|
||||
record.ctx = None
|
||||
|
||||
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||
record.ctxstr = '\n' + record.ctx
|
||||
else:
|
||||
record.ctxstr = ""
|
||||
return True
|
||||
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_out)
|
||||
log_logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logging_handler_err.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_err)
|
||||
log_logger.addHandler(logging_handler_err)
|
||||
|
||||
|
||||
class LocalQueueHandler(QueueHandler):
|
||||
def _emit(self, record: logging.LogRecord) -> None:
|
||||
# Removed the call to self.prepare(), handle task cancellation
|
||||
try:
|
||||
self.enqueue(record)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class WebHookHandler(logging.StreamHandler):
|
||||
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
|
||||
super().__init__()
|
||||
self.webhook_url = webhook_url
|
||||
self.prefix = prefix
|
||||
self.batched = ""
|
||||
self.batch = batch
|
||||
self.loop = loop
|
||||
self.batch_delay = 10
|
||||
self.batch_task = None
|
||||
self.last_batched = None
|
||||
self.waiting = []
|
||||
|
||||
self.bucket = Bucket(20, 40)
|
||||
self.ignored = 0
|
||||
|
||||
self.session = None
|
||||
self.webhook = None
|
||||
|
||||
def get_loop(self):
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
return self.loop
|
||||
|
||||
def emit(self, record):
|
||||
self.format(record)
|
||||
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||
|
||||
def _post(self, record):
|
||||
if self.session is None:
|
||||
self.setup()
|
||||
asyncio.create_task(self.post(record))
|
||||
|
||||
def setup(self):
|
||||
from discord import Webhook
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
|
||||
|
||||
async def post(self, record):
|
||||
if record.context == 'Webhook Logger':
|
||||
# Don't livelog livelog errors
|
||||
# Otherwise we recurse and Cloudflare hates us
|
||||
return
|
||||
log_context.set("Webhook Logger")
|
||||
log_action_stack.set(("Logging",))
|
||||
log_app.set(record.app)
|
||||
|
||||
try:
|
||||
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||
message = f"{header}\n{record.msg}{context}"
|
||||
|
||||
if len(message) > 1900:
|
||||
as_file = True
|
||||
else:
|
||||
as_file = False
|
||||
message = "```md\n{}\n```".format(message)
|
||||
|
||||
# Post the log message(s)
|
||||
if self.batch:
|
||||
if len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
await self._send(message, as_file=as_file)
|
||||
else:
|
||||
self.batched += message
|
||||
if len(self.batched) + len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
else:
|
||||
asyncio.create_task(self._schedule_batched())
|
||||
else:
|
||||
await self._send(message, as_file=as_file)
|
||||
except Exception as ex:
|
||||
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
|
||||
|
||||
async def _schedule_batched(self):
|
||||
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||
# noop, don't reschedule if it is already scheduled
|
||||
return
|
||||
try:
|
||||
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||
await self.batch_task
|
||||
await self._send_batched()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as ex:
|
||||
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
|
||||
|
||||
async def _send_batched_now(self):
|
||||
if self.batch_task is not None and not self.batch_task.done():
|
||||
self.batch_task.cancel()
|
||||
self.last_batched = None
|
||||
await self._send_batched()
|
||||
|
||||
async def _send_batched(self):
|
||||
if self.batched:
|
||||
batched = self.batched
|
||||
self.batched = ""
|
||||
await self._send(batched)
|
||||
|
||||
async def _send(self, message, as_file=False):
|
||||
try:
|
||||
self.bucket.request()
|
||||
except BucketOverFull:
|
||||
# Silently ignore
|
||||
self.ignored += 1
|
||||
return
|
||||
except BucketFull:
|
||||
logger.warning(
|
||||
"Can't keep up! "
|
||||
f"Ignoring records on live-logger {self.webhook.id}."
|
||||
)
|
||||
self.ignored += 1
|
||||
return
|
||||
else:
|
||||
if self.ignored > 0:
|
||||
logger.warning(
|
||||
"Can't keep up! "
|
||||
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
|
||||
)
|
||||
self.ignored = 0
|
||||
|
||||
try:
|
||||
if as_file or len(message) > 1900:
|
||||
with StringIO(message) as fp:
|
||||
fp.seek(0)
|
||||
await self.webhook.send(
|
||||
f"{self.prefix}\n`{message.splitlines()[0]}`",
|
||||
file=File(fp, filename="logs.md"),
|
||||
username=log_app.get()
|
||||
)
|
||||
else:
|
||||
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
|
||||
except discord.HTTPException:
|
||||
logger.exception(
|
||||
"Live logger errored. Slowing down live logger."
|
||||
)
|
||||
self.bucket.fill()
|
||||
|
||||
|
||||
handlers = []
|
||||
if webhook := conf.logging['general_log']:
|
||||
handler = WebHookHandler(webhook, batch=True)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['warning_log']:
|
||||
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
|
||||
handler.addFilter(ExactLevelFilter(logging.WARNING))
|
||||
handler.setLevel(logging.WARNING)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['error_log']:
|
||||
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
|
||||
handler.setLevel(logging.ERROR)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['critical_log']:
|
||||
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
|
||||
handler.setLevel(logging.CRITICAL)
|
||||
handlers.append(handler)
|
||||
|
||||
|
||||
def make_queue_handler(queue):
|
||||
qhandler = QueueHandler(queue)
|
||||
qhandler.setLevel(logging.INFO)
|
||||
qhandler.addFilter(ContextInjection())
|
||||
return qhandler
|
||||
|
||||
|
||||
def setup_main_logger(multiprocess=False):
|
||||
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
|
||||
if handlers:
|
||||
# First create a separate loop to run the handlers on
|
||||
import threading
|
||||
|
||||
def run_loop(loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_forever()
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||
loop_thread.daemon = True
|
||||
loop_thread.start()
|
||||
|
||||
for handler in handlers:
|
||||
handler.loop = loop
|
||||
|
||||
qhandler = make_queue_handler(q)
|
||||
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||
logger.addHandler(qhandler)
|
||||
|
||||
listener = QueueListener(
|
||||
q, *handlers, respect_handler_level=True
|
||||
)
|
||||
listener.start()
|
||||
return q
|
||||
11
src/modules/__init__.py
Normal file
11
src/modules/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from meta import Bot
|
||||
|
||||
|
||||
async def twitch_setup(bot: 'Bot'):
|
||||
# Import and run setup methods from each module
|
||||
# from . import module
|
||||
# await module.twitch_setup(bot)
|
||||
pass
|
||||
88
src/utils/lib.py
Normal file
88
src/utils/lib.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import re
|
||||
import datetime as dt
|
||||
|
||||
|
||||
def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str:
|
||||
"""
|
||||
Convert a datetime.timedelta object into an easily readable duration string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delta: datetime.timedelta
|
||||
The timedelta object to convert into a readable string.
|
||||
sec: bool
|
||||
Whether to include the seconds from the timedelta object in the string.
|
||||
minutes: bool
|
||||
Whether to include the minutes from the timedelta object in the string.
|
||||
short: bool
|
||||
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||
|
||||
Returns: str
|
||||
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||
Time units will be abbreviated if short was set to True.
|
||||
"""
|
||||
output = [[delta.days, 'd' if short else ' day'],
|
||||
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||
if minutes:
|
||||
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||
if sec:
|
||||
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||
for i in range(len(output)):
|
||||
if output[i][0] != 1 and not short:
|
||||
output[i][1] += 's' # type: ignore
|
||||
reply_msg = []
|
||||
if output[0][0] != 0:
|
||||
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||
for i in range(2, len(output) - 1):
|
||||
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||
if not short and reply_msg:
|
||||
reply_msg.append("and ")
|
||||
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||
return "".join(reply_msg)
|
||||
|
||||
def utc_now() -> dt.datetime:
|
||||
"""
|
||||
Return the current timezone-aware utc timestamp.
|
||||
"""
|
||||
return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
||||
|
||||
def replace_multiple(format_string, mapping):
|
||||
"""
|
||||
Subsistutes the keys from the format_dict with their corresponding values.
|
||||
|
||||
Substitution is non-chained, and done in a single pass via regex.
|
||||
"""
|
||||
if not mapping:
|
||||
raise ValueError("Empty mapping passed.")
|
||||
|
||||
keys = list(mapping.keys())
|
||||
pattern = '|'.join(f"({key})" for key in keys)
|
||||
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
|
||||
return string
|
||||
|
||||
|
||||
def parse_dur(time_str):
|
||||
"""
|
||||
Parses a user provided time duration string into a number of seconds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_str: str
|
||||
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||
|
||||
Returns: int
|
||||
The number of seconds the duration represents.
|
||||
"""
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
time_str = time_str.strip(" ,")
|
||||
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||
seconds = 0
|
||||
for bit in found:
|
||||
if bit[1] in funcs:
|
||||
seconds += funcs[bit[1]](int(bit[0]))
|
||||
return seconds
|
||||
166
src/utils/ratelimits.py
Normal file
166
src/utils/ratelimits.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
|
||||
class BucketFull(Exception):
|
||||
"""
|
||||
Throw when a requested Bucket is already full
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BucketOverFull(BucketFull):
|
||||
"""
|
||||
Throw when a requested Bucket is overfull
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Bucket:
|
||||
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
||||
|
||||
def __init__(self, max_level, empty_time):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
self.leak_rate = max_level / empty_time
|
||||
|
||||
self._level = 0
|
||||
self._last_checked = time.monotonic()
|
||||
|
||||
self._last_full = False
|
||||
self._wait_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def full(self) -> bool:
|
||||
"""
|
||||
Return whether the bucket is 'full',
|
||||
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
||||
"""
|
||||
self._leak()
|
||||
return self._level + 1 > self.max_level
|
||||
|
||||
@property
|
||||
def overfull(self):
|
||||
self._leak()
|
||||
return self._level > self.max_level
|
||||
|
||||
@property
|
||||
def delay(self):
|
||||
self._leak()
|
||||
if self._level + 1 > self.max_level:
|
||||
delay = (self._level + 1 - self.max_level) * self.leak_rate
|
||||
else:
|
||||
delay = 0
|
||||
return delay
|
||||
|
||||
def _leak(self):
|
||||
if self._level:
|
||||
elapsed = time.monotonic() - self._last_checked
|
||||
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||
|
||||
self._last_checked = time.monotonic()
|
||||
|
||||
def request(self):
|
||||
self._leak()
|
||||
if self._level > self.max_level:
|
||||
raise BucketOverFull
|
||||
elif self._level == self.max_level:
|
||||
self._level += 1
|
||||
if self._last_full:
|
||||
raise BucketOverFull
|
||||
else:
|
||||
self._last_full = True
|
||||
raise BucketFull
|
||||
else:
|
||||
self._last_full = False
|
||||
self._level += 1
|
||||
|
||||
def fill(self):
|
||||
self._leak()
|
||||
self._level = max(self._level, self.max_level + 1)
|
||||
|
||||
async def wait(self):
|
||||
"""
|
||||
Wait until the bucket has room.
|
||||
|
||||
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
||||
"""
|
||||
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||
# Otherwise multiple waiters will have the same delay,
|
||||
# and race for the wakeup after sleep.
|
||||
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
|
||||
async with self._wait_lock:
|
||||
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||
# or a synchronous request overflows the bucket while we are waiting.
|
||||
while self.full:
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
async def wrapped(self, coro):
|
||||
await self.wait()
|
||||
self.request()
|
||||
await coro
|
||||
|
||||
|
||||
class RateLimit:
|
||||
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
|
||||
self.error = error or "Too many requests, please slow down!"
|
||||
self.buckets = cache
|
||||
|
||||
def request_for(self, key):
|
||||
if not (bucket := self.buckets.get(key, None)):
|
||||
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||
|
||||
bucket.request()
|
||||
|
||||
def ward(self, member=True, key=None):
|
||||
"""
|
||||
Command ratelimit decorator.
|
||||
"""
|
||||
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(ctx, *args, **kwargs):
|
||||
self.request_for(key(ctx))
|
||||
return await func(ctx, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def limit_concurrency(aws, limit):
|
||||
"""
|
||||
Run provided awaitables concurrently,
|
||||
ensuring that no more than `limit` are running at once.
|
||||
"""
|
||||
aws = iter(aws)
|
||||
aws_ended = False
|
||||
pending = set()
|
||||
count = 0
|
||||
logger.debug("Starting limited concurrency executor")
|
||||
|
||||
while pending or not aws_ended:
|
||||
while len(pending) < limit and not aws_ended:
|
||||
aw = next(aws, None)
|
||||
if aw is None:
|
||||
aws_ended = True
|
||||
else:
|
||||
pending.add(asyncio.create_task(aw))
|
||||
count += 1
|
||||
|
||||
if not pending:
|
||||
break
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
while done:
|
||||
yield done.pop()
|
||||
logger.debug(f"Completed {count} tasks")
|
||||
Reference in New Issue
Block a user