fix: Massively improve logging context isolation.

This commit is contained in:
2023-08-22 22:14:58 +03:00
parent d578e7471d
commit f2fb3ce344
13 changed files with 440 additions and 319 deletions

View File

@@ -13,7 +13,7 @@ from aiohttp import ClientSession
from data import Database
from .config import Conf
from .logger import logging_context, log_context, log_action_stack
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
from .context import context
from .LionContext import LionContext
from .LionTree import LionTree
@@ -46,6 +46,7 @@ class LionBot(Bot):
self.translator = translator
self._locks = WeakValueDictionary()
self._running_events = set()
async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
@@ -64,27 +65,45 @@ class LionBot(Bot):
await self.tree.sync(guild=guild)
async def add_cog(self, cog: Cog, **kwargs):
with logging_context(action=f"Attach {cog.__cog_name__}"):
sup = super()
@log_wrap(action=f"Attach {cog.__cog_name__}")
async def wrapper():
logger.info(f"Attaching Cog {cog.__cog_name__}")
await super().add_cog(cog, **kwargs)
await sup.add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
await wrapper()
async def load_extension(self, name, *, package=None, **kwargs):
with logging_context(action=f"Load {name.strip('.')}"):
sup = super()
@log_wrap(action=f"Load {name.strip('.')}")
async def wrapper():
logger.info(f"Loading extension {name} in package {package}.")
await super().load_extension(name, package=package, **kwargs)
await sup.load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.")
await wrapper()
async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"):
await self.login(token)
with logging_context(stack=["Running"]):
await self.connect(reconnect=reconnect)
start_task = asyncio.create_task(self.login(token))
await start_task
with logging_context(stack=("Running",)):
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
await run_task
def dispatch(self, event_name: str, *args, **kwargs):
with logging_context(action=f"Dispatch {event_name}"):
super().dispatch(event_name, *args, **kwargs)
def _schedule_event(self, coro, event_name, *args, **kwargs):
"""
Extends client._schedule_event to keep a persistent
background task store.
"""
task = super()._schedule_event(coro, event_name, *args, **kwargs)
self._running_events.add(task)
task.add_done_callback(lambda fut: self._running_events.discard(fut))
def idlock(self, snowflakeid):
lock = self._locks.get(snowflakeid, None)
if lock is None:
@@ -124,9 +143,11 @@ class LionBot(Bot):
async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback
cmd_str = str(ctx.command)
logger.debug(f"Handling command error for {ctx}: {exception}")
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
else:
cmd_str = str(ctx.command)
try:
raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
@@ -191,14 +212,14 @@ class LionBot(Bot):
pass
finally:
exception.original = HandledException(exception.original)
except CheckFailure:
except CheckFailure as e:
logger.debug(
f"Command failed check: {exception}",
f"Command failed check: {e}",
extra={'action': 'BotError', 'with_ctx': True}
)
try:
await ctx.error_reply(exception.message)
except Exception:
await ctx.error_reply(str(e))
except discord.HTTPException:
pass
except Exception:
# Completely unknown exception outside of command invocation!
@@ -209,6 +230,5 @@ class LionBot(Bot):
)
def add_command(self, command):
if hasattr(command, '_placeholder_group_'):
return
super().add_command(command)
if not hasattr(command, '_placeholder_group_'):
super().add_command(command)

View File

@@ -6,13 +6,17 @@ from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace
from .logger import logging_context
from .logger import logging_context, set_logging_context, log_wrap
from .errors import SafeCancellation
logger = logging.getLogger(__name__)
class LionTree(CommandTree):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._call_tasks = set()
async def on_error(self, interaction, error) -> None:
try:
if isinstance(error, CommandInvokeError):
@@ -25,55 +29,66 @@ class LionTree(CommandTree):
except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
def _from_interaction(self, interaction: Interaction) -> None:
@log_wrap(context=f"iid: {interaction.id}", isolate=False)
async def wrapper():
try:
await self._call(interaction)
except AppCommandError as e:
await self._dispatch_error(interaction, e)
task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
self._call_tasks.add(task)
task.add_done_callback(lambda fut: self._call_tasks.discard(fut))
async def _call(self, interaction):
with logging_context(context=f"iid: {interaction.id}"):
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
command, options = self._get_app_command_options(data)
command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above
interaction._cs_namespace = namespace
# Same pre-fill as above
interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
with logging_context(action=f"Acmp {command.qualified_name}"):
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
await command._invoke_autocomplete(interaction, focused, namespace)
return
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
set_logging_context(action=f"Acmp {command.qualified_name}")
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
await command._invoke_autocomplete(interaction, focused, namespace)
return
with logging_context(action=f"Run {command.qualified_name}"):
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")
set_logging_context(action=f"Run {command.qualified_name}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")

View File

@@ -3,7 +3,7 @@ import asyncio
import logging
import pickle
from ..logger import logging_context
from ..logger import logging_context, log_wrap, set_logging_context
logger = logging.getLogger(__name__)
@@ -99,68 +99,70 @@ class AppClient:
# TODO
...
@log_wrap(action="Req")
async def request(self, appid, payload: 'AppPayload', wait_for_reply=True):
with logging_context(action=f"Req {appid}"):
try:
if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
set_logging_context(action=appid)
try:
if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address)
address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address)
writer.write(payload.encoded())
await writer.drain()
writer.write_eof()
if wait_for_reply:
result = await reader.read()
writer.close()
decoded = payload.route.decode(result)
return decoded
else:
return None
except Exception:
logging.exception(f"Failed to send request to {appid}'")
writer.write(payload.encoded())
await writer.drain()
writer.write_eof()
if wait_for_reply:
result = await reader.read()
writer.close()
decoded = payload.route.decode(result)
return decoded
else:
return None
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None
@log_wrap(action="Broadcast")
async def requestall(self, payload, except_self=True, only_my_peers=True):
with logging_context(action="Broadcast"):
peerlist = self.my_peers if only_my_peers else self.peers
results = await asyncio.gather(
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
return_exceptions=True
)
return dict(zip(self.peers.keys(), results))
peerlist = self.my_peers if only_my_peers else self.peers
results = await asyncio.gather(
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
return_exceptions=True
)
return dict(zip(self.peers.keys(), results))
async def handle_request(self, reader, writer):
with logging_context(action="SERV"):
data = await reader.read()
loaded = pickle.loads(data)
route, args, kwargs = loaded
set_logging_context(action="SERV")
data = await reader.read()
loaded = pickle.loads(data)
route, args, kwargs = loaded
with logging_context(action=f"SERV {route}"):
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
set_logging_context(action=route)
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
@log_wrap(stack=("ShardTalk",))
async def connect(self):
"""
Start the local peer server.
Connect to the address server.
"""
with logging_context(stack=['ShardTalk']):
# Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
# Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
logger.info(f"Serving on {self.address}")
await self.server_connection()
logger.info(f"Serving on {self.address}")
await self.server_connection()
class AppPayload:

View File

@@ -4,7 +4,7 @@ import logging
import string
import random
from ..logger import log_context, log_app, logging_context, setup_main_logger
from ..logger import log_context, log_app, setup_main_logger, set_logging_context, log_wrap
from ..config import conf
logger = logging.getLogger(__name__)
@@ -75,45 +75,45 @@ class AppServer:
"""
Register and hold a new client connection.
"""
with logging_context(action=f"CONN {appid}"):
reader, writer = connection
# Add the new client
self.clients[appid] = (address, connection)
set_logging_context(action=f"CONN {appid}")
reader, writer = connection
# Add the new client
self.clients[appid] = (address, connection)
# Send the new client a client list
peers = self.peer_list()
writer.write(pickle.dumps(peers))
writer.write(b'\n')
await writer.drain()
# Send the new client a client list
peers = self.peer_list()
writer.write(pickle.dumps(peers))
writer.write(b'\n')
await writer.drain()
# Announce the new client to everyone
await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
# Announce the new client to everyone
await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
# Keep the connection open until socket closed or EOF (indicating client death)
try:
await reader.read()
finally:
# Connection ended or it broke
logger.info(f"Lost client '{appid}'")
await self.deregister_client(appid)
# Keep the connection open until socket closed or EOF (indicating client death)
try:
await reader.read()
finally:
# Connection ended or it broke
logger.info(f"Lost client '{appid}'")
await self.deregister_client(appid)
async def handle_connection(self, reader, writer):
data = await reader.readline()
route, args, kwargs = pickle.loads(data)
rqid = short_uuid()
set_logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}")
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"):
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
# Execute route
try:
await self.routes[route]((reader, writer), *args, **kwargs)
except Exception:
logger.exception(f"AppServer recieved exception during route '{route}'")
else:
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
if route in self.routes:
# Execute route
try:
await self.routes[route]((reader, writer), *args, **kwargs)
except Exception:
logger.exception(f"AppServer recieved exception during route '{route}'")
else:
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
def peer_list(self):
return {appid: address for appid, (address, _) in self.clients.items()}
@@ -122,27 +122,27 @@ class AppServer:
self.clients.pop(appid, None)
await self.broadcast('drop_peer', (), {'appid': appid})
@log_wrap(action="broadcast")
async def broadcast(self, route, args, kwargs):
with logging_context(action="broadcast"):
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
payload = pickle.dumps((route, args, kwargs))
if self.clients:
await asyncio.gather(
*(self._send(appid, payload) for appid in self.clients),
return_exceptions=True
)
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
payload = pickle.dumps((route, args, kwargs))
if self.clients:
await asyncio.gather(
*(self._send(appid, payload) for appid in self.clients),
return_exceptions=True
)
async def message_client(self, appid, route, args, kwargs):
"""
Send a message to client `appid` along `route` with given arguments.
"""
with logging_context(action=f"MSG {appid}"):
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
if appid not in self.clients:
raise ValueError(f"Client '{appid}' is not connected.")
set_logging_context(action=f"MSG {appid}")
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
if appid not in self.clients:
raise ValueError(f"Client '{appid}' is not connected.")
payload = pickle.dumps((route, args, kwargs))
return await self._send(appid, payload)
payload = pickle.dumps((route, args, kwargs))
return await self._send(appid, payload)
async def _send(self, appid, payload):
"""
@@ -162,18 +162,19 @@ class AppServer:
async def start(self, address):
log_app.set("APPSERVER")
with logging_context(stack=["SERV"]):
server = await asyncio.start_server(self.handle_connection, **address)
logger.info(f"Serving on {address}")
async with server:
await server.serve_forever()
set_logging_context(stack=("SERV",))
server = await asyncio.start_server(self.handle_connection, **address)
logger.info(f"Serving on {address}")
async with server:
await server.serve_forever()
async def start_server():
setup_main_logger()
address = {'host': '127.0.0.1', 'port': '5000'}
server = AppServer()
await server.start(address)
task = asyncio.create_task(server.start(address))
await task
if __name__ == '__main__':

View File

@@ -1,7 +1,7 @@
import sys
import logging
import asyncio
from typing import List
from typing import List, Optional
from logging.handlers import QueueListener, QueueHandler
import queue
import multiprocessing
@@ -24,40 +24,117 @@ log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[])
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
def set_logging_context(
context: Optional[str] = None,
action: Optional[str] = None,
stack: Optional[tuple[str, ...]] = None
):
"""
Statically set the logging context variables to the given values.
If `action` is given, pushes it onto the `log_action_stack`.
"""
if context is not None:
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
@contextmanager
def logging_context(context=None, action=None, stack=None):
"""
Context manager for executing a block of code in a given logging context.
This context manager should only be used around synchronous code.
This is because async code *may* get cancelled or externally garbage collected,
in which case the finally block will be executed in the wrong context.
See https://github.com/python/cpython/issues/93740
This can be refactored nicely if this gets merged:
https://github.com/python/cpython/pull/99634
(It will not necessarily break on async code,
if the async code can be guaranteed to clean up in its own context.)
"""
if context is not None:
context_t = log_context.set(context)
if action is not None:
oldcontext = log_context.get()
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
log_action_stack.set(astack + [action])
if stack is not None:
actions_t = log_action_stack.set(stack)
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
try:
yield
finally:
if context is not None:
log_context.reset(context_t)
if stack is not None:
log_action_stack.reset(actions_t)
if action is not None:
log_context.set(oldcontext)
if stack is not None or action is not None:
log_action_stack.set(astack)
def log_wrap(**kwargs):
def with_log_ctx(isolate=True, **kwargs):
"""
Execute a coroutine inside a given logging context.
If `isolate` is true, ensures that context does not leak
outside the coroutine.
If `isolate` is false, just statically set the context,
which will leak unless the coroutine is
called in an externally copied context.
"""
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
with logging_context(**kwargs):
if isolate:
with logging_context(**kwargs):
# Task creation will synchronously copy the context
# This is gc safe
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
return await task
else:
# This will leak context changes
set_logging_context(**kwargs)
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
# For backwards compatibility
log_wrap = with_log_ctx
def persist_task(task_collection: set):
"""
Coroutine decorator that ensures the coroutine is scheduled as a task
and added to the given task_collection for strong reference
when it is called.
This is just a hack to handle discord.py events potentially
being unexpectedly garbage collected.
Since this also implicitly schedules the coroutine as a task when it is called,
the coroutine will also be run inside an isolated context.
"""
def decorator(coro):
@wraps(coro)
async def wrapped(*w_args, **w_kwargs):
name = f"persisted-{coro.__name__}"
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
task_collection.add(task)
task.add_done_callback(lambda f: task_collection.discard(f))
await task
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
@@ -208,7 +285,7 @@ class WebHookHandler(logging.StreamHandler):
async def post(self, record):
log_context.set("Webhook Logger")
log_action_stack.set(["Logging"])
log_action_stack.set(("Logging",))
log_app.set(record.app)
try: