diff --git a/requirements.txt b/requirements.txt index af59fce9..eb86670c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,6 @@ pillow python-dateutil bidict frozendict +TwitchIO +websockets +TwitchAPI diff --git a/src/bot.py b/src/bot.py index 4292eefa..852530f1 100644 --- a/src/bot.py +++ b/src/bot.py @@ -80,6 +80,14 @@ async def main(): websockets.serve(sockets.root_handler, '', conf.wserver['port']) ) + crocbot = CrocBot( + config=conf, + data=db, + prefix='!', + initial_channels=conf.croccy.getlist('initial_channels'), + token=conf.croccy['token'], + ) + lionbot = await stack.enter_async_context( LionBot( command_prefix='!', @@ -104,26 +112,15 @@ async def main(): translator=translator, chunk_guilds_at_startup=False, system_monitor=system_monitor, + crocbot=crocbot, ) ) - crocbot = CrocBot( - config=conf, - data=db, - prefix='!', - initial_channels=conf.croccy.getlist('initial_channels'), - token=conf.croccy['token'], - lionbot=lionbot - ) - lionbot.crocbot = crocbot - - crocbot.load_module('modules') - crocstart = asyncio.create_task(start_croccy(crocbot)) lionstart = asyncio.create_task(start_lion(lionbot)) await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED) - crocstart.cancel() - lionstart.cancel() + # crocstart.cancel() + # lionstart.cancel() async def start_lion(lionbot): ctx_bot.set(lionbot) diff --git a/src/meta/CrocBot.py b/src/meta/CrocBot.py index ae632a99..a4742424 100644 --- a/src/meta/CrocBot.py +++ b/src/meta/CrocBot.py @@ -10,10 +10,6 @@ from data import Database from .config import Conf -if TYPE_CHECKING: - from .LionBot import LionBot - - logger = logging.getLogger(__name__) @@ -21,12 +17,11 @@ class CrocBot(commands.Bot): def __init__(self, *args, config: Conf, data: Database, - lionbot: 'LionBot', **kwargs): + **kwargs): super().__init__(*args, **kwargs) self.config = config self.data = data self.pubsub = pubsub.PubSubPool(self) - self.lionbot = lionbot async def event_ready(self): logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 48a5e065..19904f9f 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -24,6 +24,7 @@ from .errors import HandledException, SafeCancellation from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus if TYPE_CHECKING: + from meta.CrocBot import CrocBot from core.cog import CoreCog from core.config import ConfigCog from tracking.voice.cog import VoiceTrackerCog @@ -58,6 +59,7 @@ class LionBot(Bot): initial_extensions: List[str], web_client: ClientSession, app_ipc, testing_guilds: List[int] = [], system_monitor: Optional[SystemMonitor] = None, + crocbot: Optional['CrocBot'] = None, **kwargs ): kwargs.setdefault('tree_cls', LionTree) @@ -73,6 +75,8 @@ class LionBot(Bot): self.app_ipc = app_ipc self.translator = translator + self.crocbot = crocbot + self.system_monitor = system_monitor or SystemMonitor() self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.system_monitor.add_component(self.monitor) diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py index 39ca43aa..bc28f7c5 100644 --- a/src/meta/LionCog.py +++ b/src/meta/LionCog.py @@ -1,23 +1,36 @@ -from typing import Any +from functools import partial +from typing import Any, Callable, Optional from discord.ext.commands import Cog from discord.ext import commands as cmds +from twitchio.ext import commands +from twitchio.ext.commands import Command, Bot +from twitchio.ext.commands.meta import CogEvent class LionCog(Cog): # A set of other cogs that this cog depends on depends_on: set['LionCog'] = set() _placeholder_groups_: set[str] + _twitch_cmds_: dict[str, Command] + _twitch_events_: dict[str, CogEvent] + _twitch_events_loaded_: set[Callable] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._placeholder_groups_ = set() + cls._twitch_cmds_ = {} + cls._twitch_events_ = {} for base in reversed(cls.__mro__): for elem, value in base.__dict__.items(): if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'): cls._placeholder_groups_.add(value.name) + elif isinstance(value, Command): + cls._twitch_cmds_[value.name] = value + elif isinstance(value, CogEvent): + cls._twitch_events_[value.name] = value def __new__(cls, *args: Any, **kwargs: Any): # Patch to ensure no placeholder groups are in the command list @@ -34,6 +47,51 @@ class LionCog(Cog): return await super()._inject(bot, *args, *kwargs) + def _load_twitch_methods(self, bot: Bot): + for name, command in self._twitch_cmds_.items(): + command._instance = self + command.cog = self + bot.add_command(command) + + for name, event in self._twitch_events_.items(): + callback = partial(event, self) + self._twitch_events_loaded_.add(callback) + bot.add_event(callback=callback, name=name) + + def _unload_twitch_methods(self, bot: Bot): + for name in self._twitch_cmds_: + bot.remove_command(name) + + for callback in self._twitch_events_loaded_: + bot.remove_event(callback=callback) + + self._twitch_events_loaded_.clear() + + @classmethod + def twitch_event(cls, event: Optional[str] = None): + def decorator(func) -> CogEvent: + event_name = event or func.__name__ + return CogEvent(name=event_name, func=func, module=cls.__module__) + return decorator + + async def cog_check(self, ctx): # type: ignore + """ + TwitchIO assumes cog_check is a coroutine, + so here we narrow the check to only a coroutine. + + The ctx maybe either be a twitch command context or a dpy context. + """ + if isinstance(ctx, cmds.Context): + return await self.cog_check_discord(ctx) + if isinstance(ctx, commands.Context): + return await self.cog_check_twitch(ctx) + + async def cog_check_discord(self, ctx: cmds.Context): + return True + + async def cog_check_twitch(self, ctx: commands.Context): + return True + @classmethod def placeholder_group(cls, group: cmds.HybridGroup): group._placeholder_group_ = True diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 8eec6d09..2061986c 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -26,20 +26,12 @@ active_discord = [ '.premium', '.streamalerts', '.test', -] - -active_twitch = [ + '.counters', '.nowdoing', '.shoutouts', - '.counters', '.tagstrings', ] - -def prepare(bot): - for ext in active_twitch: - bot.load_module(this_package + ext) - async def setup(bot): for ext in active_discord: await bot.load_extension(ext, package=this_package) diff --git a/src/modules/counters/__init__.py b/src/modules/counters/__init__.py index 8990eb41..30f4e6e5 100644 --- a/src/modules/counters/__init__.py +++ b/src/modules/counters/__init__.py @@ -4,10 +4,5 @@ logger = logging.getLogger(__name__) from .cog import CounterCog -def prepare(bot): - bot.add_cog(CounterCog(bot)) - async def setup(bot): - from .lion_cog import CounterCog - await bot.add_cog(CounterCog(bot)) diff --git a/src/modules/counters/cog.py b/src/modules/counters/cog.py index 0cdeba92..9b1dd032 100644 --- a/src/modules/counters/cog.py +++ b/src/modules/counters/cog.py @@ -3,11 +3,14 @@ from enum import Enum from typing import Optional from datetime import timedelta +import discord +from discord.ext import commands as cmds import twitchio from twitchio.ext import commands + from data.queries import ORDER -from meta import CrocBot +from meta import LionCog, LionBot, CrocBot from utils.lib import utc_now from . import logger from .data import CounterData @@ -22,10 +25,11 @@ class PERIOD(Enum): YEAR = ('this year', 'y', 'year', 'yearly') -class CounterCog(commands.Cog): - def __init__(self, bot: CrocBot): +class CounterCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(CounterData()) + self.crocbot: CrocBot = bot.crocbot + self.data = bot.db.load_registry(CounterData()) self.loaded = asyncio.Event() @@ -33,9 +37,18 @@ class CounterCog(commands.Cog): self.counters = {} async def cog_load(self): + self._load_twitch_methods(self.crocbot) + await self.data.init() + await self.load_counters() self.loaded.set() + async def cog_unload(self): + self._unload_twitch_methods(self.crocbot) + + async def cog_check(self, ctx): + return True + async def load_counters(self): """ Initialise counter name cache. @@ -46,18 +59,6 @@ class CounterCog(commands.Cog): f"Loaded {len(self.counters)} counters." ) - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() - - async def cog_check(self, ctx): - await self.ensure_loaded() - return True - # Counters API async def fetch_counter(self, counter: str) -> CounterData.Counter: @@ -171,7 +172,7 @@ class CounterCog(commands.Cog): if period is PERIOD.ALL: start_time = None elif period is PERIOD.STREAM: - streams = await self.bot.fetch_streams(user_ids=[userid]) + streams = await self.crocbot.fetch_streams(user_ids=[userid]) if streams: stream = streams[0] start_time = stream.started_at @@ -199,7 +200,7 @@ class CounterCog(commands.Cog): lb = await self.leaderboard(counter, start_time=start_time) if lb: userids = list(lb.keys()) - users = await self.bot.fetch_users(ids=userids) + users = await self.crocbot.fetch_users(ids=userids) name_map = {user.id: user.display_name for user in users} parts = [] for userid, total in lb.items(): @@ -283,17 +284,9 @@ class CounterCog(commands.Cog): await ctx.reply(await self.formatted_lb('water', args, int(user.id))) @commands.command() - async def reload(self, ctx: commands.Context, *, args: str = ''): - if not (ctx.author.is_mod or ctx.author.is_broadcaster): - return - if not args: - await ctx.reply("Full reload not implemented yet.") - else: - try: - self.bot.reload_module(args) - except Exception: - logger.exception("Failed to reload") - await ctx.reply("Failed to reload module! Check console~") - else: - await ctx.reply("Reloaded!") + async def stuff(self, ctx: commands.Context, *, args: str = ''): + await ctx.reply(f"Stuff {args}") + @cmds.hybrid_command('water') + async def d_water_cmd(self, ctx): + await ctx.reply(repr(ctx)) diff --git a/src/modules/counters/lion_cog.py b/src/modules/counters/lion_cog.py deleted file mode 100644 index 123514b8..00000000 --- a/src/modules/counters/lion_cog.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio -from typing import Optional - -import discord -from discord.ext import commands as cmds -from discord import app_commands as appcmds - -from meta import LionBot, LionCog, LionContext -from meta.errors import UserInputError -from meta.logger import log_wrap -from utils.lib import utc_now -from data.conditions import NULL - -from . import logger -from .data import CounterData - - -class CounterCog(LionCog): - - def __init__(self, bot: LionBot): - self.bot = bot - - self.counter_cog = bot.crocbot.get_cog('CounterCog') diff --git a/src/modules/nowdoing/__init__.py b/src/modules/nowdoing/__init__.py index a2fad715..a0db8411 100644 --- a/src/modules/nowdoing/__init__.py +++ b/src/modules/nowdoing/__init__.py @@ -4,6 +4,5 @@ logger = logging.getLogger(__name__) from .cog import NowDoingCog -def prepare(bot): - logger.info("Preparing the nowdoing module.") - bot.add_cog(NowDoingCog(bot)) +async def setup(bot): + await bot.add_cog(NowDoingCog(bot)) diff --git a/src/modules/nowdoing/cog.py b/src/modules/nowdoing/cog.py index 068b16b9..19b12283 100644 --- a/src/modules/nowdoing/cog.py +++ b/src/modules/nowdoing/cog.py @@ -8,7 +8,8 @@ from attr import dataclass import twitchio from twitchio.ext import commands -from meta import CrocBot +from meta import CrocBot, LionCog +from meta.LionBot import LionBot from meta.sockets import Channel, register_channel from utils.lib import strfdelta, utc_now from . import logger @@ -78,10 +79,11 @@ class NowDoingChannel(Channel): }) -class NowDoingCog(commands.Cog): - def __init__(self, bot: CrocBot): +class NowDoingCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(NowListData()) + self.crocbot = bot.crocbot + self.data = bot.db.load_registry(NowListData()) self.channel = NowDoingChannel(self) register_channel(self.channel.name, self.channel) @@ -94,21 +96,19 @@ class NowDoingCog(commands.Cog): await self.data.init() await self.load_tasks() + + self._load_twitch_methods(self.crocbot) self.loaded.set() - async def ensure_loaded(self): - """ - Hack because lib devs decided to remove async cog loading. - """ - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() + async def cog_unload(self): + self.loaded.clear() + self.tasks.clear() + self._unload_twitch_methods(self.crocbot) async def cog_check(self, ctx): - await self.ensure_loaded() + if not self.loaded.is_set(): + await ctx.reply("Tasklists are still loading! Please wait a moment~") + return False return True async def load_tasks(self): @@ -130,6 +130,7 @@ class NowDoingCog(commands.Cog): @commands.command(aliases=['task', 'check']) async def now(self, ctx: commands.Context, *, args: Optional[str] = None): userid = int(ctx.author.id) + args = args.strip() if args else None if args: await self.data.Task.table.delete_where(userid=userid) task = await self.data.Task.create( diff --git a/src/modules/shoutouts/__init__.py b/src/modules/shoutouts/__init__.py index 875d0f52..4671f8c8 100644 --- a/src/modules/shoutouts/__init__.py +++ b/src/modules/shoutouts/__init__.py @@ -4,5 +4,5 @@ logger = logging.getLogger(__name__) from .cog import ShoutoutCog -def prepare(bot): - bot.add_cog(ShoutoutCog(bot)) +async def setup(bot): + await bot.add_cog(ShoutoutCog(bot)) diff --git a/src/modules/shoutouts/cog.py b/src/modules/shoutouts/cog.py index cc8343f6..1ebf56d7 100644 --- a/src/modules/shoutouts/cog.py +++ b/src/modules/shoutouts/cog.py @@ -4,50 +4,50 @@ from typing import Optional import twitchio from twitchio.ext import commands -from meta import CrocBot +from meta import CrocBot, LionBot, LionCog from utils.lib import replace_multiple from . import logger from .data import ShoutoutData -class ShoutoutCog(commands.Cog): +class ShoutoutCog(LionCog): # Future extension: channel defaults and config DEFAULT_SHOUTOUT = """ We think that {name} is a great streamer and you should check them out \ and drop a follow! \ They {areorwere} streaming {game} at {channel} """ - def __init__(self, bot: CrocBot): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(ShoutoutData()) + self.crocbot = bot.crocbot + self.data = bot.db.load_registry(ShoutoutData()) self.loaded = asyncio.Event() async def cog_load(self): await self.data.init() + self._load_twitch_methods(self.crocbot) self.loaded.set() - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() + async def cog_unload(self): + self.loaded.clear() + self._unload_twitch_methods(self.crocbot) async def cog_check(self, ctx): - await self.ensure_loaded() + if not self.loaded.is_set(): + await ctx.reply("Tasklists are still loading! Please wait a moment~") + return False return True async def format_shoutout(self, text: str, user: twitchio.User): - channels = await self.bot.fetch_channels([user.id]) + channels = await self.crocbot.fetch_channels([user.id]) if channels: channel = channels[0] game = channel.game_name or 'Unknown' else: game = 'Unknown' - streams = await self.bot.fetch_streams([user.id]) + streams = await self.crocbot.fetch_streams([user.id]) live = bool(streams) mapping = { diff --git a/src/modules/tagstrings/__init__.py b/src/modules/tagstrings/__init__.py index 51ee112b..c379d1b9 100644 --- a/src/modules/tagstrings/__init__.py +++ b/src/modules/tagstrings/__init__.py @@ -4,5 +4,5 @@ logger = logging.getLogger(__name__) from .cog import TagCog -def prepare(bot): - bot.add_cog(TagCog(bot)) +async def setup(bot): + await bot.add_cog(TagCog(bot)) diff --git a/src/modules/tagstrings/cog.py b/src/modules/tagstrings/cog.py index 9fbda103..f9463539 100644 --- a/src/modules/tagstrings/cog.py +++ b/src/modules/tagstrings/cog.py @@ -6,16 +6,17 @@ import difflib import twitchio from twitchio.ext import commands -from meta import CrocBot +from meta import CrocBot, LionBot, LionCog from utils.lib import utc_now from . import logger from .data import TagData -class TagCog(commands.Cog): - def __init__(self, bot: CrocBot): +class TagCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(TagData()) + self.crocbot = bot.crocbot + self.data = bot.db.load_registry(TagData()) self.loaded = asyncio.Event() @@ -31,19 +32,24 @@ class TagCog(commands.Cog): self.tags.clear() self.tags.update(tags) + logger.info(f"Loaded {len(tags)} into cache.") async def cog_load(self): await self.data.init() await self.load_tags() + self._load_twitch_methods(self.crocbot) self.loaded.set() - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() + async def cog_unload(self): + self.loaded.clear() + self.tags.clear() + self._unload_twitch_methods(self.crocbot) - @commands.Cog.event('event_ready') - async def on_ready(self): - await self.ensure_loaded() + async def cog_check(self, ctx): + if not self.loaded.is_set(): + await ctx.reply("Tasklists are still loading! Please wait a moment~") + return False + return True # API