Files
discord-voicefix-plugin/voicefix/cog.py

479 lines
18 KiB
Python

from collections import defaultdict
from typing import Optional
import asyncio
from cachetools import FIFOCache
import discord
from discord.abc import GuildChannel
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from utils.ui import Confirm
from . import logger
from .data import LinkData
async def prepare_attachments(attachments: list[discord.Attachment]):
results = []
for attach in attachments:
try:
as_file = await attach.to_file(spoiler=attach.is_spoiler())
results.append(as_file)
except discord.HTTPException:
pass
return results
async def prepare_embeds(message: discord.Message):
embeds = [embed for embed in message.embeds if embed.type == 'rich']
return embeds
class VoiceFixCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(LinkData())
# Map of linkids to list of channelids
self.link_channels = {}
# Map of channelids to linkids
self.channel_links = {}
# Map of channelids to initialised discord.Webhook
self.hooks = {}
# Map of messageid to list of (channelid, webhookmsg) pairs, for updates
self.message_cache = FIFOCache(maxsize=200)
# webhook msgid -> orig msgid
self.wmessages = FIFOCache(maxsize=600)
self.lock = asyncio.Lock()
async def cog_load(self):
await self.data.init()
await self.reload_links()
async def reload_links(self):
records = await self.data.channel_links.select_where()
channel_links = defaultdict(set)
link_channels = defaultdict(set)
for record in records:
linkid = record['linkid']
channelid = record['channelid']
channel_links[channelid].add(linkid)
link_channels[linkid].add(channelid)
channelids = list(channel_links.keys())
if channelids:
await self.data.LinkHook.fetch_where(channelid=channelids)
for channelid in channelids:
# Will hit cache, so don't need any more data queries
await self.fetch_webhook_for(channelid)
self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()}
self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()}
logger.info(
f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels."
)
@LionCog.listener('on_message')
async def on_message(self, message: discord.Message):
# Don't need this because everything except explicit messages are webhooks now
# if self.bot.user and (message.author.id == self.bot.user.id):
# return
if message.webhook_id:
return
async with self.lock:
sent = []
linkids = self.channel_links.get(message.channel.id, ())
if linkids:
for linkid in linkids:
for channelid in self.link_channels[linkid]:
if channelid != message.channel.id:
if message.attachments:
files = await prepare_attachments(message.attachments)
else:
files = []
hook = self.hooks[channelid]
avatar = message.author.display_avatar or message.author.default_avatar
msg = await hook.send(
wait=True,
content=self.prepare_content(message, channelid),
username=f"{message.author.display_name} [#{message.channel.name}]",
avatar_url=avatar.url,
embeds=await prepare_embeds(message),
files=files,
allowed_mentions=discord.AllowedMentions.none()
)
sent.append((channelid, msg))
self.wmessages[msg.id] = message.id
if sent:
# For easier lookup
self.wmessages[message.id] = message.id
sent.append((message.channel.id, message))
self.message_cache[message.id] = sent
logger.info(f"Forwarded message {message.id}")
def prepare_content(self, message, target_chid):
reference = None
if message.reference:
# Look up message from cache in this channel
refid = message.reference.message_id
if origid := self.wmessages.get(refid, None):
if sent := self.message_cache.get(origid, None):
this_ch_msg = next(
(msg for chid, msg in sent if chid == target_chid),
None
)
if this_ch_msg:
reference = this_ch_msg
original = message.reference.cached_message
author = original.author.display_name if original else 'Unknown'
forward_str = ''
if reference:
forward_str = f"-# Replying to {author} in {reference.jump_url}"
else:
forward_str = f"-# Replying to {author} in {message.reference.jump_url}"
else:
# forward_str = f"-# Forwarded from {message.jump_url}"
forward_str = ''
if len(message.content) + len(forward_str) < 2000:
content = message.content + '\n' + forward_str
else:
content = message.content
return content
@LionCog.listener('on_message_edit')
async def on_message_edit(self, before, after):
if before.webhook_id:
return
async with self.lock:
cached_sent = self.message_cache.pop(before.id, ())
new_sent = []
for cid, msg in cached_sent:
try:
if msg.id != before.id:
msg = await msg.edit(
content=self.prepare_content(after, cid),
embeds=await prepare_embeds(after),
)
new_sent.append((cid, msg))
except discord.NotFound:
pass
if new_sent:
self.message_cache[after.id] = new_sent
@LionCog.listener('on_message_delete')
async def on_message_delete(self, message):
async with self.lock:
origid = self.wmessages.get(message.id, None)
if origid:
cached_sent = self.message_cache.pop(origid, ())
for _, msg in cached_sent:
try:
if msg.id != message.id:
await msg.delete()
except discord.NotFound:
pass
@LionCog.listener('on_reaction_add')
async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User):
async with self.lock:
message = reaction.message
emoji = reaction.emoji
origid = self.wmessages.get(message.id, None)
if origid and reaction.count == 1:
cached_sent = self.message_cache.get(origid, ())
for _, msg in cached_sent:
# TODO: Would be better to have a Message and check the reactions
try:
if msg.id != message.id:
await msg.add_reaction(emoji)
except discord.HTTPException:
pass
async def fetch_webhook_for(self, channelid) -> discord.Webhook:
hook = self.hooks.get(channelid, None)
if hook is None:
row = await self.data.LinkHook.fetch(channelid)
if row is None:
channel = self.bot.get_channel(channelid)
if channel is None:
raise ValueError("Cannot find channel to create hook.")
hook = await channel.create_webhook(name="LabRat Channel Link")
await self.data.LinkHook.create(
channelid=channelid,
webhookid=hook.id,
token=hook.token,
)
else:
hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot)
self.hooks[channelid] = hook
return hook
@cmds.hybrid_group(
name='linker',
description="Base command group for the channel linker"
)
@appcmds.default_permissions(manage_channels=True)
async def linker_group(self, ctx: LionContext):
...
@linker_group.command(
name='link',
description="Create a new link, or add a channel to an existing link."
)
@appcmds.describe(
name="Name of the new or existing channel link.",
channel1="First channel to add to the link.",
channel2="Second channel to add to the link.",
channel3="Third channel to add to the link.",
channel4="Fourth channel to add to the link.",
channel5="Fifth channel to add to the link.",
channelid="Optionally add a channel by id (for e.g. cross-server links).",
)
async def linker_link(self, ctx: LionContext,
name: str,
channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channelid: Optional[str] = None,
):
if not ctx.interaction:
return
await ctx.interaction.response.defer(thinking=True)
# Check if link 'name' already exists, create if not
existing = await self.data.Link.fetch_where()
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
if link_row is None:
# Create
link_row = await self.data.Link.create(name=name)
link_channels = set()
created = True
else:
records = await self.data.channel_links.select_where(linkid=link_row.linkid)
link_channels = {record['channelid'] for record in records}
created = False
# Create webhooks and webhook rows on channels if required
maybe_channels = [
channel1, channel2, channel3, channel4, channel5,
]
if channelid:
if not channelid.isdigit():
return await ctx.error_reply("Channel id provided must be an integer!")
if not (channel := self.bot.get_channel(int(channelid))):
return await ctx.error_reply("Could not find a channel with the provided id!")
maybe_channels.append(channel)
channels = [channel for channel in maybe_channels if channel]
for channel in channels:
await self.fetch_webhook_for(channel.id)
# Insert or update the links
for channel in channels:
if channel.id not in link_channels:
await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id)
await self.reload_links()
if created:
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Link Created",
description=(
"Created the link **{name}** and linked channels:\n{channels}"
).format(name=name, channels=', '.join(channel.mention for channel in channels))
)
else:
channelids = self.link_channels[link_row.linkid]
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Channels Linked",
description=(
"Updated the link **{name}** to link the following channels:\n{channelstr}"
).format(name=link_row.name, channelstr=channelstr)
)
await ctx.reply(embed=embed)
@linker_group.command(
name='unlink',
description="Destroy a link, or remove a channel from a link."
)
@appcmds.describe(
name="Name of the link to destroy",
channel="Channel to remove from the link.",
)
async def linker_unlink(self, ctx: LionContext,
name: str, channel: Optional[GuildChannel] = None):
if not ctx.interaction:
return
# Get the link, error if it doesn't exist
existing = await self.data.Link.fetch_where()
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
if link_row is None:
raise UserInputError(
f"Link **{name}** doesn't exist!"
)
link_channelids = self.link_channels.get(link_row.linkid, ())
if channel is not None:
# If channel was given, remove channel from link and ack
if channel.id not in link_channelids:
raise UserInputError(
f"{channel.mention} is not linked in **{link_row.name}**!"
)
await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid)
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Channel Unlinked",
description=f"{channel.mention} has been removed from **{link_row.name}**."
)
else:
# Otherwise, confirm link destroy, delete link row, and ack
channels = ', '.join(f"<#{cid}>" for cid in link_channelids)
confirm = Confirm(
f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}",
ctx.author.id,
)
confirm.embed.colour = discord.Colour.red()
try:
result = await confirm.ask(ctx.interaction)
except ResponseTimedOut:
result = False
if not result:
raise SafeCancellation
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Link removed",
description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}"
)
await link_row.delete()
await self.reload_links()
await ctx.reply(embed=embed)
@linker_link.autocomplete('name')
async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str):
"""
Autocomplete an existing link.
"""
existing = await self.data.Link.fetch_where()
names = [row.name for row in existing]
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
if not matching:
choice = appcmds.Choice(
name=f"Create a new link '{partial}'",
value=partial
)
choices = [choice]
else:
choices = [
appcmds.Choice(
name=f"Link {name}",
value=name
)
for name in matching
]
return choices
@linker_unlink.autocomplete('name')
async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str):
"""
Autocomplete an existing link.
"""
existing = await self.data.Link.fetch_where()
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
if not matching:
choice = appcmds.Choice(
name=f"No existing links matching '{partial}'",
value=partial
)
choices = [choice]
else:
choices = [
appcmds.Choice(
name=f"Link {name}",
value=name
)
for name in matching
]
return choices
@linker_group.command(
name='links',
description="Display the existing channel links."
)
async def linker_links(self, ctx: LionContext):
if not ctx.interaction:
return
await ctx.interaction.response.defer(thinking=True)
links = await self.data.Link.fetch_where()
if not links:
embed = discord.Embed(
colour=discord.Colour.light_grey(),
title="No channel links have been set up!",
description="Create a new link and add channels with {linker}".format(
linker=self.bot.core.mention_cmd('linker link')
)
)
else:
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title=f"Channel Links in {ctx.guild.name}",
)
for link in links:
channelids = self.link_channels.get(link.linkid, ())
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
embed.add_field(
name=f"Link **{link.name}**",
value=channelstr,
inline=False
)
# TODO: May want paging if over 25 links....
await ctx.reply(embed=embed)
@linker_group.command(
name="webhook",
description='Manually configure the webhook for a given channel.'
)
async def linker_webhook(self, ctx: LionContext, channel: discord.abc.GuildChannel, webhook: str):
if not ctx.interaction:
return
hook = discord.Webhook.from_url(webhook, client=self.bot)
existing = await self.data.LinkHook.fetch(channel.id)
if existing:
await existing.update(webhookid=hook.id, token=hook.token)
else:
await self.data.LinkHook.create(
channelid=channel.id,
webhookid=hook.id,
token=hook.token,
)
self.hooks[channel.id] = hook
await ctx.reply(f"Webhook for {channel.mention} updated!")