447 lines
17 KiB
Python
447 lines
17 KiB
Python
from collections import defaultdict
|
|
from typing import Optional
|
|
import asyncio
|
|
from cachetools import FIFOCache
|
|
|
|
import discord
|
|
from discord.abc import GuildChannel
|
|
from discord.ext import commands as cmds
|
|
from discord import app_commands as appcmds
|
|
|
|
from meta import LionBot, LionCog, LionContext
|
|
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
|
from utils.ui import Confirm
|
|
|
|
from . import logger
|
|
from .data import LinkData
|
|
|
|
|
|
async def prepare_attachments(attachments: list[discord.Attachment]):
|
|
results = []
|
|
for attach in attachments:
|
|
try:
|
|
as_file = await attach.to_file(spoiler=attach.is_spoiler())
|
|
results.append(as_file)
|
|
except discord.HTTPException:
|
|
pass
|
|
return results
|
|
|
|
|
|
async def prepare_embeds(message: discord.Message):
|
|
embeds = [embed for embed in message.embeds if embed.type == 'rich']
|
|
if message.reference:
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.dark_gray(),
|
|
description=f"Reply to {message.reference.jump_url}"
|
|
)
|
|
embeds.append(embed)
|
|
return embeds
|
|
|
|
|
|
|
|
class VoiceFixCog(LionCog):
|
|
def __init__(self, bot: LionBot):
|
|
self.bot = bot
|
|
self.data = bot.db.load_registry(LinkData())
|
|
|
|
# Map of linkids to list of channelids
|
|
self.link_channels = {}
|
|
|
|
# Map of channelids to linkids
|
|
self.channel_links = {}
|
|
|
|
# Map of channelids to initialised discord.Webhook
|
|
self.hooks = {}
|
|
|
|
# Map of messageid to list of (channelid, webhookmsg) pairs, for updates
|
|
self.message_cache = FIFOCache(maxsize=200)
|
|
# webhook msgid -> orig msgid
|
|
self.wmessages = FIFOCache(maxsize=600)
|
|
|
|
self.lock = asyncio.Lock()
|
|
|
|
|
|
async def cog_load(self):
|
|
await self.data.init()
|
|
|
|
await self.reload_links()
|
|
|
|
async def reload_links(self):
|
|
records = await self.data.channel_links.select_where()
|
|
channel_links = defaultdict(set)
|
|
link_channels = defaultdict(set)
|
|
|
|
for record in records:
|
|
linkid = record['linkid']
|
|
channelid = record['channelid']
|
|
|
|
channel_links[channelid].add(linkid)
|
|
link_channels[linkid].add(channelid)
|
|
|
|
channelids = list(channel_links.keys())
|
|
if channelids:
|
|
await self.data.LinkHook.fetch_where(channelid=channelids)
|
|
for channelid in channelids:
|
|
# Will hit cache, so don't need any more data queries
|
|
await self.fetch_webhook_for(channelid)
|
|
|
|
self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()}
|
|
self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()}
|
|
|
|
logger.info(
|
|
f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels."
|
|
)
|
|
|
|
@LionCog.listener('on_message')
|
|
async def on_message(self, message: discord.Message):
|
|
# Don't need this because everything except explicit messages are webhooks now
|
|
# if self.bot.user and (message.author.id == self.bot.user.id):
|
|
# return
|
|
if message.webhook_id:
|
|
return
|
|
|
|
async with self.lock:
|
|
sent = []
|
|
linkids = self.channel_links.get(message.channel.id, ())
|
|
if linkids:
|
|
for linkid in linkids:
|
|
for channelid in self.link_channels[linkid]:
|
|
if channelid != message.channel.id:
|
|
if message.attachments:
|
|
files = await prepare_attachments(message.attachments)
|
|
else:
|
|
files = []
|
|
|
|
hook = self.hooks[channelid]
|
|
avatar = message.author.avatar or message.author.default_avatar
|
|
msg = await hook.send(
|
|
content=message.content,
|
|
wait=True,
|
|
username=message.author.display_name,
|
|
avatar_url=avatar.url,
|
|
embeds=await prepare_embeds(message),
|
|
files=files,
|
|
allowed_mentions=discord.AllowedMentions.none()
|
|
)
|
|
sent.append((channelid, msg))
|
|
self.wmessages[msg.id] = message.id
|
|
if sent:
|
|
# For easier lookup
|
|
self.wmessages[message.id] = message.id
|
|
sent.append((message.channel.id, message))
|
|
|
|
self.message_cache[message.id] = sent
|
|
logger.info(f"Forwarded message {message.id}")
|
|
|
|
|
|
@LionCog.listener('on_message_edit')
|
|
async def on_message_edit(self, before, after):
|
|
async with self.lock:
|
|
cached_sent = self.message_cache.pop(before.id, ())
|
|
new_sent = []
|
|
for cid, msg in cached_sent:
|
|
try:
|
|
if msg.id != before.id:
|
|
msg = await msg.edit(
|
|
content=after.content,
|
|
embeds=await prepare_embeds(after),
|
|
)
|
|
new_sent.append((cid, msg))
|
|
except discord.NotFound:
|
|
pass
|
|
if new_sent:
|
|
self.message_cache[after.id] = new_sent
|
|
|
|
@LionCog.listener('on_message_delete')
|
|
async def on_message_delete(self, message):
|
|
async with self.lock:
|
|
origid = self.wmessages.get(message.id, None)
|
|
if origid:
|
|
cached_sent = self.message_cache.pop(origid, ())
|
|
for _, msg in cached_sent:
|
|
try:
|
|
if msg.id != message.id:
|
|
await msg.delete()
|
|
except discord.NotFound:
|
|
pass
|
|
|
|
@LionCog.listener('on_reaction_add')
|
|
async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User):
|
|
async with self.lock:
|
|
message = reaction.message
|
|
emoji = reaction.emoji
|
|
origid = self.wmessages.get(message.id, None)
|
|
if origid and reaction.count == 1:
|
|
cached_sent = self.message_cache.get(origid, ())
|
|
for _, msg in cached_sent:
|
|
# TODO: Would be better to have a Message and check the reactions
|
|
try:
|
|
if msg.id != message.id:
|
|
await msg.add_reaction(emoji)
|
|
except discord.HTTPException:
|
|
pass
|
|
|
|
async def fetch_webhook_for(self, channelid) -> discord.Webhook:
|
|
hook = self.hooks.get(channelid, None)
|
|
if hook is None:
|
|
row = await self.data.LinkHook.fetch(channelid)
|
|
if row is None:
|
|
channel = self.bot.get_channel(channelid)
|
|
if channel is None:
|
|
raise ValueError("Cannot find channel to create hook.")
|
|
hook = await channel.create_webhook(name="LabRat Channel Link")
|
|
await self.data.LinkHook.create(
|
|
channelid=channelid,
|
|
webhookid=hook.id,
|
|
token=hook.token,
|
|
)
|
|
else:
|
|
hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot)
|
|
self.hooks[channelid] = hook
|
|
return hook
|
|
|
|
@cmds.hybrid_group(
|
|
name='linker',
|
|
description="Base command group for the channel linker"
|
|
)
|
|
@appcmds.default_permissions(manage_channels=True)
|
|
async def linker_group(self, ctx: LionContext):
|
|
...
|
|
|
|
@linker_group.command(
|
|
name='link',
|
|
description="Create a new link, or add a channel to an existing link."
|
|
)
|
|
@appcmds.describe(
|
|
name="Name of the new or existing channel link.",
|
|
channel1="First channel to add to the link.",
|
|
channel2="Second channel to add to the link.",
|
|
channel3="Third channel to add to the link.",
|
|
channel4="Fourth channel to add to the link.",
|
|
channel5="Fifth channel to add to the link.",
|
|
channelid="Optionally add a channel by id (for e.g. cross-server links).",
|
|
)
|
|
async def linker_link(self, ctx: LionContext,
|
|
name: str,
|
|
channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
channelid: Optional[str] = None,
|
|
):
|
|
if not ctx.interaction:
|
|
return
|
|
await ctx.interaction.response.defer(thinking=True)
|
|
|
|
# Check if link 'name' already exists, create if not
|
|
existing = await self.data.Link.fetch_where()
|
|
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
|
if link_row is None:
|
|
# Create
|
|
link_row = await self.data.Link.create(name=name)
|
|
link_channels = set()
|
|
created = True
|
|
else:
|
|
records = await self.data.channel_links.select_where(linkid=link_row.linkid)
|
|
link_channels = {record['channelid'] for record in records}
|
|
created = False
|
|
|
|
# Create webhooks and webhook rows on channels if required
|
|
maybe_channels = [
|
|
channel1, channel2, channel3, channel4, channel5,
|
|
]
|
|
if channelid and channelid.isdigit():
|
|
channel = self.bot.get_channel(int(channelid))
|
|
maybe_channels.append(channel)
|
|
|
|
channels = [channel for channel in maybe_channels if channel]
|
|
for channel in channels:
|
|
await self.fetch_webhook_for(channel.id)
|
|
|
|
# Insert or update the links
|
|
for channel in channels:
|
|
if channel.id not in link_channels:
|
|
await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id)
|
|
|
|
await self.reload_links()
|
|
|
|
if created:
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.brand_green(),
|
|
title="Link Created",
|
|
description=(
|
|
"Created the link **{name}** and linked channels:\n{channels}"
|
|
).format(name=name, channels=', '.join(channel.mention for channel in channels))
|
|
)
|
|
else:
|
|
channelids = self.link_channels[link_row.linkid]
|
|
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.brand_green(),
|
|
title="Channels Linked",
|
|
description=(
|
|
"Updated the link **{name}** to link the following channels:\n{channelstr}"
|
|
).format(name=link_row.name, channelstr=channelstr)
|
|
)
|
|
await ctx.reply(embed=embed)
|
|
|
|
@linker_group.command(
|
|
name='unlink',
|
|
description="Destroy a link, or remove a channel from a link."
|
|
)
|
|
@appcmds.describe(
|
|
name="Name of the link to destroy",
|
|
channel="Channel to remove from the link.",
|
|
)
|
|
async def linker_unlink(self, ctx: LionContext,
|
|
name: str, channel: Optional[GuildChannel] = None):
|
|
if not ctx.interaction:
|
|
return
|
|
# Get the link, error if it doesn't exist
|
|
existing = await self.data.Link.fetch_where()
|
|
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
|
if link_row is None:
|
|
raise UserInputError(
|
|
f"Link **{name}** doesn't exist!"
|
|
)
|
|
|
|
link_channelids = self.link_channels.get(link_row.linkid, ())
|
|
|
|
if channel is not None:
|
|
# If channel was given, remove channel from link and ack
|
|
if channel.id not in link_channelids:
|
|
raise UserInputError(
|
|
f"{channel.mention} is not linked in **{link_row.name}**!"
|
|
)
|
|
await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid)
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.brand_green(),
|
|
title="Channel Unlinked",
|
|
description=f"{channel.mention} has been removed from **{link_row.name}**."
|
|
)
|
|
else:
|
|
# Otherwise, confirm link destroy, delete link row, and ack
|
|
channels = ', '.join(f"<#{cid}>" for cid in link_channelids)
|
|
confirm = Confirm(
|
|
f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}",
|
|
ctx.author.id,
|
|
)
|
|
confirm.embed.colour = discord.Colour.red()
|
|
try:
|
|
result = await confirm.ask(ctx.interaction)
|
|
except ResponseTimedOut:
|
|
result = False
|
|
if not result:
|
|
raise SafeCancellation
|
|
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.brand_green(),
|
|
title="Link removed",
|
|
description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}"
|
|
)
|
|
await link_row.delete()
|
|
|
|
await self.reload_links()
|
|
await ctx.reply(embed=embed)
|
|
|
|
@linker_link.autocomplete('name')
|
|
async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str):
|
|
"""
|
|
Autocomplete an existing link.
|
|
"""
|
|
existing = await self.data.Link.fetch_where()
|
|
names = [row.name for row in existing]
|
|
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
|
if not matching:
|
|
choice = appcmds.Choice(
|
|
name=f"Create a new link '{partial}'",
|
|
value=partial
|
|
)
|
|
choices = [choice]
|
|
else:
|
|
choices = [
|
|
appcmds.Choice(
|
|
name=f"Link {name}",
|
|
value=name
|
|
)
|
|
for name in matching
|
|
]
|
|
return choices
|
|
|
|
@linker_unlink.autocomplete('name')
|
|
async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str):
|
|
"""
|
|
Autocomplete an existing link.
|
|
"""
|
|
existing = await self.data.Link.fetch_where()
|
|
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
|
if not matching:
|
|
choice = appcmds.Choice(
|
|
name=f"No existing links matching '{partial}'",
|
|
value=partial
|
|
)
|
|
choices = [choice]
|
|
else:
|
|
choices = [
|
|
appcmds.Choice(
|
|
name=f"Link {name}",
|
|
value=name
|
|
)
|
|
for name in matching
|
|
]
|
|
return choices
|
|
|
|
@linker_group.command(
|
|
name='links',
|
|
description="Display the existing channel links."
|
|
)
|
|
async def linker_links(self, ctx: LionContext):
|
|
if not ctx.interaction:
|
|
return
|
|
await ctx.interaction.response.defer(thinking=True)
|
|
|
|
links = await self.data.Link.fetch_where()
|
|
|
|
if not links:
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.light_grey(),
|
|
title="No channel links have been set up!",
|
|
description="Create a new link and add channels with {linker}".format(
|
|
linker=self.bot.core.mention_cmd('linker link')
|
|
)
|
|
)
|
|
else:
|
|
embed = discord.Embed(
|
|
colour=discord.Colour.brand_green(),
|
|
title=f"Channel Links in {ctx.guild.name}",
|
|
)
|
|
for link in links:
|
|
channelids = self.link_channels.get(link.linkid, ())
|
|
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
|
embed.add_field(
|
|
name=f"Link **{link.name}**",
|
|
value=channelstr,
|
|
inline=False
|
|
)
|
|
# TODO: May want paging if over 25 links....
|
|
await ctx.reply(embed=embed)
|
|
|
|
@linker_group.command(
|
|
name="webhook",
|
|
description='Manually configure the webhook for a given channel.'
|
|
)
|
|
async def linker_webhook(self, ctx: LionContext, channel: discord.abc.GuildChannel, webhook: str):
|
|
if not ctx.interaction:
|
|
return
|
|
|
|
hook = discord.Webhook.from_url(webhook, client=self.bot)
|
|
await self.data.LinkHook.table.delete_where(channelid=channel.id)
|
|
await self.data.LinkHook.create(
|
|
channelid=channel.id,
|
|
webhookid=hook.id,
|
|
token=hook.token,
|
|
)
|
|
self.hooks[channel.id] = hook
|
|
await ctx.reply(f"Webhook for {channel.mention} updated!")
|