Add twitch auth module.

This commit is contained in:
2025-06-06 00:05:24 +10:00
parent 9625dec1e4
commit 2cf81c38e8
10 changed files with 450 additions and 4 deletions

View File

@@ -4,3 +4,5 @@ discord.py [voice]
iso8601 iso8601
psycopg[pool] psycopg[pool]
pytz pytz
twitchio
twitchAPI

View File

@@ -4,6 +4,7 @@ import logging
import aiohttp import aiohttp
import discord import discord
from discord.ext import commands from discord.ext import commands
from twitchAPI.twitch import Twitch
from meta import LionBot, conf, sharding, appname from meta import LionBot, conf, sharding, appname
from meta.app import shardname from meta.app import shardname
@@ -49,13 +50,15 @@ async def _data_monitor() -> ComponentStatus:
async def main(): async def main():
log_action_stack.set(("Initialising",)) log_action_stack.set(("Initialising",))
logger.info("Initialising StudyLion") logger.info("Initialising LionBot")
intents = discord.Intents.all() intents = discord.Intents.all()
intents.members = True intents.members = True
intents.message_content = True intents.message_content = True
intents.presences = False intents.presences = False
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
async with db.open(): async with db.open():
version = await db.version() version = await db.version()
if version.version != DATA_VERSION: if version.version != DATA_VERSION:
@@ -82,6 +85,7 @@ async def main():
help_command=None, help_command=None,
proxy=conf.bot.get('proxy', None), proxy=conf.bot.get('proxy', None),
chunk_guilds_at_startup=False, chunk_guilds_at_startup=False,
twitch=twitch
) as lionbot: ) as lionbot:
ctx_bot.set(lionbot) ctx_bot.set(lionbot)
lionbot.system_monitor.add_component( lionbot.system_monitor.add_component(
@@ -89,11 +93,11 @@ async def main():
) )
try: try:
log_context.set(f"APP: {appname}") log_context.set(f"APP: {appname}")
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) logger.info("LionBot initialised, starting!", extra={'action': 'Starting'})
await lionbot.start(conf.bot['TOKEN']) await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError: except asyncio.CancelledError:
log_context.set(f"APP: {appname}") log_context.set(f"APP: {appname}")
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) logger.info("LionBot closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
def _main(): def _main():

View File

@@ -9,6 +9,7 @@ from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
from aiohttp import ClientSession from aiohttp import ClientSession
from twitchAPI.twitch import Twitch
from data import Database from data import Database
from utils.lib import tabulate from utils.lib import tabulate
@@ -23,6 +24,8 @@ from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStat
if TYPE_CHECKING: if TYPE_CHECKING:
from core.cog import CoreCog from core.cog import CoreCog
from twitch.cog import TwitchAuthCog
from modules.profiles.cog import ProfileCog
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,7 +34,9 @@ class LionBot(Bot):
def __init__( def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf, self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession, initial_extensions: List[str], web_client: ClientSession,
testing_guilds: List[int] = [], **kwargs twitch: Twitch,
testing_guilds: List[int] = [],
**kwargs
): ):
kwargs.setdefault('tree_cls', LionTree) kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -43,6 +48,7 @@ class LionBot(Bot):
self.shardname = shardname self.shardname = shardname
# self.appdata = appdata # self.appdata = appdata
self.config = config self.config = config
self.twitch = twitch
self.system_monitor = SystemMonitor() self.system_monitor = SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.monitor = ComponentMonitor('LionBot', self._monitor_status)
@@ -101,6 +107,14 @@ class LionBot(Bot):
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog': def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
... ...
@overload
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
...
@overload
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
...
@overload @overload
def get_cog(self, name: str) -> Optional[Cog]: def get_cog(self, name: str) -> Optional[Cog]:
... ...

9
src/twitch/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
from .cog import TwitchAuthCog
async def setup(bot):
await bot.add_cog(TwitchAuthCog(bot))

50
src/twitch/authclient.py Normal file
View File

@@ -0,0 +1,50 @@
"""
Testing client for the twitch AuthServer.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
import asyncio
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.type import AuthScope
from meta.config import conf
URI = "http://localhost:3000/twiauth/confirm"
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
async def main():
# Load in client id and secret
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
url = auth.return_auth_url()
# Post url to user
print(url)
# Send listen request to server
# Wait for listen request
async with aiohttp.ClientSession() as session:
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
await ws.send_json({'state': auth.state})
result = await ws.receive_json()
# Hopefully get back code, print the response
print(f"Recieved: {result}")
# Authorise with code and client details
tokens = await auth.authenticate(user_token=result['code'])
if tokens:
token, refresh = tokens
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
print(f"Authorised!")
if __name__ == '__main__':
asyncio.run(main())

86
src/twitch/authserver.py Normal file
View File

@@ -0,0 +1,86 @@
import logging
import uuid
import asyncio
from contextvars import ContextVar
import aiohttp
from aiohttp import web
logger = logging.getLogger(__name__)
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
class AuthServer:
def __init__(self):
self.listeners = {}
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
args = request.query
if 'state' not in args:
raise web.HTTPBadRequest(text="No state provided.")
if args['state'] not in self.listeners:
raise web.HTTPBadRequest(text="Invalid state.")
self.listeners[args['state']].set_result(dict(args))
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
_reqid = str(uuid.uuid1())
reqid.set(_reqid)
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get the listen request data
try:
listen_req = await ws.receive_json(timeout=60)
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
if 'state' not in listen_req:
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
raise web.HTTPBadRequest(text="Listen request must include state string.")
elif listen_req['state'] in self.listeners:
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
raise web.HTTPBadRequest(text="Invalid state string.")
except ValueError:
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except TypeError:
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
except Exception:
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
raise web.HTTPInternalServerError()
try:
fut = self.listeners[listen_req['state']] = asyncio.Future()
result = await asyncio.wait_for(fut, timeout=120)
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
finally:
self.listeners.pop(listen_req['state'], None)
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
await ws.send_json(result)
await ws.close()
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
return ws
def main(argv):
app = web.Application()
server = AuthServer()
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
app.router.add_get("/twiauth/listen", server.handle_listen_request)
logger.info("App setup and configured. Starting now.")
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
if __name__ == '__main__':
import sys
main(sys.argv)

114
src/twitch/cog.py Normal file
View File

@@ -0,0 +1,114 @@
import asyncio
from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.twitch import AuthType, Twitch
from twitchAPI.type import AuthScope
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot
from meta.LionContext import LionContext
from twitch.userflow import UserAuthFlow
from utils.lib import utc_now
from . import logger
from .data import TwitchAuthData
class TwitchAuthCog(LionCog):
DEFAULT_SCOPES = []
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(TwitchAuthData())
self.client_cache = {}
async def cog_load(self):
await self.data.init()
# ----- Auth API -----
async def fetch_client_for(self, userid: str):
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow is None:
# TODO: Some user authentication error
self.client_cache.pop(userid, None)
raise ValueError("Requested user is not authenticated.")
if (twitch := self.client_cache.get(userid)) is None:
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
authscopes = [AuthScope(scope) for scope in scopes]
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
self.client_cache[userid] = twitch
return twitch
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
"""
Checks whether the given userid is authorised.
If 'scopes' is given, will also check the user has all of the given scopes.
"""
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow:
if scopes:
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
desired = {scope.value for scope in scopes}
has_auth = desired.issubset(has_scopes)
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
else:
has_auth = True
else:
has_auth = False
return has_auth
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
"""
Start the user authentication flow for the given userid.
Will request the given scopes along with the default ones and any existing scopes.
"""
self.client_cache.pop(userid, None)
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
existing = map(AuthScope, existing_strs)
to_request = set(existing).union(scopes)
return await self.start_auth(to_request)
async def start_auth(self, scopes = []):
# TODO: Work out a way to just clone the current twitch object
# Or can we otherwise build UserAuthenticator without app auth?
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
await flow.setup()
return flow
# ----- Commands -----
@cmds.hybrid_command(name='auth')
async def cmd_auth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
flow = await self.start_auth()
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")
@cmds.hybrid_command(name='modauth')
async def cmd_modauth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
scopes = [
AuthScope.MODERATOR_READ_FOLLOWERS,
AuthScope.CHANNEL_READ_REDEMPTIONS,
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
]
flow = await self.start_auth(scopes=scopes)
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")

79
src/twitch/data.py Normal file
View File

@@ -0,0 +1,79 @@
import datetime as dt
from data import Registry, RowModel, Table
from data.columns import Integer, String, Timestamp
class TwitchAuthData(Registry):
class UserAuthRow(RowModel):
"""
Schema
------
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
"""
_tablename_ = 'twitch_user_auth'
_cache_ = {}
userid = Integer(primary=True)
access_token = String()
refresh_token = String()
expires_at = Timestamp()
obtained_at = Timestamp()
@classmethod
async def update_user_auth(
cls, userid: str, token: str, refresh: str,
expires_at: dt.datetime, obtained_at: dt.datetime,
scopes: list[str]
):
if cls._connector is None:
raise ValueError("Attempting to use uninitialised Registry.")
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Clear row for this userid
await cls.table.delete_where(userid=userid)
# Insert new user row
row = await cls.create(
userid=userid,
access_token=token,
refresh_token=refresh,
expires_at=expires_at,
obtained_at=obtained_at
)
# Insert new scope rows
if scopes:
await TwitchAuthData.user_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in scopes)
)
return row
@classmethod
async def get_scopes_for(cls, userid: str) -> list[str]:
"""
Get a list of scopes stored for the given user.
Will return an empty list if the user is not authenticated.
"""
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
return [row['scope'] for row in rows] if rows else []
"""
Schema
------
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
"""
user_scopes = Table('twitch_user_scopes')

0
src/twitch/lib.py Normal file
View File

88
src/twitch/userflow.py Normal file
View File

@@ -0,0 +1,88 @@
from typing import Optional
import datetime as dt
from aiohttp import web
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator, validate_token
from twitchAPI.type import AuthType
from twitchio.client import asyncio
from meta.errors import SafeCancellation
from utils.lib import utc_now
from .data import TwitchAuthData
from . import logger
class UserAuthFlow:
auth: UserAuthenticator
data: TwitchAuthData
auth_ws: str
def __init__(self, data, auth, auth_ws):
self.auth = auth
self.data = data
self.auth_ws = auth_ws
self._setup_done = asyncio.Event()
self._comm_task: Optional[asyncio.Task] = None
async def setup(self):
"""
Establishes websocket connection to the AuthServer,
and requests listening for the given state.
Propagates any exceptions that occur during connection setup.
"""
if self._setup_done.is_set():
raise ValueError("UserAuthFlow is already set up.")
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
await self._setup_done.wait()
if self._comm_task.done() and (exc := self._comm_task.exception()):
raise exc
async def _communicate(self):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.auth_ws) as ws:
await ws.send_json({'state': self.auth.state})
self._setup_done.set()
return await ws.receive_json()
async def run(self) -> TwitchAuthData.UserAuthRow:
if not self._setup_done.is_set():
raise ValueError("Cannot run UserAuthFlow before setup.")
if self._comm_task is None:
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
result = await self._comm_task
if result.get('error', None):
# TODO Custom auth errors
# This is only documented to occur when the user denies the auth
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
if result.get('state', None) != self.auth.state:
# This should never happen unless the authserver has its wires crossed somehow,
# or the connection has been tampered with.
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
logger.critical(
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
)
raise SafeCancellation(
"Could not complete authentication! Invalid server response."
)
# Now assume result has a valid code
# Exchange code for an auth token and a refresh token
# Ignore type here, authenticate returns None if a callback function has been given.
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
# Fetch the associated userid and basic info
v_result = await validate_token(token)
userid = v_result['user_id']
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
# Save auth data
return await self.data.UserAuthRow.update_user_auth(
userid=userid, token=token, refresh=refresh,
expires_at=expiry, obtained_at=utc_now(),
scopes=[scope.value for scope in self.auth.scopes]
)