From b7e4acfee2cbde62950e1ab9886c12d436adcf28 Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 6 Sep 2024 10:57:07 +1000 Subject: [PATCH 01/10] Merge disc and twitchio Cogs. --- src/bot.py | 25 +++++++-------- src/meta/CrocBot.py | 7 +--- src/meta/LionBot.py | 4 +++ src/meta/LionCog.py | 41 +++++++++++++++++++++++- src/modules/__init__.py | 2 +- src/modules/counters/__init__.py | 5 --- src/modules/counters/cog.py | 55 ++++++++++++++------------------ src/modules/counters/lion_cog.py | 23 ------------- 8 files changed, 81 insertions(+), 81 deletions(-) delete mode 100644 src/modules/counters/lion_cog.py diff --git a/src/bot.py b/src/bot.py index 4292eefa..852530f1 100644 --- a/src/bot.py +++ b/src/bot.py @@ -80,6 +80,14 @@ async def main(): websockets.serve(sockets.root_handler, '', conf.wserver['port']) ) + crocbot = CrocBot( + config=conf, + data=db, + prefix='!', + initial_channels=conf.croccy.getlist('initial_channels'), + token=conf.croccy['token'], + ) + lionbot = await stack.enter_async_context( LionBot( command_prefix='!', @@ -104,26 +112,15 @@ async def main(): translator=translator, chunk_guilds_at_startup=False, system_monitor=system_monitor, + crocbot=crocbot, ) ) - crocbot = CrocBot( - config=conf, - data=db, - prefix='!', - initial_channels=conf.croccy.getlist('initial_channels'), - token=conf.croccy['token'], - lionbot=lionbot - ) - lionbot.crocbot = crocbot - - crocbot.load_module('modules') - crocstart = asyncio.create_task(start_croccy(crocbot)) lionstart = asyncio.create_task(start_lion(lionbot)) await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED) - crocstart.cancel() - lionstart.cancel() + # crocstart.cancel() + # lionstart.cancel() async def start_lion(lionbot): ctx_bot.set(lionbot) diff --git a/src/meta/CrocBot.py b/src/meta/CrocBot.py index ae632a99..a4742424 100644 --- a/src/meta/CrocBot.py +++ b/src/meta/CrocBot.py @@ -10,10 +10,6 @@ from data import Database from .config import Conf -if TYPE_CHECKING: - from .LionBot import LionBot - - logger = logging.getLogger(__name__) @@ -21,12 +17,11 @@ class CrocBot(commands.Bot): def __init__(self, *args, config: Conf, data: Database, - lionbot: 'LionBot', **kwargs): + **kwargs): super().__init__(*args, **kwargs) self.config = config self.data = data self.pubsub = pubsub.PubSubPool(self) - self.lionbot = lionbot async def event_ready(self): logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 48a5e065..19904f9f 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -24,6 +24,7 @@ from .errors import HandledException, SafeCancellation from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus if TYPE_CHECKING: + from meta.CrocBot import CrocBot from core.cog import CoreCog from core.config import ConfigCog from tracking.voice.cog import VoiceTrackerCog @@ -58,6 +59,7 @@ class LionBot(Bot): initial_extensions: List[str], web_client: ClientSession, app_ipc, testing_guilds: List[int] = [], system_monitor: Optional[SystemMonitor] = None, + crocbot: Optional['CrocBot'] = None, **kwargs ): kwargs.setdefault('tree_cls', LionTree) @@ -73,6 +75,8 @@ class LionBot(Bot): self.app_ipc = app_ipc self.translator = translator + self.crocbot = crocbot + self.system_monitor = system_monitor or SystemMonitor() self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.system_monitor.add_component(self.monitor) diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py index 39ca43aa..75eea8dc 100644 --- a/src/meta/LionCog.py +++ b/src/meta/LionCog.py @@ -1,23 +1,35 @@ -from typing import Any +from functools import partial +from typing import Any, Callable, Optional from discord.ext.commands import Cog from discord.ext import commands as cmds +from twitchio.ext.commands import Command, Bot +from twitchio.ext.commands.meta import CogEvent class LionCog(Cog): # A set of other cogs that this cog depends on depends_on: set['LionCog'] = set() _placeholder_groups_: set[str] + _twitch_cmds_: dict[str, Command] + _twitch_events_: dict[str, CogEvent] + _twitch_events_loaded_: set[Callable] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._placeholder_groups_ = set() + cls._twitch_cmds_ = {} + cls._twitch_events_ = {} for base in reversed(cls.__mro__): for elem, value in base.__dict__.items(): if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'): cls._placeholder_groups_.add(value.name) + elif isinstance(value, Command): + cls._twitch_cmds_[value.name] = value + elif isinstance(value, CogEvent): + cls._twitch_events_[value.name] = value def __new__(cls, *args: Any, **kwargs: Any): # Patch to ensure no placeholder groups are in the command list @@ -34,6 +46,33 @@ class LionCog(Cog): return await super()._inject(bot, *args, *kwargs) + def _load_twitch_methods(self, bot: Bot): + for name, command in self._twitch_cmds_.items(): + command._instance = self + command.cog = self + bot.add_command(command) + + for name, event in self._twitch_events_.items(): + callback = partial(event, self) + self._twitch_events_loaded_.add(callback) + bot.add_event(callback=callback, name=name) + + def _unload_twitch_methods(self, bot: Bot): + for name in self._twitch_cmds_: + bot.remove_command(name) + + for callback in self._twitch_events_loaded_: + bot.remove_event(callback=callback) + + self._twitch_events_loaded_.clear() + + @classmethod + def twitch_event(cls, event: Optional[str] = None): + def decorator(func) -> CogEvent: + event_name = event or func.__name__ + return CogEvent(name=event_name, func=func, module=cls.__module__) + return decorator + @classmethod def placeholder_group(cls, group: cmds.HybridGroup): group._placeholder_group_ = True diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 8eec6d09..b69e5289 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -26,12 +26,12 @@ active_discord = [ '.premium', '.streamalerts', '.test', + '.counters', ] active_twitch = [ '.nowdoing', '.shoutouts', - '.counters', '.tagstrings', ] diff --git a/src/modules/counters/__init__.py b/src/modules/counters/__init__.py index 8990eb41..30f4e6e5 100644 --- a/src/modules/counters/__init__.py +++ b/src/modules/counters/__init__.py @@ -4,10 +4,5 @@ logger = logging.getLogger(__name__) from .cog import CounterCog -def prepare(bot): - bot.add_cog(CounterCog(bot)) - async def setup(bot): - from .lion_cog import CounterCog - await bot.add_cog(CounterCog(bot)) diff --git a/src/modules/counters/cog.py b/src/modules/counters/cog.py index 0cdeba92..9b1dd032 100644 --- a/src/modules/counters/cog.py +++ b/src/modules/counters/cog.py @@ -3,11 +3,14 @@ from enum import Enum from typing import Optional from datetime import timedelta +import discord +from discord.ext import commands as cmds import twitchio from twitchio.ext import commands + from data.queries import ORDER -from meta import CrocBot +from meta import LionCog, LionBot, CrocBot from utils.lib import utc_now from . import logger from .data import CounterData @@ -22,10 +25,11 @@ class PERIOD(Enum): YEAR = ('this year', 'y', 'year', 'yearly') -class CounterCog(commands.Cog): - def __init__(self, bot: CrocBot): +class CounterCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(CounterData()) + self.crocbot: CrocBot = bot.crocbot + self.data = bot.db.load_registry(CounterData()) self.loaded = asyncio.Event() @@ -33,9 +37,18 @@ class CounterCog(commands.Cog): self.counters = {} async def cog_load(self): + self._load_twitch_methods(self.crocbot) + await self.data.init() + await self.load_counters() self.loaded.set() + async def cog_unload(self): + self._unload_twitch_methods(self.crocbot) + + async def cog_check(self, ctx): + return True + async def load_counters(self): """ Initialise counter name cache. @@ -46,18 +59,6 @@ class CounterCog(commands.Cog): f"Loaded {len(self.counters)} counters." ) - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() - - async def cog_check(self, ctx): - await self.ensure_loaded() - return True - # Counters API async def fetch_counter(self, counter: str) -> CounterData.Counter: @@ -171,7 +172,7 @@ class CounterCog(commands.Cog): if period is PERIOD.ALL: start_time = None elif period is PERIOD.STREAM: - streams = await self.bot.fetch_streams(user_ids=[userid]) + streams = await self.crocbot.fetch_streams(user_ids=[userid]) if streams: stream = streams[0] start_time = stream.started_at @@ -199,7 +200,7 @@ class CounterCog(commands.Cog): lb = await self.leaderboard(counter, start_time=start_time) if lb: userids = list(lb.keys()) - users = await self.bot.fetch_users(ids=userids) + users = await self.crocbot.fetch_users(ids=userids) name_map = {user.id: user.display_name for user in users} parts = [] for userid, total in lb.items(): @@ -283,17 +284,9 @@ class CounterCog(commands.Cog): await ctx.reply(await self.formatted_lb('water', args, int(user.id))) @commands.command() - async def reload(self, ctx: commands.Context, *, args: str = ''): - if not (ctx.author.is_mod or ctx.author.is_broadcaster): - return - if not args: - await ctx.reply("Full reload not implemented yet.") - else: - try: - self.bot.reload_module(args) - except Exception: - logger.exception("Failed to reload") - await ctx.reply("Failed to reload module! Check console~") - else: - await ctx.reply("Reloaded!") + async def stuff(self, ctx: commands.Context, *, args: str = ''): + await ctx.reply(f"Stuff {args}") + @cmds.hybrid_command('water') + async def d_water_cmd(self, ctx): + await ctx.reply(repr(ctx)) diff --git a/src/modules/counters/lion_cog.py b/src/modules/counters/lion_cog.py deleted file mode 100644 index 123514b8..00000000 --- a/src/modules/counters/lion_cog.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio -from typing import Optional - -import discord -from discord.ext import commands as cmds -from discord import app_commands as appcmds - -from meta import LionBot, LionCog, LionContext -from meta.errors import UserInputError -from meta.logger import log_wrap -from utils.lib import utc_now -from data.conditions import NULL - -from . import logger -from .data import CounterData - - -class CounterCog(LionCog): - - def __init__(self, bot: LionBot): - self.bot = bot - - self.counter_cog = bot.crocbot.get_cog('CounterCog') From 41f755795f029f086730049c71a2760a9b4852af Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 6 Sep 2024 10:59:47 +1000 Subject: [PATCH 02/10] feat: Start twitch user auth module. --- src/twitch/__init__.py | 9 +++++++++ src/twitch/authserver.py | 7 +++++++ src/twitch/cog.py | 31 +++++++++++++++++++++++++++++++ src/twitch/data.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+) create mode 100644 src/twitch/__init__.py create mode 100644 src/twitch/authserver.py create mode 100644 src/twitch/cog.py create mode 100644 src/twitch/data.py diff --git a/src/twitch/__init__.py b/src/twitch/__init__.py new file mode 100644 index 00000000..53e33e89 --- /dev/null +++ b/src/twitch/__init__.py @@ -0,0 +1,9 @@ +import logging + +logger = logging.getLogger(__name__) + +from .cog import TwitchAuthCog + +async def setup(bot): + await bot.add_cog(TwitchAuthCog(bot)) + diff --git a/src/twitch/authserver.py b/src/twitch/authserver.py new file mode 100644 index 00000000..e87c433c --- /dev/null +++ b/src/twitch/authserver.py @@ -0,0 +1,7 @@ +""" +We want to open an aiohttp server and listen on a configured port. +When we get a request, we validate it to be 'of twitch form', +parse out the error or access token, state, etc, and then pass that information on. + +Passing on maybe done through webhook server? +""" diff --git a/src/twitch/cog.py b/src/twitch/cog.py new file mode 100644 index 00000000..5dfdc39d --- /dev/null +++ b/src/twitch/cog.py @@ -0,0 +1,31 @@ +import asyncio +from enum import Enum +from typing import Optional +from datetime import timedelta + +import discord +from discord.ext import commands as cmds +import twitchio +from twitchio.ext import commands + + +from data.queries import ORDER +from meta import LionCog, LionBot, CrocBot +from utils.lib import utc_now +from . import logger +from .data import TwitchAuthData + + +class TwitchAuthCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(TwitchAuthData()) + + async def cog_load(self): + await self.data.init() + + # ----- Auth API ----- + + async def fetch_client_for(self, userid: int): + ... + diff --git a/src/twitch/data.py b/src/twitch/data.py new file mode 100644 index 00000000..eed13589 --- /dev/null +++ b/src/twitch/data.py @@ -0,0 +1,28 @@ +from data import Registry, RowModel +from data.columns import Integer, String, Timestamp + + +class TwitchAuthData(Registry): + class UserAuthRow(RowModel): + """ + Schema + ------ + CREATE TABLE twitch_tokens( + userid BIGINT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ + ); + + """ + _tablename_ = 'twitch_tokens' + _cache_ = {} + + userid = Integer(primary=True) + access_token = String() + expires_at = Timestamp() + refresh_token = String() + obtained_at = Timestamp() + +# TODO: Scopes From 5216107db05aaf22551853b48a1a1f4d2296e26d Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 6 Sep 2024 11:00:36 +1000 Subject: [PATCH 03/10] routine: Update requirements. --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index af59fce9..eb86670c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,6 @@ pillow python-dateutil bidict frozendict +TwitchIO +websockets +TwitchAPI From 99bb1958a8e64fc42973b5ec5a38c950fe4e9b9d Mon Sep 17 00:00:00 2001 From: Interitio Date: Sun, 8 Sep 2024 16:17:34 +1000 Subject: [PATCH 04/10] (LionCog): Split check types. --- src/meta/LionCog.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py index 75eea8dc..bc28f7c5 100644 --- a/src/meta/LionCog.py +++ b/src/meta/LionCog.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional from discord.ext.commands import Cog from discord.ext import commands as cmds +from twitchio.ext import commands from twitchio.ext.commands import Command, Bot from twitchio.ext.commands.meta import CogEvent @@ -73,6 +74,24 @@ class LionCog(Cog): return CogEvent(name=event_name, func=func, module=cls.__module__) return decorator + async def cog_check(self, ctx): # type: ignore + """ + TwitchIO assumes cog_check is a coroutine, + so here we narrow the check to only a coroutine. + + The ctx maybe either be a twitch command context or a dpy context. + """ + if isinstance(ctx, cmds.Context): + return await self.cog_check_discord(ctx) + if isinstance(ctx, commands.Context): + return await self.cog_check_twitch(ctx) + + async def cog_check_discord(self, ctx: cmds.Context): + return True + + async def cog_check_twitch(self, ctx: commands.Context): + return True + @classmethod def placeholder_group(cls, group: cmds.HybridGroup): group._placeholder_group_ = True From 85c7aeb3b6c797594ec2e77498042df6ba148e28 Mon Sep 17 00:00:00 2001 From: Interitio Date: Sun, 8 Sep 2024 16:35:37 +1000 Subject: [PATCH 05/10] (nowdoing): Migrate to merged LionCog. --- src/modules/__init__.py | 2 +- src/modules/nowdoing/__init__.py | 5 ++--- src/modules/nowdoing/cog.py | 31 ++++++++++++++++--------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/modules/__init__.py b/src/modules/__init__.py index b69e5289..83d889d8 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -27,10 +27,10 @@ active_discord = [ '.streamalerts', '.test', '.counters', + '.nowdoing', ] active_twitch = [ - '.nowdoing', '.shoutouts', '.tagstrings', ] diff --git a/src/modules/nowdoing/__init__.py b/src/modules/nowdoing/__init__.py index a2fad715..a0db8411 100644 --- a/src/modules/nowdoing/__init__.py +++ b/src/modules/nowdoing/__init__.py @@ -4,6 +4,5 @@ logger = logging.getLogger(__name__) from .cog import NowDoingCog -def prepare(bot): - logger.info("Preparing the nowdoing module.") - bot.add_cog(NowDoingCog(bot)) +async def setup(bot): + await bot.add_cog(NowDoingCog(bot)) diff --git a/src/modules/nowdoing/cog.py b/src/modules/nowdoing/cog.py index 068b16b9..19b12283 100644 --- a/src/modules/nowdoing/cog.py +++ b/src/modules/nowdoing/cog.py @@ -8,7 +8,8 @@ from attr import dataclass import twitchio from twitchio.ext import commands -from meta import CrocBot +from meta import CrocBot, LionCog +from meta.LionBot import LionBot from meta.sockets import Channel, register_channel from utils.lib import strfdelta, utc_now from . import logger @@ -78,10 +79,11 @@ class NowDoingChannel(Channel): }) -class NowDoingCog(commands.Cog): - def __init__(self, bot: CrocBot): +class NowDoingCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(NowListData()) + self.crocbot = bot.crocbot + self.data = bot.db.load_registry(NowListData()) self.channel = NowDoingChannel(self) register_channel(self.channel.name, self.channel) @@ -94,21 +96,19 @@ class NowDoingCog(commands.Cog): await self.data.init() await self.load_tasks() + + self._load_twitch_methods(self.crocbot) self.loaded.set() - async def ensure_loaded(self): - """ - Hack because lib devs decided to remove async cog loading. - """ - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() + async def cog_unload(self): + self.loaded.clear() + self.tasks.clear() + self._unload_twitch_methods(self.crocbot) async def cog_check(self, ctx): - await self.ensure_loaded() + if not self.loaded.is_set(): + await ctx.reply("Tasklists are still loading! Please wait a moment~") + return False return True async def load_tasks(self): @@ -130,6 +130,7 @@ class NowDoingCog(commands.Cog): @commands.command(aliases=['task', 'check']) async def now(self, ctx: commands.Context, *, args: Optional[str] = None): userid = int(ctx.author.id) + args = args.strip() if args else None if args: await self.data.Task.table.delete_where(userid=userid) task = await self.data.Task.create( From 75ab3d58cbae16e98f4e390341d8a42ea2dfbd70 Mon Sep 17 00:00:00 2001 From: Interitio Date: Sun, 8 Sep 2024 16:47:59 +1000 Subject: [PATCH 06/10] (shoutouts): Migrate to merged LionCog. --- src/modules/__init__.py | 2 +- src/modules/shoutouts/__init__.py | 4 ++-- src/modules/shoutouts/cog.py | 28 ++++++++++++++-------------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 83d889d8..203eac9e 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -28,10 +28,10 @@ active_discord = [ '.test', '.counters', '.nowdoing', + '.shoutouts', ] active_twitch = [ - '.shoutouts', '.tagstrings', ] diff --git a/src/modules/shoutouts/__init__.py b/src/modules/shoutouts/__init__.py index 875d0f52..4671f8c8 100644 --- a/src/modules/shoutouts/__init__.py +++ b/src/modules/shoutouts/__init__.py @@ -4,5 +4,5 @@ logger = logging.getLogger(__name__) from .cog import ShoutoutCog -def prepare(bot): - bot.add_cog(ShoutoutCog(bot)) +async def setup(bot): + await bot.add_cog(ShoutoutCog(bot)) diff --git a/src/modules/shoutouts/cog.py b/src/modules/shoutouts/cog.py index cc8343f6..1ebf56d7 100644 --- a/src/modules/shoutouts/cog.py +++ b/src/modules/shoutouts/cog.py @@ -4,50 +4,50 @@ from typing import Optional import twitchio from twitchio.ext import commands -from meta import CrocBot +from meta import CrocBot, LionBot, LionCog from utils.lib import replace_multiple from . import logger from .data import ShoutoutData -class ShoutoutCog(commands.Cog): +class ShoutoutCog(LionCog): # Future extension: channel defaults and config DEFAULT_SHOUTOUT = """ We think that {name} is a great streamer and you should check them out \ and drop a follow! \ They {areorwere} streaming {game} at {channel} """ - def __init__(self, bot: CrocBot): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(ShoutoutData()) + self.crocbot = bot.crocbot + self.data = bot.db.load_registry(ShoutoutData()) self.loaded = asyncio.Event() async def cog_load(self): await self.data.init() + self._load_twitch_methods(self.crocbot) self.loaded.set() - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() - - @commands.Cog.event('event_ready') # type: ignore - async def on_ready(self): - await self.ensure_loaded() + async def cog_unload(self): + self.loaded.clear() + self._unload_twitch_methods(self.crocbot) async def cog_check(self, ctx): - await self.ensure_loaded() + if not self.loaded.is_set(): + await ctx.reply("Tasklists are still loading! Please wait a moment~") + return False return True async def format_shoutout(self, text: str, user: twitchio.User): - channels = await self.bot.fetch_channels([user.id]) + channels = await self.crocbot.fetch_channels([user.id]) if channels: channel = channels[0] game = channel.game_name or 'Unknown' else: game = 'Unknown' - streams = await self.bot.fetch_streams([user.id]) + streams = await self.crocbot.fetch_streams([user.id]) live = bool(streams) mapping = { From 970661fe05995ad4981f6253fcdbf89ed29afd61 Mon Sep 17 00:00:00 2001 From: Interitio Date: Sun, 8 Sep 2024 16:55:55 +1000 Subject: [PATCH 07/10] (tags): Migrated to merged LionCog. --- src/modules/__init__.py | 8 -------- src/modules/tagstrings/__init__.py | 4 ++-- src/modules/tagstrings/cog.py | 26 ++++++++++++++++---------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 203eac9e..2061986c 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -29,17 +29,9 @@ active_discord = [ '.counters', '.nowdoing', '.shoutouts', -] - -active_twitch = [ '.tagstrings', ] - -def prepare(bot): - for ext in active_twitch: - bot.load_module(this_package + ext) - async def setup(bot): for ext in active_discord: await bot.load_extension(ext, package=this_package) diff --git a/src/modules/tagstrings/__init__.py b/src/modules/tagstrings/__init__.py index 51ee112b..c379d1b9 100644 --- a/src/modules/tagstrings/__init__.py +++ b/src/modules/tagstrings/__init__.py @@ -4,5 +4,5 @@ logger = logging.getLogger(__name__) from .cog import TagCog -def prepare(bot): - bot.add_cog(TagCog(bot)) +async def setup(bot): + await bot.add_cog(TagCog(bot)) diff --git a/src/modules/tagstrings/cog.py b/src/modules/tagstrings/cog.py index 9fbda103..f9463539 100644 --- a/src/modules/tagstrings/cog.py +++ b/src/modules/tagstrings/cog.py @@ -6,16 +6,17 @@ import difflib import twitchio from twitchio.ext import commands -from meta import CrocBot +from meta import CrocBot, LionBot, LionCog from utils.lib import utc_now from . import logger from .data import TagData -class TagCog(commands.Cog): - def __init__(self, bot: CrocBot): +class TagCog(LionCog): + def __init__(self, bot: LionBot): self.bot = bot - self.data = bot.data.load_registry(TagData()) + self.crocbot = bot.crocbot + self.data = bot.db.load_registry(TagData()) self.loaded = asyncio.Event() @@ -31,19 +32,24 @@ class TagCog(commands.Cog): self.tags.clear() self.tags.update(tags) + logger.info(f"Loaded {len(tags)} into cache.") async def cog_load(self): await self.data.init() await self.load_tags() + self._load_twitch_methods(self.crocbot) self.loaded.set() - async def ensure_loaded(self): - if not self.loaded.is_set(): - await self.cog_load() + async def cog_unload(self): + self.loaded.clear() + self.tags.clear() + self._unload_twitch_methods(self.crocbot) - @commands.Cog.event('event_ready') - async def on_ready(self): - await self.ensure_loaded() + async def cog_check(self, ctx): + if not self.loaded.is_set(): + await ctx.reply("Tasklists are still loading! Please wait a moment~") + return False + return True # API From 44d6d7749448cb1098c229a795dbdb11b9dd0cd4 Mon Sep 17 00:00:00 2001 From: Interitio Date: Mon, 23 Sep 2024 15:56:18 +1000 Subject: [PATCH 08/10] feat(twitch): Add authentication server. --- src/twitch/authclient.py | 50 ++++++++++++++++++++++ src/twitch/authserver.py | 91 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 src/twitch/authclient.py diff --git a/src/twitch/authclient.py b/src/twitch/authclient.py new file mode 100644 index 00000000..509b080c --- /dev/null +++ b/src/twitch/authclient.py @@ -0,0 +1,50 @@ +""" +Testing client for the twitch AuthServer. +""" +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + +import asyncio +import aiohttp +from twitchAPI.twitch import Twitch +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.type import AuthScope + +from meta.config import conf + + +URI = "http://localhost:3000/twiauth/confirm" +TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ] + +async def main(): + # Load in client id and secret + twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret']) + auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI) + url = auth.return_auth_url() + + # Post url to user + print(url) + + # Send listen request to server + # Wait for listen request + async with aiohttp.ClientSession() as session: + async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws: + await ws.send_json({'state': auth.state}) + result = await ws.receive_json() + + # Hopefully get back code, print the response + print(f"Recieved: {result}") + + # Authorise with code and client details + tokens = await auth.authenticate(user_token=result['code']) + if tokens: + token, refresh = tokens + await twitch.set_user_authentication(token, TARGET_SCOPE, refresh) + print(f"Authorised!") + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/twitch/authserver.py b/src/twitch/authserver.py index e87c433c..b26c5953 100644 --- a/src/twitch/authserver.py +++ b/src/twitch/authserver.py @@ -1,7 +1,86 @@ -""" -We want to open an aiohttp server and listen on a configured port. -When we get a request, we validate it to be 'of twitch form', -parse out the error or access token, state, etc, and then pass that information on. +import logging +import uuid +import asyncio +from contextvars import ContextVar -Passing on maybe done through webhook server? -""" +import aiohttp +from aiohttp import web + +logger = logging.getLogger(__name__) +reqid: ContextVar[str] = ContextVar('reqid', default='ROOT') + + +class AuthServer: + def __init__(self): + self.listeners = {} + + async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse: + args = request.query + if 'state' not in args: + raise web.HTTPBadRequest(text="No state provided.") + if args['state'] not in self.listeners: + raise web.HTTPBadRequest(text="Invalid state.") + self.listeners[args['state']].set_result(dict(args)) + return web.Response(text="Authorisation complete! You may now close this page and return to the application.") + + async def handle_listen_request(self, request: web.Request) -> web.StreamResponse: + _reqid = str(uuid.uuid1()) + reqid.set(_reqid) + + logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}") + + ws = web.WebSocketResponse() + await ws.prepare(request) + + # Get the listen request data + try: + listen_req = await ws.receive_json(timeout=60) + logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}") + if 'state' not in listen_req: + logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.") + raise web.HTTPBadRequest(text="Listen request must include state string.") + elif listen_req['state'] in self.listeners: + logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.") + raise web.HTTPBadRequest(text="Invalid state string.") + except ValueError: + logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.") + raise web.HTTPBadRequest(text="Request must be a JSON formatted string.") + except TypeError: + logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.") + raise web.HTTPBadRequest(text="Request must be a JSON formatted string.") + except asyncio.TimeoutError: + logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.") + raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.") + except Exception: + logger.exception(f"[reqid: {_reqid}] Unknown exception.") + raise web.HTTPInternalServerError() + + try: + fut = self.listeners[listen_req['state']] = asyncio.Future() + result = await asyncio.wait_for(fut, timeout=120) + except asyncio.TimeoutError: + logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.") + raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.") + finally: + self.listeners.pop(listen_req['state'], None) + + logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.") + await ws.send_json(result) + await ws.close() + logger.debug(f"[reqid: {_reqid}] Request completed handling.") + + return ws + +def main(argv): + app = web.Application() + server = AuthServer() + app.router.add_get("/twiauth/confirm", server.handle_twitch_callback) + app.router.add_get("/twiauth/listen", server.handle_listen_request) + + logger.info("App setup and configured. Starting now.") + web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080) + + +if __name__ == '__main__': + import sys + main(sys.argv) From caa907b6d9db4759ea030c0dafa37aecbba6e1de Mon Sep 17 00:00:00 2001 From: Interitio Date: Mon, 23 Sep 2024 15:56:51 +1000 Subject: [PATCH 09/10] feat(twitch): Add UserAuthFlow for user auth. --- src/twitch/userflow.py | 84 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 src/twitch/userflow.py diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py new file mode 100644 index 00000000..11b7fdfd --- /dev/null +++ b/src/twitch/userflow.py @@ -0,0 +1,84 @@ +from typing import Optional + +from aiohttp import web + +import aiohttp +from twitchAPI.twitch import Twitch +from twitchAPI.oauth import UserAuthenticator, validate_token +from twitchAPI.type import AuthType +from twitchio.client import asyncio + +from meta.errors import SafeCancellation +from .data import TwitchAuthData +from . import logger + +class UserAuthFlow: + auth: UserAuthenticator + data: TwitchAuthData + auth_ws: str + + def __init__(self, data, auth, auth_ws): + self.auth = auth + self.data = data + self.auth_ws = auth_ws + + self._setup_done = asyncio.Event() + self._comm_task: Optional[asyncio.Task] = None + + async def setup(self): + """ + Establishes websocket connection to the AuthServer, + and requests listening for the given state. + Propagates any exceptions that occur during connection setup. + """ + if self._setup_done.is_set(): + raise ValueError("UserAuthFlow is already set up.") + self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate') + await self._setup_done.wait() + if self._comm_task.done() and (exc := self._comm_task.exception()): + raise exc + + async def _communicate(self): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(self.auth_ws) as ws: + await ws.send_json({'state': self.auth.state}) + self._setup_done.set() + return await ws.receive_json() + + async def run(self): + if not self._setup_done.is_set(): + raise ValueError("Cannot run UserAuthFlow before setup.") + if self._comm_task is None: + raise ValueError("UserAuthFlow running with no comm task! This should be impossible.") + + result = await self._comm_task + if result['error']: + # TODO Custom auth errors + # This is only documented to occure when the user denies the auth + raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}") + + if result['state'] != self.auth.state: + # This should never happen unless the authserver has its wires crossed somehow, + # or the connection has been tampered with. + # TODO: Consider terminating for safety in this case? Or at least refusing more auth requests. + logger.critical( + f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG." + ) + raise SafeCancellation( + "Could not complete authentication! Invalid server response." + ) + + # Now assume result has a valid code + # Exchange code for an auth token and a refresh token + # Ignore type here, authenticate returns None if a callback function has been given. + token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore + + # Fetch the associated userid and basic info + v_result = await validate_token(token) + userid = v_result['user_id'] + + # Save auth data + + async def save_auth(self, userid: str, token: str, refresh: str, scopes: list[str]): + if not self.data._conn: + raise ValueError("Provided registry must be connected.") From 9c738ecb9136cf546aec50269f4b6d86fa143a9b Mon Sep 17 00:00:00 2001 From: Interitio Date: Wed, 25 Sep 2024 02:34:10 +1000 Subject: [PATCH 10/10] feat(twitch): Add basic user authentication flow. --- data/schema.sql | 18 ++++++++++++ src/bot.py | 1 + src/twitch/cog.py | 53 ++++++++++++++++++++++++++++++++++ src/twitch/data.py | 65 +++++++++++++++++++++++++++++++++++++----- src/twitch/lib.py | 0 src/twitch/userflow.py | 16 +++++++---- 6 files changed, 140 insertions(+), 13 deletions(-) create mode 100644 src/twitch/lib.py diff --git a/data/schema.sql b/data/schema.sql index 504ae859..345a997e 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1485,6 +1485,24 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name -- }}} +-- Twitch User Auth {{{ +CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ +); + + +CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT +); +CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + +-- }}} + -- Analytics Data {{{ CREATE SCHEMA "analytics"; diff --git a/src/bot.py b/src/bot.py index 852530f1..db0b312e 100644 --- a/src/bot.py +++ b/src/bot.py @@ -98,6 +98,7 @@ async def main(): config=conf, initial_extensions=[ 'utils', 'core', 'analytics', + 'twitch', 'modules', 'babel', 'tracking.voice', 'tracking.text', diff --git a/src/twitch/cog.py b/src/twitch/cog.py index 5dfdc39d..b3742b0a 100644 --- a/src/twitch/cog.py +++ b/src/twitch/cog.py @@ -5,18 +5,26 @@ from datetime import timedelta import discord from discord.ext import commands as cmds + +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.twitch import AuthType, Twitch +from twitchAPI.type import AuthScope import twitchio from twitchio.ext import commands from data.queries import ORDER from meta import LionCog, LionBot, CrocBot +from meta.LionContext import LionContext +from twitch.userflow import UserAuthFlow from utils.lib import utc_now from . import logger from .data import TwitchAuthData class TwitchAuthCog(LionCog): + DEFAULT_SCOPES = [] + def __init__(self, bot: LionBot): self.bot = bot self.data = bot.db.load_registry(TwitchAuthData()) @@ -29,3 +37,48 @@ class TwitchAuthCog(LionCog): async def fetch_client_for(self, userid: int): ... + async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool: + """ + Checks whether the given userid is authorised. + If 'scopes' is given, will also check the user has all of the given scopes. + """ + authrow = await self.data.UserAuthRow.fetch(userid) + if authrow: + if scopes: + has_scopes = await self.data.UserAuthRow.get_scopes_for(userid) + has_auth = set(map(str, scopes)).issubset(has_scopes) + else: + has_auth = True + else: + has_auth = False + return has_auth + + async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []): + """ + Start the user authentication flow for the given userid. + Will request the given scopes along with the default ones and any existing scopes. + """ + existing_strs = await self.data.UserAuthRow.get_scopes_for(userid) + existing = map(AuthScope, existing_strs) + to_request = set(existing).union(scopes) + return await self.start_auth(to_request) + + async def start_auth(self, scopes = []): + # TODO: Work out a way to just clone the current twitch object + # Or can we otherwise build UserAuthenticator without app auth? + twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret']) + auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri']) + flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url']) + await flow.setup() + + return flow + + # ----- Commands ----- + @cmds.hybrid_command(name='auth') + async def cmd_auth(self, ctx: LionContext): + if ctx.interaction: + await ctx.interaction.response.defer(ephemeral=True) + flow = await self.start_auth() + await ctx.reply(flow.auth.return_auth_url()) + await flow.run() + await ctx.reply("Authentication Complete!") diff --git a/src/twitch/data.py b/src/twitch/data.py index eed13589..94b45c37 100644 --- a/src/twitch/data.py +++ b/src/twitch/data.py @@ -1,4 +1,6 @@ -from data import Registry, RowModel +import datetime as dt + +from data import Registry, RowModel, Table from data.columns import Integer, String, Timestamp @@ -7,22 +9,71 @@ class TwitchAuthData(Registry): """ Schema ------ - CREATE TABLE twitch_tokens( - userid BIGINT PRIMARY KEY, + CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, access_token TEXT NOT NULL, expires_at TIMESTAMPTZ NOT NULL, refresh_token TEXT NOT NULL, obtained_at TIMESTAMPTZ ); - """ - _tablename_ = 'twitch_tokens' + _tablename_ = 'twitch_user_auth' _cache_ = {} userid = Integer(primary=True) access_token = String() - expires_at = Timestamp() refresh_token = String() + expires_at = Timestamp() obtained_at = Timestamp() -# TODO: Scopes + @classmethod + async def update_user_auth( + cls, userid: str, token: str, refresh: str, + expires_at: dt.datetime, obtained_at: dt.datetime, + scopes: list[str] + ): + if cls._connector is None: + raise ValueError("Attempting to use uninitialised Registry.") + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + # Clear row for this userid + await cls.table.delete_where(userid=userid) + + # Insert new user row + row = await cls.create( + userid=userid, + access_token=token, + refresh_token=refresh, + expires_at=expires_at, + obtained_at=obtained_at + ) + # Insert new scope rows + if scopes: + await TwitchAuthData.user_scopes.insert_many( + ('userid', 'scope'), + *((userid, scope) for scope in scopes) + ) + return row + + @classmethod + async def get_scopes_for(cls, userid: str) -> list[str]: + """ + Get a list of scopes stored for the given user. + Will return an empty list if the user is not authenticated. + """ + rows = await TwitchAuthData.user_scopes.select_where(userid=userid) + + return [row.scope for row in rows] if rows else [] + + + """ + Schema + ------ + CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT + ); + CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + """ + user_scopes = Table('twitch_token_scopes') diff --git a/src/twitch/lib.py b/src/twitch/lib.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py index 11b7fdfd..ce7d20dc 100644 --- a/src/twitch/userflow.py +++ b/src/twitch/userflow.py @@ -1,4 +1,5 @@ from typing import Optional +import datetime as dt from aiohttp import web @@ -9,6 +10,7 @@ from twitchAPI.type import AuthType from twitchio.client import asyncio from meta.errors import SafeCancellation +from utils.lib import utc_now from .data import TwitchAuthData from . import logger @@ -52,12 +54,12 @@ class UserAuthFlow: raise ValueError("UserAuthFlow running with no comm task! This should be impossible.") result = await self._comm_task - if result['error']: + if result.get('error', None): # TODO Custom auth errors # This is only documented to occure when the user denies the auth raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}") - if result['state'] != self.auth.state: + if result.get('state', None) != self.auth.state: # This should never happen unless the authserver has its wires crossed somehow, # or the connection has been tampered with. # TODO: Consider terminating for safety in this case? Or at least refusing more auth requests. @@ -76,9 +78,11 @@ class UserAuthFlow: # Fetch the associated userid and basic info v_result = await validate_token(token) userid = v_result['user_id'] + expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in']) # Save auth data - - async def save_auth(self, userid: str, token: str, refresh: str, scopes: list[str]): - if not self.data._conn: - raise ValueError("Provided registry must be connected.") + return await self.data.UserAuthRow.update_user_auth( + userid=userid, token=token, refresh=refresh, + expires_at=expiry, obtained_at=utc_now(), + scopes=[scope.value for scope in self.auth.scopes] + )