generated from HoloTech/twitch-bot-template
Initial commit
This commit is contained in:
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[submodule "src/data"]
|
||||||
|
path = src/data
|
||||||
|
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git
|
||||||
|
[submodule "src/modules/profiles"]
|
||||||
|
path = src/modules/profiles
|
||||||
|
url = https://git.thewisewolf.dev/HoloTech/profiles-plugin.git
|
||||||
0
data/.gitignore
vendored
Normal file
0
data/.gitignore
vendored
Normal file
63
data/schema.sql
Normal file
63
data/schema.sql
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
-- Metadata {{{
|
||||||
|
CREATE TABLE version_history(
|
||||||
|
component TEXT NOT NULL,
|
||||||
|
from_version INTEGER NOT NULL,
|
||||||
|
to_version INTEGER NOT NULL,
|
||||||
|
author TEXT NOT NULL,
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
|
||||||
|
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW._timestamp = now();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ language 'plpgsql';
|
||||||
|
|
||||||
|
CREATE TABLE app_config(
|
||||||
|
appname TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- Twitch Auth {{{
|
||||||
|
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('TWITCH_AUTH', 0, 1, 'Initial Creation');
|
||||||
|
|
||||||
|
-- Authorisation tokens allowing us to take actions on behalf of certain users or channels.
|
||||||
|
-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope.
|
||||||
|
CREATE TABLE user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER user_auth_timestamp BEFORE UPDATE ON user_auth
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
CREATE TABLE user_auth_scopes(
|
||||||
|
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
scope TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Which joins will be joined at startup,
|
||||||
|
-- and any configurable choices needed when joining the channel
|
||||||
|
CREATE TABLE bot_channels(
|
||||||
|
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
autojoin BOOLEAN DEFAULT true,
|
||||||
|
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
-- vim: set fdm=marker:
|
||||||
21
example-config/bot.conf
Normal file
21
example-config/bot.conf
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[BOT]
|
||||||
|
prefix = ?
|
||||||
|
bot_id =
|
||||||
|
|
||||||
|
ALSO_READ = config/secrets.conf
|
||||||
|
|
||||||
|
[TWTICH]
|
||||||
|
host =
|
||||||
|
port =
|
||||||
|
domain =
|
||||||
|
redirect_path =
|
||||||
|
oauth_path =
|
||||||
|
evenstub_path =
|
||||||
|
|
||||||
|
webhooks =
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
general_log =
|
||||||
|
warning_log =
|
||||||
|
error_log =
|
||||||
|
critical_log =
|
||||||
10
example-config/secrets.conf
Normal file
10
example-config/secrets.conf
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[BOT]
|
||||||
|
client_id =
|
||||||
|
client_secret =
|
||||||
|
|
||||||
|
[TWITCH]
|
||||||
|
eventsub_secret =
|
||||||
|
|
||||||
|
[DATA]
|
||||||
|
args =
|
||||||
|
appid =
|
||||||
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
twitchio
|
||||||
|
psycopg[pool]
|
||||||
|
cachetools
|
||||||
|
discord.py
|
||||||
12
scripts/start_bot.py
Normal file
12
scripts/start_bot.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# !/bin/python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from bot import _main
|
||||||
|
_main()
|
||||||
60
src/bot.py
Normal file
60
src/bot.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from twitchio.web import AiohttpAdapter
|
||||||
|
|
||||||
|
from meta import Bot, conf, setup_main_logger, args, sockets
|
||||||
|
from data import Database
|
||||||
|
|
||||||
|
from modules import twitch_setup
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyAiohttpAdapter(AiohttpAdapter):
|
||||||
|
"""
|
||||||
|
Overrides the computed AiohttpAdapter redirect_url
|
||||||
|
to always use provided domain.
|
||||||
|
"""
|
||||||
|
def _find_redirect(self, request):
|
||||||
|
return self.redirect_url
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
db = Database(conf.data['args'])
|
||||||
|
|
||||||
|
async with db.open():
|
||||||
|
adapter_keys = (
|
||||||
|
'host', 'domain', 'port',
|
||||||
|
'redirect_path', 'oauth_path', 'eventsub_path',
|
||||||
|
'eventsub_secret',
|
||||||
|
)
|
||||||
|
adapter_args = {}
|
||||||
|
for key in adapter_keys:
|
||||||
|
value = conf.twitch.get(key, '').strip()
|
||||||
|
if value:
|
||||||
|
if key == 'port':
|
||||||
|
value = int(value)
|
||||||
|
adapter_args[key] = value
|
||||||
|
adapter = ProxyAiohttpAdapter(**adapter_args)
|
||||||
|
|
||||||
|
bot = Bot(
|
||||||
|
config=conf,
|
||||||
|
dbconn=db,
|
||||||
|
adapter=adapter,
|
||||||
|
setup=twitch_setup,
|
||||||
|
using_webhooks=conf.twitch.getboolean('webhooks', False)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with websockets.serve(sockets.root_handler, '', conf.wserver.getint('port')):
|
||||||
|
try:
|
||||||
|
await bot.start()
|
||||||
|
finally:
|
||||||
|
await bot.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _main():
|
||||||
|
setup_main_logger()
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
82
src/botdata.py
Normal file
82
src/botdata.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import String, Timestamp, Integer, Bool
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuth(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'user_auth'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = String(primary=True)
|
||||||
|
token = String()
|
||||||
|
refresh_token = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class BotChannel(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE bot_channels(
|
||||||
|
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
autojoin BOOLEAN DEFAULT true,
|
||||||
|
listen_redeems BOOLEAN,
|
||||||
|
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'bot_channels'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = String(primary=True)
|
||||||
|
autojoin = Bool()
|
||||||
|
listen_redeems = Bool()
|
||||||
|
joined_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class VersionHistory(RowModel):
|
||||||
|
"""
|
||||||
|
CREATE TABLE version_history(
|
||||||
|
component TEXT NOT NULL,
|
||||||
|
from_version INTEGER NOT NULL,
|
||||||
|
to_version INTEGER NOT NULL,
|
||||||
|
author TEXT NOT NULL,
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'version_history'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
component = String()
|
||||||
|
from_version = Integer()
|
||||||
|
to_version = Integer()
|
||||||
|
author = String()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class BotData(Registry):
|
||||||
|
version_history = VersionHistory.table
|
||||||
|
|
||||||
|
user_auth = UserAuth.table
|
||||||
|
bot_channels = BotChannel.table
|
||||||
|
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_auth_scopes(
|
||||||
|
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
scope TEXT NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
user_auth_scopes = Table('user_auth_scopes')
|
||||||
|
|
||||||
22
src/constants.py
Normal file
22
src/constants.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from twitchio import Scopes
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_FILE = 'config/bot.conf'
|
||||||
|
|
||||||
|
SCHEMA_VERSIONS = {
|
||||||
|
'ROOT': 1,
|
||||||
|
'TWITCH_AUTH': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Requested scopes for the bots own twitch user
|
||||||
|
BOTUSER_SCOPES = Scopes((
|
||||||
|
Scopes.user_read_chat,
|
||||||
|
Scopes.user_write_chat,
|
||||||
|
Scopes.user_bot,
|
||||||
|
Scopes.channel_bot,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Default requested scopes for joining a channel
|
||||||
|
CHANNEL_SCOPES = Scopes((
|
||||||
|
Scopes.channel_bot,
|
||||||
|
))
|
||||||
1
src/data
Submodule
1
src/data
Submodule
Submodule src/data added at 334b5f5892
5
src/meta/__init__.py
Normal file
5
src/meta/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .args import args
|
||||||
|
from .bot import Bot
|
||||||
|
from .context import Context
|
||||||
|
from .config import Conf, conf
|
||||||
|
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task
|
||||||
28
src/meta/args.py
Normal file
28
src/meta/args.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from constants import CONFIG_FILE
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Parsed commandline arguments
|
||||||
|
# ------------------------------
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--conf',
|
||||||
|
dest='config',
|
||||||
|
default=CONFIG_FILE,
|
||||||
|
help="Path to configuration file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--host',
|
||||||
|
dest='host',
|
||||||
|
default='127.0.0.1',
|
||||||
|
help="IP address to run the websocket server on."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--port',
|
||||||
|
dest='port',
|
||||||
|
default='5001',
|
||||||
|
help="Port to run the websocket server on."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
196
src/meta/bot.py
Normal file
196
src/meta/bot.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
|
||||||
|
|
||||||
|
from twitchio.authentication import UserTokenPayload
|
||||||
|
from twitchio.ext import commands
|
||||||
|
from twitchio import Scopes, eventsub
|
||||||
|
|
||||||
|
from data import Database, ORDER
|
||||||
|
from botdata import BotData, UserAuth, BotChannel, VersionHistory
|
||||||
|
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
|
||||||
|
|
||||||
|
from .config import Conf
|
||||||
|
from .context import Context
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from modules.profiles.profiles.twitch.component import ProfilesComponent
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Bot(commands.Bot):
|
||||||
|
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
|
||||||
|
kwargs.setdefault('client_id', config.bot['client_id'])
|
||||||
|
kwargs.setdefault('client_secret', config.bot['client_secret'])
|
||||||
|
kwargs.setdefault('bot_id', config.bot['bot_id'])
|
||||||
|
kwargs.setdefault('prefix', config.bot['prefix'])
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Whether we should do eventsub via webhooks or websockets
|
||||||
|
self.using_webhooks = kwargs.get('using_webhooks', False)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.dbconn = dbconn
|
||||||
|
self.data: BotData = dbconn.load_registry(BotData())
|
||||||
|
self._setup_hook = setup
|
||||||
|
|
||||||
|
self.joined: dict[str, BotChannel] = {}
|
||||||
|
|
||||||
|
# Make the type checker happy about fetching components by name
|
||||||
|
# TODO: Move to stubs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profiles(self):
|
||||||
|
return self.get_component('ProfilesComponent')
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_component(self, name: str) -> Optional[commands.Component]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_component(self, name: str) -> Optional[commands.Component]:
|
||||||
|
return super().get_component(name)
|
||||||
|
|
||||||
|
def get_context(self, payload, *, cls: Any = None) -> Context:
|
||||||
|
cls = cls or Context
|
||||||
|
return cls(payload, bot=self)
|
||||||
|
|
||||||
|
async def event_ready(self):
|
||||||
|
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||||
|
logger.info("Logged in as %s", self.bot_id)
|
||||||
|
|
||||||
|
async def version_check(self, component: str, req_version: int):
|
||||||
|
# Query the database to confirm that the given component is listed with the given version.
|
||||||
|
# Typically done upon loading a component
|
||||||
|
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
||||||
|
|
||||||
|
version = rows[0].to_version if rows else 0
|
||||||
|
|
||||||
|
if version != req_version:
|
||||||
|
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Component %s passed version check with version %s",
|
||||||
|
component,
|
||||||
|
version
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def setup_hook(self):
|
||||||
|
await self.data.init()
|
||||||
|
for component, req in SCHEMA_VERSIONS.items():
|
||||||
|
await self.version_check(component, req)
|
||||||
|
|
||||||
|
if self._setup_hook is not None:
|
||||||
|
await self._setup_hook(self)
|
||||||
|
|
||||||
|
# Get all current bot channels
|
||||||
|
channels = await BotChannel.fetch_where(autojoin=True)
|
||||||
|
|
||||||
|
# Join the channels
|
||||||
|
await self.join_channels(*channels)
|
||||||
|
|
||||||
|
# Build bot account's own url
|
||||||
|
scopes = BOTUSER_SCOPES
|
||||||
|
url = self.get_auth_url(scopes)
|
||||||
|
logger.info("Bot account authorisation url: %s", url)
|
||||||
|
|
||||||
|
# Build everyone else's url
|
||||||
|
scopes = CHANNEL_SCOPES
|
||||||
|
url = self.get_auth_url(scopes)
|
||||||
|
logger.info("User account authorisation url: %s", url)
|
||||||
|
|
||||||
|
logger.info("Finished setup")
|
||||||
|
|
||||||
|
def get_auth_url(self, scopes: Optional[Scopes] = None):
|
||||||
|
if scopes is None:
|
||||||
|
scopes = Scopes((Scopes.channel_bot,))
|
||||||
|
|
||||||
|
url = self._adapter.get_authorization_url(scopes=scopes)
|
||||||
|
return url
|
||||||
|
|
||||||
|
async def join_channels(self, *channels: BotChannel):
|
||||||
|
"""
|
||||||
|
Register webhook subscriptions to the given channel(s).
|
||||||
|
"""
|
||||||
|
# TODO: If channels are already joined, unsubscribe
|
||||||
|
for channel in channels:
|
||||||
|
sub = None
|
||||||
|
try:
|
||||||
|
sub = eventsub.ChatMessageSubscription(
|
||||||
|
broadcaster_user_id=channel.userid,
|
||||||
|
user_id=self.bot_id,
|
||||||
|
)
|
||||||
|
if self.using_webhooks:
|
||||||
|
resp = await self.subscribe_webhook(sub)
|
||||||
|
else:
|
||||||
|
resp = await self.subscribe_websocket(sub)
|
||||||
|
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
||||||
|
self.joined[channel.userid] = channel
|
||||||
|
self.safe_dispatch('channel_joined', payload=channel)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
|
||||||
|
|
||||||
|
async def event_oauth_authorized(self, payload: UserTokenPayload):
|
||||||
|
logger.debug("Oauth flow authorization with payload %s", repr(payload))
|
||||||
|
# Save the token and scopes and update internal authorisations
|
||||||
|
resp = await self.add_token(payload.access_token, payload.refresh_token)
|
||||||
|
if resp.user_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Oauth flow recieved with no user_id. Payload was: %s",
|
||||||
|
repr(payload)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the scopes authorised included channel:bot, ensure a BotChannel exists
|
||||||
|
# And join it if needed
|
||||||
|
if Scopes.channel_bot.value in resp.scopes:
|
||||||
|
bot_channel = await BotChannel.fetch_or_create(
|
||||||
|
resp.user_id,
|
||||||
|
autojoin=True,
|
||||||
|
)
|
||||||
|
if bot_channel.autojoin:
|
||||||
|
await self.join_channels(bot_channel)
|
||||||
|
|
||||||
|
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
|
||||||
|
|
||||||
|
async def add_token(self, token: str, refresh: str):
|
||||||
|
# Update the tokens in internal cache
|
||||||
|
# This also validates the token
|
||||||
|
# And hopefully gets the userid and scopes
|
||||||
|
resp = await super().add_token(token, refresh)
|
||||||
|
if resp.user_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Added a token with no user_id. Response was: %s",
|
||||||
|
repr(resp)
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
userid = resp.user_id
|
||||||
|
new_scopes = resp.scopes
|
||||||
|
|
||||||
|
# Save the token and scopes to data
|
||||||
|
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
||||||
|
# TODO
|
||||||
|
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
||||||
|
if row.token != token or row.refresh_token != refresh:
|
||||||
|
await row.update(token=token, refresh_token=refresh)
|
||||||
|
await self.data.user_auth_scopes.delete_where(userid=userid)
|
||||||
|
await self.data.user_auth_scopes.insert_many(
|
||||||
|
('userid', 'scope'),
|
||||||
|
*((userid, scope) for scope in new_scopes)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def load_tokens(self, path: str | None = None):
|
||||||
|
for row in await UserAuth.fetch_where():
|
||||||
|
try:
|
||||||
|
await self.add_token(row.token, row.refresh_token)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to add token for {row}")
|
||||||
105
src/meta/config.py
Normal file
105
src/meta/config.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import configparser as cfgp
|
||||||
|
|
||||||
|
from .args import args
|
||||||
|
|
||||||
|
|
||||||
|
class MapDotProxy:
|
||||||
|
"""
|
||||||
|
Allows dot access to an underlying Mappable object.
|
||||||
|
"""
|
||||||
|
__slots__ = ("_map", "_converter")
|
||||||
|
|
||||||
|
def __init__(self, mappable, converter=None):
|
||||||
|
self._map = mappable
|
||||||
|
self._converter = converter
|
||||||
|
|
||||||
|
def __getattribute__(self, key):
|
||||||
|
_map = object.__getattribute__(self, '_map')
|
||||||
|
if key == '_map':
|
||||||
|
return _map
|
||||||
|
if key in _map:
|
||||||
|
_converter = object.__getattribute__(self, '_converter')
|
||||||
|
if _converter:
|
||||||
|
return _converter(_map[key])
|
||||||
|
else:
|
||||||
|
return _map[key]
|
||||||
|
else:
|
||||||
|
return object.__getattribute__(_map, key)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._map.__getitem__(key)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParser(cfgp.ConfigParser):
|
||||||
|
"""
|
||||||
|
Extension of base ConfigParser allowing optional
|
||||||
|
section option retrieval without defaults.
|
||||||
|
"""
|
||||||
|
def options(self, section, no_defaults=False, **kwargs):
|
||||||
|
if no_defaults:
|
||||||
|
try:
|
||||||
|
return list(self._sections[section].keys())
|
||||||
|
except KeyError:
|
||||||
|
raise cfgp.NoSectionError(section)
|
||||||
|
else:
|
||||||
|
return super().options(section, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Conf:
|
||||||
|
def __init__(self, configfile, section_name="DEFAULT"):
|
||||||
|
self.configfile = configfile
|
||||||
|
|
||||||
|
self.config = ConfigParser(
|
||||||
|
converters={
|
||||||
|
"intlist": self._getintlist,
|
||||||
|
"list": self._getlist,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(configfile) as conff:
|
||||||
|
# Opening with read_file mainly to ensure the file exists
|
||||||
|
self.config.read_file(conff)
|
||||||
|
|
||||||
|
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||||
|
|
||||||
|
self.default = self.config["DEFAULT"]
|
||||||
|
self.section = MapDotProxy(self.config[self.section_name])
|
||||||
|
self.bot = self.section
|
||||||
|
|
||||||
|
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||||
|
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||||
|
read = set()
|
||||||
|
while more_to_read:
|
||||||
|
to_read = more_to_read.pop(0)
|
||||||
|
read.add(to_read)
|
||||||
|
self.config.read(to_read)
|
||||||
|
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||||
|
if path not in read and path not in more_to_read]
|
||||||
|
more_to_read.extend(new_paths)
|
||||||
|
|
||||||
|
global conf
|
||||||
|
conf = self
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.section[key].strip()
|
||||||
|
|
||||||
|
def __getattr__(self, section):
|
||||||
|
name = section.upper()
|
||||||
|
return self.config[name]
|
||||||
|
|
||||||
|
def get(self, name, fallback=None):
|
||||||
|
result = self.section.get(name, fallback)
|
||||||
|
return result.strip() if result else result
|
||||||
|
|
||||||
|
def _getintlist(self, value):
|
||||||
|
return [int(item.strip()) for item in value.split(',')]
|
||||||
|
|
||||||
|
def _getlist(self, value):
|
||||||
|
return [item.strip() for item in value.split(',')]
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
with open(self.configfile, 'w') as conffile:
|
||||||
|
self.config.write(conffile)
|
||||||
|
|
||||||
|
|
||||||
|
conf = Conf(args.config, 'BOT')
|
||||||
4
src/meta/context.py
Normal file
4
src/meta/context.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from twitchio.ext import commands as cmds
|
||||||
|
|
||||||
|
class Context(cmds.Context):
|
||||||
|
...
|
||||||
469
src/meta/logger.py
Normal file
469
src/meta/logger.py
Normal file
@@ -0,0 +1,469 @@
|
|||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
from logging.handlers import QueueListener, QueueHandler
|
||||||
|
import queue
|
||||||
|
import multiprocessing
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from io import StringIO
|
||||||
|
from functools import wraps
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .config import conf
|
||||||
|
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
|
||||||
|
log_logger = logging.getLogger(__name__)
|
||||||
|
log_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||||
|
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
|
||||||
|
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
|
||||||
|
|
||||||
|
context: ContextVar[Optional[str]] = ContextVar('context', default=None)
|
||||||
|
|
||||||
|
def set_logging_context(
|
||||||
|
context: Optional[str] = None,
|
||||||
|
action: Optional[str] = None,
|
||||||
|
stack: Optional[tuple[str, ...]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Statically set the logging context variables to the given values.
|
||||||
|
|
||||||
|
If `action` is given, pushes it onto the `log_action_stack`.
|
||||||
|
"""
|
||||||
|
if context is not None:
|
||||||
|
log_context.set(context)
|
||||||
|
if action is not None or stack is not None:
|
||||||
|
astack = log_action_stack.get()
|
||||||
|
newstack = stack if stack is not None else astack
|
||||||
|
if action is not None:
|
||||||
|
newstack = (*newstack, action)
|
||||||
|
log_action_stack.set(newstack)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def logging_context(context=None, action=None, stack=None):
|
||||||
|
"""
|
||||||
|
Context manager for executing a block of code in a given logging context.
|
||||||
|
|
||||||
|
This context manager should only be used around synchronous code.
|
||||||
|
This is because async code *may* get cancelled or externally garbage collected,
|
||||||
|
in which case the finally block will be executed in the wrong context.
|
||||||
|
See https://github.com/python/cpython/issues/93740
|
||||||
|
This can be refactored nicely if this gets merged:
|
||||||
|
https://github.com/python/cpython/pull/99634
|
||||||
|
|
||||||
|
(It will not necessarily break on async code,
|
||||||
|
if the async code can be guaranteed to clean up in its own context.)
|
||||||
|
"""
|
||||||
|
if context is not None:
|
||||||
|
oldcontext = log_context.get()
|
||||||
|
log_context.set(context)
|
||||||
|
if action is not None or stack is not None:
|
||||||
|
astack = log_action_stack.get()
|
||||||
|
newstack = stack if stack is not None else astack
|
||||||
|
if action is not None:
|
||||||
|
newstack = (*newstack, action)
|
||||||
|
log_action_stack.set(newstack)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if context is not None:
|
||||||
|
log_context.set(oldcontext)
|
||||||
|
if stack is not None or action is not None:
|
||||||
|
log_action_stack.set(astack)
|
||||||
|
|
||||||
|
|
||||||
|
def with_log_ctx(isolate=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Execute a coroutine inside a given logging context.
|
||||||
|
|
||||||
|
If `isolate` is true, ensures that context does not leak
|
||||||
|
outside the coroutine.
|
||||||
|
|
||||||
|
If `isolate` is false, just statically set the context,
|
||||||
|
which will leak unless the coroutine is
|
||||||
|
called in an externally copied context.
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
if isolate:
|
||||||
|
with logging_context(**kwargs):
|
||||||
|
# Task creation will synchronously copy the context
|
||||||
|
# This is gc safe
|
||||||
|
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
|
||||||
|
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
|
||||||
|
return await task
|
||||||
|
else:
|
||||||
|
# This will leak context changes
|
||||||
|
set_logging_context(**kwargs)
|
||||||
|
return await func(*w_args, **w_kwargs)
|
||||||
|
return wrapped
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility
|
||||||
|
log_wrap = with_log_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def persist_task(task_collection: set):
|
||||||
|
"""
|
||||||
|
Coroutine decorator that ensures the coroutine is scheduled as a task
|
||||||
|
and added to the given task_collection for strong reference
|
||||||
|
when it is called.
|
||||||
|
|
||||||
|
This is just a hack to handle discord.py events potentially
|
||||||
|
being unexpectedly garbage collected.
|
||||||
|
|
||||||
|
Since this also implicitly schedules the coroutine as a task when it is called,
|
||||||
|
the coroutine will also be run inside an isolated context.
|
||||||
|
"""
|
||||||
|
def decorator(coro):
|
||||||
|
@wraps(coro)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
name = f"persisted-{coro.__name__}"
|
||||||
|
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
|
||||||
|
task_collection.add(task)
|
||||||
|
task.add_done_callback(lambda f: task_collection.discard(f))
|
||||||
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
RESET_SEQ = "\033[0m"
|
||||||
|
COLOR_SEQ = "\033[3%dm"
|
||||||
|
BOLD_SEQ = "\033[1m"
|
||||||
|
"]]]"
|
||||||
|
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||||
|
|
||||||
|
|
||||||
|
def colour_escape(fmt: str) -> str:
|
||||||
|
cmap = {
|
||||||
|
'%(black)': COLOR_SEQ % BLACK,
|
||||||
|
'%(red)': COLOR_SEQ % RED,
|
||||||
|
'%(green)': COLOR_SEQ % GREEN,
|
||||||
|
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||||
|
'%(blue)': COLOR_SEQ % BLUE,
|
||||||
|
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||||
|
'%(cyan)': COLOR_SEQ % CYAN,
|
||||||
|
'%(white)': COLOR_SEQ % WHITE,
|
||||||
|
'%(reset)': RESET_SEQ,
|
||||||
|
'%(bold)': BOLD_SEQ,
|
||||||
|
}
|
||||||
|
for key, value in cmap.items():
|
||||||
|
fmt = fmt.replace(key, value)
|
||||||
|
return fmt
|
||||||
|
|
||||||
|
|
||||||
|
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
|
||||||
|
'%(cyan)%(app)-15s%(reset)|' +
|
||||||
|
'%(cyan)%(context)-24s%(reset)|' +
|
||||||
|
'%(cyan)%(actionstr)-22s%(reset)|' +
|
||||||
|
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||||
|
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||||
|
log_format = colour_escape(log_format)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup the logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
log_fmt = logging.Formatter(
|
||||||
|
fmt=log_format,
|
||||||
|
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanFilter(logging.Filter):
|
||||||
|
def __init__(self, exclusive_maximum, name=""):
|
||||||
|
super(LessThanFilter, self).__init__(name)
|
||||||
|
self.max_level = exclusive_maximum
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.levelno < self.max_level else 0
|
||||||
|
|
||||||
|
class ExactLevelFilter(logging.Filter):
|
||||||
|
def __init__(self, target_level, name=""):
|
||||||
|
super().__init__(name)
|
||||||
|
self.target_level = target_level
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
return (record.levelno == self.target_level)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadFilter(logging.Filter):
|
||||||
|
def __init__(self, thread_name):
|
||||||
|
super().__init__("")
|
||||||
|
self.thread = thread_name
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.threadName == self.thread else 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextInjection(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# These guards are to allow override through _extra
|
||||||
|
# And to ensure the injection is idempotent
|
||||||
|
if not hasattr(record, 'context'):
|
||||||
|
record.context = log_context.get()
|
||||||
|
|
||||||
|
if not hasattr(record, 'actionstr'):
|
||||||
|
action_stack = log_action_stack.get()
|
||||||
|
if hasattr(record, 'action'):
|
||||||
|
action_stack = (*action_stack, record.action)
|
||||||
|
if action_stack:
|
||||||
|
record.actionstr = ' ➔ '.join(action_stack)
|
||||||
|
else:
|
||||||
|
record.actionstr = "Unknown Action"
|
||||||
|
|
||||||
|
if not hasattr(record, 'app'):
|
||||||
|
record.app = log_app.get()
|
||||||
|
|
||||||
|
if not hasattr(record, 'ctx'):
|
||||||
|
if ctx := context.get():
|
||||||
|
record.ctx = repr(ctx)
|
||||||
|
else:
|
||||||
|
record.ctx = None
|
||||||
|
|
||||||
|
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||||
|
record.ctxstr = '\n' + record.ctx
|
||||||
|
else:
|
||||||
|
record.ctxstr = ""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||||
|
logging_handler_out.setLevel(logging.DEBUG)
|
||||||
|
logging_handler_out.setFormatter(log_fmt)
|
||||||
|
logging_handler_out.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_out)
|
||||||
|
log_logger.addHandler(logging_handler_out)
|
||||||
|
|
||||||
|
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||||
|
logging_handler_err.setLevel(logging.WARNING)
|
||||||
|
logging_handler_err.setFormatter(log_fmt)
|
||||||
|
logging_handler_err.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_err)
|
||||||
|
log_logger.addHandler(logging_handler_err)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalQueueHandler(QueueHandler):
|
||||||
|
def _emit(self, record: logging.LogRecord) -> None:
|
||||||
|
# Removed the call to self.prepare(), handle task cancellation
|
||||||
|
try:
|
||||||
|
self.enqueue(record)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class WebHookHandler(logging.StreamHandler):
|
||||||
|
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
|
||||||
|
super().__init__()
|
||||||
|
self.webhook_url = webhook_url
|
||||||
|
self.prefix = prefix
|
||||||
|
self.batched = ""
|
||||||
|
self.batch = batch
|
||||||
|
self.loop = loop
|
||||||
|
self.batch_delay = 10
|
||||||
|
self.batch_task = None
|
||||||
|
self.last_batched = None
|
||||||
|
self.waiting = []
|
||||||
|
|
||||||
|
self.bucket = Bucket(20, 40)
|
||||||
|
self.ignored = 0
|
||||||
|
|
||||||
|
self.session = None
|
||||||
|
self.webhook = None
|
||||||
|
|
||||||
|
def get_loop(self):
|
||||||
|
if self.loop is None:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
return self.loop
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
self.format(record)
|
||||||
|
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||||
|
|
||||||
|
def _post(self, record):
|
||||||
|
if self.session is None:
|
||||||
|
self.setup()
|
||||||
|
asyncio.create_task(self.post(record))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
from discord import Webhook
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
|
||||||
|
|
||||||
|
async def post(self, record):
|
||||||
|
if record.context == 'Webhook Logger':
|
||||||
|
# Don't livelog livelog errors
|
||||||
|
# Otherwise we recurse and Cloudflare hates us
|
||||||
|
return
|
||||||
|
log_context.set("Webhook Logger")
|
||||||
|
log_action_stack.set(("Logging",))
|
||||||
|
log_app.set(record.app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||||
|
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||||
|
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||||
|
message = f"{header}\n{record.msg}{context}"
|
||||||
|
|
||||||
|
if len(message) > 1900:
|
||||||
|
as_file = True
|
||||||
|
else:
|
||||||
|
as_file = False
|
||||||
|
message = "```md\n{}\n```".format(message)
|
||||||
|
|
||||||
|
# Post the log message(s)
|
||||||
|
if self.batch:
|
||||||
|
if len(message) > 1500:
|
||||||
|
await self._send_batched_now()
|
||||||
|
await self._send(message, as_file=as_file)
|
||||||
|
else:
|
||||||
|
self.batched += message
|
||||||
|
if len(self.batched) + len(message) > 1500:
|
||||||
|
await self._send_batched_now()
|
||||||
|
else:
|
||||||
|
asyncio.create_task(self._schedule_batched())
|
||||||
|
else:
|
||||||
|
await self._send(message, as_file=as_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def _schedule_batched(self):
|
||||||
|
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||||
|
# noop, don't reschedule if it is already scheduled
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||||
|
await self.batch_task
|
||||||
|
await self._send_batched()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def _send_batched_now(self):
|
||||||
|
if self.batch_task is not None and not self.batch_task.done():
|
||||||
|
self.batch_task.cancel()
|
||||||
|
self.last_batched = None
|
||||||
|
await self._send_batched()
|
||||||
|
|
||||||
|
async def _send_batched(self):
|
||||||
|
if self.batched:
|
||||||
|
batched = self.batched
|
||||||
|
self.batched = ""
|
||||||
|
await self._send(batched)
|
||||||
|
|
||||||
|
async def _send(self, message, as_file=False):
|
||||||
|
import discord
|
||||||
|
from discord import File
|
||||||
|
try:
|
||||||
|
self.bucket.request()
|
||||||
|
except BucketOverFull:
|
||||||
|
# Silently ignore
|
||||||
|
self.ignored += 1
|
||||||
|
return
|
||||||
|
except BucketFull:
|
||||||
|
logger.warning(
|
||||||
|
"Can't keep up! "
|
||||||
|
f"Ignoring records on live-logger {self.webhook.id}."
|
||||||
|
)
|
||||||
|
self.ignored += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if self.ignored > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Can't keep up! "
|
||||||
|
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
|
||||||
|
)
|
||||||
|
self.ignored = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if as_file or len(message) > 1900:
|
||||||
|
with StringIO(message) as fp:
|
||||||
|
fp.seek(0)
|
||||||
|
await self.webhook.send(
|
||||||
|
f"{self.prefix}\n`{message.splitlines()[0]}`",
|
||||||
|
file=File(fp, filename="logs.md"),
|
||||||
|
username=log_app.get()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
|
||||||
|
except discord.HTTPException:
|
||||||
|
logger.exception(
|
||||||
|
"Live logger errored. Slowing down live logger."
|
||||||
|
)
|
||||||
|
self.bucket.fill()
|
||||||
|
|
||||||
|
|
||||||
|
handlers = []
|
||||||
|
if webhook := conf.logging['general_log']:
|
||||||
|
handler = WebHookHandler(webhook, batch=True)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['warning_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
|
||||||
|
handler.addFilter(ExactLevelFilter(logging.WARNING))
|
||||||
|
handler.setLevel(logging.WARNING)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['error_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
|
||||||
|
handler.setLevel(logging.ERROR)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['critical_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
|
||||||
|
handler.setLevel(logging.CRITICAL)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def make_queue_handler(queue):
|
||||||
|
qhandler = QueueHandler(queue)
|
||||||
|
qhandler.setLevel(logging.INFO)
|
||||||
|
qhandler.addFilter(ContextInjection())
|
||||||
|
return qhandler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_main_logger(multiprocess=False):
|
||||||
|
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
|
||||||
|
if handlers:
|
||||||
|
# First create a separate loop to run the handlers on
|
||||||
|
import threading
|
||||||
|
|
||||||
|
def run_loop(loop):
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_forever()
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||||
|
loop_thread.daemon = True
|
||||||
|
loop_thread.start()
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
handler.loop = loop
|
||||||
|
|
||||||
|
qhandler = make_queue_handler(q)
|
||||||
|
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||||
|
logger.addHandler(qhandler)
|
||||||
|
|
||||||
|
listener = QueueListener(
|
||||||
|
q, *handlers, respect_handler_level=True
|
||||||
|
)
|
||||||
|
listener.start()
|
||||||
|
return q
|
||||||
68
src/meta/sockets.py
Normal file
68
src/meta/sockets.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Channel(ABC):
|
||||||
|
"""
|
||||||
|
A channel is a stateful connection handler for a group of connected websockets.
|
||||||
|
"""
|
||||||
|
name = "Root Channel"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.connections = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def empty(self):
|
||||||
|
return not self.connections
|
||||||
|
|
||||||
|
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
|
||||||
|
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
|
||||||
|
self.connections.add(websocket)
|
||||||
|
|
||||||
|
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
|
||||||
|
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
|
||||||
|
self.connections.discard(websocket)
|
||||||
|
|
||||||
|
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def send_event(self, event, websocket=None):
|
||||||
|
message = json.dumps(event)
|
||||||
|
if not websocket:
|
||||||
|
for ws in self.connections:
|
||||||
|
await ws.send(message)
|
||||||
|
else:
|
||||||
|
await websocket.send(message)
|
||||||
|
|
||||||
|
channels = {}
|
||||||
|
|
||||||
|
def register_channel(name, channel: Channel):
|
||||||
|
channels[name] = channel
|
||||||
|
|
||||||
|
|
||||||
|
async def root_handler(websocket: websockets.WebSocketServerProtocol):
|
||||||
|
message = await websocket.recv()
|
||||||
|
event = json.loads(message)
|
||||||
|
|
||||||
|
if event.get('type', None) != 'init':
|
||||||
|
raise ValueError("Received Websocket connection with no init.")
|
||||||
|
|
||||||
|
if (channel_name := event.get('channel', None)) not in channels:
|
||||||
|
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
|
||||||
|
channel = channels[channel_name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await channel.on_connection(websocket, event)
|
||||||
|
async for message in websocket:
|
||||||
|
await channel.handle_message(websocket, message)
|
||||||
|
finally:
|
||||||
|
await channel.del_connection(websocket)
|
||||||
9
src/modules/__init__.py
Normal file
9
src/modules/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from meta import Bot
|
||||||
|
|
||||||
|
|
||||||
|
async def twitch_setup(bot: 'Bot'):
|
||||||
|
from . import profiles
|
||||||
|
await profiles.twitch_setup(bot)
|
||||||
1
src/modules/profiles
Submodule
1
src/modules/profiles
Submodule
Submodule src/modules/profiles added at 0363dc2bcd
88
src/utils/lib.py
Normal file
88
src/utils/lib.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import re
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
|
||||||
|
def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str:
|
||||||
|
"""
|
||||||
|
Convert a datetime.timedelta object into an easily readable duration string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
delta: datetime.timedelta
|
||||||
|
The timedelta object to convert into a readable string.
|
||||||
|
sec: bool
|
||||||
|
Whether to include the seconds from the timedelta object in the string.
|
||||||
|
minutes: bool
|
||||||
|
Whether to include the minutes from the timedelta object in the string.
|
||||||
|
short: bool
|
||||||
|
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||||
|
|
||||||
|
Returns: str
|
||||||
|
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||||
|
Time units will be abbreviated if short was set to True.
|
||||||
|
"""
|
||||||
|
output = [[delta.days, 'd' if short else ' day'],
|
||||||
|
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||||
|
if minutes:
|
||||||
|
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||||
|
if sec:
|
||||||
|
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||||
|
for i in range(len(output)):
|
||||||
|
if output[i][0] != 1 and not short:
|
||||||
|
output[i][1] += 's' # type: ignore
|
||||||
|
reply_msg = []
|
||||||
|
if output[0][0] != 0:
|
||||||
|
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||||
|
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||||
|
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||||
|
for i in range(2, len(output) - 1):
|
||||||
|
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||||
|
if not short and reply_msg:
|
||||||
|
reply_msg.append("and ")
|
||||||
|
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||||
|
return "".join(reply_msg)
|
||||||
|
|
||||||
|
def utc_now() -> dt.datetime:
|
||||||
|
"""
|
||||||
|
Return the current timezone-aware utc timestamp.
|
||||||
|
"""
|
||||||
|
return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
||||||
|
|
||||||
|
def replace_multiple(format_string, mapping):
|
||||||
|
"""
|
||||||
|
Subsistutes the keys from the format_dict with their corresponding values.
|
||||||
|
|
||||||
|
Substitution is non-chained, and done in a single pass via regex.
|
||||||
|
"""
|
||||||
|
if not mapping:
|
||||||
|
raise ValueError("Empty mapping passed.")
|
||||||
|
|
||||||
|
keys = list(mapping.keys())
|
||||||
|
pattern = '|'.join(f"({key})" for key in keys)
|
||||||
|
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dur(time_str):
|
||||||
|
"""
|
||||||
|
Parses a user provided time duration string into a number of seconds.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
time_str: str
|
||||||
|
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||||
|
|
||||||
|
Returns: int
|
||||||
|
The number of seconds the duration represents.
|
||||||
|
"""
|
||||||
|
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||||
|
'h': lambda x: x * 60 * 60,
|
||||||
|
'm': lambda x: x * 60,
|
||||||
|
's': lambda x: x}
|
||||||
|
time_str = time_str.strip(" ,")
|
||||||
|
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||||
|
seconds = 0
|
||||||
|
for bit in found:
|
||||||
|
if bit[1] in funcs:
|
||||||
|
seconds += funcs[bit[1]](int(bit[0]))
|
||||||
|
return seconds
|
||||||
166
src/utils/ratelimits.py
Normal file
166
src/utils/ratelimits.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFull(Exception):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is already full
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BucketOverFull(BucketFull):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is overfull
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket:
|
||||||
|
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
||||||
|
|
||||||
|
def __init__(self, max_level, empty_time):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
self.leak_rate = max_level / empty_time
|
||||||
|
|
||||||
|
self._level = 0
|
||||||
|
self._last_checked = time.monotonic()
|
||||||
|
|
||||||
|
self._last_full = False
|
||||||
|
self._wait_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether the bucket is 'full',
|
||||||
|
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
||||||
|
"""
|
||||||
|
self._leak()
|
||||||
|
return self._level + 1 > self.max_level
|
||||||
|
|
||||||
|
@property
|
||||||
|
def overfull(self):
|
||||||
|
self._leak()
|
||||||
|
return self._level > self.max_level
|
||||||
|
|
||||||
|
@property
|
||||||
|
def delay(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level + 1 > self.max_level:
|
||||||
|
delay = (self._level + 1 - self.max_level) * self.leak_rate
|
||||||
|
else:
|
||||||
|
delay = 0
|
||||||
|
return delay
|
||||||
|
|
||||||
|
def _leak(self):
|
||||||
|
if self._level:
|
||||||
|
elapsed = time.monotonic() - self._last_checked
|
||||||
|
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||||
|
|
||||||
|
self._last_checked = time.monotonic()
|
||||||
|
|
||||||
|
def request(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level > self.max_level:
|
||||||
|
raise BucketOverFull
|
||||||
|
elif self._level == self.max_level:
|
||||||
|
self._level += 1
|
||||||
|
if self._last_full:
|
||||||
|
raise BucketOverFull
|
||||||
|
else:
|
||||||
|
self._last_full = True
|
||||||
|
raise BucketFull
|
||||||
|
else:
|
||||||
|
self._last_full = False
|
||||||
|
self._level += 1
|
||||||
|
|
||||||
|
def fill(self):
|
||||||
|
self._leak()
|
||||||
|
self._level = max(self._level, self.max_level + 1)
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""
|
||||||
|
Wait until the bucket has room.
|
||||||
|
|
||||||
|
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
||||||
|
"""
|
||||||
|
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||||
|
# Otherwise multiple waiters will have the same delay,
|
||||||
|
# and race for the wakeup after sleep.
|
||||||
|
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
|
||||||
|
async with self._wait_lock:
|
||||||
|
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||||
|
# or a synchronous request overflows the bucket while we are waiting.
|
||||||
|
while self.full:
|
||||||
|
await asyncio.sleep(self.delay)
|
||||||
|
|
||||||
|
async def wrapped(self, coro):
|
||||||
|
await self.wait()
|
||||||
|
self.request()
|
||||||
|
await coro
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimit:
|
||||||
|
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
|
||||||
|
self.error = error or "Too many requests, please slow down!"
|
||||||
|
self.buckets = cache
|
||||||
|
|
||||||
|
def request_for(self, key):
|
||||||
|
if not (bucket := self.buckets.get(key, None)):
|
||||||
|
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||||
|
|
||||||
|
bucket.request()
|
||||||
|
|
||||||
|
def ward(self, member=True, key=None):
|
||||||
|
"""
|
||||||
|
Command ratelimit decorator.
|
||||||
|
"""
|
||||||
|
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(ctx, *args, **kwargs):
|
||||||
|
self.request_for(key(ctx))
|
||||||
|
return await func(ctx, *args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def limit_concurrency(aws, limit):
|
||||||
|
"""
|
||||||
|
Run provided awaitables concurrently,
|
||||||
|
ensuring that no more than `limit` are running at once.
|
||||||
|
"""
|
||||||
|
aws = iter(aws)
|
||||||
|
aws_ended = False
|
||||||
|
pending = set()
|
||||||
|
count = 0
|
||||||
|
logger.debug("Starting limited concurrency executor")
|
||||||
|
|
||||||
|
while pending or not aws_ended:
|
||||||
|
while len(pending) < limit and not aws_ended:
|
||||||
|
aw = next(aws, None)
|
||||||
|
if aw is None:
|
||||||
|
aws_ended = True
|
||||||
|
else:
|
||||||
|
pending.add(asyncio.create_task(aw))
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
break
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
pending, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
while done:
|
||||||
|
yield done.pop()
|
||||||
|
logger.debug(f"Completed {count} tasks")
|
||||||
Reference in New Issue
Block a user