diff --git a/data/schema.sql b/data/schema.sql index 504ae859..345a997e 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1485,6 +1485,24 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name -- }}} +-- Twitch User Auth {{{ +CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ +); + + +CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT +); +CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + +-- }}} + -- Analytics Data {{{ CREATE SCHEMA "analytics"; diff --git a/src/twitch/cog.py b/src/twitch/cog.py index 5dfdc39d..b3742b0a 100644 --- a/src/twitch/cog.py +++ b/src/twitch/cog.py @@ -5,18 +5,26 @@ from datetime import timedelta import discord from discord.ext import commands as cmds + +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.twitch import AuthType, Twitch +from twitchAPI.type import AuthScope import twitchio from twitchio.ext import commands from data.queries import ORDER from meta import LionCog, LionBot, CrocBot +from meta.LionContext import LionContext +from twitch.userflow import UserAuthFlow from utils.lib import utc_now from . import logger from .data import TwitchAuthData class TwitchAuthCog(LionCog): + DEFAULT_SCOPES = [] + def __init__(self, bot: LionBot): self.bot = bot self.data = bot.db.load_registry(TwitchAuthData()) @@ -29,3 +37,48 @@ class TwitchAuthCog(LionCog): async def fetch_client_for(self, userid: int): ... + async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool: + """ + Checks whether the given userid is authorised. + If 'scopes' is given, will also check the user has all of the given scopes. + """ + authrow = await self.data.UserAuthRow.fetch(userid) + if authrow: + if scopes: + has_scopes = await self.data.UserAuthRow.get_scopes_for(userid) + has_auth = set(map(str, scopes)).issubset(has_scopes) + else: + has_auth = True + else: + has_auth = False + return has_auth + + async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []): + """ + Start the user authentication flow for the given userid. + Will request the given scopes along with the default ones and any existing scopes. + """ + existing_strs = await self.data.UserAuthRow.get_scopes_for(userid) + existing = map(AuthScope, existing_strs) + to_request = set(existing).union(scopes) + return await self.start_auth(to_request) + + async def start_auth(self, scopes = []): + # TODO: Work out a way to just clone the current twitch object + # Or can we otherwise build UserAuthenticator without app auth? + twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret']) + auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri']) + flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url']) + await flow.setup() + + return flow + + # ----- Commands ----- + @cmds.hybrid_command(name='auth') + async def cmd_auth(self, ctx: LionContext): + if ctx.interaction: + await ctx.interaction.response.defer(ephemeral=True) + flow = await self.start_auth() + await ctx.reply(flow.auth.return_auth_url()) + await flow.run() + await ctx.reply("Authentication Complete!") diff --git a/src/twitch/data.py b/src/twitch/data.py index eed13589..94b45c37 100644 --- a/src/twitch/data.py +++ b/src/twitch/data.py @@ -1,4 +1,6 @@ -from data import Registry, RowModel +import datetime as dt + +from data import Registry, RowModel, Table from data.columns import Integer, String, Timestamp @@ -7,22 +9,71 @@ class TwitchAuthData(Registry): """ Schema ------ - CREATE TABLE twitch_tokens( - userid BIGINT PRIMARY KEY, + CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, access_token TEXT NOT NULL, expires_at TIMESTAMPTZ NOT NULL, refresh_token TEXT NOT NULL, obtained_at TIMESTAMPTZ ); - """ - _tablename_ = 'twitch_tokens' + _tablename_ = 'twitch_user_auth' _cache_ = {} userid = Integer(primary=True) access_token = String() - expires_at = Timestamp() refresh_token = String() + expires_at = Timestamp() obtained_at = Timestamp() -# TODO: Scopes + @classmethod + async def update_user_auth( + cls, userid: str, token: str, refresh: str, + expires_at: dt.datetime, obtained_at: dt.datetime, + scopes: list[str] + ): + if cls._connector is None: + raise ValueError("Attempting to use uninitialised Registry.") + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + # Clear row for this userid + await cls.table.delete_where(userid=userid) + + # Insert new user row + row = await cls.create( + userid=userid, + access_token=token, + refresh_token=refresh, + expires_at=expires_at, + obtained_at=obtained_at + ) + # Insert new scope rows + if scopes: + await TwitchAuthData.user_scopes.insert_many( + ('userid', 'scope'), + *((userid, scope) for scope in scopes) + ) + return row + + @classmethod + async def get_scopes_for(cls, userid: str) -> list[str]: + """ + Get a list of scopes stored for the given user. + Will return an empty list if the user is not authenticated. + """ + rows = await TwitchAuthData.user_scopes.select_where(userid=userid) + + return [row.scope for row in rows] if rows else [] + + + """ + Schema + ------ + CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT + ); + CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + """ + user_scopes = Table('twitch_token_scopes') diff --git a/src/twitch/lib.py b/src/twitch/lib.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py index 11b7fdfd..ce7d20dc 100644 --- a/src/twitch/userflow.py +++ b/src/twitch/userflow.py @@ -1,4 +1,5 @@ from typing import Optional +import datetime as dt from aiohttp import web @@ -9,6 +10,7 @@ from twitchAPI.type import AuthType from twitchio.client import asyncio from meta.errors import SafeCancellation +from utils.lib import utc_now from .data import TwitchAuthData from . import logger @@ -52,12 +54,12 @@ class UserAuthFlow: raise ValueError("UserAuthFlow running with no comm task! This should be impossible.") result = await self._comm_task - if result['error']: + if result.get('error', None): # TODO Custom auth errors # This is only documented to occure when the user denies the auth raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}") - if result['state'] != self.auth.state: + if result.get('state', None) != self.auth.state: # This should never happen unless the authserver has its wires crossed somehow, # or the connection has been tampered with. # TODO: Consider terminating for safety in this case? Or at least refusing more auth requests. @@ -76,9 +78,11 @@ class UserAuthFlow: # Fetch the associated userid and basic info v_result = await validate_token(token) userid = v_result['user_id'] + expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in']) # Save auth data - - async def save_auth(self, userid: str, token: str, refresh: str, scopes: list[str]): - if not self.data._conn: - raise ValueError("Provided registry must be connected.") + return await self.data.UserAuthRow.update_user_auth( + userid=userid, token=token, refresh=refresh, + expires_at=expiry, obtained_at=utc_now(), + scopes=[scope.value for scope in self.auth.scopes] + )