rewrite: Update meta.

This commit is contained in:
2022-11-18 08:43:54 +02:00
parent 56f66ec7d4
commit ebece5256a
9 changed files with 156 additions and 70 deletions

View File

@@ -1,4 +1,4 @@
from typing import List from typing import List, Optional, TYPE_CHECKING
import logging import logging
import asyncio import asyncio
@@ -18,12 +18,15 @@ from .LionContext import LionContext
from .LionTree import LionTree from .LionTree import LionTree
from .errors import HandledException, SafeCancellation from .errors import HandledException, SafeCancellation
if TYPE_CHECKING:
from core import CoreCog
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LionBot(Bot): class LionBot(Bot):
def __init__( def __init__(
self, *args, appname: str, db: Database, config: Conf, self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession, app_ipc, initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [], **kwargs testing_guilds: List[int] = [], **kwargs
): ):
@@ -34,9 +37,11 @@ class LionBot(Bot):
self.initial_extensions = initial_extensions self.initial_extensions = initial_extensions
self.db = db self.db = db
self.appname = appname self.appname = appname
self.shardname = shardname
# self.appdata = appdata # self.appdata = appdata
self.config = config self.config = config
self.app_ipc = app_ipc self.app_ipc = app_ipc
self.core: Optional['CoreCog'] = None
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}") log_context.set(f"APP: {self.application_id}")
@@ -72,10 +77,11 @@ class LionBot(Bot):
logger.info( logger.info(
f"Logged in as {self.application.name}\n" f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n" f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.appname}\n" f"Shard Talk identifier {self.shardname}\n"
"------------------------------\n" "------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n" f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n" f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n" f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n" "------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n" f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"

View File

@@ -2,4 +2,12 @@ from discord.ext.commands import Cog
class LionCog(Cog): class LionCog(Cog):
... # A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
async def _inject(self, bot, *args, **kwargs):
if self.depends_on:
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
return await super()._inject(bot, *args, *kwargs)

View File

@@ -1,11 +1,14 @@
import types import types
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Optional from typing import Optional, TYPE_CHECKING
import discord import discord
from discord.ext.commands import Context from discord.ext.commands import Context
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,7 +37,7 @@ FlatContext = namedtuple(
) )
class LionContext(Context): class LionContext(Context['LionBot']):
""" """
Represents the context a command is invoked under. Represents the context a command is invoked under.

View File

@@ -1,7 +1,15 @@
from .LionBot import LionBot from .LionBot import LionBot
from .config import conf from .LionCog import LionCog
from .LionContext import LionContext
from .LionTree import LionTree
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji
from .args import args from .args import args
from .app import appname, shard_talk from .app import appname, shard_talk, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot
from . import sharding from . import sharding
from . import logger from . import logger
from . import app from . import app

View File

@@ -1,9 +1,24 @@
"""
appname: str
The base identifer for this application.
This identifies which services the app offers.
shardname: str
The specific name of the running application.
Only one process should be connecteded with a given appname.
For the bot apps, usually specifies the shard id and shard number.
"""
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
from . import sharding, conf from . import sharding, conf
from .logger import log_app from .logger import log_app
from .ipc.client import AppClient from .ipc.client import AppClient
from .args import args from .args import args
appname = conf.data['appid']
appid = appname # backwards compatibility
def appname_from_shard(shardid): def appname_from_shard(shardid):
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}" appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
return appname return appname
@@ -13,16 +28,13 @@ def shard_from_appname(appname: str):
return int(appname.rsplit('_', maxsplit=1)[-1]) return int(appname.rsplit('_', maxsplit=1)[-1])
if sharding.sharded: shardname = appname_from_shard(sharding.shard_number)
appname = appname_from_shard(sharding.shard_number)
else:
appname = conf.data['appid']
log_app.set(appname) log_app.set(shardname)
shard_talk = AppClient( shard_talk = AppClient(
appname, shardname,
{'host': args.host, 'port': args.port}, {'host': args.host, 'port': args.port},
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])} {'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
) )

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
# Contains the current command context, if applicable # Contains the current command context, if applicable
context: Optional['LionContext'] = ContextVar('context', default=None) context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
# Contains the current LionBot instance # Contains the current LionBot instance
ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None) ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)

View File

@@ -1,4 +1,5 @@
from typing import Optional from typing import Optional
from string import Template
class SafeCancellation(Exception): class SafeCancellation(Exception):
@@ -12,8 +13,12 @@ class SafeCancellation(Exception):
""" """
default_message = "" default_message = ""
def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs): @property
self.msg: Optional[str] = msg if msg is not None else self.default_message def msg(self):
return self._msg if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self._msg: Optional[str] = _msg
self.details: str = details if details is not None else self.msg self.details: str = details if details is not None else self.msg
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -24,6 +29,14 @@ class UserInputError(SafeCancellation):
""" """
default_message = "Could not understand your input." default_message = "Could not understand your input."
@property
def msg(self):
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
self.info = info
super().__init__(_msg, **kwargs)
class UserCancelled(SafeCancellation): class UserCancelled(SafeCancellation):
""" """

View File

@@ -3,6 +3,8 @@ import asyncio
import logging import logging
import pickle import pickle
from ..logger import logging_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,6 +24,7 @@ class AppClient:
self._listener: Optional[asyncio.Server] = None # Local client server self._listener: Optional[asyncio.Server] = None # Local client server
self._server = None # Connection to the registry server self._server = None # Connection to the registry server
self._keepalive = None
self.register_route('new_peer')(self.new_peer) self.register_route('new_peer')(self.new_peer)
self.register_route('drop_peer')(self.drop_peer) self.register_route('drop_peer')(self.drop_peer)
@@ -29,7 +32,7 @@ class AppClient:
def register_route(self, name=None): def register_route(self, name=None):
def wrapper(coro): def wrapper(coro):
route = AppRoute(coro, name) route = AppRoute(coro, client=self, name=name)
self.routes[route.name] = route self.routes[route.name] = route
return route return route
return wrapper return wrapper
@@ -49,14 +52,21 @@ class AppClient:
self.peers = peers self.peers = peers
self._server = (reader, writer) self._server = (reader, writer)
except Exception: except Exception:
logger.exception("Could not connect to registry server. Trying again in 30 seconds.") logger.exception(
"Could not connect to registry server. Trying again in 30 seconds.",
extra={'action': 'Connect'}
)
await asyncio.sleep(30) await asyncio.sleep(30)
asyncio.create_task(self.server_connection()) asyncio.create_task(self.server_connection())
else: else:
logger.info("Connected to the registry server, launching keepalive.") logger.debug(
asyncio.create_task(self._server_keepalive()) "Connected to the registry server, launching keepalive.",
extra={'action': 'Connect'}
)
self._keepalive = asyncio.create_task(self._server_keepalive())
async def _server_keepalive(self): async def _server_keepalive(self):
with logging_context(action='Keepalive'):
if self._server is None: if self._server is None:
raise ValueError("Cannot keepalive non-existent server!") raise ValueError("Cannot keepalive non-existent server!")
reader, write = self._server reader, write = self._server
@@ -85,6 +95,7 @@ class AppClient:
... ...
async def request(self, appid, payload: 'AppPayload'): async def request(self, appid, payload: 'AppPayload'):
with logging_context(action=f"Req {appid}"):
try: try:
if appid not in self.peers: if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.") raise ValueError(f"Peer '{appid}' not found.")
@@ -104,15 +115,21 @@ class AppClient:
logging.exception(f"Failed to send request to {appid}'") logging.exception(f"Failed to send request to {appid}'")
return None return None
async def requestall(self, payload): async def requestall(self, payload, except_self=True):
results = await asyncio.gather(*(self.request(appid, payload) for appid in self.peers)) with logging_context(action="Broadcast"):
results = await asyncio.gather(
*(self.request(appid, payload) for appid in self.peers if (appid != self.appid or not except_self)),
return_exceptions=True
)
return dict(zip(self.peers.keys(), results)) return dict(zip(self.peers.keys(), results))
async def handle_request(self, reader, writer): async def handle_request(self, reader, writer):
with logging_context(action="SERV"):
data = await reader.read() data = await reader.read()
loaded = pickle.loads(data) loaded = pickle.loads(data)
route, args, kwargs = loaded route, args, kwargs = loaded
with logging_context(action=f"SERV {route}"):
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}") logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes: if route in self.routes:
@@ -129,6 +146,7 @@ class AppClient:
Start the local peer server. Start the local peer server.
Connect to the address server. Connect to the address server.
""" """
with logging_context(stack=['ShardTalk']):
# Start the client server # Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True) self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
@@ -150,13 +168,20 @@ class AppPayload:
def encoded(self): def encoded(self):
return pickle.dumps((self.route.name, self.args, self.kwargs)) return pickle.dumps((self.route.name, self.args, self.kwargs))
async def send(self, appid, **kwargs):
return await self.route._client.request(appid, self, **kwargs)
async def broadcast(self, **kwargs):
return await self.route._client.requestall(self, **kwargs)
class AppRoute: class AppRoute:
__slots__ = ('func', 'name') __slots__ = ('func', 'name', '_client')
def __init__(self, func, name=None): def __init__(self, func, client=None, name=None):
self.func = func self.func = func
self.name = name or func.__name__ self.name = name or func.__name__
self._client = client
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return AppPayload(self, *args, **kwargs) return AppPayload(self, *args, **kwargs)

View File

@@ -6,8 +6,9 @@ from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue from queue import SimpleQueue
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO from io import StringIO
from functools import wraps
from contextvars import ContextVar from contextvars import ContextVar
from discord import Webhook, File from discord import Webhook, File
import aiohttp import aiohttp
@@ -46,6 +47,16 @@ def logging_context(context=None, action=None, stack=None):
log_action_stack.set(astack) log_action_stack.set(astack)
def log_wrap(**kwargs):
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
with logging_context(**kwargs):
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
RESET_SEQ = "\033[0m" RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm" COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m" BOLD_SEQ = "\033[1m"