From de02e64b15668a1f2153cac3daf07ce1c9e78440 Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 3 Nov 2023 12:57:44 +0200 Subject: [PATCH] feat: Implement voicefix. --- data/schema.sql | 22 ++ src/meta/LionBot.py | 4 +- src/meta/LionTree.py | 4 +- src/modules/__init__.py | 1 + src/modules/voicefix/__init__.py | 7 + src/modules/voicefix/cog.py | 428 +++++++++++++++++++++++++++++++ src/modules/voicefix/data.py | 39 +++ 7 files changed, 499 insertions(+), 6 deletions(-) create mode 100644 src/modules/voicefix/__init__.py create mode 100644 src/modules/voicefix/cog.py create mode 100644 src/modules/voicefix/data.py diff --git a/data/schema.sql b/data/schema.sql index 0641eed..3d508a2 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -31,4 +31,26 @@ CREATE TABLE bot_config( ); -- }}} +-- Channel Linker {{{ + +CREATE TABLE links( + linkid SERIAL PRIMARY KEY, + name TEXT +); + +CREATE TABLE channel_webhooks( + channelid BIGINT PRIMARY KEY, + webhookid BIGINT NOT NULL, + token TEXT NOT NULL +); + +CREATE TABLE channel_links( + linkid INTEGER NOT NULL REFERENCES links (linkid) ON DELETE CASCADE, + channelid BIGINT NOT NULL REFERENCES channel_webhooks (channelid) ON DELETE CASCADE, + PRIMARY KEY (linkid, channelid) +); + + +-- }}} + -- vim: set fdm=marker: diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 32f39ee..b7e3e65 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -260,9 +260,7 @@ class LionBot(Bot): error_embed.description = ( "An unexpected error occurred while processing your command!\n" "Our development team has been notified, and the issue will be addressed soon.\n" - "If the error persists, or you have any questions, please contact our [support team]({link}) " - "and give them the extra details below." - ).format(link=self.config.bot.support_guild) + ) details = {} details['error'] = f"`{repr(e)}`" if ctx.interaction: diff --git a/src/meta/LionTree.py b/src/meta/LionTree.py index 75a3ccf..3ce097b 100644 --- a/src/meta/LionTree.py +++ b/src/meta/LionTree.py @@ -51,9 +51,7 @@ class LionTree(CommandTree): error_embed.description = ( "An unexpected error occurred during this interaction!\n" "Our development team has been notified, and the issue will be addressed soon.\n" - "If the error persists, or you have any questions, please contact our [support team]({link}) " - "and give them the extra details below." - ).format(link=interaction.client.config.bot.support_guild) + ) details = {} details['error'] = f"`{repr(e)}`" details['interactionid'] = f"`{interaction.id}`" diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 3f6887e..099c81b 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -2,6 +2,7 @@ this_package = 'modules' active = [ '.sysadmin', + '.voicefix', ] diff --git a/src/modules/voicefix/__init__.py b/src/modules/voicefix/__init__.py new file mode 100644 index 0000000..2859cf8 --- /dev/null +++ b/src/modules/voicefix/__init__.py @@ -0,0 +1,7 @@ +import logging + +logger = logging.getLogger(__name__) + +async def setup(bot): + from .cog import VoiceFixCog + await bot.add_cog(VoiceFixCog(bot)) diff --git a/src/modules/voicefix/cog.py b/src/modules/voicefix/cog.py new file mode 100644 index 0000000..1070340 --- /dev/null +++ b/src/modules/voicefix/cog.py @@ -0,0 +1,428 @@ +from collections import defaultdict +from typing import Optional +import asyncio +from cachetools import FIFOCache + +import discord +from discord.abc import GuildChannel +from discord.ext import commands as cmds +from discord import app_commands as appcmds + +from meta import LionBot, LionCog, LionContext +from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError +from utils.ui import Confirm + +from . import logger +from .data import LinkData + + +async def prepare_attachments(attachments: list[discord.Attachment]): + results = [] + for attach in attachments: + try: + as_file = await attach.to_file(spoiler=attach.is_spoiler()) + results.append(as_file) + except discord.HTTPException: + pass + return results + + +async def prepare_embeds(message: discord.Message): + embeds = [embed for embed in message.embeds if embed.type == 'rich'] + if message.reference: + embed = discord.Embed( + colour=discord.Colour.dark_gray(), + description=f"Reply to {message.reference.jump_url}" + ) + embeds.append(embed) + return embeds + + + +class VoiceFixCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(LinkData()) + + # Map of linkids to list of channelids + self.link_channels = {} + + # Map of channelids to linkids + self.channel_links = {} + + # Map of channelids to initialised discord.Webhook + self.hooks = {} + + # Map of messageid to list of (channelid, webhookmsg) pairs, for updates + self.message_cache = FIFOCache(maxsize=200) + # webhook msgid -> orig msgid + self.wmessages = FIFOCache(maxsize=600) + + self.lock = asyncio.Lock() + + + async def cog_load(self): + await self.data.init() + + await self.reload_links() + + async def reload_links(self): + records = await self.data.channel_links.select_where() + channel_links = defaultdict(set) + link_channels = defaultdict(set) + + for record in records: + linkid = record['linkid'] + channelid = record['channelid'] + + channel_links[channelid].add(linkid) + link_channels[linkid].add(channelid) + + channelids = list(channel_links.keys()) + if channelids: + await self.data.LinkHook.fetch_where(channelid=channelids) + for channelid in channelids: + # Will hit cache, so don't need any more data queries + await self.fetch_webhook_for(channelid) + + self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()} + self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()} + + logger.info( + f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels." + ) + + @LionCog.listener('on_message') + async def on_message(self, message: discord.Message): + # Don't need this because everything except explicit messages are webhooks now + # if self.bot.user and (message.author.id == self.bot.user.id): + # return + if message.webhook_id: + return + + async with self.lock: + sent = [] + linkids = self.channel_links.get(message.channel.id, ()) + if linkids: + for linkid in linkids: + for channelid in self.link_channels[linkid]: + if channelid != message.channel.id: + if message.attachments: + files = await prepare_attachments(message.attachments) + else: + files = [] + + hook = self.hooks[channelid] + avatar = message.author.avatar or message.author.default_avatar + msg = await hook.send( + content=message.content, + wait=True, + username=message.author.display_name, + avatar_url=avatar.url, + embeds=await prepare_embeds(message), + files=files, + allowed_mentions=discord.AllowedMentions.none() + ) + sent.append((channelid, msg)) + self.wmessages[msg.id] = message.id + if sent: + # For easier lookup + self.wmessages[message.id] = message.id + sent.append((message.channel.id, message)) + + self.message_cache[message.id] = sent + logger.info(f"Forwarded message {message.id}") + + + @LionCog.listener('on_message_edit') + async def on_message_edit(self, before, after): + async with self.lock: + cached_sent = self.message_cache.pop(before.id, ()) + new_sent = [] + for cid, msg in cached_sent: + try: + if msg.id != before.id: + msg = await msg.edit( + content=after.content, + embeds=await prepare_embeds(after), + ) + new_sent.append((cid, msg)) + except discord.NotFound: + pass + if new_sent: + self.message_cache[after.id] = new_sent + + @LionCog.listener('on_message_delete') + async def on_message_delete(self, message): + async with self.lock: + origid = self.wmessages.get(message.id, None) + if origid: + cached_sent = self.message_cache.pop(origid, ()) + for _, msg in cached_sent: + try: + if msg.id != message.id: + await msg.delete() + except discord.NotFound: + pass + + @LionCog.listener('on_reaction_add') + async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User): + async with self.lock: + message = reaction.message + emoji = reaction.emoji + origid = self.wmessages.get(message.id, None) + if origid and reaction.count == 1: + cached_sent = self.message_cache.get(origid, ()) + for _, msg in cached_sent: + # TODO: Would be better to have a Message and check the reactions + try: + if msg.id != message.id: + await msg.add_reaction(emoji) + except discord.HTTPException: + pass + + async def fetch_webhook_for(self, channelid) -> discord.Webhook: + hook = self.hooks.get(channelid, None) + if hook is None: + row = await self.data.LinkHook.fetch(channelid) + if row is None: + channel = self.bot.get_channel(channelid) + if channel is None: + raise ValueError("Cannot find channel to create hook.") + hook = await channel.create_webhook(name="LabRat Channel Link") + await self.data.LinkHook.create( + channelid=channelid, + webhookid=hook.id, + token=hook.token, + ) + else: + hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot) + self.hooks[channelid] = hook + return hook + + @cmds.hybrid_group( + name='linker', + description="Base command group for the channel linker" + ) + @appcmds.default_permissions(manage_channels=True) + async def linker_group(self, ctx: LionContext): + ... + + @linker_group.command( + name='link', + description="Create a new link, or add a channel to an existing link." + ) + @appcmds.describe( + name="Name of the new or existing channel link.", + channel1="First channel to add to the link.", + channel2="Second channel to add to the link.", + channel3="Third channel to add to the link.", + channel4="Fourth channel to add to the link.", + channel5="Fifth channel to add to the link.", + channelid="Optionally add a channel by id (for e.g. cross-server links).", + ) + async def linker_link(self, ctx: LionContext, + name: str, + channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None, + channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None, + channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None, + channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None, + channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None, + channelid: Optional[str] = None, + ): + if not ctx.interaction: + return + await ctx.interaction.response.defer(thinking=True) + + # Check if link 'name' already exists, create if not + existing = await self.data.Link.fetch_where() + link_row = next((row for row in existing if row.name.lower() == name.lower()), None) + if link_row is None: + # Create + link_row = await self.data.Link.create(name=name) + link_channels = set() + created = True + else: + records = await self.data.channel_links.select_where(linkid=link_row.linkid) + link_channels = {record['channelid'] for record in records} + created = False + + # Create webhooks and webhook rows on channels if required + maybe_channels = [ + channel1, channel2, channel3, channel4, channel5, + ] + if channelid and channelid.isdigit(): + channel = self.bot.get_channel(int(channelid)) + maybe_channels.append(channel) + + channels = [channel for channel in maybe_channels if channel] + for channel in channels: + await self.fetch_webhook_for(channel.id) + + # Insert or update the links + for channel in channels: + if channel.id not in link_channels: + await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id) + + await self.reload_links() + + if created: + embed = discord.Embed( + colour=discord.Colour.brand_green(), + title="Link Created", + description=( + "Created the link **{name}** and linked channels:\n{channels}" + ).format(name=name, channels=', '.join(channel.mention for channel in channels)) + ) + else: + channelids = self.link_channels[link_row.linkid] + channelstr = ', '.join(f"<#{cid}>" for cid in channelids) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + title="Channels Linked", + description=( + "Updated the link **{name}** to link the following channels:\n{channelstr}" + ).format(name=link_row.name, channelstr=channelstr) + ) + await ctx.reply(embed=embed) + + @linker_group.command( + name='unlink', + description="Destroy a link, or remove a channel from a link." + ) + @appcmds.describe( + name="Name of the link to destroy", + channel="Channel to remove from the link.", + ) + async def linker_unlink(self, ctx: LionContext, + name: str, channel: Optional[GuildChannel] = None): + if not ctx.interaction: + return + # Get the link, error if it doesn't exist + existing = await self.data.Link.fetch_where() + link_row = next((row for row in existing if row.name.lower() == name.lower()), None) + if link_row is None: + raise UserInputError( + f"Link **{name}** doesn't exist!" + ) + + link_channelids = self.link_channels.get(link_row.linkid, ()) + + if channel is not None: + # If channel was given, remove channel from link and ack + if channel.id not in link_channelids: + raise UserInputError( + f"{channel.mention} is not linked in **{link_row.name}**!" + ) + await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + title="Channel Unlinked", + description=f"{channel.mention} has been removed from **{link_row.name}**." + ) + else: + # Otherwise, confirm link destroy, delete link row, and ack + channels = ', '.join(f"<#{cid}>" for cid in link_channelids) + confirm = Confirm( + f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}", + ctx.author.id, + ) + confirm.embed.colour = discord.Colour.red() + try: + result = await confirm.ask(ctx.interaction) + except ResponseTimedOut: + result = False + if not result: + raise SafeCancellation + + embed = discord.Embed( + colour=discord.Colour.brand_green(), + title="Link removed", + description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}" + ) + await link_row.delete() + + await self.reload_links() + await ctx.reply(embed=embed) + + @linker_link.autocomplete('name') + async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str): + """ + Autocomplete an existing link. + """ + existing = await self.data.Link.fetch_where() + names = [row.name for row in existing] + matching = [row.name for row in existing if partial.lower() in row.name.lower()] + if not matching: + choice = appcmds.Choice( + name=f"Create a new link '{partial}'", + value=partial + ) + choices = [choice] + else: + choices = [ + appcmds.Choice( + name=f"Link {name}", + value=name + ) + for name in matching + ] + return choices + + @linker_unlink.autocomplete('name') + async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str): + """ + Autocomplete an existing link. + """ + existing = await self.data.Link.fetch_where() + matching = [row.name for row in existing if partial.lower() in row.name.lower()] + if not matching: + choice = appcmds.Choice( + name=f"No existing links matching '{partial}'", + value=partial + ) + choices = [choice] + else: + choices = [ + appcmds.Choice( + name=f"Link {name}", + value=name + ) + for name in matching + ] + return choices + + @linker_group.command( + name='links', + description="Display the existing channel links." + ) + async def linker_links(self, ctx: LionContext): + if not ctx.interaction: + return + await ctx.interaction.response.defer(thinking=True) + + links = await self.data.Link.fetch_where() + + if not links: + embed = discord.Embed( + colour=discord.Colour.light_grey(), + title="No channel links have been set up!", + description="Create a new link and add channels with {linker}".format( + linker=self.bot.core.mention_cmd('linker link') + ) + ) + else: + embed = discord.Embed( + colour=discord.Colour.brand_green(), + title=f"Channel Links in {ctx.guild.name}", + ) + for link in links: + channelids = self.link_channels.get(link.linkid, ()) + channelstr = ', '.join(f"<#{cid}>" for cid in channelids) + embed.add_field( + name=f"Link **{link.name}**", + value=channelstr, + inline=False + ) + # TODO: May want paging if over 25 links.... + await ctx.reply(embed=embed) diff --git a/src/modules/voicefix/data.py b/src/modules/voicefix/data.py new file mode 100644 index 0000000..e594f49 --- /dev/null +++ b/src/modules/voicefix/data.py @@ -0,0 +1,39 @@ +from data import Registry, RowModel, Table +from data.columns import Integer, Bool, Timestamp, String + + +class LinkData(Registry): + class Link(RowModel): + """ + Schema + ------ + CREATE TABLE links( + linkid SERIAL PRIMARY KEY, + name TEXT + ); + """ + _tablename_ = 'links' + _cache_ = {} + + linkid = Integer(primary=True) + name = String() + + + channel_links = Table('channel_links') + + class LinkHook(RowModel): + """ + Schema + ------ + CREATE TABLE channel_webhooks( + channelid BIGINT PRIMARY KEY, + webhookid BIGINT NOT NULL, + token TEXT NOT NULL + ); + """ + _tablename_ = 'channel_webhooks' + _cache_ = {} + + channelid = Integer(primary=True) + webhookid = Integer() + token = String()