rewrite: Add base context ContextVar.

This commit is contained in:
2022-11-07 16:08:27 +02:00
parent 872e5fd71f
commit 3445bdff3e
4 changed files with 118 additions and 12 deletions

View File

@@ -6,6 +6,7 @@ from discord.ext import commands
from meta import LionBot, conf, sharding, appname, shard_talk from meta import LionBot, conf, sharding, appname, shard_talk
from meta.logger import log_context, log_action from meta.logger import log_context, log_action
from meta.context import context
from data import Database from data import Database
@@ -57,8 +58,14 @@ async def main():
initial_extensions=['modules'], initial_extensions=['modules'],
web_client=None, web_client=None,
app_ipc=shard_talk, app_ipc=shard_talk,
testing_guilds=[889875661848723456] testing_guilds=[889875661848723456],
shard_id=sharding.shard_number,
shard_count=sharding.shard_count
) as lionbot: ) as lionbot:
context.get().bot = lionbot
@lionbot.before_invoke
async def before_invoke(ctx):
print(ctx)
log_action.set("Launching") log_action.set("Launching")
await lionbot.start(conf.bot['TOKEN']) await lionbot.start(conf.bot['TOKEN'])

51
bot/meta/context.py Normal file
View File

@@ -0,0 +1,51 @@
from contextvars import ContextVar
class Context:
__slots__ = (
'bot',
'interaction', 'message',
'guild', 'channel', 'author', 'user'
)
def __init__(self, **kwargs):
self.bot = kwargs.pop('bot', None)
self.interaction = interaction = kwargs.pop('interaction', None)
self.message = message = kwargs.pop('message', interaction.message if interaction is not None else None)
guild = kwargs.pop('guild', None)
channel = kwargs.pop('channel', None)
author = kwargs.pop('author', None)
if message is not None:
guild = guild or message.guild
channel = channel or message.channel
author = author or message.author
elif interaction is not None:
guild = guild or interaction.guild
channel = channel or interaction.channel
author = author or interaction.user
self.guild = guild
self.channel = channel
self.author = self.user = author
def log_string(self):
"""Markdown formatted summary for live logging."""
parts = []
if self.interaction is not None:
parts.append(f"<int id={self.interaction.id} type={self.interaction.type.name}>")
if self.message is not None:
parts.append(f"<msg id={self.message.id}>")
if self.author is not None:
parts.append(f"<user id={self.author.id} name='{self.author.name}'>")
if self.channel is not None:
parts.append(f"<chan id={self.channel.id} name='{self.channel.name}'>")
if self.guild is not None:
parts.append(f"<guild id={self.guild.id} name='{self.guild.name}'>")
return " ".join(parts)
context = ContextVar('context', default=Context())

View File

@@ -7,12 +7,13 @@ from contextlib import contextmanager
from io import StringIO from io import StringIO
from contextvars import ContextVar from contextvars import ContextVar
from discord import AllowedMentions, Webhook, File from discord import Webhook, File
import aiohttp import aiohttp
from .config import conf from .config import conf
from . import sharding from . import sharding
from utils.lib import split_text, utc_now from .context import context
from utils.lib import utc_now
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT') log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
@@ -41,6 +42,7 @@ BOLD_SEQ = "\033[1m"
"]]]" "]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str: def colour_escape(fmt: str) -> str:
cmap = { cmap = {
'%(black)': COLOR_SEQ % BLACK, '%(black)': COLOR_SEQ % BLACK,
@@ -94,6 +96,7 @@ class ContextInjection(logging.Filter):
if not hasattr(record, 'action'): if not hasattr(record, 'action'):
record.action = log_action.get() record.action = log_action.get()
record.app = log_app.get() record.app = log_app.get()
record.ctx = context.get().log_string()
return True return True
@@ -149,8 +152,9 @@ class WebHookHandler(logging.StreamHandler):
async def post(self, record): async def post(self, record):
try: try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]\n" header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]"
message = header+record.msg context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900: if len(message) > 1900:
as_file = True as_file = True

View File

@@ -14,9 +14,15 @@ from enum import Enum
import discord import discord
from discord.ext import commands from discord.ext import commands
from discord.app_commands import command
from discord.ui import Modal, TextInput, View from discord.ui import Modal, TextInput, View
from discord.ui.button import button from discord.ui.button import button
from meta.logger import logging_context
from meta.app import shard_talk
from meta.context import context
from meta.context import Context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -145,7 +151,7 @@ class ExecUI(View):
async def compile(self): async def compile(self):
# Call _async # Call _async
result = await _async(self.ctx, self.code, style=self.style.value) result = await _async(self.code, style=self.style.value)
# Display output # Display output
await self.show_output(result) await self.show_output(result)
@@ -180,19 +186,31 @@ def mk_print(fp: io.StringIO) -> Callable[..., None]:
return _print return _print
async def _async(ctx: commands.Context, to_eval, style='exec'): async def _async(to_eval, style='exec'):
output = io.StringIO() output = io.StringIO()
_print = mk_print(output) _print = mk_print(output)
scope = dict(sys.modules) scope: dict[str, Any] = dict(sys.modules)
scope['__builtins__'] = builtins scope['__builtins__'] = builtins
scope.update(builtins.__dict__) scope.update(builtins.__dict__)
scope['ctx'] = ctx scope['ctx'] = ctx = context.get()
scope['bot'] = ctx.bot scope['bot'] = ctx.bot
scope['print'] = _print # type: ignore scope['print'] = _print # type: ignore
try: try:
code = compile(to_eval, f"<msg: {ctx.message.id}>", style, ast.PyCF_ALLOW_TOP_LEVEL_AWAIT) if ctx.message:
source_str = f"<msg: {ctx.message.id}>"
elif ctx.interaction:
source_str = f"<iid: {ctx.interaction.id}>"
else:
source_str = "Unknown async"
code = compile(
to_eval,
source_str,
style,
ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
)
func = types.FunctionType(code, scope) func = types.FunctionType(code, scope)
ret = func() ret = func()
@@ -203,7 +221,7 @@ async def _async(ctx: commands.Context, to_eval, style='exec'):
except Exception: except Exception:
_, exc, tb = sys.exc_info() _, exc, tb = sys.exc_info()
_print("".join(traceback.format_tb(tb))) _print("".join(traceback.format_tb(tb)))
_print(repr(exc)) _print(f"{type(exc).__name__}: {exc}")
result = output.getvalue() result = output.getvalue()
logger.info( logger.info(
@@ -219,8 +237,34 @@ class Exec(commands.Cog):
@commands.hybrid_command(name='async') @commands.hybrid_command(name='async')
async def async_cmd(self, ctx, *, string: str = None): async def async_cmd(self, ctx, *, string: str = None):
await ExecUI(ctx, string, ExecStyle.EXEC).run() context.set(Context(bot=self.bot, interaction=ctx.interaction))
with logging_context(context=f"mid:{ctx.message.id}", action="CMD ASYNC"):
logger.info("Running command")
await ExecUI(ctx, string, ExecStyle.EXEC).run()
@commands.hybrid_command(name='eval') @commands.hybrid_command(name='eval')
async def eval_cmd(self, ctx, *, string: str): async def eval_cmd(self, ctx, *, string: str):
await ExecUI(ctx, string, ExecStyle.EVAL).run() await ExecUI(ctx, string, ExecStyle.EVAL).run()
@command(name='test')
async def test_cmd(self, interaction: discord.Interaction, *, channel: discord.TextChannel = None):
if channel is None:
await interaction.response.send_message("Setting widget")
else:
await interaction.response.send_message(
f"Set channel to {channel} and then setting widget or update existing widget."
)
@command(name='shardeval')
async def shardeval_cmd(self, interaction: discord.Interaction, string: str):
await interaction.response.defer(thinking=True)
results = await shard_talk.requestall(exec_route(string))
blocks = []
for appid, result in results.items():
blocks.append(f"```md\n[{appid}]\n{result}```")
await interaction.edit_original_response(content='\n'.join(blocks))
@shard_talk.register_route('exec')
async def exec_route(string):
return await _async(string)