From 7249e25975961ece9bee2dcc10bfa3e5a394b06b Mon Sep 17 00:00:00 2001 From: Conatum Date: Fri, 11 Nov 2022 08:04:23 +0200 Subject: [PATCH] rewrite: Core framework. --- bot/main.py | 43 +++++++--- bot/meta/LionBot.py | 164 ++++++++++++++++++++++++++++++++--- bot/meta/LionCog.py | 5 ++ bot/meta/LionContext.py | 184 ++++++++++++++++++++++++++++++++++++++++ bot/meta/LionTree.py | 79 +++++++++++++++++ bot/meta/config.py | 2 +- bot/meta/context.py | 63 ++++---------- bot/meta/errors.py | 51 +++++++++++ bot/meta/ipc/client.py | 12 +-- bot/meta/ipc/server.py | 94 ++++++++++---------- bot/meta/logger.py | 87 +++++++++++++++---- bot/utils/__init__.py | 0 bot/utils/lib.py | 15 ++-- bot/utils/ui.py | 67 +++++++++++++++ 14 files changed, 715 insertions(+), 151 deletions(-) create mode 100644 bot/meta/LionCog.py create mode 100644 bot/meta/LionContext.py create mode 100644 bot/meta/LionTree.py create mode 100644 bot/meta/errors.py create mode 100644 bot/utils/__init__.py create mode 100644 bot/utils/ui.py diff --git a/bot/main.py b/bot/main.py index d6279577..5282c26e 100644 --- a/bot/main.py +++ b/bot/main.py @@ -5,8 +5,8 @@ import discord from discord.ext import commands from meta import LionBot, conf, sharding, appname, shard_talk -from meta.logger import log_context, log_action -from meta.context import context +from meta.logger import log_context, log_action_stack, logging_context +from meta.context import ctx_bot from data import Database @@ -16,7 +16,7 @@ from constants import DATA_VERSION # Note: This MUST be imported after core, due to table definition orders # from settings import AppSettings -log_context.set(f"APP: {appname}") +# log_context.set(f"APP: {appname}") # client.appdata = core.data.meta.fetch_or_create(appname) @@ -36,7 +36,7 @@ db = Database(conf.data['args']) async def main(): - log_action.set("Initialising") + log_action_stack.set(["Initialising"]) logger.info("Initialising StudyLion") intents = discord.Intents.all() @@ -49,6 +49,7 @@ async def main(): error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." logger.critical(error) raise RuntimeError(error) + async with LionBot( command_prefix=commands.when_mentioned, intents=intents, @@ -58,15 +59,33 @@ async def main(): initial_extensions=['modules'], web_client=None, app_ipc=shard_talk, - testing_guilds=[889875661848723456], + testing_guilds=[889875661848723456, 879411098384752672], shard_id=sharding.shard_number, shard_count=sharding.shard_count ) as lionbot: - context.get().bot = lionbot - @lionbot.before_invoke - async def before_invoke(ctx): - print(ctx) - log_action.set("Launching") - await lionbot.start(conf.bot['TOKEN']) + ctx_bot.set(lionbot) + try: + with logging_context(context=f"APP: {appname}"): + logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) + await lionbot.start(conf.bot['TOKEN']) + except asyncio.CancelledError: + with logging_context(context=f"APP: {appname}", action="Shutting Down"): + logger.info("StudyLion closed, shutting down.", exc_info=True) -asyncio.run(main()) + +def _main(): + from signal import SIGINT, SIGTERM + + loop = asyncio.get_event_loop() + main_task = asyncio.ensure_future(main()) + for signal in [SIGINT, SIGTERM]: + loop.add_signal_handler(signal, main_task.cancel) + try: + loop.run_until_complete(main_task) + finally: + loop.close() + logging.shutdown() + + +if __name__ == '__main__': + _main() diff --git a/bot/meta/LionBot.py b/bot/meta/LionBot.py index 30d4d3dc..4e8ce455 100644 --- a/bot/meta/LionBot.py +++ b/bot/meta/LionBot.py @@ -1,27 +1,33 @@ -from typing import List, Optional, Dict +from typing import List +import logging +import asyncio import discord -from discord.ext import commands +from discord.utils import MISSING +from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError +from discord.ext.commands.errors import CommandInvokeError, CheckFailure +from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError from aiohttp import ClientSession from data import Database from .config import Conf +from .logger import logging_context, log_context, log_action_stack +from .context import context +from .LionContext import LionContext +from .LionTree import LionTree +from .errors import HandledException, SafeCancellation + +logger = logging.getLogger(__name__) -class LionBot(commands.Bot): +class LionBot(Bot): def __init__( - self, - *args, - appname: str, - db: Database, - config: Conf, - initial_extensions: List[str], - web_client: ClientSession, - app_ipc, - testing_guilds: List[int] = [], - **kwargs, + self, *args, appname: str, db: Database, config: Conf, + initial_extensions: List[str], web_client: ClientSession, app_ipc, + testing_guilds: List[int] = [], **kwargs ): + kwargs.setdefault('tree_cls', LionTree) super().__init__(*args, **kwargs) self.web_client = web_client self.testing_guilds = testing_guilds @@ -33,6 +39,7 @@ class LionBot(commands.Bot): self.app_ipc = app_ipc async def setup_hook(self) -> None: + log_context.set(f"APP: {self.application_id}") await self.app_ipc.connect() for extension in self.initial_extensions: @@ -42,3 +49,134 @@ class LionBot(commands.Bot): guild = discord.Object(guildid) self.tree.copy_global_to(guild=guild) await self.tree.sync(guild=guild) + + async def add_cog(self, cog: Cog, **kwargs): + with logging_context(action=f"Attach {cog.__cog_name__}"): + logger.info(f"Attaching Cog {cog.__cog_name__}") + await super().add_cog(cog, **kwargs) + logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.") + + async def load_extension(self, name, *, package=None, **kwargs): + with logging_context(action=f"Load {name.strip('.')}"): + logger.info(f"Loading extension {name} in package {package}.") + await super().load_extension(name, package=package, **kwargs) + logger.debug(f"Loaded extension {name} in package {package}.") + + async def start(self, token: str, *, reconnect: bool = True): + with logging_context(action="Login"): + await self.login(token) + with logging_context(stack=["Running"]): + await self.connect(reconnect=reconnect) + + async def on_ready(self): + logger.info( + f"Logged in as {self.application.name}\n" + f"Application id {self.application.id}\n" + f"Shard Talk identifier {self.appname}\n" + "------------------------------\n" + f"Enabled Modules: {', '.join(self.extensions.keys())}\n" + f"Loaded Cogs: {', '.join(self.cogs.keys())}\n" + f"Listening for {sum(1 for _ in self.walk_commands())} commands\n" + "------------------------------\n" + f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n" + "Ready to take commands!\n", + extra={'action': 'Ready'} + ) + + async def get_context(self, origin, /, *, cls=MISSING): + if cls is MISSING: + cls = LionContext + ctx = await super().get_context(origin, cls=cls) + context.set(ctx) + return ctx + + async def on_command(self, ctx: LionContext): + logger.info( + f"Executing command '{ctx.command.qualified_name}' (from module '{ctx.cog.__cog_name__}') " + f"with arguments {ctx.args} and kwargs {ctx.kwargs}.", + extra={'with_ctx': True} + ) + + async def on_command_error(self, ctx, exception): + # TODO: Some of these could have more user-feedback + cmd_str = str(ctx.command) + if isinstance(ctx.command, HybridCommand) and ctx.command.app_command: + cmd_str = ctx.command.app_command.to_dict() + try: + raise exception + except (HybridCommandError, CommandInvokeError, appCommandInvokeError): + original = exception.original + try: + raise original + except HandledException: + pass + except SafeCancellation: + if original.msg: + try: + await ctx.error_reply(original.msg) + except Exception: + pass + logger.debug( + f"Caught a safe cancellation: {original.details}", + extra={'action': 'BotError', 'with_ctx': True} + ) + except discord.Forbidden: + # Unknown uncaught Forbidden + try: + # Attempt a general error reply + await ctx.reply("I don't have enough channel or server permissions to complete that command here!") + except Exception: + # We can't send anything at all. Exit quietly, but log. + logger.warning( + f"Caught an unhandled 'Forbidden' while executing: {cmd_str}", + exc_info=True, + extra={'action': 'BotError', 'with_ctx': True} + ) + except discord.HTTPException: + logger.warning( + f"Caught an unhandled 'HTTPException' while executing: {cmd_str}", + exc_info=True, + extra={'action': 'BotError', 'with_ctx': True} + ) + except asyncio.CancelledError: + pass + except asyncio.TimeoutError: + pass + except Exception: + logger.exception( + f"Caught an unknown CommandInvokeError while executing: {cmd_str}", + exc_info=exception, + extra={'action': 'BotError', 'with_ctx': True} + ) + + error_embed = discord.Embed(title="Something went wrong!") + error_embed.description = ( + "An unexpected error occurred while processing your command!\n" + "Our development team has been notified, and the issue should be fixed soon.\n" + "If the error persists, please contact our support team and give them the following number: " + f"`{ctx.interaction.id}`" + ) + + try: + await ctx.error_reply(embed=error_embed) + except Exception: + pass + finally: + exception.original = HandledException(exception.original) + except CheckFailure: + logger.debug( + f"Command failed check: {exception}", + extra={'action': 'BotError', 'with_ctx': True} + ) + try: + await ctx.error_rely(exception.message) + except Exception: + pass + except Exception: + # Completely unknown exception outside of command invocation! + # Something is very wrong here, don't attempt user interaction. + logger.exception( + f"Caught an unknown top-level exception while executing: {cmd_str}", + exc_info=exception, + extra={'action': 'BotError', 'with_ctx': True} + ) diff --git a/bot/meta/LionCog.py b/bot/meta/LionCog.py new file mode 100644 index 00000000..f705ce01 --- /dev/null +++ b/bot/meta/LionCog.py @@ -0,0 +1,5 @@ +from discord.ext.commands import Cog + + +class LionCog(Cog): + ... diff --git a/bot/meta/LionContext.py b/bot/meta/LionContext.py new file mode 100644 index 00000000..61aa5c57 --- /dev/null +++ b/bot/meta/LionContext.py @@ -0,0 +1,184 @@ +import types +import logging +from collections import namedtuple +from typing import Optional + +import discord +from discord.ext.commands import Context + + +logger = logging.getLogger(__name__) + + +""" +Stuff that might be useful to implement (see cmdClient): + sent_messages cache + tasks cache + error reply + usage + interaction cache + View cache? + setting access +""" + + +FlatContext = namedtuple( + 'FlatContext', + ('message', + 'interaction', + 'guild', + 'author', + 'alias', + 'prefix', + 'failed') +) + + +class LionContext(Context): + """ + Represents the context a command is invoked under. + + Extends Context to add Lion-specific methods and attributes. + Also adds several contextual wrapped utilities for simpler user during command invocation. + """ + def __repr__(self): + parts = {} + if self.interaction is not None: + parts['iid'] = self.interaction.id + parts['itype'] = f"\"{self.interaction.type.name}\"" + if self.message is not None: + parts['mid'] = self.message.id + if self.author is not None: + parts['uid'] = self.author.id + parts['uname'] = f"\"{self.author.name}\"" + if self.channel is not None: + parts['cid'] = self.channel.id + parts['cname'] = f"\"{self.channel.name}\"" + if self.guild is not None: + parts['gid'] = self.guild.id + parts['gname'] = f"\"{self.guild.name}\"" + if self.command is not None: + parts['cmd'] = f"\"{self.command.qualified_name}\"" + if self.invoked_with is not None: + parts['alias'] = f"\"{self.invoked_with}\"" + if self.command_failed: + parts['failed'] = self.command_failed + + return "".format( + ' '.join(f"{name}={value}" for name, value in parts.items()) + ) + + def flatten(self): + """Flat pure-data context information, for caching and logging.""" + return FlatContext( + self.message.id, + self.interaction.id if self.interaction is not None else None, + self.guild.id if self.guild is not None else None, + self.author.id if self.author is not None else None, + self.channel.id if self.channel is not None else None, + self.invoked_with, + self.prefix, + self.command_failed + ) + + @classmethod + def util(cls, util_func): + """ + Decorator to make a utility function available as a Context instance method. + """ + setattr(cls, util_func.__name__, util_func) + logger.debug(f"Attached context utility function: {util_func.__name__}") + return util_func + + @classmethod + def wrappable_util(cls, util_func): + """ + Decorator to add a Wrappable utility function as a Context instance method. + """ + wrapped = Wrappable(util_func) + setattr(cls, util_func.__name__, wrapped) + logger.debug(f"Attached wrappable context utility function: {util_func.__name__}") + return wrapped + + async def error_reply(self, content: Optional[str] = None, **kwargs): + if content and 'embed' not in kwargs: + embed = discord.Embed( + colour=discord.Colour.red(), + description=content + ) + kwargs['embed'] = embed + content = None + + # Expect this may be run in highly unusual circumstances. + # This should never error, or at least handle all errors. + try: + await self.reply(content=content, **kwargs) + except discord.HTTPException: + pass + except Exception: + logger.exception( + "Unknown exception in 'error_reply'.", + extra={'action': 'error_reply', 'ctx': self, 'with_ctx': True} + ) + + +class Wrappable: + __slots__ = ('_func', 'wrappers') + + def __init__(self, func): + self._func = func + self.wrappers = None + + @property + def __name__(self): + return self._func.__name__ + + def add_wrapper(self, func, name=None): + self.wrappers = self.wrappers or {} + name = name or func.__name__ + self.wrappers[name] = func + logger.debug( + f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.", + extra={'action': "Wrap Util"} + ) + + def remove_wrapper(self, name): + if not self.wrappers or name not in self.wrappers: + raise ValueError( + f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'" + ) + self.wrappers.pop(name) + logger.debug( + f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.", + extra={'action': "Unwrap Util"} + ) + + def __call__(self, *args, **kwargs): + if self.wrappers: + return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs) + else: + return self._func(*args, **kwargs) + + def _wrapped(self, iter_wraps): + next_wrap = next(iter_wraps, None) + if next_wrap: + def _func(*args, **kwargs): + return next_wrap(self._wrapped(iter_wraps), *args, **kwargs) + else: + _func = self._func + return _func + + def __get__(self, instance, cls=None): + if instance is None: + return self + else: + return types.MethodType(self, instance) + + +LionContext.reply = Wrappable(LionContext.reply) + + +# @LionContext.reply.add_wrapper +# async def think(func, ctx, *args, **kwargs): +# await ctx.channel.send("thinking") +# await func(ctx, *args, **kwargs) diff --git a/bot/meta/LionTree.py b/bot/meta/LionTree.py new file mode 100644 index 00000000..05710972 --- /dev/null +++ b/bot/meta/LionTree.py @@ -0,0 +1,79 @@ +import logging + +from discord import Interaction +from discord.app_commands import CommandTree +from discord.app_commands.errors import AppCommandError, CommandInvokeError +from discord.enums import InteractionType +from discord.app_commands.namespace import Namespace + +from .logger import logging_context +from .errors import SafeCancellation + +logger = logging.getLogger(__name__) + + +class LionTree(CommandTree): + async def on_error(self, interaction, error) -> None: + try: + if isinstance(error, CommandInvokeError): + raise error.original + else: + raise error + except SafeCancellation: + # Assume this has already been handled + pass + except Exception: + logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'}) + + async def _call(self, interaction): + with logging_context(context=f"iid: {interaction.id}"): + if not await self.interaction_check(interaction): + interaction.command_failed = True + return + + data = interaction.data # type: ignore + type = data.get('type', 1) + if type != 1: + # Context menu command... + await self._call_context_menu(interaction, data, type) + return + + command, options = self._get_app_command_options(data) + + # Pre-fill the cached slot to prevent re-computation + interaction._cs_command = command + + # At this point options refers to the arguments of the command + # and command refers to the class type we care about + namespace = Namespace(interaction, data.get('resolved', {}), options) + + # Same pre-fill as above + interaction._cs_namespace = namespace + + # Auto complete handles the namespace differently... so at this point this is where we decide where that is. + if interaction.type is InteractionType.autocomplete: + with logging_context(action=f"Acmp {command.qualified_name}"): + focused = next((opt['name'] for opt in options if opt.get('focused')), None) + if focused is None: + raise AppCommandError( + 'This should not happen, but there is no focused element. This is a Discord bug.' + ) + await command._invoke_autocomplete(interaction, focused, namespace) + return + + with logging_context(action=f"Run {command.qualified_name}"): + logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}") + try: + await command._invoke_with_namespace(interaction, namespace) + except AppCommandError as e: + interaction.command_failed = True + await command._invoke_error_handlers(interaction, e) + await self.on_error(interaction, e) + else: + if not interaction.command_failed: + self.client.dispatch('app_command_completion', interaction, command) + finally: + if interaction.command_failed: + logger.debug("Command completed with errors.") + else: + logger.debug("Command completed without errors.") diff --git a/bot/meta/config.py b/bot/meta/config.py index 61060544..92fef474 100644 --- a/bot/meta/config.py +++ b/bot/meta/config.py @@ -29,7 +29,7 @@ class configEmoji(PartialEmoji): name=name, fallback=PartialEmoji(name=fallback) if fallback is not None else None, animated=bool(animated), - id=int(id) + id=int(id) if id else None ) diff --git a/bot/meta/context.py b/bot/meta/context.py index 4546a125..3841d4d0 100644 --- a/bot/meta/context.py +++ b/bot/meta/context.py @@ -1,51 +1,20 @@ +""" +Namespace for various global context variables. +Allows asyncio callbacks to accurately retrieve information about the current state. +""" + + +from typing import TYPE_CHECKING, Optional + from contextvars import ContextVar - -class Context: - __slots__ = ( - 'bot', - 'interaction', 'message', - 'guild', 'channel', 'author', 'user' - ) - - def __init__(self, **kwargs): - self.bot = kwargs.pop('bot', None) - - self.interaction = interaction = kwargs.pop('interaction', None) - self.message = message = kwargs.pop('message', interaction.message if interaction is not None else None) - - guild = kwargs.pop('guild', None) - channel = kwargs.pop('channel', None) - author = kwargs.pop('author', None) - - if message is not None: - guild = guild or message.guild - channel = channel or message.channel - author = author or message.author - elif interaction is not None: - guild = guild or interaction.guild - channel = channel or interaction.channel - author = author or interaction.user - - self.guild = guild - self.channel = channel - self.author = self.user = author - - def log_string(self): - """Markdown formatted summary for live logging.""" - parts = [] - if self.interaction is not None: - parts.append(f"") - if self.message is not None: - parts.append(f"") - if self.author is not None: - parts.append(f"") - if self.channel is not None: - parts.append(f"") - if self.guild is not None: - parts.append(f"") - - return " ".join(parts) +if TYPE_CHECKING: + from .LionBot import LionBot + from .LionContext import LionContext -context = ContextVar('context', default=Context()) +# Contains the current command context, if applicable +context: Optional['LionContext'] = ContextVar('context', default=None) + +# Contains the current LionBot instance +ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None) diff --git a/bot/meta/errors.py b/bot/meta/errors.py new file mode 100644 index 00000000..d4be3e60 --- /dev/null +++ b/bot/meta/errors.py @@ -0,0 +1,51 @@ +from typing import Optional + + +class SafeCancellation(Exception): + """ + Raised to safely cancel execution of the current operation. + + If not caught, is expected to be propagated to the Tree and safely ignored there. + If a `msg` is provided, a context-aware error handler should catch and send the message to the user. + The error handler should then set the `msg` to None, to avoid double handling. + Debugging information should go in `details`, to be logged by a top-level error handler. + """ + default_message = "" + + def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs): + self.msg: Optional[str] = msg if msg is not None else self.default_message + self.details: str = details if details is not None else self.msg + super().__init__(**kwargs) + + +class UserInputError(SafeCancellation): + """ + A SafeCancellation induced from unparseable user input. + """ + default_message = "Could not understand your input." + + +class UserCancelled(SafeCancellation): + """ + A SafeCancellation induced from manual user cancellation. + + Usually silent. + """ + default_msg = None + + +class ResponseTimedOut(SafeCancellation): + """ + A SafeCancellation induced from a user interaction time-out. + """ + default_msg = "Session timed out waiting for input." + + +class HandledException(SafeCancellation): + """ + Sentinel class to indicate to error handlers that this exception has been handled. + Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler. + """ + def __init__(self, exc=None, **kwargs): + self.exc = exc + super().__init__(**kwargs) diff --git a/bot/meta/ipc/client.py b/bot/meta/ipc/client.py index d9127400..499b9406 100644 --- a/bot/meta/ipc/client.py +++ b/bot/meta/ipc/client.py @@ -1,17 +1,19 @@ -from typing import Optional +from typing import Optional, TypeAlias, Any import asyncio import logging import pickle -logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -class AppClient: - routes = {} # route_name -> Callable[Any, Awaitable[Any]] +Address: TypeAlias = dict[str, Any] - def __init__(self, appid, client_address, server_address): + +class AppClient: + routes: dict[str, 'AppRoute'] = {} # route_name -> Callable[Any, Awaitable[Any]] + + def __init__(self, appid: str, client_address: Address, server_address: Address): self.appid = appid self.address = client_address self.server_address = server_address diff --git a/bot/meta/ipc/server.py b/bot/meta/ipc/server.py index 6d9545f4..3e46e900 100644 --- a/bot/meta/ipc/server.py +++ b/bot/meta/ipc/server.py @@ -4,7 +4,7 @@ import logging import string import random -from ..logger import log_action, log_context, log_app +from ..logger import log_context, log_app, logging_context logger = logging.getLogger(__name__) @@ -71,46 +71,45 @@ class AppServer: """ Register and hold a new client connection. """ - log_action.set("CONN " + appid) - reader, writer = connection - # Add the new client - self.clients[appid] = (address, connection) + with logging_context(action=f"CONN {appid}"): + reader, writer = connection + # Add the new client + self.clients[appid] = (address, connection) - # Send the new client a client list - peers = self.peer_list() - writer.write(pickle.dumps(peers)) - writer.write(b'\n') - await writer.drain() + # Send the new client a client list + peers = self.peer_list() + writer.write(pickle.dumps(peers)) + writer.write(b'\n') + await writer.drain() - # Announce the new client to everyone - await self.broadcast('new_peer', (), {'appid': appid, 'address': address}) + # Announce the new client to everyone + await self.broadcast('new_peer', (), {'appid': appid, 'address': address}) - # Keep the connection open until socket closed or EOF (indicating client death) - try: - await reader.read() - finally: - # Connection ended or it broke - logger.info(f"Lost client '{appid}'") - await self.deregister_client(appid) + # Keep the connection open until socket closed or EOF (indicating client death) + try: + await reader.read() + finally: + # Connection ended or it broke + logger.info(f"Lost client '{appid}'") + await self.deregister_client(appid) async def handle_connection(self, reader, writer): data = await reader.readline() route, args, kwargs = pickle.loads(data) rqid = short_uuid() - log_context.set("RQID:" + rqid) - log_action.set("SERV ROUTE " + route) - logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}") + with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"): + logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}") - if route in self.routes: - # Execute route - try: - await self.routes[route]((reader, writer), *args, **kwargs) - except Exception: - logger.exception(f"AppServer recieved exception during route '{route}'") - else: - logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.") + if route in self.routes: + # Execute route + try: + await self.routes[route]((reader, writer), *args, **kwargs) + except Exception: + logger.exception(f"AppServer recieved exception during route '{route}'") + else: + logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.") def peer_list(self): return {appid: address for appid, (address, _) in self.clients.items()} @@ -120,24 +119,26 @@ class AppServer: await self.broadcast('drop_peer', (), {'appid': appid}) async def broadcast(self, route, args, kwargs): - logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.") - payload = pickle.dumps((route, args, kwargs)) - if self.clients: - await asyncio.gather( - *(self._send(appid, payload) for appid in self.clients), - return_exceptions=True - ) + with logging_context(action="broadcast"): + logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.") + payload = pickle.dumps((route, args, kwargs)) + if self.clients: + await asyncio.gather( + *(self._send(appid, payload) for appid in self.clients), + return_exceptions=True + ) async def message_client(self, appid, route, args, kwargs): """ Send a message to client `appid` along `route` with given arguments. """ - logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.") - if appid not in self.clients: - raise ValueError(f"Client '{appid}' is not connected.") + with logging_context(action=f"MSG {appid}"): + logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.") + if appid not in self.clients: + raise ValueError(f"Client '{appid}' is not connected.") - payload = pickle.dumps((route, args, kwargs)) - return await self._send(appid, payload) + payload = pickle.dumps((route, args, kwargs)) + return await self._send(appid, payload) async def _send(self, appid, payload): """ @@ -157,10 +158,11 @@ class AppServer: async def start(self, address): log_app.set("APPSERVER") - server = await asyncio.start_server(self.handle_connection, **address) - logger.info(f"Serving on {address}") - async with server: - await server.serve_forever() + with logging_context(stack=["SERV"]): + server = await asyncio.start_server(self.handle_connection, **address) + logger.info(f"Serving on {address}") + async with server: + await server.serve_forever() async def start_server(): diff --git a/bot/meta/logger.py b/bot/meta/logger.py index e9dc386c..0cdd7f7d 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -1,6 +1,7 @@ import sys import logging import asyncio +from typing import List from logging.handlers import QueueListener, QueueHandler from queue import SimpleQueue from contextlib import contextmanager @@ -16,24 +17,33 @@ from .context import context from utils.lib import utc_now +log_logger = logging.getLogger(__name__) +log_logger.propagate = False + + log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') -log_action: ContextVar[str] = ContextVar('logging_action', default='UNKNOWN ACTION') +log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[]) log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number)) @contextmanager -def logging_context(context=None, action=None): +def logging_context(context=None, action=None, stack=None): if context is not None: context_t = log_context.set(context) if action is not None: - action_t = log_action.set(action) + astack = log_action_stack.get() + log_action_stack.set(astack + [action]) + if stack is not None: + actions_t = log_action_stack.set(stack) try: yield finally: if context is not None: log_context.reset(context_t) + if stack is not None: + log_action_stack.reset(actions_t) if action is not None: - log_action.reset(action_t) + log_action_stack.set(astack) RESET_SEQ = "\033[0m" @@ -63,10 +73,10 @@ def colour_escape(fmt: str) -> str: log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' + '[%(cyan)%(app)-15s%(reset)]' + - '[%(cyan)%(context)-22s%(reset)]' + - '[%(cyan)%(action)-22s%(reset)]' + + '[%(cyan)%(context)-24s%(reset)]' + + '[%(cyan)%(actionstr)-22s%(reset)]' + ' %(bold)%(cyan)%(name)s:%(reset)' + - ' %(white)%(message)s%(reset)') + ' %(white)%(message)s%(ctxstr)s%(reset)') log_format = colour_escape(log_format) @@ -74,7 +84,7 @@ log_format = colour_escape(log_format) logger = logging.getLogger() log_fmt = logging.Formatter( fmt=log_format, - datefmt='%Y-%m-%d %H:%M:%S' + # datefmt='%Y-%m-%d %H:%M:%S' ) logger.setLevel(logging.NOTSET) @@ -89,14 +99,45 @@ class LessThanFilter(logging.Filter): return 1 if record.levelno < self.max_level else 0 +class ThreadFilter(logging.Filter): + def __init__(self, thread_name): + super().__init__("") + self.thread = thread_name + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.threadName == self.thread else 0 + + class ContextInjection(logging.Filter): def filter(self, record): + # These guards are to allow override through _extra + # And to ensure the injection is idempotent if not hasattr(record, 'context'): record.context = log_context.get() - if not hasattr(record, 'action'): - record.action = log_action.get() - record.app = log_app.get() - record.ctx = context.get().log_string() + + if not hasattr(record, 'actionstr'): + action_stack = log_action_stack.get() + if hasattr(record, 'action'): + action_stack = (*action_stack, record.action) + if action_stack: + record.actionstr = ' ➔ '.join(action_stack) + else: + record.actionstr = "Unknown Action" + + if not hasattr(record, 'app'): + record.app = log_app.get() + + if not hasattr(record, 'ctx'): + if ctx := context.get(): + record.ctx = repr(ctx) + else: + record.ctx = None + + if getattr(record, 'with_ctx', False) and record.ctx: + record.ctxstr = '\n' + record.ctx + else: + record.ctxstr = "" return True @@ -106,12 +147,14 @@ logging_handler_out.setFormatter(log_fmt) logging_handler_out.addFilter(LessThanFilter(logging.WARNING)) logging_handler_out.addFilter(ContextInjection()) logger.addHandler(logging_handler_out) +log_logger.addHandler(logging_handler_out) logging_handler_err = logging.StreamHandler(sys.stderr) logging_handler_err.setLevel(logging.WARNING) logging_handler_err.setFormatter(log_fmt) logging_handler_err.addFilter(ContextInjection()) logger.addHandler(logging_handler_err) +log_logger.addHandler(logging_handler_err) class LocalQueueHandler(QueueHandler): @@ -127,7 +170,7 @@ class LocalQueueHandler(QueueHandler): class WebHookHandler(logging.StreamHandler): def __init__(self, webhook_url, batch=False, loop=None): - super().__init__(self) + super().__init__() self.webhook_url = webhook_url self.batched = "" self.batch = batch @@ -150,9 +193,13 @@ class WebHookHandler(logging.StreamHandler): asyncio.create_task(self.post(record)) async def post(self, record): + log_context.set("Webhook Logger") + log_action_stack.set(["Logging"]) + log_app.set(record.app) + try: timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") - header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]" + header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>" context = f"\n# Context: {record.ctx}" if record.ctx else "" message = f"{header}\n{record.msg}{context}" @@ -164,12 +211,12 @@ class WebHookHandler(logging.StreamHandler): # Post the log message(s) if self.batch: - if len(message) > 1000: + if len(message) > 1500: await self._send_batched_now() await self._send(message, as_file=as_file) else: self.batched += message - if len(self.batched) + len(message) > 1000: + if len(self.batched) + len(message) > 1500: await self._send_batched_now() else: asyncio.create_task(self._schedule_batched()) @@ -209,9 +256,12 @@ class WebHookHandler(logging.StreamHandler): if as_file or len(message) > 2000: with StringIO(message) as fp: fp.seek(0) - await webhook.send(file=File(fp, filename="logs.md")) + await webhook.send( + file=File(fp, filename="logs.md"), + username=log_app.get() + ) else: - await webhook.send(message) + await webhook.send(message, username=log_app.get()) handlers = [] @@ -254,6 +304,7 @@ if handlers: qhandler = QueueHandler(queue) qhandler.setLevel(logging.INFO) qhandler.addFilter(ContextInjection()) + # qhandler.addFilter(ThreadFilter('MainThread')) logger.addHandler(qhandler) listener = QueueListener( diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bot/utils/lib.py b/bot/utils/lib.py index 5fe3d108..92a832f6 100644 --- a/bot/utils/lib.py +++ b/bot/utils/lib.py @@ -1,6 +1,7 @@ import datetime import iso8601 # type: ignore import re +from contextvars import Context import discord @@ -439,15 +440,6 @@ def jumpto(guildid: int, channeldid: int, messageid: int) -> str: ) -class DotDict(dict): - """ - Dict-type allowing dot access to keys. - """ - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def utc_now() -> datetime.datetime: """ Return the current timezone-aware utc timestamp. @@ -464,3 +456,8 @@ def multiple_replace(string: str, rep_dict: dict[str, str]) -> str: return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string) else: return string + + +def recover_context(context: Context): + for var in context: + var.set(context[var]) diff --git a/bot/utils/ui.py b/bot/utils/ui.py new file mode 100644 index 00000000..c09fee41 --- /dev/null +++ b/bot/utils/ui.py @@ -0,0 +1,67 @@ +from typing import List, Coroutine +import asyncio +import logging +from contextvars import copy_context + +import discord +from discord.ui import Modal + +from .lib import recover_context + + +class FastModal(Modal): + def __init__(self, *items, **kwargs): + super().__init__(**kwargs) + for item in items: + self.add_item(item) + self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() + self._waiters: List[Coroutine[discord.Interaction]] = [] + self._context = copy_context() + + async def wait_for(self, check=None, timeout=None): + # Wait for _result or timeout + # If we timeout, or the view times out, raise TimeoutError + # Otherwise, return the Interaction + # This allows multiple listeners and callbacks to wait on + # TODO: Wait on the timeout as well + while True: + result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) + if check is not None: + if not check(result): + continue + return result + + def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): + def wrapper(coro): + async def wrapped_callback(interaction): + if check is not None: + if not check(interaction): + return + try: + await coro(interaction, *pass_args, **pass_kwargs) + except Exception: + # TODO: Log exception + logging.exception( + f"Exception occurred executing FastModal callback '{coro.__name__}'" + ) + if once: + self._waiters.remove(wrapped_callback) + self._waiters.append(wrapped_callback) + return wrapper + + async def on_submit(self, interaction): + # Transitional patch to re-instantiate the current context + # Not required in py 3.11, instead pass a context parameter to create_task + recover_context(self._context) + + old_result = self._result + self._result = asyncio.get_event_loop().create_future() + old_result.set_result(interaction) + + for waiter in self._waiters: + asyncio.create_task(waiter(interaction)) + + async def on_error(self, interaction, error): + # This should never happen, since on_submit has its own error handling + # TODO: Logging + logging.error("Submit error occured in FastModal")