from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload import logging import asyncio from weakref import WeakValueDictionary from constants import SCHEMA_VERSIONS import discord from discord.utils import MISSING from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError from discord.ext.commands.errors import CommandInvokeError, CheckFailure from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError from aiohttp import ClientSession from data import Database, ORDER from utils.lib import tabulate from babel.translator import LeoBabel from botdata import BotData, VersionHistory from .config import Conf from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context from .context import context from .LionContext import LionContext from .LionTree import LionTree from .errors import HandledException, SafeCancellation from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus if TYPE_CHECKING: from core.cog import CoreCog from modules.profiles.profiles.discord.cog import ProfilesCog logger = logging.getLogger(__name__) class LionBot(Bot): def __init__( self, *args, appname: str, shardname: str, db: Database, config: Conf, initial_extensions: List[str], web_client: ClientSession, testing_guilds: List[int] = [], **kwargs ): kwargs.setdefault('tree_cls', LionTree) super().__init__(*args, **kwargs) self.web_client = web_client self.testing_guilds = testing_guilds self.initial_extensions = initial_extensions self.db = db self.appname = appname self.shardname = shardname # self.appdata = appdata self.data: BotData = db.load_registry(BotData()) self.config = config self.translator = LeoBabel() self.system_monitor = SystemMonitor() self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.system_monitor.add_component(self.monitor) self._locks = WeakValueDictionary() self._running_events = set() @property def dbconn(self): return self.db @property def core(self): return self.get_cog('CoreCog') @property def profiles(self): return self.get_cog('ProfilesCog') async def _monitor_status(self): if self.is_closed(): level = StatusLevel.ERRORED info = "(ERROR) Websocket is closed" data = {} elif self.is_ws_ratelimited(): level = StatusLevel.WAITING info = "(WAITING) Websocket is ratelimited" data = {} elif not self.is_ready(): level = StatusLevel.STARTING info = "(STARTING) Not yet ready" data = {} else: level = StatusLevel.OKAY info = ( "(OK) " "Logged in with {guild_count} guilds, " ", websocket latency {latency}, and {events} running events." ) data = { 'guild_count': len(self.guilds), 'latency': self.latency, 'events': len(self._running_events), } return ComponentStatus(level, info, info, data) async def setup_hook(self) -> None: log_context.set(f"APP: {self.application_id}") for extension in self.initial_extensions: await self.load_extension(extension) for guildid in self.testing_guilds: guild = discord.Object(guildid) if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)): self.tree.copy_global_to(guild=guild) await self.tree.sync(guild=guild) # To make the type checker happy about fetching cogs by name # TODO: Move this to stubs at some point @overload def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog': ... @overload def get_cog(self, name: Literal['ProfilesCog']) -> 'ProfilesCog': ... @overload def get_cog(self, name: str) -> Optional[Cog]: ... def get_cog(self, name: str) -> Optional[Cog]: return super().get_cog(name) async def add_cog(self, cog: Cog, **kwargs): sup = super() @log_wrap(action=f"Attach {cog.__cog_name__}") async def wrapper(): logger.info(f"Attaching Cog {cog.__cog_name__}") await sup.add_cog(cog, **kwargs) logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.") await wrapper() async def load_extension(self, name, *, package=None, **kwargs): sup = super() @log_wrap(action=f"Load {name.strip('.')}") async def wrapper(): logger.info(f"Loading extension {name} in package {package}.") await sup.load_extension(name, package=package, **kwargs) logger.debug(f"Loaded extension {name} in package {package}.") await wrapper() async def start(self, token: str, *, reconnect: bool = True): await self.data.init() for component, req in SCHEMA_VERSIONS.items(): await self.version_check(component, req) with logging_context(action="Login"): start_task = asyncio.create_task(self.login(token)) await start_task with logging_context(stack=("Running",)): run_task = asyncio.create_task(self.connect(reconnect=reconnect)) await run_task async def version_check(self, component: str, req_version: int): # Query the database to confirm that the given component is listed with the given version. # Typically done upon loading a component rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1) version = rows[0].to_version if rows else 0 if version != req_version: raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'") else: logger.debug( "Component %s passed version check with version %s", component, version ) return True def dispatch(self, event_name: str, *args, **kwargs): with logging_context(action=f"Dispatch {event_name}"): super().dispatch(event_name, *args, **kwargs) def _schedule_event(self, coro, event_name, *args, **kwargs): """ Extends client._schedule_event to keep a persistent background task store. """ task = super()._schedule_event(coro, event_name, *args, **kwargs) self._running_events.add(task) task.add_done_callback(lambda fut: self._running_events.discard(fut)) def idlock(self, snowflakeid): lock = self._locks.get(snowflakeid, None) if lock is None: lock = self._locks[snowflakeid] = asyncio.Lock() return lock async def on_ready(self): logger.info( f"Logged in as {self.application.name}\n" f"Application id {self.application.id}\n" f"Shard Talk identifier {self.shardname}\n" "------------------------------\n" f"Enabled Modules: {', '.join(self.extensions.keys())}\n" f"Loaded Cogs: {', '.join(self.cogs.keys())}\n" f"Registered Data: {', '.join(self.db.registries.keys())}\n" f"Listening for {sum(1 for _ in self.walk_commands())} commands\n" "------------------------------\n" f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n" "Ready to take commands!\n", extra={'action': 'Ready'} ) async def get_context(self, origin, /, *, cls=MISSING): if cls is MISSING: cls = LionContext ctx = await super().get_context(origin, cls=cls) context.set(ctx) return ctx async def on_command(self, ctx: LionContext): logger.info( f"Executing command '{ctx.command.qualified_name}' " f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') " f"with interaction: {ctx.interaction.data if ctx.interaction else None}", extra={'with_ctx': True} ) async def on_command_error(self, ctx, exception): # TODO: Some of these could have more user-feedback logger.debug(f"Handling command error for {ctx}: {exception}") if isinstance(ctx.command, HybridCommand) and ctx.command.app_command: cmd_str = ctx.command.app_command.to_dict(self.tree) else: cmd_str = str(ctx.command) try: raise exception except (HybridCommandError, CommandInvokeError, appCommandInvokeError): try: if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)): original = exception.original.original raise original else: original = exception.original raise original except HandledException: pass except TransformerError as e: msg = str(e) if msg: try: await ctx.error_reply(msg) except Exception: pass logger.debug( f"Caught a transformer error: {repr(e)}", extra={'action': 'BotError', 'with_ctx': True} ) except SafeCancellation: if original.msg: try: await ctx.error_reply(original.msg) except Exception: pass logger.debug( f"Caught a safe cancellation: {original.details}", extra={'action': 'BotError', 'with_ctx': True} ) except discord.Forbidden: # Unknown uncaught Forbidden try: # Attempt a general error reply await ctx.reply("I don't have enough channel or server permissions to complete that command here!") except Exception: # We can't send anything at all. Exit quietly, but log. logger.warning( f"Caught an unhandled 'Forbidden' while executing: {cmd_str}", exc_info=True, extra={'action': 'BotError', 'with_ctx': True} ) except discord.HTTPException: logger.error( f"Caught an unhandled 'HTTPException' while executing: {cmd_str}", exc_info=True, extra={'action': 'BotError', 'with_ctx': True} ) except asyncio.CancelledError: pass except asyncio.TimeoutError: pass except Exception as e: logger.exception( f"Caught an unknown CommandInvokeError while executing: {cmd_str}", extra={'action': 'BotError', 'with_ctx': True} ) error_embed = discord.Embed( title="Something went wrong!", colour=discord.Colour.dark_red() ) error_embed.description = ( "An unexpected error occurred while processing your command!\n" "Our development team has been notified, and the issue will be addressed soon.\n" ) details = {} details['error'] = f"`{repr(e)}`" if ctx.interaction: details['interactionid'] = f"`{ctx.interaction.id}`" if ctx.command: details['cmd'] = f"`{ctx.command.qualified_name}`" if ctx.author: details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`" if ctx.guild: details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`" details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`" if ctx.author: ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else '' details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`" if ctx.channel.type is discord.enums.ChannelType.private: details['channel'] = "`Direct Message`" elif ctx.channel: details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`" details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`" if ctx.author: details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`" details['shard'] = f"`{self.shardname}`" details['log_stack'] = f"`{log_action_stack.get()}`" table = '\n'.join(tabulate(*details.items())) error_embed.add_field(name='Details', value=table) try: await ctx.error_reply(embed=error_embed) except discord.HTTPException: pass finally: exception.original = HandledException(exception.original) except CheckFailure as e: logger.debug( f"Command failed check: {e}: {e.args}", extra={'action': 'BotError', 'with_ctx': True} ) try: await ctx.error_reply(str(e)) except discord.HTTPException: pass except Exception: # Completely unknown exception outside of command invocation! # Something is very wrong here, don't attempt user interaction. logger.exception( f"Caught an unknown top-level exception while executing: {cmd_str}", extra={'action': 'BotError', 'with_ctx': True} ) def add_command(self, command): if not hasattr(command, '_placeholder_group_'): super().add_command(command) def request_chunking_for(self, guild): if not guild.chunked: return asyncio.create_task( self._connection.chunk_guild(guild, wait=False, cache=True), name=f"Background chunkreq for {guild.id}" ) async def on_interaction(self, interaction: discord.Interaction): """ Adds the interaction author to guild cache if appropriate. This gets run a little bit late, so it is possible the interaction gets handled without the author being in case. """ guild = interaction.guild user = interaction.user if guild is not None and user is not None and isinstance(user, discord.Member): if not guild.get_member(user.id): guild._add_member(user) if guild is not None and not guild.chunked: # Getting an interaction in the guild is a good enough reason to request chunking logger.info( f"Unchunked guild requesting chunking after interaction." ) self.request_chunking_for(guild)