Compare commits
1 Commits
feat-simpl
...
feat-auth
| Author | SHA1 | Date | |
|---|---|---|---|
| d8c3d50800 |
@@ -1485,6 +1485,24 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name
|
|||||||
-- }}}
|
-- }}}
|
||||||
|
|
||||||
|
|
||||||
|
-- Twitch User Auth {{{
|
||||||
|
CREATE TABLE twitch_user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
access_token TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
obtained_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
|
||||||
-- Analytics Data {{{
|
-- Analytics Data {{{
|
||||||
CREATE SCHEMA "analytics";
|
CREATE SCHEMA "analytics";
|
||||||
|
|||||||
@@ -5,18 +5,26 @@ from datetime import timedelta
|
|||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands as cmds
|
from discord.ext import commands as cmds
|
||||||
|
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.twitch import AuthType, Twitch
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
import twitchio
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
|
|
||||||
|
|
||||||
from data.queries import ORDER
|
from data.queries import ORDER
|
||||||
from meta import LionCog, LionBot, CrocBot
|
from meta import LionCog, LionBot, CrocBot
|
||||||
|
from meta.LionContext import LionContext
|
||||||
|
from twitch.userflow import UserAuthFlow
|
||||||
from utils.lib import utc_now
|
from utils.lib import utc_now
|
||||||
from . import logger
|
from . import logger
|
||||||
from .data import TwitchAuthData
|
from .data import TwitchAuthData
|
||||||
|
|
||||||
|
|
||||||
class TwitchAuthCog(LionCog):
|
class TwitchAuthCog(LionCog):
|
||||||
|
DEFAULT_SCOPES = []
|
||||||
|
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.data = bot.db.load_registry(TwitchAuthData())
|
self.data = bot.db.load_registry(TwitchAuthData())
|
||||||
@@ -29,3 +37,48 @@ class TwitchAuthCog(LionCog):
|
|||||||
async def fetch_client_for(self, userid: int):
|
async def fetch_client_for(self, userid: int):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether the given userid is authorised.
|
||||||
|
If 'scopes' is given, will also check the user has all of the given scopes.
|
||||||
|
"""
|
||||||
|
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||||
|
if authrow:
|
||||||
|
if scopes:
|
||||||
|
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
has_auth = set(map(str, scopes)).issubset(has_scopes)
|
||||||
|
else:
|
||||||
|
has_auth = True
|
||||||
|
else:
|
||||||
|
has_auth = False
|
||||||
|
return has_auth
|
||||||
|
|
||||||
|
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
|
||||||
|
"""
|
||||||
|
Start the user authentication flow for the given userid.
|
||||||
|
Will request the given scopes along with the default ones and any existing scopes.
|
||||||
|
"""
|
||||||
|
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
existing = map(AuthScope, existing_strs)
|
||||||
|
to_request = set(existing).union(scopes)
|
||||||
|
return await self.start_auth(to_request)
|
||||||
|
|
||||||
|
async def start_auth(self, scopes = []):
|
||||||
|
# TODO: Work out a way to just clone the current twitch object
|
||||||
|
# Or can we otherwise build UserAuthenticator without app auth?
|
||||||
|
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
|
||||||
|
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
|
||||||
|
await flow.setup()
|
||||||
|
|
||||||
|
return flow
|
||||||
|
|
||||||
|
# ----- Commands -----
|
||||||
|
@cmds.hybrid_command(name='auth')
|
||||||
|
async def cmd_auth(self, ctx: LionContext):
|
||||||
|
if ctx.interaction:
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
flow = await self.start_auth()
|
||||||
|
await ctx.reply(flow.auth.return_auth_url())
|
||||||
|
await flow.run()
|
||||||
|
await ctx.reply("Authentication Complete!")
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from data import Registry, RowModel
|
import datetime as dt
|
||||||
|
|
||||||
|
from data import Registry, RowModel, Table
|
||||||
from data.columns import Integer, String, Timestamp
|
from data.columns import Integer, String, Timestamp
|
||||||
|
|
||||||
|
|
||||||
@@ -7,22 +9,71 @@ class TwitchAuthData(Registry):
|
|||||||
"""
|
"""
|
||||||
Schema
|
Schema
|
||||||
------
|
------
|
||||||
CREATE TABLE twitch_tokens(
|
CREATE TABLE twitch_user_auth(
|
||||||
userid BIGINT PRIMARY KEY,
|
userid TEXT PRIMARY KEY,
|
||||||
access_token TEXT NOT NULL,
|
access_token TEXT NOT NULL,
|
||||||
expires_at TIMESTAMPTZ NOT NULL,
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
refresh_token TEXT NOT NULL,
|
refresh_token TEXT NOT NULL,
|
||||||
obtained_at TIMESTAMPTZ
|
obtained_at TIMESTAMPTZ
|
||||||
);
|
);
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_tablename_ = 'twitch_tokens'
|
_tablename_ = 'twitch_user_auth'
|
||||||
_cache_ = {}
|
_cache_ = {}
|
||||||
|
|
||||||
userid = Integer(primary=True)
|
userid = Integer(primary=True)
|
||||||
access_token = String()
|
access_token = String()
|
||||||
expires_at = Timestamp()
|
|
||||||
refresh_token = String()
|
refresh_token = String()
|
||||||
|
expires_at = Timestamp()
|
||||||
obtained_at = Timestamp()
|
obtained_at = Timestamp()
|
||||||
|
|
||||||
# TODO: Scopes
|
@classmethod
|
||||||
|
async def update_user_auth(
|
||||||
|
cls, userid: str, token: str, refresh: str,
|
||||||
|
expires_at: dt.datetime, obtained_at: dt.datetime,
|
||||||
|
scopes: list[str]
|
||||||
|
):
|
||||||
|
if cls._connector is None:
|
||||||
|
raise ValueError("Attempting to use uninitialised Registry.")
|
||||||
|
async with cls._connector.connection() as conn:
|
||||||
|
cls._connector.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
# Clear row for this userid
|
||||||
|
await cls.table.delete_where(userid=userid)
|
||||||
|
|
||||||
|
# Insert new user row
|
||||||
|
row = await cls.create(
|
||||||
|
userid=userid,
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh,
|
||||||
|
expires_at=expires_at,
|
||||||
|
obtained_at=obtained_at
|
||||||
|
)
|
||||||
|
# Insert new scope rows
|
||||||
|
if scopes:
|
||||||
|
await TwitchAuthData.user_scopes.insert_many(
|
||||||
|
('userid', 'scope'),
|
||||||
|
*((userid, scope) for scope in scopes)
|
||||||
|
)
|
||||||
|
return row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_scopes_for(cls, userid: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get a list of scopes stored for the given user.
|
||||||
|
Will return an empty list if the user is not authenticated.
|
||||||
|
"""
|
||||||
|
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||||
|
|
||||||
|
return [row.scope for row in rows] if rows else []
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
"""
|
||||||
|
user_scopes = Table('twitch_token_scopes')
|
||||||
|
|||||||
0
src/twitch/lib.py
Normal file
0
src/twitch/lib.py
Normal file
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@ from twitchAPI.type import AuthType
|
|||||||
from twitchio.client import asyncio
|
from twitchio.client import asyncio
|
||||||
|
|
||||||
from meta.errors import SafeCancellation
|
from meta.errors import SafeCancellation
|
||||||
|
from utils.lib import utc_now
|
||||||
from .data import TwitchAuthData
|
from .data import TwitchAuthData
|
||||||
from . import logger
|
from . import logger
|
||||||
|
|
||||||
@@ -52,12 +54,12 @@ class UserAuthFlow:
|
|||||||
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
||||||
|
|
||||||
result = await self._comm_task
|
result = await self._comm_task
|
||||||
if result['error']:
|
if result.get('error', None):
|
||||||
# TODO Custom auth errors
|
# TODO Custom auth errors
|
||||||
# This is only documented to occure when the user denies the auth
|
# This is only documented to occure when the user denies the auth
|
||||||
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||||
|
|
||||||
if result['state'] != self.auth.state:
|
if result.get('state', None) != self.auth.state:
|
||||||
# This should never happen unless the authserver has its wires crossed somehow,
|
# This should never happen unless the authserver has its wires crossed somehow,
|
||||||
# or the connection has been tampered with.
|
# or the connection has been tampered with.
|
||||||
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
||||||
@@ -76,9 +78,11 @@ class UserAuthFlow:
|
|||||||
# Fetch the associated userid and basic info
|
# Fetch the associated userid and basic info
|
||||||
v_result = await validate_token(token)
|
v_result = await validate_token(token)
|
||||||
userid = v_result['user_id']
|
userid = v_result['user_id']
|
||||||
|
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
|
||||||
|
|
||||||
# Save auth data
|
# Save auth data
|
||||||
|
return await self.data.UserAuthRow.update_user_auth(
|
||||||
async def save_auth(self, userid: str, token: str, refresh: str, scopes: list[str]):
|
userid=userid, token=token, refresh=refresh,
|
||||||
if not self.data._conn:
|
expires_at=expiry, obtained_at=utc_now(),
|
||||||
raise ValueError("Provided registry must be connected.")
|
scopes=[scope.value for scope in self.auth.scopes]
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user