11 Commits

22 changed files with 566 additions and 135 deletions

View File

@@ -1485,6 +1485,24 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name
-- }}}
-- Twitch User Auth {{{
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
-- }}}
-- Analytics Data {{{
CREATE SCHEMA "analytics";

View File

@@ -80,6 +80,14 @@ async def main():
websockets.serve(sockets.root_handler, '', conf.wserver['port'])
)
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
)
lionbot = await stack.enter_async_context(
LionBot(
command_prefix='!',
@@ -104,26 +112,15 @@ async def main():
translator=translator,
chunk_guilds_at_startup=False,
system_monitor=system_monitor,
crocbot=crocbot,
)
)
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
lionbot=lionbot
)
lionbot.crocbot = crocbot
crocbot.load_module('modules')
crocstart = asyncio.create_task(start_croccy(crocbot))
lionstart = asyncio.create_task(start_lion(lionbot))
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
crocstart.cancel()
lionstart.cancel()
# crocstart.cancel()
# lionstart.cancel()
async def start_lion(lionbot):
ctx_bot.set(lionbot)

View File

@@ -10,10 +10,6 @@ from data import Database
from .config import Conf
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__)
@@ -21,12 +17,11 @@ class CrocBot(commands.Bot):
def __init__(self, *args,
config: Conf,
data: Database,
lionbot: 'LionBot', **kwargs):
**kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data = data
self.pubsub = pubsub.PubSubPool(self)
self.lionbot = lionbot
async def event_ready(self):
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")

View File

@@ -24,6 +24,7 @@ from .errors import HandledException, SafeCancellation
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
if TYPE_CHECKING:
from meta.CrocBot import CrocBot
from core.cog import CoreCog
from core.config import ConfigCog
from tracking.voice.cog import VoiceTrackerCog
@@ -58,6 +59,7 @@ class LionBot(Bot):
initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [],
system_monitor: Optional[SystemMonitor] = None,
crocbot: Optional['CrocBot'] = None,
**kwargs
):
kwargs.setdefault('tree_cls', LionTree)
@@ -73,6 +75,8 @@ class LionBot(Bot):
self.app_ipc = app_ipc
self.translator = translator
self.crocbot = crocbot
self.system_monitor = system_monitor or SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
self.system_monitor.add_component(self.monitor)

View File

@@ -1,23 +1,36 @@
from typing import Any
from functools import partial
from typing import Any, Callable, Optional
from discord.ext.commands import Cog
from discord.ext import commands as cmds
from twitchio.ext import commands
from twitchio.ext.commands import Command, Bot
from twitchio.ext.commands.meta import CogEvent
class LionCog(Cog):
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str]
_twitch_cmds_: dict[str, Command]
_twitch_events_: dict[str, CogEvent]
_twitch_events_loaded_: set[Callable]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set()
cls._twitch_cmds_ = {}
cls._twitch_events_ = {}
for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name)
elif isinstance(value, Command):
cls._twitch_cmds_[value.name] = value
elif isinstance(value, CogEvent):
cls._twitch_events_[value.name] = value
def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list
@@ -34,6 +47,51 @@ class LionCog(Cog):
return await super()._inject(bot, *args, *kwargs)
def _load_twitch_methods(self, bot: Bot):
for name, command in self._twitch_cmds_.items():
command._instance = self
command.cog = self
bot.add_command(command)
for name, event in self._twitch_events_.items():
callback = partial(event, self)
self._twitch_events_loaded_.add(callback)
bot.add_event(callback=callback, name=name)
def _unload_twitch_methods(self, bot: Bot):
for name in self._twitch_cmds_:
bot.remove_command(name)
for callback in self._twitch_events_loaded_:
bot.remove_event(callback=callback)
self._twitch_events_loaded_.clear()
@classmethod
def twitch_event(cls, event: Optional[str] = None):
def decorator(func) -> CogEvent:
event_name = event or func.__name__
return CogEvent(name=event_name, func=func, module=cls.__module__)
return decorator
async def cog_check(self, ctx): # type: ignore
"""
TwitchIO assumes cog_check is a coroutine,
so here we narrow the check to only a coroutine.
The ctx maybe either be a twitch command context or a dpy context.
"""
if isinstance(ctx, cmds.Context):
return await self.cog_check_discord(ctx)
if isinstance(ctx, commands.Context):
return await self.cog_check_twitch(ctx)
async def cog_check_discord(self, ctx: cmds.Context):
return True
async def cog_check_twitch(self, ctx: commands.Context):
return True
@classmethod
def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True

View File

@@ -26,20 +26,12 @@ active_discord = [
'.premium',
'.streamalerts',
'.test',
]
active_twitch = [
'.counters',
'.nowdoing',
'.shoutouts',
'.counters',
'.tagstrings',
]
def prepare(bot):
for ext in active_twitch:
bot.load_module(this_package + ext)
async def setup(bot):
for ext in active_discord:
await bot.load_extension(ext, package=this_package)

View File

@@ -4,10 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import CounterCog
def prepare(bot):
bot.add_cog(CounterCog(bot))
async def setup(bot):
from .lion_cog import CounterCog
await bot.add_cog(CounterCog(bot))

View File

@@ -3,11 +3,14 @@ from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import CrocBot
from meta import LionCog, LionBot, CrocBot
from utils.lib import utc_now
from . import logger
from .data import CounterData
@@ -22,10 +25,11 @@ class PERIOD(Enum):
YEAR = ('this year', 'y', 'year', 'yearly')
class CounterCog(commands.Cog):
def __init__(self, bot: CrocBot):
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(CounterData())
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(CounterData())
self.loaded = asyncio.Event()
@@ -33,9 +37,18 @@ class CounterCog(commands.Cog):
self.counters = {}
async def cog_load(self):
self._load_twitch_methods(self.crocbot)
await self.data.init()
await self.load_counters()
self.loaded.set()
async def cog_unload(self):
self._unload_twitch_methods(self.crocbot)
async def cog_check(self, ctx):
return True
async def load_counters(self):
"""
Initialise counter name cache.
@@ -46,18 +59,6 @@ class CounterCog(commands.Cog):
f"Loaded {len(self.counters)} counters."
)
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
# Counters API
async def fetch_counter(self, counter: str) -> CounterData.Counter:
@@ -171,7 +172,7 @@ class CounterCog(commands.Cog):
if period is PERIOD.ALL:
start_time = None
elif period is PERIOD.STREAM:
streams = await self.bot.fetch_streams(user_ids=[userid])
streams = await self.crocbot.fetch_streams(user_ids=[userid])
if streams:
stream = streams[0]
start_time = stream.started_at
@@ -199,7 +200,7 @@ class CounterCog(commands.Cog):
lb = await self.leaderboard(counter, start_time=start_time)
if lb:
userids = list(lb.keys())
users = await self.bot.fetch_users(ids=userids)
users = await self.crocbot.fetch_users(ids=userids)
name_map = {user.id: user.display_name for user in users}
parts = []
for userid, total in lb.items():
@@ -283,17 +284,9 @@ class CounterCog(commands.Cog):
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
@commands.command()
async def reload(self, ctx: commands.Context, *, args: str = ''):
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
return
if not args:
await ctx.reply("Full reload not implemented yet.")
else:
try:
self.bot.reload_module(args)
except Exception:
logger.exception("Failed to reload")
await ctx.reply("Failed to reload module! Check console~")
else:
await ctx.reply("Reloaded!")
async def stuff(self, ctx: commands.Context, *, args: str = ''):
await ctx.reply(f"Stuff {args}")
@cmds.hybrid_command('water')
async def d_water_cmd(self, ctx):
await ctx.reply(repr(ctx))

View File

@@ -1,23 +0,0 @@
import asyncio
from typing import Optional
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError
from meta.logger import log_wrap
from utils.lib import utc_now
from data.conditions import NULL
from . import logger
from .data import CounterData
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.counter_cog = bot.crocbot.get_cog('CounterCog')

View File

@@ -4,6 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import NowDoingCog
def prepare(bot):
logger.info("Preparing the nowdoing module.")
bot.add_cog(NowDoingCog(bot))
async def setup(bot):
await bot.add_cog(NowDoingCog(bot))

View File

@@ -8,7 +8,8 @@ from attr import dataclass
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta import CrocBot, LionCog
from meta.LionBot import LionBot
from meta.sockets import Channel, register_channel
from utils.lib import strfdelta, utc_now
from . import logger
@@ -78,10 +79,11 @@ class NowDoingChannel(Channel):
})
class NowDoingCog(commands.Cog):
def __init__(self, bot: CrocBot):
class NowDoingCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(NowListData())
self.crocbot = bot.crocbot
self.data = bot.db.load_registry(NowListData())
self.channel = NowDoingChannel(self)
register_channel(self.channel.name, self.channel)
@@ -94,21 +96,19 @@ class NowDoingCog(commands.Cog):
await self.data.init()
await self.load_tasks()
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
"""
Hack because lib devs decided to remove async cog loading.
"""
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_unload(self):
self.loaded.clear()
self.tasks.clear()
self._unload_twitch_methods(self.crocbot)
async def cog_check(self, ctx):
await self.ensure_loaded()
if not self.loaded.is_set():
await ctx.reply("Tasklists are still loading! Please wait a moment~")
return False
return True
async def load_tasks(self):
@@ -130,6 +130,7 @@ class NowDoingCog(commands.Cog):
@commands.command(aliases=['task', 'check'])
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
args = args.strip() if args else None
if args:
await self.data.Task.table.delete_where(userid=userid)
task = await self.data.Task.create(

View File

@@ -4,5 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import ShoutoutCog
def prepare(bot):
bot.add_cog(ShoutoutCog(bot))
async def setup(bot):
await bot.add_cog(ShoutoutCog(bot))

View File

@@ -4,50 +4,50 @@ from typing import Optional
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta import CrocBot, LionBot, LionCog
from utils.lib import replace_multiple
from . import logger
from .data import ShoutoutData
class ShoutoutCog(commands.Cog):
class ShoutoutCog(LionCog):
# Future extension: channel defaults and config
DEFAULT_SHOUTOUT = """
We think that {name} is a great streamer and you should check them out \
and drop a follow! \
They {areorwere} streaming {game} at {channel}
"""
def __init__(self, bot: CrocBot):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(ShoutoutData())
self.crocbot = bot.crocbot
self.data = bot.db.load_registry(ShoutoutData())
self.loaded = asyncio.Event()
async def cog_load(self):
await self.data.init()
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_unload(self):
self.loaded.clear()
self._unload_twitch_methods(self.crocbot)
async def cog_check(self, ctx):
await self.ensure_loaded()
if not self.loaded.is_set():
await ctx.reply("Tasklists are still loading! Please wait a moment~")
return False
return True
async def format_shoutout(self, text: str, user: twitchio.User):
channels = await self.bot.fetch_channels([user.id])
channels = await self.crocbot.fetch_channels([user.id])
if channels:
channel = channels[0]
game = channel.game_name or 'Unknown'
else:
game = 'Unknown'
streams = await self.bot.fetch_streams([user.id])
streams = await self.crocbot.fetch_streams([user.id])
live = bool(streams)
mapping = {

View File

@@ -4,5 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import TagCog
def prepare(bot):
bot.add_cog(TagCog(bot))
async def setup(bot):
await bot.add_cog(TagCog(bot))

View File

@@ -6,16 +6,17 @@ import difflib
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta import CrocBot, LionBot, LionCog
from utils.lib import utc_now
from . import logger
from .data import TagData
class TagCog(commands.Cog):
def __init__(self, bot: CrocBot):
class TagCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(TagData())
self.crocbot = bot.crocbot
self.data = bot.db.load_registry(TagData())
self.loaded = asyncio.Event()
@@ -31,19 +32,24 @@ class TagCog(commands.Cog):
self.tags.clear()
self.tags.update(tags)
logger.info(f"Loaded {len(tags)} into cache.")
async def cog_load(self):
await self.data.init()
await self.load_tags()
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
async def cog_unload(self):
self.loaded.clear()
self.tags.clear()
self._unload_twitch_methods(self.crocbot)
@commands.Cog.event('event_ready')
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
if not self.loaded.is_set():
await ctx.reply("Tasklists are still loading! Please wait a moment~")
return False
return True
# API

9
src/twitch/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
from .cog import TwitchAuthCog
async def setup(bot):
await bot.add_cog(TwitchAuthCog(bot))

50
src/twitch/authclient.py Normal file
View File

@@ -0,0 +1,50 @@
"""
Testing client for the twitch AuthServer.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
import asyncio
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.type import AuthScope
from meta.config import conf
URI = "http://localhost:3000/twiauth/confirm"
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
async def main():
# Load in client id and secret
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
url = auth.return_auth_url()
# Post url to user
print(url)
# Send listen request to server
# Wait for listen request
async with aiohttp.ClientSession() as session:
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
await ws.send_json({'state': auth.state})
result = await ws.receive_json()
# Hopefully get back code, print the response
print(f"Recieved: {result}")
# Authorise with code and client details
tokens = await auth.authenticate(user_token=result['code'])
if tokens:
token, refresh = tokens
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
print(f"Authorised!")
if __name__ == '__main__':
asyncio.run(main())

86
src/twitch/authserver.py Normal file
View File

@@ -0,0 +1,86 @@
import logging
import uuid
import asyncio
from contextvars import ContextVar
import aiohttp
from aiohttp import web
logger = logging.getLogger(__name__)
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
class AuthServer:
def __init__(self):
self.listeners = {}
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
args = request.query
if 'state' not in args:
raise web.HTTPBadRequest(text="No state provided.")
if args['state'] not in self.listeners:
raise web.HTTPBadRequest(text="Invalid state.")
self.listeners[args['state']].set_result(dict(args))
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
_reqid = str(uuid.uuid1())
reqid.set(_reqid)
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get the listen request data
try:
listen_req = await ws.receive_json(timeout=60)
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
if 'state' not in listen_req:
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
raise web.HTTPBadRequest(text="Listen request must include state string.")
elif listen_req['state'] in self.listeners:
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
raise web.HTTPBadRequest(text="Invalid state string.")
except ValueError:
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except TypeError:
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
except Exception:
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
raise web.HTTPInternalServerError()
try:
fut = self.listeners[listen_req['state']] = asyncio.Future()
result = await asyncio.wait_for(fut, timeout=120)
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
finally:
self.listeners.pop(listen_req['state'], None)
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
await ws.send_json(result)
await ws.close()
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
return ws
def main(argv):
app = web.Application()
server = AuthServer()
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
app.router.add_get("/twiauth/listen", server.handle_listen_request)
logger.info("App setup and configured. Starting now.")
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
if __name__ == '__main__':
import sys
main(sys.argv)

84
src/twitch/cog.py Normal file
View File

@@ -0,0 +1,84 @@
import asyncio
from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.twitch import AuthType, Twitch
from twitchAPI.type import AuthScope
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot
from meta.LionContext import LionContext
from twitch.userflow import UserAuthFlow
from utils.lib import utc_now
from . import logger
from .data import TwitchAuthData
class TwitchAuthCog(LionCog):
DEFAULT_SCOPES = []
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(TwitchAuthData())
async def cog_load(self):
await self.data.init()
# ----- Auth API -----
async def fetch_client_for(self, userid: int):
...
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
"""
Checks whether the given userid is authorised.
If 'scopes' is given, will also check the user has all of the given scopes.
"""
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow:
if scopes:
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
has_auth = set(map(str, scopes)).issubset(has_scopes)
else:
has_auth = True
else:
has_auth = False
return has_auth
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
"""
Start the user authentication flow for the given userid.
Will request the given scopes along with the default ones and any existing scopes.
"""
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
existing = map(AuthScope, existing_strs)
to_request = set(existing).union(scopes)
return await self.start_auth(to_request)
async def start_auth(self, scopes = []):
# TODO: Work out a way to just clone the current twitch object
# Or can we otherwise build UserAuthenticator without app auth?
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
await flow.setup()
return flow
# ----- Commands -----
@cmds.hybrid_command(name='auth')
async def cmd_auth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
flow = await self.start_auth()
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")

79
src/twitch/data.py Normal file
View File

@@ -0,0 +1,79 @@
import datetime as dt
from data import Registry, RowModel, Table
from data.columns import Integer, String, Timestamp
class TwitchAuthData(Registry):
class UserAuthRow(RowModel):
"""
Schema
------
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
"""
_tablename_ = 'twitch_user_auth'
_cache_ = {}
userid = Integer(primary=True)
access_token = String()
refresh_token = String()
expires_at = Timestamp()
obtained_at = Timestamp()
@classmethod
async def update_user_auth(
cls, userid: str, token: str, refresh: str,
expires_at: dt.datetime, obtained_at: dt.datetime,
scopes: list[str]
):
if cls._connector is None:
raise ValueError("Attempting to use uninitialised Registry.")
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Clear row for this userid
await cls.table.delete_where(userid=userid)
# Insert new user row
row = await cls.create(
userid=userid,
access_token=token,
refresh_token=refresh,
expires_at=expires_at,
obtained_at=obtained_at
)
# Insert new scope rows
if scopes:
await TwitchAuthData.user_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in scopes)
)
return row
@classmethod
async def get_scopes_for(cls, userid: str) -> list[str]:
"""
Get a list of scopes stored for the given user.
Will return an empty list if the user is not authenticated.
"""
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
return [row.scope for row in rows] if rows else []
"""
Schema
------
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
"""
user_scopes = Table('twitch_token_scopes')

0
src/twitch/lib.py Normal file
View File

88
src/twitch/userflow.py Normal file
View File

@@ -0,0 +1,88 @@
from typing import Optional
import datetime as dt
from aiohttp import web
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator, validate_token
from twitchAPI.type import AuthType
from twitchio.client import asyncio
from meta.errors import SafeCancellation
from utils.lib import utc_now
from .data import TwitchAuthData
from . import logger
class UserAuthFlow:
auth: UserAuthenticator
data: TwitchAuthData
auth_ws: str
def __init__(self, data, auth, auth_ws):
self.auth = auth
self.data = data
self.auth_ws = auth_ws
self._setup_done = asyncio.Event()
self._comm_task: Optional[asyncio.Task] = None
async def setup(self):
"""
Establishes websocket connection to the AuthServer,
and requests listening for the given state.
Propagates any exceptions that occur during connection setup.
"""
if self._setup_done.is_set():
raise ValueError("UserAuthFlow is already set up.")
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
await self._setup_done.wait()
if self._comm_task.done() and (exc := self._comm_task.exception()):
raise exc
async def _communicate(self):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.auth_ws) as ws:
await ws.send_json({'state': self.auth.state})
self._setup_done.set()
return await ws.receive_json()
async def run(self):
if not self._setup_done.is_set():
raise ValueError("Cannot run UserAuthFlow before setup.")
if self._comm_task is None:
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
result = await self._comm_task
if result.get('error', None):
# TODO Custom auth errors
# This is only documented to occure when the user denies the auth
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
if result.get('state', None) != self.auth.state:
# This should never happen unless the authserver has its wires crossed somehow,
# or the connection has been tampered with.
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
logger.critical(
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
)
raise SafeCancellation(
"Could not complete authentication! Invalid server response."
)
# Now assume result has a valid code
# Exchange code for an auth token and a refresh token
# Ignore type here, authenticate returns None if a callback function has been given.
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
# Fetch the associated userid and basic info
v_result = await validate_token(token)
userid = v_result['user_id']
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
# Save auth data
return await self.data.UserAuthRow.update_user_auth(
userid=userid, token=token, refresh=refresh,
expires_at=expiry, obtained_at=utc_now(),
scopes=[scope.value for scope in self.auth.scopes]
)