diff --git a/bot/meta/LionBot.py b/bot/meta/LionBot.py index 4e8ce455..585fd8dd 100644 --- a/bot/meta/LionBot.py +++ b/bot/meta/LionBot.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, TYPE_CHECKING import logging import asyncio @@ -18,12 +18,15 @@ from .LionContext import LionContext from .LionTree import LionTree from .errors import HandledException, SafeCancellation +if TYPE_CHECKING: + from core import CoreCog + logger = logging.getLogger(__name__) class LionBot(Bot): def __init__( - self, *args, appname: str, db: Database, config: Conf, + self, *args, appname: str, shardname: str, db: Database, config: Conf, initial_extensions: List[str], web_client: ClientSession, app_ipc, testing_guilds: List[int] = [], **kwargs ): @@ -34,9 +37,11 @@ class LionBot(Bot): self.initial_extensions = initial_extensions self.db = db self.appname = appname + self.shardname = shardname # self.appdata = appdata self.config = config self.app_ipc = app_ipc + self.core: Optional['CoreCog'] = None async def setup_hook(self) -> None: log_context.set(f"APP: {self.application_id}") @@ -72,10 +77,11 @@ class LionBot(Bot): logger.info( f"Logged in as {self.application.name}\n" f"Application id {self.application.id}\n" - f"Shard Talk identifier {self.appname}\n" + f"Shard Talk identifier {self.shardname}\n" "------------------------------\n" f"Enabled Modules: {', '.join(self.extensions.keys())}\n" f"Loaded Cogs: {', '.join(self.cogs.keys())}\n" + f"Registered Data: {', '.join(self.db.registries.keys())}\n" f"Listening for {sum(1 for _ in self.walk_commands())} commands\n" "------------------------------\n" f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n" diff --git a/bot/meta/LionCog.py b/bot/meta/LionCog.py index f705ce01..cab7e5c0 100644 --- a/bot/meta/LionCog.py +++ b/bot/meta/LionCog.py @@ -2,4 +2,12 @@ from discord.ext.commands import Cog class LionCog(Cog): - ... + # A set of other cogs that this cog depends on + depends_on: set['LionCog'] = set() + + async def _inject(self, bot, *args, **kwargs): + if self.depends_on: + not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)} + raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}") + + return await super()._inject(bot, *args, *kwargs) diff --git a/bot/meta/LionContext.py b/bot/meta/LionContext.py index 61aa5c57..eb0ffaa8 100644 --- a/bot/meta/LionContext.py +++ b/bot/meta/LionContext.py @@ -1,11 +1,14 @@ import types import logging from collections import namedtuple -from typing import Optional +from typing import Optional, TYPE_CHECKING import discord from discord.ext.commands import Context +if TYPE_CHECKING: + from .LionBot import LionBot + logger = logging.getLogger(__name__) @@ -34,7 +37,7 @@ FlatContext = namedtuple( ) -class LionContext(Context): +class LionContext(Context['LionBot']): """ Represents the context a command is invoked under. diff --git a/bot/meta/__init__.py b/bot/meta/__init__.py index fa9d999b..31feacc4 100644 --- a/bot/meta/__init__.py +++ b/bot/meta/__init__.py @@ -1,7 +1,15 @@ from .LionBot import LionBot -from .config import conf +from .LionCog import LionCog +from .LionContext import LionContext +from .LionTree import LionTree + +from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app +from .config import conf, configEmoji from .args import args -from .app import appname, shard_talk +from .app import appname, shard_talk, appname_from_shard, shard_from_appname +from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled +from .context import context, ctx_bot + from . import sharding from . import logger from . import app diff --git a/bot/meta/app.py b/bot/meta/app.py index 42335703..aa84a3bb 100644 --- a/bot/meta/app.py +++ b/bot/meta/app.py @@ -1,9 +1,24 @@ +""" +appname: str + The base identifer for this application. + This identifies which services the app offers. +shardname: str + The specific name of the running application. + Only one process should be connecteded with a given appname. + For the bot apps, usually specifies the shard id and shard number. +""" +# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data? + from . import sharding, conf from .logger import log_app from .ipc.client import AppClient from .args import args +appname = conf.data['appid'] +appid = appname # backwards compatibility + + def appname_from_shard(shardid): appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}" return appname @@ -13,16 +28,13 @@ def shard_from_appname(appname: str): return int(appname.rsplit('_', maxsplit=1)[-1]) -if sharding.sharded: - appname = appname_from_shard(sharding.shard_number) -else: - appname = conf.data['appid'] +shardname = appname_from_shard(sharding.shard_number) -log_app.set(appname) +log_app.set(shardname) shard_talk = AppClient( - appname, + shardname, {'host': args.host, 'port': args.port}, {'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])} ) diff --git a/bot/meta/context.py b/bot/meta/context.py index 3841d4d0..75f1df23 100644 --- a/bot/meta/context.py +++ b/bot/meta/context.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: # Contains the current command context, if applicable -context: Optional['LionContext'] = ContextVar('context', default=None) +context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None) # Contains the current LionBot instance -ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None) +ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None) diff --git a/bot/meta/errors.py b/bot/meta/errors.py index d4be3e60..a5d6cbf3 100644 --- a/bot/meta/errors.py +++ b/bot/meta/errors.py @@ -1,4 +1,5 @@ from typing import Optional +from string import Template class SafeCancellation(Exception): @@ -12,8 +13,12 @@ class SafeCancellation(Exception): """ default_message = "" - def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs): - self.msg: Optional[str] = msg if msg is not None else self.default_message + @property + def msg(self): + return self._msg if self._msg is not None else self.default_message + + def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs): + self._msg: Optional[str] = _msg self.details: str = details if details is not None else self.msg super().__init__(**kwargs) @@ -24,6 +29,14 @@ class UserInputError(SafeCancellation): """ default_message = "Could not understand your input." + @property + def msg(self): + return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message + + def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs): + self.info = info + super().__init__(_msg, **kwargs) + class UserCancelled(SafeCancellation): """ diff --git a/bot/meta/ipc/client.py b/bot/meta/ipc/client.py index 499b9406..a4d173ed 100644 --- a/bot/meta/ipc/client.py +++ b/bot/meta/ipc/client.py @@ -3,6 +3,8 @@ import asyncio import logging import pickle +from ..logger import logging_context + logger = logging.getLogger(__name__) @@ -22,6 +24,7 @@ class AppClient: self._listener: Optional[asyncio.Server] = None # Local client server self._server = None # Connection to the registry server + self._keepalive = None self.register_route('new_peer')(self.new_peer) self.register_route('drop_peer')(self.drop_peer) @@ -29,7 +32,7 @@ class AppClient: def register_route(self, name=None): def wrapper(coro): - route = AppRoute(coro, name) + route = AppRoute(coro, client=self, name=name) self.routes[route.name] = route return route return wrapper @@ -49,24 +52,31 @@ class AppClient: self.peers = peers self._server = (reader, writer) except Exception: - logger.exception("Could not connect to registry server. Trying again in 30 seconds.") + logger.exception( + "Could not connect to registry server. Trying again in 30 seconds.", + extra={'action': 'Connect'} + ) await asyncio.sleep(30) asyncio.create_task(self.server_connection()) else: - logger.info("Connected to the registry server, launching keepalive.") - asyncio.create_task(self._server_keepalive()) + logger.debug( + "Connected to the registry server, launching keepalive.", + extra={'action': 'Connect'} + ) + self._keepalive = asyncio.create_task(self._server_keepalive()) async def _server_keepalive(self): - if self._server is None: - raise ValueError("Cannot keepalive non-existent server!") - reader, write = self._server - try: - await reader.read() - except Exception: - logger.exception("Lost connection to address server. Reconnecting...") - else: - # Connection ended or broke - logger.info("Lost connection to address server. Reconnecting...") + with logging_context(action='Keepalive'): + if self._server is None: + raise ValueError("Cannot keepalive non-existent server!") + reader, write = self._server + try: + await reader.read() + except Exception: + logger.exception("Lost connection to address server. Reconnecting...") + else: + # Connection ended or broke + logger.info("Lost connection to address server. Reconnecting...") await asyncio.sleep(30) asyncio.create_task(self.server_connection()) @@ -85,55 +95,63 @@ class AppClient: ... async def request(self, appid, payload: 'AppPayload'): - try: - if appid not in self.peers: - raise ValueError(f"Peer '{appid}' not found.") - logger.debug(f"Sending request to app '{appid}' with payload {payload}") + with logging_context(action=f"Req {appid}"): + try: + if appid not in self.peers: + raise ValueError(f"Peer '{appid}' not found.") + logger.debug(f"Sending request to app '{appid}' with payload {payload}") - address = self.peers[appid] - reader, writer = await asyncio.open_connection(**address) + address = self.peers[appid] + reader, writer = await asyncio.open_connection(**address) - writer.write(payload.encoded()) - await writer.drain() - writer.write_eof() - result = await reader.read() - writer.close() - decoded = payload.route.decode(result) - return decoded - except Exception: - logging.exception(f"Failed to send request to {appid}'") - return None + writer.write(payload.encoded()) + await writer.drain() + writer.write_eof() + result = await reader.read() + writer.close() + decoded = payload.route.decode(result) + return decoded + except Exception: + logging.exception(f"Failed to send request to {appid}'") + return None - async def requestall(self, payload): - results = await asyncio.gather(*(self.request(appid, payload) for appid in self.peers)) - return dict(zip(self.peers.keys(), results)) + async def requestall(self, payload, except_self=True): + with logging_context(action="Broadcast"): + results = await asyncio.gather( + *(self.request(appid, payload) for appid in self.peers if (appid != self.appid or not except_self)), + return_exceptions=True + ) + return dict(zip(self.peers.keys(), results)) async def handle_request(self, reader, writer): - data = await reader.read() - loaded = pickle.loads(data) - route, args, kwargs = loaded + with logging_context(action="SERV"): + data = await reader.read() + loaded = pickle.loads(data) + route, args, kwargs = loaded - logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}") + with logging_context(action=f"SERV {route}"): + logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}") - if route in self.routes: - try: - await self.routes[route].run((reader, writer), args, kwargs) - except Exception: - logger.exception(f"Fatal exception during route '{route}'. This should never happen!") - else: - logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.") - writer.write_eof() + if route in self.routes: + try: + await self.routes[route].run((reader, writer), args, kwargs) + except Exception: + logger.exception(f"Fatal exception during route '{route}'. This should never happen!") + else: + logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.") + writer.write_eof() async def connect(self): """ Start the local peer server. Connect to the address server. """ - # Start the client server - self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True) + with logging_context(stack=['ShardTalk']): + # Start the client server + self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True) - logger.info(f"Serving on {self.address}") - await self.server_connection() + logger.info(f"Serving on {self.address}") + await self.server_connection() class AppPayload: @@ -150,13 +168,20 @@ class AppPayload: def encoded(self): return pickle.dumps((self.route.name, self.args, self.kwargs)) + async def send(self, appid, **kwargs): + return await self.route._client.request(appid, self, **kwargs) + + async def broadcast(self, **kwargs): + return await self.route._client.requestall(self, **kwargs) + class AppRoute: - __slots__ = ('func', 'name') + __slots__ = ('func', 'name', '_client') - def __init__(self, func, name=None): + def __init__(self, func, client=None, name=None): self.func = func self.name = name or func.__name__ + self._client = client def __call__(self, *args, **kwargs): return AppPayload(self, *args, **kwargs) diff --git a/bot/meta/logger.py b/bot/meta/logger.py index 0cdd7f7d..4c95d4e1 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -6,8 +6,9 @@ from logging.handlers import QueueListener, QueueHandler from queue import SimpleQueue from contextlib import contextmanager from io import StringIO - +from functools import wraps from contextvars import ContextVar + from discord import Webhook, File import aiohttp @@ -46,6 +47,16 @@ def logging_context(context=None, action=None, stack=None): log_action_stack.set(astack) +def log_wrap(**kwargs): + def decorator(func): + @wraps(func) + async def wrapped(*w_args, **w_kwargs): + with logging_context(**kwargs): + return await func(*w_args, **w_kwargs) + return wrapped + return decorator + + RESET_SEQ = "\033[0m" COLOR_SEQ = "\033[3%dm" BOLD_SEQ = "\033[1m"