rewrite: Update meta.

This commit is contained in:
2022-11-18 08:43:54 +02:00
parent 56f66ec7d4
commit ebece5256a
9 changed files with 156 additions and 70 deletions

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, TYPE_CHECKING
import logging
import asyncio
@@ -18,12 +18,15 @@ from .LionContext import LionContext
from .LionTree import LionTree
from .errors import HandledException, SafeCancellation
if TYPE_CHECKING:
from core import CoreCog
logger = logging.getLogger(__name__)
class LionBot(Bot):
def __init__(
self, *args, appname: str, db: Database, config: Conf,
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [], **kwargs
):
@@ -34,9 +37,11 @@ class LionBot(Bot):
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.app_ipc = app_ipc
self.core: Optional['CoreCog'] = None
async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
@@ -72,10 +77,11 @@ class LionBot(Bot):
logger.info(
f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.appname}\n"
f"Shard Talk identifier {self.shardname}\n"
"------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"

View File

@@ -2,4 +2,12 @@ from discord.ext.commands import Cog
class LionCog(Cog):
...
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
async def _inject(self, bot, *args, **kwargs):
if self.depends_on:
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
return await super()._inject(bot, *args, *kwargs)

View File

@@ -1,11 +1,14 @@
import types
import logging
from collections import namedtuple
from typing import Optional
from typing import Optional, TYPE_CHECKING
import discord
from discord.ext.commands import Context
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__)
@@ -34,7 +37,7 @@ FlatContext = namedtuple(
)
class LionContext(Context):
class LionContext(Context['LionBot']):
"""
Represents the context a command is invoked under.

View File

@@ -1,7 +1,15 @@
from .LionBot import LionBot
from .config import conf
from .LionCog import LionCog
from .LionContext import LionContext
from .LionTree import LionTree
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji
from .args import args
from .app import appname, shard_talk
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot
from . import sharding
from . import logger
from . import app

View File

@@ -1,9 +1,24 @@
"""
appname: str
The base identifer for this application.
This identifies which services the app offers.
shardname: str
The specific name of the running application.
Only one process should be connecteded with a given appname.
For the bot apps, usually specifies the shard id and shard number.
"""
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
from . import sharding, conf
from .logger import log_app
from .ipc.client import AppClient
from .args import args
appname = conf.data['appid']
appid = appname # backwards compatibility
def appname_from_shard(shardid):
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
return appname
@@ -13,16 +28,13 @@ def shard_from_appname(appname: str):
return int(appname.rsplit('_', maxsplit=1)[-1])
if sharding.sharded:
appname = appname_from_shard(sharding.shard_number)
else:
appname = conf.data['appid']
shardname = appname_from_shard(sharding.shard_number)
log_app.set(appname)
log_app.set(shardname)
shard_talk = AppClient(
appname,
shardname,
{'host': args.host, 'port': args.port},
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
)

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
# Contains the current command context, if applicable
context: Optional['LionContext'] = ContextVar('context', default=None)
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
# Contains the current LionBot instance
ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None)
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)

View File

@@ -1,4 +1,5 @@
from typing import Optional
from string import Template
class SafeCancellation(Exception):
@@ -12,8 +13,12 @@ class SafeCancellation(Exception):
"""
default_message = ""
def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self.msg: Optional[str] = msg if msg is not None else self.default_message
@property
def msg(self):
return self._msg if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self._msg: Optional[str] = _msg
self.details: str = details if details is not None else self.msg
super().__init__(**kwargs)
@@ -24,6 +29,14 @@ class UserInputError(SafeCancellation):
"""
default_message = "Could not understand your input."
@property
def msg(self):
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
self.info = info
super().__init__(_msg, **kwargs)
class UserCancelled(SafeCancellation):
"""

View File

@@ -3,6 +3,8 @@ import asyncio
import logging
import pickle
from ..logger import logging_context
logger = logging.getLogger(__name__)
@@ -22,6 +24,7 @@ class AppClient:
self._listener: Optional[asyncio.Server] = None # Local client server
self._server = None # Connection to the registry server
self._keepalive = None
self.register_route('new_peer')(self.new_peer)
self.register_route('drop_peer')(self.drop_peer)
@@ -29,7 +32,7 @@ class AppClient:
def register_route(self, name=None):
def wrapper(coro):
route = AppRoute(coro, name)
route = AppRoute(coro, client=self, name=name)
self.routes[route.name] = route
return route
return wrapper
@@ -49,24 +52,31 @@ class AppClient:
self.peers = peers
self._server = (reader, writer)
except Exception:
logger.exception("Could not connect to registry server. Trying again in 30 seconds.")
logger.exception(
"Could not connect to registry server. Trying again in 30 seconds.",
extra={'action': 'Connect'}
)
await asyncio.sleep(30)
asyncio.create_task(self.server_connection())
else:
logger.info("Connected to the registry server, launching keepalive.")
asyncio.create_task(self._server_keepalive())
logger.debug(
"Connected to the registry server, launching keepalive.",
extra={'action': 'Connect'}
)
self._keepalive = asyncio.create_task(self._server_keepalive())
async def _server_keepalive(self):
if self._server is None:
raise ValueError("Cannot keepalive non-existent server!")
reader, write = self._server
try:
await reader.read()
except Exception:
logger.exception("Lost connection to address server. Reconnecting...")
else:
# Connection ended or broke
logger.info("Lost connection to address server. Reconnecting...")
with logging_context(action='Keepalive'):
if self._server is None:
raise ValueError("Cannot keepalive non-existent server!")
reader, write = self._server
try:
await reader.read()
except Exception:
logger.exception("Lost connection to address server. Reconnecting...")
else:
# Connection ended or broke
logger.info("Lost connection to address server. Reconnecting...")
await asyncio.sleep(30)
asyncio.create_task(self.server_connection())
@@ -85,55 +95,63 @@ class AppClient:
...
async def request(self, appid, payload: 'AppPayload'):
try:
if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
with logging_context(action=f"Req {appid}"):
try:
if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address)
address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address)
writer.write(payload.encoded())
await writer.drain()
writer.write_eof()
result = await reader.read()
writer.close()
decoded = payload.route.decode(result)
return decoded
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None
writer.write(payload.encoded())
await writer.drain()
writer.write_eof()
result = await reader.read()
writer.close()
decoded = payload.route.decode(result)
return decoded
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None
async def requestall(self, payload):
results = await asyncio.gather(*(self.request(appid, payload) for appid in self.peers))
return dict(zip(self.peers.keys(), results))
async def requestall(self, payload, except_self=True):
with logging_context(action="Broadcast"):
results = await asyncio.gather(
*(self.request(appid, payload) for appid in self.peers if (appid != self.appid or not except_self)),
return_exceptions=True
)
return dict(zip(self.peers.keys(), results))
async def handle_request(self, reader, writer):
data = await reader.read()
loaded = pickle.loads(data)
route, args, kwargs = loaded
with logging_context(action="SERV"):
data = await reader.read()
loaded = pickle.loads(data)
route, args, kwargs = loaded
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
with logging_context(action=f"SERV {route}"):
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
async def connect(self):
"""
Start the local peer server.
Connect to the address server.
"""
# Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
with logging_context(stack=['ShardTalk']):
# Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
logger.info(f"Serving on {self.address}")
await self.server_connection()
logger.info(f"Serving on {self.address}")
await self.server_connection()
class AppPayload:
@@ -150,13 +168,20 @@ class AppPayload:
def encoded(self):
return pickle.dumps((self.route.name, self.args, self.kwargs))
async def send(self, appid, **kwargs):
return await self.route._client.request(appid, self, **kwargs)
async def broadcast(self, **kwargs):
return await self.route._client.requestall(self, **kwargs)
class AppRoute:
__slots__ = ('func', 'name')
__slots__ = ('func', 'name', '_client')
def __init__(self, func, name=None):
def __init__(self, func, client=None, name=None):
self.func = func
self.name = name or func.__name__
self._client = client
def __call__(self, *args, **kwargs):
return AppPayload(self, *args, **kwargs)

View File

@@ -6,8 +6,9 @@ from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
from discord import Webhook, File
import aiohttp
@@ -46,6 +47,16 @@ def logging_context(context=None, action=None, stack=None):
log_action_stack.set(astack)
def log_wrap(**kwargs):
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
with logging_context(**kwargs):
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"