diff --git a/bot/modules/__init__.py b/bot/modules/__init__.py index 7c5e5626..859e3e62 100644 --- a/bot/modules/__init__.py +++ b/bot/modules/__init__.py @@ -1,2 +1,11 @@ +this_package = 'modules' + +active = [ + '.sysadmin', + '.test' +] + + async def setup(bot): - await bot.load_extension('modules.bot_admin') + for ext in active: + await bot.load_extension(ext, package=this_package) diff --git a/bot/modules/bot_admin/__init__.py b/bot/modules/bot_admin/__init__.py deleted file mode 100644 index f072d6f8..00000000 --- a/bot/modules/bot_admin/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .exec_cog import Exec - - -async def setup(bot): - await bot.add_cog(Exec(bot)) diff --git a/bot/modules/bot_admin/exec_cog.py b/bot/modules/bot_admin/exec_cog.py deleted file mode 100644 index 7912861b..00000000 --- a/bot/modules/bot_admin/exec_cog.py +++ /dev/null @@ -1,270 +0,0 @@ -import io -import ast -import sys -import types -import traceback -import builtins -import inspect -import asyncio -import logging - -from typing import Callable, Any, Coroutine, List, Optional - -from enum import Enum - -import discord -from discord.ext import commands -from discord.app_commands import command -from discord.ui import Modal, TextInput, View -from discord.ui.button import button - -from meta.logger import logging_context -from meta.app import shard_talk -from meta.context import context -from meta.context import Context - - -logger = logging.getLogger(__name__) - - -class FastModal(Modal): - def __init__(self, *items, **kwargs): - super().__init__(**kwargs) - for item in items: - self.add_item(item) - self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future() - self._waiters: List[Coroutine[discord.Interaction]] = [] - - async def wait_for(self, check=None, timeout=None): - # Wait for _result or timeout - # If we timeout, or the view times out, raise TimeoutError - # Otherwise, return the Interaction - # This allows multiple listeners and callbacks to wait on - # TODO: Wait on the timeout as well - while True: - result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout) - if check is not None: - if not check(result): - continue - return result - - def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}): - def wrapper(coro): - async def wrapped_callback(interaction): - if check is not None: - if not check(interaction): - return - try: - await coro(interaction, *pass_args, **pass_kwargs) - except Exception: - # TODO: Log exception - ... - if once: - self._waiters.remove(wrapped_callback) - self._waiters.append(wrapped_callback) - return wrapper - - async def on_submit(self, interaction): - old_result = self._result - self._result = asyncio.get_event_loop().create_future() - old_result.set_result(interaction) - - for waiter in self._waiters: - asyncio.create_task(waiter(interaction)) - - async def on_error(self, interaction, error): - # This should never happen, since on_submit has its own error handling - # TODO: Logging - ... - - -class ExecModal(FastModal, title="Execute"): - code: TextInput = TextInput( - label="Code to execute", - style=discord.TextStyle.long, - required=True - ) - - -class ExecStyle(Enum): - EXEC = 'exec' - EVAL = 'eval' - - -class ExecUI(View): - def __init__(self, ctx, code=None, style=ExecStyle.EXEC, ephemeral=False): - super().__init__() - - self.ctx: commands.Context = ctx - self.interaction: Optional[discord.Interaction] = ctx.interaction - self.code: Optional[str] = code - self.style: ExecStyle = style - self.ephemeral: bool = ephemeral - - self._modal: Optional[ExecModal] = None - self._msg: Optional[discord.Message] = None - - async def run(self): - if self.code is None: - if (interaction := self.interaction) is not None: - self.interaction = None - await interaction.response.send_modal(self.get_modal()) - else: - # Complain - # TODO: error_reply - await self.ctx.reply("Pls give code.") - else: - await self.interaction.response.defer(thinking=True) - await self.compile() - - @button(label="Recompile") - async def recompile_button(self, interaction, butt): - # Interaction response with modal - await interaction.response.send_modal(self.get_modal()) - - def create_modal(self) -> ExecModal: - modal = ExecModal() - - @modal.submit_callback() - async def exec_submit(interaction: discord.Interaction): - if self.interaction is None: - self.interaction = interaction - await interaction.response.defer(thinking=True) - else: - await interaction.response.defer() - - # Set code - self.code = modal.code.value - - # Call compile - await self.compile() - - return modal - - def get_modal(self): - if self._modal is None: - # Create modal - self._modal = self.create_modal() - - self._modal.code.default = self.code - return self._modal - - async def compile(self): - # Call _async - result = await _async(self.code, style=self.style.value) - - # Display output - await self.show_output(result) - - async def show_output(self, output): - # Format output - # If output message exists and not ephemeral, edit - # Otherwise, send message, add buttons - # TODO: File output - # TODO: Check this handles ephemerals properly - formatted = "```py\n{}```".format(output) - if self._msg is None: - if self.interaction is not None: - msg = await self.interaction.edit_original_response(content=formatted, view=self) - else: - # Send new message - msg = await self.ctx.reply(formatted, ephemeral=self.ephemeral, view=self) - - if not self.ephemeral: - self._msg = msg - else: - if self.interaction is not None: - await self.interaction.edit_original_response(content=formatted, view=self) - else: - # Edit message - await self._msg.edit(formatted) - - -def mk_print(fp: io.StringIO) -> Callable[..., None]: - def _print(*args, file: Any = fp, **kwargs): - return print(*args, file=file, **kwargs) - return _print - - -async def _async(to_eval, style='exec'): - output = io.StringIO() - _print = mk_print(output) - - scope: dict[str, Any] = dict(sys.modules) - scope['__builtins__'] = builtins - scope.update(builtins.__dict__) - scope['ctx'] = ctx = context.get() - scope['bot'] = ctx.bot - scope['print'] = _print # type: ignore - - try: - if ctx.message: - source_str = f"" - elif ctx.interaction: - source_str = f"" - else: - source_str = "Unknown async" - - code = compile( - to_eval, - source_str, - style, - ast.PyCF_ALLOW_TOP_LEVEL_AWAIT - ) - func = types.FunctionType(code, scope) - - ret = func() - if inspect.iscoroutine(ret): - ret = await ret - if ret is not None: - _print(repr(ret)) - except Exception: - _, exc, tb = sys.exc_info() - _print("".join(traceback.format_tb(tb))) - _print(f"{type(exc).__name__}: {exc}") - - result = output.getvalue() - logger.info( - f"Exec complete, output:\n{result}", - extra={'action': "Code Exec"} - ) - return result - - -class Exec(commands.Cog): - def __init__(self, bot): - self.bot = bot - - @commands.hybrid_command(name='async') - async def async_cmd(self, ctx, *, string: str = None): - context.set(Context(bot=self.bot, interaction=ctx.interaction)) - with logging_context(context=f"mid:{ctx.message.id}", action="CMD ASYNC"): - logger.info("Running command") - await ExecUI(ctx, string, ExecStyle.EXEC).run() - - @commands.hybrid_command(name='eval') - async def eval_cmd(self, ctx, *, string: str): - await ExecUI(ctx, string, ExecStyle.EVAL).run() - - @command(name='test') - async def test_cmd(self, interaction: discord.Interaction, *, channel: discord.TextChannel = None): - if channel is None: - await interaction.response.send_message("Setting widget") - else: - await interaction.response.send_message( - f"Set channel to {channel} and then setting widget or update existing widget." - ) - - @command(name='shardeval') - async def shardeval_cmd(self, interaction: discord.Interaction, string: str): - await interaction.response.defer(thinking=True) - results = await shard_talk.requestall(exec_route(string)) - blocks = [] - for appid, result in results.items(): - blocks.append(f"```md\n[{appid}]\n{result}```") - await interaction.edit_original_response(content='\n'.join(blocks)) - - -@shard_talk.register_route('exec') -async def exec_route(string): - return await _async(string) diff --git a/bot/modules/sysadmin/__init__.py b/bot/modules/sysadmin/__init__.py new file mode 100644 index 00000000..8a27d2da --- /dev/null +++ b/bot/modules/sysadmin/__init__.py @@ -0,0 +1,15 @@ +from .exec_cog import Exec +from .blacklists import Blacklists +from .guild_log import GuildLog +from .presence import PresenceCtrl + +from .dash import LeoSettings + + +async def setup(bot): + await bot.add_cog(LeoSettings(bot)) + + await bot.add_cog(Blacklists(bot)) + await bot.add_cog(Exec(bot)) + await bot.add_cog(GuildLog(bot)) + await bot.add_cog(PresenceCtrl(bot)) diff --git a/bot/modules/sysadmin/blacklists.py b/bot/modules/sysadmin/blacklists.py new file mode 100644 index 00000000..f7de559d --- /dev/null +++ b/bot/modules/sysadmin/blacklists.py @@ -0,0 +1,667 @@ +from typing import Optional, List +import asyncio +import logging + +from data import Table, Registry, ORDER + +import discord +from discord.abc import Messageable +from discord.ext import commands as cmds +from discord.app_commands.transformers import AppCommandOptionType +from discord.ui.select import select, Select, SelectOption +from discord.ui.button import button +from discord.ui.text_input import TextStyle, TextInput + +from meta import LionCog, LionBot, LionContext +from meta.logger import logging_context, log_wrap +from meta.errors import UserInputError +from meta.app import shard_talk + +from utils.ui import ChoicedEnum, Transformed, FastModal, LeoUI, error_handler_for, ModalRetryUI +from utils.lib import EmbedField, tabulate, MessageArgs, parse_ids, error_embed + +from wards import sys_admin + +logger = logging.getLogger(__name__) + + +class BlacklistData(Registry, name="blacklists"): + guild_blacklist = Table('global_guild_blacklist') + user_blacklist = Table('global_user_blacklist') + + +class BlacklistAction(ChoicedEnum): + ADD_USER = "Blacklist Users" + RM_USER = "UnBlacklist Users" + ADD_GUILD = "Blacklist Guilds" + RM_GUILD = "UnBlacklist Guilds" + + @property + def choice_name(self): + return self.value + + +class Blacklists(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = self.bot.db.load_registry(BlacklistData()) + + self.user_blacklist: set[int] = set() + self.guild_blacklist: set[int] = set() + + self.talk_user_blacklist = shard_talk.register_route("user blacklist")(self.load_user_blacklist) + self.talk_guild_blacklist = shard_talk.register_route("guild blacklist")(self.load_guild_blacklist) + + async def cog_load(self): + await self.data.init() + await self.load_user_blacklist() + await self.load_guild_blacklist() + + async def load_user_blacklist(self): + """Populate the user blacklist.""" + rows = await self.data.user_blacklist.select_where() + self.user_blacklist = {row['userid'] for row in rows} + logger.info( + f"Loaded {len(self.user_blacklist)} blacklisted users." + ) + + async def load_guild_blacklist(self): + """Populate the guild blacklist.""" + rows = await self.data.guild_blacklist.select_where() + self.guild_blacklist = {row['guildid'] for row in rows} + logger.info( + f"Loaded {len(self.guild_blacklist)} blacklisted guilds." + ) + if self.bot.is_ready(): + with logging_context(action="Guild Blacklist"): + await self.leave_blacklisted_guilds() + + @LionCog.listener('on_ready') + @log_wrap(action="Guild Blacklist") + async def leave_blacklisted_guilds(self): + """Leave any blacklisted guilds we are in on this shard.""" + to_leave = [ + guild for guild in self.bot.guilds + if guild.id in self.guild_blacklist + ] + + asyncio.gather(*(guild.leave() for guild in to_leave)) + + logger.info( + "Left {} blacklisted guilds.".format(len(to_leave)), + ) + + @LionCog.listener('on_guild_join') + @log_wrap(action="Check Guild Blacklist") + async def check_guild_blacklist(self, guild): + """Check if the given guild is in the blacklist, and leave if so.""" + with logging_context(context=f"gid: {guild.id}"): + if guild.id in self.guild_blacklist: + await guild.leave() + logger.info( + "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id) + ) + + async def bot_check_once(self, ctx: LionContext) -> bool: # type:ignore + if ctx.author.id in self.user_blacklist: + logger.debug( + f"Ignoring command from blacklisted user .", + extra={'action': 'User Blacklist'} + ) + return False + else: + return True + + @log_wrap(action="User Blacklist") + async def blacklist_users(self, actorid, userids, reason): + await self.data.user_blacklist.insert_many( + ('userid', 'ownerid', 'reason'), + *((userid, actorid, reason) for userid in userids) + ) + self.user_blacklist.update(userids) + await self.talk_user_blacklist().broadcast() + + uid_str = ', '.join(f"" for userid in userids) + logger.info( + f"Owner blacklisted {uid_str} with reason: \"{reason}\"" + ) + + @log_wrap(action="User Blacklist") + async def unblacklist_users(self, actorid, userids): + await self.data.user_blacklist.delete_where(userid=userids) + self.user_blacklist.difference_update(userids) + + await self.talk_user_blacklist().broadcast() + + uid_str = ', '.join(f"" for userid in userids) + logger.info( + f"Owner removed blacklist for user(s) {uid_str}." + ) + + @log_wrap(action="Guild Blacklist") + async def blacklist_guilds(self, actorid, guildids, reason): + await self.data.guild_blacklist.insert_many( + ('guildid', 'ownerid', 'reason'), + *((guildid, actorid, reason) for guildid in guildids) + ) + self.guild_blacklist.update(guildids) + await self.talk_guild_blacklist().broadcast() + + gid_str = ', '.join(f"" for guildid in guildids) + logger.info( + f"Owner blacklisted {gid_str} with reason: \"{reason}\"" + ) + + @log_wrap(action="Guild Blacklist") + async def unblacklist_guilds(self, actorid, guildids): + await self.data.guild_blacklist.delete_where(guildid=guildids) + self.guild_blacklist.difference_update(guildids) + + await self.talk_guild_blacklist().broadcast() + + gid_str = ', '.join(f"" for guildid in guildids) + logger.info( + f"Owner removed blacklist for guild(s) {gid_str}." + ) + await self.check_guild_blacklist() + + @cmds.hybrid_command( + name="blacklist", + description="Display and modify the user and guild blacklists." + ) + @cmds.check(sys_admin) + async def blacklist_cmd( + self, + ctx: LionContext, + action: Optional[Transformed[BlacklistAction, AppCommandOptionType.string]] = None, + targets: Optional[str] = None, + reason: Optional[str] = None + ): + """ + Display and modify the user and guild blacklists. + + With no arguments, just displays the Blacklist UI. + + If `targets` are provided, they should be a comma separated list of user or guild ids. + If `action` is not specified, they are assumed to be users to blacklist. + `reason` is the reason for the blacklist. + If `targets` are provided, but `reason` is not, it will be prompted for. + """ + UI = BlacklistUI(ctx.bot, ctx, auth=[ctx.author.id]) + if not ctx.interaction: + return await ctx.error_reply("This command cannot be used as a text command.") + + if (action is None and targets is not None) or action is BlacklistAction.ADD_USER: + await UI.spawn_add_users(ctx.interaction, targets, reason) + elif action is BlacklistAction.ADD_GUILD: + await UI.spawn_add_guilds(ctx.interaction, targets, reason) + elif action is BlacklistAction.RM_USER: + if targets is None: + UI._show_remove = True + await UI.spawn() + else: + try: + userids = parse_ids(targets) + except UserInputError as ex: + await ctx.error_reply("Could not extract user id from {item}".format(**ex.info)) + else: + await UI.do_rm_users(ctx.interaction, userids) + elif action is BlacklistAction.RM_GUILD: + if targets is None: + UI._show_remove = True + UI.guild_mode = True + await UI.spawn() + else: + try: + guildids = parse_ids(targets) + except UserInputError as ex: + await ctx.error_reply("Could not extract guild id from {item}".format(**ex.info)) + else: + await UI.do_rm_guilds(ctx.interaction, guildids) + elif action is None and targets is None: + await UI.spawn() + + +class BlacklistInput(FastModal): + targets: TextInput = TextInput( + label="Userids to blacklist.", + placeholder="Comma separated ids.", + max_length=4000, + required=True + ) + + reason: TextInput = TextInput( + label="Reason for the blacklist.", + style=TextStyle.long, + max_length=4000, + required=True + ) + + @error_handler_for(UserInputError) + async def rerequest(self, interaction: discord.Interaction, error: UserInputError): + await ModalRetryUI(self, error.msg).respond_to(interaction) + + +class BlacklistUI(LeoUI): + block_len = 5 # Number of entries to show per page + + def __init__(self, bot: LionBot, dest: Messageable, auth: Optional[List[int]] = None): + super().__init__() + # Client information + self.bot = bot + self.cog: Blacklists = bot.get_cog('Blacklists') # type: ignore + if self.cog is None: + raise ValueError("Cannot run BlacklistUI without the 'Blacklists' cog.") + + # State + self.guild_mode = False # Whether we are showing guild blacklist or user blacklist + # List of current pages, as (page args, data slice) tuples + self.pages: Optional[List[tuple[MessageArgs, tuple[int, int]]]] = None + self.page_no: int = 0 # Current page we are on + self.data = None # List of data rows for this mode + + # Discord State + self.dest = dest # The destination to send or resend the UI + self.message: Optional[discord.Message] = None # Message holding the UI + + # UI State + # This is better handled by a general abstract "_extra" or layout modi interface. + # For now, just a flag for whether we show the extra remove menu. + self._show_remove = False + self.auth = auth # List of userids authorised to use the UI + + async def interaction_check(self, interaction): + if self.auth and interaction.user.id not in self.auth: + await interaction.response.send_message( + embed=error_embed("You are not authorised to use this interface!"), + ephemeral=True + ) + return False + else: + return True + + async def cleanup(self): + if self.message is not None: + try: + await self.message.edit(view=None) + except discord.HTTPException: + pass + + @button(label="ADD", row=2) + async def press_add(self, interaction, pressed): + if self.guild_mode: + await self.spawn_add_guilds(interaction) + else: + await self.spawn_add_users(interaction) + + @button(label="RM", row=2) + async def press_rm(self, interaction, pressed): + await interaction.response.defer() + self._show_remove = not self._show_remove + await self.show() + + @button(label="Switch", row=2) + async def press_switch(self, interaction, pressed): + await interaction.response.defer() + if self.guild_mode: + await self.set_user_mode() + else: + await self.set_guild_mode() + + @button(label="<", row=1) + async def press_previous(self, interaction, pressed): + await interaction.response.defer() + self.page_no -= 1 + await self.show() + + @button(label="x", row=1) + async def press_cancel(self, interaction, pressed): + await interaction.response.defer() + if self.message: + try: + await self.message.delete() + except discord.HTTPException: + pass + await self.close() + + @button(label=">", row=1) + async def press_next(self, interaction, pressed): + await interaction.response.defer() + self.page_no += 1 + await self.show() + + @select(cls=Select, max_values=block_len) + async def select_remove(self, interaction, selected): + self._show_remove = False + if not selected.values: + # Treat this as a cancel + await interaction.response.defer() + else: + # Parse the values and pass straight to the appropriate do method + # Aside from race states, should be impossible for this to raise a handled exception + # (So no need to catch UserInputError) + ids = map(int, selected.values) + if self.guild_mode: + await self.do_rm_guilds(interaction, ids) + else: + await self.do_rm_users(interaction, ids) + + @property + def current_page(self): + if not self.pages: + raise ValueError("Cannot get the current page without pages!") + self.page_no %= len(self.pages) + return self.pages[self.page_no] + + async def spawn(self): + """ + Run the UI. + """ + if self.guild_mode: + await self.set_guild_mode() + else: + await self.set_user_mode() + + async def update_data(self): + """ + Updated stored data for the current mode. + """ + if self.guild_mode: + query = self.cog.data.guild_blacklist.select_where() + query.leftjoin('guild_config', using=('guildid',)) + query.select('guildid', 'ownerid', 'reason', 'name', 'created_at') + else: + query = self.cog.data.user_blacklist.select_where() + query.leftjoin('user_config', using=('userid',)) + query.select('userid', 'ownerid', 'reason', 'name', 'created_at') + + query.order_by('created_at', ORDER.DESC) + self.data = await query + return self.data + + async def set_guild_mode(self): + """ + Change UI to guild blacklist mode. + """ + self.guild_mode = True + self.press_add.label = "Blacklist Guilds" + self.press_rm.label = "Un-Blacklist Guilds" + self.press_switch.label = "Show User List" + self.select_remove.placeholder = "Select User id to remove" + + if not self.guild_mode: + self._show_remove = False + + self.page_no = 0 + await self.refresh() + + async def set_user_mode(self): + """ + Change UI to user blacklist mode. + """ + self.press_add.label = "Blacklist Users" + self.press_rm.label = "Un-Blacklist Users" + self.press_switch.label = "Show Guild List" + self.select_remove.placeholder = "Select Guild id to remove" + + if self.guild_mode: + self._show_remove = False + + self.guild_mode = False + self.page_no = 0 + await self.refresh() + + async def show(self): + """ + Show the Blacklist UI, creating a new message if required. + """ + if len(self.pages) > 1: + self.set_layout( + (self.press_previous, self.press_cancel, self.press_next), + (self.press_add, self.press_rm, self.press_switch) + ) + else: + self.set_layout( + (self.press_add, self.press_rm, self.press_switch, self.press_cancel) + ) + page, slice = self.current_page + if self._show_remove and self.data: + key = 'guildid' if self.guild_mode else 'userid' + self.select_remove._underlying.options = [ + SelectOption(label=str(row[key]), value=str(row[key])) + for row in self.data[slice[0]:slice[1]] + ] + self.set_layout(*self._layout, (self.select_remove,)) + + self.press_rm.disabled = (not self.data) + + if self.message is not None: + self.message = await self.message.edit(**page.edit_args, view=self) + else: + self.message = await self.dest.send(**page.send_args, view=self) + + def format_user_rows(self, *rows): + fields = [] + for row in rows: + userid = row['userid'] + name = row['name'] + if user := self.bot.get_user(userid): + name = f"({user.name})" + elif oldname := row['name']: + name = f"({oldname})" + else: + name = '' + reason = row['reason'] + if len(reason) > 900: + reason = reason[:900] + '...' + table = '\n'.join(tabulate( + ("User", f"<@{userid}> {name}"), + ("Blacklisted by", f"<@{row['ownerid']}>"), + ("Blacklisted at", f""), + ("Reason", reason) + )) + fields.append(EmbedField(name=str(userid), value=table, inline=False)) + return fields + + def format_guild_rows(self, *rows): + fields = [] + for row in rows: + guildid = row['guildid'] + + name = row['name'] + if guild := self.bot.get_guild(guildid): + name = f"({guild.name})" + elif oldname := row['name']: + name = f"({oldname})" + else: + name = '' + + reason = row['reason'] + table = '\n'.join(tabulate( + ("Guild", f"`{guildid}` {name}"), + ("Blacklisted by", f"<@{row['ownerid']}>"), + ("Blacklisted at", f""), + ("Reason", reason) + )) + fields.append(EmbedField(name=str(guildid), value=table, inline=False)) + return fields + + async def make_pages(self): + """ + Format the data in `self.data`, respecting the current mode. + """ + if self.data is None: + raise ValueError("Cannot make pages without initialising first!") + + embeds = [] + slices = [] + if self.guild_mode: + title = "Guild Blacklist" + no_desc = "There are no blacklisted guilds" + formatter = self.format_guild_rows + else: + title = "User Blacklist" + no_desc = "There are no blacklisted users" + formatter = self.format_user_rows + + base_embed = discord.Embed( + title=title, + colour=discord.Colour.dark_orange() + ) + if len(self.data) == 0: + base_embed.description = no_desc + embeds.append(base_embed) + slices.append((0, 0)) + else: + fields = formatter(*self.data) + bl = self.block_len + blocks = [(fields[i:i+bl], (i, i+bl)) for i in range(0, len(fields), bl)] + n = len(blocks) + for i, (block, slice) in enumerate(blocks): + embed = base_embed.copy() + embed._fields = [field._asdict() for field in block] + if n > 1: + embed.title += f" (Page {i + 1}/{n})" + embeds.append(embed) + slices.append(slice) + + pages = [MessageArgs(embed=embed) for embed in embeds] + self.pages = list(zip(pages, slices)) + return self.pages + + async def refresh(self): + """ + Refresh the current UI message, if it exists. + Takes into account the current mode and page number. + """ + await self.update_data() + await self.make_pages() + await self.show() + + async def spawn_add_users(self, interaction: discord.Interaction, + userids: Optional[str] = None, reason: Optional[str] = None): + """Spawn the add_users modal, optionally with fields pre-filled.""" + modal = BlacklistInput(title="Blacklist users") + modal.targets.default = userids + modal.reason.default = reason + + @modal.submit_callback() + async def add_users_submit(interaction): + await self.parse_add_users(interaction, modal.targets.value, modal.reason.value) + + await interaction.response.send_modal(modal) + + async def parse_add_users(self, interaction, useridstr: str, reason: str): + """ + Parse provided userid string and reason, and pass onto do_add_users. + If they are invalid, instead raise a UserInputError. + """ + try: + userids = parse_ids(useridstr) + except UserInputError as ex: + raise UserInputError("Could not extract a user id from `$item`", info=ex.info) from None + + await self.do_add_users(interaction, userids, reason) + + async def do_add_users(self, interaction: discord.Interaction, userids: list[int], reason: str): + """ + Actually blacklist the given users and send an ack. + To be run after initial argument validation. + Updates the UI, or posts one if it doesn't exist. + """ + remaining = set(userids).difference(self.cog.user_blacklist) + if not remaining: + raise UserInputError("All provided users are already blacklisted!") + await self.cog.blacklist_users(interaction.user.id, list(remaining), reason) + embed = discord.Embed( + title="Users Blacklisted", + description=( + "You have blacklisted the following users:\n" + + (', '.join(f"`{uid}`" for uid in remaining)) + ), + colour=discord.Colour.green() + ) + await interaction.response.send_message(embed=embed, ephemeral=True) + if self.message is not None: + await self.set_user_mode() + + async def do_rm_users(self, interaction: discord.Interaction, userids: list[int]): + remaining = self.cog.user_blacklist.intersection(userids) + if not remaining: + raise UserInputError("None of these users are blacklisted") + await self.cog.unblacklist_users(interaction.user.id, list(remaining)) + embed = discord.Embed( + title="Users removed from Blacklist", + description=( + "You have removed the following users from the blacklist:\n" + + (', '.join(f"`{uid}`" for uid in remaining)) + ), + colour=discord.Colour.green() + ) + await interaction.response.send_message(embed=embed, ephemeral=True) + if self.message is not None: + await self.set_user_mode() + + async def spawn_add_guilds(self, interaction: discord.Interaction, + guildids: Optional[str] = None, reason: Optional[str] = None): + """Spawn the add_guilds modal, optionally with fields pre-filled.""" + modal = BlacklistInput(title="Blacklist guilds") + modal.targets.default = guildids + modal.reason.default = reason + + @modal.submit_callback() + async def add_guilds_submit(interaction): + await self.parse_add_guilds(interaction, modal.targets.value, modal.reason.value) + + await interaction.response.send_modal(modal) + + async def parse_add_guilds(self, interaction, guildidstr: str, reason: str): + """ + Parse provided guildid string and reason, and pass onto do_add_guilds. + If they are invalid, instead raise a UserInputError. + """ + try: + guildids = parse_ids(guildidstr) + except UserInputError as ex: + raise UserInputError("Could not extract a guild id from `$item`", info=ex.info) from None + + await self.do_add_guilds(interaction, guildids, reason) + + async def do_add_guilds(self, interaction: discord.Interaction, guildids: list[int], reason: str): + """ + Actually blacklist the given guilds and send an ack. + To be run after initial argument validation. + Updates the UI, or posts one if it doesn't exist. + """ + remaining = set(guildids).difference(self.cog.guild_blacklist) + if not remaining: + raise UserInputError("All provided guilds are already blacklisted!") + await self.cog.blacklist_guilds(interaction.user.id, list(remaining), reason) + embed = discord.Embed( + title="Guilds Blacklisted", + description=( + "You have blacklisted the following guilds:\n" + + (', '.join(f"`{gid}`" for gid in remaining)) + ), + colour=discord.Colour.green() + ) + await interaction.response.send_message(embed=embed, ephemeral=True) + if self.message is not None: + await self.set_guild_mode() + + async def do_rm_guilds(self, interaction: discord.Interaction, guildids: list[int]): + remaining = self.cog.guild_blacklist.intersection(guildids) + if not remaining: + raise UserInputError("None of these guilds are blacklisted") + await self.cog.unblacklist_guilds(interaction.user.id, list(remaining)) + embed = discord.Embed( + title="Guilds removed from Blacklist", + description=( + "You have removed the following guilds from the blacklist:\n" + + (', '.join(f"`{gid}`" for gid in remaining)) + ), + colour=discord.Colour.green() + ) + await interaction.response.send_message(embed=embed, ephemeral=True) + if self.message is not None: + await self.set_guild_mode() diff --git a/bot/modules/sysadmin/dash.py b/bot/modules/sysadmin/dash.py new file mode 100644 index 00000000..2f705d38 --- /dev/null +++ b/bot/modules/sysadmin/dash.py @@ -0,0 +1,42 @@ +""" +The dashboard shows a summary of the various registered global bot settings. +""" + +import discord +import discord.ext.commands as cmds + +from meta import LionBot, LionCog, LionContext +from meta.app import appname +from wards import sys_admin + +from settings.groups import SettingGroup + + +class LeoSettings(LionCog, group_name='leo'): + __cog_is_app_commands_group__ = True + depends = {'CoreCog'} + + def __init__(self, bot: LionBot): + self.bot = bot + + self.bot_setting_groups: list[SettingGroup] = [] + + @cmds.hybrid_command( + name='dashboard', + description="Global setting dashboard" + ) + @cmds.check(sys_admin) + async def dash_cmd(self, ctx: LionContext): + embed = discord.Embed( + title="System Admin Dashboard", + colour=discord.Colour.orange() + ) + for group in self.bot_setting_groups: + table = await group.make_setting_table(appname) + description = group.description.format(ctx=ctx, bot=ctx.bot).strip() + embed.add_field( + name=group.title.format(ctx=ctx, bot=ctx.bot), + value=f"{description}\n{table}" + ) + + await ctx.reply(embed=embed) diff --git a/bot/modules/sysadmin/exec_cog.py b/bot/modules/sysadmin/exec_cog.py new file mode 100644 index 00000000..93628c14 --- /dev/null +++ b/bot/modules/sysadmin/exec_cog.py @@ -0,0 +1,384 @@ +import io +import ast +import sys +import types +import asyncio +import traceback +import builtins +import inspect +import logging +from io import StringIO + +from typing import Callable, Any, Optional + +from enum import Enum + +import discord +from discord.ext import commands +from discord.ui import TextInput, View +from discord.ui.button import button +import discord.app_commands as appcmd + +from meta.logger import logging_context +from meta.app import shard_talk +from meta import conf +from meta.context import context, ctx_bot +from meta.LionContext import LionContext +from meta.LionCog import LionCog +from meta.LionBot import LionBot + +from utils.ui import FastModal, input + +from wards import sys_admin + + +logger = logging.getLogger(__name__) + + +class ExecModal(FastModal, title="Execute"): + code: TextInput = TextInput( + label="Code to execute", + style=discord.TextStyle.long, + required=True + ) + + +class ExecStyle(Enum): + EXEC = 'exec' + EVAL = 'eval' + + +class ExecUI(View): + def __init__(self, ctx, code=None, style=ExecStyle.EXEC, ephemeral=True) -> None: + super().__init__() + + self.ctx: LionContext = ctx + self.interaction: Optional[discord.Interaction] = ctx.interaction + self.code: Optional[str] = code + self.style: ExecStyle = style + self.ephemeral: bool = ephemeral + + self._modal: Optional[ExecModal] = None + self._msg: Optional[discord.Message] = None + + async def interaction_check(self, interaction: discord.Interaction): + """Only allow the original author to use this View""" + if interaction.user.id != self.ctx.author.id: + await interaction.response.send_message( + "You cannot use this interface!", + ephemeral=True + ) + return False + else: + return True + + async def run(self): + if self.code is None: + if (interaction := self.interaction) is not None: + self.interaction = None + await interaction.response.send_modal(self.get_modal()) + await self.wait() + else: + # Complain + # TODO: error_reply + await self.ctx.reply("Pls give code.") + else: + await self.interaction.response.defer(thinking=True, ephemeral=self.ephemeral) + await self.compile() + await self.wait() + + @button(label="Recompile") + async def recompile_button(self, interaction, butt): + # Interaction response with modal + await interaction.response.send_modal(self.get_modal()) + + @button(label="Show Source") + async def source_button(self, interaction, butt): + if len(self.code) > 1900: + # Send as file + with StringIO(self.code) as fp: + fp.seek(0) + file = discord.File(fp, filename="source.py") + await interaction.response.send_message(file=file, ephemeral=True) + else: + # Send as message + await interaction.response.send_message( + content=f"```py\n{self.code}```", + ephemeral=True + ) + + def create_modal(self) -> ExecModal: + modal = ExecModal() + + @modal.submit_callback() + async def exec_submit(interaction: discord.Interaction): + if self.interaction is None: + self.interaction = interaction + await interaction.response.defer(thinking=True) + else: + await interaction.response.defer() + + # Set code + self.code = modal.code.value + + # Call compile + await self.compile() + + return modal + + def get_modal(self): + if self._modal is None: + # Create modal + self._modal = self.create_modal() + + self._modal.code.default = self.code + return self._modal + + async def compile(self): + # Call _async + result = await _async(self.code, style=self.style.value) + + # Display output + await self.show_output(result) + + async def show_output(self, output): + # Format output + # If output message exists and not ephemeral, edit + # Otherwise, send message, add buttons + if len(output) > 1900: + # Send as file + with StringIO(output) as fp: + fp.seek(0) + args = { + 'content': None, + 'attachments': [discord.File(fp, filename="output.md")] + } + else: + args = { + 'content': f"```md\n{output}```", + 'attachments': [] + } + + if self._msg is None: + if self.interaction is not None: + msg = await self.interaction.edit_original_response(**args, view=self) + else: + # Send new message + if args['content'] is None: + args['file'] = args.pop('attachments')[0] + msg = await self.ctx.reply(**args, ephemeral=self.ephemeral, view=self) + + if not self.ephemeral: + self._msg = msg + else: + if self.interaction is not None: + await self.interaction.edit_original_response(**args, view=self) + else: + # Edit message + await self._msg.edit(**args) + + +def mk_print(fp: io.StringIO) -> Callable[..., None]: + def _print(*args, file: Any = fp, **kwargs): + return print(*args, file=file, **kwargs) + return _print + + +async def _async(to_eval: str, style='exec'): + with logging_context(action="Code Exec"): + newline = '\n' * ('\n' in to_eval) + logger.info( + f"Exec code with {style}: {newline}{to_eval}" + ) + + output = io.StringIO() + _print = mk_print(output) + + scope: dict[str, Any] = dict(sys.modules) + scope['__builtins__'] = builtins + scope.update(builtins.__dict__) + scope['ctx'] = ctx = context.get() + scope['bot'] = ctx_bot.get() + scope['print'] = _print # type: ignore + + try: + if ctx and ctx.message: + source_str = f"" + elif ctx and ctx.interaction: + source_str = f"" + else: + source_str = "Unknown async" + + code = compile( + to_eval, + source_str, + style, + ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + ) + func = types.FunctionType(code, scope) + + ret = func() + if inspect.iscoroutine(ret): + ret = await ret + if ret is not None: + _print(repr(ret)) + except Exception: + _, exc, tb = sys.exc_info() + _print("".join(traceback.format_tb(tb))) + _print(f"{type(exc).__name__}: {exc}") + + result = output.getvalue().strip() + newline = '\n' * ('\n' in result) + logger.info( + f"Exec complete, output: {newline}{result}" + ) + return result + + +class Exec(LionCog): + guild_ids = conf.bot.getintlist('admin_guilds') + + def __init__(self, bot: LionBot): + self.bot = bot + + self.talk_async = shard_talk.register_route('exec')(_async) + + async def cog_check(self, ctx: LionContext) -> bool: # type: ignore + return await sys_admin(ctx) + + @commands.hybrid_command( + name='async', + description="Execute arbitrary code with Exec" + ) + @appcmd.describe( + string="Code to execute." + ) + @appcmd.guilds(*guild_ids) + async def async_cmd(self, ctx: LionContext, *, string: Optional[str] = None): + await ExecUI(ctx, string, ExecStyle.EXEC).run() + + @commands.hybrid_command( + name='eval', + description='Execute arbitrary code with Eval' + ) + @appcmd.describe( + string="Code to evaluate." + ) + @appcmd.guilds(*guild_ids) + async def eval_cmd(self, ctx: LionContext, *, string: str): + await ExecUI(ctx, string, ExecStyle.EVAL).run() + + @commands.hybrid_command( + name='asyncall', + description="Execute arbitrary code on all shards." + ) + @appcmd.describe( + string="Cross-shard code to execute. Cannot reference ctx!", + target="Target shard app name, see autocomplete for options." + ) + @appcmd.guilds(*guild_ids) + async def asyncall_cmd(self, ctx: LionContext, string: Optional[str] = None, target: Optional[str] = None): + if string is None and ctx.interaction: + try: + ctx.interaction, string = await input( + ctx.interaction, "Cross-shard execute", "Code to execute?", + style=discord.TextStyle.long + ) + except asyncio.TimeoutError: + return + if ctx.interaction: + await ctx.interaction.response.defer(thinking=True, ephemeral=True) + if target is not None: + if target not in shard_talk.peers: + embed = discord.Embed(description=f"Unknown peer {target}", colour=discord.Colour.red()) + if ctx.interaction: + await ctx.interaction.edit_original_response(embed=embed) + else: + await ctx.reply(embed=embed) + return + else: + result = await self.talk_async(string).send(target) + results = {target: result} + else: + results = await self.talk_async(string).broadcast(except_self=False) + + blocks = [f"# {appid}\n{result}" for appid, result in results.items()] + output = "\n\n".join(blocks) + if len(output) > 1900: + # Send as file + with StringIO(output) as fp: + fp.seek(0) + file = discord.File(fp, filename="output.md") # type: ignore + await ctx.reply(file=file) + else: + # Send as message + await ctx.reply(f"```md\n{output}```", ephemeral=True) + + @asyncall_cmd.autocomplete('target') + async def asyncall_target_acmpl(self, interaction: discord.Interaction, partial: str): + appids = set(shard_talk.peers.keys()) + results = [ + appcmd.Choice(name=appid, value=appid) + for appid in appids + if partial.lower() in appid.lower() + ] + if not results: + results = [ + appcmd.Choice(name=f"No peers found matching {partial}", value="None") + ] + return results + + @commands.hybrid_command( + name='reload', + description="Reload a given LionBot extension. Launches an ExecUI." + ) + @appcmd.describe( + extension="Name of the extesion to reload. See autocomplete for options." + ) + @appcmd.guilds(*guild_ids) + async def reload_cmd(self, ctx: LionContext, extension: str): + """ + This is essentially just a friendly wrapper to reload an extension. + It is equivalent to running "await bot.reload_extension(extension)" in eval, + with a slightly nicer interface through the autocomplete and error handling. + """ + if extension not in self.bot.extensions: + embed = discord.Embed(description=f"Unknown extension {extension}", colour=discord.Colour.red()) + await ctx.reply(embed=embed) + else: + # Uses an ExecUI to simplify error handling and re-execution + string = f"await bot.reload_extension('{extension}')" + await ExecUI(ctx, string, ExecStyle.EVAL).run() + + @reload_cmd.autocomplete('extension') + async def reload_extension_acmpl(self, interaction: discord.Interaction, partial: str): + keys = set(self.bot.extensions.keys()) + results = [ + appcmd.Choice(name=key, value=key) + for key in keys + if partial.lower() in key.lower() + ] + if not results: + results = [ + appcmd.Choice(name=f"No extensions found matching {partial}", value="None") + ] + return results + + @commands.hybrid_command( + name='shutdown', + description="Shutdown (or restart) the client." + ) + @appcmd.guilds(*guild_ids) + async def shutdown_cmd(self, ctx: LionContext): + """ + Shutdown the client. + Maybe do something friendly here? + """ + logger.info("Shutting down on admin request.") + await ctx.reply( + embed=discord.Embed( + description=f"Understood {ctx.author.mention}, cleaning up and shutting down!", + colour=discord.Colour.orange() + ) + ) + await self.bot.close() diff --git a/bot/modules/sysadmin/guild_log.py b/bot/modules/sysadmin/guild_log.py new file mode 100644 index 00000000..a12d24e2 --- /dev/null +++ b/bot/modules/sysadmin/guild_log.py @@ -0,0 +1,89 @@ +import datetime + +import discord +from discord import Webhook + +from meta.LionCog import LionCog +from meta.LionBot import LionBot +from meta.logger import log_wrap + + +class GuildLog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + + @LionCog.listener('on_guild_remove') + @log_wrap(action="Log Guild Leave") + async def log_left_guild(self, guild: discord.Guild): + # Build embed + embed = discord.Embed(title="`{0.name} (ID: {0.id})`".format(guild), + colour=discord.Colour.red(), + timestamp=datetime.datetime.utcnow()) + embed.set_author(name="Left guild!") + + # Add more specific information about the guild + embed.add_field(name="Owner", value="{0.name} (ID: {0.id})".format(guild.owner), inline=False) + embed.add_field(name="Members (cached)", value="{}".format(len(guild.members)), inline=False) + embed.add_field(name="Now studying in", value="{} guilds".format(len(self.bot.guilds)), inline=False) + + # Retrieve the guild log channel and log the event + log_webhook = self.bot.config.endpoints.get("guild_log") + if log_webhook: + webhook = Webhook.from_url(log_webhook, session=self.bot.web_client) + await webhook.send(embed=embed, username=self.bot.appname) + + @LionCog.listener('on_guild_join') + @log_wrap(action="Log Guild Join") + async def log_join_guild(self, guild: discord.Guild): + owner = guild.owner + + bots = 0 + known = 0 + unknown = 0 + other_members = set(mem.id for mem in self.bot.get_all_members() if mem.guild != guild) + + for member in guild.members: + if member.bot: + bots += 1 + elif member.id in other_members: + known += 1 + else: + unknown += 1 + + mem1 = "people I know" if known != 1 else "person I know" + mem2 = "new friends" if unknown != 1 else "new friend" + mem3 = "bots" if bots != 1 else "bot" + mem4 = "total members" + known = "`{}`".format(known) + unknown = "`{}`".format(unknown) + bots = "`{}`".format(bots) + total = "`{}`".format(guild.member_count) + mem_str = "{0:<5}\t{4},\n{1:<5}\t{5},\n{2:<5}\t{6}, and\n{3:<5}\t{7}.".format( + known, + unknown, + bots, + total, + mem1, + mem2, + mem3, + mem4 + ) + created = "".format(int(guild.created_at.timestamp())) + + embed = discord.Embed( + title="`{0.name} (ID: {0.id})`".format(guild), + colour=discord.Colour.green(), + timestamp=datetime.datetime.utcnow() + ) + embed.set_author(name="Joined guild!") + + embed.add_field(name="Owner", value="{0} (ID: {0.id})".format(owner), inline=False) + embed.add_field(name="Created at", value=created, inline=False) + embed.add_field(name="Members", value=mem_str, inline=False) + embed.add_field(name="Now studying in", value="{} guilds".format(len(self.bot.guilds)), inline=False) + + # Retrieve the guild log channel and log the event + log_webhook = self.bot.config.endpoints.get("guild_log") + if log_webhook: + webhook = Webhook.from_url(log_webhook, session=self.bot.web_client) + await webhook.send(embed=embed, username=self.bot.appname) diff --git a/bot/modules/sysadmin/leo_group.py b/bot/modules/sysadmin/leo_group.py new file mode 100644 index 00000000..9ca3b5dd --- /dev/null +++ b/bot/modules/sysadmin/leo_group.py @@ -0,0 +1,68 @@ +from discord.app_commands import Group, Command +from discord.ext.commands import HybridCommand + +from meta import LionCog + + +class LeoGroup(Group, name='leo'): + """ + Base command group for all Leo system admin commands. + """ + ... + + +""" +TODO: + This will take some work to get working. + We want to be able to specify a command in a cog + as a subcommand of a group command in a different cog, + or even a different extension. + Unfortunately, this really messes with the hotloading and unloading, + and may require overriding LionCog.__new__. + + We also have to answer some implementation decisions, + such as what happens when the child command cog gets unloaded/reloaded? + What happens when the group command gets unloaded/reloaded? + + Well, if the child cog gets unloaded, it makes sense to detach the commands. + The commands should keep their binding to the defining cog, + the parent command is mainly relevant for the CommandTree, which we have control of anyway.. + + If the parent cog gets unloaded, it makes sense to unload all the subcommands, if possible. + + Now technically, it shouldn't _matter_ where the child command is defined. + The Tree is in charge (or should be) of arranging parent commands and subcommands. + The Group class should just specify some optional extra properties or wrappers + to apply to the subcommands. + So perhaps we can just extend Hybrid command to actually pass in a parent... + Or specify a _string_ as the parent, which gets mapped with a group class + if it exists.. but it doesn't need to exist. +""" + + +class LeoCog(LionCog): + """ + Abstract container cog acting as a manager for the LeoGroup above. + """ + def __init__(self, bot): + self.bot = bot + self.commands = [] + self.group = LeoGroup() + + def attach(self, *commands): + """ + Attach the given commands to the LeoGroup group. + """ + for command in commands: + if isinstance(command, Command): + # Classic app command, attach as-is + cmd = command + elif isinstance(command, HybridCommand): + cmd = command.app_command + else: + raise ValueError( + f"Command must by 'app_commands.Command' or 'commands.HybridCommand' not {cmd.__class_}" + ) + self.group.add_command(cmd) + + self.commands.extend(commands) diff --git a/bot/modules/sysadmin/presence.py b/bot/modules/sysadmin/presence.py new file mode 100644 index 00000000..71a88948 --- /dev/null +++ b/bot/modules/sysadmin/presence.py @@ -0,0 +1,375 @@ +from typing import Optional +import asyncio +import logging +from string import Template + +import discord +from discord.ext import commands as cmds +import discord.app_commands as appcmds +from discord.app_commands.transformers import AppCommandOptionType + +from meta import LionCog, LionBot, LionContext +from meta.logger import log_wrap +from meta.app import shard_talk, appname +from utils.ui import ChoicedEnum, Transformed +from utils.lib import tabulate + +from data import RowModel, Registry, RegisterEnum +from data.columns import String, Column + +from settings.data import ModelData +from settings.setting_types import EnumSetting, StringSetting +from settings.groups import SettingGroup + +from wards import sys_admin + +logger = logging.getLogger(__name__) + + +class AppActivityType(ChoicedEnum): + """ + Schema + ------ + CREATE TYPE ActivityType AS ENUM( + 'PLAYING', + 'WATCHING', + 'LISTENING', + 'STREAMING' + ); + """ + playing = ('PLAYING', 'Playing', discord.ActivityType.playing) + watching = ('WATCHING', 'Watching', discord.ActivityType.watching) + listening = ('LISTENING', 'Listening', discord.ActivityType.listening) + streaming = ('STREAMING', 'Streaming', discord.ActivityType.streaming) + + @property + def choice_name(self): + return self.value[1] + + @property + def choice_value(self): + return self.value[1] + + +class AppStatus(ChoicedEnum): + """ + Schema + ------ + CREATE TYPE OnlineStatus AS ENUM( + 'ONLINE', + 'IDLE', + 'DND', + 'OFFLINE' + ); + """ + online = ('ONLINE', 'Online', discord.Status.online) + idle = ('IDLE', 'Idle', discord.Status.idle) + dnd = ('DND', 'Do Not Disturb', discord.Status.dnd) + offline = ('OFFLINE', 'Offline/Invisible', discord.Status.offline) + + @property + def choice_name(self): + return self.value[1] + + @property + def choice_value(self): + return self.value[1] + + +class PresenceData(Registry, name='presence'): + class AppPresence(RowModel): + """ + Schema + ------ + CREATE TABLE bot_config_presence( + appname TEXT PRIMARY KEY REFERENCES bot_config(appname) ON DELETE CASCADE, + online_status OnlineStatus, + activity_type ActivityType, + activity_name Text + ); + """ + _tablename_ = 'bot_config_presence' + _cache_ = {} + + appname = String(primary=True) + online_status: Column[AppStatus] = Column() + activity_type: Column[AppActivityType] = Column() + activity_name = String() + + AppActivityType = RegisterEnum(AppActivityType, name="ActivityType") + AppStatus = RegisterEnum(AppStatus, name='OnlineStatus') + + +class PresenceSettings(SettingGroup): + """ + Control the bot status and activity. + """ + _title = "Presence Settings ({bot.core.cmd_name_cache[presence].mention})" + + class PresenceStatus(ModelData, EnumSetting[str, AppStatus]): + display_name = 'online_status' + desc = "Bot status indicator" + long_desc = "Whether the bot account displays as online, idle, dnd, or offline." + accepts = "One of 'online', 'idle', 'dnd', or 'offline'." + + _model = PresenceData.AppPresence + _column = PresenceData.AppPresence.online_status.name + _create_row = True + + _enum = AppStatus + _outputs = {item: item.value[1] for item in _enum} + _inputs = {item.name: item for item in _enum} + _default = AppStatus.online + + class PresenceType(ModelData, EnumSetting[str, AppActivityType]): + display_name = 'activity_type' + desc = "Type of presence activity" + long_desc = "Whether the bot activity is shown as 'Listening', 'Playing', or 'Watching'." + accepts = "One of 'listening', 'playing', 'watching', or 'streaming'." + + _model = PresenceData.AppPresence + _column = PresenceData.AppPresence.activity_type.name + _create_row = True + + _enum = AppActivityType + _outputs = {item: item.value[1] for item in _enum} + _inputs = {item.name: item for item in _enum} + _default = AppActivityType.watching + + class PresenceName(ModelData, StringSetting[str]): + display_name = 'activity_name' + desc = "Name of the presence activity" + long_desc = "Presence activity name." + accepts = "Any string." + + _model = PresenceData.AppPresence + _column = PresenceData.AppPresence.activity_name.name + _create_row = True + _default = "$in_vc students in $voice_channels study rooms!" + + +class PresenceCtrl(LionCog): + depends = {'CoreCog', 'LeoSettings'} + + # Only update every 60 seconds at most + ratelimit = 60 + + # Update at least every 300 seconds regardless of events + interval = 300 + + # Possible substitution keys, and the events that listen to them + keys = { + '$in_vc': {'on_voice_state_update'}, + '$voice_channels': {'on_channel_add', 'on_channel_remove'}, + '$shard_members': {'on_member_join', 'on_member_leave'}, + '$shard_guilds': {'on_guild_join', 'on_guild_leave'} + } + + default_format = "$in_vc students in $voice_channels study rooms!" + default_activity = discord.ActivityType.watching + default_status = discord.Status.online + + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(PresenceData()) + self.settings = PresenceSettings() + + self.activity_type: discord.ActivityType = self.default_activity + self.activity_format: str = self.default_format + self.status: discord.Status = self.default_status + + self._listening: set = set() + self._tick = asyncio.Event() + self._loop_task: Optional[asyncio.Task] = None + + self.talk_reload_presence = shard_talk.register_route("reload presence")(self.reload_presence) + + async def cog_load(self): + await self.data.init() + if (leo_setting_cog := self.bot.get_cog('LeoSettings')) is not None: + leo_setting_cog.bot_setting_groups.append(self.settings) + + await self.reload_presence() + self.update_listeners() + self._loop_task = asyncio.create_task(self.presence_loop()) + await self.tick() + + async def cog_unload(self): + """ + De-register the event listeners, and cancel the presence update loop. + """ + if (leo_setting_cog := self.bot.get_cog('LeoSettings')) is not None: + leo_setting_cog.bot_setting_groups.remove(self.settings) + + if self._loop_task is not None and not self._loop_task.done(): + self._loop_task.cancel("Unloading") + + for event in self._listening: + self.bot.remove_listener(self.tick, event) + self._listening.discard(event) + + def update_listeners(self): + # Build the list of events that should trigger status updates + # Un-register any current listeners we don't need + # Re-register any new listeners we need + new_listeners = set() + for key, events in self.keys.items(): + if key in self.activity_format: + new_listeners.update(events) + to_remove = self._listening.difference(new_listeners) + to_add = new_listeners.difference(self._listening) + + for event in to_remove: + self.bot.remove_listener(self.tick, event) + for event in to_add: + self.bot.add_listener(self.tick, event) + + self._listening = new_listeners + + async def reload_presence(self) -> None: + # Reload the presence information from the appconfig table + # TODO: When botconfig is done, these should load from settings, instead of directly from data + self.data.AppPresence._cache_.pop(appname, None) + self.activity_type = (await self.settings.PresenceType.get(appname)).value.value[2] + self.activity_format = (await self.settings.PresenceName.get(appname)).value + self.status = (await self.settings.PresenceStatus.get(appname)).value.value[2] + + async def set_presence(self, activity: Optional[discord.BaseActivity], status: Optional[discord.Status]): + """ + Globally change the client presence and save the new presence information. + """ + # TODO: Waiting on botconfig settings + self.activity_type = activity.type if activity else None + self.activity_name = activity.name if activity else None + self.status = status or self.status + await self.talk_reload_presence().broadcast(except_self=False) + + async def format_activity(self, form: str) -> str: + """ + Format the given string. + """ + subs = { + 'shard_members': sum(1 for _ in self.bot.get_all_members()), + 'shard_guilds': sum(1 for _ in self.bot.guilds) + } + if '$in_vc' in form: + # TODO: Waiting on study module data + subs['in_vc'] = sum(1 for m in self.bot.get_all_members() if m.voice and m.voice.channel) + if '$voice_channels' in form: + # TODO: Waiting on study module data + subs['voice_channels'] = sum(1 for c in self.bot.get_all_channels() if c.type == discord.ChannelType.voice) + + return Template(form).safe_substitute(subs) + + async def tick(self, *args, **kwargs): + """ + Request a presence update when next possible. + Arbitrary arguments allow this to be used as a generic event listener. + """ + self._tick.set() + + @log_wrap(action="Presence Update") + async def _do_presence_update(self): + try: + activity_name = await self.format_activity(self.activity_format) + await self.bot.change_presence( + activity=discord.Activity( + type=self.activity_type, + name=activity_name + ), + status=self.status + ) + logger.debug( + "Set status to '%s' with activity '%s' \"%s\"", + str(self.status), str(self.activity_type), str(activity_name) + ) + except Exception: + logger.exception( + "Unhandled exception occurred while updating client presence. Ignoring." + ) + + @log_wrap(stack=["Presence", "Loop"]) + async def presence_loop(self): + """ + Request a client presence update when possible. + """ + await self.bot.wait_until_ready() + logger.debug("Launching presence update loop.") + try: + while True: + # Wait for the wakeup event + try: + await asyncio.wait_for(self._tick.wait(), timeout=self.interval) + except asyncio.TimeoutError: + pass + + # Clear the wakeup event + self._tick.clear() + + # Run the presence update + await self._do_presence_update() + + # Wait for the delay + await asyncio.sleep(self.ratelimit) + except asyncio.CancelledError: + logger.debug("Closing client presence update loop.") + except Exception: + logger.exception( + "Unhandled exception occurred running client presence update loop. Closing loop." + ) + + @cmds.hybrid_command( + name="presence", + description="Globally set the bot status and activity." + ) + @cmds.check(sys_admin) + @appcmds.describe( + status="Online status (online | idle | dnd | offline)", + type="Activity type (watching | listening | playing | streaming)", + string="Activity name, supports substitutions $in_vc, $voice_channels, $shard_guilds, $shard_members" + ) + async def presence_cmd( + self, + ctx: LionContext, + status: Optional[Transformed[AppStatus, AppCommandOptionType.string]] = None, + type: Optional[Transformed[AppActivityType, AppCommandOptionType.string]] = None, + string: Optional[str] = None + ): + """ + Modify the client online status and activity. + + Discord makes no guarantees as to which combination of activity type and arguments actually work. + """ + colours = { + discord.Status.online: discord.Colour.green(), + discord.Status.idle: discord.Colour.orange(), + discord.Status.dnd: discord.Colour.red(), + discord.Status.offline: discord.Colour.light_grey() + } + + if any((status, type, string)): + # TODO: Batch? + if status is not None: + await self.settings.PresenceStatus(appname, status).write() + if type is not None: + await self.settings.PresenceType(appname, type).write() + if string is not None: + await self.settings.PresenceName(appname, string).write() + + await self.talk_reload_presence().broadcast(except_self=False) + await self._do_presence_update() + + current_name = await self.format_activity(self.activity_format) + table = '\n'.join( + tabulate( + ('Status', self.status.name), + ('Activity', f"{self.activity_type.name} {current_name}"), + ) + ) + await ctx.reply( + embed=discord.Embed( + title="Current Presence", + description=table, + colour=colours[self.status] + ) + ) diff --git a/bot/wards.py b/bot/wards.py new file mode 100644 index 00000000..1ea4da38 --- /dev/null +++ b/bot/wards.py @@ -0,0 +1,9 @@ +from meta.LionContext import LionContext + + +async def sys_admin(ctx: LionContext) -> bool: + """ + Checks whether the context author is listed in the configuration file as a bot admin. + """ + admins = ctx.bot.config.bot.getintlist('admins') + return ctx.author.id in admins