rewrite: Restructure to include GUI.

This commit is contained in:
2022-12-23 06:44:32 +02:00
parent 2b93354248
commit f328324747
224 changed files with 8 additions and 0 deletions

200
src/meta/LionBot.py Normal file
View File

@@ -0,0 +1,200 @@
from typing import List, Optional, TYPE_CHECKING
import logging
import asyncio
import discord
from discord.utils import MISSING
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
from aiohttp import ClientSession
from data import Database
from .config import Conf
from .logger import logging_context, log_context, log_action_stack
from .context import context
from .LionContext import LionContext
from .LionTree import LionTree
from .errors import HandledException, SafeCancellation
if TYPE_CHECKING:
from core import CoreCog
logger = logging.getLogger(__name__)
class LionBot(Bot):
def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [], translator=None, **kwargs
):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs)
self.web_client = web_client
self.testing_guilds = testing_guilds
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.app_ipc = app_ipc
self.core: Optional['CoreCog'] = None
self.translator = translator
async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
await self.app_ipc.connect()
if self.translator is not None:
await self.tree.set_translator(self.translator)
for extension in self.initial_extensions:
await self.load_extension(extension)
for guildid in self.testing_guilds:
guild = discord.Object(guildid)
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)
async def add_cog(self, cog: Cog, **kwargs):
with logging_context(action=f"Attach {cog.__cog_name__}"):
logger.info(f"Attaching Cog {cog.__cog_name__}")
await super().add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
async def load_extension(self, name, *, package=None, **kwargs):
with logging_context(action=f"Load {name.strip('.')}"):
logger.info(f"Loading extension {name} in package {package}.")
await super().load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.")
async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"):
await self.login(token)
with logging_context(stack=["Running"]):
await self.connect(reconnect=reconnect)
async def on_ready(self):
logger.info(
f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.shardname}\n"
"------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
"Ready to take commands!\n",
extra={'action': 'Ready'}
)
async def get_context(self, origin, /, *, cls=MISSING):
if cls is MISSING:
cls = LionContext
ctx = await super().get_context(origin, cls=cls)
context.set(ctx)
return ctx
async def on_command(self, ctx: LionContext):
logger.info(
f"Executing command '{ctx.command.qualified_name}' "
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
f"with arguments {ctx.args} and kwargs {ctx.kwargs}.",
extra={'with_ctx': True}
)
async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback
cmd_str = str(ctx.command)
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
try:
raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
try:
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
original = exception.original.original
raise original
else:
original = exception.original
raise original
except HandledException:
pass
except SafeCancellation:
if original.msg:
try:
await ctx.error_reply(original.msg)
except Exception:
pass
logger.debug(
f"Caught a safe cancellation: {original.details}",
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.Forbidden:
# Unknown uncaught Forbidden
try:
# Attempt a general error reply
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
except Exception:
# We can't send anything at all. Exit quietly, but log.
logger.warning(
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.HTTPException:
logger.warning(
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass
except Exception:
logger.exception(
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
error_embed = discord.Embed(title="Something went wrong!")
error_embed.description = (
"An unexpected error occurred while processing your command!\n"
"Our development team has been notified, and the issue should be fixed soon.\n"
"If the error persists, please contact our support team and give them the following number: "
f"`{ctx.interaction.id}`"
)
try:
await ctx.error_reply(embed=error_embed)
except Exception:
pass
finally:
exception.original = HandledException(exception.original)
except CheckFailure:
logger.debug(
f"Command failed check: {exception}",
extra={'action': 'BotError', 'with_ctx': True}
)
try:
await ctx.error_rely(exception.message)
except Exception:
pass
except Exception:
# Completely unknown exception outside of command invocation!
# Something is very wrong here, don't attempt user interaction.
logger.exception(
f"Caught an unknown top-level exception while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
def add_command(self, command):
if hasattr(command, '_placeholder_group_'):
return
super().add_command(command)

58
src/meta/LionCog.py Normal file
View File

@@ -0,0 +1,58 @@
from typing import Any
from discord.ext.commands import Cog
from discord.ext import commands as cmds
class LionCog(Cog):
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set()
for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name)
def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list
self = super().__new__(cls)
self.__cog_commands__ = [
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
]
return self
async def _inject(self, bot, *args, **kwargs):
if self.depends_on:
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
return await super()._inject(bot, *args, *kwargs)
@classmethod
def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True
return group
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
"""
Crossload a placeholder group's commands into the target group
"""
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
raise ValueError("Placeholder and target groups my be HypridGroups.")
if placeholder_group.name not in self._placeholder_groups_:
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
if target_group.app_command is None:
raise ValueError("Target group has no app_command to crossload into.")
for command in placeholder_group.commands:
placeholder_group.remove_command(command.name)
target_group.remove_command(command.name)
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
command.app_command = acmd
target_group.add_command(command)

190
src/meta/LionContext.py Normal file
View File

@@ -0,0 +1,190 @@
import types
import logging
from collections import namedtuple
from typing import Optional, TYPE_CHECKING
import discord
from discord.ext.commands import Context
if TYPE_CHECKING:
from .LionBot import LionBot
from core.lion import Lion
logger = logging.getLogger(__name__)
"""
Stuff that might be useful to implement (see cmdClient):
sent_messages cache
tasks cache
error reply
usage
interaction cache
View cache?
setting access
"""
FlatContext = namedtuple(
'FlatContext',
('message',
'interaction',
'guild',
'author',
'alias',
'prefix',
'failed')
)
class LionContext(Context['LionBot']):
"""
Represents the context a command is invoked under.
Extends Context to add Lion-specific methods and attributes.
Also adds several contextual wrapped utilities for simpler user during command invocation.
"""
alion: 'Lion'
def __repr__(self):
parts = {}
if self.interaction is not None:
parts['iid'] = self.interaction.id
parts['itype'] = f"\"{self.interaction.type.name}\""
if self.message is not None:
parts['mid'] = self.message.id
if self.author is not None:
parts['uid'] = self.author.id
parts['uname'] = f"\"{self.author.name}\""
if self.channel is not None:
parts['cid'] = self.channel.id
parts['cname'] = f"\"{self.channel.name}\""
if self.guild is not None:
parts['gid'] = self.guild.id
parts['gname'] = f"\"{self.guild.name}\""
if self.command is not None:
parts['cmd'] = f"\"{self.command.qualified_name}\""
if self.invoked_with is not None:
parts['alias'] = f"\"{self.invoked_with}\""
if self.command_failed:
parts['failed'] = self.command_failed
return "<LionContext: {}>".format(
' '.join(f"{name}={value}" for name, value in parts.items())
)
def flatten(self):
"""Flat pure-data context information, for caching and logging."""
return FlatContext(
self.message.id,
self.interaction.id if self.interaction is not None else None,
self.guild.id if self.guild is not None else None,
self.author.id if self.author is not None else None,
self.channel.id if self.channel is not None else None,
self.invoked_with,
self.prefix,
self.command_failed
)
@classmethod
def util(cls, util_func):
"""
Decorator to make a utility function available as a Context instance method.
"""
setattr(cls, util_func.__name__, util_func)
logger.debug(f"Attached context utility function: {util_func.__name__}")
return util_func
@classmethod
def wrappable_util(cls, util_func):
"""
Decorator to add a Wrappable utility function as a Context instance method.
"""
wrapped = Wrappable(util_func)
setattr(cls, util_func.__name__, wrapped)
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
return wrapped
async def error_reply(self, content: Optional[str] = None, **kwargs):
if content and 'embed' not in kwargs:
embed = discord.Embed(
colour=discord.Colour.red(),
description=content
)
kwargs['embed'] = embed
content = None
# Expect this may be run in highly unusual circumstances.
# This should never error, or at least handle all errors.
try:
await self.reply(content=content, **kwargs)
except discord.HTTPException:
pass
except Exception:
logger.exception(
"Unknown exception in 'error_reply'.",
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
)
class Wrappable:
__slots__ = ('_func', 'wrappers')
def __init__(self, func):
self._func = func
self.wrappers = None
@property
def __name__(self):
return self._func.__name__
def add_wrapper(self, func, name=None):
self.wrappers = self.wrappers or {}
name = name or func.__name__
self.wrappers[name] = func
logger.debug(
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
extra={'action': "Wrap Util"}
)
def remove_wrapper(self, name):
if not self.wrappers or name not in self.wrappers:
raise ValueError(
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
)
self.wrappers.pop(name)
logger.debug(
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
extra={'action': "Unwrap Util"}
)
def __call__(self, *args, **kwargs):
if self.wrappers:
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
else:
return self._func(*args, **kwargs)
def _wrapped(self, iter_wraps):
next_wrap = next(iter_wraps, None)
if next_wrap:
def _func(*args, **kwargs):
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
else:
_func = self._func
return _func
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
return types.MethodType(self, instance)
LionContext.reply = Wrappable(LionContext.reply)
# @LionContext.reply.add_wrapper
# async def think(func, ctx, *args, **kwargs):
# await ctx.channel.send("thinking")
# await func(ctx, *args, **kwargs)

79
src/meta/LionTree.py Normal file
View File

@@ -0,0 +1,79 @@
import logging
from discord import Interaction
from discord.app_commands import CommandTree
from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace
from .logger import logging_context
from .errors import SafeCancellation
logger = logging.getLogger(__name__)
class LionTree(CommandTree):
async def on_error(self, interaction, error) -> None:
try:
if isinstance(error, CommandInvokeError):
raise error.original
else:
raise error
except SafeCancellation:
# Assume this has already been handled
pass
except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
async def _call(self, interaction):
with logging_context(context=f"iid: {interaction.id}"):
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above
interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
with logging_context(action=f"Acmp {command.qualified_name}"):
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
await command._invoke_autocomplete(interaction, focused, namespace)
return
with logging_context(action=f"Run {command.qualified_name}"):
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")

15
src/meta/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .LionBot import LionBot
from .LionCog import LionCog
from .LionContext import LionContext
from .LionTree import LionTree
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji
from .args import args
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot
from . import sharding
from . import logger
from . import app

46
src/meta/app.py Normal file
View File

@@ -0,0 +1,46 @@
"""
appname: str
The base identifer for this application.
This identifies which services the app offers.
shardname: str
The specific name of the running application.
Only one process should be connecteded with a given appname.
For the bot apps, usually specifies the shard id and shard number.
"""
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
from . import sharding, conf
from .logger import log_app
from .ipc.client import AppClient
from .args import args
appname = conf.data['appid']
appid = appname # backwards compatibility
def appname_from_shard(shardid):
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
return appname
def shard_from_appname(appname: str):
return int(appname.rsplit('_', maxsplit=1)[-1])
shardname = appname_from_shard(sharding.shard_number)
log_app.set(shardname)
shard_talk = AppClient(
shardname,
appname,
{'host': args.host, 'port': args.port},
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
)
@shard_talk.register_route()
async def ping():
return "Pong!"

35
src/meta/args.py Normal file
View File

@@ -0,0 +1,35 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file."
)
parser.add_argument(
'--shard',
dest='shard',
default=None,
type=int,
help="Shard number to run, if applicable."
)
parser.add_argument(
'--host',
dest='host',
default='127.0.0.1',
help="IP address to run the app listener on."
)
parser.add_argument(
'--port',
dest='port',
default='5001',
help="Port to run the app listener on."
)
args = parser.parse_args()

137
src/meta/config.py Normal file
View File

@@ -0,0 +1,137 @@
from discord import PartialEmoji
import configparser as cfgp
from .args import args
class configEmoji(PartialEmoji):
__slots__ = ('fallback',)
def __init__(self, *args, fallback=None, **kwargs):
super().__init__(*args, **kwargs)
self.fallback = fallback
@classmethod
def from_str(cls, emojistr: str):
"""
Parses emoji strings of one of the following forms
`<a:name:id> or fallback`
`<:name:id> or fallback`
`<a:name:id>`
`<:name:id>`
"""
splits = emojistr.rsplit(' or ', maxsplit=1)
fallback = splits[1] if len(splits) > 1 else None
emojistr = splits[0].strip('<> ')
animated, name, id = emojistr.split(':')
return cls(
name=name,
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated),
id=int(id) if id else None
)
class MapDotProxy:
"""
Allows dot access to an underlying Mappable object.
"""
__slots__ = ("_map", "_converter")
def __init__(self, mappable, converter=None):
self._map = mappable
self._converter = converter
def __getattribute__(self, key):
_map = object.__getattribute__(self, '_map')
if key == '_map':
return _map
if key in _map:
_converter = object.__getattribute__(self, '_converter')
if _converter:
return _converter(_map[key])
else:
return _map[key]
else:
return object.__getattribute__(_map, key)
def __getitem__(self, key):
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
"emoji": configEmoji.from_str,
}
)
self.config.read(configfile)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = MapDotProxy(self.config[self.section_name])
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
self.emojis = MapDotProxy(
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
converter=configEmoji.from_str
)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
return self.config[section.upper()]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config, 'STUDYLION')

20
src/meta/context.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Namespace for various global context variables.
Allows asyncio callbacks to accurately retrieve information about the current state.
"""
from typing import TYPE_CHECKING, Optional
from contextvars import ContextVar
if TYPE_CHECKING:
from .LionBot import LionBot
from .LionContext import LionContext
# Contains the current command context, if applicable
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
# Contains the current LionBot instance
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)

64
src/meta/errors.py Normal file
View File

@@ -0,0 +1,64 @@
from typing import Optional
from string import Template
class SafeCancellation(Exception):
"""
Raised to safely cancel execution of the current operation.
If not caught, is expected to be propagated to the Tree and safely ignored there.
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
The error handler should then set the `msg` to None, to avoid double handling.
Debugging information should go in `details`, to be logged by a top-level error handler.
"""
default_message = ""
@property
def msg(self):
return self._msg if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self._msg: Optional[str] = _msg
self.details: str = details if details is not None else self.msg
super().__init__(**kwargs)
class UserInputError(SafeCancellation):
"""
A SafeCancellation induced from unparseable user input.
"""
default_message = "Could not understand your input."
@property
def msg(self):
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
self.info = info
super().__init__(_msg, **kwargs)
class UserCancelled(SafeCancellation):
"""
A SafeCancellation induced from manual user cancellation.
Usually silent.
"""
default_msg = None
class ResponseTimedOut(SafeCancellation):
"""
A SafeCancellation induced from a user interaction time-out.
"""
default_msg = "Session timed out waiting for input."
class HandledException(SafeCancellation):
"""
Sentinel class to indicate to error handlers that this exception has been handled.
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
"""
def __init__(self, exc=None, **kwargs):
self.exc = exc
super().__init__(**kwargs)

2
src/meta/ipc/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .client import AppClient, AppPayload, AppRoute
from .server import AppServer

236
src/meta/ipc/client.py Normal file
View File

@@ -0,0 +1,236 @@
from typing import Optional, TypeAlias, Any
import asyncio
import logging
import pickle
from ..logger import logging_context
logger = logging.getLogger(__name__)
Address: TypeAlias = dict[str, Any]
class AppClient:
routes: dict[str, 'AppRoute'] = {} # route_name -> Callable[Any, Awaitable[Any]]
def __init__(self, appid: str, basename: str, client_address: Address, server_address: Address):
self.appid = appid # String identifier for this ShardTalk client
self.basename = basename # Prefix used to recognise app peers
self.address = client_address
self.server_address = server_address
self.peers = {appid: client_address} # appid -> address
self._listener: Optional[asyncio.Server] = None # Local client server
self._server = None # Connection to the registry server
self._keepalive = None
self.register_route('new_peer')(self.new_peer)
self.register_route('drop_peer')(self.drop_peer)
self.register_route('peer_list')(self.peer_list)
@property
def my_peers(self):
return {peerid: peer for peerid, peer in self.peers.items() if peerid.startswith(self.basename)}
def register_route(self, name=None):
def wrapper(coro):
route = AppRoute(coro, client=self, name=name)
self.routes[route.name] = route
return route
return wrapper
async def server_connection(self):
"""Establish a connection to the registry server"""
try:
reader, writer = await asyncio.open_connection(**self.server_address)
payload = ('connect', (), {'appid': self.appid, 'address': self.address})
writer.write(pickle.dumps(payload))
writer.write(b'\n')
await writer.drain()
data = await reader.readline()
peers = pickle.loads(data)
self.peers = peers
self._server = (reader, writer)
except Exception:
logger.exception(
"Could not connect to registry server. Trying again in 30 seconds.",
extra={'action': 'Connect'}
)
await asyncio.sleep(30)
asyncio.create_task(self.server_connection())
else:
logger.debug(
"Connected to the registry server, launching keepalive.",
extra={'action': 'Connect'}
)
self._keepalive = asyncio.create_task(self._server_keepalive())
async def _server_keepalive(self):
with logging_context(action='Keepalive'):
if self._server is None:
raise ValueError("Cannot keepalive non-existent server!")
reader, write = self._server
try:
await reader.read()
except Exception:
logger.exception("Lost connection to address server. Reconnecting...")
else:
# Connection ended or broke
logger.info("Lost connection to address server. Reconnecting...")
await asyncio.sleep(30)
asyncio.create_task(self.server_connection())
async def new_peer(self, appid, address):
self.peers[appid] = address
async def peer_list(self, peers):
self.peers = peers
async def drop_peer(self, appid):
self.peers.pop(appid, None)
async def close(self):
# Close connection to the server
# TODO
...
async def request(self, appid, payload: 'AppPayload', wait_for_reply=True):
with logging_context(action=f"Req {appid}"):
try:
if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address)
writer.write(payload.encoded())
await writer.drain()
writer.write_eof()
if wait_for_reply:
result = await reader.read()
writer.close()
decoded = payload.route.decode(result)
return decoded
else:
return None
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None
async def requestall(self, payload, except_self=True, only_my_peers=True):
with logging_context(action="Broadcast"):
peerlist = self.my_peers if only_my_peers else self.peers
results = await asyncio.gather(
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
return_exceptions=True
)
return dict(zip(self.peers.keys(), results))
async def handle_request(self, reader, writer):
with logging_context(action="SERV"):
data = await reader.read()
loaded = pickle.loads(data)
route, args, kwargs = loaded
with logging_context(action=f"SERV {route}"):
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
async def connect(self):
"""
Start the local peer server.
Connect to the address server.
"""
with logging_context(stack=['ShardTalk']):
# Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
logger.info(f"Serving on {self.address}")
await self.server_connection()
class AppPayload:
__slots__ = ('route', 'args', 'kwargs')
def __init__(self, route, *args, **kwargs):
self.route = route
self.args = args
self.kwargs = kwargs
def __await__(self):
return self.route.execute(*self.args, **self.kwargs).__await__()
def encoded(self):
return pickle.dumps((self.route.name, self.args, self.kwargs))
async def send(self, appid, **kwargs):
return await self.route._client.request(appid, self, **kwargs)
async def broadcast(self, **kwargs):
return await self.route._client.requestall(self, **kwargs)
class AppRoute:
__slots__ = ('func', 'name', '_client')
def __init__(self, func, client=None, name=None):
self.func = func
self.name = name or func.__name__
self._client = client
def __call__(self, *args, **kwargs):
return AppPayload(self, *args, **kwargs)
def encode(self, output):
return pickle.dumps(output)
def decode(self, encoded):
# TODO: Handle exceptions here somehow
if len(encoded) > 0:
return pickle.loads(encoded)
else:
return ''
def encoder(self, func):
self.encode = func
def decoder(self, func):
self.decode = func
async def execute(self, *args, **kwargs):
"""
Execute the underlying function, with the given arguments.
"""
return await self.func(*args, **kwargs)
async def run(self, connection, args, kwargs):
"""
Run the route, with the given arguments, using the given connection.
"""
# TODO: ContextVar here for logging? Or in handle_request?
# Get encoded result
# TODO: handle exceptions in the execution process
try:
result = await self.execute(*args, **kwargs)
payload = self.encode(result)
except Exception:
logger.exception(f"Exception occured running route '{self.name}' with args: {args} and kwargs: {kwargs}")
payload = b''
_, writer = connection
writer.write(payload)
await writer.drain()
writer.close()

175
src/meta/ipc/server.py Normal file
View File

@@ -0,0 +1,175 @@
import asyncio
import pickle
import logging
import string
import random
from ..logger import log_context, log_app, logging_context
logger = logging.getLogger(__name__)
uuid_alphabet = string.ascii_lowercase + string.digits
def short_uuid():
return ''.join(random.choices(uuid_alphabet, k=10))
class AppServer:
routes = {} # route name -> bound method
def __init__(self):
self.clients = {} # AppID -> (info, connection)
self.route('ping')(self.route_ping)
self.route('whereis')(self.route_whereis)
self.route('peers')(self.route_peers)
self.route('connect')(self.client_connection)
@classmethod
def route(cls, route_name):
"""
Decorator to add a route to the server.
"""
def wrapper(coro):
cls.routes[route_name] = coro
return coro
return wrapper
async def route_ping(self, connection):
"""
Pong.
"""
reader, writer = connection
writer.write(b"Pong")
writer.write_eof()
async def route_whereis(self, connection, appid):
"""
Return an address for the given client appid.
Returns None if the client does not have a connection.
"""
reader, writer = connection
if appid in self.clients:
writer.write(pickle.dumps(self.clients[appid][0]))
else:
writer.write(b'')
writer.write_eof()
async def route_peers(self, connection):
"""
Send back a map of current peers.
"""
reader, writer = connection
peers = self.peer_list()
payload = pickle.dumps(('peer_list', (peers,)))
writer.write(payload)
writer.write_eof()
async def client_connection(self, connection, appid, address):
"""
Register and hold a new client connection.
"""
with logging_context(action=f"CONN {appid}"):
reader, writer = connection
# Add the new client
self.clients[appid] = (address, connection)
# Send the new client a client list
peers = self.peer_list()
writer.write(pickle.dumps(peers))
writer.write(b'\n')
await writer.drain()
# Announce the new client to everyone
await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
# Keep the connection open until socket closed or EOF (indicating client death)
try:
await reader.read()
finally:
# Connection ended or it broke
logger.info(f"Lost client '{appid}'")
await self.deregister_client(appid)
async def handle_connection(self, reader, writer):
data = await reader.readline()
route, args, kwargs = pickle.loads(data)
rqid = short_uuid()
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"):
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
# Execute route
try:
await self.routes[route]((reader, writer), *args, **kwargs)
except Exception:
logger.exception(f"AppServer recieved exception during route '{route}'")
else:
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
def peer_list(self):
return {appid: address for appid, (address, _) in self.clients.items()}
async def deregister_client(self, appid):
self.clients.pop(appid, None)
await self.broadcast('drop_peer', (), {'appid': appid})
async def broadcast(self, route, args, kwargs):
with logging_context(action="broadcast"):
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
payload = pickle.dumps((route, args, kwargs))
if self.clients:
await asyncio.gather(
*(self._send(appid, payload) for appid in self.clients),
return_exceptions=True
)
async def message_client(self, appid, route, args, kwargs):
"""
Send a message to client `appid` along `route` with given arguments.
"""
with logging_context(action=f"MSG {appid}"):
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
if appid not in self.clients:
raise ValueError(f"Client '{appid}' is not connected.")
payload = pickle.dumps((route, args, kwargs))
return await self._send(appid, payload)
async def _send(self, appid, payload):
"""
Send the encoded `payload` to the client `appid`.
"""
address, _ = self.clients[appid]
try:
reader, writer = await asyncio.open_connection(**address)
writer.write(payload)
writer.write_eof()
await writer.drain()
writer.close()
except Exception as ex:
# TODO: Close client if we can't connect?
logger.exception(f"Failed to send message to '{appid}'")
raise ex
async def start(self, address):
log_app.set("APPSERVER")
with logging_context(stack=["SERV"]):
server = await asyncio.start_server(self.handle_connection, **address)
logger.info(f"Serving on {address}")
async with server:
await server.serve_forever()
async def start_server():
address = {'host': '127.0.0.1', 'port': '5000'}
server = AppServer()
await server.start(address)
if __name__ == '__main__':
asyncio.run(start_server())

324
src/meta/logger.py Normal file
View File

@@ -0,0 +1,324 @@
import sys
import logging
import asyncio
from typing import List
from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
from discord import Webhook, File
import aiohttp
from .config import conf
from . import sharding
from .context import context
from utils.lib import utc_now
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[])
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
@contextmanager
def logging_context(context=None, action=None, stack=None):
if context is not None:
context_t = log_context.set(context)
if action is not None:
astack = log_action_stack.get()
log_action_stack.set(astack + [action])
if stack is not None:
actions_t = log_action_stack.set(stack)
try:
yield
finally:
if context is not None:
log_context.reset(context_t)
if stack is not None:
log_action_stack.reset(actions_t)
if action is not None:
log_action_stack.set(astack)
def log_wrap(**kwargs):
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
with logging_context(**kwargs):
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
'[%(cyan)%(app)-15s%(reset)]' +
'[%(cyan)%(context)-24s%(reset)]' +
'[%(cyan)%(actionstr)-22s%(reset)]' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
# datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter):
def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
except asyncio.CancelledError:
raise
except Exception:
self.handleError(record)
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, batch=False, loop=None):
super().__init__()
self.webhook_url = webhook_url
self.batched = ""
self.batch = batch
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
asyncio.create_task(self.post(record))
async def post(self, record):
log_context.set("Webhook Logger")
log_action_stack.set(["Logging"])
log_app.set(record.app)
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900:
as_file = True
else:
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 1500:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += message
if len(self.batched) + len(message) > 1500:
await self._send_batched_now()
else:
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(ex)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(ex)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, message, as_file=False):
async with aiohttp.ClientSession() as session:
webhook = Webhook.from_url(self.webhook_url, session=session)
if as_file or len(message) > 2000:
with StringIO(message) as fp:
fp.seek(0)
await webhook.send(
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else:
await webhook.send(message, username=log_app.get())
handlers = []
if webhook := conf.logging['general_log']:
handler = WebHookHandler(webhook, batch=True)
handlers.append(handler)
if webhook := conf.logging['error_log']:
handler = WebHookHandler(webhook, batch=False)
handler.setLevel(logging.ERROR)
handlers.append(handler)
if webhook := conf.logging['critical_log']:
handler = WebHookHandler(webhook, batch=False)
handler.setLevel(logging.CRITICAL)
handlers.append(handler)
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
queue: SimpleQueue[logging.LogRecord] = SimpleQueue()
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler)
listener = QueueListener(
queue, *handlers, respect_handler_level=True
)
listener.start()

View File

@@ -0,0 +1,34 @@
from discord import Intents
from cmdClient.cmdClient import cmdClient
from . import patches
from .interactions import InteractionType
from .config import conf
from .sharding import shard_number, shard_count
from LionContext import LionContext
# Initialise client
owners = [int(owner) for owner in conf.bot.getlist('owners')]
intents = Intents.all()
intents.presences = False
client = cmdClient(
prefix=conf.bot['prefix'],
owners=owners,
intents=intents,
shard_id=shard_number,
shard_count=shard_count,
baseContext=LionContext
)
client.conf = conf
# TODO: Could include client id here, or app id, to avoid multiple handling.
NOOP_ID = 'NOOP'
@client.add_after_event('interaction_create')
async def handle_noop_interaction(client, interaction):
if interaction.interaction_type in (InteractionType.MESSAGE_COMPONENT, InteractionType.MODAL_SUBMIT):
if interaction.custom_id == NOOP_ID:
interaction.ack()

View File

@@ -0,0 +1,110 @@
import sys
import logging
import asyncio
from discord import AllowedMentions
from cmdClient.logger import cmd_log_handler
from utils.lib import mail, split_text
from .client import client
from .config import conf
from . import sharding
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=('[{asctime}][{levelname:^8}]' +
'[SHARD {}]'.format(sharding.shard_number) +
' {message}'),
datefmt='%d/%m | %H:%M:%S',
style='{'
)
# term_handler = logging.StreamHandler(sys.stdout)
# term_handler.setFormatter(log_fmt)
# logger.addHandler(term_handler)
# logger.setLevel(logging.INFO)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
logger.setLevel(logging.NOTSET)
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logger.addHandler(logging_handler_err)
# Define the context log format and attach it to the command logger as well
@cmd_log_handler
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
if level >= logging.WARNING:
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
else:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
# Attach logger to client, for convenience
client.log = log

35
src/meta/sharding.py Normal file
View File

@@ -0,0 +1,35 @@
from .args import args
from .config import conf
from psycopg import sql
from data.conditions import Condition, Joiner
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
"""
Condition constructor for filtering by shard id.
Example Usage
-------------
Query.where(_shard_condition('guildid', 10, 1))
"""
return Condition(
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
guildid=sql.Identifier(guild_column),
shard_count=sql.Literal(shard_count)
),
Joiner.EQUALS,
sql.Placeholder(),
(shard_id,)
)
# Pre-built Condition for filtering by current shard.
THIS_SHARD = SHARDID(shard_number)