diff --git a/src/bot.py b/src/bot.py index 4292eefa..852530f1 100644 --- a/src/bot.py +++ b/src/bot.py @@ -80,6 +80,14 @@ async def main(): websockets.serve(sockets.root_handler, '', conf.wserver['port']) ) + crocbot = CrocBot( + config=conf, + data=db, + prefix='!', + initial_channels=conf.croccy.getlist('initial_channels'), + token=conf.croccy['token'], + ) + lionbot = await stack.enter_async_context( LionBot( command_prefix='!', @@ -104,26 +112,15 @@ async def main(): translator=translator, chunk_guilds_at_startup=False, system_monitor=system_monitor, + crocbot=crocbot, ) ) - crocbot = CrocBot( - config=conf, - data=db, - prefix='!', - initial_channels=conf.croccy.getlist('initial_channels'), - token=conf.croccy['token'], - lionbot=lionbot - ) - lionbot.crocbot = crocbot - - crocbot.load_module('modules') - crocstart = asyncio.create_task(start_croccy(crocbot)) lionstart = asyncio.create_task(start_lion(lionbot)) await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED) - crocstart.cancel() - lionstart.cancel() + # crocstart.cancel() + # lionstart.cancel() async def start_lion(lionbot): ctx_bot.set(lionbot) diff --git a/src/meta/CrocBot.py b/src/meta/CrocBot.py index ae632a99..a4742424 100644 --- a/src/meta/CrocBot.py +++ b/src/meta/CrocBot.py @@ -10,10 +10,6 @@ from data import Database from .config import Conf -if TYPE_CHECKING: - from .LionBot import LionBot - - logger = logging.getLogger(__name__) @@ -21,12 +17,11 @@ class CrocBot(commands.Bot): def __init__(self, *args, config: Conf, data: Database, - lionbot: 'LionBot', **kwargs): + **kwargs): super().__init__(*args, **kwargs) self.config = config self.data = data self.pubsub = pubsub.PubSubPool(self) - self.lionbot = lionbot async def event_ready(self): logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 48a5e065..19904f9f 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -24,6 +24,7 @@ from .errors import HandledException, SafeCancellation from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus if TYPE_CHECKING: + from meta.CrocBot import CrocBot from core.cog import CoreCog from core.config import ConfigCog from tracking.voice.cog import VoiceTrackerCog @@ -58,6 +59,7 @@ class LionBot(Bot): initial_extensions: List[str], web_client: ClientSession, app_ipc, testing_guilds: List[int] = [], system_monitor: Optional[SystemMonitor] = None, + crocbot: Optional['CrocBot'] = None, **kwargs ): kwargs.setdefault('tree_cls', LionTree) @@ -73,6 +75,8 @@ class LionBot(Bot): self.app_ipc = app_ipc self.translator = translator + self.crocbot = crocbot + self.system_monitor = system_monitor or SystemMonitor() self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.system_monitor.add_component(self.monitor) diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py index 39ca43aa..75eea8dc 100644 --- a/src/meta/LionCog.py +++ b/src/meta/LionCog.py @@ -1,23 +1,35 @@ -from typing import Any +from functools import partial +from typing import Any, Callable, Optional from discord.ext.commands import Cog from discord.ext import commands as cmds +from twitchio.ext.commands import Command, Bot +from twitchio.ext.commands.meta import CogEvent class LionCog(Cog): # A set of other cogs that this cog depends on depends_on: set['LionCog'] = set() _placeholder_groups_: set[str] + _twitch_cmds_: dict[str, Command] + _twitch_events_: dict[str, CogEvent] + _twitch_events_loaded_: set[Callable] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._placeholder_groups_ = set() + cls._twitch_cmds_ = {} + cls._twitch_events_ = {} for base in reversed(cls.__mro__): for elem, value in base.__dict__.items(): if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'): cls._placeholder_groups_.add(value.name) + elif isinstance(value, Command): + cls._twitch_cmds_[value.name] = value + elif isinstance(value, CogEvent): + cls._twitch_events_[value.name] = value def __new__(cls, *args: Any, **kwargs: Any): # Patch to ensure no placeholder groups are in the command list @@ -34,6 +46,33 @@ class LionCog(Cog): return await super()._inject(bot, *args, *kwargs) + def _load_twitch_methods(self, bot: Bot): + for name, command in self._twitch_cmds_.items(): + command._instance = self + command.cog = self + bot.add_command(command) + + for name, event in self._twitch_events_.items(): + callback = partial(event, self) + self._twitch_events_loaded_.add(callback) + bot.add_event(callback=callback, name=name) + + def _unload_twitch_methods(self, bot: Bot): + for name in self._twitch_cmds_: + bot.remove_command(name) + + for callback in self._twitch_events_loaded_: + bot.remove_event(callback=callback) + + self._twitch_events_loaded_.clear() + + @classmethod + def twitch_event(cls, event: Optional[str] = None): + def decorator(func) -> CogEvent: + event_name = event or func.__name__ + return CogEvent(name=event_name, func=func, module=cls.__module__) + return decorator + @classmethod def placeholder_group(cls, group: cmds.HybridGroup): group._placeholder_group_ = True diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 8eec6d09..b69e5289 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -26,12 +26,12 @@ active_discord = [ '.premium', '.streamalerts', '.test', + '.counters', ] active_twitch = [ '.nowdoing', '.shoutouts', - '.counters', '.tagstrings', ] diff --git a/src/modules/counters/__init__.py b/src/modules/counters/__init__.py index 8990eb41..30f4e6e5 100644 --- a/src/modules/counters/__init__.py +++ b/src/modules/counters/__init__.py @@ -4,10 +4,5 @@ logger = logging.getLogger(__name__) from .cog import CounterCog -def prepare(bot): - bot.add_cog(CounterCog(bot)) - async def setup(bot): - from .lion_cog import CounterCog - await bot.add_cog(CounterCog(bot)) diff --git a/src/modules/counters/cog.py b/src/modules/counters/cog.py index 0cdeba92..9b1dd032 100644 --- a/src/modules/counters/cog.py +++ b/src/modules/counters/cog.py @@ -3,11 +3,14 @@ from enum import Enum from typing import Optional from datetime import timedelta +import discord +from discord.ext import commands as cmds import twitchio from twitchio.ext import commands + from data.queries import ORDER -from meta import CrocBot +from meta import LionCog, LionBot, CrocBot from utils.lib import utc_now from . import logger from .data import CounterData @@ -22,10 +25,11 @@ class PERIOD(Enum): YEAR = ('this year', 'y', 'year', 'yearly') -class CounterCog(commands.Cog): - def __init__(self, bot: CrocBot): +class CounterCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(CounterData()) + self.crocbot: CrocBot = bot.crocbot + self.data = bot.db.load_registry(CounterData()) self.loaded = asyncio.Event() @@ -33,9 +37,18 @@ class CounterCog(commands.Cog): self.counters = {} async def cog_load(self): + self._load_twitch_methods(self.crocbot) + await self.data.init() + await self.load_counters() self.loaded.set() + async def cog_unload(self): + self._unload_twitch_methods(self.crocbot) + + async def cog_check(self, ctx): + return True + async def load_counters(self): """ Initialise counter name cache. @@ -46,18 +59,6 @@ class CounterCog(commands.Cog): f"Loaded {len(self.counters)} counters." ) - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() - - async def cog_check(self, ctx): - await self.ensure_loaded() - return True - # Counters API async def fetch_counter(self, counter: str) -> CounterData.Counter: @@ -171,7 +172,7 @@ class CounterCog(commands.Cog): if period is PERIOD.ALL: start_time = None elif period is PERIOD.STREAM: - streams = await self.bot.fetch_streams(user_ids=[userid]) + streams = await self.crocbot.fetch_streams(user_ids=[userid]) if streams: stream = streams[0] start_time = stream.started_at @@ -199,7 +200,7 @@ class CounterCog(commands.Cog): lb = await self.leaderboard(counter, start_time=start_time) if lb: userids = list(lb.keys()) - users = await self.bot.fetch_users(ids=userids) + users = await self.crocbot.fetch_users(ids=userids) name_map = {user.id: user.display_name for user in users} parts = [] for userid, total in lb.items(): @@ -283,17 +284,9 @@ class CounterCog(commands.Cog): await ctx.reply(await self.formatted_lb('water', args, int(user.id))) @commands.command() - async def reload(self, ctx: commands.Context, *, args: str = ''): - if not (ctx.author.is_mod or ctx.author.is_broadcaster): - return - if not args: - await ctx.reply("Full reload not implemented yet.") - else: - try: - self.bot.reload_module(args) - except Exception: - logger.exception("Failed to reload") - await ctx.reply("Failed to reload module! Check console~") - else: - await ctx.reply("Reloaded!") + async def stuff(self, ctx: commands.Context, *, args: str = ''): + await ctx.reply(f"Stuff {args}") + @cmds.hybrid_command('water') + async def d_water_cmd(self, ctx): + await ctx.reply(repr(ctx)) diff --git a/src/modules/counters/lion_cog.py b/src/modules/counters/lion_cog.py deleted file mode 100644 index 123514b8..00000000 --- a/src/modules/counters/lion_cog.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio -from typing import Optional - -import discord -from discord.ext import commands as cmds -from discord import app_commands as appcmds - -from meta import LionBot, LionCog, LionContext -from meta.errors import UserInputError -from meta.logger import log_wrap -from utils.lib import utc_now -from data.conditions import NULL - -from . import logger -from .data import CounterData - - -class CounterCog(LionCog): - - def __init__(self, bot: LionBot): - self.bot = bot - - self.counter_cog = bot.crocbot.get_cog('CounterCog')