From 2cf81c38e8cddf9b3fa3ce56473184ecd1a73053 Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 6 Jun 2025 00:05:24 +1000 Subject: [PATCH] Add twitch auth module. --- requirements.txt | 2 + src/bot.py | 10 ++-- src/meta/LionBot.py | 16 +++++- src/twitch/__init__.py | 9 ++++ src/twitch/authclient.py | 50 +++++++++++++++++ src/twitch/authserver.py | 86 +++++++++++++++++++++++++++++ src/twitch/cog.py | 114 +++++++++++++++++++++++++++++++++++++++ src/twitch/data.py | 79 +++++++++++++++++++++++++++ src/twitch/lib.py | 0 src/twitch/userflow.py | 88 ++++++++++++++++++++++++++++++ 10 files changed, 450 insertions(+), 4 deletions(-) create mode 100644 src/twitch/__init__.py create mode 100644 src/twitch/authclient.py create mode 100644 src/twitch/authserver.py create mode 100644 src/twitch/cog.py create mode 100644 src/twitch/data.py create mode 100644 src/twitch/lib.py create mode 100644 src/twitch/userflow.py diff --git a/requirements.txt b/requirements.txt index 83073f4..f6562b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ discord.py [voice] iso8601 psycopg[pool] pytz +twitchio +twitchAPI diff --git a/src/bot.py b/src/bot.py index 0b61a61..dbc99c8 100644 --- a/src/bot.py +++ b/src/bot.py @@ -4,6 +4,7 @@ import logging import aiohttp import discord from discord.ext import commands +from twitchAPI.twitch import Twitch from meta import LionBot, conf, sharding, appname from meta.app import shardname @@ -49,13 +50,15 @@ async def _data_monitor() -> ComponentStatus: async def main(): log_action_stack.set(("Initialising",)) - logger.info("Initialising StudyLion") + logger.info("Initialising LionBot") intents = discord.Intents.all() intents.members = True intents.message_content = True intents.presences = False + twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret']) + async with db.open(): version = await db.version() if version.version != DATA_VERSION: @@ -82,6 +85,7 @@ async def main(): help_command=None, proxy=conf.bot.get('proxy', None), chunk_guilds_at_startup=False, + twitch=twitch ) as lionbot: ctx_bot.set(lionbot) lionbot.system_monitor.add_component( @@ -89,11 +93,11 @@ async def main(): ) try: log_context.set(f"APP: {appname}") - logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) + logger.info("LionBot initialised, starting!", extra={'action': 'Starting'}) await lionbot.start(conf.bot['TOKEN']) except asyncio.CancelledError: log_context.set(f"APP: {appname}") - logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) + logger.info("LionBot closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) def _main(): diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 32f39ee..5f8981e 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -9,6 +9,7 @@ from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError from discord.ext.commands.errors import CommandInvokeError, CheckFailure from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError from aiohttp import ClientSession +from twitchAPI.twitch import Twitch from data import Database from utils.lib import tabulate @@ -23,6 +24,8 @@ from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStat if TYPE_CHECKING: from core.cog import CoreCog + from twitch.cog import TwitchAuthCog + from modules.profiles.cog import ProfileCog logger = logging.getLogger(__name__) @@ -31,7 +34,9 @@ class LionBot(Bot): def __init__( self, *args, appname: str, shardname: str, db: Database, config: Conf, initial_extensions: List[str], web_client: ClientSession, - testing_guilds: List[int] = [], **kwargs + twitch: Twitch, + testing_guilds: List[int] = [], + **kwargs ): kwargs.setdefault('tree_cls', LionTree) super().__init__(*args, **kwargs) @@ -43,6 +48,7 @@ class LionBot(Bot): self.shardname = shardname # self.appdata = appdata self.config = config + self.twitch = twitch self.system_monitor = SystemMonitor() self.monitor = ComponentMonitor('LionBot', self._monitor_status) @@ -101,6 +107,14 @@ class LionBot(Bot): def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog': ... + @overload + def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog': + ... + + @overload + def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog': + ... + @overload def get_cog(self, name: str) -> Optional[Cog]: ... diff --git a/src/twitch/__init__.py b/src/twitch/__init__.py new file mode 100644 index 0000000..53e33e8 --- /dev/null +++ b/src/twitch/__init__.py @@ -0,0 +1,9 @@ +import logging + +logger = logging.getLogger(__name__) + +from .cog import TwitchAuthCog + +async def setup(bot): + await bot.add_cog(TwitchAuthCog(bot)) + diff --git a/src/twitch/authclient.py b/src/twitch/authclient.py new file mode 100644 index 0000000..509b080 --- /dev/null +++ b/src/twitch/authclient.py @@ -0,0 +1,50 @@ +""" +Testing client for the twitch AuthServer. +""" +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + +import asyncio +import aiohttp +from twitchAPI.twitch import Twitch +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.type import AuthScope + +from meta.config import conf + + +URI = "http://localhost:3000/twiauth/confirm" +TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ] + +async def main(): + # Load in client id and secret + twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret']) + auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI) + url = auth.return_auth_url() + + # Post url to user + print(url) + + # Send listen request to server + # Wait for listen request + async with aiohttp.ClientSession() as session: + async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws: + await ws.send_json({'state': auth.state}) + result = await ws.receive_json() + + # Hopefully get back code, print the response + print(f"Recieved: {result}") + + # Authorise with code and client details + tokens = await auth.authenticate(user_token=result['code']) + if tokens: + token, refresh = tokens + await twitch.set_user_authentication(token, TARGET_SCOPE, refresh) + print(f"Authorised!") + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/twitch/authserver.py b/src/twitch/authserver.py new file mode 100644 index 0000000..b26c595 --- /dev/null +++ b/src/twitch/authserver.py @@ -0,0 +1,86 @@ +import logging +import uuid +import asyncio +from contextvars import ContextVar + +import aiohttp +from aiohttp import web + +logger = logging.getLogger(__name__) +reqid: ContextVar[str] = ContextVar('reqid', default='ROOT') + + +class AuthServer: + def __init__(self): + self.listeners = {} + + async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse: + args = request.query + if 'state' not in args: + raise web.HTTPBadRequest(text="No state provided.") + if args['state'] not in self.listeners: + raise web.HTTPBadRequest(text="Invalid state.") + self.listeners[args['state']].set_result(dict(args)) + return web.Response(text="Authorisation complete! You may now close this page and return to the application.") + + async def handle_listen_request(self, request: web.Request) -> web.StreamResponse: + _reqid = str(uuid.uuid1()) + reqid.set(_reqid) + + logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}") + + ws = web.WebSocketResponse() + await ws.prepare(request) + + # Get the listen request data + try: + listen_req = await ws.receive_json(timeout=60) + logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}") + if 'state' not in listen_req: + logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.") + raise web.HTTPBadRequest(text="Listen request must include state string.") + elif listen_req['state'] in self.listeners: + logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.") + raise web.HTTPBadRequest(text="Invalid state string.") + except ValueError: + logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.") + raise web.HTTPBadRequest(text="Request must be a JSON formatted string.") + except TypeError: + logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.") + raise web.HTTPBadRequest(text="Request must be a JSON formatted string.") + except asyncio.TimeoutError: + logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.") + raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.") + except Exception: + logger.exception(f"[reqid: {_reqid}] Unknown exception.") + raise web.HTTPInternalServerError() + + try: + fut = self.listeners[listen_req['state']] = asyncio.Future() + result = await asyncio.wait_for(fut, timeout=120) + except asyncio.TimeoutError: + logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.") + raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.") + finally: + self.listeners.pop(listen_req['state'], None) + + logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.") + await ws.send_json(result) + await ws.close() + logger.debug(f"[reqid: {_reqid}] Request completed handling.") + + return ws + +def main(argv): + app = web.Application() + server = AuthServer() + app.router.add_get("/twiauth/confirm", server.handle_twitch_callback) + app.router.add_get("/twiauth/listen", server.handle_listen_request) + + logger.info("App setup and configured. Starting now.") + web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080) + + +if __name__ == '__main__': + import sys + main(sys.argv) diff --git a/src/twitch/cog.py b/src/twitch/cog.py new file mode 100644 index 0000000..ac4f30d --- /dev/null +++ b/src/twitch/cog.py @@ -0,0 +1,114 @@ +import asyncio +from enum import Enum +from typing import Optional +from datetime import timedelta + +import discord +from discord.ext import commands as cmds + +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.twitch import AuthType, Twitch +from twitchAPI.type import AuthScope +import twitchio +from twitchio.ext import commands + + +from data.queries import ORDER +from meta import LionCog, LionBot, CrocBot +from meta.LionContext import LionContext +from twitch.userflow import UserAuthFlow +from utils.lib import utc_now +from . import logger +from .data import TwitchAuthData + + +class TwitchAuthCog(LionCog): + DEFAULT_SCOPES = [] + + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(TwitchAuthData()) + + self.client_cache = {} + + async def cog_load(self): + await self.data.init() + + # ----- Auth API ----- + + async def fetch_client_for(self, userid: str): + authrow = await self.data.UserAuthRow.fetch(userid) + if authrow is None: + # TODO: Some user authentication error + self.client_cache.pop(userid, None) + raise ValueError("Requested user is not authenticated.") + if (twitch := self.client_cache.get(userid)) is None: + twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret']) + scopes = await self.data.UserAuthRow.get_scopes_for(userid) + authscopes = [AuthScope(scope) for scope in scopes] + await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token) + self.client_cache[userid] = twitch + return twitch + + async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool: + """ + Checks whether the given userid is authorised. + If 'scopes' is given, will also check the user has all of the given scopes. + """ + authrow = await self.data.UserAuthRow.fetch(userid) + if authrow: + if scopes: + has_scopes = await self.data.UserAuthRow.get_scopes_for(userid) + desired = {scope.value for scope in scopes} + has_auth = desired.issubset(has_scopes) + logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}") + else: + has_auth = True + else: + has_auth = False + return has_auth + + async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []): + """ + Start the user authentication flow for the given userid. + Will request the given scopes along with the default ones and any existing scopes. + """ + self.client_cache.pop(userid, None) + existing_strs = await self.data.UserAuthRow.get_scopes_for(userid) + existing = map(AuthScope, existing_strs) + to_request = set(existing).union(scopes) + return await self.start_auth(to_request) + + async def start_auth(self, scopes = []): + # TODO: Work out a way to just clone the current twitch object + # Or can we otherwise build UserAuthenticator without app auth? + twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret']) + auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri']) + flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url']) + await flow.setup() + + return flow + + # ----- Commands ----- + @cmds.hybrid_command(name='auth') + async def cmd_auth(self, ctx: LionContext): + if ctx.interaction: + await ctx.interaction.response.defer(ephemeral=True) + flow = await self.start_auth() + await ctx.reply(flow.auth.return_auth_url()) + await flow.run() + await ctx.reply("Authentication Complete!") + + @cmds.hybrid_command(name='modauth') + async def cmd_modauth(self, ctx: LionContext): + if ctx.interaction: + await ctx.interaction.response.defer(ephemeral=True) + scopes = [ + AuthScope.MODERATOR_READ_FOLLOWERS, + AuthScope.CHANNEL_READ_REDEMPTIONS, + AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES, + ] + flow = await self.start_auth(scopes=scopes) + await ctx.reply(flow.auth.return_auth_url()) + await flow.run() + await ctx.reply("Authentication Complete!") diff --git a/src/twitch/data.py b/src/twitch/data.py new file mode 100644 index 0000000..9f0f474 --- /dev/null +++ b/src/twitch/data.py @@ -0,0 +1,79 @@ +import datetime as dt + +from data import Registry, RowModel, Table +from data.columns import Integer, String, Timestamp + + +class TwitchAuthData(Registry): + class UserAuthRow(RowModel): + """ + Schema + ------ + CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ + ); + """ + _tablename_ = 'twitch_user_auth' + _cache_ = {} + + userid = Integer(primary=True) + access_token = String() + refresh_token = String() + expires_at = Timestamp() + obtained_at = Timestamp() + + @classmethod + async def update_user_auth( + cls, userid: str, token: str, refresh: str, + expires_at: dt.datetime, obtained_at: dt.datetime, + scopes: list[str] + ): + if cls._connector is None: + raise ValueError("Attempting to use uninitialised Registry.") + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + # Clear row for this userid + await cls.table.delete_where(userid=userid) + + # Insert new user row + row = await cls.create( + userid=userid, + access_token=token, + refresh_token=refresh, + expires_at=expires_at, + obtained_at=obtained_at + ) + # Insert new scope rows + if scopes: + await TwitchAuthData.user_scopes.insert_many( + ('userid', 'scope'), + *((userid, scope) for scope in scopes) + ) + return row + + @classmethod + async def get_scopes_for(cls, userid: str) -> list[str]: + """ + Get a list of scopes stored for the given user. + Will return an empty list if the user is not authenticated. + """ + rows = await TwitchAuthData.user_scopes.select_where(userid=userid) + + return [row['scope'] for row in rows] if rows else [] + + + """ + Schema + ------ + CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT + ); + CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + """ + user_scopes = Table('twitch_user_scopes') diff --git a/src/twitch/lib.py b/src/twitch/lib.py new file mode 100644 index 0000000..e69de29 diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py new file mode 100644 index 0000000..8a73614 --- /dev/null +++ b/src/twitch/userflow.py @@ -0,0 +1,88 @@ +from typing import Optional +import datetime as dt + +from aiohttp import web + +import aiohttp +from twitchAPI.twitch import Twitch +from twitchAPI.oauth import UserAuthenticator, validate_token +from twitchAPI.type import AuthType +from twitchio.client import asyncio + +from meta.errors import SafeCancellation +from utils.lib import utc_now +from .data import TwitchAuthData +from . import logger + +class UserAuthFlow: + auth: UserAuthenticator + data: TwitchAuthData + auth_ws: str + + def __init__(self, data, auth, auth_ws): + self.auth = auth + self.data = data + self.auth_ws = auth_ws + + self._setup_done = asyncio.Event() + self._comm_task: Optional[asyncio.Task] = None + + async def setup(self): + """ + Establishes websocket connection to the AuthServer, + and requests listening for the given state. + Propagates any exceptions that occur during connection setup. + """ + if self._setup_done.is_set(): + raise ValueError("UserAuthFlow is already set up.") + self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate') + await self._setup_done.wait() + if self._comm_task.done() and (exc := self._comm_task.exception()): + raise exc + + async def _communicate(self): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(self.auth_ws) as ws: + await ws.send_json({'state': self.auth.state}) + self._setup_done.set() + return await ws.receive_json() + + async def run(self) -> TwitchAuthData.UserAuthRow: + if not self._setup_done.is_set(): + raise ValueError("Cannot run UserAuthFlow before setup.") + if self._comm_task is None: + raise ValueError("UserAuthFlow running with no comm task! This should be impossible.") + + result = await self._comm_task + if result.get('error', None): + # TODO Custom auth errors + # This is only documented to occur when the user denies the auth + raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}") + + if result.get('state', None) != self.auth.state: + # This should never happen unless the authserver has its wires crossed somehow, + # or the connection has been tampered with. + # TODO: Consider terminating for safety in this case? Or at least refusing more auth requests. + logger.critical( + f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG." + ) + raise SafeCancellation( + "Could not complete authentication! Invalid server response." + ) + + # Now assume result has a valid code + # Exchange code for an auth token and a refresh token + # Ignore type here, authenticate returns None if a callback function has been given. + token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore + + # Fetch the associated userid and basic info + v_result = await validate_token(token) + userid = v_result['user_id'] + expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in']) + + # Save auth data + return await self.data.UserAuthRow.update_user_auth( + userid=userid, token=token, refresh=refresh, + expires_at=expiry, obtained_at=utc_now(), + scopes=[scope.value for scope in self.auth.scopes] + )