fix: Massively improve logging context isolation.

This commit is contained in:
2023-08-22 22:14:58 +03:00
parent d578e7471d
commit f2fb3ce344
13 changed files with 440 additions and 319 deletions

View File

@@ -110,7 +110,7 @@ class Analytics(LionCog):
duration = utc_now() - ctx.message.created_at duration = utc_now() - ctx.message.created_at
event = CommandEvent( event = CommandEvent(
appname=appname, appname=appname,
cmdname=ctx.command.name if ctx.command else 'Unknown', cmdname=ctx.command.name if ctx.command else 'Unknown', # TODO: qualified_name
cogname=ctx.cog.qualified_name if ctx.cog else None, cogname=ctx.cog.qualified_name if ctx.cog else None,
userid=ctx.author.id, userid=ctx.author.id,
created_at=utc_now(), created_at=utc_now(),

View File

@@ -5,7 +5,7 @@ from collections import namedtuple
from typing import NamedTuple, Optional, Generic, Type, TypeVar from typing import NamedTuple, Optional, Generic, Type, TypeVar
from meta.ipc import AppRoute, AppClient from meta.ipc import AppRoute, AppClient
from meta.logger import logging_context, log_wrap from meta.logger import logging_context, log_wrap, set_logging_context
from data import RowModel from data import RowModel
from .data import AnalyticsData, CommandStatus, VoiceAction, GuildAction from .data import AnalyticsData, CommandStatus, VoiceAction, GuildAction
@@ -52,39 +52,39 @@ class EventHandler(Generic[T]):
f"Queue on event handler {self.route_name} is full! Discarding event {data}" f"Queue on event handler {self.route_name} is full! Discarding event {data}"
) )
@log_wrap(action='consumer', isolate=False)
async def consumer(self): async def consumer(self):
with logging_context(action='consumer'): while True:
while True: try:
try: item = await self.queue.get()
item = await self.queue.get() self.batch.append(item)
self.batch.append(item) if len(self.batch) > self.batch_size:
if len(self.batch) > self.batch_size: await self.process_batch()
await self.process_batch() except asyncio.CancelledError:
except asyncio.CancelledError: # Try and process the last batch
# Try and process the last batch logger.info(
logger.info( f"Event handler {self.route_name} received cancellation signal! "
f"Event handler {self.route_name} received cancellation signal! " "Trying to process last batch."
"Trying to process last batch." )
) if self.batch:
if self.batch: await self.process_batch()
await self.process_batch() raise
raise except Exception:
except Exception: logger.exception(
logger.exception( f"Event handler {self.route_name} received unhandled error."
f"Event handler {self.route_name} received unhandled error." " Ignoring and continuing cautiously."
" Ignoring and continuing cautiously." )
) pass
pass
@log_wrap(action='batch', isolate=False)
async def process_batch(self): async def process_batch(self):
with logging_context(action='batch'): logger.debug("Processing Batch")
logger.debug("Processing Batch") # TODO: copy syntax might be more efficient here
# TODO: copy syntax might be more efficient here await self.model.table.insert_many(
await self.model.table.insert_many( self.struct._fields,
self.struct._fields, *map(tuple, self.batch)
*map(tuple, self.batch) )
) self.batch.clear()
self.batch.clear()
def bind(self, client: AppClient): def bind(self, client: AppClient):
""" """

View File

@@ -67,7 +67,11 @@ class AnalyticsServer:
results = await self.talk_shard_snapshot().broadcast() results = await self.talk_shard_snapshot().broadcast()
# Make sure everyone sent results and there were no exceptions (e.g. concurrency) # Make sure everyone sent results and there were no exceptions (e.g. concurrency)
if not all(result is not None and not isinstance(result, Exception) for result in results.values()): failed = not isinstance(results, dict)
failed = failed or not any(
result is None or isinstance(result, Exception) for result in results.values()
)
if failed:
# This should essentially never happen # This should essentially never happen
# Either some of the shards could not make a snapshot (e.g. Discord client issues) # Either some of the shards could not make a snapshot (e.g. Discord client issues)
# or they disconnected in the process. # or they disconnected in the process.

View File

@@ -7,7 +7,7 @@ from discord.ext import commands
from meta import LionBot, conf, sharding, appname, shard_talk from meta import LionBot, conf, sharding, appname, shard_talk
from meta.app import shardname from meta.app import shardname
from meta.logger import log_context, log_action_stack, logging_context, setup_main_logger from meta.logger import log_context, log_action_stack, setup_main_logger
from meta.context import ctx_bot from meta.context import ctx_bot
from data import Database from data import Database
@@ -30,7 +30,7 @@ db = Database(conf.data['args'])
async def main(): async def main():
log_action_stack.set(["Initialising"]) log_action_stack.set(("Initialising",))
logger.info("Initialising StudyLion") logger.info("Initialising StudyLion")
intents = discord.Intents.all() intents = discord.Intents.all()
@@ -73,12 +73,12 @@ async def main():
) as lionbot: ) as lionbot:
ctx_bot.set(lionbot) ctx_bot.set(lionbot)
try: try:
with logging_context(context=f"APP: {appname}"): log_context.set(f"APP: {appname}")
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
await lionbot.start(conf.bot['TOKEN']) await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError: except asyncio.CancelledError:
with logging_context(context=f"APP: {appname}", action="Shutting Down"): log_context.set(f"APP: {appname}")
logger.info("StudyLion closed, shutting down.", exc_info=True) logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
def _main(): def _main():

View File

@@ -13,7 +13,7 @@ from aiohttp import ClientSession
from data import Database from data import Database
from .config import Conf from .config import Conf
from .logger import logging_context, log_context, log_action_stack from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
from .context import context from .context import context
from .LionContext import LionContext from .LionContext import LionContext
from .LionTree import LionTree from .LionTree import LionTree
@@ -46,6 +46,7 @@ class LionBot(Bot):
self.translator = translator self.translator = translator
self._locks = WeakValueDictionary() self._locks = WeakValueDictionary()
self._running_events = set()
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}") log_context.set(f"APP: {self.application_id}")
@@ -64,27 +65,45 @@ class LionBot(Bot):
await self.tree.sync(guild=guild) await self.tree.sync(guild=guild)
async def add_cog(self, cog: Cog, **kwargs): async def add_cog(self, cog: Cog, **kwargs):
with logging_context(action=f"Attach {cog.__cog_name__}"): sup = super()
@log_wrap(action=f"Attach {cog.__cog_name__}")
async def wrapper():
logger.info(f"Attaching Cog {cog.__cog_name__}") logger.info(f"Attaching Cog {cog.__cog_name__}")
await super().add_cog(cog, **kwargs) await sup.add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.") logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
await wrapper()
async def load_extension(self, name, *, package=None, **kwargs): async def load_extension(self, name, *, package=None, **kwargs):
with logging_context(action=f"Load {name.strip('.')}"): sup = super()
@log_wrap(action=f"Load {name.strip('.')}")
async def wrapper():
logger.info(f"Loading extension {name} in package {package}.") logger.info(f"Loading extension {name} in package {package}.")
await super().load_extension(name, package=package, **kwargs) await sup.load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.") logger.debug(f"Loaded extension {name} in package {package}.")
await wrapper()
async def start(self, token: str, *, reconnect: bool = True): async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"): with logging_context(action="Login"):
await self.login(token) start_task = asyncio.create_task(self.login(token))
with logging_context(stack=["Running"]): await start_task
await self.connect(reconnect=reconnect)
with logging_context(stack=("Running",)):
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
await run_task
def dispatch(self, event_name: str, *args, **kwargs): def dispatch(self, event_name: str, *args, **kwargs):
with logging_context(action=f"Dispatch {event_name}"): with logging_context(action=f"Dispatch {event_name}"):
super().dispatch(event_name, *args, **kwargs) super().dispatch(event_name, *args, **kwargs)
def _schedule_event(self, coro, event_name, *args, **kwargs):
"""
Extends client._schedule_event to keep a persistent
background task store.
"""
task = super()._schedule_event(coro, event_name, *args, **kwargs)
self._running_events.add(task)
task.add_done_callback(lambda fut: self._running_events.discard(fut))
def idlock(self, snowflakeid): def idlock(self, snowflakeid):
lock = self._locks.get(snowflakeid, None) lock = self._locks.get(snowflakeid, None)
if lock is None: if lock is None:
@@ -124,9 +143,11 @@ class LionBot(Bot):
async def on_command_error(self, ctx, exception): async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback # TODO: Some of these could have more user-feedback
cmd_str = str(ctx.command) logger.debug(f"Handling command error for {ctx}: {exception}")
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command: if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict() cmd_str = ctx.command.app_command.to_dict()
else:
cmd_str = str(ctx.command)
try: try:
raise exception raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError): except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
@@ -191,14 +212,14 @@ class LionBot(Bot):
pass pass
finally: finally:
exception.original = HandledException(exception.original) exception.original = HandledException(exception.original)
except CheckFailure: except CheckFailure as e:
logger.debug( logger.debug(
f"Command failed check: {exception}", f"Command failed check: {e}",
extra={'action': 'BotError', 'with_ctx': True} extra={'action': 'BotError', 'with_ctx': True}
) )
try: try:
await ctx.error_reply(exception.message) await ctx.error_reply(str(e))
except Exception: except discord.HTTPException:
pass pass
except Exception: except Exception:
# Completely unknown exception outside of command invocation! # Completely unknown exception outside of command invocation!
@@ -209,6 +230,5 @@ class LionBot(Bot):
) )
def add_command(self, command): def add_command(self, command):
if hasattr(command, '_placeholder_group_'): if not hasattr(command, '_placeholder_group_'):
return super().add_command(command)
super().add_command(command)

View File

@@ -6,13 +6,17 @@ from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace from discord.app_commands.namespace import Namespace
from .logger import logging_context from .logger import logging_context, set_logging_context, log_wrap
from .errors import SafeCancellation from .errors import SafeCancellation
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LionTree(CommandTree): class LionTree(CommandTree):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._call_tasks = set()
async def on_error(self, interaction, error) -> None: async def on_error(self, interaction, error) -> None:
try: try:
if isinstance(error, CommandInvokeError): if isinstance(error, CommandInvokeError):
@@ -25,55 +29,66 @@ class LionTree(CommandTree):
except Exception: except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'}) logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
def _from_interaction(self, interaction: Interaction) -> None:
@log_wrap(context=f"iid: {interaction.id}", isolate=False)
async def wrapper():
try:
await self._call(interaction)
except AppCommandError as e:
await self._dispatch_error(interaction, e)
task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
self._call_tasks.add(task)
task.add_done_callback(lambda fut: self._call_tasks.discard(fut))
async def _call(self, interaction): async def _call(self, interaction):
with logging_context(context=f"iid: {interaction.id}"): if not await self.interaction_check(interaction):
if not await self.interaction_check(interaction): interaction.command_failed = True
interaction.command_failed = True return
return
data = interaction.data # type: ignore data = interaction.data # type: ignore
type = data.get('type', 1) type = data.get('type', 1)
if type != 1: if type != 1:
# Context menu command... # Context menu command...
await self._call_context_menu(interaction, data, type) await self._call_context_menu(interaction, data, type)
return return
command, options = self._get_app_command_options(data) command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation # Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command interaction._cs_command = command
# At this point options refers to the arguments of the command # At this point options refers to the arguments of the command
# and command refers to the class type we care about # and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options) namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above # Same pre-fill as above
interaction._cs_namespace = namespace interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is. # Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete: if interaction.type is InteractionType.autocomplete:
with logging_context(action=f"Acmp {command.qualified_name}"): set_logging_context(action=f"Acmp {command.qualified_name}")
focused = next((opt['name'] for opt in options if opt.get('focused')), None) focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None: if focused is None:
raise AppCommandError( raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.' 'This should not happen, but there is no focused element. This is a Discord bug.'
) )
await command._invoke_autocomplete(interaction, focused, namespace) await command._invoke_autocomplete(interaction, focused, namespace)
return return
with logging_context(action=f"Run {command.qualified_name}"): set_logging_context(action=f"Run {command.qualified_name}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}") logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try: try:
await command._invoke_with_namespace(interaction, namespace) await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e: except AppCommandError as e:
interaction.command_failed = True interaction.command_failed = True
await command._invoke_error_handlers(interaction, e) await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e) await self.on_error(interaction, e)
else: else:
if not interaction.command_failed: if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command) self.client.dispatch('app_command_completion', interaction, command)
finally: finally:
if interaction.command_failed: if interaction.command_failed:
logger.debug("Command completed with errors.") logger.debug("Command completed with errors.")
else: else:
logger.debug("Command completed without errors.") logger.debug("Command completed without errors.")

View File

@@ -3,7 +3,7 @@ import asyncio
import logging import logging
import pickle import pickle
from ..logger import logging_context from ..logger import logging_context, log_wrap, set_logging_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -99,68 +99,70 @@ class AppClient:
# TODO # TODO
... ...
@log_wrap(action="Req")
async def request(self, appid, payload: 'AppPayload', wait_for_reply=True): async def request(self, appid, payload: 'AppPayload', wait_for_reply=True):
with logging_context(action=f"Req {appid}"): set_logging_context(action=appid)
try: try:
if appid not in self.peers: if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.") raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}") logger.debug(f"Sending request to app '{appid}' with payload {payload}")
address = self.peers[appid] address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address) reader, writer = await asyncio.open_connection(**address)
writer.write(payload.encoded()) writer.write(payload.encoded())
await writer.drain() await writer.drain()
writer.write_eof() writer.write_eof()
if wait_for_reply: if wait_for_reply:
result = await reader.read() result = await reader.read()
writer.close() writer.close()
decoded = payload.route.decode(result) decoded = payload.route.decode(result)
return decoded return decoded
else: else:
return None
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None return None
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None
@log_wrap(action="Broadcast")
async def requestall(self, payload, except_self=True, only_my_peers=True): async def requestall(self, payload, except_self=True, only_my_peers=True):
with logging_context(action="Broadcast"): peerlist = self.my_peers if only_my_peers else self.peers
peerlist = self.my_peers if only_my_peers else self.peers results = await asyncio.gather(
results = await asyncio.gather( *(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)), return_exceptions=True
return_exceptions=True )
) return dict(zip(self.peers.keys(), results))
return dict(zip(self.peers.keys(), results))
async def handle_request(self, reader, writer): async def handle_request(self, reader, writer):
with logging_context(action="SERV"): set_logging_context(action="SERV")
data = await reader.read() data = await reader.read()
loaded = pickle.loads(data) loaded = pickle.loads(data)
route, args, kwargs = loaded route, args, kwargs = loaded
with logging_context(action=f"SERV {route}"): set_logging_context(action=route)
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes: logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
@log_wrap(stack=("ShardTalk",))
async def connect(self): async def connect(self):
""" """
Start the local peer server. Start the local peer server.
Connect to the address server. Connect to the address server.
""" """
with logging_context(stack=['ShardTalk']): # Start the client server
# Start the client server self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
logger.info(f"Serving on {self.address}") logger.info(f"Serving on {self.address}")
await self.server_connection() await self.server_connection()
class AppPayload: class AppPayload:

View File

@@ -4,7 +4,7 @@ import logging
import string import string
import random import random
from ..logger import log_context, log_app, logging_context, setup_main_logger from ..logger import log_context, log_app, setup_main_logger, set_logging_context, log_wrap
from ..config import conf from ..config import conf
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -75,27 +75,27 @@ class AppServer:
""" """
Register and hold a new client connection. Register and hold a new client connection.
""" """
with logging_context(action=f"CONN {appid}"): set_logging_context(action=f"CONN {appid}")
reader, writer = connection reader, writer = connection
# Add the new client # Add the new client
self.clients[appid] = (address, connection) self.clients[appid] = (address, connection)
# Send the new client a client list # Send the new client a client list
peers = self.peer_list() peers = self.peer_list()
writer.write(pickle.dumps(peers)) writer.write(pickle.dumps(peers))
writer.write(b'\n') writer.write(b'\n')
await writer.drain() await writer.drain()
# Announce the new client to everyone # Announce the new client to everyone
await self.broadcast('new_peer', (), {'appid': appid, 'address': address}) await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
# Keep the connection open until socket closed or EOF (indicating client death) # Keep the connection open until socket closed or EOF (indicating client death)
try: try:
await reader.read() await reader.read()
finally: finally:
# Connection ended or it broke # Connection ended or it broke
logger.info(f"Lost client '{appid}'") logger.info(f"Lost client '{appid}'")
await self.deregister_client(appid) await self.deregister_client(appid)
async def handle_connection(self, reader, writer): async def handle_connection(self, reader, writer):
data = await reader.readline() data = await reader.readline()
@@ -103,17 +103,17 @@ class AppServer:
rqid = short_uuid() rqid = short_uuid()
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"): set_logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}")
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}") logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes: if route in self.routes:
# Execute route # Execute route
try: try:
await self.routes[route]((reader, writer), *args, **kwargs) await self.routes[route]((reader, writer), *args, **kwargs)
except Exception: except Exception:
logger.exception(f"AppServer recieved exception during route '{route}'") logger.exception(f"AppServer recieved exception during route '{route}'")
else: else:
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.") logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
def peer_list(self): def peer_list(self):
return {appid: address for appid, (address, _) in self.clients.items()} return {appid: address for appid, (address, _) in self.clients.items()}
@@ -122,27 +122,27 @@ class AppServer:
self.clients.pop(appid, None) self.clients.pop(appid, None)
await self.broadcast('drop_peer', (), {'appid': appid}) await self.broadcast('drop_peer', (), {'appid': appid})
@log_wrap(action="broadcast")
async def broadcast(self, route, args, kwargs): async def broadcast(self, route, args, kwargs):
with logging_context(action="broadcast"): logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.") payload = pickle.dumps((route, args, kwargs))
payload = pickle.dumps((route, args, kwargs)) if self.clients:
if self.clients: await asyncio.gather(
await asyncio.gather( *(self._send(appid, payload) for appid in self.clients),
*(self._send(appid, payload) for appid in self.clients), return_exceptions=True
return_exceptions=True )
)
async def message_client(self, appid, route, args, kwargs): async def message_client(self, appid, route, args, kwargs):
""" """
Send a message to client `appid` along `route` with given arguments. Send a message to client `appid` along `route` with given arguments.
""" """
with logging_context(action=f"MSG {appid}"): set_logging_context(action=f"MSG {appid}")
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.") logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
if appid not in self.clients: if appid not in self.clients:
raise ValueError(f"Client '{appid}' is not connected.") raise ValueError(f"Client '{appid}' is not connected.")
payload = pickle.dumps((route, args, kwargs)) payload = pickle.dumps((route, args, kwargs))
return await self._send(appid, payload) return await self._send(appid, payload)
async def _send(self, appid, payload): async def _send(self, appid, payload):
""" """
@@ -162,18 +162,19 @@ class AppServer:
async def start(self, address): async def start(self, address):
log_app.set("APPSERVER") log_app.set("APPSERVER")
with logging_context(stack=["SERV"]): set_logging_context(stack=("SERV",))
server = await asyncio.start_server(self.handle_connection, **address) server = await asyncio.start_server(self.handle_connection, **address)
logger.info(f"Serving on {address}") logger.info(f"Serving on {address}")
async with server: async with server:
await server.serve_forever() await server.serve_forever()
async def start_server(): async def start_server():
setup_main_logger() setup_main_logger()
address = {'host': '127.0.0.1', 'port': '5000'} address = {'host': '127.0.0.1', 'port': '5000'}
server = AppServer() server = AppServer()
await server.start(address) task = asyncio.create_task(server.start(address))
await task
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,7 +1,7 @@
import sys import sys
import logging import logging
import asyncio import asyncio
from typing import List from typing import List, Optional
from logging.handlers import QueueListener, QueueHandler from logging.handlers import QueueListener, QueueHandler
import queue import queue
import multiprocessing import multiprocessing
@@ -24,40 +24,117 @@ log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[]) log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number)) log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
def set_logging_context(
context: Optional[str] = None,
action: Optional[str] = None,
stack: Optional[tuple[str, ...]] = None
):
"""
Statically set the logging context variables to the given values.
If `action` is given, pushes it onto the `log_action_stack`.
"""
if context is not None:
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
@contextmanager @contextmanager
def logging_context(context=None, action=None, stack=None): def logging_context(context=None, action=None, stack=None):
"""
Context manager for executing a block of code in a given logging context.
This context manager should only be used around synchronous code.
This is because async code *may* get cancelled or externally garbage collected,
in which case the finally block will be executed in the wrong context.
See https://github.com/python/cpython/issues/93740
This can be refactored nicely if this gets merged:
https://github.com/python/cpython/pull/99634
(It will not necessarily break on async code,
if the async code can be guaranteed to clean up in its own context.)
"""
if context is not None: if context is not None:
context_t = log_context.set(context) oldcontext = log_context.get()
if action is not None: log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get() astack = log_action_stack.get()
log_action_stack.set(astack + [action]) newstack = stack if stack is not None else astack
if stack is not None: if action is not None:
actions_t = log_action_stack.set(stack) newstack = (*newstack, action)
log_action_stack.set(newstack)
try: try:
yield yield
finally: finally:
if context is not None: if context is not None:
log_context.reset(context_t) log_context.set(oldcontext)
if stack is not None: if stack is not None or action is not None:
log_action_stack.reset(actions_t)
if action is not None:
log_action_stack.set(astack) log_action_stack.set(astack)
def log_wrap(**kwargs): def with_log_ctx(isolate=True, **kwargs):
"""
Execute a coroutine inside a given logging context.
If `isolate` is true, ensures that context does not leak
outside the coroutine.
If `isolate` is false, just statically set the context,
which will leak unless the coroutine is
called in an externally copied context.
"""
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapped(*w_args, **w_kwargs): async def wrapped(*w_args, **w_kwargs):
with logging_context(**kwargs): if isolate:
with logging_context(**kwargs):
# Task creation will synchronously copy the context
# This is gc safe
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
return await task
else:
# This will leak context changes
set_logging_context(**kwargs)
return await func(*w_args, **w_kwargs) return await func(*w_args, **w_kwargs)
return wrapped return wrapped
return decorator return decorator
# For backwards compatibility
log_wrap = with_log_ctx
def persist_task(task_collection: set):
"""
Coroutine decorator that ensures the coroutine is scheduled as a task
and added to the given task_collection for strong reference
when it is called.
This is just a hack to handle discord.py events potentially
being unexpectedly garbage collected.
Since this also implicitly schedules the coroutine as a task when it is called,
the coroutine will also be run inside an isolated context.
"""
def decorator(coro):
@wraps(coro)
async def wrapped(*w_args, **w_kwargs):
name = f"persisted-{coro.__name__}"
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
task_collection.add(task)
task.add_done_callback(lambda f: task_collection.discard(f))
await task
RESET_SEQ = "\033[0m" RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm" COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m" BOLD_SEQ = "\033[1m"
@@ -208,7 +285,7 @@ class WebHookHandler(logging.StreamHandler):
async def post(self, record): async def post(self, record):
log_context.set("Webhook Logger") log_context.set("Webhook Logger")
log_action_stack.set(["Logging"]) log_action_stack.set(("Logging",))
log_app.set(record.app) log_app.set(record.app)
try: try:

View File

@@ -717,7 +717,7 @@ class Timer:
f"Timer <tid: {channelid}> deleted. Reason given: {reason!r}" f"Timer <tid: {channelid}> deleted. Reason given: {reason!r}"
) )
@log_wrap(stack=['Timer Loop']) @log_wrap(action='Timer Loop')
async def _runloop(self): async def _runloop(self):
""" """
Main loop which controls the Main loop which controls the

View File

@@ -27,7 +27,7 @@ from data.columns import Integer, String, Timestamp, Bool
from meta import LionBot, LionCog, LionContext from meta import LionBot, LionCog, LionContext
from meta.app import shard_talk, appname_from_shard from meta.app import shard_talk, appname_from_shard
from meta.logger import log_wrap, logging_context from meta.logger import log_wrap, logging_context, set_logging_context
from babel import ctx_translator, ctx_locale from babel import ctx_translator, ctx_locale
@@ -280,6 +280,7 @@ class Reminders(LionCog):
f"Scheduled new reminders: {tuple(reminder.reminderid for reminder in reminders)}", f"Scheduled new reminders: {tuple(reminder.reminderid for reminder in reminders)}",
) )
@log_wrap(action="Send Reminder")
async def execute_reminder(self, reminderid): async def execute_reminder(self, reminderid):
""" """
Send the reminder with the given reminderid. Send the reminder with the given reminderid.
@@ -287,64 +288,65 @@ class Reminders(LionCog):
This should in general only be executed from the executor shard, This should in general only be executed from the executor shard,
through a ReminderMonitor instance. through a ReminderMonitor instance.
""" """
with logging_context(action='Send Reminder', context=f"rid: {reminderid}"): set_logging_context(context=f"rid: {reminderid}")
reminder = await self.data.Reminder.fetch(reminderid)
if reminder is None:
logger.warning(
f"Attempted to execute a reminder <rid: {reminderid}> that no longer exists!"
)
return
try: reminder = await self.data.Reminder.fetch(reminderid)
# Try and find the user if reminder is None:
userid = reminder.userid logger.warning(
if not (user := self.bot.get_user(userid)): f"Attempted to execute a reminder <rid: {reminderid}> that no longer exists!"
user = await self.bot.fetch_user(userid) )
return
# Set the locale variables try:
locale = await self.bot.get_cog('BabelCog').get_user_locale(userid) # Try and find the user
ctx_locale.set(locale) userid = reminder.userid
ctx_translator.set(self.bot.translator) if not (user := self.bot.get_user(userid)):
user = await self.bot.fetch_user(userid)
# Build the embed # Set the locale variables
embed = reminder.embed locale = await self.bot.get_cog('BabelCog').get_user_locale(userid)
ctx_locale.set(locale)
ctx_translator.set(self.bot.translator)
# Attempt to send to user # Build the embed
# TODO: Consider adding a View to this, for cancelling a repeated reminder or showing reminders embed = reminder.embed
await user.send(embed=embed)
# Update the data as required # Attempt to send to user
if reminder.interval: # TODO: Consider adding a View to this, for cancelling a repeated reminder or showing reminders
now = utc_now() await user.send(embed=embed)
# Use original reminder time to calculate repeat, avoiding drift
next_time = reminder.remind_at + dt.timedelta(seconds=reminder.interval) # Update the data as required
# Skip any expired repeats, to avoid spamming requests after downtime if reminder.interval:
# TODO: Is this actually dst safe? now = utc_now()
while next_time.timestamp() <= now.timestamp(): # Use original reminder time to calculate repeat, avoiding drift
next_time = next_time + dt.timedelta(seconds=reminder.interval) next_time = reminder.remind_at + dt.timedelta(seconds=reminder.interval)
await reminder.update(remind_at=next_time) # Skip any expired repeats, to avoid spamming requests after downtime
self.monitor.schedule_task(reminder.reminderid, reminder.timestamp) # TODO: Is this actually dst safe?
logger.debug( while next_time.timestamp() <= now.timestamp():
f"Executed reminder <rid: {reminder.reminderid}> and scheduled repeat at {next_time}." next_time = next_time + dt.timedelta(seconds=reminder.interval)
) await reminder.update(remind_at=next_time)
else: self.monitor.schedule_task(reminder.reminderid, reminder.timestamp)
await reminder.delete()
logger.debug(
f"Executed reminder <rid: {reminder.reminderid}>."
)
except discord.HTTPException as e:
await reminder.update(failed=True)
logger.debug( logger.debug(
f"Reminder <rid: {reminder.reminderid}> could not be sent: {e.text}", f"Executed reminder <rid: {reminder.reminderid}> and scheduled repeat at {next_time}."
) )
except Exception: else:
await reminder.update(failed=True) await reminder.delete()
logger.exception( logger.debug(
f"Reminder <rid: {reminder.reminderid}> failed for an unknown reason!" f"Executed reminder <rid: {reminder.reminderid}>."
) )
finally: except discord.HTTPException as e:
# Dispatch for analytics await reminder.update(failed=True)
self.bot.dispatch('reminder_sent', reminder) logger.debug(
f"Reminder <rid: {reminder.reminderid}> could not be sent: {e.text}",
)
except Exception:
await reminder.update(failed=True)
logger.exception(
f"Reminder <rid: {reminder.reminderid}> failed for an unknown reason!"
)
finally:
# Dispatch for analytics
self.bot.dispatch('reminder_sent', reminder)
@cmds.hybrid_group( @cmds.hybrid_group(
name=_p('cmd:reminders', "reminders") name=_p('cmd:reminders', "reminders")

View File

@@ -13,7 +13,7 @@ from discord.ui.button import button
from discord.ui.text_input import TextStyle, TextInput from discord.ui.text_input import TextStyle, TextInput
from meta import LionCog, LionBot, LionContext from meta import LionCog, LionBot, LionContext
from meta.logger import logging_context, log_wrap from meta.logger import logging_context, log_wrap, set_logging_context
from meta.errors import UserInputError from meta.errors import UserInputError
from meta.app import shard_talk from meta.app import shard_talk
@@ -73,8 +73,7 @@ class Blacklists(LionCog):
f"Loaded {len(self.guild_blacklist)} blacklisted guilds." f"Loaded {len(self.guild_blacklist)} blacklisted guilds."
) )
if self.bot.is_ready(): if self.bot.is_ready():
with logging_context(action="Guild Blacklist"): await self.leave_blacklisted_guilds()
await self.leave_blacklisted_guilds()
@LionCog.listener('on_ready') @LionCog.listener('on_ready')
@log_wrap(action="Guild Blacklist") @log_wrap(action="Guild Blacklist")
@@ -84,8 +83,9 @@ class Blacklists(LionCog):
guild for guild in self.bot.guilds guild for guild in self.bot.guilds
if guild.id in self.guild_blacklist if guild.id in self.guild_blacklist
] ]
if to_leave:
asyncio.gather(*(guild.leave() for guild in to_leave)) tasks = [asyncio.create_task(guild.leave()) for guild in to_leave]
await asyncio.gather(*tasks)
logger.info( logger.info(
"Left {} blacklisted guilds.".format(len(to_leave)), "Left {} blacklisted guilds.".format(len(to_leave)),
@@ -95,12 +95,12 @@ class Blacklists(LionCog):
@log_wrap(action="Check Guild Blacklist") @log_wrap(action="Check Guild Blacklist")
async def check_guild_blacklist(self, guild): async def check_guild_blacklist(self, guild):
"""Check if the given guild is in the blacklist, and leave if so.""" """Check if the given guild is in the blacklist, and leave if so."""
with logging_context(context=f"gid: {guild.id}"): if guild.id in self.guild_blacklist:
if guild.id in self.guild_blacklist: set_logging_context(context=f"gid: {guild.id}")
await guild.leave() await guild.leave()
logger.info( logger.info(
"Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id) "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id)
) )
async def bot_check_once(self, ctx: LionContext) -> bool: # type:ignore async def bot_check_once(self, ctx: LionContext) -> bool: # type:ignore
if ctx.author.id in self.user_blacklist: if ctx.author.id in self.user_blacklist:

View File

@@ -19,7 +19,7 @@ from discord.ui import TextInput, View
from discord.ui.button import button from discord.ui.button import button
import discord.app_commands as appcmd import discord.app_commands as appcmd
from meta.logger import logging_context from meta.logger import logging_context, log_wrap
from meta.app import shard_talk from meta.app import shard_talk
from meta import conf from meta import conf
from meta.context import context, ctx_bot from meta.context import context, ctx_bot
@@ -185,54 +185,54 @@ def mk_print(fp: io.StringIO) -> Callable[..., None]:
return _print return _print
@log_wrap(action="Code Exec")
async def _async(to_eval: str, style='exec'): async def _async(to_eval: str, style='exec'):
with logging_context(action="Code Exec"): newline = '\n' * ('\n' in to_eval)
newline = '\n' * ('\n' in to_eval) logger.info(
logger.info( f"Exec code with {style}: {newline}{to_eval}"
f"Exec code with {style}: {newline}{to_eval}" )
output = io.StringIO()
_print = mk_print(output)
scope: dict[str, Any] = dict(sys.modules)
scope['__builtins__'] = builtins
scope.update(builtins.__dict__)
scope['ctx'] = ctx = context.get()
scope['bot'] = ctx_bot.get()
scope['print'] = _print # type: ignore
try:
if ctx and ctx.message:
source_str = f"<msg: {ctx.message.id}>"
elif ctx and ctx.interaction:
source_str = f"<iid: {ctx.interaction.id}>"
else:
source_str = "Unknown async"
code = compile(
to_eval,
source_str,
style,
ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
) )
func = types.FunctionType(code, scope)
output = io.StringIO() ret = func()
_print = mk_print(output) if inspect.iscoroutine(ret):
ret = await ret
if ret is not None:
_print(repr(ret))
except Exception:
_, exc, tb = sys.exc_info()
_print("".join(traceback.format_tb(tb)))
_print(f"{type(exc).__name__}: {exc}")
scope: dict[str, Any] = dict(sys.modules) result = output.getvalue().strip()
scope['__builtins__'] = builtins newline = '\n' * ('\n' in result)
scope.update(builtins.__dict__) logger.info(
scope['ctx'] = ctx = context.get() f"Exec complete, output: {newline}{result}"
scope['bot'] = ctx_bot.get() )
scope['print'] = _print # type: ignore
try:
if ctx and ctx.message:
source_str = f"<msg: {ctx.message.id}>"
elif ctx and ctx.interaction:
source_str = f"<iid: {ctx.interaction.id}>"
else:
source_str = "Unknown async"
code = compile(
to_eval,
source_str,
style,
ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
)
func = types.FunctionType(code, scope)
ret = func()
if inspect.iscoroutine(ret):
ret = await ret
if ret is not None:
_print(repr(ret))
except Exception:
_, exc, tb = sys.exc_info()
_print("".join(traceback.format_tb(tb)))
_print(f"{type(exc).__name__}: {exc}")
result = output.getvalue().strip()
newline = '\n' * ('\n' in result)
logger.info(
f"Exec complete, output: {newline}{result}"
)
return result return result