rewrite: Update meta.
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Optional, TYPE_CHECKING
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@@ -18,12 +18,15 @@ from .LionContext import LionContext
|
|||||||
from .LionTree import LionTree
|
from .LionTree import LionTree
|
||||||
from .errors import HandledException, SafeCancellation
|
from .errors import HandledException, SafeCancellation
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core import CoreCog
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LionBot(Bot):
|
class LionBot(Bot):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *args, appname: str, db: Database, config: Conf,
|
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||||
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
||||||
testing_guilds: List[int] = [], **kwargs
|
testing_guilds: List[int] = [], **kwargs
|
||||||
):
|
):
|
||||||
@@ -34,9 +37,11 @@ class LionBot(Bot):
|
|||||||
self.initial_extensions = initial_extensions
|
self.initial_extensions = initial_extensions
|
||||||
self.db = db
|
self.db = db
|
||||||
self.appname = appname
|
self.appname = appname
|
||||||
|
self.shardname = shardname
|
||||||
# self.appdata = appdata
|
# self.appdata = appdata
|
||||||
self.config = config
|
self.config = config
|
||||||
self.app_ipc = app_ipc
|
self.app_ipc = app_ipc
|
||||||
|
self.core: Optional['CoreCog'] = None
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
log_context.set(f"APP: {self.application_id}")
|
log_context.set(f"APP: {self.application_id}")
|
||||||
@@ -72,10 +77,11 @@ class LionBot(Bot):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Logged in as {self.application.name}\n"
|
f"Logged in as {self.application.name}\n"
|
||||||
f"Application id {self.application.id}\n"
|
f"Application id {self.application.id}\n"
|
||||||
f"Shard Talk identifier {self.appname}\n"
|
f"Shard Talk identifier {self.shardname}\n"
|
||||||
"------------------------------\n"
|
"------------------------------\n"
|
||||||
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||||
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||||
|
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||||
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||||
"------------------------------\n"
|
"------------------------------\n"
|
||||||
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||||
|
|||||||
@@ -2,4 +2,12 @@ from discord.ext.commands import Cog
|
|||||||
|
|
||||||
|
|
||||||
class LionCog(Cog):
|
class LionCog(Cog):
|
||||||
...
|
# A set of other cogs that this cog depends on
|
||||||
|
depends_on: set['LionCog'] = set()
|
||||||
|
|
||||||
|
async def _inject(self, bot, *args, **kwargs):
|
||||||
|
if self.depends_on:
|
||||||
|
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
|
||||||
|
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
|
||||||
|
|
||||||
|
return await super()._inject(bot, *args, *kwargs)
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import types
|
import types
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .LionBot import LionBot
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -34,7 +37,7 @@ FlatContext = namedtuple(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LionContext(Context):
|
class LionContext(Context['LionBot']):
|
||||||
"""
|
"""
|
||||||
Represents the context a command is invoked under.
|
Represents the context a command is invoked under.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
from .LionBot import LionBot
|
from .LionBot import LionBot
|
||||||
from .config import conf
|
from .LionCog import LionCog
|
||||||
|
from .LionContext import LionContext
|
||||||
|
from .LionTree import LionTree
|
||||||
|
|
||||||
|
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||||
|
from .config import conf, configEmoji
|
||||||
from .args import args
|
from .args import args
|
||||||
from .app import appname, shard_talk
|
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
|
||||||
|
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||||
|
from .context import context, ctx_bot
|
||||||
|
|
||||||
from . import sharding
|
from . import sharding
|
||||||
from . import logger
|
from . import logger
|
||||||
from . import app
|
from . import app
|
||||||
|
|||||||
@@ -1,9 +1,24 @@
|
|||||||
|
"""
|
||||||
|
appname: str
|
||||||
|
The base identifer for this application.
|
||||||
|
This identifies which services the app offers.
|
||||||
|
shardname: str
|
||||||
|
The specific name of the running application.
|
||||||
|
Only one process should be connecteded with a given appname.
|
||||||
|
For the bot apps, usually specifies the shard id and shard number.
|
||||||
|
"""
|
||||||
|
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
|
||||||
|
|
||||||
from . import sharding, conf
|
from . import sharding, conf
|
||||||
from .logger import log_app
|
from .logger import log_app
|
||||||
from .ipc.client import AppClient
|
from .ipc.client import AppClient
|
||||||
from .args import args
|
from .args import args
|
||||||
|
|
||||||
|
|
||||||
|
appname = conf.data['appid']
|
||||||
|
appid = appname # backwards compatibility
|
||||||
|
|
||||||
|
|
||||||
def appname_from_shard(shardid):
|
def appname_from_shard(shardid):
|
||||||
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
||||||
return appname
|
return appname
|
||||||
@@ -13,16 +28,13 @@ def shard_from_appname(appname: str):
|
|||||||
return int(appname.rsplit('_', maxsplit=1)[-1])
|
return int(appname.rsplit('_', maxsplit=1)[-1])
|
||||||
|
|
||||||
|
|
||||||
if sharding.sharded:
|
shardname = appname_from_shard(sharding.shard_number)
|
||||||
appname = appname_from_shard(sharding.shard_number)
|
|
||||||
else:
|
|
||||||
appname = conf.data['appid']
|
|
||||||
|
|
||||||
log_app.set(appname)
|
log_app.set(shardname)
|
||||||
|
|
||||||
|
|
||||||
shard_talk = AppClient(
|
shard_talk = AppClient(
|
||||||
appname,
|
shardname,
|
||||||
{'host': args.host, 'port': args.port},
|
{'host': args.host, 'port': args.port},
|
||||||
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
|
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
# Contains the current command context, if applicable
|
# Contains the current command context, if applicable
|
||||||
context: Optional['LionContext'] = ContextVar('context', default=None)
|
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
|
||||||
|
|
||||||
# Contains the current LionBot instance
|
# Contains the current LionBot instance
|
||||||
ctx_bot: Optional['LionBot'] = ContextVar('bot', default=None)
|
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from string import Template
|
||||||
|
|
||||||
|
|
||||||
class SafeCancellation(Exception):
|
class SafeCancellation(Exception):
|
||||||
@@ -12,8 +13,12 @@ class SafeCancellation(Exception):
|
|||||||
"""
|
"""
|
||||||
default_message = ""
|
default_message = ""
|
||||||
|
|
||||||
def __init__(self, msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
@property
|
||||||
self.msg: Optional[str] = msg if msg is not None else self.default_message
|
def msg(self):
|
||||||
|
return self._msg if self._msg is not None else self.default_message
|
||||||
|
|
||||||
|
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||||
|
self._msg: Optional[str] = _msg
|
||||||
self.details: str = details if details is not None else self.msg
|
self.details: str = details if details is not None else self.msg
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@@ -24,6 +29,14 @@ class UserInputError(SafeCancellation):
|
|||||||
"""
|
"""
|
||||||
default_message = "Could not understand your input."
|
default_message = "Could not understand your input."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msg(self):
|
||||||
|
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
|
||||||
|
|
||||||
|
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
|
||||||
|
self.info = info
|
||||||
|
super().__init__(_msg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class UserCancelled(SafeCancellation):
|
class UserCancelled(SafeCancellation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
from ..logger import logging_context
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ class AppClient:
|
|||||||
|
|
||||||
self._listener: Optional[asyncio.Server] = None # Local client server
|
self._listener: Optional[asyncio.Server] = None # Local client server
|
||||||
self._server = None # Connection to the registry server
|
self._server = None # Connection to the registry server
|
||||||
|
self._keepalive = None
|
||||||
|
|
||||||
self.register_route('new_peer')(self.new_peer)
|
self.register_route('new_peer')(self.new_peer)
|
||||||
self.register_route('drop_peer')(self.drop_peer)
|
self.register_route('drop_peer')(self.drop_peer)
|
||||||
@@ -29,7 +32,7 @@ class AppClient:
|
|||||||
|
|
||||||
def register_route(self, name=None):
|
def register_route(self, name=None):
|
||||||
def wrapper(coro):
|
def wrapper(coro):
|
||||||
route = AppRoute(coro, name)
|
route = AppRoute(coro, client=self, name=name)
|
||||||
self.routes[route.name] = route
|
self.routes[route.name] = route
|
||||||
return route
|
return route
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -49,14 +52,21 @@ class AppClient:
|
|||||||
self.peers = peers
|
self.peers = peers
|
||||||
self._server = (reader, writer)
|
self._server = (reader, writer)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Could not connect to registry server. Trying again in 30 seconds.")
|
logger.exception(
|
||||||
|
"Could not connect to registry server. Trying again in 30 seconds.",
|
||||||
|
extra={'action': 'Connect'}
|
||||||
|
)
|
||||||
await asyncio.sleep(30)
|
await asyncio.sleep(30)
|
||||||
asyncio.create_task(self.server_connection())
|
asyncio.create_task(self.server_connection())
|
||||||
else:
|
else:
|
||||||
logger.info("Connected to the registry server, launching keepalive.")
|
logger.debug(
|
||||||
asyncio.create_task(self._server_keepalive())
|
"Connected to the registry server, launching keepalive.",
|
||||||
|
extra={'action': 'Connect'}
|
||||||
|
)
|
||||||
|
self._keepalive = asyncio.create_task(self._server_keepalive())
|
||||||
|
|
||||||
async def _server_keepalive(self):
|
async def _server_keepalive(self):
|
||||||
|
with logging_context(action='Keepalive'):
|
||||||
if self._server is None:
|
if self._server is None:
|
||||||
raise ValueError("Cannot keepalive non-existent server!")
|
raise ValueError("Cannot keepalive non-existent server!")
|
||||||
reader, write = self._server
|
reader, write = self._server
|
||||||
@@ -85,6 +95,7 @@ class AppClient:
|
|||||||
...
|
...
|
||||||
|
|
||||||
async def request(self, appid, payload: 'AppPayload'):
|
async def request(self, appid, payload: 'AppPayload'):
|
||||||
|
with logging_context(action=f"Req {appid}"):
|
||||||
try:
|
try:
|
||||||
if appid not in self.peers:
|
if appid not in self.peers:
|
||||||
raise ValueError(f"Peer '{appid}' not found.")
|
raise ValueError(f"Peer '{appid}' not found.")
|
||||||
@@ -104,15 +115,21 @@ class AppClient:
|
|||||||
logging.exception(f"Failed to send request to {appid}'")
|
logging.exception(f"Failed to send request to {appid}'")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def requestall(self, payload):
|
async def requestall(self, payload, except_self=True):
|
||||||
results = await asyncio.gather(*(self.request(appid, payload) for appid in self.peers))
|
with logging_context(action="Broadcast"):
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(self.request(appid, payload) for appid in self.peers if (appid != self.appid or not except_self)),
|
||||||
|
return_exceptions=True
|
||||||
|
)
|
||||||
return dict(zip(self.peers.keys(), results))
|
return dict(zip(self.peers.keys(), results))
|
||||||
|
|
||||||
async def handle_request(self, reader, writer):
|
async def handle_request(self, reader, writer):
|
||||||
|
with logging_context(action="SERV"):
|
||||||
data = await reader.read()
|
data = await reader.read()
|
||||||
loaded = pickle.loads(data)
|
loaded = pickle.loads(data)
|
||||||
route, args, kwargs = loaded
|
route, args, kwargs = loaded
|
||||||
|
|
||||||
|
with logging_context(action=f"SERV {route}"):
|
||||||
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||||
|
|
||||||
if route in self.routes:
|
if route in self.routes:
|
||||||
@@ -129,6 +146,7 @@ class AppClient:
|
|||||||
Start the local peer server.
|
Start the local peer server.
|
||||||
Connect to the address server.
|
Connect to the address server.
|
||||||
"""
|
"""
|
||||||
|
with logging_context(stack=['ShardTalk']):
|
||||||
# Start the client server
|
# Start the client server
|
||||||
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
|
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
|
||||||
|
|
||||||
@@ -150,13 +168,20 @@ class AppPayload:
|
|||||||
def encoded(self):
|
def encoded(self):
|
||||||
return pickle.dumps((self.route.name, self.args, self.kwargs))
|
return pickle.dumps((self.route.name, self.args, self.kwargs))
|
||||||
|
|
||||||
|
async def send(self, appid, **kwargs):
|
||||||
|
return await self.route._client.request(appid, self, **kwargs)
|
||||||
|
|
||||||
|
async def broadcast(self, **kwargs):
|
||||||
|
return await self.route._client.requestall(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AppRoute:
|
class AppRoute:
|
||||||
__slots__ = ('func', 'name')
|
__slots__ = ('func', 'name', '_client')
|
||||||
|
|
||||||
def __init__(self, func, name=None):
|
def __init__(self, func, client=None, name=None):
|
||||||
self.func = func
|
self.func = func
|
||||||
self.name = name or func.__name__
|
self.name = name or func.__name__
|
||||||
|
self._client = client
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return AppPayload(self, *args, **kwargs)
|
return AppPayload(self, *args, **kwargs)
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ from logging.handlers import QueueListener, QueueHandler
|
|||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from functools import wraps
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
from discord import Webhook, File
|
from discord import Webhook, File
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@@ -46,6 +47,16 @@ def logging_context(context=None, action=None, stack=None):
|
|||||||
log_action_stack.set(astack)
|
log_action_stack.set(astack)
|
||||||
|
|
||||||
|
|
||||||
|
def log_wrap(**kwargs):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
with logging_context(**kwargs):
|
||||||
|
return await func(*w_args, **w_kwargs)
|
||||||
|
return wrapped
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
RESET_SEQ = "\033[0m"
|
RESET_SEQ = "\033[0m"
|
||||||
COLOR_SEQ = "\033[3%dm"
|
COLOR_SEQ = "\033[3%dm"
|
||||||
BOLD_SEQ = "\033[1m"
|
BOLD_SEQ = "\033[1m"
|
||||||
|
|||||||
Reference in New Issue
Block a user