rewrite: Core framework.

This commit is contained in:
2022-11-11 08:04:23 +02:00
parent 2121749238
commit 7249e25975
14 changed files with 715 additions and 151 deletions

View File

@@ -5,8 +5,8 @@ import discord
from discord.ext import commands from discord.ext import commands
from meta import LionBot, conf, sharding, appname, shard_talk from meta import LionBot, conf, sharding, appname, shard_talk
from meta.logger import log_context, log_action from meta.logger import log_context, log_action_stack, logging_context
from meta.context import context from meta.context import ctx_bot
from data import Database from data import Database
@@ -16,7 +16,7 @@ from constants import DATA_VERSION
# Note: This MUST be imported after core, due to table definition orders # Note: This MUST be imported after core, due to table definition orders
# from settings import AppSettings # from settings import AppSettings
log_context.set(f"APP: {appname}") # log_context.set(f"APP: {appname}")
# client.appdata = core.data.meta.fetch_or_create(appname) # client.appdata = core.data.meta.fetch_or_create(appname)
@@ -36,7 +36,7 @@ db = Database(conf.data['args'])
async def main(): async def main():
log_action.set("Initialising") log_action_stack.set(["Initialising"])
logger.info("Initialising StudyLion") logger.info("Initialising StudyLion")
intents = discord.Intents.all() intents = discord.Intents.all()
@@ -49,6 +49,7 @@ async def main():
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error) logger.critical(error)
raise RuntimeError(error) raise RuntimeError(error)
async with LionBot( async with LionBot(
command_prefix=commands.when_mentioned, command_prefix=commands.when_mentioned,
intents=intents, intents=intents,
@@ -58,15 +59,33 @@ async def main():
initial_extensions=['modules'], initial_extensions=['modules'],
web_client=None, web_client=None,
app_ipc=shard_talk, app_ipc=shard_talk,
testing_guilds=[889875661848723456], testing_guilds=[889875661848723456, 879411098384752672],
shard_id=sharding.shard_number, shard_id=sharding.shard_number,
shard_count=sharding.shard_count shard_count=sharding.shard_count
) as lionbot: ) as lionbot:
context.get().bot = lionbot ctx_bot.set(lionbot)
@lionbot.before_invoke try:
async def before_invoke(ctx): with logging_context(context=f"APP: {appname}"):
print(ctx) logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
log_action.set("Launching")
await lionbot.start(conf.bot['TOKEN']) await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError:
with logging_context(context=f"APP: {appname}", action="Shutting Down"):
logger.info("StudyLion closed, shutting down.", exc_info=True)
asyncio.run(main())
def _main():
from signal import SIGINT, SIGTERM
loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(main())
for signal in [SIGINT, SIGTERM]:
loop.add_signal_handler(signal, main_task.cancel)
try:
loop.run_until_complete(main_task)
finally:
loop.close()
logging.shutdown()
if __name__ == '__main__':
_main()

View File

@@ -1,27 +1,33 @@
from typing import List, Optional, Dict from typing import List
import logging
import asyncio
import discord import discord
from discord.ext import commands from discord.utils import MISSING
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
from aiohttp import ClientSession from aiohttp import ClientSession
from data import Database from data import Database
from .config import Conf from .config import Conf
from .logger import logging_context, log_context, log_action_stack
from .context import context
from .LionContext import LionContext
from .LionTree import LionTree
from .errors import HandledException, SafeCancellation
logger = logging.getLogger(__name__)
class LionBot(commands.Bot): class LionBot(Bot):
def __init__( def __init__(
self, self, *args, appname: str, db: Database, config: Conf,
*args, initial_extensions: List[str], web_client: ClientSession, app_ipc,
appname: str, testing_guilds: List[int] = [], **kwargs
db: Database,
config: Conf,
initial_extensions: List[str],
web_client: ClientSession,
app_ipc,
testing_guilds: List[int] = [],
**kwargs,
): ):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.web_client = web_client self.web_client = web_client
self.testing_guilds = testing_guilds self.testing_guilds = testing_guilds
@@ -33,6 +39,7 @@ class LionBot(commands.Bot):
self.app_ipc = app_ipc self.app_ipc = app_ipc
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
await self.app_ipc.connect() await self.app_ipc.connect()
for extension in self.initial_extensions: for extension in self.initial_extensions:
@@ -42,3 +49,134 @@ class LionBot(commands.Bot):
guild = discord.Object(guildid) guild = discord.Object(guildid)
self.tree.copy_global_to(guild=guild) self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild) await self.tree.sync(guild=guild)
async def add_cog(self, cog: Cog, **kwargs):
with logging_context(action=f"Attach {cog.__cog_name__}"):
logger.info(f"Attaching Cog {cog.__cog_name__}")
await super().add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
async def load_extension(self, name, *, package=None, **kwargs):
with logging_context(action=f"Load {name.strip('.')}"):
logger.info(f"Loading extension {name} in package {package}.")
await super().load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.")
async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"):
await self.login(token)
with logging_context(stack=["Running"]):
await self.connect(reconnect=reconnect)
async def on_ready(self):
logger.info(
f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.appname}\n"
"------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
"Ready to take commands!\n",
extra={'action': 'Ready'}
)
async def get_context(self, origin, /, *, cls=MISSING):
if cls is MISSING:
cls = LionContext
ctx = await super().get_context(origin, cls=cls)
context.set(ctx)
return ctx
async def on_command(self, ctx: LionContext):
logger.info(
f"Executing command '{ctx.command.qualified_name}' (from module '{ctx.cog.__cog_name__}') "
f"with arguments {ctx.args} and kwargs {ctx.kwargs}.",
extra={'with_ctx': True}
)
async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback
cmd_str = str(ctx.command)
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
try:
raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
original = exception.original
try:
raise original
except HandledException:
pass
except SafeCancellation:
if original.msg:
try:
await ctx.error_reply(original.msg)
except Exception:
pass
logger.debug(
f"Caught a safe cancellation: {original.details}",
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.Forbidden:
# Unknown uncaught Forbidden
try:
# Attempt a general error reply
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
except Exception:
# We can't send anything at all. Exit quietly, but log.
logger.warning(
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.HTTPException:
logger.warning(
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass
except Exception:
logger.exception(
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
exc_info=exception,
extra={'action': 'BotError', 'with_ctx': True}
)
error_embed = discord.Embed(title="Something went wrong!")
error_embed.description = (
"An unexpected error occurred while processing your command!\n"
"Our development team has been notified, and the issue should be fixed soon.\n"
"If the error persists, please contact our support team and give them the following number: "
f"`{ctx.interaction.id}`"
)
try:
await ctx.error_reply(embed=error_embed)
except Exception:
pass
finally:
exception.original = HandledException(exception.original)
except CheckFailure:
logger.debug(
f"Command failed check: {exception}",
extra={'action': 'BotError', 'with_ctx': True}
)
try:
await ctx.error_rely(exception.message)
except Exception:
pass
except Exception:
# Completely unknown exception outside of command invocation!
# Something is very wrong here, don't attempt user interaction.
logger.exception(
f"Caught an unknown top-level exception while executing: {cmd_str}",
exc_info=exception,
extra={'action': 'BotError', 'with_ctx': True}
)

5
bot/meta/LionCog.py Normal file
View File

@@ -0,0 +1,5 @@
from discord.ext.commands import Cog
class LionCog(Cog):
...

184
bot/meta/LionContext.py Normal file
View File

@@ -0,0 +1,184 @@
import types
import logging
from collections import namedtuple
from typing import Optional
import discord
from discord.ext.commands import Context
logger = logging.getLogger(__name__)
"""
Stuff that might be useful to implement (see cmdClient):
sent_messages cache
tasks cache
error reply
usage
interaction cache
View cache?
setting access
"""
FlatContext = namedtuple(
'FlatContext',
('message',
'interaction',
'guild',
'author',
'alias',
'prefix',
'failed')
)
class LionContext(Context):
"""
Represents the context a command is invoked under.
Extends Context to add Lion-specific methods and attributes.
Also adds several contextual wrapped utilities for simpler user during command invocation.
"""
def __repr__(self):
parts = {}
if self.interaction is not None:
parts['iid'] = self.interaction.id
parts['itype'] = f"\"{self.interaction.type.name}\""
if self.message is not None:
parts['mid'] = self.message.id
if self.author is not None:
parts['uid'] = self.author.id
parts['uname'] = f"\"{self.author.name}\""
if self.channel is not None:
parts['cid'] = self.channel.id
parts['cname'] = f"\"{self.channel.name}\""
if self.guild is not None:
parts['gid'] = self.guild.id
parts['gname'] = f"\"{self.guild.name}\""
if self.command is not None:
parts['cmd'] = f"\"{self.command.qualified_name}\""
if self.invoked_with is not None:
parts['alias'] = f"\"{self.invoked_with}\""
if self.command_failed:
parts['failed'] = self.command_failed
return "<LionContext: {}>".format(
' '.join(f"{name}={value}" for name, value in parts.items())
)
def flatten(self):
"""Flat pure-data context information, for caching and logging."""
return FlatContext(
self.message.id,
self.interaction.id if self.interaction is not None else None,
self.guild.id if self.guild is not None else None,
self.author.id if self.author is not None else None,
self.channel.id if self.channel is not None else None,
self.invoked_with,
self.prefix,
self.command_failed
)
@classmethod
def util(cls, util_func):
"""
Decorator to make a utility function available as a Context instance method.
"""
setattr(cls, util_func.__name__, util_func)
logger.debug(f"Attached context utility function: {util_func.__name__}")
return util_func
@classmethod
def wrappable_util(cls, util_func):
"""
Decorator to add a Wrappable utility function as a Context instance method.
"""
wrapped = Wrappable(util_func)
setattr(cls, util_func.__name__, wrapped)
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
return wrapped
async def error_reply(self, content: Optional[str] = None, **kwargs):
if content and 'embed' not in kwargs:
embed = discord.Embed(
colour=discord.Colour.red(),
description=content
)
kwargs['embed'] = embed
content = None
# Expect this may be run in highly unusual circumstances.
# This should never error, or at least handle all errors.
try:
await self.reply(content=content, **kwargs)
except discord.HTTPException:
pass
except Exception:
logger.exception(
"Unknown exception in 'error_reply'.",
extra={'action': 'error_reply', 'ctx': self, 'with_ctx': True}
)
class Wrappable:
__slots__ = ('_func', 'wrappers')
def __init__(self, func):
self._func = func
self.wrappers = None
@property
def __name__(self):
return self._func.__name__
def add_wrapper(self, func, name=None):
self.wrappers = self.wrappers or {}
name = name or func.__name__
self.wrappers[name] = func
logger.debug(
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
extra={'action': "Wrap Util"}
)
def remove_wrapper(self, name):
if not self.wrappers or name not in self.wrappers:
raise ValueError(
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
)
self.wrappers.pop(name)
logger.debug(
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
extra={'action': "Unwrap Util"}
)
def __call__(self, *args, **kwargs):
if self.wrappers:
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
else:
return self._func(*args, **kwargs)
def _wrapped(self, iter_wraps):
next_wrap = next(iter_wraps, None)
if next_wrap:
def _func(*args, **kwargs):
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
else:
_func = self._func
return _func
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
return types.MethodType(self, instance)
LionContext.reply = Wrappable(LionContext.reply)
# @LionContext.reply.add_wrapper
# async def think(func, ctx, *args, **kwargs):
# await ctx.channel.send("thinking")
# await func(ctx, *args, **kwargs)

79
bot/meta/LionTree.py Normal file
View File

@@ -0,0 +1,79 @@
import logging
from discord import Interaction
from discord.app_commands import CommandTree
from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace
from .logger import logging_context
from .errors import SafeCancellation
logger = logging.getLogger(__name__)
class LionTree(CommandTree):
async def on_error(self, interaction, error) -> None:
try:
if isinstance(error, CommandInvokeError):
raise error.original
else:
raise error
except SafeCancellation:
# Assume this has already been handled
pass
except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
async def _call(self, interaction):
with logging_context(context=f"iid: {interaction.id}"):
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above
interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
with logging_context(action=f"Acmp {command.qualified_name}"):
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
await command._invoke_autocomplete(interaction, focused, namespace)
return
with logging_context(action=f"Run {command.qualified_name}"):
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")

View File

@@ -29,7 +29,7 @@ class configEmoji(PartialEmoji):
name=name, name=name,
fallback=PartialEmoji(name=fallback) if fallback is not None else None, fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated), animated=bool(animated),
id=int(id) id=int(id) if id else None
) )

View File

@@ -1,51 +1,20 @@
"""
Namespace for various global context variables.
Allows asyncio callbacks to accurately retrieve information about the current state.
"""
from typing import TYPE_CHECKING, Optional
from contextvars import ContextVar from contextvars import ContextVar
if TYPE_CHECKING:
class Context: from .LionBot import LionBot
__slots__ = ( from .LionContext import LionContext
'bot',
'interaction', 'message',
'guild', 'channel', 'author', 'user'
)
def __init__(self, **kwargs):
self.bot = kwargs.pop('bot', None)
self.interaction = interaction = kwargs.pop('interaction', None)
self.message = message = kwargs.pop('message', interaction.message if interaction is not None else None)
guild = kwargs.pop('guild', None)
channel = kwargs.pop('channel', None)
author = kwargs.pop('author', None)
if message is not None:
guild = guild or message.guild
channel = channel or message.channel
author = author or message.author
elif interaction is not None:
guild = guild or interaction.guild
channel = channel or interaction.channel
author = author or interaction.user
self.guild = guild
self.channel = channel
self.author = self.user = author
def log_string(self):
"""Markdown formatted summary for live logging."""
parts = []
if self.interaction is not None:
parts.append(f"<int id={self.interaction.id} type={self.interaction.type.name}>")
if self.message is not None:
parts.append(f"<msg id={self.message.id}>")
if self.author is not None:
parts.append(f"<user id={self.author.id} name='{self.author.name}'>")
if self.channel is not None:
parts.append(f"<chan id={self.channel.id} name='{self.channel.name}'>")
if self.guild is not None:
parts.append(f"<guild id={self.guild.id} name='{self.guild.name}'>")
return " ".join(parts)
context = ContextVar('context', default=Context()) # Contains the current command context, if applicable
context: Optional['LionContext'] = ContextVar('context', default=None)
# Contains the current LionBot instance
ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None)

51
bot/meta/errors.py Normal file
View File

@@ -0,0 +1,51 @@
from typing import Optional
class SafeCancellation(Exception):
"""
Raised to safely cancel execution of the current operation.
If not caught, is expected to be propagated to the Tree and safely ignored there.
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
The error handler should then set the `msg` to None, to avoid double handling.
Debugging information should go in `details`, to be logged by a top-level error handler.
"""
default_message = ""
def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self.msg: Optional[str] = msg if msg is not None else self.default_message
self.details: str = details if details is not None else self.msg
super().__init__(**kwargs)
class UserInputError(SafeCancellation):
"""
A SafeCancellation induced from unparseable user input.
"""
default_message = "Could not understand your input."
class UserCancelled(SafeCancellation):
"""
A SafeCancellation induced from manual user cancellation.
Usually silent.
"""
default_msg = None
class ResponseTimedOut(SafeCancellation):
"""
A SafeCancellation induced from a user interaction time-out.
"""
default_msg = "Session timed out waiting for input."
class HandledException(SafeCancellation):
"""
Sentinel class to indicate to error handlers that this exception has been handled.
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
"""
def __init__(self, exc=None, **kwargs):
self.exc = exc
super().__init__(**kwargs)

View File

@@ -1,17 +1,19 @@
from typing import Optional from typing import Optional, TypeAlias, Any
import asyncio import asyncio
import logging import logging
import pickle import pickle
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AppClient: Address: TypeAlias = dict[str, Any]
routes = {} # route_name -> Callable[Any, Awaitable[Any]]
def __init__(self, appid, client_address, server_address):
class AppClient:
routes: dict[str, 'AppRoute'] = {} # route_name -> Callable[Any, Awaitable[Any]]
def __init__(self, appid: str, client_address: Address, server_address: Address):
self.appid = appid self.appid = appid
self.address = client_address self.address = client_address
self.server_address = server_address self.server_address = server_address

View File

@@ -4,7 +4,7 @@ import logging
import string import string
import random import random
from ..logger import log_action, log_context, log_app from ..logger import log_context, log_app, logging_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,7 +71,7 @@ class AppServer:
""" """
Register and hold a new client connection. Register and hold a new client connection.
""" """
log_action.set("CONN " + appid) with logging_context(action=f"CONN {appid}"):
reader, writer = connection reader, writer = connection
# Add the new client # Add the new client
self.clients[appid] = (address, connection) self.clients[appid] = (address, connection)
@@ -98,9 +98,8 @@ class AppServer:
route, args, kwargs = pickle.loads(data) route, args, kwargs = pickle.loads(data)
rqid = short_uuid() rqid = short_uuid()
log_context.set("RQID:" + rqid)
log_action.set("SERV ROUTE " + route)
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"):
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}") logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes: if route in self.routes:
@@ -120,6 +119,7 @@ class AppServer:
await self.broadcast('drop_peer', (), {'appid': appid}) await self.broadcast('drop_peer', (), {'appid': appid})
async def broadcast(self, route, args, kwargs): async def broadcast(self, route, args, kwargs):
with logging_context(action="broadcast"):
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.") logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
payload = pickle.dumps((route, args, kwargs)) payload = pickle.dumps((route, args, kwargs))
if self.clients: if self.clients:
@@ -132,6 +132,7 @@ class AppServer:
""" """
Send a message to client `appid` along `route` with given arguments. Send a message to client `appid` along `route` with given arguments.
""" """
with logging_context(action=f"MSG {appid}"):
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.") logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
if appid not in self.clients: if appid not in self.clients:
raise ValueError(f"Client '{appid}' is not connected.") raise ValueError(f"Client '{appid}' is not connected.")
@@ -157,6 +158,7 @@ class AppServer:
async def start(self, address): async def start(self, address):
log_app.set("APPSERVER") log_app.set("APPSERVER")
with logging_context(stack=["SERV"]):
server = await asyncio.start_server(self.handle_connection, **address) server = await asyncio.start_server(self.handle_connection, **address)
logger.info(f"Serving on {address}") logger.info(f"Serving on {address}")
async with server: async with server:

View File

@@ -1,6 +1,7 @@
import sys import sys
import logging import logging
import asyncio import asyncio
from typing import List
from logging.handlers import QueueListener, QueueHandler from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue from queue import SimpleQueue
from contextlib import contextmanager from contextlib import contextmanager
@@ -16,24 +17,33 @@ from .context import context
from utils.lib import utc_now from utils.lib import utc_now
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action: ContextVar[str] = ContextVar('logging_action', default='UNKNOWN ACTION') log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[])
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number)) log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
@contextmanager @contextmanager
def logging_context(context=None, action=None): def logging_context(context=None, action=None, stack=None):
if context is not None: if context is not None:
context_t = log_context.set(context) context_t = log_context.set(context)
if action is not None: if action is not None:
action_t = log_action.set(action) astack = log_action_stack.get()
log_action_stack.set(astack + [action])
if stack is not None:
actions_t = log_action_stack.set(stack)
try: try:
yield yield
finally: finally:
if context is not None: if context is not None:
log_context.reset(context_t) log_context.reset(context_t)
if stack is not None:
log_action_stack.reset(actions_t)
if action is not None: if action is not None:
log_action.reset(action_t) log_action_stack.set(astack)
RESET_SEQ = "\033[0m" RESET_SEQ = "\033[0m"
@@ -63,10 +73,10 @@ def colour_escape(fmt: str) -> str:
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' + log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
'[%(cyan)%(app)-15s%(reset)]' + '[%(cyan)%(app)-15s%(reset)]' +
'[%(cyan)%(context)-22s%(reset)]' + '[%(cyan)%(context)-24s%(reset)]' +
'[%(cyan)%(action)-22s%(reset)]' + '[%(cyan)%(actionstr)-22s%(reset)]' +
' %(bold)%(cyan)%(name)s:%(reset)' + ' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(reset)') ' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format) log_format = colour_escape(log_format)
@@ -74,7 +84,7 @@ log_format = colour_escape(log_format)
logger = logging.getLogger() logger = logging.getLogger()
log_fmt = logging.Formatter( log_fmt = logging.Formatter(
fmt=log_format, fmt=log_format,
datefmt='%Y-%m-%d %H:%M:%S' # datefmt='%Y-%m-%d %H:%M:%S'
) )
logger.setLevel(logging.NOTSET) logger.setLevel(logging.NOTSET)
@@ -89,14 +99,45 @@ class LessThanFilter(logging.Filter):
return 1 if record.levelno < self.max_level else 0 return 1 if record.levelno < self.max_level else 0
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter): class ContextInjection(logging.Filter):
def filter(self, record): def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'): if not hasattr(record, 'context'):
record.context = log_context.get() record.context = log_context.get()
if not hasattr(record, 'action'):
record.action = log_action.get() if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get() record.app = log_app.get()
record.ctx = context.get().log_string()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True return True
@@ -106,12 +147,14 @@ logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING)) logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logging_handler_out.addFilter(ContextInjection()) logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out) logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr) logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING) logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt) logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection()) logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err) logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler): class LocalQueueHandler(QueueHandler):
@@ -127,7 +170,7 @@ class LocalQueueHandler(QueueHandler):
class WebHookHandler(logging.StreamHandler): class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, batch=False, loop=None): def __init__(self, webhook_url, batch=False, loop=None):
super().__init__(self) super().__init__()
self.webhook_url = webhook_url self.webhook_url = webhook_url
self.batched = "" self.batched = ""
self.batch = batch self.batch = batch
@@ -150,9 +193,13 @@ class WebHookHandler(logging.StreamHandler):
asyncio.create_task(self.post(record)) asyncio.create_task(self.post(record))
async def post(self, record): async def post(self, record):
log_context.set("Webhook Logger")
log_action_stack.set(["Logging"])
log_app.set(record.app)
try: try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]" header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else "" context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}" message = f"{header}\n{record.msg}{context}"
@@ -164,12 +211,12 @@ class WebHookHandler(logging.StreamHandler):
# Post the log message(s) # Post the log message(s)
if self.batch: if self.batch:
if len(message) > 1000: if len(message) > 1500:
await self._send_batched_now() await self._send_batched_now()
await self._send(message, as_file=as_file) await self._send(message, as_file=as_file)
else: else:
self.batched += message self.batched += message
if len(self.batched) + len(message) > 1000: if len(self.batched) + len(message) > 1500:
await self._send_batched_now() await self._send_batched_now()
else: else:
asyncio.create_task(self._schedule_batched()) asyncio.create_task(self._schedule_batched())
@@ -209,9 +256,12 @@ class WebHookHandler(logging.StreamHandler):
if as_file or len(message) > 2000: if as_file or len(message) > 2000:
with StringIO(message) as fp: with StringIO(message) as fp:
fp.seek(0) fp.seek(0)
await webhook.send(file=File(fp, filename="logs.md")) await webhook.send(
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else: else:
await webhook.send(message) await webhook.send(message, username=log_app.get())
handlers = [] handlers = []
@@ -254,6 +304,7 @@ if handlers:
qhandler = QueueHandler(queue) qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO) qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection()) qhandler.addFilter(ContextInjection())
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler) logger.addHandler(qhandler)
listener = QueueListener( listener = QueueListener(

0
bot/utils/__init__.py Normal file
View File

View File

@@ -1,6 +1,7 @@
import datetime import datetime
import iso8601 # type: ignore import iso8601 # type: ignore
import re import re
from contextvars import Context
import discord import discord
@@ -439,15 +440,6 @@ def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
) )
class DotDict(dict):
"""
Dict-type allowing dot access to keys.
"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def utc_now() -> datetime.datetime: def utc_now() -> datetime.datetime:
""" """
Return the current timezone-aware utc timestamp. Return the current timezone-aware utc timestamp.
@@ -464,3 +456,8 @@ def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string) return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else: else:
return string return string
def recover_context(context: Context):
for var in context:
var.set(context[var])

67
bot/utils/ui.py Normal file
View File

@@ -0,0 +1,67 @@
from typing import List, Coroutine
import asyncio
import logging
from contextvars import copy_context
import discord
from discord.ui import Modal
from .lib import recover_context
class FastModal(Modal):
def __init__(self, *items, **kwargs):
super().__init__(**kwargs)
for item in items:
self.add_item(item)
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
self._waiters: List[Coroutine[discord.Interaction]] = []
self._context = copy_context()
async def wait_for(self, check=None, timeout=None):
# Wait for _result or timeout
# If we timeout, or the view times out, raise TimeoutError
# Otherwise, return the Interaction
# This allows multiple listeners and callbacks to wait on
# TODO: Wait on the timeout as well
while True:
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
if check is not None:
if not check(result):
continue
return result
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
def wrapper(coro):
async def wrapped_callback(interaction):
if check is not None:
if not check(interaction):
return
try:
await coro(interaction, *pass_args, **pass_kwargs)
except Exception:
# TODO: Log exception
logging.exception(
f"Exception occurred executing FastModal callback '{coro.__name__}'"
)
if once:
self._waiters.remove(wrapped_callback)
self._waiters.append(wrapped_callback)
return wrapper
async def on_submit(self, interaction):
# Transitional patch to re-instantiate the current context
# Not required in py 3.11, instead pass a context parameter to create_task
recover_context(self._context)
old_result = self._result
self._result = asyncio.get_event_loop().create_future()
old_result.set_result(interaction)
for waiter in self._waiters:
asyncio.create_task(waiter(interaction))
async def on_error(self, interaction, error):
# This should never happen, since on_submit has its own error handling
# TODO: Logging
logging.error("Submit error occured in FastModal")