Initial commit

This commit is contained in:
HoloTech
2025-09-01 20:26:45 +10:00
commit 1f60610947
19 changed files with 1293 additions and 0 deletions

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "src/data"]
path = src/data
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git

0
data/.gitignore vendored Normal file
View File

63
data/schema.sql Normal file
View File

@@ -0,0 +1,63 @@
-- Metadata {{{
CREATE TABLE version_history(
component TEXT NOT NULL,
from_version INTEGER NOT NULL,
to_version INTEGER NOT NULL,
author TEXT NOT NULL,
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
);
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
CREATE OR REPLACE FUNCTION update_timestamp_column()
RETURNS TRIGGER AS $$
BEGIN
NEW._timestamp = now();
RETURN NEW;
END;
$$ language 'plpgsql';
CREATE TABLE app_config(
appname TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- }}}
-- Twitch Auth {{{
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('TWITCH_AUTH', 0, 1, 'Initial Creation');
-- Authorisation tokens allowing us to take actions on behalf of certain users or channels.
-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope.
CREATE TABLE user_auth(
userid TEXT PRIMARY KEY,
token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER user_auth_timestamp BEFORE UPDATE ON user_auth
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
CREATE TABLE user_auth_scopes(
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
scope TEXT NOT NULL
);
-- Which joins will be joined at startup,
-- and any configurable choices needed when joining the channel
CREATE TABLE bot_channels(
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
autojoin BOOLEAN DEFAULT true,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
-- }}}
-- vim: set fdm=marker:

15
example-config/bot.conf Normal file
View File

@@ -0,0 +1,15 @@
[BOT]
prefix = ?
bot_id =
ALSO_READ = config/secrets.conf
wshost = localhost
wsport = 4343
wsdomain = localhost:4343
[LOGGING]
general_log =
warning_log =
error_log =
critical_log =

View File

@@ -0,0 +1,7 @@
[CROCBOT]
client_id =
client_secret =
[DATA]
args =
appid =

4
requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
twitchio
psycopg[pool]
cachetools
discord.py

12
scripts/start_bot.py Normal file
View File

@@ -0,0 +1,12 @@
# !/bin/python3
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
if __name__ == '__main__':
from bot import _main
_main()

42
src/bot.py Normal file
View File

@@ -0,0 +1,42 @@
import asyncio
import logging
import websockets
from twitchio.web import AiohttpAdapter
from meta import Bot, conf, setup_main_logger, args
from data import Database
from modules import twitch_setup
logger = logging.getLogger(__name__)
async def main():
db = Database(conf.data['args'])
async with db.open():
adapter = AiohttpAdapter(
host=conf.bot.get('wshost', None),
port=conf.bot.getint('wsport', None),
domain=conf.bot.get('wsdomain', None),
eventsub_secret=conf.bot.get('eventsub_secret', None)
)
bot = Bot(
config=conf,
dbconn=db,
adapter=adapter,
setup=twitch_setup,
)
try:
await bot.start()
finally:
await bot.close()
def _main():
setup_main_logger()
asyncio.run(main())

82
src/botdata.py Normal file
View File

@@ -0,0 +1,82 @@
from data import Registry, RowModel, Table
from data.columns import String, Timestamp, Integer, Bool
class UserAuth(RowModel):
"""
Schema
======
CREATE TABLE user_auth(
userid TEXT PRIMARY KEY,
token TEXT NOT NULL,
refresh_token TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_auth'
_cache_ = {}
userid = String(primary=True)
token = String()
refresh_token = String()
created_at = Timestamp()
_timestamp = Timestamp()
class BotChannel(RowModel):
"""
Schema
======
CREATE TABLE bot_channels(
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
autojoin BOOLEAN DEFAULT true,
listen_redeems BOOLEAN,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'bot_channels'
_cache_ = {}
userid = String(primary=True)
autojoin = Bool()
listen_redeems = Bool()
joined_at = Timestamp()
_timestamp = Timestamp()
class VersionHistory(RowModel):
"""
CREATE TABLE version_history(
component TEXT NOT NULL,
from_version INTEGER NOT NULL,
to_version INTEGER NOT NULL,
author TEXT NOT NULL,
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
);
"""
_tablename_ = 'version_history'
_cache_ = {}
component = String()
from_version = Integer()
to_version = Integer()
author = String()
_timestamp = Timestamp()
class BotData(Registry):
version_history = VersionHistory.table
user_auth = UserAuth.table
bot_channels = BotChannel.table
"""
CREATE TABLE user_auth_scopes(
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
scope TEXT NOT NULL
);
"""
user_auth_scopes = Table('user_auth_scopes')

22
src/constants.py Normal file
View File

@@ -0,0 +1,22 @@
from twitchio import Scopes
CONFIG_FILE = 'config/bot.conf'
SCHEMA_VERSIONS = {
'ROOT': 1,
'TWITCH_AUTH': 1
}
# Requested scopes for the bots own twitch user
BOTUSER_SCOPES = Scopes((
Scopes.user_read_chat,
Scopes.user_write_chat,
Scopes.user_bot,
Scopes.channel_bot,
))
# Default requested scopes for joining a channel
CHANNEL_SCOPES = Scopes((
Scopes.channel_bot,
))

1
src/data Submodule

Submodule src/data added at 334b5f5892

4
src/meta/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .args import args
from .bot import Bot
from .config import Conf, conf
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task

28
src/meta/args.py Normal file
View File

@@ -0,0 +1,28 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file."
)
parser.add_argument(
'--host',
dest='host',
default='127.0.0.1',
help="IP address to run the websocket server on."
)
parser.add_argument(
'--port',
dest='port',
default='5001',
help="Port to run the websocket server on."
)
args = parser.parse_args()

174
src/meta/bot.py Normal file
View File

@@ -0,0 +1,174 @@
import logging
from typing import Optional
from twitchio.authentication import UserTokenPayload
from twitchio.ext import commands
from twitchio import Scopes, eventsub
from data import Database, ORDER
from botdata import BotData, UserAuth, BotChannel, VersionHistory
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
from .config import Conf
logger = logging.getLogger(__name__)
class Bot(commands.Bot):
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
kwargs.setdefault('client_id', config.bot['client_id'])
kwargs.setdefault('client_secret', config.bot['client_secret'])
kwargs.setdefault('bot_id', config.bot['bot_id'])
kwargs.setdefault('prefix', config.bot['prefix'])
super().__init__(*args, **kwargs)
if config.bot.get('eventsub_secret', None):
self.using_webhooks = True
else:
self.using_webhooks = False
self.config = config
self.dbconn = dbconn
self.data: BotData = dbconn.load_registry(BotData())
self._setup_hook = setup
self.joined: dict[str, BotChannel] = {}
async def event_ready(self):
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
logger.info("Logged in as %s", self.bot_id)
async def version_check(self, component: str, req_version: int):
# Query the database to confirm that the given component is listed with the given version.
# Typically done upon loading a component
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
version = rows[0].to_version if rows else 0
if version != req_version:
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
else:
logger.debug(
"Component %s passed version check with version %s",
component,
version
)
return True
async def setup_hook(self):
await self.data.init()
for component, req in SCHEMA_VERSIONS.items():
await self.version_check(component, req)
if self._setup_hook is not None:
await self._setup_hook(self)
# Get all current bot channels
channels = await BotChannel.fetch_where(autojoin=True)
# Join the channels
await self.join_channels(*channels)
# Build bot account's own url
scopes = BOTUSER_SCOPES
url = self.get_auth_url(scopes)
logger.info("Bot account authorisation url: %s", url)
# Build everyone else's url
scopes = CHANNEL_SCOPES
url = self.get_auth_url(scopes)
logger.info("User account authorisation url: %s", url)
logger.info("Finished setup")
def get_auth_url(self, scopes: Optional[Scopes] = None):
if scopes is None:
scopes = Scopes((Scopes.channel_bot,))
url = self._adapter.get_authorization_url(scopes=scopes)
return url
async def join_channels(self, *channels: BotChannel):
"""
Register webhook subscriptions to the given channel(s).
"""
# TODO: If channels are already joined, unsubscribe
for channel in channels:
sub = None
try:
sub = eventsub.ChatMessageSubscription(
broadcaster_user_id=channel.userid,
user_id=self.bot_id,
)
if self.using_webhooks:
resp = await self.subscribe_webhook(sub)
else:
resp = await self.subscribe_websocket(sub)
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
self.joined[channel.userid] = channel
self.safe_dispatch('channel_joined', payload=channel)
except Exception:
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
async def event_oauth_authorized(self, payload: UserTokenPayload):
logger.debug("Oauth flow authorization with payload %s", repr(payload))
# Save the token and scopes and update internal authorisations
resp = await self.add_token(payload.access_token, payload.refresh_token)
if resp.user_id is None:
logger.warning(
"Oauth flow recieved with no user_id. Payload was: %s",
repr(payload)
)
return
# If the scopes authorised included channel:bot, ensure a BotChannel exists
# And join it if needed
if Scopes.channel_bot.value in resp.scopes:
bot_channel = await BotChannel.fetch_or_create(
resp.user_id,
autojoin=True,
)
if bot_channel.autojoin:
await self.join_channels(bot_channel)
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
async def add_token(self, token: str, refresh: str):
# Update the tokens in internal cache
# This also validates the token
# And hopefully gets the userid and scopes
resp = await super().add_token(token, refresh)
if resp.user_id is None:
logger.warning(
"Added a token with no user_id. Response was: %s",
repr(resp)
)
return resp
userid = resp.user_id
new_scopes = resp.scopes
# Save the token and scopes to data
# Wrap this in a transaction so if it fails halfway we rollback correctly
async with self.dbconn.connection() as conn:
self.dbconn.conn = conn
async with conn.transaction():
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
if row.token != token or row.refresh_token != refresh:
await row.update(token=token, refresh_token=refresh)
await self.data.user_auth_scopes.delete_where(userid=userid)
await self.data.user_auth_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in new_scopes)
)
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
return resp
async def load_tokens(self, path: str | None = None):
for row in await UserAuth.fetch_where():
try:
await self.add_token(row.token, row.refresh_token)
except Exception:
logger.exception(f"Failed to add token for {row}")

105
src/meta/config.py Normal file
View File

@@ -0,0 +1,105 @@
import configparser as cfgp
from .args import args
class MapDotProxy:
"""
Allows dot access to an underlying Mappable object.
"""
__slots__ = ("_map", "_converter")
def __init__(self, mappable, converter=None):
self._map = mappable
self._converter = converter
def __getattribute__(self, key):
_map = object.__getattribute__(self, '_map')
if key == '_map':
return _map
if key in _map:
_converter = object.__getattribute__(self, '_converter')
if _converter:
return _converter(_map[key])
else:
return _map[key]
else:
return object.__getattribute__(_map, key)
def __getitem__(self, key):
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
}
)
with open(configfile) as conff:
# Opening with read_file mainly to ensure the file exists
self.config.read_file(conff)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = MapDotProxy(self.config[self.section_name])
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
name = section.upper()
return self.config[name]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config, 'BOT')

466
src/meta/logger.py Normal file
View File

@@ -0,0 +1,466 @@
import sys
import logging
import asyncio
from typing import List, Optional
from logging.handlers import QueueListener, QueueHandler
import queue
import multiprocessing
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
import aiohttp
from .config import conf
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
context: ContextVar[Optional[str]] = ContextVar('context', default=None)
def set_logging_context(
context: Optional[str] = None,
action: Optional[str] = None,
stack: Optional[tuple[str, ...]] = None
):
"""
Statically set the logging context variables to the given values.
If `action` is given, pushes it onto the `log_action_stack`.
"""
if context is not None:
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
@contextmanager
def logging_context(context=None, action=None, stack=None):
"""
Context manager for executing a block of code in a given logging context.
This context manager should only be used around synchronous code.
This is because async code *may* get cancelled or externally garbage collected,
in which case the finally block will be executed in the wrong context.
See https://github.com/python/cpython/issues/93740
This can be refactored nicely if this gets merged:
https://github.com/python/cpython/pull/99634
(It will not necessarily break on async code,
if the async code can be guaranteed to clean up in its own context.)
"""
if context is not None:
oldcontext = log_context.get()
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
try:
yield
finally:
if context is not None:
log_context.set(oldcontext)
if stack is not None or action is not None:
log_action_stack.set(astack)
def with_log_ctx(isolate=True, **kwargs):
"""
Execute a coroutine inside a given logging context.
If `isolate` is true, ensures that context does not leak
outside the coroutine.
If `isolate` is false, just statically set the context,
which will leak unless the coroutine is
called in an externally copied context.
"""
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
if isolate:
with logging_context(**kwargs):
# Task creation will synchronously copy the context
# This is gc safe
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
return await task
else:
# This will leak context changes
set_logging_context(**kwargs)
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
# For backwards compatibility
log_wrap = with_log_ctx
def persist_task(task_collection: set):
"""
Coroutine decorator that ensures the coroutine is scheduled as a task
and added to the given task_collection for strong reference
when it is called.
This is just a hack to handle discord.py events potentially
being unexpectedly garbage collected.
Since this also implicitly schedules the coroutine as a task when it is called,
the coroutine will also be run inside an isolated context.
"""
def decorator(coro):
@wraps(coro)
async def wrapped(*w_args, **w_kwargs):
name = f"persisted-{coro.__name__}"
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
task_collection.add(task)
task.add_done_callback(lambda f: task_collection.discard(f))
await task
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
'%(cyan)%(app)-15s%(reset)|' +
'%(cyan)%(context)-24s%(reset)|' +
'%(cyan)%(actionstr)-22s%(reset)|' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
# datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ExactLevelFilter(logging.Filter):
def __init__(self, target_level, name=""):
super().__init__(name)
self.target_level = target_level
def filter(self, record):
return (record.levelno == self.target_level)
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter):
def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
except asyncio.CancelledError:
raise
except Exception:
self.handleError(record)
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
super().__init__()
self.webhook_url = webhook_url
self.prefix = prefix
self.batched = ""
self.batch = batch
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
self.bucket = Bucket(20, 40)
self.ignored = 0
self.session = None
self.webhook = None
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.format(record)
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
if self.session is None:
self.setup()
asyncio.create_task(self.post(record))
def setup(self):
from discord import Webhook
self.session = aiohttp.ClientSession()
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
async def post(self, record):
if record.context == 'Webhook Logger':
# Don't livelog livelog errors
# Otherwise we recurse and Cloudflare hates us
return
log_context.set("Webhook Logger")
log_action_stack.set(("Logging",))
log_app.set(record.app)
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900:
as_file = True
else:
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 1500:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += message
if len(self.batched) + len(message) > 1500:
await self._send_batched_now()
else:
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, message, as_file=False):
try:
self.bucket.request()
except BucketOverFull:
# Silently ignore
self.ignored += 1
return
except BucketFull:
logger.warning(
"Can't keep up! "
f"Ignoring records on live-logger {self.webhook.id}."
)
self.ignored += 1
return
else:
if self.ignored > 0:
logger.warning(
"Can't keep up! "
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
)
self.ignored = 0
try:
if as_file or len(message) > 1900:
with StringIO(message) as fp:
fp.seek(0)
await self.webhook.send(
f"{self.prefix}\n`{message.splitlines()[0]}`",
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else:
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
except discord.HTTPException:
logger.exception(
"Live logger errored. Slowing down live logger."
)
self.bucket.fill()
handlers = []
if webhook := conf.logging['general_log']:
handler = WebHookHandler(webhook, batch=True)
handlers.append(handler)
if webhook := conf.logging['warning_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
handler.addFilter(ExactLevelFilter(logging.WARNING))
handler.setLevel(logging.WARNING)
handlers.append(handler)
if webhook := conf.logging['error_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
handler.setLevel(logging.ERROR)
handlers.append(handler)
if webhook := conf.logging['critical_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
handler.setLevel(logging.CRITICAL)
handlers.append(handler)
def make_queue_handler(queue):
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
return qhandler
def setup_main_logger(multiprocess=False):
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
qhandler = make_queue_handler(q)
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler)
listener = QueueListener(
q, *handlers, respect_handler_level=True
)
listener.start()
return q

11
src/modules/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from meta import Bot
async def twitch_setup(bot: 'Bot'):
# Import and run setup methods from each module
# from . import module
# await module.twitch_setup(bot)
pass

88
src/utils/lib.py Normal file
View File

@@ -0,0 +1,88 @@
import re
import datetime as dt
def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str:
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's' # type: ignore
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def utc_now() -> dt.datetime:
"""
Return the current timezone-aware utc timestamp.
"""
return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
def replace_multiple(format_string, mapping):
"""
Subsistutes the keys from the format_dict with their corresponding values.
Substitution is non-chained, and done in a single pass via regex.
"""
if not mapping:
raise ValueError("Empty mapping passed.")
keys = list(mapping.keys())
pattern = '|'.join(f"({key})" for key in keys)
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
return string
def parse_dur(time_str):
"""
Parses a user provided time duration string into a number of seconds.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds

166
src/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,166 @@
import asyncio
import time
import logging
from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.monotonic()
self._last_full = False
self._wait_lock = asyncio.Lock()
@property
def full(self) -> bool:
"""
Return whether the bucket is 'full',
that is, whether an immediate request against the bucket will raise `BucketFull`.
"""
self._leak()
return self._level + 1 > self.max_level
@property
def overfull(self):
self._leak()
return self._level > self.max_level
@property
def delay(self):
self._leak()
if self._level + 1 > self.max_level:
delay = (self._level + 1 - self.max_level) * self.leak_rate
else:
delay = 0
return delay
def _leak(self):
if self._level:
elapsed = time.monotonic() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.monotonic()
def request(self):
self._leak()
if self._level > self.max_level:
raise BucketOverFull
elif self._level == self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
def fill(self):
self._leak()
self._level = max(self._level, self.max_level + 1)
async def wait(self):
"""
Wait until the bucket has room.
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
"""
# Wrapped in a lock so that waiters are correctly handled in wait-order
# Otherwise multiple waiters will have the same delay,
# and race for the wakeup after sleep.
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
async with self._wait_lock:
# We do this in a loop in case asyncio.sleep throws us out early,
# or a synchronous request overflows the bucket while we are waiting.
while self.full:
await asyncio.sleep(self.delay)
async def wrapped(self, coro):
await self.wait()
self.request()
await coro
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
bucket.request()
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")