Compare commits
1 Commits
feat-auth
...
timerlayou
| Author | SHA1 | Date | |
|---|---|---|---|
| 87488eaf99 |
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,9 +1,9 @@
|
||||
[submodule "bot/gui"]
|
||||
path = src/gui
|
||||
url = https://github.com/StudyLions/StudyLion-Plugin-GUI.git
|
||||
url = git@github.com:Intery/CafeHelper-GUI.git
|
||||
[submodule "skins"]
|
||||
path = skins
|
||||
url = https://github.com/Intery/pillow-skins.git
|
||||
url = git@github.com:Intery/CafeHelper-Skins.git
|
||||
[submodule "src/modules/voicefix"]
|
||||
path = src/modules/voicefix
|
||||
url = https://github.com/Intery/StudyLion-voicefix.git
|
||||
|
||||
@@ -1485,24 +1485,6 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name
|
||||
-- }}}
|
||||
|
||||
|
||||
-- Twitch User Auth {{{
|
||||
CREATE TABLE twitch_user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
access_token TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
obtained_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE twitch_user_scopes(
|
||||
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
scope TEXT
|
||||
);
|
||||
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||
|
||||
-- }}}
|
||||
|
||||
|
||||
-- Analytics Data {{{
|
||||
CREATE SCHEMA "analytics";
|
||||
|
||||
2
skins
2
skins
Submodule skins updated: d3d6a28bc9...686857321e
2
src/gui
2
src/gui
Submodule src/gui updated: c1bcb05c25...40bc140355
@@ -1,9 +0,0 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import TwitchAuthCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TwitchAuthCog(bot))
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
"""
|
||||
Testing client for the twitch AuthServer.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from twitchAPI.twitch import Twitch
|
||||
from twitchAPI.oauth import UserAuthenticator
|
||||
from twitchAPI.type import AuthScope
|
||||
|
||||
from meta.config import conf
|
||||
|
||||
|
||||
URI = "http://localhost:3000/twiauth/confirm"
|
||||
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
|
||||
|
||||
async def main():
|
||||
# Load in client id and secret
|
||||
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
|
||||
url = auth.return_auth_url()
|
||||
|
||||
# Post url to user
|
||||
print(url)
|
||||
|
||||
# Send listen request to server
|
||||
# Wait for listen request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
|
||||
await ws.send_json({'state': auth.state})
|
||||
result = await ws.receive_json()
|
||||
|
||||
# Hopefully get back code, print the response
|
||||
print(f"Recieved: {result}")
|
||||
|
||||
# Authorise with code and client details
|
||||
tokens = await auth.authenticate(user_token=result['code'])
|
||||
if tokens:
|
||||
token, refresh = tokens
|
||||
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
|
||||
print(f"Authorised!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -1,86 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
import asyncio
|
||||
from contextvars import ContextVar
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
|
||||
|
||||
|
||||
class AuthServer:
|
||||
def __init__(self):
|
||||
self.listeners = {}
|
||||
|
||||
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
|
||||
args = request.query
|
||||
if 'state' not in args:
|
||||
raise web.HTTPBadRequest(text="No state provided.")
|
||||
if args['state'] not in self.listeners:
|
||||
raise web.HTTPBadRequest(text="Invalid state.")
|
||||
self.listeners[args['state']].set_result(dict(args))
|
||||
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
|
||||
|
||||
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
|
||||
_reqid = str(uuid.uuid1())
|
||||
reqid.set(_reqid)
|
||||
|
||||
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
|
||||
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
# Get the listen request data
|
||||
try:
|
||||
listen_req = await ws.receive_json(timeout=60)
|
||||
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
|
||||
if 'state' not in listen_req:
|
||||
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
|
||||
raise web.HTTPBadRequest(text="Listen request must include state string.")
|
||||
elif listen_req['state'] in self.listeners:
|
||||
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
|
||||
raise web.HTTPBadRequest(text="Invalid state string.")
|
||||
except ValueError:
|
||||
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
|
||||
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||
except TypeError:
|
||||
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
|
||||
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
|
||||
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
|
||||
except Exception:
|
||||
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
|
||||
raise web.HTTPInternalServerError()
|
||||
|
||||
try:
|
||||
fut = self.listeners[listen_req['state']] = asyncio.Future()
|
||||
result = await asyncio.wait_for(fut, timeout=120)
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
|
||||
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
|
||||
finally:
|
||||
self.listeners.pop(listen_req['state'], None)
|
||||
|
||||
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
|
||||
await ws.send_json(result)
|
||||
await ws.close()
|
||||
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
|
||||
|
||||
return ws
|
||||
|
||||
def main(argv):
|
||||
app = web.Application()
|
||||
server = AuthServer()
|
||||
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
|
||||
app.router.add_get("/twiauth/listen", server.handle_listen_request)
|
||||
|
||||
logger.info("App setup and configured. Starting now.")
|
||||
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
main(sys.argv)
|
||||
@@ -1,84 +0,0 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
from twitchAPI.oauth import UserAuthenticator
|
||||
from twitchAPI.twitch import AuthType, Twitch
|
||||
from twitchAPI.type import AuthScope
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import LionCog, LionBot, CrocBot
|
||||
from meta.LionContext import LionContext
|
||||
from twitch.userflow import UserAuthFlow
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import TwitchAuthData
|
||||
|
||||
|
||||
class TwitchAuthCog(LionCog):
|
||||
DEFAULT_SCOPES = []
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(TwitchAuthData())
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
# ----- Auth API -----
|
||||
|
||||
async def fetch_client_for(self, userid: int):
|
||||
...
|
||||
|
||||
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||
"""
|
||||
Checks whether the given userid is authorised.
|
||||
If 'scopes' is given, will also check the user has all of the given scopes.
|
||||
"""
|
||||
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||
if authrow:
|
||||
if scopes:
|
||||
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
has_auth = set(map(str, scopes)).issubset(has_scopes)
|
||||
else:
|
||||
has_auth = True
|
||||
else:
|
||||
has_auth = False
|
||||
return has_auth
|
||||
|
||||
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
|
||||
"""
|
||||
Start the user authentication flow for the given userid.
|
||||
Will request the given scopes along with the default ones and any existing scopes.
|
||||
"""
|
||||
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
existing = map(AuthScope, existing_strs)
|
||||
to_request = set(existing).union(scopes)
|
||||
return await self.start_auth(to_request)
|
||||
|
||||
async def start_auth(self, scopes = []):
|
||||
# TODO: Work out a way to just clone the current twitch object
|
||||
# Or can we otherwise build UserAuthenticator without app auth?
|
||||
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
|
||||
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
|
||||
await flow.setup()
|
||||
|
||||
return flow
|
||||
|
||||
# ----- Commands -----
|
||||
@cmds.hybrid_command(name='auth')
|
||||
async def cmd_auth(self, ctx: LionContext):
|
||||
if ctx.interaction:
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
flow = await self.start_auth()
|
||||
await ctx.reply(flow.auth.return_auth_url())
|
||||
await flow.run()
|
||||
await ctx.reply("Authentication Complete!")
|
||||
@@ -1,79 +0,0 @@
|
||||
import datetime as dt
|
||||
|
||||
from data import Registry, RowModel, Table
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class TwitchAuthData(Registry):
|
||||
class UserAuthRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE twitch_user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
access_token TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
obtained_at TIMESTAMPTZ
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'twitch_user_auth'
|
||||
_cache_ = {}
|
||||
|
||||
userid = Integer(primary=True)
|
||||
access_token = String()
|
||||
refresh_token = String()
|
||||
expires_at = Timestamp()
|
||||
obtained_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def update_user_auth(
|
||||
cls, userid: str, token: str, refresh: str,
|
||||
expires_at: dt.datetime, obtained_at: dt.datetime,
|
||||
scopes: list[str]
|
||||
):
|
||||
if cls._connector is None:
|
||||
raise ValueError("Attempting to use uninitialised Registry.")
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
# Clear row for this userid
|
||||
await cls.table.delete_where(userid=userid)
|
||||
|
||||
# Insert new user row
|
||||
row = await cls.create(
|
||||
userid=userid,
|
||||
access_token=token,
|
||||
refresh_token=refresh,
|
||||
expires_at=expires_at,
|
||||
obtained_at=obtained_at
|
||||
)
|
||||
# Insert new scope rows
|
||||
if scopes:
|
||||
await TwitchAuthData.user_scopes.insert_many(
|
||||
('userid', 'scope'),
|
||||
*((userid, scope) for scope in scopes)
|
||||
)
|
||||
return row
|
||||
|
||||
@classmethod
|
||||
async def get_scopes_for(cls, userid: str) -> list[str]:
|
||||
"""
|
||||
Get a list of scopes stored for the given user.
|
||||
Will return an empty list if the user is not authenticated.
|
||||
"""
|
||||
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||
|
||||
return [row.scope for row in rows] if rows else []
|
||||
|
||||
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE twitch_user_scopes(
|
||||
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
scope TEXT
|
||||
);
|
||||
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||
"""
|
||||
user_scopes = Table('twitch_token_scopes')
|
||||
@@ -1,88 +0,0 @@
|
||||
from typing import Optional
|
||||
import datetime as dt
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import aiohttp
|
||||
from twitchAPI.twitch import Twitch
|
||||
from twitchAPI.oauth import UserAuthenticator, validate_token
|
||||
from twitchAPI.type import AuthType
|
||||
from twitchio.client import asyncio
|
||||
|
||||
from meta.errors import SafeCancellation
|
||||
from utils.lib import utc_now
|
||||
from .data import TwitchAuthData
|
||||
from . import logger
|
||||
|
||||
class UserAuthFlow:
|
||||
auth: UserAuthenticator
|
||||
data: TwitchAuthData
|
||||
auth_ws: str
|
||||
|
||||
def __init__(self, data, auth, auth_ws):
|
||||
self.auth = auth
|
||||
self.data = data
|
||||
self.auth_ws = auth_ws
|
||||
|
||||
self._setup_done = asyncio.Event()
|
||||
self._comm_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Establishes websocket connection to the AuthServer,
|
||||
and requests listening for the given state.
|
||||
Propagates any exceptions that occur during connection setup.
|
||||
"""
|
||||
if self._setup_done.is_set():
|
||||
raise ValueError("UserAuthFlow is already set up.")
|
||||
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
|
||||
await self._setup_done.wait()
|
||||
if self._comm_task.done() and (exc := self._comm_task.exception()):
|
||||
raise exc
|
||||
|
||||
async def _communicate(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect(self.auth_ws) as ws:
|
||||
await ws.send_json({'state': self.auth.state})
|
||||
self._setup_done.set()
|
||||
return await ws.receive_json()
|
||||
|
||||
async def run(self):
|
||||
if not self._setup_done.is_set():
|
||||
raise ValueError("Cannot run UserAuthFlow before setup.")
|
||||
if self._comm_task is None:
|
||||
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
||||
|
||||
result = await self._comm_task
|
||||
if result.get('error', None):
|
||||
# TODO Custom auth errors
|
||||
# This is only documented to occure when the user denies the auth
|
||||
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||
|
||||
if result.get('state', None) != self.auth.state:
|
||||
# This should never happen unless the authserver has its wires crossed somehow,
|
||||
# or the connection has been tampered with.
|
||||
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
||||
logger.critical(
|
||||
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
|
||||
)
|
||||
raise SafeCancellation(
|
||||
"Could not complete authentication! Invalid server response."
|
||||
)
|
||||
|
||||
# Now assume result has a valid code
|
||||
# Exchange code for an auth token and a refresh token
|
||||
# Ignore type here, authenticate returns None if a callback function has been given.
|
||||
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
|
||||
|
||||
# Fetch the associated userid and basic info
|
||||
v_result = await validate_token(token)
|
||||
userid = v_result['user_id']
|
||||
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
|
||||
|
||||
# Save auth data
|
||||
return await self.data.UserAuthRow.update_user_auth(
|
||||
userid=userid, token=token, refresh=refresh,
|
||||
expires_at=expiry, obtained_at=utc_now(),
|
||||
scopes=[scope.value for scope in self.auth.scopes]
|
||||
)
|
||||
7
tests/__init__.py
Normal file
7
tests/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# !/bin/python3
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from src.cards import WeeklyGoalCard
|
||||
from gui.cards import WeeklyGoalCard
|
||||
|
||||
|
||||
async def get_card():
|
||||
card = await WeeklyGoalCard.generate_sample()
|
||||
with open('samples/weekly-sample.png', 'wb') as image_file:
|
||||
with open('output/weekly-sample.png', 'wb') as image_file:
|
||||
image_file.write(card.fp.read())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
15
tests/gui/cards/pomo_sample.py
Normal file
15
tests/gui/cards/pomo_sample.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from gui.cards import BreakTimerCard, FocusTimerCard
|
||||
|
||||
|
||||
async def get_card():
|
||||
card = await BreakTimerCard.generate_sample()
|
||||
with open('output/break_timer_sample.png', 'wb') as image_file:
|
||||
image_file.write(card.fp.read())
|
||||
card = await FocusTimerCard.generate_sample()
|
||||
with open('output/focus_timer_sample.png', 'wb') as image_file:
|
||||
image_file.write(card.fp.read())
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(get_card())
|
||||
Reference in New Issue
Block a user