From 7457ae6ac1edb56782d048aebd824afe42609866 Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 1 Aug 2025 00:05:48 +1000 Subject: [PATCH] Initial Template --- .gitmodules | 3 + config/example-bot.conf | 15 ++ config/example-secrets.conf | 7 + data/.gitignore | 0 data/schema.sql | 63 +++++ requirements.txt | 4 + scripts/start_bot.py | 12 + src/bot.py | 42 ++++ src/botdata.py | 82 +++++++ src/constants.py | 22 ++ src/data | 1 + src/meta/__init__.py | 4 + src/meta/args.py | 28 +++ src/meta/bot.py | 174 ++++++++++++++ src/meta/config.py | 105 ++++++++ src/meta/logger.py | 466 ++++++++++++++++++++++++++++++++++++ src/modules/__init__.py | 11 + src/utils/lib.py | 88 +++++++ src/utils/ratelimits.py | 166 +++++++++++++ 19 files changed, 1293 insertions(+) create mode 100644 .gitmodules create mode 100644 config/example-bot.conf create mode 100644 config/example-secrets.conf create mode 100644 data/.gitignore create mode 100644 data/schema.sql create mode 100644 requirements.txt create mode 100644 scripts/start_bot.py create mode 100644 src/bot.py create mode 100644 src/botdata.py create mode 100644 src/constants.py create mode 160000 src/data create mode 100644 src/meta/__init__.py create mode 100644 src/meta/args.py create mode 100644 src/meta/bot.py create mode 100644 src/meta/config.py create mode 100644 src/meta/logger.py create mode 100644 src/modules/__init__.py create mode 100644 src/utils/lib.py create mode 100644 src/utils/ratelimits.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..766cecf --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/data"] + path = src/data + url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git diff --git a/config/example-bot.conf b/config/example-bot.conf new file mode 100644 index 0000000..1afa4a8 --- /dev/null +++ b/config/example-bot.conf @@ -0,0 +1,15 @@ +[BOT] +prefix = ? +bot_id = + +ALSO_READ = config/secrets.conf + +wshost = localhost +wsport = 4343 +wsdomain = localhost:4343 + +[LOGGING] +general_log = +warning_log = +error_log = +critical_log = diff --git a/config/example-secrets.conf b/config/example-secrets.conf new file mode 100644 index 0000000..f931031 --- /dev/null +++ b/config/example-secrets.conf @@ -0,0 +1,7 @@ +[CROCBOT] +client_id = +client_secret = + +[DATA] +args = +appid = diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/data/schema.sql b/data/schema.sql new file mode 100644 index 0000000..40f6d0a --- /dev/null +++ b/data/schema.sql @@ -0,0 +1,63 @@ +-- Metadata {{{ +CREATE TABLE version_history( + component TEXT NOT NULL, + from_version INTEGER NOT NULL, + to_version INTEGER NOT NULL, + author TEXT NOT NULL, + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), +); +INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation'); + + +CREATE OR REPLACE FUNCTION update_timestamp_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW._timestamp = now(); + RETURN NEW; +END; +$$ language 'plpgsql'; + +CREATE TABLE app_config( + appname TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- }}} + +-- Twitch Auth {{{ +INSERT INTO version_history (component, from_version, to_version, author) VALUES ('TWITCH_AUTH', 0, 1, 'Initial Creation'); + +-- Authorisation tokens allowing us to take actions on behalf of certain users or channels. +-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope. +CREATE TABLE user_auth( + userid TEXT PRIMARY KEY, + token TEXT NOT NULL, + refresh_token TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE TRIGGER user_auth_timestamp BEFORE UPDATE ON user_auth + FOR EACH ROW EXECUTE FUNCTION update_timestamp_column(); + +CREATE TABLE user_auth_scopes( + userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE, + scope TEXT NOT NULL +); + +-- Which joins will be joined at startup, +-- and any configurable choices needed when joining the channel +CREATE TABLE bot_channels( + userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE, + autojoin BOOLEAN DEFAULT true, + joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels + FOR EACH ROW EXECUTE FUNCTION update_timestamp_column(); + + +-- }}} + + + +-- vim: set fdm=marker: diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..48f9b57 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +twitchio +psycopg[pool] +cachetools +discord.py diff --git a/scripts/start_bot.py b/scripts/start_bot.py new file mode 100644 index 0000000..49d6ad1 --- /dev/null +++ b/scripts/start_bot.py @@ -0,0 +1,12 @@ +# !/bin/python3 + +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + + +if __name__ == '__main__': + from bot import _main + _main() diff --git a/src/bot.py b/src/bot.py new file mode 100644 index 0000000..e99283e --- /dev/null +++ b/src/bot.py @@ -0,0 +1,42 @@ +import asyncio +import logging +import websockets + +from twitchio.web import AiohttpAdapter + +from meta import CrocBot, conf, setup_main_logger, args +from data import Database + +from modules import twitch_setup + +logger = logging.getLogger(__name__) + + +async def main(): + db = Database(conf.data['args']) + + async with db.open(): + adapter = AiohttpAdapter( + host=conf.bot.get('wshost', None), + port=conf.bot.getint('wsport', None), + domain=conf.bot.get('wsdomain', None), + eventsub_secret=conf.bot.get('eventsub_secret', None) + ) + + bot = CrocBot( + config=conf, + dbconn=db, + adapter=adapter, + setup=twitch_setup, + ) + + try: + await bot.start() + finally: + await bot.close() + + +def _main(): + setup_main_logger() + + asyncio.run(main()) diff --git a/src/botdata.py b/src/botdata.py new file mode 100644 index 0000000..8287b5f --- /dev/null +++ b/src/botdata.py @@ -0,0 +1,82 @@ +from data import Registry, RowModel, Table +from data.columns import String, Timestamp, Integer, Bool + + +class UserAuth(RowModel): + """ + Schema + ====== + CREATE TABLE user_auth( + userid TEXT PRIMARY KEY, + token TEXT NOT NULL, + refresh_token TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """ + _tablename_ = 'user_auth' + _cache_ = {} + + userid = String(primary=True) + token = String() + refresh_token = String() + created_at = Timestamp() + _timestamp = Timestamp() + + +class BotChannel(RowModel): + """ + Schema + ====== + CREATE TABLE bot_channels( + userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE, + autojoin BOOLEAN DEFAULT true, + listen_redeems BOOLEAN, + joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """ + _tablename_ = 'bot_channels' + _cache_ = {} + + userid = String(primary=True) + autojoin = Bool() + listen_redeems = Bool() + joined_at = Timestamp() + _timestamp = Timestamp() + + +class VersionHistory(RowModel): + """ + CREATE TABLE version_history( + component TEXT NOT NULL, + from_version INTEGER NOT NULL, + to_version INTEGER NOT NULL, + author TEXT NOT NULL, + _timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ); + """ + _tablename_ = 'version_history' + _cache_ = {} + + component = String() + from_version = Integer() + to_version = Integer() + author = String() + _timestamp = Timestamp() + + +class BotData(Registry): + version_history = VersionHistory.table + + user_auth = UserAuth.table + bot_channels = BotChannel.table + + """ + CREATE TABLE user_auth_scopes( + userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE, + scope TEXT NOT NULL + ); + """ + user_auth_scopes = Table('user_auth_scopes') + diff --git a/src/constants.py b/src/constants.py new file mode 100644 index 0000000..26ae468 --- /dev/null +++ b/src/constants.py @@ -0,0 +1,22 @@ +from twitchio import Scopes + + +CONFIG_FILE = 'config/bot.conf' + +SCHEMA_VERSIONS = { + 'ROOT': 1, + 'TWITCH_AUTH': 1 +} + +# Requested scopes for the bots own twitch user +BOTUSER_SCOPES = Scopes(( + Scopes.user_read_chat, + Scopes.user_write_chat, + Scopes.user_bot, + Scopes.channel_bot, +)) + +# Default requested scopes for joining a channel +CHANNEL_SCOPES = Scopes(( + Scopes.channel_bot, +)) diff --git a/src/data b/src/data new file mode 160000 index 0000000..cfdfe0e --- /dev/null +++ b/src/data @@ -0,0 +1 @@ +Subproject commit cfdfe0eb50034d54a08c8449e8a62a5b8854e259 diff --git a/src/meta/__init__.py b/src/meta/__init__.py new file mode 100644 index 0000000..c01addd --- /dev/null +++ b/src/meta/__init__.py @@ -0,0 +1,4 @@ +from .args import args +from .bot import Bot +from .config import Conf, conf +from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task diff --git a/src/meta/args.py b/src/meta/args.py new file mode 100644 index 0000000..ea4978f --- /dev/null +++ b/src/meta/args.py @@ -0,0 +1,28 @@ +import argparse + +from constants import CONFIG_FILE + +# ------------------------------ +# Parsed commandline arguments +# ------------------------------ +parser = argparse.ArgumentParser() +parser.add_argument( + '--conf', + dest='config', + default=CONFIG_FILE, + help="Path to configuration file." +) +parser.add_argument( + '--host', + dest='host', + default='127.0.0.1', + help="IP address to run the websocket server on." +) +parser.add_argument( + '--port', + dest='port', + default='5001', + help="Port to run the websocket server on." +) + +args = parser.parse_args() diff --git a/src/meta/bot.py b/src/meta/bot.py new file mode 100644 index 0000000..79cf7f4 --- /dev/null +++ b/src/meta/bot.py @@ -0,0 +1,174 @@ +import logging +from typing import Optional + +from twitchio.authentication import UserTokenPayload +from twitchio.ext import commands +from twitchio import Scopes, eventsub + +from data import Database, ORDER +from botdata import BotData, UserAuth, BotChannel, VersionHistory +from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS + +from .config import Conf + + +logger = logging.getLogger(__name__) + + +class Bot(commands.Bot): + def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs): + kwargs.setdefault('client_id', config.bot['client_id']) + kwargs.setdefault('client_secret', config.bot['client_secret']) + kwargs.setdefault('bot_id', config.bot['bot_id']) + kwargs.setdefault('prefix', config.bot['prefix']) + + super().__init__(*args, **kwargs) + + if config.bot.get('eventsub_secret', None): + self.using_webhooks = True + else: + self.using_webhooks = False + + self.config = config + self.dbconn = dbconn + self.data: BotData = dbconn.load_registry(BotData()) + self._setup_hook = setup + + self.joined: dict[str, BotChannel] = {} + + async def event_ready(self): + # logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") + logger.info("Logged in as %s", self.bot_id) + + async def version_check(self, component: str, req_version: int): + # Query the database to confirm that the given component is listed with the given version. + # Typically done upon loading a component + rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1) + + version = rows[0].to_version if rows else 0 + + if version != req_version: + raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'") + else: + logger.debug( + "Component %s passed version check with version %s", + component, + version + ) + return True + + async def setup_hook(self): + await self.data.init() + for component, req in SCHEMA_VERSIONS.items(): + await self.version_check(component, req) + + if self._setup_hook is not None: + await self._setup_hook(self) + + # Get all current bot channels + channels = await BotChannel.fetch_where(autojoin=True) + + # Join the channels + await self.join_channels(*channels) + + # Build bot account's own url + scopes = BOTUSER_SCOPES + url = self.get_auth_url(scopes) + logger.info("Bot account authorisation url: %s", url) + + # Build everyone else's url + scopes = CHANNEL_SCOPES + url = self.get_auth_url(scopes) + logger.info("User account authorisation url: %s", url) + + logger.info("Finished setup") + + def get_auth_url(self, scopes: Optional[Scopes] = None): + if scopes is None: + scopes = Scopes((Scopes.channel_bot,)) + + url = self._adapter.get_authorization_url(scopes=scopes) + return url + + async def join_channels(self, *channels: BotChannel): + """ + Register webhook subscriptions to the given channel(s). + """ + # TODO: If channels are already joined, unsubscribe + for channel in channels: + sub = None + try: + sub = eventsub.ChatMessageSubscription( + broadcaster_user_id=channel.userid, + user_id=self.bot_id, + ) + if self.using_webhooks: + resp = await self.subscribe_webhook(sub) + else: + resp = await self.subscribe_websocket(sub) + logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp) + self.joined[channel.userid] = channel + self.safe_dispatch('channel_joined', payload=channel) + except Exception: + logger.exception("Failed to subscribe to %s with %s", channel.userid, sub) + + async def event_oauth_authorized(self, payload: UserTokenPayload): + logger.debug("Oauth flow authorization with payload %s", repr(payload)) + # Save the token and scopes and update internal authorisations + resp = await self.add_token(payload.access_token, payload.refresh_token) + if resp.user_id is None: + logger.warning( + "Oauth flow recieved with no user_id. Payload was: %s", + repr(payload) + ) + return + + # If the scopes authorised included channel:bot, ensure a BotChannel exists + # And join it if needed + if Scopes.channel_bot.value in resp.scopes: + bot_channel = await BotChannel.fetch_or_create( + resp.user_id, + autojoin=True, + ) + if bot_channel.autojoin: + await self.join_channels(bot_channel) + + logger.info("Oauth flow authorization complete for payload %s", repr(payload)) + + async def add_token(self, token: str, refresh: str): + # Update the tokens in internal cache + # This also validates the token + # And hopefully gets the userid and scopes + resp = await super().add_token(token, refresh) + if resp.user_id is None: + logger.warning( + "Added a token with no user_id. Response was: %s", + repr(resp) + ) + return resp + userid = resp.user_id + new_scopes = resp.scopes + + # Save the token and scopes to data + # Wrap this in a transaction so if it fails halfway we rollback correctly + async with self.dbconn.connection() as conn: + self.dbconn.conn = conn + async with conn.transaction(): + row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh) + if row.token != token or row.refresh_token != refresh: + await row.update(token=token, refresh_token=refresh) + await self.data.user_auth_scopes.delete_where(userid=userid) + await self.data.user_auth_scopes.insert_many( + ('userid', 'scope'), + *((userid, scope) for scope in new_scopes) + ) + + logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes)) + return resp + + async def load_tokens(self, path: str | None = None): + for row in await UserAuth.fetch_where(): + try: + await self.add_token(row.token, row.refresh_token) + except Exception: + logger.exception(f"Failed to add token for {row}") diff --git a/src/meta/config.py b/src/meta/config.py new file mode 100644 index 0000000..b842fc5 --- /dev/null +++ b/src/meta/config.py @@ -0,0 +1,105 @@ +import configparser as cfgp + +from .args import args + + +class MapDotProxy: + """ + Allows dot access to an underlying Mappable object. + """ + __slots__ = ("_map", "_converter") + + def __init__(self, mappable, converter=None): + self._map = mappable + self._converter = converter + + def __getattribute__(self, key): + _map = object.__getattribute__(self, '_map') + if key == '_map': + return _map + if key in _map: + _converter = object.__getattribute__(self, '_converter') + if _converter: + return _converter(_map[key]) + else: + return _map[key] + else: + return object.__getattribute__(_map, key) + + def __getitem__(self, key): + return self._map.__getitem__(key) + + +class ConfigParser(cfgp.ConfigParser): + """ + Extension of base ConfigParser allowing optional + section option retrieval without defaults. + """ + def options(self, section, no_defaults=False, **kwargs): + if no_defaults: + try: + return list(self._sections[section].keys()) + except KeyError: + raise cfgp.NoSectionError(section) + else: + return super().options(section, **kwargs) + + +class Conf: + def __init__(self, configfile, section_name="DEFAULT"): + self.configfile = configfile + + self.config = ConfigParser( + converters={ + "intlist": self._getintlist, + "list": self._getlist, + } + ) + + with open(configfile) as conff: + # Opening with read_file mainly to ensure the file exists + self.config.read_file(conff) + + self.section_name = section_name if section_name in self.config else 'DEFAULT' + + self.default = self.config["DEFAULT"] + self.section = MapDotProxy(self.config[self.section_name]) + self.bot = self.section + + # Config file recursion, read in configuration files specified in every "ALSO_READ" key. + more_to_read = self.section.getlist("ALSO_READ", []) + read = set() + while more_to_read: + to_read = more_to_read.pop(0) + read.add(to_read) + self.config.read(to_read) + new_paths = [path for path in self.section.getlist("ALSO_READ", []) + if path not in read and path not in more_to_read] + more_to_read.extend(new_paths) + + global conf + conf = self + + def __getitem__(self, key): + return self.section[key].strip() + + def __getattr__(self, section): + name = section.upper() + return self.config[name] + + def get(self, name, fallback=None): + result = self.section.get(name, fallback) + return result.strip() if result else result + + def _getintlist(self, value): + return [int(item.strip()) for item in value.split(',')] + + def _getlist(self, value): + return [item.strip() for item in value.split(',')] + + def write(self): + with open(self.configfile, 'w') as conffile: + self.config.write(conffile) + + +conf = Conf(args.config, 'BOT') diff --git a/src/meta/logger.py b/src/meta/logger.py new file mode 100644 index 0000000..d2e7cb3 --- /dev/null +++ b/src/meta/logger.py @@ -0,0 +1,466 @@ +import sys +import logging +import asyncio +from typing import List, Optional +from logging.handlers import QueueListener, QueueHandler +import queue +import multiprocessing +from contextlib import contextmanager +from io import StringIO +from functools import wraps +from contextvars import ContextVar + +import aiohttp + +from .config import conf +from utils.ratelimits import Bucket, BucketOverFull, BucketFull + + +log_logger = logging.getLogger(__name__) +log_logger.propagate = False + + +log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') +log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=()) +log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT") + +context: ContextVar[Optional[str]] = ContextVar('context', default=None) + +def set_logging_context( + context: Optional[str] = None, + action: Optional[str] = None, + stack: Optional[tuple[str, ...]] = None +): + """ + Statically set the logging context variables to the given values. + + If `action` is given, pushes it onto the `log_action_stack`. + """ + if context is not None: + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + + +@contextmanager +def logging_context(context=None, action=None, stack=None): + """ + Context manager for executing a block of code in a given logging context. + + This context manager should only be used around synchronous code. + This is because async code *may* get cancelled or externally garbage collected, + in which case the finally block will be executed in the wrong context. + See https://github.com/python/cpython/issues/93740 + This can be refactored nicely if this gets merged: + https://github.com/python/cpython/pull/99634 + + (It will not necessarily break on async code, + if the async code can be guaranteed to clean up in its own context.) + """ + if context is not None: + oldcontext = log_context.get() + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + try: + yield + finally: + if context is not None: + log_context.set(oldcontext) + if stack is not None or action is not None: + log_action_stack.set(astack) + + +def with_log_ctx(isolate=True, **kwargs): + """ + Execute a coroutine inside a given logging context. + + If `isolate` is true, ensures that context does not leak + outside the coroutine. + + If `isolate` is false, just statically set the context, + which will leak unless the coroutine is + called in an externally copied context. + """ + def decorator(func): + @wraps(func) + async def wrapped(*w_args, **w_kwargs): + if isolate: + with logging_context(**kwargs): + # Task creation will synchronously copy the context + # This is gc safe + name = kwargs.get('action', f"log-wrapped-{func.__name__}") + task = asyncio.create_task(func(*w_args, **w_kwargs), name=name) + return await task + else: + # This will leak context changes + set_logging_context(**kwargs) + return await func(*w_args, **w_kwargs) + return wrapped + return decorator + + +# For backwards compatibility +log_wrap = with_log_ctx + + +def persist_task(task_collection: set): + """ + Coroutine decorator that ensures the coroutine is scheduled as a task + and added to the given task_collection for strong reference + when it is called. + + This is just a hack to handle discord.py events potentially + being unexpectedly garbage collected. + + Since this also implicitly schedules the coroutine as a task when it is called, + the coroutine will also be run inside an isolated context. + """ + def decorator(coro): + @wraps(coro) + async def wrapped(*w_args, **w_kwargs): + name = f"persisted-{coro.__name__}" + task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name) + task_collection.add(task) + task.add_done_callback(lambda f: task_collection.discard(f)) + await task + + +RESET_SEQ = "\033[0m" +COLOR_SEQ = "\033[3%dm" +BOLD_SEQ = "\033[1m" +"]]]" +BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) + + +def colour_escape(fmt: str) -> str: + cmap = { + '%(black)': COLOR_SEQ % BLACK, + '%(red)': COLOR_SEQ % RED, + '%(green)': COLOR_SEQ % GREEN, + '%(yellow)': COLOR_SEQ % YELLOW, + '%(blue)': COLOR_SEQ % BLUE, + '%(magenta)': COLOR_SEQ % MAGENTA, + '%(cyan)': COLOR_SEQ % CYAN, + '%(white)': COLOR_SEQ % WHITE, + '%(reset)': RESET_SEQ, + '%(bold)': BOLD_SEQ, + } + for key, value in cmap.items(): + fmt = fmt.replace(key, value) + return fmt + + +log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' + + '%(cyan)%(app)-15s%(reset)|' + + '%(cyan)%(context)-24s%(reset)|' + + '%(cyan)%(actionstr)-22s%(reset)|' + + ' %(bold)%(cyan)%(name)s:%(reset)' + + ' %(white)%(message)s%(ctxstr)s%(reset)') +log_format = colour_escape(log_format) + + +# Setup the logger +logger = logging.getLogger() +log_fmt = logging.Formatter( + fmt=log_format, + # datefmt='%Y-%m-%d %H:%M:%S' +) +logger.setLevel(logging.NOTSET) + + +class LessThanFilter(logging.Filter): + def __init__(self, exclusive_maximum, name=""): + super(LessThanFilter, self).__init__(name) + self.max_level = exclusive_maximum + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.levelno < self.max_level else 0 + +class ExactLevelFilter(logging.Filter): + def __init__(self, target_level, name=""): + super().__init__(name) + self.target_level = target_level + + def filter(self, record): + return (record.levelno == self.target_level) + + +class ThreadFilter(logging.Filter): + def __init__(self, thread_name): + super().__init__("") + self.thread = thread_name + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.threadName == self.thread else 0 + + +class ContextInjection(logging.Filter): + def filter(self, record): + # These guards are to allow override through _extra + # And to ensure the injection is idempotent + if not hasattr(record, 'context'): + record.context = log_context.get() + + if not hasattr(record, 'actionstr'): + action_stack = log_action_stack.get() + if hasattr(record, 'action'): + action_stack = (*action_stack, record.action) + if action_stack: + record.actionstr = ' ➔ '.join(action_stack) + else: + record.actionstr = "Unknown Action" + + if not hasattr(record, 'app'): + record.app = log_app.get() + + if not hasattr(record, 'ctx'): + if ctx := context.get(): + record.ctx = repr(ctx) + else: + record.ctx = None + + if getattr(record, 'with_ctx', False) and record.ctx: + record.ctxstr = '\n' + record.ctx + else: + record.ctxstr = "" + return True + + +logging_handler_out = logging.StreamHandler(sys.stdout) +logging_handler_out.setLevel(logging.DEBUG) +logging_handler_out.setFormatter(log_fmt) +logging_handler_out.addFilter(ContextInjection()) +logger.addHandler(logging_handler_out) +log_logger.addHandler(logging_handler_out) + +logging_handler_err = logging.StreamHandler(sys.stderr) +logging_handler_err.setLevel(logging.WARNING) +logging_handler_err.setFormatter(log_fmt) +logging_handler_err.addFilter(ContextInjection()) +logger.addHandler(logging_handler_err) +log_logger.addHandler(logging_handler_err) + + +class LocalQueueHandler(QueueHandler): + def _emit(self, record: logging.LogRecord) -> None: + # Removed the call to self.prepare(), handle task cancellation + try: + self.enqueue(record) + except asyncio.CancelledError: + raise + except Exception: + self.handleError(record) + + +class WebHookHandler(logging.StreamHandler): + def __init__(self, webhook_url, prefix="", batch=True, loop=None): + super().__init__() + self.webhook_url = webhook_url + self.prefix = prefix + self.batched = "" + self.batch = batch + self.loop = loop + self.batch_delay = 10 + self.batch_task = None + self.last_batched = None + self.waiting = [] + + self.bucket = Bucket(20, 40) + self.ignored = 0 + + self.session = None + self.webhook = None + + def get_loop(self): + if self.loop is None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + return self.loop + + def emit(self, record): + self.format(record) + self.get_loop().call_soon_threadsafe(self._post, record) + + def _post(self, record): + if self.session is None: + self.setup() + asyncio.create_task(self.post(record)) + + def setup(self): + from discord import Webhook + self.session = aiohttp.ClientSession() + self.webhook = Webhook.from_url(self.webhook_url, session=self.session) + + async def post(self, record): + if record.context == 'Webhook Logger': + # Don't livelog livelog errors + # Otherwise we recurse and Cloudflare hates us + return + log_context.set("Webhook Logger") + log_action_stack.set(("Logging",)) + log_app.set(record.app) + + try: + timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") + header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>" + context = f"\n# Context: {record.ctx}" if record.ctx else "" + message = f"{header}\n{record.msg}{context}" + + if len(message) > 1900: + as_file = True + else: + as_file = False + message = "```md\n{}\n```".format(message) + + # Post the log message(s) + if self.batch: + if len(message) > 1500: + await self._send_batched_now() + await self._send(message, as_file=as_file) + else: + self.batched += message + if len(self.batched) + len(message) > 1500: + await self._send_batched_now() + else: + asyncio.create_task(self._schedule_batched()) + else: + await self._send(message, as_file=as_file) + except Exception as ex: + print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr) + + async def _schedule_batched(self): + if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()): + # noop, don't reschedule if it is already scheduled + return + try: + self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay)) + await self.batch_task + await self._send_batched() + except asyncio.CancelledError: + return + except Exception as ex: + print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr) + + async def _send_batched_now(self): + if self.batch_task is not None and not self.batch_task.done(): + self.batch_task.cancel() + self.last_batched = None + await self._send_batched() + + async def _send_batched(self): + if self.batched: + batched = self.batched + self.batched = "" + await self._send(batched) + + async def _send(self, message, as_file=False): + try: + self.bucket.request() + except BucketOverFull: + # Silently ignore + self.ignored += 1 + return + except BucketFull: + logger.warning( + "Can't keep up! " + f"Ignoring records on live-logger {self.webhook.id}." + ) + self.ignored += 1 + return + else: + if self.ignored > 0: + logger.warning( + "Can't keep up! " + f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing." + ) + self.ignored = 0 + + try: + if as_file or len(message) > 1900: + with StringIO(message) as fp: + fp.seek(0) + await self.webhook.send( + f"{self.prefix}\n`{message.splitlines()[0]}`", + file=File(fp, filename="logs.md"), + username=log_app.get() + ) + else: + await self.webhook.send(self.prefix + '\n' + message, username=log_app.get()) + except discord.HTTPException: + logger.exception( + "Live logger errored. Slowing down live logger." + ) + self.bucket.fill() + + +handlers = [] +if webhook := conf.logging['general_log']: + handler = WebHookHandler(webhook, batch=True) + handlers.append(handler) + +if webhook := conf.logging['warning_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True) + handler.addFilter(ExactLevelFilter(logging.WARNING)) + handler.setLevel(logging.WARNING) + handlers.append(handler) + +if webhook := conf.logging['error_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True) + handler.setLevel(logging.ERROR) + handlers.append(handler) + +if webhook := conf.logging['critical_log']: + handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False) + handler.setLevel(logging.CRITICAL) + handlers.append(handler) + + +def make_queue_handler(queue): + qhandler = QueueHandler(queue) + qhandler.setLevel(logging.INFO) + qhandler.addFilter(ContextInjection()) + return qhandler + + +def setup_main_logger(multiprocess=False): + q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue() + if handlers: + # First create a separate loop to run the handlers on + import threading + + def run_loop(loop): + asyncio.set_event_loop(loop) + try: + loop.run_forever() + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + loop = asyncio.new_event_loop() + loop_thread = threading.Thread(target=lambda: run_loop(loop)) + loop_thread.daemon = True + loop_thread.start() + + for handler in handlers: + handler.loop = loop + + qhandler = make_queue_handler(q) + # qhandler.addFilter(ThreadFilter('MainThread')) + logger.addHandler(qhandler) + + listener = QueueListener( + q, *handlers, respect_handler_level=True + ) + listener.start() + return q diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..96ee6bc --- /dev/null +++ b/src/modules/__init__.py @@ -0,0 +1,11 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from meta import Bot + + +async def twitch_setup(bot: 'Bot'): + # Import and run setup methods from each module + # from . import module + # await module.twitch_setup(bot) + pass diff --git a/src/utils/lib.py b/src/utils/lib.py new file mode 100644 index 0000000..6eb2a79 --- /dev/null +++ b/src/utils/lib.py @@ -0,0 +1,88 @@ +import re +import datetime as dt + + +def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str: + """ + Convert a datetime.timedelta object into an easily readable duration string. + + Parameters + ---------- + delta: datetime.timedelta + The timedelta object to convert into a readable string. + sec: bool + Whether to include the seconds from the timedelta object in the string. + minutes: bool + Whether to include the minutes from the timedelta object in the string. + short: bool + Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s"). + + Returns: str + A string containing a time from the datetime.timedelta object, in a readable format. + Time units will be abbreviated if short was set to True. + """ + output = [[delta.days, 'd' if short else ' day'], + [delta.seconds // 3600, 'h' if short else ' hour']] + if minutes: + output.append([delta.seconds // 60 % 60, 'm' if short else ' minute']) + if sec: + output.append([delta.seconds % 60, 's' if short else ' second']) + for i in range(len(output)): + if output[i][0] != 1 and not short: + output[i][1] += 's' # type: ignore + reply_msg = [] + if output[0][0] != 0: + reply_msg.append("{}{} ".format(output[0][0], output[0][1])) + if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2: + reply_msg.append("{}{} ".format(output[1][0], output[1][1])) + for i in range(2, len(output) - 1): + reply_msg.append("{}{} ".format(output[i][0], output[i][1])) + if not short and reply_msg: + reply_msg.append("and ") + reply_msg.append("{}{}".format(output[-1][0], output[-1][1])) + return "".join(reply_msg) + +def utc_now() -> dt.datetime: + """ + Return the current timezone-aware utc timestamp. + """ + return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc) + +def replace_multiple(format_string, mapping): + """ + Subsistutes the keys from the format_dict with their corresponding values. + + Substitution is non-chained, and done in a single pass via regex. + """ + if not mapping: + raise ValueError("Empty mapping passed.") + + keys = list(mapping.keys()) + pattern = '|'.join(f"({key})" for key in keys) + string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string) + return string + + +def parse_dur(time_str): + """ + Parses a user provided time duration string into a number of seconds. + + Parameters + ---------- + time_str: str + The time string to parse. String can include days, hours, minutes, and seconds. + + Returns: int + The number of seconds the duration represents. + """ + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + time_str = time_str.strip(" ,") + found = re.findall(r'(\d+)\s?(\w+?)', time_str) + seconds = 0 + for bit in found: + if bit[1] in funcs: + seconds += funcs[bit[1]](int(bit[0])) + return seconds diff --git a/src/utils/ratelimits.py b/src/utils/ratelimits.py new file mode 100644 index 0000000..b7d7bf2 --- /dev/null +++ b/src/utils/ratelimits.py @@ -0,0 +1,166 @@ +import asyncio +import time +import logging + +from cachetools import TTLCache + +logger = logging.getLogger() + + + +class BucketFull(Exception): + """ + Throw when a requested Bucket is already full + """ + pass + + +class BucketOverFull(BucketFull): + """ + Throw when a requested Bucket is overfull + """ + pass + + +class Bucket: + __slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock') + + def __init__(self, max_level, empty_time): + self.max_level = max_level + self.empty_time = empty_time + self.leak_rate = max_level / empty_time + + self._level = 0 + self._last_checked = time.monotonic() + + self._last_full = False + self._wait_lock = asyncio.Lock() + + @property + def full(self) -> bool: + """ + Return whether the bucket is 'full', + that is, whether an immediate request against the bucket will raise `BucketFull`. + """ + self._leak() + return self._level + 1 > self.max_level + + @property + def overfull(self): + self._leak() + return self._level > self.max_level + + @property + def delay(self): + self._leak() + if self._level + 1 > self.max_level: + delay = (self._level + 1 - self.max_level) * self.leak_rate + else: + delay = 0 + return delay + + def _leak(self): + if self._level: + elapsed = time.monotonic() - self._last_checked + self._level = max(0, self._level - (elapsed * self.leak_rate)) + + self._last_checked = time.monotonic() + + def request(self): + self._leak() + if self._level > self.max_level: + raise BucketOverFull + elif self._level == self.max_level: + self._level += 1 + if self._last_full: + raise BucketOverFull + else: + self._last_full = True + raise BucketFull + else: + self._last_full = False + self._level += 1 + + def fill(self): + self._leak() + self._level = max(self._level, self.max_level + 1) + + async def wait(self): + """ + Wait until the bucket has room. + + Guarantees that a `request` directly afterwards will not raise `BucketFull`. + """ + # Wrapped in a lock so that waiters are correctly handled in wait-order + # Otherwise multiple waiters will have the same delay, + # and race for the wakeup after sleep. + # Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order + async with self._wait_lock: + # We do this in a loop in case asyncio.sleep throws us out early, + # or a synchronous request overflows the bucket while we are waiting. + while self.full: + await asyncio.sleep(self.delay) + + async def wrapped(self, coro): + await self.wait() + self.request() + await coro + + +class RateLimit: + def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)): + self.max_level = max_level + self.empty_time = empty_time + + self.error = error or "Too many requests, please slow down!" + self.buckets = cache + + def request_for(self, key): + if not (bucket := self.buckets.get(key, None)): + bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time) + + bucket.request() + + def ward(self, member=True, key=None): + """ + Command ratelimit decorator. + """ + key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id)) + + def decorator(func): + async def wrapper(ctx, *args, **kwargs): + self.request_for(key(ctx)) + return await func(ctx, *args, **kwargs) + return wrapper + return decorator + + +async def limit_concurrency(aws, limit): + """ + Run provided awaitables concurrently, + ensuring that no more than `limit` are running at once. + """ + aws = iter(aws) + aws_ended = False + pending = set() + count = 0 + logger.debug("Starting limited concurrency executor") + + while pending or not aws_ended: + while len(pending) < limit and not aws_ended: + aw = next(aws, None) + if aw is None: + aws_ended = True + else: + pending.add(asyncio.create_task(aw)) + count += 1 + + if not pending: + break + + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + while done: + yield done.pop() + logger.debug(f"Completed {count} tasks")