From 3445bdff3eadbe6cd19f8d1c3832c16ca9b1448e Mon Sep 17 00:00:00 2001 From: Conatum Date: Mon, 7 Nov 2022 16:08:27 +0200 Subject: [PATCH] rewrite: Add base `context` ContextVar. --- bot/main.py | 9 ++++- bot/meta/context.py | 51 +++++++++++++++++++++++++++ bot/meta/logger.py | 12 ++++--- bot/modules/bot_admin/exec_cog.py | 58 +++++++++++++++++++++++++++---- 4 files changed, 118 insertions(+), 12 deletions(-) create mode 100644 bot/meta/context.py diff --git a/bot/main.py b/bot/main.py index 309cb57e..d6279577 100644 --- a/bot/main.py +++ b/bot/main.py @@ -6,6 +6,7 @@ from discord.ext import commands from meta import LionBot, conf, sharding, appname, shard_talk from meta.logger import log_context, log_action +from meta.context import context from data import Database @@ -57,8 +58,14 @@ async def main(): initial_extensions=['modules'], web_client=None, app_ipc=shard_talk, - testing_guilds=[889875661848723456] + testing_guilds=[889875661848723456], + shard_id=sharding.shard_number, + shard_count=sharding.shard_count ) as lionbot: + context.get().bot = lionbot + @lionbot.before_invoke + async def before_invoke(ctx): + print(ctx) log_action.set("Launching") await lionbot.start(conf.bot['TOKEN']) diff --git a/bot/meta/context.py b/bot/meta/context.py new file mode 100644 index 00000000..4546a125 --- /dev/null +++ b/bot/meta/context.py @@ -0,0 +1,51 @@ +from contextvars import ContextVar + + +class Context: + __slots__ = ( + 'bot', + 'interaction', 'message', + 'guild', 'channel', 'author', 'user' + ) + + def __init__(self, **kwargs): + self.bot = kwargs.pop('bot', None) + + self.interaction = interaction = kwargs.pop('interaction', None) + self.message = message = kwargs.pop('message', interaction.message if interaction is not None else None) + + guild = kwargs.pop('guild', None) + channel = kwargs.pop('channel', None) + author = kwargs.pop('author', None) + + if message is not None: + guild = guild or message.guild + channel = channel or message.channel + author = author or message.author + elif interaction is not None: + guild = guild or interaction.guild + channel = channel or interaction.channel + author = author or interaction.user + + self.guild = guild + self.channel = channel + self.author = self.user = author + + def log_string(self): + """Markdown formatted summary for live logging.""" + parts = [] + if self.interaction is not None: + parts.append(f"") + if self.message is not None: + parts.append(f"") + if self.author is not None: + parts.append(f"") + if self.channel is not None: + parts.append(f"") + if self.guild is not None: + parts.append(f"") + + return " ".join(parts) + + +context = ContextVar('context', default=Context()) diff --git a/bot/meta/logger.py b/bot/meta/logger.py index fbd36589..e9dc386c 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -7,12 +7,13 @@ from contextlib import contextmanager from io import StringIO from contextvars import ContextVar -from discord import AllowedMentions, Webhook, File +from discord import Webhook, File import aiohttp from .config import conf from . import sharding -from utils.lib import split_text, utc_now +from .context import context +from utils.lib import utc_now log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') @@ -41,6 +42,7 @@ BOLD_SEQ = "\033[1m" "]]]" BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) + def colour_escape(fmt: str) -> str: cmap = { '%(black)': COLOR_SEQ % BLACK, @@ -94,6 +96,7 @@ class ContextInjection(logging.Filter): if not hasattr(record, 'action'): record.action = log_action.get() record.app = log_app.get() + record.ctx = context.get().log_string() return True @@ -149,8 +152,9 @@ class WebHookHandler(logging.StreamHandler): async def post(self, record): try: timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") - header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]\n" - message = header+record.msg + header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]" + context = f"\n# Context: {record.ctx}" if record.ctx else "" + message = f"{header}\n{record.msg}{context}" if len(message) > 1900: as_file = True diff --git a/bot/modules/bot_admin/exec_cog.py b/bot/modules/bot_admin/exec_cog.py index 23b2058d..7912861b 100644 --- a/bot/modules/bot_admin/exec_cog.py +++ b/bot/modules/bot_admin/exec_cog.py @@ -14,9 +14,15 @@ from enum import Enum import discord from discord.ext import commands +from discord.app_commands import command from discord.ui import Modal, TextInput, View from discord.ui.button import button +from meta.logger import logging_context +from meta.app import shard_talk +from meta.context import context +from meta.context import Context + logger = logging.getLogger(__name__) @@ -145,7 +151,7 @@ class ExecUI(View): async def compile(self): # Call _async - result = await _async(self.ctx, self.code, style=self.style.value) + result = await _async(self.code, style=self.style.value) # Display output await self.show_output(result) @@ -180,19 +186,31 @@ def mk_print(fp: io.StringIO) -> Callable[..., None]: return _print -async def _async(ctx: commands.Context, to_eval, style='exec'): +async def _async(to_eval, style='exec'): output = io.StringIO() _print = mk_print(output) - scope = dict(sys.modules) + scope: dict[str, Any] = dict(sys.modules) scope['__builtins__'] = builtins scope.update(builtins.__dict__) - scope['ctx'] = ctx + scope['ctx'] = ctx = context.get() scope['bot'] = ctx.bot scope['print'] = _print # type: ignore try: - code = compile(to_eval, f"", style, ast.PyCF_ALLOW_TOP_LEVEL_AWAIT) + if ctx.message: + source_str = f"" + elif ctx.interaction: + source_str = f"" + else: + source_str = "Unknown async" + + code = compile( + to_eval, + source_str, + style, + ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + ) func = types.FunctionType(code, scope) ret = func() @@ -203,7 +221,7 @@ async def _async(ctx: commands.Context, to_eval, style='exec'): except Exception: _, exc, tb = sys.exc_info() _print("".join(traceback.format_tb(tb))) - _print(repr(exc)) + _print(f"{type(exc).__name__}: {exc}") result = output.getvalue() logger.info( @@ -219,8 +237,34 @@ class Exec(commands.Cog): @commands.hybrid_command(name='async') async def async_cmd(self, ctx, *, string: str = None): - await ExecUI(ctx, string, ExecStyle.EXEC).run() + context.set(Context(bot=self.bot, interaction=ctx.interaction)) + with logging_context(context=f"mid:{ctx.message.id}", action="CMD ASYNC"): + logger.info("Running command") + await ExecUI(ctx, string, ExecStyle.EXEC).run() @commands.hybrid_command(name='eval') async def eval_cmd(self, ctx, *, string: str): await ExecUI(ctx, string, ExecStyle.EVAL).run() + + @command(name='test') + async def test_cmd(self, interaction: discord.Interaction, *, channel: discord.TextChannel = None): + if channel is None: + await interaction.response.send_message("Setting widget") + else: + await interaction.response.send_message( + f"Set channel to {channel} and then setting widget or update existing widget." + ) + + @command(name='shardeval') + async def shardeval_cmd(self, interaction: discord.Interaction, string: str): + await interaction.response.defer(thinking=True) + results = await shard_talk.requestall(exec_route(string)) + blocks = [] + for appid, result in results.items(): + blocks.append(f"```md\n[{appid}]\n{result}```") + await interaction.edit_original_response(content='\n'.join(blocks)) + + +@shard_talk.register_route('exec') +async def exec_route(string): + return await _async(string)