rewrite: Restructure to include GUI.
This commit is contained in:
200
src/meta/LionBot.py
Normal file
200
src/meta/LionBot.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.utils import MISSING
|
||||
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from data import Database
|
||||
|
||||
from .config import Conf
|
||||
from .logger import logging_context, log_context, log_action_stack
|
||||
from .context import context
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
from .errors import HandledException, SafeCancellation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core import CoreCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
||||
testing_guilds: List[int] = [], translator=None, **kwargs
|
||||
):
|
||||
kwargs.setdefault('tree_cls', LionTree)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.web_client = web_client
|
||||
self.testing_guilds = testing_guilds
|
||||
self.initial_extensions = initial_extensions
|
||||
self.db = db
|
||||
self.appname = appname
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
self.app_ipc = app_ipc
|
||||
self.core: Optional['CoreCog'] = None
|
||||
self.translator = translator
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
log_context.set(f"APP: {self.application_id}")
|
||||
await self.app_ipc.connect()
|
||||
|
||||
if self.translator is not None:
|
||||
await self.tree.set_translator(self.translator)
|
||||
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(extension)
|
||||
|
||||
for guildid in self.testing_guilds:
|
||||
guild = discord.Object(guildid)
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
|
||||
async def add_cog(self, cog: Cog, **kwargs):
|
||||
with logging_context(action=f"Attach {cog.__cog_name__}"):
|
||||
logger.info(f"Attaching Cog {cog.__cog_name__}")
|
||||
await super().add_cog(cog, **kwargs)
|
||||
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
|
||||
|
||||
async def load_extension(self, name, *, package=None, **kwargs):
|
||||
with logging_context(action=f"Load {name.strip('.')}"):
|
||||
logger.info(f"Loading extension {name} in package {package}.")
|
||||
await super().load_extension(name, package=package, **kwargs)
|
||||
logger.debug(f"Loaded extension {name} in package {package}.")
|
||||
|
||||
async def start(self, token: str, *, reconnect: bool = True):
|
||||
with logging_context(action="Login"):
|
||||
await self.login(token)
|
||||
with logging_context(stack=["Running"]):
|
||||
await self.connect(reconnect=reconnect)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info(
|
||||
f"Logged in as {self.application.name}\n"
|
||||
f"Application id {self.application.id}\n"
|
||||
f"Shard Talk identifier {self.shardname}\n"
|
||||
"------------------------------\n"
|
||||
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||
"------------------------------\n"
|
||||
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||
"Ready to take commands!\n",
|
||||
extra={'action': 'Ready'}
|
||||
)
|
||||
|
||||
async def get_context(self, origin, /, *, cls=MISSING):
|
||||
if cls is MISSING:
|
||||
cls = LionContext
|
||||
ctx = await super().get_context(origin, cls=cls)
|
||||
context.set(ctx)
|
||||
return ctx
|
||||
|
||||
async def on_command(self, ctx: LionContext):
|
||||
logger.info(
|
||||
f"Executing command '{ctx.command.qualified_name}' "
|
||||
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
|
||||
f"with arguments {ctx.args} and kwargs {ctx.kwargs}.",
|
||||
extra={'with_ctx': True}
|
||||
)
|
||||
|
||||
async def on_command_error(self, ctx, exception):
|
||||
# TODO: Some of these could have more user-feedback
|
||||
cmd_str = str(ctx.command)
|
||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||
cmd_str = ctx.command.app_command.to_dict()
|
||||
try:
|
||||
raise exception
|
||||
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
|
||||
try:
|
||||
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
|
||||
original = exception.original.original
|
||||
raise original
|
||||
else:
|
||||
original = exception.original
|
||||
raise original
|
||||
except HandledException:
|
||||
pass
|
||||
except SafeCancellation:
|
||||
if original.msg:
|
||||
try:
|
||||
await ctx.error_reply(original.msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a safe cancellation: {original.details}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.Forbidden:
|
||||
# Unknown uncaught Forbidden
|
||||
try:
|
||||
# Attempt a general error reply
|
||||
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
|
||||
except Exception:
|
||||
# We can't send anything at all. Exit quietly, but log.
|
||||
logger.warning(
|
||||
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.HTTPException:
|
||||
logger.warning(
|
||||
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
error_embed = discord.Embed(title="Something went wrong!")
|
||||
error_embed.description = (
|
||||
"An unexpected error occurred while processing your command!\n"
|
||||
"Our development team has been notified, and the issue should be fixed soon.\n"
|
||||
"If the error persists, please contact our support team and give them the following number: "
|
||||
f"`{ctx.interaction.id}`"
|
||||
)
|
||||
|
||||
try:
|
||||
await ctx.error_reply(embed=error_embed)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
exception.original = HandledException(exception.original)
|
||||
except CheckFailure:
|
||||
logger.debug(
|
||||
f"Command failed check: {exception}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
try:
|
||||
await ctx.error_rely(exception.message)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
# Completely unknown exception outside of command invocation!
|
||||
# Something is very wrong here, don't attempt user interaction.
|
||||
logger.exception(
|
||||
f"Caught an unknown top-level exception while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
def add_command(self, command):
|
||||
if hasattr(command, '_placeholder_group_'):
|
||||
return
|
||||
super().add_command(command)
|
||||
58
src/meta/LionCog.py
Normal file
58
src/meta/LionCog.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any
|
||||
|
||||
from discord.ext.commands import Cog
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
|
||||
class LionCog(Cog):
|
||||
# A set of other cogs that this cog depends on
|
||||
depends_on: set['LionCog'] = set()
|
||||
_placeholder_groups_: set[str]
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
cls._placeholder_groups_ = set()
|
||||
|
||||
for base in reversed(cls.__mro__):
|
||||
for elem, value in base.__dict__.items():
|
||||
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
|
||||
cls._placeholder_groups_.add(value.name)
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any):
|
||||
# Patch to ensure no placeholder groups are in the command list
|
||||
self = super().__new__(cls)
|
||||
self.__cog_commands__ = [
|
||||
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
|
||||
]
|
||||
return self
|
||||
|
||||
async def _inject(self, bot, *args, **kwargs):
|
||||
if self.depends_on:
|
||||
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
|
||||
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
|
||||
|
||||
return await super()._inject(bot, *args, *kwargs)
|
||||
|
||||
@classmethod
|
||||
def placeholder_group(cls, group: cmds.HybridGroup):
|
||||
group._placeholder_group_ = True
|
||||
return group
|
||||
|
||||
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
|
||||
"""
|
||||
Crossload a placeholder group's commands into the target group
|
||||
"""
|
||||
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
|
||||
raise ValueError("Placeholder and target groups my be HypridGroups.")
|
||||
if placeholder_group.name not in self._placeholder_groups_:
|
||||
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
|
||||
if target_group.app_command is None:
|
||||
raise ValueError("Target group has no app_command to crossload into.")
|
||||
|
||||
for command in placeholder_group.commands:
|
||||
placeholder_group.remove_command(command.name)
|
||||
target_group.remove_command(command.name)
|
||||
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
|
||||
command.app_command = acmd
|
||||
target_group.add_command(command)
|
||||
190
src/meta/LionContext.py
Normal file
190
src/meta/LionContext.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import types
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from discord.ext.commands import Context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
from core.lion import Lion
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
Stuff that might be useful to implement (see cmdClient):
|
||||
sent_messages cache
|
||||
tasks cache
|
||||
error reply
|
||||
usage
|
||||
interaction cache
|
||||
View cache?
|
||||
setting access
|
||||
"""
|
||||
|
||||
|
||||
FlatContext = namedtuple(
|
||||
'FlatContext',
|
||||
('message',
|
||||
'interaction',
|
||||
'guild',
|
||||
'author',
|
||||
'alias',
|
||||
'prefix',
|
||||
'failed')
|
||||
)
|
||||
|
||||
|
||||
class LionContext(Context['LionBot']):
|
||||
"""
|
||||
Represents the context a command is invoked under.
|
||||
|
||||
Extends Context to add Lion-specific methods and attributes.
|
||||
Also adds several contextual wrapped utilities for simpler user during command invocation.
|
||||
"""
|
||||
alion: 'Lion'
|
||||
|
||||
def __repr__(self):
|
||||
parts = {}
|
||||
if self.interaction is not None:
|
||||
parts['iid'] = self.interaction.id
|
||||
parts['itype'] = f"\"{self.interaction.type.name}\""
|
||||
if self.message is not None:
|
||||
parts['mid'] = self.message.id
|
||||
if self.author is not None:
|
||||
parts['uid'] = self.author.id
|
||||
parts['uname'] = f"\"{self.author.name}\""
|
||||
if self.channel is not None:
|
||||
parts['cid'] = self.channel.id
|
||||
parts['cname'] = f"\"{self.channel.name}\""
|
||||
if self.guild is not None:
|
||||
parts['gid'] = self.guild.id
|
||||
parts['gname'] = f"\"{self.guild.name}\""
|
||||
if self.command is not None:
|
||||
parts['cmd'] = f"\"{self.command.qualified_name}\""
|
||||
if self.invoked_with is not None:
|
||||
parts['alias'] = f"\"{self.invoked_with}\""
|
||||
if self.command_failed:
|
||||
parts['failed'] = self.command_failed
|
||||
|
||||
return "<LionContext: {}>".format(
|
||||
' '.join(f"{name}={value}" for name, value in parts.items())
|
||||
)
|
||||
|
||||
def flatten(self):
|
||||
"""Flat pure-data context information, for caching and logging."""
|
||||
return FlatContext(
|
||||
self.message.id,
|
||||
self.interaction.id if self.interaction is not None else None,
|
||||
self.guild.id if self.guild is not None else None,
|
||||
self.author.id if self.author is not None else None,
|
||||
self.channel.id if self.channel is not None else None,
|
||||
self.invoked_with,
|
||||
self.prefix,
|
||||
self.command_failed
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def util(cls, util_func):
|
||||
"""
|
||||
Decorator to make a utility function available as a Context instance method.
|
||||
"""
|
||||
setattr(cls, util_func.__name__, util_func)
|
||||
logger.debug(f"Attached context utility function: {util_func.__name__}")
|
||||
return util_func
|
||||
|
||||
@classmethod
|
||||
def wrappable_util(cls, util_func):
|
||||
"""
|
||||
Decorator to add a Wrappable utility function as a Context instance method.
|
||||
"""
|
||||
wrapped = Wrappable(util_func)
|
||||
setattr(cls, util_func.__name__, wrapped)
|
||||
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
|
||||
return wrapped
|
||||
|
||||
async def error_reply(self, content: Optional[str] = None, **kwargs):
|
||||
if content and 'embed' not in kwargs:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description=content
|
||||
)
|
||||
kwargs['embed'] = embed
|
||||
content = None
|
||||
|
||||
# Expect this may be run in highly unusual circumstances.
|
||||
# This should never error, or at least handle all errors.
|
||||
try:
|
||||
await self.reply(content=content, **kwargs)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unknown exception in 'error_reply'.",
|
||||
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
|
||||
)
|
||||
|
||||
|
||||
class Wrappable:
|
||||
__slots__ = ('_func', 'wrappers')
|
||||
|
||||
def __init__(self, func):
|
||||
self._func = func
|
||||
self.wrappers = None
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self._func.__name__
|
||||
|
||||
def add_wrapper(self, func, name=None):
|
||||
self.wrappers = self.wrappers or {}
|
||||
name = name or func.__name__
|
||||
self.wrappers[name] = func
|
||||
logger.debug(
|
||||
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
|
||||
extra={'action': "Wrap Util"}
|
||||
)
|
||||
|
||||
def remove_wrapper(self, name):
|
||||
if not self.wrappers or name not in self.wrappers:
|
||||
raise ValueError(
|
||||
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
|
||||
)
|
||||
self.wrappers.pop(name)
|
||||
logger.debug(
|
||||
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
|
||||
extra={'action': "Unwrap Util"}
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.wrappers:
|
||||
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
|
||||
else:
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
def _wrapped(self, iter_wraps):
|
||||
next_wrap = next(iter_wraps, None)
|
||||
if next_wrap:
|
||||
def _func(*args, **kwargs):
|
||||
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
|
||||
else:
|
||||
_func = self._func
|
||||
return _func
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return types.MethodType(self, instance)
|
||||
|
||||
|
||||
LionContext.reply = Wrappable(LionContext.reply)
|
||||
|
||||
|
||||
# @LionContext.reply.add_wrapper
|
||||
# async def think(func, ctx, *args, **kwargs):
|
||||
# await ctx.channel.send("thinking")
|
||||
# await func(ctx, *args, **kwargs)
|
||||
79
src/meta/LionTree.py
Normal file
79
src/meta/LionTree.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
|
||||
from discord import Interaction
|
||||
from discord.app_commands import CommandTree
|
||||
from discord.app_commands.errors import AppCommandError, CommandInvokeError
|
||||
from discord.enums import InteractionType
|
||||
from discord.app_commands.namespace import Namespace
|
||||
|
||||
from .logger import logging_context
|
||||
from .errors import SafeCancellation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionTree(CommandTree):
|
||||
async def on_error(self, interaction, error) -> None:
|
||||
try:
|
||||
if isinstance(error, CommandInvokeError):
|
||||
raise error.original
|
||||
else:
|
||||
raise error
|
||||
except SafeCancellation:
|
||||
# Assume this has already been handled
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
|
||||
|
||||
async def _call(self, interaction):
|
||||
with logging_context(context=f"iid: {interaction.id}"):
|
||||
if not await self.interaction_check(interaction):
|
||||
interaction.command_failed = True
|
||||
return
|
||||
|
||||
data = interaction.data # type: ignore
|
||||
type = data.get('type', 1)
|
||||
if type != 1:
|
||||
# Context menu command...
|
||||
await self._call_context_menu(interaction, data, type)
|
||||
return
|
||||
|
||||
command, options = self._get_app_command_options(data)
|
||||
|
||||
# Pre-fill the cached slot to prevent re-computation
|
||||
interaction._cs_command = command
|
||||
|
||||
# At this point options refers to the arguments of the command
|
||||
# and command refers to the class type we care about
|
||||
namespace = Namespace(interaction, data.get('resolved', {}), options)
|
||||
|
||||
# Same pre-fill as above
|
||||
interaction._cs_namespace = namespace
|
||||
|
||||
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
|
||||
if interaction.type is InteractionType.autocomplete:
|
||||
with logging_context(action=f"Acmp {command.qualified_name}"):
|
||||
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
|
||||
if focused is None:
|
||||
raise AppCommandError(
|
||||
'This should not happen, but there is no focused element. This is a Discord bug.'
|
||||
)
|
||||
await command._invoke_autocomplete(interaction, focused, namespace)
|
||||
return
|
||||
|
||||
with logging_context(action=f"Run {command.qualified_name}"):
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
|
||||
try:
|
||||
await command._invoke_with_namespace(interaction, namespace)
|
||||
except AppCommandError as e:
|
||||
interaction.command_failed = True
|
||||
await command._invoke_error_handlers(interaction, e)
|
||||
await self.on_error(interaction, e)
|
||||
else:
|
||||
if not interaction.command_failed:
|
||||
self.client.dispatch('app_command_completion', interaction, command)
|
||||
finally:
|
||||
if interaction.command_failed:
|
||||
logger.debug("Command completed with errors.")
|
||||
else:
|
||||
logger.debug("Command completed without errors.")
|
||||
15
src/meta/__init__.py
Normal file
15
src/meta/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .LionBot import LionBot
|
||||
from .LionCog import LionCog
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
|
||||
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||
from .config import conf, configEmoji
|
||||
from .args import args
|
||||
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
|
||||
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||
from .context import context, ctx_bot
|
||||
|
||||
from . import sharding
|
||||
from . import logger
|
||||
from . import app
|
||||
46
src/meta/app.py
Normal file
46
src/meta/app.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
appname: str
|
||||
The base identifer for this application.
|
||||
This identifies which services the app offers.
|
||||
shardname: str
|
||||
The specific name of the running application.
|
||||
Only one process should be connecteded with a given appname.
|
||||
For the bot apps, usually specifies the shard id and shard number.
|
||||
"""
|
||||
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
|
||||
|
||||
from . import sharding, conf
|
||||
from .logger import log_app
|
||||
from .ipc.client import AppClient
|
||||
from .args import args
|
||||
|
||||
|
||||
appname = conf.data['appid']
|
||||
appid = appname # backwards compatibility
|
||||
|
||||
|
||||
def appname_from_shard(shardid):
|
||||
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
||||
return appname
|
||||
|
||||
|
||||
def shard_from_appname(appname: str):
|
||||
return int(appname.rsplit('_', maxsplit=1)[-1])
|
||||
|
||||
|
||||
shardname = appname_from_shard(sharding.shard_number)
|
||||
|
||||
log_app.set(shardname)
|
||||
|
||||
|
||||
shard_talk = AppClient(
|
||||
shardname,
|
||||
appname,
|
||||
{'host': args.host, 'port': args.port},
|
||||
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
|
||||
)
|
||||
|
||||
|
||||
@shard_talk.register_route()
|
||||
async def ping():
|
||||
return "Pong!"
|
||||
35
src/meta/args.py
Normal file
35
src/meta/args.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import argparse
|
||||
|
||||
from constants import CONFIG_FILE
|
||||
|
||||
# ------------------------------
|
||||
# Parsed commandline arguments
|
||||
# ------------------------------
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--conf',
|
||||
dest='config',
|
||||
default=CONFIG_FILE,
|
||||
help="Path to configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shard',
|
||||
dest='shard',
|
||||
default=None,
|
||||
type=int,
|
||||
help="Shard number to run, if applicable."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
dest='host',
|
||||
default='127.0.0.1',
|
||||
help="IP address to run the app listener on."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
dest='port',
|
||||
default='5001',
|
||||
help="Port to run the app listener on."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
137
src/meta/config.py
Normal file
137
src/meta/config.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from discord import PartialEmoji
|
||||
import configparser as cfgp
|
||||
|
||||
from .args import args
|
||||
|
||||
|
||||
class configEmoji(PartialEmoji):
|
||||
__slots__ = ('fallback',)
|
||||
|
||||
def __init__(self, *args, fallback=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fallback = fallback
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, emojistr: str):
|
||||
"""
|
||||
Parses emoji strings of one of the following forms
|
||||
`<a:name:id> or fallback`
|
||||
`<:name:id> or fallback`
|
||||
`<a:name:id>`
|
||||
`<:name:id>`
|
||||
"""
|
||||
splits = emojistr.rsplit(' or ', maxsplit=1)
|
||||
|
||||
fallback = splits[1] if len(splits) > 1 else None
|
||||
emojistr = splits[0].strip('<> ')
|
||||
animated, name, id = emojistr.split(':')
|
||||
return cls(
|
||||
name=name,
|
||||
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
|
||||
animated=bool(animated),
|
||||
id=int(id) if id else None
|
||||
)
|
||||
|
||||
|
||||
class MapDotProxy:
|
||||
"""
|
||||
Allows dot access to an underlying Mappable object.
|
||||
"""
|
||||
__slots__ = ("_map", "_converter")
|
||||
|
||||
def __init__(self, mappable, converter=None):
|
||||
self._map = mappable
|
||||
self._converter = converter
|
||||
|
||||
def __getattribute__(self, key):
|
||||
_map = object.__getattribute__(self, '_map')
|
||||
if key == '_map':
|
||||
return _map
|
||||
if key in _map:
|
||||
_converter = object.__getattribute__(self, '_converter')
|
||||
if _converter:
|
||||
return _converter(_map[key])
|
||||
else:
|
||||
return _map[key]
|
||||
else:
|
||||
return object.__getattribute__(_map, key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._map.__getitem__(key)
|
||||
|
||||
|
||||
class ConfigParser(cfgp.ConfigParser):
|
||||
"""
|
||||
Extension of base ConfigParser allowing optional
|
||||
section option retrieval without defaults.
|
||||
"""
|
||||
def options(self, section, no_defaults=False, **kwargs):
|
||||
if no_defaults:
|
||||
try:
|
||||
return list(self._sections[section].keys())
|
||||
except KeyError:
|
||||
raise cfgp.NoSectionError(section)
|
||||
else:
|
||||
return super().options(section, **kwargs)
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self, configfile, section_name="DEFAULT"):
|
||||
self.configfile = configfile
|
||||
|
||||
self.config = ConfigParser(
|
||||
converters={
|
||||
"intlist": self._getintlist,
|
||||
"list": self._getlist,
|
||||
"emoji": configEmoji.from_str,
|
||||
}
|
||||
)
|
||||
self.config.read(configfile)
|
||||
|
||||
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||
|
||||
self.default = self.config["DEFAULT"]
|
||||
self.section = MapDotProxy(self.config[self.section_name])
|
||||
self.bot = self.section
|
||||
|
||||
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||
read = set()
|
||||
while more_to_read:
|
||||
to_read = more_to_read.pop(0)
|
||||
read.add(to_read)
|
||||
self.config.read(to_read)
|
||||
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||
if path not in read and path not in more_to_read]
|
||||
more_to_read.extend(new_paths)
|
||||
|
||||
self.emojis = MapDotProxy(
|
||||
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
|
||||
converter=configEmoji.from_str
|
||||
)
|
||||
|
||||
global conf
|
||||
conf = self
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.section[key].strip()
|
||||
|
||||
def __getattr__(self, section):
|
||||
return self.config[section.upper()]
|
||||
|
||||
def get(self, name, fallback=None):
|
||||
result = self.section.get(name, fallback)
|
||||
return result.strip() if result else result
|
||||
|
||||
def _getintlist(self, value):
|
||||
return [int(item.strip()) for item in value.split(',')]
|
||||
|
||||
def _getlist(self, value):
|
||||
return [item.strip() for item in value.split(',')]
|
||||
|
||||
def write(self):
|
||||
with open(self.configfile, 'w') as conffile:
|
||||
self.config.write(conffile)
|
||||
|
||||
|
||||
conf = Conf(args.config, 'STUDYLION')
|
||||
20
src/meta/context.py
Normal file
20
src/meta/context.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Namespace for various global context variables.
|
||||
Allows asyncio callbacks to accurately retrieve information about the current state.
|
||||
"""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
from .LionContext import LionContext
|
||||
|
||||
|
||||
# Contains the current command context, if applicable
|
||||
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
|
||||
|
||||
# Contains the current LionBot instance
|
||||
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)
|
||||
64
src/meta/errors.py
Normal file
64
src/meta/errors.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
class SafeCancellation(Exception):
|
||||
"""
|
||||
Raised to safely cancel execution of the current operation.
|
||||
|
||||
If not caught, is expected to be propagated to the Tree and safely ignored there.
|
||||
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
|
||||
The error handler should then set the `msg` to None, to avoid double handling.
|
||||
Debugging information should go in `details`, to be logged by a top-level error handler.
|
||||
"""
|
||||
default_message = ""
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||
self._msg: Optional[str] = _msg
|
||||
self.details: str = details if details is not None else self.msg
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class UserInputError(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from unparseable user input.
|
||||
"""
|
||||
default_message = "Could not understand your input."
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
|
||||
self.info = info
|
||||
super().__init__(_msg, **kwargs)
|
||||
|
||||
|
||||
class UserCancelled(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from manual user cancellation.
|
||||
|
||||
Usually silent.
|
||||
"""
|
||||
default_msg = None
|
||||
|
||||
|
||||
class ResponseTimedOut(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from a user interaction time-out.
|
||||
"""
|
||||
default_msg = "Session timed out waiting for input."
|
||||
|
||||
|
||||
class HandledException(SafeCancellation):
|
||||
"""
|
||||
Sentinel class to indicate to error handlers that this exception has been handled.
|
||||
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
|
||||
"""
|
||||
def __init__(self, exc=None, **kwargs):
|
||||
self.exc = exc
|
||||
super().__init__(**kwargs)
|
||||
2
src/meta/ipc/__init__.py
Normal file
2
src/meta/ipc/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .client import AppClient, AppPayload, AppRoute
|
||||
from .server import AppServer
|
||||
236
src/meta/ipc/client.py
Normal file
236
src/meta/ipc/client.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from typing import Optional, TypeAlias, Any
|
||||
import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
from ..logger import logging_context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Address: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class AppClient:
|
||||
routes: dict[str, 'AppRoute'] = {} # route_name -> Callable[Any, Awaitable[Any]]
|
||||
|
||||
def __init__(self, appid: str, basename: str, client_address: Address, server_address: Address):
|
||||
self.appid = appid # String identifier for this ShardTalk client
|
||||
self.basename = basename # Prefix used to recognise app peers
|
||||
self.address = client_address
|
||||
self.server_address = server_address
|
||||
|
||||
self.peers = {appid: client_address} # appid -> address
|
||||
|
||||
self._listener: Optional[asyncio.Server] = None # Local client server
|
||||
self._server = None # Connection to the registry server
|
||||
self._keepalive = None
|
||||
|
||||
self.register_route('new_peer')(self.new_peer)
|
||||
self.register_route('drop_peer')(self.drop_peer)
|
||||
self.register_route('peer_list')(self.peer_list)
|
||||
|
||||
@property
|
||||
def my_peers(self):
|
||||
return {peerid: peer for peerid, peer in self.peers.items() if peerid.startswith(self.basename)}
|
||||
|
||||
def register_route(self, name=None):
|
||||
def wrapper(coro):
|
||||
route = AppRoute(coro, client=self, name=name)
|
||||
self.routes[route.name] = route
|
||||
return route
|
||||
return wrapper
|
||||
|
||||
async def server_connection(self):
|
||||
"""Establish a connection to the registry server"""
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(**self.server_address)
|
||||
|
||||
payload = ('connect', (), {'appid': self.appid, 'address': self.address})
|
||||
writer.write(pickle.dumps(payload))
|
||||
writer.write(b'\n')
|
||||
await writer.drain()
|
||||
|
||||
data = await reader.readline()
|
||||
peers = pickle.loads(data)
|
||||
self.peers = peers
|
||||
self._server = (reader, writer)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Could not connect to registry server. Trying again in 30 seconds.",
|
||||
extra={'action': 'Connect'}
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
else:
|
||||
logger.debug(
|
||||
"Connected to the registry server, launching keepalive.",
|
||||
extra={'action': 'Connect'}
|
||||
)
|
||||
self._keepalive = asyncio.create_task(self._server_keepalive())
|
||||
|
||||
async def _server_keepalive(self):
|
||||
with logging_context(action='Keepalive'):
|
||||
if self._server is None:
|
||||
raise ValueError("Cannot keepalive non-existent server!")
|
||||
reader, write = self._server
|
||||
try:
|
||||
await reader.read()
|
||||
except Exception:
|
||||
logger.exception("Lost connection to address server. Reconnecting...")
|
||||
else:
|
||||
# Connection ended or broke
|
||||
logger.info("Lost connection to address server. Reconnecting...")
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
|
||||
async def new_peer(self, appid, address):
|
||||
self.peers[appid] = address
|
||||
|
||||
async def peer_list(self, peers):
|
||||
self.peers = peers
|
||||
|
||||
async def drop_peer(self, appid):
|
||||
self.peers.pop(appid, None)
|
||||
|
||||
async def close(self):
|
||||
# Close connection to the server
|
||||
# TODO
|
||||
...
|
||||
|
||||
async def request(self, appid, payload: 'AppPayload', wait_for_reply=True):
|
||||
with logging_context(action=f"Req {appid}"):
|
||||
try:
|
||||
if appid not in self.peers:
|
||||
raise ValueError(f"Peer '{appid}' not found.")
|
||||
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
|
||||
|
||||
address = self.peers[appid]
|
||||
reader, writer = await asyncio.open_connection(**address)
|
||||
|
||||
writer.write(payload.encoded())
|
||||
await writer.drain()
|
||||
writer.write_eof()
|
||||
if wait_for_reply:
|
||||
result = await reader.read()
|
||||
writer.close()
|
||||
decoded = payload.route.decode(result)
|
||||
return decoded
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logging.exception(f"Failed to send request to {appid}'")
|
||||
return None
|
||||
|
||||
async def requestall(self, payload, except_self=True, only_my_peers=True):
|
||||
with logging_context(action="Broadcast"):
|
||||
peerlist = self.my_peers if only_my_peers else self.peers
|
||||
results = await asyncio.gather(
|
||||
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
|
||||
return_exceptions=True
|
||||
)
|
||||
return dict(zip(self.peers.keys(), results))
|
||||
|
||||
async def handle_request(self, reader, writer):
|
||||
with logging_context(action="SERV"):
|
||||
data = await reader.read()
|
||||
loaded = pickle.loads(data)
|
||||
route, args, kwargs = loaded
|
||||
|
||||
with logging_context(action=f"SERV {route}"):
|
||||
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||
|
||||
if route in self.routes:
|
||||
try:
|
||||
await self.routes[route].run((reader, writer), args, kwargs)
|
||||
except Exception:
|
||||
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
|
||||
else:
|
||||
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
|
||||
writer.write_eof()
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
Start the local peer server.
|
||||
Connect to the address server.
|
||||
"""
|
||||
with logging_context(stack=['ShardTalk']):
|
||||
# Start the client server
|
||||
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
|
||||
|
||||
logger.info(f"Serving on {self.address}")
|
||||
await self.server_connection()
|
||||
|
||||
|
||||
class AppPayload:
|
||||
__slots__ = ('route', 'args', 'kwargs')
|
||||
|
||||
def __init__(self, route, *args, **kwargs):
|
||||
self.route = route
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __await__(self):
|
||||
return self.route.execute(*self.args, **self.kwargs).__await__()
|
||||
|
||||
def encoded(self):
|
||||
return pickle.dumps((self.route.name, self.args, self.kwargs))
|
||||
|
||||
async def send(self, appid, **kwargs):
|
||||
return await self.route._client.request(appid, self, **kwargs)
|
||||
|
||||
async def broadcast(self, **kwargs):
|
||||
return await self.route._client.requestall(self, **kwargs)
|
||||
|
||||
|
||||
class AppRoute:
|
||||
__slots__ = ('func', 'name', '_client')
|
||||
|
||||
def __init__(self, func, client=None, name=None):
|
||||
self.func = func
|
||||
self.name = name or func.__name__
|
||||
self._client = client
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AppPayload(self, *args, **kwargs)
|
||||
|
||||
def encode(self, output):
|
||||
return pickle.dumps(output)
|
||||
|
||||
def decode(self, encoded):
|
||||
# TODO: Handle exceptions here somehow
|
||||
if len(encoded) > 0:
|
||||
return pickle.loads(encoded)
|
||||
else:
|
||||
return ''
|
||||
|
||||
def encoder(self, func):
|
||||
self.encode = func
|
||||
|
||||
def decoder(self, func):
|
||||
self.decode = func
|
||||
|
||||
async def execute(self, *args, **kwargs):
|
||||
"""
|
||||
Execute the underlying function, with the given arguments.
|
||||
"""
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
async def run(self, connection, args, kwargs):
|
||||
"""
|
||||
Run the route, with the given arguments, using the given connection.
|
||||
"""
|
||||
# TODO: ContextVar here for logging? Or in handle_request?
|
||||
# Get encoded result
|
||||
# TODO: handle exceptions in the execution process
|
||||
try:
|
||||
result = await self.execute(*args, **kwargs)
|
||||
payload = self.encode(result)
|
||||
except Exception:
|
||||
logger.exception(f"Exception occured running route '{self.name}' with args: {args} and kwargs: {kwargs}")
|
||||
payload = b''
|
||||
_, writer = connection
|
||||
writer.write(payload)
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
175
src/meta/ipc/server.py
Normal file
175
src/meta/ipc/server.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
import logging
|
||||
import string
|
||||
import random
|
||||
|
||||
from ..logger import log_context, log_app, logging_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
uuid_alphabet = string.ascii_lowercase + string.digits
|
||||
|
||||
|
||||
def short_uuid():
|
||||
return ''.join(random.choices(uuid_alphabet, k=10))
|
||||
|
||||
|
||||
class AppServer:
|
||||
routes = {} # route name -> bound method
|
||||
|
||||
def __init__(self):
|
||||
self.clients = {} # AppID -> (info, connection)
|
||||
|
||||
self.route('ping')(self.route_ping)
|
||||
self.route('whereis')(self.route_whereis)
|
||||
self.route('peers')(self.route_peers)
|
||||
self.route('connect')(self.client_connection)
|
||||
|
||||
@classmethod
|
||||
def route(cls, route_name):
|
||||
"""
|
||||
Decorator to add a route to the server.
|
||||
"""
|
||||
def wrapper(coro):
|
||||
cls.routes[route_name] = coro
|
||||
return coro
|
||||
return wrapper
|
||||
|
||||
async def route_ping(self, connection):
|
||||
"""
|
||||
Pong.
|
||||
"""
|
||||
reader, writer = connection
|
||||
writer.write(b"Pong")
|
||||
writer.write_eof()
|
||||
|
||||
async def route_whereis(self, connection, appid):
|
||||
"""
|
||||
Return an address for the given client appid.
|
||||
Returns None if the client does not have a connection.
|
||||
"""
|
||||
reader, writer = connection
|
||||
if appid in self.clients:
|
||||
writer.write(pickle.dumps(self.clients[appid][0]))
|
||||
else:
|
||||
writer.write(b'')
|
||||
writer.write_eof()
|
||||
|
||||
async def route_peers(self, connection):
|
||||
"""
|
||||
Send back a map of current peers.
|
||||
"""
|
||||
reader, writer = connection
|
||||
peers = self.peer_list()
|
||||
payload = pickle.dumps(('peer_list', (peers,)))
|
||||
writer.write(payload)
|
||||
writer.write_eof()
|
||||
|
||||
async def client_connection(self, connection, appid, address):
|
||||
"""
|
||||
Register and hold a new client connection.
|
||||
"""
|
||||
with logging_context(action=f"CONN {appid}"):
|
||||
reader, writer = connection
|
||||
# Add the new client
|
||||
self.clients[appid] = (address, connection)
|
||||
|
||||
# Send the new client a client list
|
||||
peers = self.peer_list()
|
||||
writer.write(pickle.dumps(peers))
|
||||
writer.write(b'\n')
|
||||
await writer.drain()
|
||||
|
||||
# Announce the new client to everyone
|
||||
await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
|
||||
|
||||
# Keep the connection open until socket closed or EOF (indicating client death)
|
||||
try:
|
||||
await reader.read()
|
||||
finally:
|
||||
# Connection ended or it broke
|
||||
logger.info(f"Lost client '{appid}'")
|
||||
await self.deregister_client(appid)
|
||||
|
||||
async def handle_connection(self, reader, writer):
|
||||
data = await reader.readline()
|
||||
route, args, kwargs = pickle.loads(data)
|
||||
|
||||
rqid = short_uuid()
|
||||
|
||||
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"):
|
||||
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||
|
||||
if route in self.routes:
|
||||
# Execute route
|
||||
try:
|
||||
await self.routes[route]((reader, writer), *args, **kwargs)
|
||||
except Exception:
|
||||
logger.exception(f"AppServer recieved exception during route '{route}'")
|
||||
else:
|
||||
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
|
||||
|
||||
def peer_list(self):
|
||||
return {appid: address for appid, (address, _) in self.clients.items()}
|
||||
|
||||
async def deregister_client(self, appid):
|
||||
self.clients.pop(appid, None)
|
||||
await self.broadcast('drop_peer', (), {'appid': appid})
|
||||
|
||||
async def broadcast(self, route, args, kwargs):
|
||||
with logging_context(action="broadcast"):
|
||||
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
|
||||
payload = pickle.dumps((route, args, kwargs))
|
||||
if self.clients:
|
||||
await asyncio.gather(
|
||||
*(self._send(appid, payload) for appid in self.clients),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
async def message_client(self, appid, route, args, kwargs):
|
||||
"""
|
||||
Send a message to client `appid` along `route` with given arguments.
|
||||
"""
|
||||
with logging_context(action=f"MSG {appid}"):
|
||||
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
|
||||
if appid not in self.clients:
|
||||
raise ValueError(f"Client '{appid}' is not connected.")
|
||||
|
||||
payload = pickle.dumps((route, args, kwargs))
|
||||
return await self._send(appid, payload)
|
||||
|
||||
async def _send(self, appid, payload):
|
||||
"""
|
||||
Send the encoded `payload` to the client `appid`.
|
||||
"""
|
||||
address, _ = self.clients[appid]
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(**address)
|
||||
writer.write(payload)
|
||||
writer.write_eof()
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
except Exception as ex:
|
||||
# TODO: Close client if we can't connect?
|
||||
logger.exception(f"Failed to send message to '{appid}'")
|
||||
raise ex
|
||||
|
||||
async def start(self, address):
|
||||
log_app.set("APPSERVER")
|
||||
with logging_context(stack=["SERV"]):
|
||||
server = await asyncio.start_server(self.handle_connection, **address)
|
||||
logger.info(f"Serving on {address}")
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
async def start_server():
|
||||
address = {'host': '127.0.0.1', 'port': '5000'}
|
||||
server = AppServer()
|
||||
await server.start(address)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(start_server())
|
||||
324
src/meta/logger.py
Normal file
324
src/meta/logger.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List
|
||||
from logging.handlers import QueueListener, QueueHandler
|
||||
from queue import SimpleQueue
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
|
||||
from discord import Webhook, File
|
||||
import aiohttp
|
||||
|
||||
from .config import conf
|
||||
from . import sharding
|
||||
from .context import context
|
||||
from utils.lib import utc_now
|
||||
|
||||
|
||||
log_logger = logging.getLogger(__name__)
|
||||
log_logger.propagate = False
|
||||
|
||||
|
||||
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||
log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[])
|
||||
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def logging_context(context=None, action=None, stack=None):
|
||||
if context is not None:
|
||||
context_t = log_context.set(context)
|
||||
if action is not None:
|
||||
astack = log_action_stack.get()
|
||||
log_action_stack.set(astack + [action])
|
||||
if stack is not None:
|
||||
actions_t = log_action_stack.set(stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if context is not None:
|
||||
log_context.reset(context_t)
|
||||
if stack is not None:
|
||||
log_action_stack.reset(actions_t)
|
||||
if action is not None:
|
||||
log_action_stack.set(astack)
|
||||
|
||||
|
||||
def log_wrap(**kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
with logging_context(**kwargs):
|
||||
return await func(*w_args, **w_kwargs)
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
RESET_SEQ = "\033[0m"
|
||||
COLOR_SEQ = "\033[3%dm"
|
||||
BOLD_SEQ = "\033[1m"
|
||||
"]]]"
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||
|
||||
|
||||
def colour_escape(fmt: str) -> str:
|
||||
cmap = {
|
||||
'%(black)': COLOR_SEQ % BLACK,
|
||||
'%(red)': COLOR_SEQ % RED,
|
||||
'%(green)': COLOR_SEQ % GREEN,
|
||||
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||
'%(blue)': COLOR_SEQ % BLUE,
|
||||
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||
'%(cyan)': COLOR_SEQ % CYAN,
|
||||
'%(white)': COLOR_SEQ % WHITE,
|
||||
'%(reset)': RESET_SEQ,
|
||||
'%(bold)': BOLD_SEQ,
|
||||
}
|
||||
for key, value in cmap.items():
|
||||
fmt = fmt.replace(key, value)
|
||||
return fmt
|
||||
|
||||
|
||||
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
|
||||
'[%(cyan)%(app)-15s%(reset)]' +
|
||||
'[%(cyan)%(context)-24s%(reset)]' +
|
||||
'[%(cyan)%(actionstr)-22s%(reset)]' +
|
||||
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||
log_format = colour_escape(log_format)
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=log_format,
|
||||
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
|
||||
class ThreadFilter(logging.Filter):
|
||||
def __init__(self, thread_name):
|
||||
super().__init__("")
|
||||
self.thread = thread_name
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.threadName == self.thread else 0
|
||||
|
||||
|
||||
class ContextInjection(logging.Filter):
|
||||
def filter(self, record):
|
||||
# These guards are to allow override through _extra
|
||||
# And to ensure the injection is idempotent
|
||||
if not hasattr(record, 'context'):
|
||||
record.context = log_context.get()
|
||||
|
||||
if not hasattr(record, 'actionstr'):
|
||||
action_stack = log_action_stack.get()
|
||||
if hasattr(record, 'action'):
|
||||
action_stack = (*action_stack, record.action)
|
||||
if action_stack:
|
||||
record.actionstr = ' ➔ '.join(action_stack)
|
||||
else:
|
||||
record.actionstr = "Unknown Action"
|
||||
|
||||
if not hasattr(record, 'app'):
|
||||
record.app = log_app.get()
|
||||
|
||||
if not hasattr(record, 'ctx'):
|
||||
if ctx := context.get():
|
||||
record.ctx = repr(ctx)
|
||||
else:
|
||||
record.ctx = None
|
||||
|
||||
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||
record.ctxstr = '\n' + record.ctx
|
||||
else:
|
||||
record.ctxstr = ""
|
||||
return True
|
||||
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||
logging_handler_out.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_out)
|
||||
log_logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logging_handler_err.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_err)
|
||||
log_logger.addHandler(logging_handler_err)
|
||||
|
||||
|
||||
class LocalQueueHandler(QueueHandler):
|
||||
def _emit(self, record: logging.LogRecord) -> None:
|
||||
# Removed the call to self.prepare(), handle task cancellation
|
||||
try:
|
||||
self.enqueue(record)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class WebHookHandler(logging.StreamHandler):
|
||||
def __init__(self, webhook_url, batch=False, loop=None):
|
||||
super().__init__()
|
||||
self.webhook_url = webhook_url
|
||||
self.batched = ""
|
||||
self.batch = batch
|
||||
self.loop = loop
|
||||
self.batch_delay = 10
|
||||
self.batch_task = None
|
||||
self.last_batched = None
|
||||
self.waiting = []
|
||||
|
||||
def get_loop(self):
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
return self.loop
|
||||
|
||||
def emit(self, record):
|
||||
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||
|
||||
def _post(self, record):
|
||||
asyncio.create_task(self.post(record))
|
||||
|
||||
async def post(self, record):
|
||||
log_context.set("Webhook Logger")
|
||||
log_action_stack.set(["Logging"])
|
||||
log_app.set(record.app)
|
||||
|
||||
try:
|
||||
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||
message = f"{header}\n{record.msg}{context}"
|
||||
|
||||
if len(message) > 1900:
|
||||
as_file = True
|
||||
else:
|
||||
as_file = False
|
||||
message = "```md\n{}\n```".format(message)
|
||||
|
||||
# Post the log message(s)
|
||||
if self.batch:
|
||||
if len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
await self._send(message, as_file=as_file)
|
||||
else:
|
||||
self.batched += message
|
||||
if len(self.batched) + len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
else:
|
||||
asyncio.create_task(self._schedule_batched())
|
||||
else:
|
||||
await self._send(message, as_file=as_file)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
async def _schedule_batched(self):
|
||||
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||
# noop, don't reschedule if it is already scheduled
|
||||
return
|
||||
try:
|
||||
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||
await self.batch_task
|
||||
await self._send_batched()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
async def _send_batched_now(self):
|
||||
if self.batch_task is not None and not self.batch_task.done():
|
||||
self.batch_task.cancel()
|
||||
self.last_batched = None
|
||||
await self._send_batched()
|
||||
|
||||
async def _send_batched(self):
|
||||
if self.batched:
|
||||
batched = self.batched
|
||||
self.batched = ""
|
||||
await self._send(batched)
|
||||
|
||||
async def _send(self, message, as_file=False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
webhook = Webhook.from_url(self.webhook_url, session=session)
|
||||
if as_file or len(message) > 2000:
|
||||
with StringIO(message) as fp:
|
||||
fp.seek(0)
|
||||
await webhook.send(
|
||||
file=File(fp, filename="logs.md"),
|
||||
username=log_app.get()
|
||||
)
|
||||
else:
|
||||
await webhook.send(message, username=log_app.get())
|
||||
|
||||
|
||||
handlers = []
|
||||
if webhook := conf.logging['general_log']:
|
||||
handler = WebHookHandler(webhook, batch=True)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['error_log']:
|
||||
handler = WebHookHandler(webhook, batch=False)
|
||||
handler.setLevel(logging.ERROR)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['critical_log']:
|
||||
handler = WebHookHandler(webhook, batch=False)
|
||||
handler.setLevel(logging.CRITICAL)
|
||||
handlers.append(handler)
|
||||
|
||||
if handlers:
|
||||
# First create a separate loop to run the handlers on
|
||||
import threading
|
||||
|
||||
def run_loop(loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_forever()
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||
loop_thread.daemon = True
|
||||
loop_thread.start()
|
||||
|
||||
for handler in handlers:
|
||||
handler.loop = loop
|
||||
|
||||
queue: SimpleQueue[logging.LogRecord] = SimpleQueue()
|
||||
|
||||
qhandler = QueueHandler(queue)
|
||||
qhandler.setLevel(logging.INFO)
|
||||
qhandler.addFilter(ContextInjection())
|
||||
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||
logger.addHandler(qhandler)
|
||||
|
||||
listener = QueueListener(
|
||||
queue, *handlers, respect_handler_level=True
|
||||
)
|
||||
listener.start()
|
||||
34
src/meta/pending-rewrite/client.py
Normal file
34
src/meta/pending-rewrite/client.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from discord import Intents
|
||||
from cmdClient.cmdClient import cmdClient
|
||||
|
||||
from . import patches
|
||||
from .interactions import InteractionType
|
||||
from .config import conf
|
||||
from .sharding import shard_number, shard_count
|
||||
from LionContext import LionContext
|
||||
|
||||
|
||||
# Initialise client
|
||||
owners = [int(owner) for owner in conf.bot.getlist('owners')]
|
||||
intents = Intents.all()
|
||||
intents.presences = False
|
||||
client = cmdClient(
|
||||
prefix=conf.bot['prefix'],
|
||||
owners=owners,
|
||||
intents=intents,
|
||||
shard_id=shard_number,
|
||||
shard_count=shard_count,
|
||||
baseContext=LionContext
|
||||
)
|
||||
client.conf = conf
|
||||
|
||||
|
||||
# TODO: Could include client id here, or app id, to avoid multiple handling.
|
||||
NOOP_ID = 'NOOP'
|
||||
|
||||
|
||||
@client.add_after_event('interaction_create')
|
||||
async def handle_noop_interaction(client, interaction):
|
||||
if interaction.interaction_type in (InteractionType.MESSAGE_COMPONENT, InteractionType.MODAL_SUBMIT):
|
||||
if interaction.custom_id == NOOP_ID:
|
||||
interaction.ack()
|
||||
110
src/meta/pending-rewrite/logger.py
Normal file
110
src/meta/pending-rewrite/logger.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from discord import AllowedMentions
|
||||
|
||||
from cmdClient.logger import cmd_log_handler
|
||||
|
||||
from utils.lib import mail, split_text
|
||||
|
||||
from .client import client
|
||||
from .config import conf
|
||||
from . import sharding
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=('[{asctime}][{levelname:^8}]' +
|
||||
'[SHARD {}]'.format(sharding.shard_number) +
|
||||
' {message}'),
|
||||
datefmt='%d/%m | %H:%M:%S',
|
||||
style='{'
|
||||
)
|
||||
# term_handler = logging.StreamHandler(sys.stdout)
|
||||
# term_handler.setFormatter(log_fmt)
|
||||
# logger.addHandler(term_handler)
|
||||
# logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||
logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logger.addHandler(logging_handler_err)
|
||||
|
||||
|
||||
# Define the context log format and attach it to the command logger as well
|
||||
@cmd_log_handler
|
||||
def log(message, context="GLOBAL", level=logging.INFO, post=True):
|
||||
# Add prefixes to lines for better parsing capability
|
||||
lines = message.splitlines()
|
||||
if len(lines) > 1:
|
||||
lines = [
|
||||
'┌ ' * (i == 0) + '│ ' * (0 < i < len(lines) - 1) + '└ ' * (i == len(lines) - 1) + line
|
||||
for i, line in enumerate(lines)
|
||||
]
|
||||
else:
|
||||
lines = ['─ ' + message]
|
||||
|
||||
for line in lines:
|
||||
logger.log(level, '\b[{}] {}'.format(
|
||||
str(context).center(22, '='),
|
||||
line
|
||||
))
|
||||
|
||||
# Fire and forget to the channel logger, if it is set up
|
||||
if post and client.is_ready():
|
||||
asyncio.ensure_future(live_log(message, context, level))
|
||||
|
||||
|
||||
# Live logger that posts to the logging channels
|
||||
async def live_log(message, context, level):
|
||||
if level >= logging.INFO:
|
||||
if level >= logging.WARNING:
|
||||
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
|
||||
else:
|
||||
log_chid = conf.bot.getint('log_channel')
|
||||
|
||||
# Generate the log messages
|
||||
if sharding.sharded:
|
||||
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
|
||||
else:
|
||||
header = f"[{logging.getLevelName(level)}][{context}]"
|
||||
|
||||
if len(message) > 1900:
|
||||
blocks = split_text(message, blocksize=1900, code=False)
|
||||
else:
|
||||
blocks = [message]
|
||||
|
||||
if len(blocks) > 1:
|
||||
blocks = [
|
||||
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
|
||||
]
|
||||
else:
|
||||
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
|
||||
|
||||
# Post the log messages
|
||||
if log_chid:
|
||||
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
|
||||
|
||||
|
||||
# Attach logger to client, for convenience
|
||||
client.log = log
|
||||
35
src/meta/sharding.py
Normal file
35
src/meta/sharding.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from .args import args
|
||||
from .config import conf
|
||||
|
||||
from psycopg import sql
|
||||
from data.conditions import Condition, Joiner
|
||||
|
||||
|
||||
shard_number = args.shard or 0
|
||||
|
||||
shard_count = conf.bot.getint('shard_count', 1)
|
||||
|
||||
sharded = (shard_count > 0)
|
||||
|
||||
|
||||
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering by shard id.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(_shard_condition('guildid', 10, 1))
|
||||
"""
|
||||
return Condition(
|
||||
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
|
||||
guildid=sql.Identifier(guild_column),
|
||||
shard_count=sql.Literal(shard_count)
|
||||
),
|
||||
Joiner.EQUALS,
|
||||
sql.Placeholder(),
|
||||
(shard_id,)
|
||||
)
|
||||
|
||||
|
||||
# Pre-built Condition for filtering by current shard.
|
||||
THIS_SHARD = SHARDID(shard_number)
|
||||
Reference in New Issue
Block a user