feat(core): Channel hook manager.
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import discord.app_commands as appcmd
|
import discord.app_commands as appcmd
|
||||||
@@ -16,6 +17,7 @@ from .lion import Lions
|
|||||||
from .lion_guild import GuildConfig
|
from .lion_guild import GuildConfig
|
||||||
from .lion_member import MemberConfig
|
from .lion_member import MemberConfig
|
||||||
from .lion_user import UserConfig
|
from .lion_user import UserConfig
|
||||||
|
from .hooks import HookedChannel
|
||||||
|
|
||||||
|
|
||||||
class keydefaultdict(defaultdict):
|
class keydefaultdict(defaultdict):
|
||||||
@@ -54,6 +56,7 @@ class CoreCog(LionCog):
|
|||||||
self.app_cmd_cache: list[discord.app_commands.AppCommand] = []
|
self.app_cmd_cache: list[discord.app_commands.AppCommand] = []
|
||||||
self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {}
|
self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {}
|
||||||
self.mention_cache: dict[str, str] = keydefaultdict(self.mention_cmd)
|
self.mention_cache: dict[str, str] = keydefaultdict(self.mention_cmd)
|
||||||
|
self.hook_cache: WeakValueDictionary[int, HookedChannel] = WeakValueDictionary()
|
||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
# Fetch (and possibly create) core data rows.
|
# Fetch (and possibly create) core data rows.
|
||||||
@@ -91,7 +94,7 @@ class CoreCog(LionCog):
|
|||||||
cache |= subcache
|
cache |= subcache
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def mention_cmd(self, name):
|
def mention_cmd(self, name: str):
|
||||||
"""
|
"""
|
||||||
Create an application command mention for the given names.
|
Create an application command mention for the given names.
|
||||||
|
|
||||||
@@ -103,6 +106,12 @@ class CoreCog(LionCog):
|
|||||||
mention = f"</{name}:1110834049204891730>"
|
mention = f"</{name}:1110834049204891730>"
|
||||||
return mention
|
return mention
|
||||||
|
|
||||||
|
def hooked_channel(self, channelid: int):
|
||||||
|
if (hooked := self.hook_cache.get(channelid, None)) is None:
|
||||||
|
hooked = HookedChannel(self.bot, channelid)
|
||||||
|
self.hook_cache[channelid] = hooked
|
||||||
|
return hooked
|
||||||
|
|
||||||
async def cog_unload(self):
|
async def cog_unload(self):
|
||||||
await self.bot.remove_cog(self.lions.qualified_name)
|
await self.bot.remove_cog(self.lions.qualified_name)
|
||||||
self.bot.remove_listener(self.shard_update_guilds, name='on_guild_join')
|
self.bot.remove_listener(self.shard_update_guilds, name='on_guild_join')
|
||||||
|
|||||||
106
src/core/hooks.py
Normal file
106
src/core/hooks.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from meta import LionBot
|
||||||
|
|
||||||
|
from .data import CoreData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
MISSING = discord.utils.MISSING
|
||||||
|
|
||||||
|
|
||||||
|
class HookedChannel:
|
||||||
|
def __init__(self, bot: LionBot, channelid: int):
|
||||||
|
self.bot = bot
|
||||||
|
self.channelid = channelid
|
||||||
|
|
||||||
|
self.webhook: Optional[discord.Webhook] | MISSING = None
|
||||||
|
self.data: Optional[CoreData.LionHook] = None
|
||||||
|
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channel(self) -> Optional[discord.TextChannel | discord.VoiceChannel | discord.StageChannel]:
|
||||||
|
if not self.bot.is_ready():
|
||||||
|
raise ValueError("Cannot get hooked channel before ready.")
|
||||||
|
channel = self.bot.get_channel(self.channelid)
|
||||||
|
if channel and not isinstance(channel, (discord.TextChannel, discord.VoiceChannel, discord.StageChannel)):
|
||||||
|
raise ValueError(f"Hooked channel expects GuildChannel not '{channel.__class__.__name__}'")
|
||||||
|
return channel
|
||||||
|
|
||||||
|
async def get_webhook(self) -> Optional[discord.Webhook]:
|
||||||
|
"""
|
||||||
|
Fetch the saved discord.Webhook for this channel.
|
||||||
|
|
||||||
|
Uses cached webhook if possible, but instantiates if required.
|
||||||
|
Does not create a new webhook, use `create_webhook` for that.
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
if self.webhook is MISSING:
|
||||||
|
hook = None
|
||||||
|
elif self.webhook is None:
|
||||||
|
# Fetch webhook data
|
||||||
|
data = await CoreData.LionHook.fetch(self.channelid)
|
||||||
|
if data is not None:
|
||||||
|
# Instantiate Webhook
|
||||||
|
hook = self.webhook = data.as_webhook(client=self.bot)
|
||||||
|
else:
|
||||||
|
self.webhook = MISSING
|
||||||
|
hook = None
|
||||||
|
else:
|
||||||
|
hook = self.webhook
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
async def create_webhook(self, **creation_kwargs) -> Optional[discord.Webhook]:
|
||||||
|
"""
|
||||||
|
Create and save a new webhook in this channel.
|
||||||
|
|
||||||
|
Returns None if we could not create a new webhook.
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
if self.webhook is not MISSING:
|
||||||
|
# Delete any existing webhook
|
||||||
|
if self.webhook is not None:
|
||||||
|
try:
|
||||||
|
await self.webhook.delete()
|
||||||
|
except discord.HTTPException as e:
|
||||||
|
logger.info(
|
||||||
|
f"Ignoring exception while refreshing webhook for {self.channelid}: {repr(e)}"
|
||||||
|
)
|
||||||
|
await self.bot.core.data.LionHook.table.delete_where(channelid=self.channelid)
|
||||||
|
self.webhook = MISSING
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
channel = self.channel
|
||||||
|
if channel is not None and channel.permissions_for(channel.guild.me).manage_webhooks:
|
||||||
|
if 'avatar' not in creation_kwargs:
|
||||||
|
avatar = self.bot.user.avatar if self.bot.user else None
|
||||||
|
creation_kwargs['avatar'] = (await avatar.to_file()).fp.read() if avatar else None
|
||||||
|
webhook = await channel.create_webhook(**creation_kwargs)
|
||||||
|
self.data = await self.bot.core.data.LionHook.create(
|
||||||
|
channelid=self.channelid,
|
||||||
|
token=webhook.token,
|
||||||
|
webhookid=webhook.id,
|
||||||
|
)
|
||||||
|
self.webhook = webhook
|
||||||
|
return webhook
|
||||||
|
|
||||||
|
async def invalidate(self, webhook: discord.Webhook):
|
||||||
|
"""
|
||||||
|
Invalidate the given webhook.
|
||||||
|
|
||||||
|
To be used when the webhook has been deleted on the Discord side.
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
if self.webhook is not None and self.webhook is not MISSING and self.webhook.id == webhook.id:
|
||||||
|
# Webhook provided matches current webhook
|
||||||
|
# Delete current webhook
|
||||||
|
self.webhook = MISSING
|
||||||
|
self.data = None
|
||||||
|
await self.bot.core.data.LionHook.table.delete_where(webhookid=webhook.id)
|
||||||
Reference in New Issue
Block a user