rewrite: Update meta.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
@@ -18,12 +18,15 @@ from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
from .errors import HandledException, SafeCancellation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core import CoreCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, db: Database, config: Conf,
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
||||
testing_guilds: List[int] = [], **kwargs
|
||||
):
|
||||
@@ -34,9 +37,11 @@ class LionBot(Bot):
|
||||
self.initial_extensions = initial_extensions
|
||||
self.db = db
|
||||
self.appname = appname
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
self.app_ipc = app_ipc
|
||||
self.core: Optional['CoreCog'] = None
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
log_context.set(f"APP: {self.application_id}")
|
||||
@@ -72,10 +77,11 @@ class LionBot(Bot):
|
||||
logger.info(
|
||||
f"Logged in as {self.application.name}\n"
|
||||
f"Application id {self.application.id}\n"
|
||||
f"Shard Talk identifier {self.appname}\n"
|
||||
f"Shard Talk identifier {self.shardname}\n"
|
||||
"------------------------------\n"
|
||||
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||
"------------------------------\n"
|
||||
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||
|
||||
@@ -2,4 +2,12 @@ from discord.ext.commands import Cog
|
||||
|
||||
|
||||
class LionCog(Cog):
|
||||
...
|
||||
# A set of other cogs that this cog depends on
|
||||
depends_on: set['LionCog'] = set()
|
||||
|
||||
async def _inject(self, bot, *args, **kwargs):
|
||||
if self.depends_on:
|
||||
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
|
||||
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
|
||||
|
||||
return await super()._inject(bot, *args, *kwargs)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import types
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from discord.ext.commands import Context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +37,7 @@ FlatContext = namedtuple(
|
||||
)
|
||||
|
||||
|
||||
class LionContext(Context):
|
||||
class LionContext(Context['LionBot']):
|
||||
"""
|
||||
Represents the context a command is invoked under.
|
||||
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
from .LionBot import LionBot
|
||||
from .config import conf
|
||||
from .LionCog import LionCog
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
|
||||
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||
from .config import conf, configEmoji
|
||||
from .args import args
|
||||
from .app import appname, shard_talk
|
||||
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
|
||||
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||
from .context import context, ctx_bot
|
||||
|
||||
from . import sharding
|
||||
from . import logger
|
||||
from . import app
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
"""
|
||||
appname: str
|
||||
The base identifer for this application.
|
||||
This identifies which services the app offers.
|
||||
shardname: str
|
||||
The specific name of the running application.
|
||||
Only one process should be connecteded with a given appname.
|
||||
For the bot apps, usually specifies the shard id and shard number.
|
||||
"""
|
||||
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
|
||||
|
||||
from . import sharding, conf
|
||||
from .logger import log_app
|
||||
from .ipc.client import AppClient
|
||||
from .args import args
|
||||
|
||||
|
||||
appname = conf.data['appid']
|
||||
appid = appname # backwards compatibility
|
||||
|
||||
|
||||
def appname_from_shard(shardid):
|
||||
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
||||
return appname
|
||||
@@ -13,16 +28,13 @@ def shard_from_appname(appname: str):
|
||||
return int(appname.rsplit('_', maxsplit=1)[-1])
|
||||
|
||||
|
||||
if sharding.sharded:
|
||||
appname = appname_from_shard(sharding.shard_number)
|
||||
else:
|
||||
appname = conf.data['appid']
|
||||
shardname = appname_from_shard(sharding.shard_number)
|
||||
|
||||
log_app.set(appname)
|
||||
log_app.set(shardname)
|
||||
|
||||
|
||||
shard_talk = AppClient(
|
||||
appname,
|
||||
shardname,
|
||||
{'host': args.host, 'port': args.port},
|
||||
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# Contains the current command context, if applicable
|
||||
context: Optional['LionContext'] = ContextVar('context', default=None)
|
||||
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
|
||||
|
||||
# Contains the current LionBot instance
|
||||
ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None)
|
||||
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
class SafeCancellation(Exception):
|
||||
@@ -12,8 +13,12 @@ class SafeCancellation(Exception):
|
||||
"""
|
||||
default_message = ""
|
||||
|
||||
def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||
self.msg: Optional[str] = msg if msg is not None else self.default_message
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||
self._msg: Optional[str] = _msg
|
||||
self.details: str = details if details is not None else self.msg
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -24,6 +29,14 @@ class UserInputError(SafeCancellation):
|
||||
"""
|
||||
default_message = "Could not understand your input."
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
|
||||
self.info = info
|
||||
super().__init__(_msg, **kwargs)
|
||||
|
||||
|
||||
class UserCancelled(SafeCancellation):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,8 @@ import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
from ..logger import logging_context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +24,7 @@ class AppClient:
|
||||
|
||||
self._listener: Optional[asyncio.Server] = None # Local client server
|
||||
self._server = None # Connection to the registry server
|
||||
self._keepalive = None
|
||||
|
||||
self.register_route('new_peer')(self.new_peer)
|
||||
self.register_route('drop_peer')(self.drop_peer)
|
||||
@@ -29,7 +32,7 @@ class AppClient:
|
||||
|
||||
def register_route(self, name=None):
|
||||
def wrapper(coro):
|
||||
route = AppRoute(coro, name)
|
||||
route = AppRoute(coro, client=self, name=name)
|
||||
self.routes[route.name] = route
|
||||
return route
|
||||
return wrapper
|
||||
@@ -49,14 +52,21 @@ class AppClient:
|
||||
self.peers = peers
|
||||
self._server = (reader, writer)
|
||||
except Exception:
|
||||
logger.exception("Could not connect to registry server. Trying again in 30 seconds.")
|
||||
logger.exception(
|
||||
"Could not connect to registry server. Trying again in 30 seconds.",
|
||||
extra={'action': 'Connect'}
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
else:
|
||||
logger.info("Connected to the registry server, launching keepalive.")
|
||||
asyncio.create_task(self._server_keepalive())
|
||||
logger.debug(
|
||||
"Connected to the registry server, launching keepalive.",
|
||||
extra={'action': 'Connect'}
|
||||
)
|
||||
self._keepalive = asyncio.create_task(self._server_keepalive())
|
||||
|
||||
async def _server_keepalive(self):
|
||||
with logging_context(action='Keepalive'):
|
||||
if self._server is None:
|
||||
raise ValueError("Cannot keepalive non-existent server!")
|
||||
reader, write = self._server
|
||||
@@ -85,6 +95,7 @@ class AppClient:
|
||||
...
|
||||
|
||||
async def request(self, appid, payload: 'AppPayload'):
|
||||
with logging_context(action=f"Req {appid}"):
|
||||
try:
|
||||
if appid not in self.peers:
|
||||
raise ValueError(f"Peer '{appid}' not found.")
|
||||
@@ -104,15 +115,21 @@ class AppClient:
|
||||
logging.exception(f"Failed to send request to {appid}'")
|
||||
return None
|
||||
|
||||
async def requestall(self, payload):
|
||||
results = await asyncio.gather(*(self.request(appid, payload) for appid in self.peers))
|
||||
async def requestall(self, payload, except_self=True):
|
||||
with logging_context(action="Broadcast"):
|
||||
results = await asyncio.gather(
|
||||
*(self.request(appid, payload) for appid in self.peers if (appid != self.appid or not except_self)),
|
||||
return_exceptions=True
|
||||
)
|
||||
return dict(zip(self.peers.keys(), results))
|
||||
|
||||
async def handle_request(self, reader, writer):
|
||||
with logging_context(action="SERV"):
|
||||
data = await reader.read()
|
||||
loaded = pickle.loads(data)
|
||||
route, args, kwargs = loaded
|
||||
|
||||
with logging_context(action=f"SERV {route}"):
|
||||
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||
|
||||
if route in self.routes:
|
||||
@@ -129,6 +146,7 @@ class AppClient:
|
||||
Start the local peer server.
|
||||
Connect to the address server.
|
||||
"""
|
||||
with logging_context(stack=['ShardTalk']):
|
||||
# Start the client server
|
||||
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
|
||||
|
||||
@@ -150,13 +168,20 @@ class AppPayload:
|
||||
def encoded(self):
|
||||
return pickle.dumps((self.route.name, self.args, self.kwargs))
|
||||
|
||||
async def send(self, appid, **kwargs):
|
||||
return await self.route._client.request(appid, self, **kwargs)
|
||||
|
||||
async def broadcast(self, **kwargs):
|
||||
return await self.route._client.requestall(self, **kwargs)
|
||||
|
||||
|
||||
class AppRoute:
|
||||
__slots__ = ('func', 'name')
|
||||
__slots__ = ('func', 'name', '_client')
|
||||
|
||||
def __init__(self, func, name=None):
|
||||
def __init__(self, func, client=None, name=None):
|
||||
self.func = func
|
||||
self.name = name or func.__name__
|
||||
self._client = client
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AppPayload(self, *args, **kwargs)
|
||||
|
||||
@@ -6,8 +6,9 @@ from logging.handlers import QueueListener, QueueHandler
|
||||
from queue import SimpleQueue
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
|
||||
from discord import Webhook, File
|
||||
import aiohttp
|
||||
|
||||
@@ -46,6 +47,16 @@ def logging_context(context=None, action=None, stack=None):
|
||||
log_action_stack.set(astack)
|
||||
|
||||
|
||||
def log_wrap(**kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
with logging_context(**kwargs):
|
||||
return await func(*w_args, **w_kwargs)
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
RESET_SEQ = "\033[0m"
|
||||
COLOR_SEQ = "\033[3%dm"
|
||||
BOLD_SEQ = "\033[1m"
|
||||
|
||||
Reference in New Issue
Block a user