Initial Creation from bot template.
This commit is contained in:
344
src/meta/LionBot.py
Normal file
344
src/meta/LionBot.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload
|
||||
import logging
|
||||
import asyncio
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import discord
|
||||
from discord.utils import MISSING
|
||||
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from data import Database
|
||||
from utils.lib import tabulate
|
||||
|
||||
from .config import Conf
|
||||
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
|
||||
from .context import context
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
from .errors import HandledException, SafeCancellation
|
||||
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.cog import CoreCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||
initial_extensions: List[str], web_client: ClientSession,
|
||||
testing_guilds: List[int] = [], **kwargs
|
||||
):
|
||||
kwargs.setdefault('tree_cls', LionTree)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.web_client = web_client
|
||||
self.testing_guilds = testing_guilds
|
||||
self.initial_extensions = initial_extensions
|
||||
self.db = db
|
||||
self.appname = appname
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
|
||||
self.system_monitor = SystemMonitor()
|
||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||
self.system_monitor.add_component(self.monitor)
|
||||
|
||||
self._locks = WeakValueDictionary()
|
||||
self._running_events = set()
|
||||
|
||||
@property
|
||||
def core(self):
|
||||
return self.get_cog('CoreCog')
|
||||
|
||||
async def _monitor_status(self):
|
||||
if self.is_closed():
|
||||
level = StatusLevel.ERRORED
|
||||
info = "(ERROR) Websocket is closed"
|
||||
data = {}
|
||||
elif self.is_ws_ratelimited():
|
||||
level = StatusLevel.WAITING
|
||||
info = "(WAITING) Websocket is ratelimited"
|
||||
data = {}
|
||||
elif not self.is_ready():
|
||||
level = StatusLevel.STARTING
|
||||
info = "(STARTING) Not yet ready"
|
||||
data = {}
|
||||
else:
|
||||
level = StatusLevel.OKAY
|
||||
info = (
|
||||
"(OK) "
|
||||
"Logged in with {guild_count} guilds, "
|
||||
", websocket latency {latency}, and {events} running events."
|
||||
)
|
||||
data = {
|
||||
'guild_count': len(self.guilds),
|
||||
'latency': self.latency,
|
||||
'events': len(self._running_events),
|
||||
}
|
||||
return ComponentStatus(level, info, info, data)
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
log_context.set(f"APP: {self.application_id}")
|
||||
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(extension)
|
||||
|
||||
for guildid in self.testing_guilds:
|
||||
guild = discord.Object(guildid)
|
||||
if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)):
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
|
||||
# To make the type checker happy about fetching cogs by name
|
||||
# TODO: Move this to stubs at some point
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
...
|
||||
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
return super().get_cog(name)
|
||||
|
||||
async def add_cog(self, cog: Cog, **kwargs):
|
||||
sup = super()
|
||||
@log_wrap(action=f"Attach {cog.__cog_name__}")
|
||||
async def wrapper():
|
||||
logger.info(f"Attaching Cog {cog.__cog_name__}")
|
||||
await sup.add_cog(cog, **kwargs)
|
||||
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
|
||||
await wrapper()
|
||||
|
||||
async def load_extension(self, name, *, package=None, **kwargs):
|
||||
sup = super()
|
||||
@log_wrap(action=f"Load {name.strip('.')}")
|
||||
async def wrapper():
|
||||
logger.info(f"Loading extension {name} in package {package}.")
|
||||
await sup.load_extension(name, package=package, **kwargs)
|
||||
logger.debug(f"Loaded extension {name} in package {package}.")
|
||||
await wrapper()
|
||||
|
||||
async def start(self, token: str, *, reconnect: bool = True):
|
||||
with logging_context(action="Login"):
|
||||
start_task = asyncio.create_task(self.login(token))
|
||||
await start_task
|
||||
|
||||
with logging_context(stack=("Running",)):
|
||||
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
|
||||
await run_task
|
||||
|
||||
def dispatch(self, event_name: str, *args, **kwargs):
|
||||
with logging_context(action=f"Dispatch {event_name}"):
|
||||
super().dispatch(event_name, *args, **kwargs)
|
||||
|
||||
def _schedule_event(self, coro, event_name, *args, **kwargs):
|
||||
"""
|
||||
Extends client._schedule_event to keep a persistent
|
||||
background task store.
|
||||
"""
|
||||
task = super()._schedule_event(coro, event_name, *args, **kwargs)
|
||||
self._running_events.add(task)
|
||||
task.add_done_callback(lambda fut: self._running_events.discard(fut))
|
||||
|
||||
def idlock(self, snowflakeid):
|
||||
lock = self._locks.get(snowflakeid, None)
|
||||
if lock is None:
|
||||
lock = self._locks[snowflakeid] = asyncio.Lock()
|
||||
return lock
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info(
|
||||
f"Logged in as {self.application.name}\n"
|
||||
f"Application id {self.application.id}\n"
|
||||
f"Shard Talk identifier {self.shardname}\n"
|
||||
"------------------------------\n"
|
||||
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||
"------------------------------\n"
|
||||
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||
"Ready to take commands!\n",
|
||||
extra={'action': 'Ready'}
|
||||
)
|
||||
|
||||
async def get_context(self, origin, /, *, cls=MISSING):
|
||||
if cls is MISSING:
|
||||
cls = LionContext
|
||||
ctx = await super().get_context(origin, cls=cls)
|
||||
context.set(ctx)
|
||||
return ctx
|
||||
|
||||
async def on_command(self, ctx: LionContext):
|
||||
logger.info(
|
||||
f"Executing command '{ctx.command.qualified_name}' "
|
||||
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
|
||||
f"with interaction: {ctx.interaction.data if ctx.interaction else None}",
|
||||
extra={'with_ctx': True}
|
||||
)
|
||||
|
||||
async def on_command_error(self, ctx, exception):
|
||||
# TODO: Some of these could have more user-feedback
|
||||
logger.debug(f"Handling command error for {ctx}: {exception}")
|
||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||
cmd_str = ctx.command.app_command.to_dict()
|
||||
else:
|
||||
cmd_str = str(ctx.command)
|
||||
try:
|
||||
raise exception
|
||||
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
|
||||
try:
|
||||
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
|
||||
original = exception.original.original
|
||||
raise original
|
||||
else:
|
||||
original = exception.original
|
||||
raise original
|
||||
except HandledException:
|
||||
pass
|
||||
except TransformerError as e:
|
||||
msg = str(e)
|
||||
if msg:
|
||||
try:
|
||||
await ctx.error_reply(msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a transformer error: {repr(e)}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except SafeCancellation:
|
||||
if original.msg:
|
||||
try:
|
||||
await ctx.error_reply(original.msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a safe cancellation: {original.details}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.Forbidden:
|
||||
# Unknown uncaught Forbidden
|
||||
try:
|
||||
# Attempt a general error reply
|
||||
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
|
||||
except Exception:
|
||||
# We can't send anything at all. Exit quietly, but log.
|
||||
logger.warning(
|
||||
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.HTTPException:
|
||||
logger.error(
|
||||
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
error_embed = discord.Embed(
|
||||
title="Something went wrong!",
|
||||
colour=discord.Colour.dark_red()
|
||||
)
|
||||
error_embed.description = (
|
||||
"An unexpected error occurred while processing your command!\n"
|
||||
"Our development team has been notified, and the issue will be addressed soon.\n"
|
||||
"If the error persists, or you have any questions, please contact our [support team]({link}) "
|
||||
"and give them the extra details below."
|
||||
).format(link=self.config.bot.support_guild)
|
||||
details = {}
|
||||
details['error'] = f"`{repr(e)}`"
|
||||
if ctx.interaction:
|
||||
details['interactionid'] = f"`{ctx.interaction.id}`"
|
||||
if ctx.command:
|
||||
details['cmd'] = f"`{ctx.command.qualified_name}`"
|
||||
if ctx.author:
|
||||
details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`"
|
||||
if ctx.guild:
|
||||
details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`"
|
||||
details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`"
|
||||
if ctx.author:
|
||||
ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else ''
|
||||
details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`"
|
||||
if ctx.channel.type is discord.enums.ChannelType.private:
|
||||
details['channel'] = "`Direct Message`"
|
||||
elif ctx.channel:
|
||||
details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`"
|
||||
details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`"
|
||||
if ctx.author:
|
||||
details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`"
|
||||
details['shard'] = f"`{self.shardname}`"
|
||||
details['log_stack'] = f"`{log_action_stack.get()}`"
|
||||
|
||||
table = '\n'.join(tabulate(*details.items()))
|
||||
error_embed.add_field(name='Details', value=table)
|
||||
|
||||
try:
|
||||
await ctx.error_reply(embed=error_embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
finally:
|
||||
exception.original = HandledException(exception.original)
|
||||
except CheckFailure as e:
|
||||
logger.debug(
|
||||
f"Command failed check: {e}: {e.args}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
try:
|
||||
await ctx.error_reply(str(e))
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
except Exception:
|
||||
# Completely unknown exception outside of command invocation!
|
||||
# Something is very wrong here, don't attempt user interaction.
|
||||
logger.exception(
|
||||
f"Caught an unknown top-level exception while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
def add_command(self, command):
|
||||
if not hasattr(command, '_placeholder_group_'):
|
||||
super().add_command(command)
|
||||
|
||||
def request_chunking_for(self, guild):
|
||||
if not guild.chunked:
|
||||
return asyncio.create_task(
|
||||
self._connection.chunk_guild(guild, wait=False, cache=True),
|
||||
name=f"Background chunkreq for {guild.id}"
|
||||
)
|
||||
|
||||
async def on_interaction(self, interaction: discord.Interaction):
|
||||
"""
|
||||
Adds the interaction author to guild cache if appropriate.
|
||||
|
||||
This gets run a little bit late, so it is possible the interaction gets handled
|
||||
without the author being in case.
|
||||
"""
|
||||
guild = interaction.guild
|
||||
user = interaction.user
|
||||
if guild is not None and user is not None and isinstance(user, discord.Member):
|
||||
if not guild.get_member(user.id):
|
||||
guild._add_member(user)
|
||||
if guild is not None and not guild.chunked:
|
||||
# Getting an interaction in the guild is a good enough reason to request chunking
|
||||
logger.info(
|
||||
f"Unchunked guild <gid: {guild.id}> requesting chunking after interaction."
|
||||
)
|
||||
self.request_chunking_for(guild)
|
||||
58
src/meta/LionCog.py
Normal file
58
src/meta/LionCog.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any
|
||||
|
||||
from discord.ext.commands import Cog
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
|
||||
class LionCog(Cog):
|
||||
# A set of other cogs that this cog depends on
|
||||
depends_on: set['LionCog'] = set()
|
||||
_placeholder_groups_: set[str]
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
cls._placeholder_groups_ = set()
|
||||
|
||||
for base in reversed(cls.__mro__):
|
||||
for elem, value in base.__dict__.items():
|
||||
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
|
||||
cls._placeholder_groups_.add(value.name)
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any):
|
||||
# Patch to ensure no placeholder groups are in the command list
|
||||
self = super().__new__(cls)
|
||||
self.__cog_commands__ = [
|
||||
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
|
||||
]
|
||||
return self
|
||||
|
||||
async def _inject(self, bot, *args, **kwargs):
|
||||
if self.depends_on:
|
||||
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
|
||||
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
|
||||
|
||||
return await super()._inject(bot, *args, *kwargs)
|
||||
|
||||
@classmethod
|
||||
def placeholder_group(cls, group: cmds.HybridGroup):
|
||||
group._placeholder_group_ = True
|
||||
return group
|
||||
|
||||
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
|
||||
"""
|
||||
Crossload a placeholder group's commands into the target group
|
||||
"""
|
||||
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
|
||||
raise ValueError("Placeholder and target groups my be HypridGroups.")
|
||||
if placeholder_group.name not in self._placeholder_groups_:
|
||||
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
|
||||
if target_group.app_command is None:
|
||||
raise ValueError("Target group has no app_command to crossload into.")
|
||||
|
||||
for command in placeholder_group.commands:
|
||||
placeholder_group.remove_command(command.name)
|
||||
target_group.remove_command(command.name)
|
||||
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
|
||||
command.app_command = acmd
|
||||
target_group.add_command(command)
|
||||
195
src/meta/LionContext.py
Normal file
195
src/meta/LionContext.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import types
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from discord.enums import ChannelType
|
||||
from discord.ext.commands import Context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
Stuff that might be useful to implement (see cmdClient):
|
||||
sent_messages cache
|
||||
tasks cache
|
||||
error reply
|
||||
usage
|
||||
interaction cache
|
||||
View cache?
|
||||
setting access
|
||||
"""
|
||||
|
||||
|
||||
FlatContext = namedtuple(
|
||||
'FlatContext',
|
||||
('message',
|
||||
'interaction',
|
||||
'guild',
|
||||
'author',
|
||||
'channel',
|
||||
'alias',
|
||||
'prefix',
|
||||
'failed')
|
||||
)
|
||||
|
||||
|
||||
class LionContext(Context['LionBot']):
|
||||
"""
|
||||
Represents the context a command is invoked under.
|
||||
|
||||
Extends Context to add Lion-specific methods and attributes.
|
||||
Also adds several contextual wrapped utilities for simpler user during command invocation.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
parts = {}
|
||||
if self.interaction is not None:
|
||||
parts['iid'] = self.interaction.id
|
||||
parts['itype'] = f"\"{self.interaction.type.name}\""
|
||||
if self.message is not None:
|
||||
parts['mid'] = self.message.id
|
||||
if self.author is not None:
|
||||
parts['uid'] = self.author.id
|
||||
parts['uname'] = f"\"{self.author.name}\""
|
||||
if self.channel is not None:
|
||||
parts['cid'] = self.channel.id
|
||||
if self.channel.type is ChannelType.private:
|
||||
parts['cname'] = f"\"{self.channel.recipient}\""
|
||||
else:
|
||||
parts['cname'] = f"\"{self.channel.name}\""
|
||||
if self.guild is not None:
|
||||
parts['gid'] = self.guild.id
|
||||
parts['gname'] = f"\"{self.guild.name}\""
|
||||
if self.command is not None:
|
||||
parts['cmd'] = f"\"{self.command.qualified_name}\""
|
||||
if self.invoked_with is not None:
|
||||
parts['alias'] = f"\"{self.invoked_with}\""
|
||||
if self.command_failed:
|
||||
parts['failed'] = self.command_failed
|
||||
|
||||
return "<LionContext: {}>".format(
|
||||
' '.join(f"{name}={value}" for name, value in parts.items())
|
||||
)
|
||||
|
||||
def flatten(self):
|
||||
"""Flat pure-data context information, for caching and logging."""
|
||||
return FlatContext(
|
||||
self.message.id,
|
||||
self.interaction.id if self.interaction is not None else None,
|
||||
self.guild.id if self.guild is not None else None,
|
||||
self.author.id if self.author is not None else None,
|
||||
self.channel.id if self.channel is not None else None,
|
||||
self.invoked_with,
|
||||
self.prefix,
|
||||
self.command_failed
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def util(cls, util_func):
|
||||
"""
|
||||
Decorator to make a utility function available as a Context instance method.
|
||||
"""
|
||||
setattr(cls, util_func.__name__, util_func)
|
||||
logger.debug(f"Attached context utility function: {util_func.__name__}")
|
||||
return util_func
|
||||
|
||||
@classmethod
|
||||
def wrappable_util(cls, util_func):
|
||||
"""
|
||||
Decorator to add a Wrappable utility function as a Context instance method.
|
||||
"""
|
||||
wrapped = Wrappable(util_func)
|
||||
setattr(cls, util_func.__name__, wrapped)
|
||||
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
|
||||
return wrapped
|
||||
|
||||
async def error_reply(self, content: Optional[str] = None, **kwargs):
|
||||
if content and 'embed' not in kwargs:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description=content
|
||||
)
|
||||
kwargs['embed'] = embed
|
||||
content = None
|
||||
|
||||
# Expect this may be run in highly unusual circumstances.
|
||||
# This should never error, or at least handle all errors.
|
||||
if self.interaction:
|
||||
kwargs.setdefault('ephemeral', True)
|
||||
try:
|
||||
await self.reply(content=content, **kwargs)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unknown exception in 'error_reply'.",
|
||||
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
|
||||
)
|
||||
|
||||
|
||||
class Wrappable:
|
||||
__slots__ = ('_func', 'wrappers')
|
||||
|
||||
def __init__(self, func):
|
||||
self._func = func
|
||||
self.wrappers = None
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self._func.__name__
|
||||
|
||||
def add_wrapper(self, func, name=None):
|
||||
self.wrappers = self.wrappers or {}
|
||||
name = name or func.__name__
|
||||
self.wrappers[name] = func
|
||||
logger.debug(
|
||||
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
|
||||
extra={'action': "Wrap Util"}
|
||||
)
|
||||
|
||||
def remove_wrapper(self, name):
|
||||
if not self.wrappers or name not in self.wrappers:
|
||||
raise ValueError(
|
||||
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
|
||||
)
|
||||
self.wrappers.pop(name)
|
||||
logger.debug(
|
||||
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
|
||||
extra={'action': "Unwrap Util"}
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.wrappers:
|
||||
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
|
||||
else:
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
def _wrapped(self, iter_wraps):
|
||||
next_wrap = next(iter_wraps, None)
|
||||
if next_wrap:
|
||||
def _func(*args, **kwargs):
|
||||
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
|
||||
else:
|
||||
_func = self._func
|
||||
return _func
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return types.MethodType(self, instance)
|
||||
|
||||
|
||||
LionContext.reply = Wrappable(LionContext.reply)
|
||||
|
||||
|
||||
# @LionContext.reply.add_wrapper
|
||||
# async def think(func, ctx, *args, **kwargs):
|
||||
# await ctx.channel.send("thinking")
|
||||
# await func(ctx, *args, **kwargs)
|
||||
150
src/meta/LionTree.py
Normal file
150
src/meta/LionTree.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord import Interaction
|
||||
from discord.app_commands import CommandTree
|
||||
from discord.app_commands.errors import AppCommandError, CommandInvokeError
|
||||
from discord.enums import InteractionType
|
||||
from discord.app_commands.namespace import Namespace
|
||||
|
||||
from utils.lib import tabulate
|
||||
|
||||
from .logger import logging_context, set_logging_context, log_wrap, log_action_stack
|
||||
from .errors import SafeCancellation
|
||||
from .config import conf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionTree(CommandTree):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._call_tasks = set()
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error) -> None:
|
||||
try:
|
||||
if isinstance(error, CommandInvokeError):
|
||||
raise error.original
|
||||
else:
|
||||
raise error
|
||||
except SafeCancellation:
|
||||
# Assume this has already been handled
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
|
||||
if interaction.type is not InteractionType.autocomplete:
|
||||
embed = self.bugsplat(interaction, error)
|
||||
await self.error_reply(interaction, embed)
|
||||
|
||||
async def error_reply(self, interaction, embed):
|
||||
if not interaction.is_expired():
|
||||
try:
|
||||
if interaction.response.is_done():
|
||||
await interaction.followup.send(embed=embed, ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
def bugsplat(self, interaction, e):
|
||||
error_embed = discord.Embed(title="Something went wrong!", colour=discord.Colour.red())
|
||||
error_embed.description = (
|
||||
"An unexpected error occurred during this interaction!\n"
|
||||
"Our development team has been notified, and the issue will be addressed soon.\n"
|
||||
"If the error persists, or you have any questions, please contact our [support team]({link}) "
|
||||
"and give them the extra details below."
|
||||
).format(link=interaction.client.config.bot.support_guild)
|
||||
details = {}
|
||||
details['error'] = f"`{repr(e)}`"
|
||||
details['interactionid'] = f"`{interaction.id}`"
|
||||
details['interactiontype'] = f"`{interaction.type}`"
|
||||
if interaction.command:
|
||||
details['cmd'] = f"`{interaction.command.qualified_name}`"
|
||||
if interaction.user:
|
||||
details['user'] = f"`{interaction.user.id}` -- `{interaction.user}`"
|
||||
if interaction.guild:
|
||||
details['guild'] = f"`{interaction.guild.id}` -- `{interaction.guild.name}`"
|
||||
details['my_guild_perms'] = f"`{interaction.guild.me.guild_permissions.value}`"
|
||||
if interaction.user:
|
||||
ownerstr = ' (owner)' if interaction.user.id == interaction.guild.owner_id else ''
|
||||
details['user_guild_perms'] = f"`{interaction.user.guild_permissions.value}{ownerstr}`"
|
||||
if interaction.channel.type is discord.enums.ChannelType.private:
|
||||
details['channel'] = "`Direct Message`"
|
||||
elif interaction.channel:
|
||||
details['channel'] = f"`{interaction.channel.id}` -- `{interaction.channel.name}`"
|
||||
details['my_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.guild.me).value}`"
|
||||
if interaction.user:
|
||||
details['user_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.user).value}`"
|
||||
details['shard'] = f"`{interaction.client.shardname}`"
|
||||
details['log_stack'] = f"`{log_action_stack.get()}`"
|
||||
|
||||
table = '\n'.join(tabulate(*details.items()))
|
||||
error_embed.add_field(name='Details', value=table)
|
||||
return error_embed
|
||||
|
||||
def _from_interaction(self, interaction: Interaction) -> None:
|
||||
@log_wrap(context=f"iid: {interaction.id}", isolate=False)
|
||||
async def wrapper():
|
||||
try:
|
||||
await self._call(interaction)
|
||||
except AppCommandError as e:
|
||||
await self._dispatch_error(interaction, e)
|
||||
|
||||
task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
|
||||
self._call_tasks.add(task)
|
||||
task.add_done_callback(lambda fut: self._call_tasks.discard(fut))
|
||||
|
||||
async def _call(self, interaction):
|
||||
if not await self.interaction_check(interaction):
|
||||
interaction.command_failed = True
|
||||
return
|
||||
|
||||
data = interaction.data # type: ignore
|
||||
type = data.get('type', 1)
|
||||
if type != 1:
|
||||
# Context menu command...
|
||||
await self._call_context_menu(interaction, data, type)
|
||||
return
|
||||
|
||||
command, options = self._get_app_command_options(data)
|
||||
|
||||
# Pre-fill the cached slot to prevent re-computation
|
||||
interaction._cs_command = command
|
||||
|
||||
# At this point options refers to the arguments of the command
|
||||
# and command refers to the class type we care about
|
||||
namespace = Namespace(interaction, data.get('resolved', {}), options)
|
||||
|
||||
# Same pre-fill as above
|
||||
interaction._cs_namespace = namespace
|
||||
|
||||
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
|
||||
if interaction.type is InteractionType.autocomplete:
|
||||
set_logging_context(action=f"Acmp {command.qualified_name}")
|
||||
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
|
||||
if focused is None:
|
||||
raise AppCommandError(
|
||||
'This should not happen, but there is no focused element. This is a Discord bug.'
|
||||
)
|
||||
try:
|
||||
await command._invoke_autocomplete(interaction, focused, namespace)
|
||||
except Exception as e:
|
||||
await self.on_error(interaction, e)
|
||||
return
|
||||
|
||||
set_logging_context(action=f"Run {command.qualified_name}")
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
|
||||
try:
|
||||
await command._invoke_with_namespace(interaction, namespace)
|
||||
except AppCommandError as e:
|
||||
interaction.command_failed = True
|
||||
await command._invoke_error_handlers(interaction, e)
|
||||
await self.on_error(interaction, e)
|
||||
else:
|
||||
if not interaction.command_failed:
|
||||
self.client.dispatch('app_command_completion', interaction, command)
|
||||
finally:
|
||||
if interaction.command_failed:
|
||||
logger.debug("Command completed with errors.")
|
||||
else:
|
||||
logger.debug("Command completed without errors.")
|
||||
15
src/meta/__init__.py
Normal file
15
src/meta/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .LionBot import LionBot
|
||||
from .LionCog import LionCog
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
|
||||
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||
from .config import conf, configEmoji
|
||||
from .args import args
|
||||
from .app import appname, appname_from_shard, shard_from_appname
|
||||
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||
from .context import context, ctx_bot
|
||||
|
||||
from . import sharding
|
||||
from . import logger
|
||||
from . import app
|
||||
32
src/meta/app.py
Normal file
32
src/meta/app.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
appname: str
|
||||
The base identifer for this application.
|
||||
This identifies which services the app offers.
|
||||
shardname: str
|
||||
The specific name of the running application.
|
||||
Only one process should be connecteded with a given appname.
|
||||
For the bot apps, usually specifies the shard id and shard number.
|
||||
"""
|
||||
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
|
||||
|
||||
from . import sharding, conf
|
||||
from .logger import log_app
|
||||
from .args import args
|
||||
|
||||
|
||||
appname = conf.data['appid']
|
||||
appid = appname # backwards compatibility
|
||||
|
||||
|
||||
def appname_from_shard(shardid):
|
||||
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
||||
return appname
|
||||
|
||||
|
||||
def shard_from_appname(appname: str):
|
||||
return int(appname.rsplit('_', maxsplit=1)[-1])
|
||||
|
||||
|
||||
shardname = appname_from_shard(sharding.shard_number)
|
||||
|
||||
log_app.set(shardname)
|
||||
35
src/meta/args.py
Normal file
35
src/meta/args.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import argparse
|
||||
|
||||
from constants import CONFIG_FILE
|
||||
|
||||
# ------------------------------
|
||||
# Parsed commandline arguments
|
||||
# ------------------------------
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--conf',
|
||||
dest='config',
|
||||
default=CONFIG_FILE,
|
||||
help="Path to configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shard',
|
||||
dest='shard',
|
||||
default=None,
|
||||
type=int,
|
||||
help="Shard number to run, if applicable."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
dest='host',
|
||||
default='127.0.0.1',
|
||||
help="IP address to run the app listener on."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
dest='port',
|
||||
default='5001',
|
||||
help="Port to run the app listener on."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
146
src/meta/config.py
Normal file
146
src/meta/config.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from discord import PartialEmoji
|
||||
import configparser as cfgp
|
||||
|
||||
from .args import args
|
||||
|
||||
shard_number = args.shard
|
||||
|
||||
class configEmoji(PartialEmoji):
|
||||
__slots__ = ('fallback',)
|
||||
|
||||
def __init__(self, *args, fallback=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fallback = fallback
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, emojistr: str):
|
||||
"""
|
||||
Parses emoji strings of one of the following forms
|
||||
`<a:name:id> or fallback`
|
||||
`<:name:id> or fallback`
|
||||
`<a:name:id>`
|
||||
`<:name:id>`
|
||||
"""
|
||||
splits = emojistr.rsplit(' or ', maxsplit=1)
|
||||
|
||||
fallback = splits[1] if len(splits) > 1 else None
|
||||
emojistr = splits[0].strip('<> ')
|
||||
animated, name, id = emojistr.split(':')
|
||||
return cls(
|
||||
name=name,
|
||||
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
|
||||
animated=bool(animated),
|
||||
id=int(id) if id else None
|
||||
)
|
||||
|
||||
|
||||
class MapDotProxy:
|
||||
"""
|
||||
Allows dot access to an underlying Mappable object.
|
||||
"""
|
||||
__slots__ = ("_map", "_converter")
|
||||
|
||||
def __init__(self, mappable, converter=None):
|
||||
self._map = mappable
|
||||
self._converter = converter
|
||||
|
||||
def __getattribute__(self, key):
|
||||
_map = object.__getattribute__(self, '_map')
|
||||
if key == '_map':
|
||||
return _map
|
||||
if key in _map:
|
||||
_converter = object.__getattribute__(self, '_converter')
|
||||
if _converter:
|
||||
return _converter(_map[key])
|
||||
else:
|
||||
return _map[key]
|
||||
else:
|
||||
return object.__getattribute__(_map, key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._map.__getitem__(key)
|
||||
|
||||
|
||||
class ConfigParser(cfgp.ConfigParser):
|
||||
"""
|
||||
Extension of base ConfigParser allowing optional
|
||||
section option retrieval without defaults.
|
||||
"""
|
||||
def options(self, section, no_defaults=False, **kwargs):
|
||||
if no_defaults:
|
||||
try:
|
||||
return list(self._sections[section].keys())
|
||||
except KeyError:
|
||||
raise cfgp.NoSectionError(section)
|
||||
else:
|
||||
return super().options(section, **kwargs)
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self, configfile, section_name="DEFAULT"):
|
||||
self.configfile = configfile
|
||||
|
||||
self.config = ConfigParser(
|
||||
converters={
|
||||
"intlist": self._getintlist,
|
||||
"list": self._getlist,
|
||||
"emoji": configEmoji.from_str,
|
||||
}
|
||||
)
|
||||
|
||||
with open(configfile) as conff:
|
||||
# Opening with read_file mainly to ensure the file exists
|
||||
self.config.read_file(conff)
|
||||
|
||||
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||
|
||||
self.default = self.config["DEFAULT"]
|
||||
self.section = MapDotProxy(self.config[self.section_name])
|
||||
self.bot = self.section
|
||||
|
||||
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||
read = set()
|
||||
while more_to_read:
|
||||
to_read = more_to_read.pop(0)
|
||||
read.add(to_read)
|
||||
self.config.read(to_read)
|
||||
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||
if path not in read and path not in more_to_read]
|
||||
more_to_read.extend(new_paths)
|
||||
|
||||
self.emojis = MapDotProxy(
|
||||
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
|
||||
converter=configEmoji.from_str
|
||||
)
|
||||
|
||||
global conf
|
||||
conf = self
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.section[key].strip()
|
||||
|
||||
def __getattr__(self, section):
|
||||
name = section.upper()
|
||||
shard_name = f"{name}-{shard_number}"
|
||||
if shard_name in self.config:
|
||||
return self.config[shard_name]
|
||||
else:
|
||||
return self.config[name]
|
||||
|
||||
def get(self, name, fallback=None):
|
||||
result = self.section.get(name, fallback)
|
||||
return result.strip() if result else result
|
||||
|
||||
def _getintlist(self, value):
|
||||
return [int(item.strip()) for item in value.split(',')]
|
||||
|
||||
def _getlist(self, value):
|
||||
return [item.strip() for item in value.split(',')]
|
||||
|
||||
def write(self):
|
||||
with open(self.configfile, 'w') as conffile:
|
||||
self.config.write(conffile)
|
||||
|
||||
|
||||
conf = Conf(args.config, 'BOT')
|
||||
20
src/meta/context.py
Normal file
20
src/meta/context.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Namespace for various global context variables.
|
||||
Allows asyncio callbacks to accurately retrieve information about the current state.
|
||||
"""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
from .LionContext import LionContext
|
||||
|
||||
|
||||
# Contains the current command context, if applicable
|
||||
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
|
||||
|
||||
# Contains the current LionBot instance
|
||||
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)
|
||||
64
src/meta/errors.py
Normal file
64
src/meta/errors.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
class SafeCancellation(Exception):
|
||||
"""
|
||||
Raised to safely cancel execution of the current operation.
|
||||
|
||||
If not caught, is expected to be propagated to the Tree and safely ignored there.
|
||||
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
|
||||
The error handler should then set the `msg` to None, to avoid double handling.
|
||||
Debugging information should go in `details`, to be logged by a top-level error handler.
|
||||
"""
|
||||
default_message = ""
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||
self._msg: Optional[str] = _msg
|
||||
self.details: str = details if details is not None else self.msg
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class UserInputError(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from unparseable user input.
|
||||
"""
|
||||
default_message = "Could not understand your input."
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
|
||||
self.info = info
|
||||
super().__init__(_msg, **kwargs)
|
||||
|
||||
|
||||
class UserCancelled(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from manual user cancellation.
|
||||
|
||||
Usually silent.
|
||||
"""
|
||||
default_msg = None
|
||||
|
||||
|
||||
class ResponseTimedOut(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from a user interaction time-out.
|
||||
"""
|
||||
default_msg = "Session timed out waiting for input."
|
||||
|
||||
|
||||
class HandledException(SafeCancellation):
|
||||
"""
|
||||
Sentinel class to indicate to error handlers that this exception has been handled.
|
||||
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
|
||||
"""
|
||||
def __init__(self, exc=None, **kwargs):
|
||||
self.exc = exc
|
||||
super().__init__(**kwargs)
|
||||
468
src/meta/logger.py
Normal file
468
src/meta/logger.py
Normal file
@@ -0,0 +1,468 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from logging.handlers import QueueListener, QueueHandler
|
||||
import queue
|
||||
import multiprocessing
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
|
||||
import discord
|
||||
from discord import Webhook, File
|
||||
import aiohttp
|
||||
|
||||
from .config import conf
|
||||
from . import sharding
|
||||
from .context import context
|
||||
from utils.lib import utc_now
|
||||
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||
|
||||
|
||||
log_logger = logging.getLogger(__name__)
|
||||
log_logger.propagate = False
|
||||
|
||||
|
||||
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
|
||||
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
|
||||
|
||||
def set_logging_context(
|
||||
context: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
stack: Optional[tuple[str, ...]] = None
|
||||
):
|
||||
"""
|
||||
Statically set the logging context variables to the given values.
|
||||
|
||||
If `action` is given, pushes it onto the `log_action_stack`.
|
||||
"""
|
||||
if context is not None:
|
||||
log_context.set(context)
|
||||
if action is not None or stack is not None:
|
||||
astack = log_action_stack.get()
|
||||
newstack = stack if stack is not None else astack
|
||||
if action is not None:
|
||||
newstack = (*newstack, action)
|
||||
log_action_stack.set(newstack)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def logging_context(context=None, action=None, stack=None):
|
||||
"""
|
||||
Context manager for executing a block of code in a given logging context.
|
||||
|
||||
This context manager should only be used around synchronous code.
|
||||
This is because async code *may* get cancelled or externally garbage collected,
|
||||
in which case the finally block will be executed in the wrong context.
|
||||
See https://github.com/python/cpython/issues/93740
|
||||
This can be refactored nicely if this gets merged:
|
||||
https://github.com/python/cpython/pull/99634
|
||||
|
||||
(It will not necessarily break on async code,
|
||||
if the async code can be guaranteed to clean up in its own context.)
|
||||
"""
|
||||
if context is not None:
|
||||
oldcontext = log_context.get()
|
||||
log_context.set(context)
|
||||
if action is not None or stack is not None:
|
||||
astack = log_action_stack.get()
|
||||
newstack = stack if stack is not None else astack
|
||||
if action is not None:
|
||||
newstack = (*newstack, action)
|
||||
log_action_stack.set(newstack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if context is not None:
|
||||
log_context.set(oldcontext)
|
||||
if stack is not None or action is not None:
|
||||
log_action_stack.set(astack)
|
||||
|
||||
|
||||
def with_log_ctx(isolate=True, **kwargs):
|
||||
"""
|
||||
Execute a coroutine inside a given logging context.
|
||||
|
||||
If `isolate` is true, ensures that context does not leak
|
||||
outside the coroutine.
|
||||
|
||||
If `isolate` is false, just statically set the context,
|
||||
which will leak unless the coroutine is
|
||||
called in an externally copied context.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
if isolate:
|
||||
with logging_context(**kwargs):
|
||||
# Task creation will synchronously copy the context
|
||||
# This is gc safe
|
||||
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
|
||||
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
|
||||
return await task
|
||||
else:
|
||||
# This will leak context changes
|
||||
set_logging_context(**kwargs)
|
||||
return await func(*w_args, **w_kwargs)
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
log_wrap = with_log_ctx
|
||||
|
||||
|
||||
def persist_task(task_collection: set):
|
||||
"""
|
||||
Coroutine decorator that ensures the coroutine is scheduled as a task
|
||||
and added to the given task_collection for strong reference
|
||||
when it is called.
|
||||
|
||||
This is just a hack to handle discord.py events potentially
|
||||
being unexpectedly garbage collected.
|
||||
|
||||
Since this also implicitly schedules the coroutine as a task when it is called,
|
||||
the coroutine will also be run inside an isolated context.
|
||||
"""
|
||||
def decorator(coro):
|
||||
@wraps(coro)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
name = f"persisted-{coro.__name__}"
|
||||
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
|
||||
task_collection.add(task)
|
||||
task.add_done_callback(lambda f: task_collection.discard(f))
|
||||
await task
|
||||
|
||||
|
||||
RESET_SEQ = "\033[0m"
|
||||
COLOR_SEQ = "\033[3%dm"
|
||||
BOLD_SEQ = "\033[1m"
|
||||
"]]]"
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||
|
||||
|
||||
def colour_escape(fmt: str) -> str:
|
||||
cmap = {
|
||||
'%(black)': COLOR_SEQ % BLACK,
|
||||
'%(red)': COLOR_SEQ % RED,
|
||||
'%(green)': COLOR_SEQ % GREEN,
|
||||
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||
'%(blue)': COLOR_SEQ % BLUE,
|
||||
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||
'%(cyan)': COLOR_SEQ % CYAN,
|
||||
'%(white)': COLOR_SEQ % WHITE,
|
||||
'%(reset)': RESET_SEQ,
|
||||
'%(bold)': BOLD_SEQ,
|
||||
}
|
||||
for key, value in cmap.items():
|
||||
fmt = fmt.replace(key, value)
|
||||
return fmt
|
||||
|
||||
|
||||
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
|
||||
'%(cyan)%(app)-15s%(reset)|' +
|
||||
'%(cyan)%(context)-24s%(reset)|' +
|
||||
'%(cyan)%(actionstr)-22s%(reset)|' +
|
||||
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||
log_format = colour_escape(log_format)
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=log_format,
|
||||
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
class ExactLevelFilter(logging.Filter):
|
||||
def __init__(self, target_level, name=""):
|
||||
super().__init__(name)
|
||||
self.target_level = target_level
|
||||
|
||||
def filter(self, record):
|
||||
return (record.levelno == self.target_level)
|
||||
|
||||
|
||||
class ThreadFilter(logging.Filter):
|
||||
def __init__(self, thread_name):
|
||||
super().__init__("")
|
||||
self.thread = thread_name
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.threadName == self.thread else 0
|
||||
|
||||
|
||||
class ContextInjection(logging.Filter):
|
||||
def filter(self, record):
|
||||
# These guards are to allow override through _extra
|
||||
# And to ensure the injection is idempotent
|
||||
if not hasattr(record, 'context'):
|
||||
record.context = log_context.get()
|
||||
|
||||
if not hasattr(record, 'actionstr'):
|
||||
action_stack = log_action_stack.get()
|
||||
if hasattr(record, 'action'):
|
||||
action_stack = (*action_stack, record.action)
|
||||
if action_stack:
|
||||
record.actionstr = ' ➔ '.join(action_stack)
|
||||
else:
|
||||
record.actionstr = "Unknown Action"
|
||||
|
||||
if not hasattr(record, 'app'):
|
||||
record.app = log_app.get()
|
||||
|
||||
if not hasattr(record, 'ctx'):
|
||||
if ctx := context.get():
|
||||
record.ctx = repr(ctx)
|
||||
else:
|
||||
record.ctx = None
|
||||
|
||||
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||
record.ctxstr = '\n' + record.ctx
|
||||
else:
|
||||
record.ctxstr = ""
|
||||
return True
|
||||
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_out)
|
||||
log_logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logging_handler_err.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_err)
|
||||
log_logger.addHandler(logging_handler_err)
|
||||
|
||||
|
||||
class LocalQueueHandler(QueueHandler):
|
||||
def _emit(self, record: logging.LogRecord) -> None:
|
||||
# Removed the call to self.prepare(), handle task cancellation
|
||||
try:
|
||||
self.enqueue(record)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class WebHookHandler(logging.StreamHandler):
|
||||
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
|
||||
super().__init__()
|
||||
self.webhook_url = webhook_url
|
||||
self.prefix = prefix
|
||||
self.batched = ""
|
||||
self.batch = batch
|
||||
self.loop = loop
|
||||
self.batch_delay = 10
|
||||
self.batch_task = None
|
||||
self.last_batched = None
|
||||
self.waiting = []
|
||||
|
||||
self.bucket = Bucket(20, 40)
|
||||
self.ignored = 0
|
||||
|
||||
self.session = None
|
||||
self.webhook = None
|
||||
|
||||
def get_loop(self):
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
return self.loop
|
||||
|
||||
def emit(self, record):
|
||||
self.format(record)
|
||||
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||
|
||||
def _post(self, record):
|
||||
if self.session is None:
|
||||
self.setup()
|
||||
asyncio.create_task(self.post(record))
|
||||
|
||||
def setup(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
|
||||
|
||||
async def post(self, record):
|
||||
if record.context == 'Webhook Logger':
|
||||
# Don't livelog livelog errors
|
||||
# Otherwise we recurse and Cloudflare hates us
|
||||
return
|
||||
log_context.set("Webhook Logger")
|
||||
log_action_stack.set(("Logging",))
|
||||
log_app.set(record.app)
|
||||
|
||||
try:
|
||||
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||
message = f"{header}\n{record.msg}{context}"
|
||||
|
||||
if len(message) > 1900:
|
||||
as_file = True
|
||||
else:
|
||||
as_file = False
|
||||
message = "```md\n{}\n```".format(message)
|
||||
|
||||
# Post the log message(s)
|
||||
if self.batch:
|
||||
if len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
await self._send(message, as_file=as_file)
|
||||
else:
|
||||
self.batched += message
|
||||
if len(self.batched) + len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
else:
|
||||
asyncio.create_task(self._schedule_batched())
|
||||
else:
|
||||
await self._send(message, as_file=as_file)
|
||||
except Exception as ex:
|
||||
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
|
||||
|
||||
async def _schedule_batched(self):
|
||||
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||
# noop, don't reschedule if it is already scheduled
|
||||
return
|
||||
try:
|
||||
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||
await self.batch_task
|
||||
await self._send_batched()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as ex:
|
||||
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
|
||||
|
||||
async def _send_batched_now(self):
|
||||
if self.batch_task is not None and not self.batch_task.done():
|
||||
self.batch_task.cancel()
|
||||
self.last_batched = None
|
||||
await self._send_batched()
|
||||
|
||||
async def _send_batched(self):
|
||||
if self.batched:
|
||||
batched = self.batched
|
||||
self.batched = ""
|
||||
await self._send(batched)
|
||||
|
||||
async def _send(self, message, as_file=False):
|
||||
try:
|
||||
self.bucket.request()
|
||||
except BucketOverFull:
|
||||
# Silently ignore
|
||||
self.ignored += 1
|
||||
return
|
||||
except BucketFull:
|
||||
logger.warning(
|
||||
"Can't keep up! "
|
||||
f"Ignoring records on live-logger {self.webhook.id}."
|
||||
)
|
||||
self.ignored += 1
|
||||
return
|
||||
else:
|
||||
if self.ignored > 0:
|
||||
logger.warning(
|
||||
"Can't keep up! "
|
||||
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
|
||||
)
|
||||
self.ignored = 0
|
||||
|
||||
try:
|
||||
if as_file or len(message) > 1900:
|
||||
with StringIO(message) as fp:
|
||||
fp.seek(0)
|
||||
await self.webhook.send(
|
||||
f"{self.prefix}\n`{message.splitlines()[0]}`",
|
||||
file=File(fp, filename="logs.md"),
|
||||
username=log_app.get()
|
||||
)
|
||||
else:
|
||||
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
|
||||
except discord.HTTPException:
|
||||
logger.exception(
|
||||
"Live logger errored. Slowing down live logger."
|
||||
)
|
||||
self.bucket.fill()
|
||||
|
||||
|
||||
handlers = []
|
||||
if webhook := conf.logging['general_log']:
|
||||
handler = WebHookHandler(webhook, batch=True)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['warning_log']:
|
||||
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
|
||||
handler.addFilter(ExactLevelFilter(logging.WARNING))
|
||||
handler.setLevel(logging.WARNING)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['error_log']:
|
||||
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
|
||||
handler.setLevel(logging.ERROR)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['critical_log']:
|
||||
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
|
||||
handler.setLevel(logging.CRITICAL)
|
||||
handlers.append(handler)
|
||||
|
||||
|
||||
def make_queue_handler(queue):
|
||||
qhandler = QueueHandler(queue)
|
||||
qhandler.setLevel(logging.INFO)
|
||||
qhandler.addFilter(ContextInjection())
|
||||
return qhandler
|
||||
|
||||
|
||||
def setup_main_logger(multiprocess=False):
|
||||
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
|
||||
if handlers:
|
||||
# First create a separate loop to run the handlers on
|
||||
import threading
|
||||
|
||||
def run_loop(loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_forever()
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||
loop_thread.daemon = True
|
||||
loop_thread.start()
|
||||
|
||||
for handler in handlers:
|
||||
handler.loop = loop
|
||||
|
||||
qhandler = make_queue_handler(q)
|
||||
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||
logger.addHandler(qhandler)
|
||||
|
||||
listener = QueueListener(
|
||||
q, *handlers, respect_handler_level=True
|
||||
)
|
||||
listener.start()
|
||||
return q
|
||||
139
src/meta/monitor.py
Normal file
139
src/meta/monitor.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import asyncio
|
||||
from enum import IntEnum
|
||||
from collections import deque, ChainMap
|
||||
import datetime as dt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StatusLevel(IntEnum):
|
||||
ERRORED = -2
|
||||
UNSURE = -1
|
||||
WAITING = 0
|
||||
STARTING = 1
|
||||
OKAY = 2
|
||||
|
||||
@property
|
||||
def symbol(self):
|
||||
return symbols[self]
|
||||
|
||||
|
||||
symbols = {
|
||||
StatusLevel.ERRORED: '🟥',
|
||||
StatusLevel.UNSURE: '🟧',
|
||||
StatusLevel.WAITING: '⬜',
|
||||
StatusLevel.STARTING: '🟫',
|
||||
StatusLevel.OKAY: '🟩',
|
||||
}
|
||||
|
||||
|
||||
class ComponentStatus:
|
||||
def __init__(self, level: StatusLevel, short_formatstr: str, long_formatstr: str, data: dict = {}):
|
||||
self.level = level
|
||||
self.short_formatstr = short_formatstr
|
||||
self.long_formatstr = long_formatstr
|
||||
self.data = data
|
||||
self.created_at = dt.datetime.now(tz=dt.timezone.utc)
|
||||
|
||||
def format_args(self):
|
||||
extra = {
|
||||
'created_at': self.created_at,
|
||||
'level': self.level,
|
||||
'symbol': self.level.symbol,
|
||||
}
|
||||
return ChainMap(extra, self.data)
|
||||
|
||||
@property
|
||||
def short(self):
|
||||
return self.short_formatstr.format(**self.format_args())
|
||||
|
||||
@property
|
||||
def long(self):
|
||||
return self.long_formatstr.format(**self.format_args())
|
||||
|
||||
|
||||
class ComponentMonitor:
|
||||
_name = None
|
||||
|
||||
def __init__(self, name=None, callback=None):
|
||||
self._callback = callback
|
||||
self.name = name or self._name
|
||||
if not self.name:
|
||||
raise ValueError("ComponentMonitor must have a name")
|
||||
|
||||
async def _make_status(self, *args, **kwargs):
|
||||
if self._callback is not None:
|
||||
return await self._callback(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
async def status(self) -> ComponentStatus:
|
||||
try:
|
||||
status = await self._make_status()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Status callback for component '{self.name}' failed. This should not happen."
|
||||
)
|
||||
status = ComponentStatus(
|
||||
level=StatusLevel.UNSURE,
|
||||
short_formatstr="Status callback for '{name}' failed with error '{error}'",
|
||||
long_formatstr="Status callback for '{name}' failed with error '{error}'",
|
||||
data={
|
||||
'name': self.name,
|
||||
'error': repr(e)
|
||||
}
|
||||
)
|
||||
return status
|
||||
|
||||
|
||||
class SystemMonitor:
|
||||
def __init__(self):
|
||||
self.components = {}
|
||||
self.recent = deque(maxlen=10)
|
||||
|
||||
def add_component(self, component: ComponentMonitor):
|
||||
self.components[component.name] = component
|
||||
return component
|
||||
|
||||
async def request(self):
|
||||
"""
|
||||
Request status from each component.
|
||||
"""
|
||||
tasks = {
|
||||
name: asyncio.create_task(comp.status())
|
||||
for name, comp in self.components.items()
|
||||
}
|
||||
await asyncio.gather(*tasks.values())
|
||||
status = {
|
||||
name: await fut for name, fut in tasks.items()
|
||||
}
|
||||
self.recent.append(status)
|
||||
return status
|
||||
|
||||
async def _format_summary(self, status_dict: dict[str, ComponentStatus]):
|
||||
"""
|
||||
Format a one line summary from a status dict.
|
||||
"""
|
||||
freq = {level: 0 for level in StatusLevel}
|
||||
for status in status_dict.values():
|
||||
freq[status.level] += 1
|
||||
|
||||
summary = '\t'.join(f"{level.symbol} {count}" for level, count in freq.items() if count)
|
||||
return summary
|
||||
|
||||
async def _format_overview(self, status_dict: dict[str, ComponentStatus]):
|
||||
"""
|
||||
Format an overview (one line per component) from a status dict.
|
||||
"""
|
||||
lines = []
|
||||
for name, status in status_dict.items():
|
||||
lines.append(f"{status.level.symbol} {name}: {status.short}")
|
||||
summary = await self._format_summary(status_dict)
|
||||
return '\n'.join((summary, *lines))
|
||||
|
||||
async def get_summary(self):
|
||||
return await self._format_summary(await self.request())
|
||||
|
||||
async def get_overview(self):
|
||||
return await self._format_overview(await self.request())
|
||||
35
src/meta/sharding.py
Normal file
35
src/meta/sharding.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from .args import args
|
||||
from .config import conf
|
||||
|
||||
from psycopg import sql
|
||||
from data.conditions import Condition, Joiner
|
||||
|
||||
|
||||
shard_number = args.shard or 0
|
||||
|
||||
shard_count = conf.bot.getint('shard_count', 1)
|
||||
|
||||
sharded = (shard_count > 0)
|
||||
|
||||
|
||||
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering by shard id.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(_shard_condition('guildid', 10, 1))
|
||||
"""
|
||||
return Condition(
|
||||
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
|
||||
guildid=sql.Identifier(guild_column),
|
||||
shard_count=sql.Literal(shard_count)
|
||||
),
|
||||
Joiner.EQUALS,
|
||||
sql.Placeholder(),
|
||||
(shard_id,)
|
||||
)
|
||||
|
||||
|
||||
# Pre-built Condition for filtering by current shard.
|
||||
THIS_SHARD = SHARDID(shard_number)
|
||||
Reference in New Issue
Block a user