Initial Creation from bot template.

This commit is contained in:
2025-06-05 19:35:46 +10:00
commit 2e8d2555d5
50 changed files with 6751 additions and 0 deletions

344
src/meta/LionBot.py Normal file
View File

@@ -0,0 +1,344 @@
from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload
import logging
import asyncio
from weakref import WeakValueDictionary
import discord
from discord.utils import MISSING
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
from aiohttp import ClientSession
from data import Database
from utils.lib import tabulate
from .config import Conf
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
from .context import context
from .LionContext import LionContext
from .LionTree import LionTree
from .errors import HandledException, SafeCancellation
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
if TYPE_CHECKING:
from core.cog import CoreCog
logger = logging.getLogger(__name__)
class LionBot(Bot):
def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession,
testing_guilds: List[int] = [], **kwargs
):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs)
self.web_client = web_client
self.testing_guilds = testing_guilds
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.system_monitor = SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
self.system_monitor.add_component(self.monitor)
self._locks = WeakValueDictionary()
self._running_events = set()
@property
def core(self):
return self.get_cog('CoreCog')
async def _monitor_status(self):
if self.is_closed():
level = StatusLevel.ERRORED
info = "(ERROR) Websocket is closed"
data = {}
elif self.is_ws_ratelimited():
level = StatusLevel.WAITING
info = "(WAITING) Websocket is ratelimited"
data = {}
elif not self.is_ready():
level = StatusLevel.STARTING
info = "(STARTING) Not yet ready"
data = {}
else:
level = StatusLevel.OKAY
info = (
"(OK) "
"Logged in with {guild_count} guilds, "
", websocket latency {latency}, and {events} running events."
)
data = {
'guild_count': len(self.guilds),
'latency': self.latency,
'events': len(self._running_events),
}
return ComponentStatus(level, info, info, data)
async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
for extension in self.initial_extensions:
await self.load_extension(extension)
for guildid in self.testing_guilds:
guild = discord.Object(guildid)
if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)):
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)
# To make the type checker happy about fetching cogs by name
# TODO: Move this to stubs at some point
@overload
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
...
@overload
def get_cog(self, name: str) -> Optional[Cog]:
...
def get_cog(self, name: str) -> Optional[Cog]:
return super().get_cog(name)
async def add_cog(self, cog: Cog, **kwargs):
sup = super()
@log_wrap(action=f"Attach {cog.__cog_name__}")
async def wrapper():
logger.info(f"Attaching Cog {cog.__cog_name__}")
await sup.add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
await wrapper()
async def load_extension(self, name, *, package=None, **kwargs):
sup = super()
@log_wrap(action=f"Load {name.strip('.')}")
async def wrapper():
logger.info(f"Loading extension {name} in package {package}.")
await sup.load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.")
await wrapper()
async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"):
start_task = asyncio.create_task(self.login(token))
await start_task
with logging_context(stack=("Running",)):
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
await run_task
def dispatch(self, event_name: str, *args, **kwargs):
with logging_context(action=f"Dispatch {event_name}"):
super().dispatch(event_name, *args, **kwargs)
def _schedule_event(self, coro, event_name, *args, **kwargs):
"""
Extends client._schedule_event to keep a persistent
background task store.
"""
task = super()._schedule_event(coro, event_name, *args, **kwargs)
self._running_events.add(task)
task.add_done_callback(lambda fut: self._running_events.discard(fut))
def idlock(self, snowflakeid):
lock = self._locks.get(snowflakeid, None)
if lock is None:
lock = self._locks[snowflakeid] = asyncio.Lock()
return lock
async def on_ready(self):
logger.info(
f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.shardname}\n"
"------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
"Ready to take commands!\n",
extra={'action': 'Ready'}
)
async def get_context(self, origin, /, *, cls=MISSING):
if cls is MISSING:
cls = LionContext
ctx = await super().get_context(origin, cls=cls)
context.set(ctx)
return ctx
async def on_command(self, ctx: LionContext):
logger.info(
f"Executing command '{ctx.command.qualified_name}' "
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
f"with interaction: {ctx.interaction.data if ctx.interaction else None}",
extra={'with_ctx': True}
)
async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback
logger.debug(f"Handling command error for {ctx}: {exception}")
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
else:
cmd_str = str(ctx.command)
try:
raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
try:
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
original = exception.original.original
raise original
else:
original = exception.original
raise original
except HandledException:
pass
except TransformerError as e:
msg = str(e)
if msg:
try:
await ctx.error_reply(msg)
except Exception:
pass
logger.debug(
f"Caught a transformer error: {repr(e)}",
extra={'action': 'BotError', 'with_ctx': True}
)
except SafeCancellation:
if original.msg:
try:
await ctx.error_reply(original.msg)
except Exception:
pass
logger.debug(
f"Caught a safe cancellation: {original.details}",
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.Forbidden:
# Unknown uncaught Forbidden
try:
# Attempt a general error reply
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
except Exception:
# We can't send anything at all. Exit quietly, but log.
logger.warning(
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.HTTPException:
logger.error(
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass
except Exception as e:
logger.exception(
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
error_embed = discord.Embed(
title="Something went wrong!",
colour=discord.Colour.dark_red()
)
error_embed.description = (
"An unexpected error occurred while processing your command!\n"
"Our development team has been notified, and the issue will be addressed soon.\n"
"If the error persists, or you have any questions, please contact our [support team]({link}) "
"and give them the extra details below."
).format(link=self.config.bot.support_guild)
details = {}
details['error'] = f"`{repr(e)}`"
if ctx.interaction:
details['interactionid'] = f"`{ctx.interaction.id}`"
if ctx.command:
details['cmd'] = f"`{ctx.command.qualified_name}`"
if ctx.author:
details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`"
if ctx.guild:
details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`"
details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`"
if ctx.author:
ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else ''
details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`"
if ctx.channel.type is discord.enums.ChannelType.private:
details['channel'] = "`Direct Message`"
elif ctx.channel:
details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`"
details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`"
if ctx.author:
details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`"
details['shard'] = f"`{self.shardname}`"
details['log_stack'] = f"`{log_action_stack.get()}`"
table = '\n'.join(tabulate(*details.items()))
error_embed.add_field(name='Details', value=table)
try:
await ctx.error_reply(embed=error_embed)
except discord.HTTPException:
pass
finally:
exception.original = HandledException(exception.original)
except CheckFailure as e:
logger.debug(
f"Command failed check: {e}: {e.args}",
extra={'action': 'BotError', 'with_ctx': True}
)
try:
await ctx.error_reply(str(e))
except discord.HTTPException:
pass
except Exception:
# Completely unknown exception outside of command invocation!
# Something is very wrong here, don't attempt user interaction.
logger.exception(
f"Caught an unknown top-level exception while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
def add_command(self, command):
if not hasattr(command, '_placeholder_group_'):
super().add_command(command)
def request_chunking_for(self, guild):
if not guild.chunked:
return asyncio.create_task(
self._connection.chunk_guild(guild, wait=False, cache=True),
name=f"Background chunkreq for {guild.id}"
)
async def on_interaction(self, interaction: discord.Interaction):
"""
Adds the interaction author to guild cache if appropriate.
This gets run a little bit late, so it is possible the interaction gets handled
without the author being in case.
"""
guild = interaction.guild
user = interaction.user
if guild is not None and user is not None and isinstance(user, discord.Member):
if not guild.get_member(user.id):
guild._add_member(user)
if guild is not None and not guild.chunked:
# Getting an interaction in the guild is a good enough reason to request chunking
logger.info(
f"Unchunked guild <gid: {guild.id}> requesting chunking after interaction."
)
self.request_chunking_for(guild)

58
src/meta/LionCog.py Normal file
View File

@@ -0,0 +1,58 @@
from typing import Any
from discord.ext.commands import Cog
from discord.ext import commands as cmds
class LionCog(Cog):
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set()
for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name)
def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list
self = super().__new__(cls)
self.__cog_commands__ = [
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
]
return self
async def _inject(self, bot, *args, **kwargs):
if self.depends_on:
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
return await super()._inject(bot, *args, *kwargs)
@classmethod
def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True
return group
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
"""
Crossload a placeholder group's commands into the target group
"""
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
raise ValueError("Placeholder and target groups my be HypridGroups.")
if placeholder_group.name not in self._placeholder_groups_:
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
if target_group.app_command is None:
raise ValueError("Target group has no app_command to crossload into.")
for command in placeholder_group.commands:
placeholder_group.remove_command(command.name)
target_group.remove_command(command.name)
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
command.app_command = acmd
target_group.add_command(command)

195
src/meta/LionContext.py Normal file
View File

@@ -0,0 +1,195 @@
import types
import logging
from collections import namedtuple
from typing import Optional, TYPE_CHECKING
import discord
from discord.enums import ChannelType
from discord.ext.commands import Context
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__)
"""
Stuff that might be useful to implement (see cmdClient):
sent_messages cache
tasks cache
error reply
usage
interaction cache
View cache?
setting access
"""
FlatContext = namedtuple(
'FlatContext',
('message',
'interaction',
'guild',
'author',
'channel',
'alias',
'prefix',
'failed')
)
class LionContext(Context['LionBot']):
"""
Represents the context a command is invoked under.
Extends Context to add Lion-specific methods and attributes.
Also adds several contextual wrapped utilities for simpler user during command invocation.
"""
def __repr__(self):
parts = {}
if self.interaction is not None:
parts['iid'] = self.interaction.id
parts['itype'] = f"\"{self.interaction.type.name}\""
if self.message is not None:
parts['mid'] = self.message.id
if self.author is not None:
parts['uid'] = self.author.id
parts['uname'] = f"\"{self.author.name}\""
if self.channel is not None:
parts['cid'] = self.channel.id
if self.channel.type is ChannelType.private:
parts['cname'] = f"\"{self.channel.recipient}\""
else:
parts['cname'] = f"\"{self.channel.name}\""
if self.guild is not None:
parts['gid'] = self.guild.id
parts['gname'] = f"\"{self.guild.name}\""
if self.command is not None:
parts['cmd'] = f"\"{self.command.qualified_name}\""
if self.invoked_with is not None:
parts['alias'] = f"\"{self.invoked_with}\""
if self.command_failed:
parts['failed'] = self.command_failed
return "<LionContext: {}>".format(
' '.join(f"{name}={value}" for name, value in parts.items())
)
def flatten(self):
"""Flat pure-data context information, for caching and logging."""
return FlatContext(
self.message.id,
self.interaction.id if self.interaction is not None else None,
self.guild.id if self.guild is not None else None,
self.author.id if self.author is not None else None,
self.channel.id if self.channel is not None else None,
self.invoked_with,
self.prefix,
self.command_failed
)
@classmethod
def util(cls, util_func):
"""
Decorator to make a utility function available as a Context instance method.
"""
setattr(cls, util_func.__name__, util_func)
logger.debug(f"Attached context utility function: {util_func.__name__}")
return util_func
@classmethod
def wrappable_util(cls, util_func):
"""
Decorator to add a Wrappable utility function as a Context instance method.
"""
wrapped = Wrappable(util_func)
setattr(cls, util_func.__name__, wrapped)
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
return wrapped
async def error_reply(self, content: Optional[str] = None, **kwargs):
if content and 'embed' not in kwargs:
embed = discord.Embed(
colour=discord.Colour.red(),
description=content
)
kwargs['embed'] = embed
content = None
# Expect this may be run in highly unusual circumstances.
# This should never error, or at least handle all errors.
if self.interaction:
kwargs.setdefault('ephemeral', True)
try:
await self.reply(content=content, **kwargs)
except discord.HTTPException:
pass
except Exception:
logger.exception(
"Unknown exception in 'error_reply'.",
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
)
class Wrappable:
__slots__ = ('_func', 'wrappers')
def __init__(self, func):
self._func = func
self.wrappers = None
@property
def __name__(self):
return self._func.__name__
def add_wrapper(self, func, name=None):
self.wrappers = self.wrappers or {}
name = name or func.__name__
self.wrappers[name] = func
logger.debug(
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
extra={'action': "Wrap Util"}
)
def remove_wrapper(self, name):
if not self.wrappers or name not in self.wrappers:
raise ValueError(
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
)
self.wrappers.pop(name)
logger.debug(
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
extra={'action': "Unwrap Util"}
)
def __call__(self, *args, **kwargs):
if self.wrappers:
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
else:
return self._func(*args, **kwargs)
def _wrapped(self, iter_wraps):
next_wrap = next(iter_wraps, None)
if next_wrap:
def _func(*args, **kwargs):
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
else:
_func = self._func
return _func
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
return types.MethodType(self, instance)
LionContext.reply = Wrappable(LionContext.reply)
# @LionContext.reply.add_wrapper
# async def think(func, ctx, *args, **kwargs):
# await ctx.channel.send("thinking")
# await func(ctx, *args, **kwargs)

150
src/meta/LionTree.py Normal file
View File

@@ -0,0 +1,150 @@
import logging
import discord
from discord import Interaction
from discord.app_commands import CommandTree
from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace
from utils.lib import tabulate
from .logger import logging_context, set_logging_context, log_wrap, log_action_stack
from .errors import SafeCancellation
from .config import conf
logger = logging.getLogger(__name__)
class LionTree(CommandTree):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._call_tasks = set()
async def on_error(self, interaction: discord.Interaction, error) -> None:
try:
if isinstance(error, CommandInvokeError):
raise error.original
else:
raise error
except SafeCancellation:
# Assume this has already been handled
pass
except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
if interaction.type is not InteractionType.autocomplete:
embed = self.bugsplat(interaction, error)
await self.error_reply(interaction, embed)
async def error_reply(self, interaction, embed):
if not interaction.is_expired():
try:
if interaction.response.is_done():
await interaction.followup.send(embed=embed, ephemeral=True)
else:
await interaction.response.send_message(embed=embed, ephemeral=True)
except discord.HTTPException:
pass
def bugsplat(self, interaction, e):
error_embed = discord.Embed(title="Something went wrong!", colour=discord.Colour.red())
error_embed.description = (
"An unexpected error occurred during this interaction!\n"
"Our development team has been notified, and the issue will be addressed soon.\n"
"If the error persists, or you have any questions, please contact our [support team]({link}) "
"and give them the extra details below."
).format(link=interaction.client.config.bot.support_guild)
details = {}
details['error'] = f"`{repr(e)}`"
details['interactionid'] = f"`{interaction.id}`"
details['interactiontype'] = f"`{interaction.type}`"
if interaction.command:
details['cmd'] = f"`{interaction.command.qualified_name}`"
if interaction.user:
details['user'] = f"`{interaction.user.id}` -- `{interaction.user}`"
if interaction.guild:
details['guild'] = f"`{interaction.guild.id}` -- `{interaction.guild.name}`"
details['my_guild_perms'] = f"`{interaction.guild.me.guild_permissions.value}`"
if interaction.user:
ownerstr = ' (owner)' if interaction.user.id == interaction.guild.owner_id else ''
details['user_guild_perms'] = f"`{interaction.user.guild_permissions.value}{ownerstr}`"
if interaction.channel.type is discord.enums.ChannelType.private:
details['channel'] = "`Direct Message`"
elif interaction.channel:
details['channel'] = f"`{interaction.channel.id}` -- `{interaction.channel.name}`"
details['my_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.guild.me).value}`"
if interaction.user:
details['user_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.user).value}`"
details['shard'] = f"`{interaction.client.shardname}`"
details['log_stack'] = f"`{log_action_stack.get()}`"
table = '\n'.join(tabulate(*details.items()))
error_embed.add_field(name='Details', value=table)
return error_embed
def _from_interaction(self, interaction: Interaction) -> None:
@log_wrap(context=f"iid: {interaction.id}", isolate=False)
async def wrapper():
try:
await self._call(interaction)
except AppCommandError as e:
await self._dispatch_error(interaction, e)
task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
self._call_tasks.add(task)
task.add_done_callback(lambda fut: self._call_tasks.discard(fut))
async def _call(self, interaction):
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above
interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
set_logging_context(action=f"Acmp {command.qualified_name}")
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
try:
await command._invoke_autocomplete(interaction, focused, namespace)
except Exception as e:
await self.on_error(interaction, e)
return
set_logging_context(action=f"Run {command.qualified_name}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")

15
src/meta/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .LionBot import LionBot
from .LionCog import LionCog
from .LionContext import LionContext
from .LionTree import LionTree
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji
from .args import args
from .app import appname, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot
from . import sharding
from . import logger
from . import app

32
src/meta/app.py Normal file
View File

@@ -0,0 +1,32 @@
"""
appname: str
The base identifer for this application.
This identifies which services the app offers.
shardname: str
The specific name of the running application.
Only one process should be connecteded with a given appname.
For the bot apps, usually specifies the shard id and shard number.
"""
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
from . import sharding, conf
from .logger import log_app
from .args import args
appname = conf.data['appid']
appid = appname # backwards compatibility
def appname_from_shard(shardid):
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
return appname
def shard_from_appname(appname: str):
return int(appname.rsplit('_', maxsplit=1)[-1])
shardname = appname_from_shard(sharding.shard_number)
log_app.set(shardname)

35
src/meta/args.py Normal file
View File

@@ -0,0 +1,35 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file."
)
parser.add_argument(
'--shard',
dest='shard',
default=None,
type=int,
help="Shard number to run, if applicable."
)
parser.add_argument(
'--host',
dest='host',
default='127.0.0.1',
help="IP address to run the app listener on."
)
parser.add_argument(
'--port',
dest='port',
default='5001',
help="Port to run the app listener on."
)
args = parser.parse_args()

146
src/meta/config.py Normal file
View File

@@ -0,0 +1,146 @@
from discord import PartialEmoji
import configparser as cfgp
from .args import args
shard_number = args.shard
class configEmoji(PartialEmoji):
__slots__ = ('fallback',)
def __init__(self, *args, fallback=None, **kwargs):
super().__init__(*args, **kwargs)
self.fallback = fallback
@classmethod
def from_str(cls, emojistr: str):
"""
Parses emoji strings of one of the following forms
`<a:name:id> or fallback`
`<:name:id> or fallback`
`<a:name:id>`
`<:name:id>`
"""
splits = emojistr.rsplit(' or ', maxsplit=1)
fallback = splits[1] if len(splits) > 1 else None
emojistr = splits[0].strip('<> ')
animated, name, id = emojistr.split(':')
return cls(
name=name,
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated),
id=int(id) if id else None
)
class MapDotProxy:
"""
Allows dot access to an underlying Mappable object.
"""
__slots__ = ("_map", "_converter")
def __init__(self, mappable, converter=None):
self._map = mappable
self._converter = converter
def __getattribute__(self, key):
_map = object.__getattribute__(self, '_map')
if key == '_map':
return _map
if key in _map:
_converter = object.__getattribute__(self, '_converter')
if _converter:
return _converter(_map[key])
else:
return _map[key]
else:
return object.__getattribute__(_map, key)
def __getitem__(self, key):
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
"emoji": configEmoji.from_str,
}
)
with open(configfile) as conff:
# Opening with read_file mainly to ensure the file exists
self.config.read_file(conff)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = MapDotProxy(self.config[self.section_name])
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
self.emojis = MapDotProxy(
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
converter=configEmoji.from_str
)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
name = section.upper()
shard_name = f"{name}-{shard_number}"
if shard_name in self.config:
return self.config[shard_name]
else:
return self.config[name]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config, 'BOT')

20
src/meta/context.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Namespace for various global context variables.
Allows asyncio callbacks to accurately retrieve information about the current state.
"""
from typing import TYPE_CHECKING, Optional
from contextvars import ContextVar
if TYPE_CHECKING:
from .LionBot import LionBot
from .LionContext import LionContext
# Contains the current command context, if applicable
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
# Contains the current LionBot instance
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)

64
src/meta/errors.py Normal file
View File

@@ -0,0 +1,64 @@
from typing import Optional
from string import Template
class SafeCancellation(Exception):
"""
Raised to safely cancel execution of the current operation.
If not caught, is expected to be propagated to the Tree and safely ignored there.
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
The error handler should then set the `msg` to None, to avoid double handling.
Debugging information should go in `details`, to be logged by a top-level error handler.
"""
default_message = ""
@property
def msg(self):
return self._msg if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self._msg: Optional[str] = _msg
self.details: str = details if details is not None else self.msg
super().__init__(**kwargs)
class UserInputError(SafeCancellation):
"""
A SafeCancellation induced from unparseable user input.
"""
default_message = "Could not understand your input."
@property
def msg(self):
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
self.info = info
super().__init__(_msg, **kwargs)
class UserCancelled(SafeCancellation):
"""
A SafeCancellation induced from manual user cancellation.
Usually silent.
"""
default_msg = None
class ResponseTimedOut(SafeCancellation):
"""
A SafeCancellation induced from a user interaction time-out.
"""
default_msg = "Session timed out waiting for input."
class HandledException(SafeCancellation):
"""
Sentinel class to indicate to error handlers that this exception has been handled.
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
"""
def __init__(self, exc=None, **kwargs):
self.exc = exc
super().__init__(**kwargs)

468
src/meta/logger.py Normal file
View File

@@ -0,0 +1,468 @@
import sys
import logging
import asyncio
from typing import List, Optional
from logging.handlers import QueueListener, QueueHandler
import queue
import multiprocessing
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
import discord
from discord import Webhook, File
import aiohttp
from .config import conf
from . import sharding
from .context import context
from utils.lib import utc_now
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
def set_logging_context(
context: Optional[str] = None,
action: Optional[str] = None,
stack: Optional[tuple[str, ...]] = None
):
"""
Statically set the logging context variables to the given values.
If `action` is given, pushes it onto the `log_action_stack`.
"""
if context is not None:
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
@contextmanager
def logging_context(context=None, action=None, stack=None):
"""
Context manager for executing a block of code in a given logging context.
This context manager should only be used around synchronous code.
This is because async code *may* get cancelled or externally garbage collected,
in which case the finally block will be executed in the wrong context.
See https://github.com/python/cpython/issues/93740
This can be refactored nicely if this gets merged:
https://github.com/python/cpython/pull/99634
(It will not necessarily break on async code,
if the async code can be guaranteed to clean up in its own context.)
"""
if context is not None:
oldcontext = log_context.get()
log_context.set(context)
if action is not None or stack is not None:
astack = log_action_stack.get()
newstack = stack if stack is not None else astack
if action is not None:
newstack = (*newstack, action)
log_action_stack.set(newstack)
try:
yield
finally:
if context is not None:
log_context.set(oldcontext)
if stack is not None or action is not None:
log_action_stack.set(astack)
def with_log_ctx(isolate=True, **kwargs):
"""
Execute a coroutine inside a given logging context.
If `isolate` is true, ensures that context does not leak
outside the coroutine.
If `isolate` is false, just statically set the context,
which will leak unless the coroutine is
called in an externally copied context.
"""
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
if isolate:
with logging_context(**kwargs):
# Task creation will synchronously copy the context
# This is gc safe
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
return await task
else:
# This will leak context changes
set_logging_context(**kwargs)
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
# For backwards compatibility
log_wrap = with_log_ctx
def persist_task(task_collection: set):
"""
Coroutine decorator that ensures the coroutine is scheduled as a task
and added to the given task_collection for strong reference
when it is called.
This is just a hack to handle discord.py events potentially
being unexpectedly garbage collected.
Since this also implicitly schedules the coroutine as a task when it is called,
the coroutine will also be run inside an isolated context.
"""
def decorator(coro):
@wraps(coro)
async def wrapped(*w_args, **w_kwargs):
name = f"persisted-{coro.__name__}"
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
task_collection.add(task)
task.add_done_callback(lambda f: task_collection.discard(f))
await task
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
'%(cyan)%(app)-15s%(reset)|' +
'%(cyan)%(context)-24s%(reset)|' +
'%(cyan)%(actionstr)-22s%(reset)|' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
# datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ExactLevelFilter(logging.Filter):
def __init__(self, target_level, name=""):
super().__init__(name)
self.target_level = target_level
def filter(self, record):
return (record.levelno == self.target_level)
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter):
def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
except asyncio.CancelledError:
raise
except Exception:
self.handleError(record)
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
super().__init__()
self.webhook_url = webhook_url
self.prefix = prefix
self.batched = ""
self.batch = batch
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
self.bucket = Bucket(20, 40)
self.ignored = 0
self.session = None
self.webhook = None
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.format(record)
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
if self.session is None:
self.setup()
asyncio.create_task(self.post(record))
def setup(self):
self.session = aiohttp.ClientSession()
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
async def post(self, record):
if record.context == 'Webhook Logger':
# Don't livelog livelog errors
# Otherwise we recurse and Cloudflare hates us
return
log_context.set("Webhook Logger")
log_action_stack.set(("Logging",))
log_app.set(record.app)
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900:
as_file = True
else:
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 1500:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += message
if len(self.batched) + len(message) > 1500:
await self._send_batched_now()
else:
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, message, as_file=False):
try:
self.bucket.request()
except BucketOverFull:
# Silently ignore
self.ignored += 1
return
except BucketFull:
logger.warning(
"Can't keep up! "
f"Ignoring records on live-logger {self.webhook.id}."
)
self.ignored += 1
return
else:
if self.ignored > 0:
logger.warning(
"Can't keep up! "
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
)
self.ignored = 0
try:
if as_file or len(message) > 1900:
with StringIO(message) as fp:
fp.seek(0)
await self.webhook.send(
f"{self.prefix}\n`{message.splitlines()[0]}`",
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else:
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
except discord.HTTPException:
logger.exception(
"Live logger errored. Slowing down live logger."
)
self.bucket.fill()
handlers = []
if webhook := conf.logging['general_log']:
handler = WebHookHandler(webhook, batch=True)
handlers.append(handler)
if webhook := conf.logging['warning_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
handler.addFilter(ExactLevelFilter(logging.WARNING))
handler.setLevel(logging.WARNING)
handlers.append(handler)
if webhook := conf.logging['error_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
handler.setLevel(logging.ERROR)
handlers.append(handler)
if webhook := conf.logging['critical_log']:
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
handler.setLevel(logging.CRITICAL)
handlers.append(handler)
def make_queue_handler(queue):
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
return qhandler
def setup_main_logger(multiprocess=False):
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
qhandler = make_queue_handler(q)
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler)
listener = QueueListener(
q, *handlers, respect_handler_level=True
)
listener.start()
return q

139
src/meta/monitor.py Normal file
View File

@@ -0,0 +1,139 @@
import logging
import asyncio
from enum import IntEnum
from collections import deque, ChainMap
import datetime as dt
logger = logging.getLogger(__name__)
class StatusLevel(IntEnum):
ERRORED = -2
UNSURE = -1
WAITING = 0
STARTING = 1
OKAY = 2
@property
def symbol(self):
return symbols[self]
symbols = {
StatusLevel.ERRORED: '🟥',
StatusLevel.UNSURE: '🟧',
StatusLevel.WAITING: '',
StatusLevel.STARTING: '🟫',
StatusLevel.OKAY: '🟩',
}
class ComponentStatus:
def __init__(self, level: StatusLevel, short_formatstr: str, long_formatstr: str, data: dict = {}):
self.level = level
self.short_formatstr = short_formatstr
self.long_formatstr = long_formatstr
self.data = data
self.created_at = dt.datetime.now(tz=dt.timezone.utc)
def format_args(self):
extra = {
'created_at': self.created_at,
'level': self.level,
'symbol': self.level.symbol,
}
return ChainMap(extra, self.data)
@property
def short(self):
return self.short_formatstr.format(**self.format_args())
@property
def long(self):
return self.long_formatstr.format(**self.format_args())
class ComponentMonitor:
_name = None
def __init__(self, name=None, callback=None):
self._callback = callback
self.name = name or self._name
if not self.name:
raise ValueError("ComponentMonitor must have a name")
async def _make_status(self, *args, **kwargs):
if self._callback is not None:
return await self._callback(*args, **kwargs)
else:
raise NotImplementedError
async def status(self) -> ComponentStatus:
try:
status = await self._make_status()
except Exception as e:
logger.exception(
f"Status callback for component '{self.name}' failed. This should not happen."
)
status = ComponentStatus(
level=StatusLevel.UNSURE,
short_formatstr="Status callback for '{name}' failed with error '{error}'",
long_formatstr="Status callback for '{name}' failed with error '{error}'",
data={
'name': self.name,
'error': repr(e)
}
)
return status
class SystemMonitor:
def __init__(self):
self.components = {}
self.recent = deque(maxlen=10)
def add_component(self, component: ComponentMonitor):
self.components[component.name] = component
return component
async def request(self):
"""
Request status from each component.
"""
tasks = {
name: asyncio.create_task(comp.status())
for name, comp in self.components.items()
}
await asyncio.gather(*tasks.values())
status = {
name: await fut for name, fut in tasks.items()
}
self.recent.append(status)
return status
async def _format_summary(self, status_dict: dict[str, ComponentStatus]):
"""
Format a one line summary from a status dict.
"""
freq = {level: 0 for level in StatusLevel}
for status in status_dict.values():
freq[status.level] += 1
summary = '\t'.join(f"{level.symbol} {count}" for level, count in freq.items() if count)
return summary
async def _format_overview(self, status_dict: dict[str, ComponentStatus]):
"""
Format an overview (one line per component) from a status dict.
"""
lines = []
for name, status in status_dict.items():
lines.append(f"{status.level.symbol} {name}: {status.short}")
summary = await self._format_summary(status_dict)
return '\n'.join((summary, *lines))
async def get_summary(self):
return await self._format_summary(await self.request())
async def get_overview(self):
return await self._format_overview(await self.request())

35
src/meta/sharding.py Normal file
View File

@@ -0,0 +1,35 @@
from .args import args
from .config import conf
from psycopg import sql
from data.conditions import Condition, Joiner
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
"""
Condition constructor for filtering by shard id.
Example Usage
-------------
Query.where(_shard_condition('guildid', 10, 1))
"""
return Condition(
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
guildid=sql.Identifier(guild_column),
shard_count=sql.Literal(shard_count)
),
Joiner.EQUALS,
sql.Placeholder(),
(shard_id,)
)
# Pre-built Condition for filtering by current shard.
THIS_SHARD = SHARDID(shard_number)