generated from HoloTech/discord-bot-template
Initial commit
This commit is contained in:
373
src/meta/LionBot.py
Normal file
373
src/meta/LionBot.py
Normal file
@@ -0,0 +1,373 @@
|
||||
from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload
|
||||
import logging
|
||||
import asyncio
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from constants import SCHEMA_VERSIONS
|
||||
import discord
|
||||
from discord.utils import MISSING
|
||||
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from data import Database
|
||||
from utils.lib import tabulate
|
||||
from babel.translator import LeoBabel
|
||||
from botdata import BotData, VersionHistory
|
||||
|
||||
from .config import Conf
|
||||
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
|
||||
from .context import context
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
from .errors import HandledException, SafeCancellation
|
||||
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.cog import CoreCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||
initial_extensions: List[str], web_client: ClientSession,
|
||||
testing_guilds: List[int] = [], **kwargs
|
||||
):
|
||||
kwargs.setdefault('tree_cls', LionTree)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.web_client = web_client
|
||||
self.testing_guilds = testing_guilds
|
||||
self.initial_extensions = initial_extensions
|
||||
self.db = db
|
||||
self.appname = appname
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.data: BotData = db.load_registry(BotData())
|
||||
self.config = config
|
||||
self.translator = LeoBabel()
|
||||
|
||||
self.system_monitor = SystemMonitor()
|
||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||
self.system_monitor.add_component(self.monitor)
|
||||
|
||||
self._locks = WeakValueDictionary()
|
||||
self._running_events = set()
|
||||
|
||||
@property
|
||||
def dbconn(self):
|
||||
return self.db
|
||||
|
||||
@property
|
||||
def core(self):
|
||||
return self.get_cog('CoreCog')
|
||||
|
||||
async def _monitor_status(self):
|
||||
if self.is_closed():
|
||||
level = StatusLevel.ERRORED
|
||||
info = "(ERROR) Websocket is closed"
|
||||
data = {}
|
||||
elif self.is_ws_ratelimited():
|
||||
level = StatusLevel.WAITING
|
||||
info = "(WAITING) Websocket is ratelimited"
|
||||
data = {}
|
||||
elif not self.is_ready():
|
||||
level = StatusLevel.STARTING
|
||||
info = "(STARTING) Not yet ready"
|
||||
data = {}
|
||||
else:
|
||||
level = StatusLevel.OKAY
|
||||
info = (
|
||||
"(OK) "
|
||||
"Logged in with {guild_count} guilds, "
|
||||
", websocket latency {latency}, and {events} running events."
|
||||
)
|
||||
data = {
|
||||
'guild_count': len(self.guilds),
|
||||
'latency': self.latency,
|
||||
'events': len(self._running_events),
|
||||
}
|
||||
return ComponentStatus(level, info, info, data)
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
log_context.set(f"APP: {self.application_id}")
|
||||
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(extension)
|
||||
|
||||
for guildid in self.testing_guilds:
|
||||
guild = discord.Object(guildid)
|
||||
if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)):
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
|
||||
# To make the type checker happy about fetching cogs by name
|
||||
# TODO: Move this to stubs at some point
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
...
|
||||
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
return super().get_cog(name)
|
||||
|
||||
async def add_cog(self, cog: Cog, **kwargs):
|
||||
sup = super()
|
||||
@log_wrap(action=f"Attach {cog.__cog_name__}")
|
||||
async def wrapper():
|
||||
logger.info(f"Attaching Cog {cog.__cog_name__}")
|
||||
await sup.add_cog(cog, **kwargs)
|
||||
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
|
||||
await wrapper()
|
||||
|
||||
async def load_extension(self, name, *, package=None, **kwargs):
|
||||
sup = super()
|
||||
@log_wrap(action=f"Load {name.strip('.')}")
|
||||
async def wrapper():
|
||||
logger.info(f"Loading extension {name} in package {package}.")
|
||||
await sup.load_extension(name, package=package, **kwargs)
|
||||
logger.debug(f"Loaded extension {name} in package {package}.")
|
||||
await wrapper()
|
||||
|
||||
async def start(self, token: str, *, reconnect: bool = True):
|
||||
await self.data.init()
|
||||
for component, req in SCHEMA_VERSIONS.items():
|
||||
await self.version_check(component, req)
|
||||
|
||||
with logging_context(action="Login"):
|
||||
start_task = asyncio.create_task(self.login(token))
|
||||
await start_task
|
||||
|
||||
with logging_context(stack=("Running",)):
|
||||
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
|
||||
await run_task
|
||||
|
||||
async def version_check(self, component: str, req_version: int):
|
||||
# Query the database to confirm that the given component is listed with the given version.
|
||||
# Typically done upon loading a component
|
||||
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
||||
|
||||
version = rows[0].to_version if rows else 0
|
||||
|
||||
if version != req_version:
|
||||
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
||||
else:
|
||||
logger.debug(
|
||||
"Component %s passed version check with version %s",
|
||||
component,
|
||||
version
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def dispatch(self, event_name: str, *args, **kwargs):
|
||||
with logging_context(action=f"Dispatch {event_name}"):
|
||||
super().dispatch(event_name, *args, **kwargs)
|
||||
|
||||
def _schedule_event(self, coro, event_name, *args, **kwargs):
|
||||
"""
|
||||
Extends client._schedule_event to keep a persistent
|
||||
background task store.
|
||||
"""
|
||||
task = super()._schedule_event(coro, event_name, *args, **kwargs)
|
||||
self._running_events.add(task)
|
||||
task.add_done_callback(lambda fut: self._running_events.discard(fut))
|
||||
|
||||
def idlock(self, snowflakeid):
|
||||
lock = self._locks.get(snowflakeid, None)
|
||||
if lock is None:
|
||||
lock = self._locks[snowflakeid] = asyncio.Lock()
|
||||
return lock
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info(
|
||||
f"Logged in as {self.application.name}\n"
|
||||
f"Application id {self.application.id}\n"
|
||||
f"Shard Talk identifier {self.shardname}\n"
|
||||
"------------------------------\n"
|
||||
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||
"------------------------------\n"
|
||||
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||
"Ready to take commands!\n",
|
||||
extra={'action': 'Ready'}
|
||||
)
|
||||
|
||||
async def get_context(self, origin, /, *, cls=MISSING):
|
||||
if cls is MISSING:
|
||||
cls = LionContext
|
||||
ctx = await super().get_context(origin, cls=cls)
|
||||
context.set(ctx)
|
||||
return ctx
|
||||
|
||||
async def on_command(self, ctx: LionContext):
|
||||
logger.info(
|
||||
f"Executing command '{ctx.command.qualified_name}' "
|
||||
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
|
||||
f"with interaction: {ctx.interaction.data if ctx.interaction else None}",
|
||||
extra={'with_ctx': True}
|
||||
)
|
||||
|
||||
async def on_command_error(self, ctx, exception):
|
||||
# TODO: Some of these could have more user-feedback
|
||||
logger.debug(f"Handling command error for {ctx}: {exception}")
|
||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||
cmd_str = ctx.command.app_command.to_dict(self.tree)
|
||||
else:
|
||||
cmd_str = str(ctx.command)
|
||||
try:
|
||||
raise exception
|
||||
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
|
||||
try:
|
||||
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
|
||||
original = exception.original.original
|
||||
raise original
|
||||
else:
|
||||
original = exception.original
|
||||
raise original
|
||||
except HandledException:
|
||||
pass
|
||||
except TransformerError as e:
|
||||
msg = str(e)
|
||||
if msg:
|
||||
try:
|
||||
await ctx.error_reply(msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a transformer error: {repr(e)}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except SafeCancellation:
|
||||
if original.msg:
|
||||
try:
|
||||
await ctx.error_reply(original.msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a safe cancellation: {original.details}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.Forbidden:
|
||||
# Unknown uncaught Forbidden
|
||||
try:
|
||||
# Attempt a general error reply
|
||||
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
|
||||
except Exception:
|
||||
# We can't send anything at all. Exit quietly, but log.
|
||||
logger.warning(
|
||||
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.HTTPException:
|
||||
logger.error(
|
||||
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
error_embed = discord.Embed(
|
||||
title="Something went wrong!",
|
||||
colour=discord.Colour.dark_red()
|
||||
)
|
||||
error_embed.description = (
|
||||
"An unexpected error occurred while processing your command!\n"
|
||||
"Our development team has been notified, and the issue will be addressed soon.\n"
|
||||
)
|
||||
details = {}
|
||||
details['error'] = f"`{repr(e)}`"
|
||||
if ctx.interaction:
|
||||
details['interactionid'] = f"`{ctx.interaction.id}`"
|
||||
if ctx.command:
|
||||
details['cmd'] = f"`{ctx.command.qualified_name}`"
|
||||
if ctx.author:
|
||||
details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`"
|
||||
if ctx.guild:
|
||||
details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`"
|
||||
details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`"
|
||||
if ctx.author:
|
||||
ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else ''
|
||||
details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`"
|
||||
if ctx.channel.type is discord.enums.ChannelType.private:
|
||||
details['channel'] = "`Direct Message`"
|
||||
elif ctx.channel:
|
||||
details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`"
|
||||
details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`"
|
||||
if ctx.author:
|
||||
details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`"
|
||||
details['shard'] = f"`{self.shardname}`"
|
||||
details['log_stack'] = f"`{log_action_stack.get()}`"
|
||||
|
||||
table = '\n'.join(tabulate(*details.items()))
|
||||
error_embed.add_field(name='Details', value=table)
|
||||
|
||||
try:
|
||||
await ctx.error_reply(embed=error_embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
finally:
|
||||
exception.original = HandledException(exception.original)
|
||||
except CheckFailure as e:
|
||||
logger.debug(
|
||||
f"Command failed check: {e}: {e.args}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
try:
|
||||
await ctx.error_reply(str(e))
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
except Exception:
|
||||
# Completely unknown exception outside of command invocation!
|
||||
# Something is very wrong here, don't attempt user interaction.
|
||||
logger.exception(
|
||||
f"Caught an unknown top-level exception while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
def add_command(self, command):
|
||||
if not hasattr(command, '_placeholder_group_'):
|
||||
super().add_command(command)
|
||||
|
||||
def request_chunking_for(self, guild):
|
||||
if not guild.chunked:
|
||||
return asyncio.create_task(
|
||||
self._connection.chunk_guild(guild, wait=False, cache=True),
|
||||
name=f"Background chunkreq for {guild.id}"
|
||||
)
|
||||
|
||||
async def on_interaction(self, interaction: discord.Interaction):
|
||||
"""
|
||||
Adds the interaction author to guild cache if appropriate.
|
||||
|
||||
This gets run a little bit late, so it is possible the interaction gets handled
|
||||
without the author being in case.
|
||||
"""
|
||||
guild = interaction.guild
|
||||
user = interaction.user
|
||||
if guild is not None and user is not None and isinstance(user, discord.Member):
|
||||
if not guild.get_member(user.id):
|
||||
guild._add_member(user)
|
||||
if guild is not None and not guild.chunked:
|
||||
# Getting an interaction in the guild is a good enough reason to request chunking
|
||||
logger.info(
|
||||
f"Unchunked guild <gid: {guild.id}> requesting chunking after interaction."
|
||||
)
|
||||
self.request_chunking_for(guild)
|
||||
Reference in New Issue
Block a user