Refactor into plugin.
This commit is contained in:
449
voicefix/cog.py
Normal file
449
voicefix/cog.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
from cachetools import FIFOCache
|
||||
|
||||
import discord
|
||||
from discord.abc import GuildChannel
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||
from utils.ui import Confirm
|
||||
|
||||
from . import logger
|
||||
from .data import LinkData
|
||||
|
||||
|
||||
async def prepare_attachments(attachments: list[discord.Attachment]):
|
||||
results = []
|
||||
for attach in attachments:
|
||||
try:
|
||||
as_file = await attach.to_file(spoiler=attach.is_spoiler())
|
||||
results.append(as_file)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
async def prepare_embeds(message: discord.Message):
|
||||
embeds = [embed for embed in message.embeds if embed.type == 'rich']
|
||||
if message.reference:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.dark_gray(),
|
||||
description=f"Reply to {message.reference.jump_url}"
|
||||
)
|
||||
embeds.append(embed)
|
||||
return embeds
|
||||
|
||||
|
||||
|
||||
class VoiceFixCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(LinkData())
|
||||
|
||||
# Map of linkids to list of channelids
|
||||
self.link_channels = {}
|
||||
|
||||
# Map of channelids to linkids
|
||||
self.channel_links = {}
|
||||
|
||||
# Map of channelids to initialised discord.Webhook
|
||||
self.hooks = {}
|
||||
|
||||
# Map of messageid to list of (channelid, webhookmsg) pairs, for updates
|
||||
self.message_cache = FIFOCache(maxsize=200)
|
||||
# webhook msgid -> orig msgid
|
||||
self.wmessages = FIFOCache(maxsize=600)
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
await self.reload_links()
|
||||
|
||||
async def reload_links(self):
|
||||
records = await self.data.channel_links.select_where()
|
||||
channel_links = defaultdict(set)
|
||||
link_channels = defaultdict(set)
|
||||
|
||||
for record in records:
|
||||
linkid = record['linkid']
|
||||
channelid = record['channelid']
|
||||
|
||||
channel_links[channelid].add(linkid)
|
||||
link_channels[linkid].add(channelid)
|
||||
|
||||
channelids = list(channel_links.keys())
|
||||
if channelids:
|
||||
await self.data.LinkHook.fetch_where(channelid=channelids)
|
||||
for channelid in channelids:
|
||||
# Will hit cache, so don't need any more data queries
|
||||
await self.fetch_webhook_for(channelid)
|
||||
|
||||
self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()}
|
||||
self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()}
|
||||
|
||||
logger.info(
|
||||
f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels."
|
||||
)
|
||||
|
||||
@LionCog.listener('on_message')
|
||||
async def on_message(self, message: discord.Message):
|
||||
# Don't need this because everything except explicit messages are webhooks now
|
||||
# if self.bot.user and (message.author.id == self.bot.user.id):
|
||||
# return
|
||||
if message.webhook_id:
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
sent = []
|
||||
linkids = self.channel_links.get(message.channel.id, ())
|
||||
if linkids:
|
||||
for linkid in linkids:
|
||||
for channelid in self.link_channels[linkid]:
|
||||
if channelid != message.channel.id:
|
||||
if message.attachments:
|
||||
files = await prepare_attachments(message.attachments)
|
||||
else:
|
||||
files = []
|
||||
|
||||
hook = self.hooks[channelid]
|
||||
avatar = message.author.avatar or message.author.default_avatar
|
||||
msg = await hook.send(
|
||||
content=message.content,
|
||||
wait=True,
|
||||
username=message.author.display_name,
|
||||
avatar_url=avatar.url,
|
||||
embeds=await prepare_embeds(message),
|
||||
files=files,
|
||||
allowed_mentions=discord.AllowedMentions.none()
|
||||
)
|
||||
sent.append((channelid, msg))
|
||||
self.wmessages[msg.id] = message.id
|
||||
if sent:
|
||||
# For easier lookup
|
||||
self.wmessages[message.id] = message.id
|
||||
sent.append((message.channel.id, message))
|
||||
|
||||
self.message_cache[message.id] = sent
|
||||
logger.info(f"Forwarded message {message.id}")
|
||||
|
||||
|
||||
@LionCog.listener('on_message_edit')
|
||||
async def on_message_edit(self, before, after):
|
||||
async with self.lock:
|
||||
cached_sent = self.message_cache.pop(before.id, ())
|
||||
new_sent = []
|
||||
for cid, msg in cached_sent:
|
||||
try:
|
||||
if msg.id != before.id:
|
||||
msg = await msg.edit(
|
||||
content=after.content,
|
||||
embeds=await prepare_embeds(after),
|
||||
)
|
||||
new_sent.append((cid, msg))
|
||||
except discord.NotFound:
|
||||
pass
|
||||
if new_sent:
|
||||
self.message_cache[after.id] = new_sent
|
||||
|
||||
@LionCog.listener('on_message_delete')
|
||||
async def on_message_delete(self, message):
|
||||
async with self.lock:
|
||||
origid = self.wmessages.get(message.id, None)
|
||||
if origid:
|
||||
cached_sent = self.message_cache.pop(origid, ())
|
||||
for _, msg in cached_sent:
|
||||
try:
|
||||
if msg.id != message.id:
|
||||
await msg.delete()
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
@LionCog.listener('on_reaction_add')
|
||||
async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User):
|
||||
async with self.lock:
|
||||
message = reaction.message
|
||||
emoji = reaction.emoji
|
||||
origid = self.wmessages.get(message.id, None)
|
||||
if origid and reaction.count == 1:
|
||||
cached_sent = self.message_cache.get(origid, ())
|
||||
for _, msg in cached_sent:
|
||||
# TODO: Would be better to have a Message and check the reactions
|
||||
try:
|
||||
if msg.id != message.id:
|
||||
await msg.add_reaction(emoji)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
async def fetch_webhook_for(self, channelid) -> discord.Webhook:
|
||||
hook = self.hooks.get(channelid, None)
|
||||
if hook is None:
|
||||
row = await self.data.LinkHook.fetch(channelid)
|
||||
if row is None:
|
||||
channel = self.bot.get_channel(channelid)
|
||||
if channel is None:
|
||||
raise ValueError("Cannot find channel to create hook.")
|
||||
hook = await channel.create_webhook(name="LabRat Channel Link")
|
||||
await self.data.LinkHook.create(
|
||||
channelid=channelid,
|
||||
webhookid=hook.id,
|
||||
token=hook.token,
|
||||
)
|
||||
else:
|
||||
hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot)
|
||||
self.hooks[channelid] = hook
|
||||
return hook
|
||||
|
||||
@cmds.hybrid_group(
|
||||
name='linker',
|
||||
description="Base command group for the channel linker"
|
||||
)
|
||||
@appcmds.default_permissions(manage_channels=True)
|
||||
async def linker_group(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@linker_group.command(
|
||||
name='link',
|
||||
description="Create a new link, or add a channel to an existing link."
|
||||
)
|
||||
@appcmds.describe(
|
||||
name="Name of the new or existing channel link.",
|
||||
channel1="First channel to add to the link.",
|
||||
channel2="Second channel to add to the link.",
|
||||
channel3="Third channel to add to the link.",
|
||||
channel4="Fourth channel to add to the link.",
|
||||
channel5="Fifth channel to add to the link.",
|
||||
channelid="Optionally add a channel by id (for e.g. cross-server links).",
|
||||
)
|
||||
async def linker_link(self, ctx: LionContext,
|
||||
name: str,
|
||||
channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channelid: Optional[str] = None,
|
||||
):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
await ctx.interaction.response.defer(thinking=True)
|
||||
|
||||
# Check if link 'name' already exists, create if not
|
||||
existing = await self.data.Link.fetch_where()
|
||||
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
||||
if link_row is None:
|
||||
# Create
|
||||
link_row = await self.data.Link.create(name=name)
|
||||
link_channels = set()
|
||||
created = True
|
||||
else:
|
||||
records = await self.data.channel_links.select_where(linkid=link_row.linkid)
|
||||
link_channels = {record['channelid'] for record in records}
|
||||
created = False
|
||||
|
||||
# Create webhooks and webhook rows on channels if required
|
||||
maybe_channels = [
|
||||
channel1, channel2, channel3, channel4, channel5,
|
||||
]
|
||||
if channelid and channelid.isdigit():
|
||||
channel = self.bot.get_channel(int(channelid))
|
||||
maybe_channels.append(channel)
|
||||
|
||||
channels = [channel for channel in maybe_channels if channel]
|
||||
for channel in channels:
|
||||
await self.fetch_webhook_for(channel.id)
|
||||
|
||||
# Insert or update the links
|
||||
for channel in channels:
|
||||
if channel.id not in link_channels:
|
||||
await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id)
|
||||
|
||||
await self.reload_links()
|
||||
|
||||
if created:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Link Created",
|
||||
description=(
|
||||
"Created the link **{name}** and linked channels:\n{channels}"
|
||||
).format(name=name, channels=', '.join(channel.mention for channel in channels))
|
||||
)
|
||||
else:
|
||||
channelids = self.link_channels[link_row.linkid]
|
||||
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Channels Linked",
|
||||
description=(
|
||||
"Updated the link **{name}** to link the following channels:\n{channelstr}"
|
||||
).format(name=link_row.name, channelstr=channelstr)
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@linker_group.command(
|
||||
name='unlink',
|
||||
description="Destroy a link, or remove a channel from a link."
|
||||
)
|
||||
@appcmds.describe(
|
||||
name="Name of the link to destroy",
|
||||
channel="Channel to remove from the link.",
|
||||
)
|
||||
async def linker_unlink(self, ctx: LionContext,
|
||||
name: str, channel: Optional[GuildChannel] = None):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
# Get the link, error if it doesn't exist
|
||||
existing = await self.data.Link.fetch_where()
|
||||
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
||||
if link_row is None:
|
||||
raise UserInputError(
|
||||
f"Link **{name}** doesn't exist!"
|
||||
)
|
||||
|
||||
link_channelids = self.link_channels.get(link_row.linkid, ())
|
||||
|
||||
if channel is not None:
|
||||
# If channel was given, remove channel from link and ack
|
||||
if channel.id not in link_channelids:
|
||||
raise UserInputError(
|
||||
f"{channel.mention} is not linked in **{link_row.name}**!"
|
||||
)
|
||||
await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid)
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Channel Unlinked",
|
||||
description=f"{channel.mention} has been removed from **{link_row.name}**."
|
||||
)
|
||||
else:
|
||||
# Otherwise, confirm link destroy, delete link row, and ack
|
||||
channels = ', '.join(f"<#{cid}>" for cid in link_channelids)
|
||||
confirm = Confirm(
|
||||
f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}",
|
||||
ctx.author.id,
|
||||
)
|
||||
confirm.embed.colour = discord.Colour.red()
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction)
|
||||
except ResponseTimedOut:
|
||||
result = False
|
||||
if not result:
|
||||
raise SafeCancellation
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Link removed",
|
||||
description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}"
|
||||
)
|
||||
await link_row.delete()
|
||||
|
||||
await self.reload_links()
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@linker_link.autocomplete('name')
|
||||
async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
Autocomplete an existing link.
|
||||
"""
|
||||
existing = await self.data.Link.fetch_where()
|
||||
names = [row.name for row in existing]
|
||||
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
||||
if not matching:
|
||||
choice = appcmds.Choice(
|
||||
name=f"Create a new link '{partial}'",
|
||||
value=partial
|
||||
)
|
||||
choices = [choice]
|
||||
else:
|
||||
choices = [
|
||||
appcmds.Choice(
|
||||
name=f"Link {name}",
|
||||
value=name
|
||||
)
|
||||
for name in matching
|
||||
]
|
||||
return choices
|
||||
|
||||
@linker_unlink.autocomplete('name')
|
||||
async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
Autocomplete an existing link.
|
||||
"""
|
||||
existing = await self.data.Link.fetch_where()
|
||||
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
||||
if not matching:
|
||||
choice = appcmds.Choice(
|
||||
name=f"No existing links matching '{partial}'",
|
||||
value=partial
|
||||
)
|
||||
choices = [choice]
|
||||
else:
|
||||
choices = [
|
||||
appcmds.Choice(
|
||||
name=f"Link {name}",
|
||||
value=name
|
||||
)
|
||||
for name in matching
|
||||
]
|
||||
return choices
|
||||
|
||||
@linker_group.command(
|
||||
name='links',
|
||||
description="Display the existing channel links."
|
||||
)
|
||||
async def linker_links(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
await ctx.interaction.response.defer(thinking=True)
|
||||
|
||||
links = await self.data.Link.fetch_where()
|
||||
|
||||
if not links:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.light_grey(),
|
||||
title="No channel links have been set up!",
|
||||
description="Create a new link and add channels with {linker}".format(
|
||||
linker=self.bot.core.mention_cmd('linker link')
|
||||
)
|
||||
)
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=f"Channel Links in {ctx.guild.name}",
|
||||
)
|
||||
for link in links:
|
||||
channelids = self.link_channels.get(link.linkid, ())
|
||||
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
||||
embed.add_field(
|
||||
name=f"Link **{link.name}**",
|
||||
value=channelstr,
|
||||
inline=False
|
||||
)
|
||||
# TODO: May want paging if over 25 links....
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@linker_group.command(
|
||||
name="webhook",
|
||||
description='Manually configure the webhook for a given channel.'
|
||||
)
|
||||
async def linker_webhook(self, ctx: LionContext, channel: discord.abc.GuildChannel, webhook: str):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
hook = discord.Webhook.from_url(webhook, client=self.bot)
|
||||
existing = await self.data.LinkHook.fetch(channel.id)
|
||||
if existing:
|
||||
await existing.update(webhookid=hook.id, token=hook.token)
|
||||
else:
|
||||
await self.data.LinkHook.create(
|
||||
channelid=channel.id,
|
||||
webhookid=hook.id,
|
||||
token=hook.token,
|
||||
)
|
||||
self.hooks[channel.id] = hook
|
||||
await ctx.reply(f"Webhook for {channel.mention} updated!")
|
||||
Reference in New Issue
Block a user