From f2fb3ce344a21ed0e24f295e95eba2207aa2d2a4 Mon Sep 17 00:00:00 2001 From: Conatum Date: Tue, 22 Aug 2023 22:14:58 +0300 Subject: [PATCH] fix: Massively improve logging context isolation. --- src/analytics/cog.py | 2 +- src/analytics/events.py | 62 ++++++++--------- src/analytics/server.py | 6 +- src/bot.py | 14 ++-- src/meta/LionBot.py | 52 +++++++++----- src/meta/LionTree.py | 105 ++++++++++++++++------------- src/meta/ipc/client.py | 94 +++++++++++++------------- src/meta/ipc/server.py | 101 +++++++++++++-------------- src/meta/logger.py | 105 +++++++++++++++++++++++++---- src/modules/pomodoro/timer.py | 2 +- src/modules/reminders/cog.py | 104 ++++++++++++++-------------- src/modules/sysadmin/blacklists.py | 22 +++--- src/modules/sysadmin/exec_cog.py | 90 ++++++++++++------------- 13 files changed, 440 insertions(+), 319 deletions(-) diff --git a/src/analytics/cog.py b/src/analytics/cog.py index ae4a00fa..a0069a43 100644 --- a/src/analytics/cog.py +++ b/src/analytics/cog.py @@ -110,7 +110,7 @@ class Analytics(LionCog): duration = utc_now() - ctx.message.created_at event = CommandEvent( appname=appname, - cmdname=ctx.command.name if ctx.command else 'Unknown', + cmdname=ctx.command.name if ctx.command else 'Unknown', # TODO: qualified_name cogname=ctx.cog.qualified_name if ctx.cog else None, userid=ctx.author.id, created_at=utc_now(), diff --git a/src/analytics/events.py b/src/analytics/events.py index c1d0b4ca..fcae95b2 100644 --- a/src/analytics/events.py +++ b/src/analytics/events.py @@ -5,7 +5,7 @@ from collections import namedtuple from typing import NamedTuple, Optional, Generic, Type, TypeVar from meta.ipc import AppRoute, AppClient -from meta.logger import logging_context, log_wrap +from meta.logger import logging_context, log_wrap, set_logging_context from data import RowModel from .data import AnalyticsData, CommandStatus, VoiceAction, GuildAction @@ -52,39 +52,39 @@ class EventHandler(Generic[T]): f"Queue on event handler {self.route_name} is full! Discarding event {data}" ) + @log_wrap(action='consumer', isolate=False) async def consumer(self): - with logging_context(action='consumer'): - while True: - try: - item = await self.queue.get() - self.batch.append(item) - if len(self.batch) > self.batch_size: - await self.process_batch() - except asyncio.CancelledError: - # Try and process the last batch - logger.info( - f"Event handler {self.route_name} received cancellation signal! " - "Trying to process last batch." - ) - if self.batch: - await self.process_batch() - raise - except Exception: - logger.exception( - f"Event handler {self.route_name} received unhandled error." - " Ignoring and continuing cautiously." - ) - pass + while True: + try: + item = await self.queue.get() + self.batch.append(item) + if len(self.batch) > self.batch_size: + await self.process_batch() + except asyncio.CancelledError: + # Try and process the last batch + logger.info( + f"Event handler {self.route_name} received cancellation signal! " + "Trying to process last batch." + ) + if self.batch: + await self.process_batch() + raise + except Exception: + logger.exception( + f"Event handler {self.route_name} received unhandled error." + " Ignoring and continuing cautiously." + ) + pass + @log_wrap(action='batch', isolate=False) async def process_batch(self): - with logging_context(action='batch'): - logger.debug("Processing Batch") - # TODO: copy syntax might be more efficient here - await self.model.table.insert_many( - self.struct._fields, - *map(tuple, self.batch) - ) - self.batch.clear() + logger.debug("Processing Batch") + # TODO: copy syntax might be more efficient here + await self.model.table.insert_many( + self.struct._fields, + *map(tuple, self.batch) + ) + self.batch.clear() def bind(self, client: AppClient): """ diff --git a/src/analytics/server.py b/src/analytics/server.py index 619cb60f..94d4e3c6 100644 --- a/src/analytics/server.py +++ b/src/analytics/server.py @@ -67,7 +67,11 @@ class AnalyticsServer: results = await self.talk_shard_snapshot().broadcast() # Make sure everyone sent results and there were no exceptions (e.g. concurrency) - if not all(result is not None and not isinstance(result, Exception) for result in results.values()): + failed = not isinstance(results, dict) + failed = failed or not any( + result is None or isinstance(result, Exception) for result in results.values() + ) + if failed: # This should essentially never happen # Either some of the shards could not make a snapshot (e.g. Discord client issues) # or they disconnected in the process. diff --git a/src/bot.py b/src/bot.py index c66d1ce9..2761af82 100644 --- a/src/bot.py +++ b/src/bot.py @@ -7,7 +7,7 @@ from discord.ext import commands from meta import LionBot, conf, sharding, appname, shard_talk from meta.app import shardname -from meta.logger import log_context, log_action_stack, logging_context, setup_main_logger +from meta.logger import log_context, log_action_stack, setup_main_logger from meta.context import ctx_bot from data import Database @@ -30,7 +30,7 @@ db = Database(conf.data['args']) async def main(): - log_action_stack.set(["Initialising"]) + log_action_stack.set(("Initialising",)) logger.info("Initialising StudyLion") intents = discord.Intents.all() @@ -73,12 +73,12 @@ async def main(): ) as lionbot: ctx_bot.set(lionbot) try: - with logging_context(context=f"APP: {appname}"): - logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) - await lionbot.start(conf.bot['TOKEN']) + log_context.set(f"APP: {appname}") + logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) + await lionbot.start(conf.bot['TOKEN']) except asyncio.CancelledError: - with logging_context(context=f"APP: {appname}", action="Shutting Down"): - logger.info("StudyLion closed, shutting down.", exc_info=True) + log_context.set(f"APP: {appname}") + logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) def _main(): diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 57363b1d..cd73bcff 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -13,7 +13,7 @@ from aiohttp import ClientSession from data import Database from .config import Conf -from .logger import logging_context, log_context, log_action_stack +from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context from .context import context from .LionContext import LionContext from .LionTree import LionTree @@ -46,6 +46,7 @@ class LionBot(Bot): self.translator = translator self._locks = WeakValueDictionary() + self._running_events = set() async def setup_hook(self) -> None: log_context.set(f"APP: {self.application_id}") @@ -64,27 +65,45 @@ class LionBot(Bot): await self.tree.sync(guild=guild) async def add_cog(self, cog: Cog, **kwargs): - with logging_context(action=f"Attach {cog.__cog_name__}"): + sup = super() + @log_wrap(action=f"Attach {cog.__cog_name__}") + async def wrapper(): logger.info(f"Attaching Cog {cog.__cog_name__}") - await super().add_cog(cog, **kwargs) + await sup.add_cog(cog, **kwargs) logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.") + await wrapper() async def load_extension(self, name, *, package=None, **kwargs): - with logging_context(action=f"Load {name.strip('.')}"): + sup = super() + @log_wrap(action=f"Load {name.strip('.')}") + async def wrapper(): logger.info(f"Loading extension {name} in package {package}.") - await super().load_extension(name, package=package, **kwargs) + await sup.load_extension(name, package=package, **kwargs) logger.debug(f"Loaded extension {name} in package {package}.") + await wrapper() async def start(self, token: str, *, reconnect: bool = True): with logging_context(action="Login"): - await self.login(token) - with logging_context(stack=["Running"]): - await self.connect(reconnect=reconnect) + start_task = asyncio.create_task(self.login(token)) + await start_task + + with logging_context(stack=("Running",)): + run_task = asyncio.create_task(self.connect(reconnect=reconnect)) + await run_task def dispatch(self, event_name: str, *args, **kwargs): with logging_context(action=f"Dispatch {event_name}"): super().dispatch(event_name, *args, **kwargs) + def _schedule_event(self, coro, event_name, *args, **kwargs): + """ + Extends client._schedule_event to keep a persistent + background task store. + """ + task = super()._schedule_event(coro, event_name, *args, **kwargs) + self._running_events.add(task) + task.add_done_callback(lambda fut: self._running_events.discard(fut)) + def idlock(self, snowflakeid): lock = self._locks.get(snowflakeid, None) if lock is None: @@ -124,9 +143,11 @@ class LionBot(Bot): async def on_command_error(self, ctx, exception): # TODO: Some of these could have more user-feedback - cmd_str = str(ctx.command) + logger.debug(f"Handling command error for {ctx}: {exception}") if isinstance(ctx.command, HybridCommand) and ctx.command.app_command: cmd_str = ctx.command.app_command.to_dict() + else: + cmd_str = str(ctx.command) try: raise exception except (HybridCommandError, CommandInvokeError, appCommandInvokeError): @@ -191,14 +212,14 @@ class LionBot(Bot): pass finally: exception.original = HandledException(exception.original) - except CheckFailure: + except CheckFailure as e: logger.debug( - f"Command failed check: {exception}", + f"Command failed check: {e}", extra={'action': 'BotError', 'with_ctx': True} ) try: - await ctx.error_reply(exception.message) - except Exception: + await ctx.error_reply(str(e)) + except discord.HTTPException: pass except Exception: # Completely unknown exception outside of command invocation! @@ -209,6 +230,5 @@ class LionBot(Bot): ) def add_command(self, command): - if hasattr(command, '_placeholder_group_'): - return - super().add_command(command) + if not hasattr(command, '_placeholder_group_'): + super().add_command(command) diff --git a/src/meta/LionTree.py b/src/meta/LionTree.py index 05710972..9a2dd7ec 100644 --- a/src/meta/LionTree.py +++ b/src/meta/LionTree.py @@ -6,13 +6,17 @@ from discord.app_commands.errors import AppCommandError, CommandInvokeError from discord.enums import InteractionType from discord.app_commands.namespace import Namespace -from .logger import logging_context +from .logger import logging_context, set_logging_context, log_wrap from .errors import SafeCancellation logger = logging.getLogger(__name__) class LionTree(CommandTree): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._call_tasks = set() + async def on_error(self, interaction, error) -> None: try: if isinstance(error, CommandInvokeError): @@ -25,55 +29,66 @@ class LionTree(CommandTree): except Exception: logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'}) + def _from_interaction(self, interaction: Interaction) -> None: + @log_wrap(context=f"iid: {interaction.id}", isolate=False) + async def wrapper(): + try: + await self._call(interaction) + except AppCommandError as e: + await self._dispatch_error(interaction, e) + + task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker') + self._call_tasks.add(task) + task.add_done_callback(lambda fut: self._call_tasks.discard(fut)) + async def _call(self, interaction): - with logging_context(context=f"iid: {interaction.id}"): - if not await self.interaction_check(interaction): - interaction.command_failed = True - return + if not await self.interaction_check(interaction): + interaction.command_failed = True + return - data = interaction.data # type: ignore - type = data.get('type', 1) - if type != 1: - # Context menu command... - await self._call_context_menu(interaction, data, type) - return + data = interaction.data # type: ignore + type = data.get('type', 1) + if type != 1: + # Context menu command... + await self._call_context_menu(interaction, data, type) + return - command, options = self._get_app_command_options(data) + command, options = self._get_app_command_options(data) - # Pre-fill the cached slot to prevent re-computation - interaction._cs_command = command + # Pre-fill the cached slot to prevent re-computation + interaction._cs_command = command - # At this point options refers to the arguments of the command - # and command refers to the class type we care about - namespace = Namespace(interaction, data.get('resolved', {}), options) + # At this point options refers to the arguments of the command + # and command refers to the class type we care about + namespace = Namespace(interaction, data.get('resolved', {}), options) - # Same pre-fill as above - interaction._cs_namespace = namespace + # Same pre-fill as above + interaction._cs_namespace = namespace - # Auto complete handles the namespace differently... so at this point this is where we decide where that is. - if interaction.type is InteractionType.autocomplete: - with logging_context(action=f"Acmp {command.qualified_name}"): - focused = next((opt['name'] for opt in options if opt.get('focused')), None) - if focused is None: - raise AppCommandError( - 'This should not happen, but there is no focused element. This is a Discord bug.' - ) - await command._invoke_autocomplete(interaction, focused, namespace) - return + # Auto complete handles the namespace differently... so at this point this is where we decide where that is. + if interaction.type is InteractionType.autocomplete: + set_logging_context(action=f"Acmp {command.qualified_name}") + focused = next((opt['name'] for opt in options if opt.get('focused')), None) + if focused is None: + raise AppCommandError( + 'This should not happen, but there is no focused element. This is a Discord bug.' + ) + await command._invoke_autocomplete(interaction, focused, namespace) + return - with logging_context(action=f"Run {command.qualified_name}"): - logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}") - try: - await command._invoke_with_namespace(interaction, namespace) - except AppCommandError as e: - interaction.command_failed = True - await command._invoke_error_handlers(interaction, e) - await self.on_error(interaction, e) - else: - if not interaction.command_failed: - self.client.dispatch('app_command_completion', interaction, command) - finally: - if interaction.command_failed: - logger.debug("Command completed with errors.") - else: - logger.debug("Command completed without errors.") + set_logging_context(action=f"Run {command.qualified_name}") + logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}") + try: + await command._invoke_with_namespace(interaction, namespace) + except AppCommandError as e: + interaction.command_failed = True + await command._invoke_error_handlers(interaction, e) + await self.on_error(interaction, e) + else: + if not interaction.command_failed: + self.client.dispatch('app_command_completion', interaction, command) + finally: + if interaction.command_failed: + logger.debug("Command completed with errors.") + else: + logger.debug("Command completed without errors.") diff --git a/src/meta/ipc/client.py b/src/meta/ipc/client.py index 82fca1ad..be0d43ca 100644 --- a/src/meta/ipc/client.py +++ b/src/meta/ipc/client.py @@ -3,7 +3,7 @@ import asyncio import logging import pickle -from ..logger import logging_context +from ..logger import logging_context, log_wrap, set_logging_context logger = logging.getLogger(__name__) @@ -99,68 +99,70 @@ class AppClient: # TODO ... + @log_wrap(action="Req") async def request(self, appid, payload: 'AppPayload', wait_for_reply=True): - with logging_context(action=f"Req {appid}"): - try: - if appid not in self.peers: - raise ValueError(f"Peer '{appid}' not found.") - logger.debug(f"Sending request to app '{appid}' with payload {payload}") + set_logging_context(action=appid) + try: + if appid not in self.peers: + raise ValueError(f"Peer '{appid}' not found.") + logger.debug(f"Sending request to app '{appid}' with payload {payload}") - address = self.peers[appid] - reader, writer = await asyncio.open_connection(**address) + address = self.peers[appid] + reader, writer = await asyncio.open_connection(**address) - writer.write(payload.encoded()) - await writer.drain() - writer.write_eof() - if wait_for_reply: - result = await reader.read() - writer.close() - decoded = payload.route.decode(result) - return decoded - else: - return None - except Exception: - logging.exception(f"Failed to send request to {appid}'") + writer.write(payload.encoded()) + await writer.drain() + writer.write_eof() + if wait_for_reply: + result = await reader.read() + writer.close() + decoded = payload.route.decode(result) + return decoded + else: return None + except Exception: + logging.exception(f"Failed to send request to {appid}'") + return None + @log_wrap(action="Broadcast") async def requestall(self, payload, except_self=True, only_my_peers=True): - with logging_context(action="Broadcast"): - peerlist = self.my_peers if only_my_peers else self.peers - results = await asyncio.gather( - *(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)), - return_exceptions=True - ) - return dict(zip(self.peers.keys(), results)) + peerlist = self.my_peers if only_my_peers else self.peers + results = await asyncio.gather( + *(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)), + return_exceptions=True + ) + return dict(zip(self.peers.keys(), results)) async def handle_request(self, reader, writer): - with logging_context(action="SERV"): - data = await reader.read() - loaded = pickle.loads(data) - route, args, kwargs = loaded + set_logging_context(action="SERV") + data = await reader.read() + loaded = pickle.loads(data) + route, args, kwargs = loaded - with logging_context(action=f"SERV {route}"): - logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}") + set_logging_context(action=route) - if route in self.routes: - try: - await self.routes[route].run((reader, writer), args, kwargs) - except Exception: - logger.exception(f"Fatal exception during route '{route}'. This should never happen!") - else: - logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.") - writer.write_eof() + logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}") + if route in self.routes: + try: + await self.routes[route].run((reader, writer), args, kwargs) + except Exception: + logger.exception(f"Fatal exception during route '{route}'. This should never happen!") + else: + logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.") + writer.write_eof() + + @log_wrap(stack=("ShardTalk",)) async def connect(self): """ Start the local peer server. Connect to the address server. """ - with logging_context(stack=['ShardTalk']): - # Start the client server - self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True) + # Start the client server + self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True) - logger.info(f"Serving on {self.address}") - await self.server_connection() + logger.info(f"Serving on {self.address}") + await self.server_connection() class AppPayload: diff --git a/src/meta/ipc/server.py b/src/meta/ipc/server.py index 43181d8c..0ebb8627 100644 --- a/src/meta/ipc/server.py +++ b/src/meta/ipc/server.py @@ -4,7 +4,7 @@ import logging import string import random -from ..logger import log_context, log_app, logging_context, setup_main_logger +from ..logger import log_context, log_app, setup_main_logger, set_logging_context, log_wrap from ..config import conf logger = logging.getLogger(__name__) @@ -75,45 +75,45 @@ class AppServer: """ Register and hold a new client connection. """ - with logging_context(action=f"CONN {appid}"): - reader, writer = connection - # Add the new client - self.clients[appid] = (address, connection) + set_logging_context(action=f"CONN {appid}") + reader, writer = connection + # Add the new client + self.clients[appid] = (address, connection) - # Send the new client a client list - peers = self.peer_list() - writer.write(pickle.dumps(peers)) - writer.write(b'\n') - await writer.drain() + # Send the new client a client list + peers = self.peer_list() + writer.write(pickle.dumps(peers)) + writer.write(b'\n') + await writer.drain() - # Announce the new client to everyone - await self.broadcast('new_peer', (), {'appid': appid, 'address': address}) + # Announce the new client to everyone + await self.broadcast('new_peer', (), {'appid': appid, 'address': address}) - # Keep the connection open until socket closed or EOF (indicating client death) - try: - await reader.read() - finally: - # Connection ended or it broke - logger.info(f"Lost client '{appid}'") - await self.deregister_client(appid) + # Keep the connection open until socket closed or EOF (indicating client death) + try: + await reader.read() + finally: + # Connection ended or it broke + logger.info(f"Lost client '{appid}'") + await self.deregister_client(appid) async def handle_connection(self, reader, writer): data = await reader.readline() route, args, kwargs = pickle.loads(data) rqid = short_uuid() + + set_logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}") + logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}") - with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"): - logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}") - - if route in self.routes: - # Execute route - try: - await self.routes[route]((reader, writer), *args, **kwargs) - except Exception: - logger.exception(f"AppServer recieved exception during route '{route}'") - else: - logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.") + if route in self.routes: + # Execute route + try: + await self.routes[route]((reader, writer), *args, **kwargs) + except Exception: + logger.exception(f"AppServer recieved exception during route '{route}'") + else: + logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.") def peer_list(self): return {appid: address for appid, (address, _) in self.clients.items()} @@ -122,27 +122,27 @@ class AppServer: self.clients.pop(appid, None) await self.broadcast('drop_peer', (), {'appid': appid}) + @log_wrap(action="broadcast") async def broadcast(self, route, args, kwargs): - with logging_context(action="broadcast"): - logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.") - payload = pickle.dumps((route, args, kwargs)) - if self.clients: - await asyncio.gather( - *(self._send(appid, payload) for appid in self.clients), - return_exceptions=True - ) + logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.") + payload = pickle.dumps((route, args, kwargs)) + if self.clients: + await asyncio.gather( + *(self._send(appid, payload) for appid in self.clients), + return_exceptions=True + ) async def message_client(self, appid, route, args, kwargs): """ Send a message to client `appid` along `route` with given arguments. """ - with logging_context(action=f"MSG {appid}"): - logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.") - if appid not in self.clients: - raise ValueError(f"Client '{appid}' is not connected.") + set_logging_context(action=f"MSG {appid}") + logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.") + if appid not in self.clients: + raise ValueError(f"Client '{appid}' is not connected.") - payload = pickle.dumps((route, args, kwargs)) - return await self._send(appid, payload) + payload = pickle.dumps((route, args, kwargs)) + return await self._send(appid, payload) async def _send(self, appid, payload): """ @@ -162,18 +162,19 @@ class AppServer: async def start(self, address): log_app.set("APPSERVER") - with logging_context(stack=["SERV"]): - server = await asyncio.start_server(self.handle_connection, **address) - logger.info(f"Serving on {address}") - async with server: - await server.serve_forever() + set_logging_context(stack=("SERV",)) + server = await asyncio.start_server(self.handle_connection, **address) + logger.info(f"Serving on {address}") + async with server: + await server.serve_forever() async def start_server(): setup_main_logger() address = {'host': '127.0.0.1', 'port': '5000'} server = AppServer() - await server.start(address) + task = asyncio.create_task(server.start(address)) + await task if __name__ == '__main__': diff --git a/src/meta/logger.py b/src/meta/logger.py index ba76cf82..9ace46b1 100644 --- a/src/meta/logger.py +++ b/src/meta/logger.py @@ -1,7 +1,7 @@ import sys import logging import asyncio -from typing import List +from typing import List, Optional from logging.handlers import QueueListener, QueueHandler import queue import multiprocessing @@ -24,40 +24,117 @@ log_logger.propagate = False log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') -log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[]) +log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=()) log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number)) +def set_logging_context( + context: Optional[str] = None, + action: Optional[str] = None, + stack: Optional[tuple[str, ...]] = None +): + """ + Statically set the logging context variables to the given values. + + If `action` is given, pushes it onto the `log_action_stack`. + """ + if context is not None: + log_context.set(context) + if action is not None or stack is not None: + astack = log_action_stack.get() + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) + @contextmanager def logging_context(context=None, action=None, stack=None): + """ + Context manager for executing a block of code in a given logging context. + + This context manager should only be used around synchronous code. + This is because async code *may* get cancelled or externally garbage collected, + in which case the finally block will be executed in the wrong context. + See https://github.com/python/cpython/issues/93740 + This can be refactored nicely if this gets merged: + https://github.com/python/cpython/pull/99634 + + (It will not necessarily break on async code, + if the async code can be guaranteed to clean up in its own context.) + """ if context is not None: - context_t = log_context.set(context) - if action is not None: + oldcontext = log_context.get() + log_context.set(context) + if action is not None or stack is not None: astack = log_action_stack.get() - log_action_stack.set(astack + [action]) - if stack is not None: - actions_t = log_action_stack.set(stack) + newstack = stack if stack is not None else astack + if action is not None: + newstack = (*newstack, action) + log_action_stack.set(newstack) try: yield finally: if context is not None: - log_context.reset(context_t) - if stack is not None: - log_action_stack.reset(actions_t) - if action is not None: + log_context.set(oldcontext) + if stack is not None or action is not None: log_action_stack.set(astack) -def log_wrap(**kwargs): +def with_log_ctx(isolate=True, **kwargs): + """ + Execute a coroutine inside a given logging context. + + If `isolate` is true, ensures that context does not leak + outside the coroutine. + + If `isolate` is false, just statically set the context, + which will leak unless the coroutine is + called in an externally copied context. + """ def decorator(func): @wraps(func) async def wrapped(*w_args, **w_kwargs): - with logging_context(**kwargs): + if isolate: + with logging_context(**kwargs): + # Task creation will synchronously copy the context + # This is gc safe + name = kwargs.get('action', f"log-wrapped-{func.__name__}") + task = asyncio.create_task(func(*w_args, **w_kwargs), name=name) + return await task + else: + # This will leak context changes + set_logging_context(**kwargs) return await func(*w_args, **w_kwargs) return wrapped return decorator +# For backwards compatibility +log_wrap = with_log_ctx + + +def persist_task(task_collection: set): + """ + Coroutine decorator that ensures the coroutine is scheduled as a task + and added to the given task_collection for strong reference + when it is called. + + This is just a hack to handle discord.py events potentially + being unexpectedly garbage collected. + + Since this also implicitly schedules the coroutine as a task when it is called, + the coroutine will also be run inside an isolated context. + """ + def decorator(coro): + @wraps(coro) + async def wrapped(*w_args, **w_kwargs): + name = f"persisted-{coro.__name__}" + task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name) + task_collection.add(task) + task.add_done_callback(lambda f: task_collection.discard(f)) + await task + + RESET_SEQ = "\033[0m" COLOR_SEQ = "\033[3%dm" BOLD_SEQ = "\033[1m" @@ -208,7 +285,7 @@ class WebHookHandler(logging.StreamHandler): async def post(self, record): log_context.set("Webhook Logger") - log_action_stack.set(["Logging"]) + log_action_stack.set(("Logging",)) log_app.set(record.app) try: diff --git a/src/modules/pomodoro/timer.py b/src/modules/pomodoro/timer.py index bbddcbf2..ef43af0d 100644 --- a/src/modules/pomodoro/timer.py +++ b/src/modules/pomodoro/timer.py @@ -717,7 +717,7 @@ class Timer: f"Timer deleted. Reason given: {reason!r}" ) - @log_wrap(stack=['Timer Loop']) + @log_wrap(action='Timer Loop') async def _runloop(self): """ Main loop which controls the diff --git a/src/modules/reminders/cog.py b/src/modules/reminders/cog.py index 429f76a6..8d88bcd2 100644 --- a/src/modules/reminders/cog.py +++ b/src/modules/reminders/cog.py @@ -27,7 +27,7 @@ from data.columns import Integer, String, Timestamp, Bool from meta import LionBot, LionCog, LionContext from meta.app import shard_talk, appname_from_shard -from meta.logger import log_wrap, logging_context +from meta.logger import log_wrap, logging_context, set_logging_context from babel import ctx_translator, ctx_locale @@ -280,6 +280,7 @@ class Reminders(LionCog): f"Scheduled new reminders: {tuple(reminder.reminderid for reminder in reminders)}", ) + @log_wrap(action="Send Reminder") async def execute_reminder(self, reminderid): """ Send the reminder with the given reminderid. @@ -287,64 +288,65 @@ class Reminders(LionCog): This should in general only be executed from the executor shard, through a ReminderMonitor instance. """ - with logging_context(action='Send Reminder', context=f"rid: {reminderid}"): - reminder = await self.data.Reminder.fetch(reminderid) - if reminder is None: - logger.warning( - f"Attempted to execute a reminder that no longer exists!" - ) - return + set_logging_context(context=f"rid: {reminderid}") - try: - # Try and find the user - userid = reminder.userid - if not (user := self.bot.get_user(userid)): - user = await self.bot.fetch_user(userid) + reminder = await self.data.Reminder.fetch(reminderid) + if reminder is None: + logger.warning( + f"Attempted to execute a reminder that no longer exists!" + ) + return - # Set the locale variables - locale = await self.bot.get_cog('BabelCog').get_user_locale(userid) - ctx_locale.set(locale) - ctx_translator.set(self.bot.translator) + try: + # Try and find the user + userid = reminder.userid + if not (user := self.bot.get_user(userid)): + user = await self.bot.fetch_user(userid) - # Build the embed - embed = reminder.embed + # Set the locale variables + locale = await self.bot.get_cog('BabelCog').get_user_locale(userid) + ctx_locale.set(locale) + ctx_translator.set(self.bot.translator) - # Attempt to send to user - # TODO: Consider adding a View to this, for cancelling a repeated reminder or showing reminders - await user.send(embed=embed) + # Build the embed + embed = reminder.embed - # Update the data as required - if reminder.interval: - now = utc_now() - # Use original reminder time to calculate repeat, avoiding drift - next_time = reminder.remind_at + dt.timedelta(seconds=reminder.interval) - # Skip any expired repeats, to avoid spamming requests after downtime - # TODO: Is this actually dst safe? - while next_time.timestamp() <= now.timestamp(): - next_time = next_time + dt.timedelta(seconds=reminder.interval) - await reminder.update(remind_at=next_time) - self.monitor.schedule_task(reminder.reminderid, reminder.timestamp) - logger.debug( - f"Executed reminder and scheduled repeat at {next_time}." - ) - else: - await reminder.delete() - logger.debug( - f"Executed reminder ." - ) - except discord.HTTPException as e: - await reminder.update(failed=True) + # Attempt to send to user + # TODO: Consider adding a View to this, for cancelling a repeated reminder or showing reminders + await user.send(embed=embed) + + # Update the data as required + if reminder.interval: + now = utc_now() + # Use original reminder time to calculate repeat, avoiding drift + next_time = reminder.remind_at + dt.timedelta(seconds=reminder.interval) + # Skip any expired repeats, to avoid spamming requests after downtime + # TODO: Is this actually dst safe? + while next_time.timestamp() <= now.timestamp(): + next_time = next_time + dt.timedelta(seconds=reminder.interval) + await reminder.update(remind_at=next_time) + self.monitor.schedule_task(reminder.reminderid, reminder.timestamp) logger.debug( - f"Reminder could not be sent: {e.text}", + f"Executed reminder and scheduled repeat at {next_time}." ) - except Exception: - await reminder.update(failed=True) - logger.exception( - f"Reminder failed for an unknown reason!" + else: + await reminder.delete() + logger.debug( + f"Executed reminder ." ) - finally: - # Dispatch for analytics - self.bot.dispatch('reminder_sent', reminder) + except discord.HTTPException as e: + await reminder.update(failed=True) + logger.debug( + f"Reminder could not be sent: {e.text}", + ) + except Exception: + await reminder.update(failed=True) + logger.exception( + f"Reminder failed for an unknown reason!" + ) + finally: + # Dispatch for analytics + self.bot.dispatch('reminder_sent', reminder) @cmds.hybrid_group( name=_p('cmd:reminders', "reminders") diff --git a/src/modules/sysadmin/blacklists.py b/src/modules/sysadmin/blacklists.py index 095dd6c2..881e2446 100644 --- a/src/modules/sysadmin/blacklists.py +++ b/src/modules/sysadmin/blacklists.py @@ -13,7 +13,7 @@ from discord.ui.button import button from discord.ui.text_input import TextStyle, TextInput from meta import LionCog, LionBot, LionContext -from meta.logger import logging_context, log_wrap +from meta.logger import logging_context, log_wrap, set_logging_context from meta.errors import UserInputError from meta.app import shard_talk @@ -73,8 +73,7 @@ class Blacklists(LionCog): f"Loaded {len(self.guild_blacklist)} blacklisted guilds." ) if self.bot.is_ready(): - with logging_context(action="Guild Blacklist"): - await self.leave_blacklisted_guilds() + await self.leave_blacklisted_guilds() @LionCog.listener('on_ready') @log_wrap(action="Guild Blacklist") @@ -84,8 +83,9 @@ class Blacklists(LionCog): guild for guild in self.bot.guilds if guild.id in self.guild_blacklist ] - - asyncio.gather(*(guild.leave() for guild in to_leave)) + if to_leave: + tasks = [asyncio.create_task(guild.leave()) for guild in to_leave] + await asyncio.gather(*tasks) logger.info( "Left {} blacklisted guilds.".format(len(to_leave)), @@ -95,12 +95,12 @@ class Blacklists(LionCog): @log_wrap(action="Check Guild Blacklist") async def check_guild_blacklist(self, guild): """Check if the given guild is in the blacklist, and leave if so.""" - with logging_context(context=f"gid: {guild.id}"): - if guild.id in self.guild_blacklist: - await guild.leave() - logger.info( - "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id) - ) + if guild.id in self.guild_blacklist: + set_logging_context(context=f"gid: {guild.id}") + await guild.leave() + logger.info( + "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id) + ) async def bot_check_once(self, ctx: LionContext) -> bool: # type:ignore if ctx.author.id in self.user_blacklist: diff --git a/src/modules/sysadmin/exec_cog.py b/src/modules/sysadmin/exec_cog.py index 7cd965e7..b8e652a8 100644 --- a/src/modules/sysadmin/exec_cog.py +++ b/src/modules/sysadmin/exec_cog.py @@ -19,7 +19,7 @@ from discord.ui import TextInput, View from discord.ui.button import button import discord.app_commands as appcmd -from meta.logger import logging_context +from meta.logger import logging_context, log_wrap from meta.app import shard_talk from meta import conf from meta.context import context, ctx_bot @@ -185,54 +185,54 @@ def mk_print(fp: io.StringIO) -> Callable[..., None]: return _print +@log_wrap(action="Code Exec") async def _async(to_eval: str, style='exec'): - with logging_context(action="Code Exec"): - newline = '\n' * ('\n' in to_eval) - logger.info( - f"Exec code with {style}: {newline}{to_eval}" + newline = '\n' * ('\n' in to_eval) + logger.info( + f"Exec code with {style}: {newline}{to_eval}" + ) + + output = io.StringIO() + _print = mk_print(output) + + scope: dict[str, Any] = dict(sys.modules) + scope['__builtins__'] = builtins + scope.update(builtins.__dict__) + scope['ctx'] = ctx = context.get() + scope['bot'] = ctx_bot.get() + scope['print'] = _print # type: ignore + + try: + if ctx and ctx.message: + source_str = f"" + elif ctx and ctx.interaction: + source_str = f"" + else: + source_str = "Unknown async" + + code = compile( + to_eval, + source_str, + style, + ast.PyCF_ALLOW_TOP_LEVEL_AWAIT ) + func = types.FunctionType(code, scope) - output = io.StringIO() - _print = mk_print(output) + ret = func() + if inspect.iscoroutine(ret): + ret = await ret + if ret is not None: + _print(repr(ret)) + except Exception: + _, exc, tb = sys.exc_info() + _print("".join(traceback.format_tb(tb))) + _print(f"{type(exc).__name__}: {exc}") - scope: dict[str, Any] = dict(sys.modules) - scope['__builtins__'] = builtins - scope.update(builtins.__dict__) - scope['ctx'] = ctx = context.get() - scope['bot'] = ctx_bot.get() - scope['print'] = _print # type: ignore - - try: - if ctx and ctx.message: - source_str = f"" - elif ctx and ctx.interaction: - source_str = f"" - else: - source_str = "Unknown async" - - code = compile( - to_eval, - source_str, - style, - ast.PyCF_ALLOW_TOP_LEVEL_AWAIT - ) - func = types.FunctionType(code, scope) - - ret = func() - if inspect.iscoroutine(ret): - ret = await ret - if ret is not None: - _print(repr(ret)) - except Exception: - _, exc, tb = sys.exc_info() - _print("".join(traceback.format_tb(tb))) - _print(f"{type(exc).__name__}: {exc}") - - result = output.getvalue().strip() - newline = '\n' * ('\n' in result) - logger.info( - f"Exec complete, output: {newline}{result}" - ) + result = output.getvalue().strip() + newline = '\n' * ('\n' in result) + logger.info( + f"Exec complete, output: {newline}{result}" + ) return result