Merge disc and twitchio Cogs.

This commit is contained in:
2024-09-06 10:57:07 +10:00
parent 7069c87e8e
commit b7e4acfee2
8 changed files with 81 additions and 81 deletions

View File

@@ -80,6 +80,14 @@ async def main():
websockets.serve(sockets.root_handler, '', conf.wserver['port']) websockets.serve(sockets.root_handler, '', conf.wserver['port'])
) )
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
)
lionbot = await stack.enter_async_context( lionbot = await stack.enter_async_context(
LionBot( LionBot(
command_prefix='!', command_prefix='!',
@@ -104,26 +112,15 @@ async def main():
translator=translator, translator=translator,
chunk_guilds_at_startup=False, chunk_guilds_at_startup=False,
system_monitor=system_monitor, system_monitor=system_monitor,
crocbot=crocbot,
) )
) )
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
lionbot=lionbot
)
lionbot.crocbot = crocbot
crocbot.load_module('modules')
crocstart = asyncio.create_task(start_croccy(crocbot)) crocstart = asyncio.create_task(start_croccy(crocbot))
lionstart = asyncio.create_task(start_lion(lionbot)) lionstart = asyncio.create_task(start_lion(lionbot))
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED) await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
crocstart.cancel() # crocstart.cancel()
lionstart.cancel() # lionstart.cancel()
async def start_lion(lionbot): async def start_lion(lionbot):
ctx_bot.set(lionbot) ctx_bot.set(lionbot)

View File

@@ -10,10 +10,6 @@ from data import Database
from .config import Conf from .config import Conf
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,12 +17,11 @@ class CrocBot(commands.Bot):
def __init__(self, *args, def __init__(self, *args,
config: Conf, config: Conf,
data: Database, data: Database,
lionbot: 'LionBot', **kwargs): **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.config = config self.config = config
self.data = data self.data = data
self.pubsub = pubsub.PubSubPool(self) self.pubsub = pubsub.PubSubPool(self)
self.lionbot = lionbot
async def event_ready(self): async def event_ready(self):
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")

View File

@@ -24,6 +24,7 @@ from .errors import HandledException, SafeCancellation
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from meta.CrocBot import CrocBot
from core.cog import CoreCog from core.cog import CoreCog
from core.config import ConfigCog from core.config import ConfigCog
from tracking.voice.cog import VoiceTrackerCog from tracking.voice.cog import VoiceTrackerCog
@@ -58,6 +59,7 @@ class LionBot(Bot):
initial_extensions: List[str], web_client: ClientSession, app_ipc, initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [], testing_guilds: List[int] = [],
system_monitor: Optional[SystemMonitor] = None, system_monitor: Optional[SystemMonitor] = None,
crocbot: Optional['CrocBot'] = None,
**kwargs **kwargs
): ):
kwargs.setdefault('tree_cls', LionTree) kwargs.setdefault('tree_cls', LionTree)
@@ -73,6 +75,8 @@ class LionBot(Bot):
self.app_ipc = app_ipc self.app_ipc = app_ipc
self.translator = translator self.translator = translator
self.crocbot = crocbot
self.system_monitor = system_monitor or SystemMonitor() self.system_monitor = system_monitor or SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.monitor = ComponentMonitor('LionBot', self._monitor_status)
self.system_monitor.add_component(self.monitor) self.system_monitor.add_component(self.monitor)

View File

@@ -1,23 +1,35 @@
from typing import Any from functools import partial
from typing import Any, Callable, Optional
from discord.ext.commands import Cog from discord.ext.commands import Cog
from discord.ext import commands as cmds from discord.ext import commands as cmds
from twitchio.ext.commands import Command, Bot
from twitchio.ext.commands.meta import CogEvent
class LionCog(Cog): class LionCog(Cog):
# A set of other cogs that this cog depends on # A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set() depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str] _placeholder_groups_: set[str]
_twitch_cmds_: dict[str, Command]
_twitch_events_: dict[str, CogEvent]
_twitch_events_loaded_: set[Callable]
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set() cls._placeholder_groups_ = set()
cls._twitch_cmds_ = {}
cls._twitch_events_ = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items(): for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'): if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name) cls._placeholder_groups_.add(value.name)
elif isinstance(value, Command):
cls._twitch_cmds_[value.name] = value
elif isinstance(value, CogEvent):
cls._twitch_events_[value.name] = value
def __new__(cls, *args: Any, **kwargs: Any): def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list # Patch to ensure no placeholder groups are in the command list
@@ -34,6 +46,33 @@ class LionCog(Cog):
return await super()._inject(bot, *args, *kwargs) return await super()._inject(bot, *args, *kwargs)
def _load_twitch_methods(self, bot: Bot):
for name, command in self._twitch_cmds_.items():
command._instance = self
command.cog = self
bot.add_command(command)
for name, event in self._twitch_events_.items():
callback = partial(event, self)
self._twitch_events_loaded_.add(callback)
bot.add_event(callback=callback, name=name)
def _unload_twitch_methods(self, bot: Bot):
for name in self._twitch_cmds_:
bot.remove_command(name)
for callback in self._twitch_events_loaded_:
bot.remove_event(callback=callback)
self._twitch_events_loaded_.clear()
@classmethod
def twitch_event(cls, event: Optional[str] = None):
def decorator(func) -> CogEvent:
event_name = event or func.__name__
return CogEvent(name=event_name, func=func, module=cls.__module__)
return decorator
@classmethod @classmethod
def placeholder_group(cls, group: cmds.HybridGroup): def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True group._placeholder_group_ = True

View File

@@ -26,12 +26,12 @@ active_discord = [
'.premium', '.premium',
'.streamalerts', '.streamalerts',
'.test', '.test',
'.counters',
] ]
active_twitch = [ active_twitch = [
'.nowdoing', '.nowdoing',
'.shoutouts', '.shoutouts',
'.counters',
'.tagstrings', '.tagstrings',
] ]

View File

@@ -4,10 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import CounterCog from .cog import CounterCog
def prepare(bot):
bot.add_cog(CounterCog(bot))
async def setup(bot): async def setup(bot):
from .lion_cog import CounterCog
await bot.add_cog(CounterCog(bot)) await bot.add_cog(CounterCog(bot))

View File

@@ -3,11 +3,14 @@ from enum import Enum
from typing import Optional from typing import Optional
from datetime import timedelta from datetime import timedelta
import discord
from discord.ext import commands as cmds
import twitchio import twitchio
from twitchio.ext import commands from twitchio.ext import commands
from data.queries import ORDER from data.queries import ORDER
from meta import CrocBot from meta import LionCog, LionBot, CrocBot
from utils.lib import utc_now from utils.lib import utc_now
from . import logger from . import logger
from .data import CounterData from .data import CounterData
@@ -22,10 +25,11 @@ class PERIOD(Enum):
YEAR = ('this year', 'y', 'year', 'yearly') YEAR = ('this year', 'y', 'year', 'yearly')
class CounterCog(commands.Cog): class CounterCog(LionCog):
def __init__(self, bot: CrocBot): def __init__(self, bot: LionBot):
self.bot = bot self.bot = bot
self.data = bot.data.load_registry(CounterData()) self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(CounterData())
self.loaded = asyncio.Event() self.loaded = asyncio.Event()
@@ -33,9 +37,18 @@ class CounterCog(commands.Cog):
self.counters = {} self.counters = {}
async def cog_load(self): async def cog_load(self):
self._load_twitch_methods(self.crocbot)
await self.data.init() await self.data.init()
await self.load_counters()
self.loaded.set() self.loaded.set()
async def cog_unload(self):
self._unload_twitch_methods(self.crocbot)
async def cog_check(self, ctx):
return True
async def load_counters(self): async def load_counters(self):
""" """
Initialise counter name cache. Initialise counter name cache.
@@ -46,18 +59,6 @@ class CounterCog(commands.Cog):
f"Loaded {len(self.counters)} counters." f"Loaded {len(self.counters)} counters."
) )
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
# Counters API # Counters API
async def fetch_counter(self, counter: str) -> CounterData.Counter: async def fetch_counter(self, counter: str) -> CounterData.Counter:
@@ -171,7 +172,7 @@ class CounterCog(commands.Cog):
if period is PERIOD.ALL: if period is PERIOD.ALL:
start_time = None start_time = None
elif period is PERIOD.STREAM: elif period is PERIOD.STREAM:
streams = await self.bot.fetch_streams(user_ids=[userid]) streams = await self.crocbot.fetch_streams(user_ids=[userid])
if streams: if streams:
stream = streams[0] stream = streams[0]
start_time = stream.started_at start_time = stream.started_at
@@ -199,7 +200,7 @@ class CounterCog(commands.Cog):
lb = await self.leaderboard(counter, start_time=start_time) lb = await self.leaderboard(counter, start_time=start_time)
if lb: if lb:
userids = list(lb.keys()) userids = list(lb.keys())
users = await self.bot.fetch_users(ids=userids) users = await self.crocbot.fetch_users(ids=userids)
name_map = {user.id: user.display_name for user in users} name_map = {user.id: user.display_name for user in users}
parts = [] parts = []
for userid, total in lb.items(): for userid, total in lb.items():
@@ -283,17 +284,9 @@ class CounterCog(commands.Cog):
await ctx.reply(await self.formatted_lb('water', args, int(user.id))) await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
@commands.command() @commands.command()
async def reload(self, ctx: commands.Context, *, args: str = ''): async def stuff(self, ctx: commands.Context, *, args: str = ''):
if not (ctx.author.is_mod or ctx.author.is_broadcaster): await ctx.reply(f"Stuff {args}")
return
if not args:
await ctx.reply("Full reload not implemented yet.")
else:
try:
self.bot.reload_module(args)
except Exception:
logger.exception("Failed to reload")
await ctx.reply("Failed to reload module! Check console~")
else:
await ctx.reply("Reloaded!")
@cmds.hybrid_command('water')
async def d_water_cmd(self, ctx):
await ctx.reply(repr(ctx))

View File

@@ -1,23 +0,0 @@
import asyncio
from typing import Optional
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError
from meta.logger import log_wrap
from utils.lib import utc_now
from data.conditions import NULL
from . import logger
from .data import CounterData
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.counter_cog = bot.crocbot.get_cog('CounterCog')